聊一聊TensorFlow的數(shù)據(jù)導(dǎo)入機(jī)制
今天我們要講的是TensorFlow中的數(shù)據(jù)導(dǎo)入機(jī)制,傳統(tǒng)的做法是習(xí)慣于先構(gòu)建好TF圖模型,然后開(kāi)啟一個(gè)會(huì)話(Session),在運(yùn)行圖模型之前將數(shù)據(jù)feed到圖中,這種做法的缺點(diǎn)是數(shù)據(jù)IO帶來(lái)的時(shí)間消耗很大,那么在訓(xùn)練非常龐大的數(shù)據(jù)集的時(shí)候,不提倡采用這種做法,TensorFlow中取而代之的是tf.data.Dataset模塊,今天我們重點(diǎn)介紹這個(gè)。
tf.data是一個(gè)十分強(qiáng)大的可以用于構(gòu)建復(fù)雜的數(shù)據(jù)導(dǎo)入機(jī)制的API,例如,如果你要處理的是圖像,那么tf.data可以幫助你把分布在不同位置的文件整合到一起,并且對(duì)每幅圖片添加微小的隨機(jī)噪聲,以及隨機(jī)選取一部分圖片作為一個(gè)batch進(jìn)行訓(xùn)練;又或者是你要處理文本,那么tf.data可以幫助從文本中解析符號(hào)并且轉(zhuǎn)換成embedding矩陣,然后將不同長(zhǎng)度的序列變成一個(gè)個(gè)batch。
我們可以用tf.data.Dataset來(lái)構(gòu)建一個(gè)數(shù)據(jù)集,數(shù)據(jù)集的來(lái)源可以有多種方式,例如如果你的數(shù)據(jù)集是預(yù)先以TFRecord格式寫(xiě)在硬盤(pán)上的,那么你可以用tf.data.TFRecordDataset來(lái)構(gòu)建;如果你的數(shù)據(jù)集是內(nèi)存中的tensor變量,那么可以用tf.data.Dataset.from_tensors() 或 tf.data.Dataset.from_tensor_slices()來(lái)構(gòu)建。下面我將通過(guò)代碼來(lái)演示它們。
首先,我們來(lái)看從內(nèi)存中的tensor變量來(lái)構(gòu)建數(shù)據(jù)集,如下代碼所示,首先構(gòu)建了一個(gè)0~10的數(shù)據(jù)集,然后構(gòu)建迭代器,迭代器可以每次從數(shù)據(jù)集中提取一個(gè)元素:
import tensorflow as tf dataset=tf.data.Dataset.range(10) iterator=dataset.make_one_shot_iterator() next_element = iterator.get_next()with tf.Session() as sess: for _ in range(10): print(sess.run(next_element))
如上代碼所示,range()是tf.data.Dataset類的一個(gè)靜態(tài)函數(shù),用于產(chǎn)生一段序列。需要注意的是,構(gòu)建的數(shù)據(jù)集需要是同一種數(shù)據(jù)類型以及內(nèi)部結(jié)構(gòu)。除此之外,由于range(10)代表0~9一共十個(gè)數(shù),因此,這里的iterator只能運(yùn)行10次,超過(guò)以后將會(huì)拋出tf.errors.OutOfRangeError異常。如果希望不拋出異常,則可以調(diào)用dataset.repeat(count)即可實(shí)現(xiàn)count次自動(dòng)重復(fù)的迭代器。
range的范圍我們也可以在運(yùn)行時(shí)才確定,即定義max_range為placeholder變量,這個(gè)時(shí)候需要調(diào)用Dataset的make_initializable_iterator方法來(lái)構(gòu)建迭代器,并且這個(gè)迭代器的operation需要在迭代之前被運(yùn)行,代碼如下所示:
max_range=tf.placeholder(tf.int64, shape=[]) dataset = tf.data.Dataset.range(max_range) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next()with tf.Session() as sess: sess.run(iterator.initializer, feed_dict={max_range: 10}) for _ in range(10): print(sess.run(next_element))
也可以為不同的數(shù)據(jù)集創(chuàng)建同一個(gè)迭代器,為了使得這個(gè)迭代器可以被重復(fù)使用,需要保證不同數(shù)據(jù)集的類型和維度是一致的。例如,下面的代碼演示了如何使用同一個(gè)迭代器來(lái)構(gòu)建訓(xùn)練集和驗(yàn)證集,可以看到,當(dāng)我們開(kāi)始訓(xùn)練訓(xùn)練集的時(shí)候,就需要先執(zhí)行training_init_op,目的是使得迭代器開(kāi)始加載訓(xùn)練數(shù)據(jù);而當(dāng)進(jìn)行驗(yàn)證的時(shí)候,則需要先執(zhí)行validation_init_op,道理一樣。
training_data = tf.data.Dataset.range(100).map(lambda x: x+tf.random_uniform([], -10, 10, tf.int64)) validation_data = tf.data.Dataset.range(50) iterator = tf.Iterator.from_structure(training_data.output_types, training_data.output_shapes) iterator = tf.data.Iterator.from_structure(training_data.output_types, training_data.output_shapes) next_element = iterator.get_next() training_init_op=iterator.make_initializer(training_data) validation_init_op=iterator.make_initializer(validation_data)with tf.Session() as sess: for epoch in range(10): sess.run(training_init_op) for _ in range(100): sess.run(next_element) sess.run(validation_init_op) for _ in range(50): sess.run(next_element)
也可以通過(guò)Tensor變量構(gòu)建tf.data.Dataset,如下代碼所示,需要注意的是,這里的Tensor的維度是4×10,因此,傳入到迭代器中就是可以運(yùn)行4次,每次運(yùn)行生成一個(gè)長(zhǎng)度為10的向量。
import tensorflow as tf dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10])) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next()with tf.Session() as sess: sess.run(iterator.initializer) for i in range(4): value = sess.run(next_element) print(value)
最后,還有一種比較常見(jiàn)的讀取數(shù)據(jù)的方式,就是從TFRecord文件中去讀取,這里再介紹一下之前在語(yǔ)音識(shí)別項(xiàng)目里采取的TFRecord的讀寫(xiě)代碼。
首先是將音頻特征寫(xiě)入到TFRecord文件之中,在語(yǔ)音識(shí)別中,我們最常用的兩個(gè)特征就是MFCC和LogFBank,要寫(xiě)入文件中的不僅僅是這兩個(gè)變量,還要有文本標(biāo)簽Label以及特征序列的長(zhǎng)度sequence_legnth,這四個(gè)變量中,只有sequence_length是整數(shù)標(biāo)量,其他三個(gè)都是列表格式,所以這里對(duì)于列表使用字節(jié)來(lái)保存,而對(duì)于標(biāo)量,使用整型來(lái)保存。
def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))class RecordWriter(object): def __init__(self): pass def write(self, content, tfrecords_filename): writer = tf.python_io.TFRecordWriter(tfrecords_filename) if isinstance(content, list): feature_dict = {} for i in range(len(content)): feature = content[i] if i==0: feature_raw = np.array(feature).tostring() feature_dict['mfccFeat']=_bytes_feature(feature_raw) elif i==1: feature_raw = np.array(feature).tostring() feature_dict['logfbankFeat']=_bytes_feature(feature_raw) elif i==2: feature_raw = np.array(feature).tostring() feature_dict['label']=_bytes_feature(feature_raw) else: feature_dict['sequence_length']=_int64_feature(feature) features_to_write = tf.train.Example(features=tf.train.Features(feature=feature_dict)) writer.write(features_to_write.SerializeToString()) writer.close() print('Record has been writen:'+tfrecords_filename)
寫(xiě)好TFRecord以后,在讀取的時(shí)候首先需要對(duì)TFRecord格式文件進(jìn)行解析,解析函數(shù)如下:
def parse(self, serialized): feature_dict={} feature_dict['mfccFeat']=tf.FixedLenFeature([], tf.string) feature_dict['logfbankFeat']=tf.FixedLenFeature([], tf.string) feature_dict['label']=tf.FixedLenFeature([], tf.string) feature_dict['sequence_length']=tf.FixedLenFeature([1], tf.int64) features = tf.parse_single_example( serialized, features=feature_dict) mfcc = tf.reshape(tf.decode_raw(features['mfccFeat'], tf.float32), [-1, self.feature_num]) logfbank = tf.reshape(tf.decode_raw(features['logfbankFeat'], tf.float32), [-1, self.feature_num]) label = tf.decode_raw(features['label'], tf.int64) return mfcc, logfbank, label, features['sequence_length']
然后我們可以直接通過(guò)調(diào)用tf.data.TFRecordDataset來(lái)導(dǎo)入TFRecord文件列表,以及對(duì)每個(gè)文件調(diào)用parse函數(shù)進(jìn)行解析,并且由于每個(gè)文件的特征矩陣長(zhǎng)度不一,所以需要對(duì)齊進(jìn)行padding操作,最終可以獲得迭代器,代碼如下:
self.fileNameList = tf.placeholder(tf.string, [None, ]) padded_shapes= ([-1,feature_num],[-1,feature_num],[-1],[1]) padded_values = (0.0,0.0,np.int64(-1),np.int64(0)) dataset = tf.data.TFRecordDataset(self.fileNameList, buffer_size=self.buffer_size).map(self.parse, num_parallel_call).padded_batch(batch_size, padded_shapes, padded_values) self.iterator = tf.data.Iterator.from_structure((tf.float32, tf.float32, tf.int64, tf.int64), (tf.TensorShape([None, None, 60]), tf.TensorShape([None, None, 60]), tf.TensorShape([None, None]), tf.TensorShape([None, None]))) self.initializer = self.iterator.make_initializer(dataset)
于是,關(guān)于TFRecord文件的讀寫(xiě)就介紹完了,并且,基于TensorFlow的數(shù)據(jù)導(dǎo)入機(jī)制也介紹完了。
-
數(shù)據(jù)集
+關(guān)注
關(guān)注
4文章
1224瀏覽量
25461 -
tensorflow
+關(guān)注
關(guān)注
13文章
330瀏覽量
61187
原文標(biāo)題:聊一聊TensorFlow的數(shù)據(jù)導(dǎo)入機(jī)制
文章出處:【微信號(hào):DeepLearningDigest,微信公眾號(hào):深度學(xué)習(xí)每日摘要】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
關(guān)于 TensorFlow
使用 TensorFlow, 你必須明白 TensorFlow
TensorFlow運(yùn)行時(shí)無(wú)法加載本機(jī)
導(dǎo)入tensorflow時(shí)未找到“GLIBC_2.23”錯(cuò)誤
情地使用Tensorflow吧!
TensorFlow是什么
TensorFlow教程|常見(jiàn)問(wèn)題
TensorFlow csv文件讀取數(shù)據(jù)(代碼實(shí)現(xiàn))詳解
TensorFlow實(shí)現(xiàn)簡(jiǎn)單線性回歸
TensorFlow實(shí)現(xiàn)多元線性回歸(超詳細(xì))
TensorFlow邏輯回歸處理MNIST數(shù)據(jù)集
TensorFlow邏輯回歸處理MNIST數(shù)據(jù)集
圖文詳解tensorflow數(shù)據(jù)讀取機(jī)制
TensorFlow數(shù)據(jù)讀取機(jī)制分析

評(píng)論