微軟研究院最近提出了一個(gè)新的 LLM 自回歸基礎(chǔ)架構(gòu) Retentive Networks (RetNet)[1,4],該架構(gòu)相對(duì)于 Transformer 架構(gòu)的優(yōu)勢(shì)是同時(shí)具備:訓(xùn)練可并行、推理成本低和良好的性能,不可能三角。
論文中給出一個(gè)很形象的示意圖,RetNet 在正中間表示同時(shí)具備三個(gè)優(yōu)點(diǎn),而其他的架構(gòu) Linear Transformer、Recurrent Network 和 Transformer 都只能同時(shí)具備其中兩個(gè)有點(diǎn)。
接下來(lái)看一下論文給出的 RetNet 和 Transformer 的對(duì)比實(shí)驗(yàn)結(jié)果:
當(dāng)輸入序列長(zhǎng)度增加的時(shí)候,RetNet 的 GPU 顯存占用一直是穩(wěn)定的和權(quán)值差不多,而 Transformer 則是和輸入長(zhǎng)度成正比。
首先看紅色線和紫色線,都是輸入長(zhǎng)度在 8192 下,RetNet 和 Transformer 推理延時(shí)的對(duì)比。
可以看到當(dāng) batch size 增加的時(shí)候, RetNet 的推理延時(shí)也還是很穩(wěn)定,而 Transformer 的推理延時(shí)則是和 batch size 成正比。
而 Transformer 即使是輸入長(zhǎng)度縮小到 1024 ,推理延時(shí)也還是比 RetNet 要高。
RetNet 架構(gòu)解讀
RetNet 架構(gòu)和 Transformer 類似,也是堆疊 層同樣的模塊,每個(gè)模塊內(nèi)部包含兩個(gè)子模塊:一個(gè) multi-scale retention(MSR)和一個(gè) feed-forward network (FFN)。
下面詳細(xì)解讀一下這個(gè) retention 子模塊。
首先給定一個(gè)輸入序列 :
其中 表示序列的長(zhǎng)度。然后輸入序列首先經(jīng)過(guò) embedding 層得到詞嵌入向量:
其中 表示隱含層的維度。
Retention 機(jī)制
首先對(duì)給定輸入詞嵌入向量序列 中的每個(gè)時(shí)間步 的向量 都乘以權(quán)值 得到 :
然后同樣有類似 Transformer 架構(gòu)的 Q 和 K 的投影:
其中 是需要學(xué)習(xí)的權(quán)值。
接著假設(shè)現(xiàn)在有一個(gè)序列建模的問(wèn)題,通過(guò)狀態(tài) 將 映射為 向量。首先來(lái)看論文中給出的映射方式定義:
其中 是一個(gè)矩陣, 表示時(shí)間步 對(duì)應(yīng)的 投影則 。同樣 表示時(shí)間步 對(duì)應(yīng)的 投影。
那么上面公式中的 計(jì)算公式是怎么得出來(lái)呢,下面詳細(xì)解釋一下,首先將 展開(kāi):
其中 表示單位矩陣(主對(duì)角線元素為1,其余元素為0的方陣)。然后我們假定 為初始狀態(tài)元素為全0的矩陣,則有:
再繼續(xù)上述推導(dǎo)過(guò)程:
所以根據(jù)上述推導(dǎo)過(guò)程和條件歸納可得:
然后我們來(lái)看一下 矩陣是什么,論文中定義了 是一個(gè)可對(duì)角化的矩陣,具體定義為:
其中 都是 維的向量, 是一個(gè)可逆矩陣,而要理解 首先得復(fù)習(xí)一下歐拉公式 [2]:
其中 表示任意實(shí)數(shù), 是自然對(duì)數(shù)的底數(shù), 是復(fù)數(shù)中的虛數(shù)單位,也可以表示為實(shí)部 ,虛部 的一個(gè)復(fù)數(shù),歐拉公式[2]建立了指數(shù)函數(shù)、三角函數(shù)和復(fù)數(shù)之間的橋梁。
而這里 是一個(gè) 維向量:
則 也就是將向量元素兩兩一組表示分別表示為復(fù)數(shù)的實(shí)部和虛部:
然后 就是一個(gè)對(duì)角矩陣,對(duì)角元素的值就對(duì)應(yīng)將 和 轉(zhuǎn)成復(fù)數(shù)向量相乘再將結(jié)果轉(zhuǎn)回實(shí)數(shù)向量的結(jié)果。
關(guān)于復(fù)數(shù)向量相乘可以參考文章:?
一文看懂 LLaMA 中的旋轉(zhuǎn)式位置編碼(Rotary Position Embedding)
現(xiàn)在我們知道了矩陣 的構(gòu)成就能得到:
這里因?yàn)?是可逆矩陣則有性質(zhì)
其中 為單位矩陣,則將 次方展開(kāi):
就是 個(gè) 矩陣相乘,中間相鄰的 都消掉了,所以可得:
然后我們回到計(jì)算 的公式:
接著論文中提出把 吸收進(jìn) 和 也就是 和 分別用 和 替代當(dāng)作學(xué)習(xí)的權(quán)值,那么可得:
接著將公式簡(jiǎn)化,將 改為一個(gè)實(shí)數(shù)常量,那么可得:
在繼續(xù)推導(dǎo)前,先來(lái)仔細(xì)看一下 ,借助歐拉公式展開(kāi):
然后復(fù)習(xí)一下三角函數(shù)的性質(zhì)[3]:
則有:
轉(zhuǎn)為復(fù)數(shù)形式表示就是:
剛好就對(duì)應(yīng) 的共軛
所以可得:
其中 表示共軛轉(zhuǎn)置操作。
Retention 的訓(xùn)練并行表示
首先回顧單個(gè)時(shí)間步 的輸出 的計(jì)算公式如下:
而所有時(shí)間步的輸出是可以并行計(jì)算的,用矩陣形式表達(dá)如下:
其中 ,而 表示兩個(gè)矩陣逐元素相乘, 和 每一行對(duì)應(yīng)一個(gè)時(shí)間步的 q 和 k 向量。
而 每一行對(duì)應(yīng)向量 。 就是對(duì)應(yīng) 矩陣的共軛,也就是將 矩陣每一行改為復(fù)數(shù)的共軛形式。
而 矩陣是一個(gè)下三角矩陣,其中第 行第 列的元素計(jì)算方式:
Retention 的推理循環(huán)表示
推理階段的循環(huán)表示論文中定義如下:
怎么理解呢,還是先回顧單個(gè)時(shí)間步 的輸出 的計(jì)算公式:
上述公式最后一步和推理階段循環(huán)表示公式中各個(gè)元素的對(duì)應(yīng)關(guān)系是:
對(duì)應(yīng)論文中的圖示:
圖中的 表示 GroupNorm。
可以看到在推理階段,RetNet 在計(jì)算當(dāng)前時(shí)間步 的輸出 只依賴于上一個(gè)時(shí)間步產(chǎn)出的狀態(tài)矩陣 。
其實(shí)就是把計(jì)算順序改了一下,先計(jì)算的 和 的相乘然后一直累加到狀態(tài)矩陣 上,最后再和 相乘。
而不是像 Transformer 架構(gòu)那樣,每個(gè)時(shí)間步的計(jì)算要先算 和前面所有時(shí)間步的 相乘得到 attention 權(quán)值再和 相乘求和,這樣就需要一直保留歷史的 和 。
Gated Multi-Scale Retention
然后 RetNet 每一層中的 Retention 子模塊其實(shí)也是分了 個(gè)頭,每個(gè)頭用不同的 參數(shù),同時(shí)每個(gè)頭都采用不同的 常量,這也是 ?Multi-Scale Retention 名稱的來(lái)由。
則對(duì)輸入 , MSR 層的輸出是:
其中, , 是激活函數(shù)用來(lái)生成門(mén)控閾值,還有由于每個(gè)頭均采用不同的 ,所以每個(gè)頭的輸出要單獨(dú)做 normalize 之后再 concat。
編輯:黃飛
?
評(píng)論