wenet概述#
wenet處理流程#
wav語音經(jīng)過一系列前處理之后送入encoder,encoder的輸出會給到ctc decoder和attention decoder。其中ctc decoder是深度優(yōu)先的搜索程序,負(fù)責(zé)搜索出n段候選預(yù)測序列,然后ctc的結(jié)構(gòu)與encoder的結(jié)果一起送入attention decoder進(jìn)行預(yù)測序列的打分,選出最好的預(yù)測序列,輸出結(jié)果
模型結(jié)構(gòu)#
encoder#
encoder是12個conformer模塊堆疊
decoder#
decoder基本結(jié)構(gòu)是transformer,先經(jīng)過字的自注意力再和mel特征進(jìn)行交叉注意力,wenet的decoder是雙向decoder,有一個預(yù)測序列從左到右的打分,還一個從右到左的打分,每個decoder堆疊3個transformer(取決于訓(xùn)練時的配置文件,也可以是6個)
模型優(yōu)化#
圖優(yōu)化#
結(jié)合netron觀察模型圖結(jié)構(gòu),考慮以下幾個優(yōu)化方向
刪除冗余算子
算子融合
改變算子執(zhí)行順序
Where(MaskedFill)的優(yōu)化#
這段網(wǎng)絡(luò)是在做掩碼操作,即輸入一個掩碼 Tensor對數(shù)據(jù)做Mask,第一個Where把不需要的數(shù)據(jù)設(shè)置為-inf。經(jīng)過Softmax之后這些數(shù)據(jù)已經(jīng)變成了0,但是后面又增加了一個Where,把相同位置再次設(shè)置為了0。
這段結(jié)構(gòu)在網(wǎng)絡(luò)中出現(xiàn)了12次,單Where算子耗時30 ms相當(dāng)于多了一倍的計(jì)算時間,可以在編譯階段使用圖優(yōu)化進(jìn)行消除,減少模型計(jì)算量。
耗時從30ms降低到15ms
MatMul的優(yōu)化#
1.這段結(jié)構(gòu)在Decoder中出現(xiàn)6次,且屬于計(jì)算集中的attention部分。但是C維度在transpose之后只有8,我們的TPU有64個lane,算力沒有完全利用起來
2.有transpose隔斷了layer group,增加了數(shù)據(jù)搬運(yùn)
優(yōu)化:
1.可以使用hdim_is_batch優(yōu)化,把a(bǔ)ttention的head放在h維。為了保證網(wǎng)絡(luò)變換前后等效,需要在matmul后面新生成一個transpose。
2.生成新的transpose之后,實(shí)現(xiàn)masked fill算子、softmax算子的transpose move down的優(yōu)化pattern,使得tranpose的執(zhí)行順序可以放到該段結(jié)構(gòu)結(jié)束處,同時與結(jié)束處原本有的tranpose做抵消,達(dá)到減少數(shù)據(jù)搬運(yùn)的目的。
3.由于消除了transpose,使得這段網(wǎng)絡(luò)可以做到local layer,同時因?yàn)榘?49放到c維度了,又可以充分利用64個lane的計(jì)算資源了
其余的算子經(jīng)過transpose move down,可以實(shí)現(xiàn)transpose的一路下移,在局部網(wǎng)絡(luò)中讓C可以保持349,使得64個lane可以獲得更充分的利用
算子層面優(yōu)化#
Where(MaskedFill)的優(yōu)化#
MaskedFill如果全走Global耗時30ms,即便減少一半的算子數(shù)量還是15ms。而Select算子有l(wèi)ocal的實(shí)現(xiàn),同時可以通過參數(shù)配置完成MaskedFill的功能,但不支持廣播。所以在編譯階段加入Tile完成廣播,從而支持Local Layer。
但引入了Tile,Tile操作本身耗時3.8ms,代價可接受,后續(xù)可以進(jìn)一步優(yōu)化
MaskedFill算子從30ms 減半數(shù)量后到15ms,引入tile之后減少到3.8ms(Tile)+127us(MaskedFill)
后續(xù)考慮使用bdc完成tile操作,完成進(jìn)一步優(yōu)化
CPU Layer的優(yōu)化#
兩個CPU Gather PT操作占用456ms,可以使用dma的Gather操作在TPU實(shí)現(xiàn)算子
Gather PT算子從456ms減少到68us
優(yōu)化結(jié)果#
WeNet Decoder | 耗時 |
---|---|
原始模型 | 611ms |
CPU Layer替換 | 156ms |
MaskedFill 減半 | 141ms |
MatMul hdim_is_batch優(yōu)化+Permute Move優(yōu)化+MaskedFill支持Local | 71ms |
-
cpu
+關(guān)注
關(guān)注
68文章
11074瀏覽量
216924 -
模型
+關(guān)注
關(guān)注
1文章
3516瀏覽量
50366 -
算子
+關(guān)注
關(guān)注
0文章
16瀏覽量
7347
發(fā)布評論請先 登錄






評論