上一期文章帶大家認(rèn)識了一個(gè)名為GRU的新朋友, GRU本身自帶處理時(shí)序數(shù)據(jù)的屬性,特別擅長對于時(shí)間序列的識別和檢測(例如音頻、傳感器信號等)。GRU其實(shí)是RNN模型的一個(gè)衍生形式,巧妙地設(shè)計(jì)了兩個(gè)門控單元:reset門和更新門。reset門負(fù)責(zé)針對歷史遺留的狀態(tài)進(jìn)行重置,丟棄掉無用信息;更新門負(fù)責(zé)對歷史狀態(tài)進(jìn)行更新,將新的輸入與歷史數(shù)據(jù)集進(jìn)行整合。通過模型訓(xùn)練,讓模型能夠自動調(diào)整這兩個(gè)門控單元的狀態(tài),以期達(dá)到歷史數(shù)據(jù)與最新數(shù)據(jù)和諧共存的目的。
理論知識掌握了,下面就來看看如何訓(xùn)練一個(gè)GRU模型吧。
訓(xùn)練平臺選用Keras,請?zhí)崆白孕邪惭bKeras開發(fā)工具。直接上代碼,首先是數(shù)據(jù)導(dǎo)入部分,我們直接使用mnist手寫字體數(shù)據(jù)集:
import numpy as np import tensorflow as tf from tensorflow.keras.models import Sequential from tensorflow.keras.layers import GRU, Dense from tensorflow.keras.datasets import mnist from tensorflow.keras.utils import to_categorical from tensorflow.keras.models import load_model # 準(zhǔn)備數(shù)據(jù)集 (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = x_train.astype('float32') / 255.0 x_test = x_test.astype('float32') / 255.0 y_train = to_categorical(y_train, 10) y_test = to_categorical(y_test, 10)
模型構(gòu)建與訓(xùn)練:
# 構(gòu)建GRU模型 model = Sequential() model.add(GRU(128, input_shape=(28, 28), stateful=False, unroll=False)) model.add(Dense(10, activation='softmax')) # 編譯模型 model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) # 模型訓(xùn)練 model.fit(x_train, y_train, batch_size=128, epochs=10, validation_data=(x_test, y_test))
這里,眼尖的伙伴應(yīng)該是注意到了,GRU模型構(gòu)建的時(shí)候,有兩個(gè)參數(shù),分別是stateful以及unroll,這兩個(gè)參數(shù)是什么意思呢?
GRU層的stateful和unroll是兩個(gè)重要的參數(shù),它們對GRU模型的行為和性能有著重要影響:
stateful參數(shù):默認(rèn)情況下,stateful參數(shù)為False。當(dāng)stateful設(shè)置為True時(shí),表示在處理連續(xù)的數(shù)據(jù)時(shí),GRU層的狀態(tài)會被保留并傳遞到下一個(gè)時(shí)間步,而不是每個(gè)batch都重置狀態(tài)。這對于處理時(shí)間序列數(shù)據(jù)時(shí)非常有用,例如在處理長序列時(shí),可以保持模型的狀態(tài)信息,而不是在每個(gè)batch之間重置。需要注意的是,在使用stateful時(shí),您需要手動管理狀態(tài)的重置。
unroll參數(shù):默認(rèn)情況下,unroll參數(shù)為False。當(dāng)unroll設(shè)置為True時(shí),表示在計(jì)算時(shí)會展開RNN的循環(huán),這樣可以提高計(jì)算性能,但會增加內(nèi)存消耗。通常情況下,對于較短的序列,unroll設(shè)置為True可以提高計(jì)算速度,但對于較長的序列,可能會導(dǎo)致內(nèi)存消耗過大。
通過合理設(shè)置stateful和unroll參數(shù),可以根據(jù)具體的數(shù)據(jù)和模型需求來平衡模型的狀態(tài)管理和計(jì)算性能。而我們這里用到的mnist數(shù)據(jù)集實(shí)際上并不是時(shí)間序列數(shù)據(jù),而只是將其當(dāng)作一個(gè)時(shí)序數(shù)據(jù)集來用。因此,每個(gè)batch之間實(shí)際上是沒有顯示的前后關(guān)系的,不建議使用stateful。而是每一個(gè)batch之后都要將其狀態(tài)清零。即stateful=False。而unroll參數(shù),大家就可以自行測試了。
模型評估與轉(zhuǎn)換:
# 模型評估 score = model.evaluate(x_test, y_test, verbose=0) print('Test loss:', score[0]) print('Test accuracy:', score[1]) # 保存模型 model.save("mnist_gru_model.h5") # 加載模型并轉(zhuǎn)換 converter = tf.lite.TFLiteConverter.from_keras_model(load_model("mnist_gru_model.h5")) tflite_model = converter.convert() # 保存tflite格式模型 with open('mnist_gru_model.tflite', 'wb') as f: f.write(tflite_model)
便寫好程序后,運(yùn)行等待訓(xùn)練完畢,可以看到經(jīng)過10個(gè)epoch之后,模型即達(dá)到了98.57%的測試精度:
來看看最終的模型樣子,參數(shù)stateful=False,unroll=True:
這里,我們就會發(fā)現(xiàn),模型的輸入好像被拆分成了很多份,這是因?yàn)槲覀冎付溯斎胧?8*28。第一個(gè)28表示有28個(gè)時(shí)間步,后面的28則表示每一個(gè)時(shí)間步的維度。這里的時(shí)間步,指代的就是歷史的數(shù)據(jù)。
現(xiàn)在,GRU模型訓(xùn)練就全部介紹完畢了,對于機(jī)器學(xué)習(xí)和深度學(xué)習(xí)感興趣的伙伴們,不妨親自動手嘗試一下,搭建并訓(xùn)練一個(gè)屬于自己的GRU模型吧!
希望每一位探索者都能在機(jī)器學(xué)習(xí)的道路上不斷前行,收獲滿滿的知識和成果!
-
Gru
+關(guān)注
關(guān)注
0文章
12瀏覽量
7580 -
機(jī)器學(xué)習(xí)
+關(guān)注
關(guān)注
66文章
8481瀏覽量
133858 -
rnn
+關(guān)注
關(guān)注
0文章
89瀏覽量
7035
原文標(biāo)題:GRU模型實(shí)戰(zhàn)訓(xùn)練,智能決策更精準(zhǔn)!
文章出處:【微信號:NXP_SMART_HARDWARE,微信公眾號:恩智浦MCU加油站】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
訓(xùn)練好的ai模型導(dǎo)入cubemx不成功怎么處理?
AI賦能邊緣網(wǎng)關(guān):開啟智能時(shí)代的新藍(lán)海
【「大模型啟示錄」閱讀體驗(yàn)】如何在客服領(lǐng)域應(yīng)用大模型
什么是大模型、大模型是怎么訓(xùn)練出來的及大模型作用

谷東科技民航維修智能決策大模型榮獲華為昇騰技術(shù)認(rèn)證
大語言模型的預(yù)訓(xùn)練
人臉識別模型訓(xùn)練流程
人臉識別模型訓(xùn)練是什么意思
預(yù)訓(xùn)練模型的基本原理和應(yīng)用
深度學(xué)習(xí)模型訓(xùn)練過程詳解
深入GRU:解鎖模型測試新維度

GRU是什么?GRU模型如何讓你的神經(jīng)網(wǎng)絡(luò)更聰明 掌握時(shí)間 掌握未來

評論