屹立不倒的 Transformer 迎來了一個(gè)強(qiáng)勁競爭者。
在別的領(lǐng)域,如果你想形容一個(gè)東西非常重要,你可能將其形容為「撐起了某領(lǐng)域的半壁江山」。但在 AI 大模型領(lǐng)域,Transformer 架構(gòu)不能這么形容,因?yàn)樗鼛缀鯎纹鹆恕刚麄€(gè)江山」。
自 2017 年被提出以來,Transformer 已經(jīng)成為 AI 大模型的主流架構(gòu),但隨著模型規(guī)模的擴(kuò)展和需要處理的序列不斷變長,Transformer 的局限性也逐漸凸顯。一個(gè)很明顯的缺陷是:Transformer 模型中自注意力機(jī)制的計(jì)算量會隨著上下文長度的增加呈平方級增長,比如上下文增加 32 倍時(shí),計(jì)算量可能會增長 1000 倍,計(jì)算效率非常低。
為了克服這些缺陷,研究者們開發(fā)出了很多注意力機(jī)制的高效變體,但這往往以犧牲其有效性特為代價(jià)。到目前為止,這些變體都還沒有被證明能在不同領(lǐng)域發(fā)揮有效作用。
最近,一項(xiàng)名為「Mamba」的研究似乎打破了這一局面。
在這篇論文中,研究者提出了一種新的架構(gòu) ——「選擇性狀態(tài)空間模型( selective state space model)」。它在多個(gè)方面改進(jìn)了先前的工作。
作者表示,「Mamba」在語言建模方面可以媲美甚至擊敗 Transformer。而且,它可以隨上下文長度的增加實(shí)現(xiàn)線性擴(kuò)展,其性能在實(shí)際數(shù)據(jù)中可提高到百萬 token 長度序列,并實(shí)現(xiàn) 5 倍的推理吞吐量提升。
消息一出,人們紛紛點(diǎn)贊,有人表示已經(jīng)迫不及待想要把它用在大模型上了。
作為通用序列模型的骨干,Mamba 在語言、音頻和基因組學(xué)等多種模態(tài)中都達(dá)到了 SOTA 性能。在語言建模方面,無論是預(yù)訓(xùn)練還是下游評估,他們的 Mamba-3B 模型都優(yōu)于同等規(guī)模的 Transformer 模型,并能與兩倍于其規(guī)模的 Transformer 模型相媲美。
這篇論文的作者只有兩位,一位是卡內(nèi)基梅隆大學(xué)機(jī)器學(xué)習(xí)系助理教授 Albert Gu,另一位是 Together.AI 首席科學(xué)家、普林斯頓大學(xué)計(jì)算機(jī)科學(xué)助理教授(即將上任)Tri Dao。
Albert Gu 表示,這項(xiàng)研究的一個(gè)重要?jiǎng)?chuàng)新是引入了一個(gè)名為「選擇性 SSM」的架構(gòu),該架構(gòu)是 Albert Gu 此前主導(dǎo)研發(fā)的 S4 架構(gòu)(Structured State Spaces for Sequence Modeling ,用于序列建模的結(jié)構(gòu)化狀態(tài)空間)的一個(gè)簡單泛化,可以有選擇地決定關(guān)注還是忽略傳入的輸入。一個(gè)「小小的改變」—— 讓某些參數(shù)成為輸入的函數(shù),結(jié)果卻非常有效。
值得一提的是,S4 是一個(gè)非常成功的架構(gòu)。此前,它成功地對? Long Range Arena (LRA) 中的長程依賴進(jìn)行了建模,并成為首個(gè)在 Path-X 上獲得高于平均性能的模型。更具體地說,S4 是一類用于深度學(xué)習(xí)的序列模型,與 RNN、CNN 和經(jīng)典的狀態(tài)空間模型(State Space Model,SSM)廣泛相關(guān)。SSM 是獨(dú)立的序列轉(zhuǎn)換,可被整合到端到端神經(jīng)網(wǎng)絡(luò)架構(gòu)中( SSM 架構(gòu)有時(shí)也稱 SSNN,它與 SSM 層的關(guān)系就像 CNN 與線性卷積層的關(guān)系一樣)。Mamba 論文也討論了一些著名的 SSM 架構(gòu),比如 Linear attention、H3、Hyena、RetNet、RWKV,其中許多也將作為論文研究的基線。Mamba 的成功讓 Albert Gu 對 SSM 的未來充滿了信心。
Tri Dao 則是 FlashAttention、Flash Attention v2、Flash-Decoding的作者。FlashAttention 是一種對注意力計(jì)算進(jìn)行重新排序并利用經(jīng)典技術(shù)(平鋪、重新計(jì)算)加快速度并將內(nèi)存使用從序列長度的二次減少到線性的算法。Flash Attention v2、Flash-Decoding 都是建立在 Flash Attention 基礎(chǔ)上的后續(xù)工作,把大模型的長文本推理效率不斷推向極限。在 Mamba 之前,Tri Dao 和 Albert Gu 也有過合作。
另外,這項(xiàng)研究的模型代碼和預(yù)訓(xùn)練的檢查點(diǎn)是開源的,參見以下鏈接:https://github.com/state-spaces/mamba.
論文鏈接:https://arxiv.org/ftp/arxiv/papers/2312/2312.00752.pdf
https://github.com/state-spaces/mamba
方法創(chuàng)新
論文第 3.1 節(jié)介紹了如何利用合成任務(wù)的直覺來啟發(fā)選擇機(jī)制,第 3.2 節(jié)解釋了如何將這一機(jī)制納入狀態(tài)空間模型。由此產(chǎn)生的時(shí)變 SSM 不能使用卷積,導(dǎo)致了高效計(jì)算的技術(shù)難題。研究者采用了一種硬件感知算法,利用當(dāng)前硬件的內(nèi)存層次結(jié)構(gòu)來克服這一難題(第 3.3 節(jié))。第 3.4 節(jié)描述了一個(gè)簡單的 SSM 架構(gòu),不需要注意力,甚至不需要 MLP 塊。第 3.5 節(jié)討論了選擇機(jī)制的一些其他特性。
選擇機(jī)制
研究者發(fā)現(xiàn)了此前模型的一個(gè)關(guān)鍵局限:以依賴輸入的方式高效選擇數(shù)據(jù)的能力(即關(guān)注或忽略特定輸入)。
序列建模的一個(gè)基本方法是將上下文壓縮到更小的狀態(tài),我們可以從這個(gè)角度來看待當(dāng)下流行的序列模型。例如,注意力既高效又低效,因?yàn)樗緵]有明確壓縮上下文。這一點(diǎn)可以從自回歸推理需要明確存儲整個(gè)上下文(即 KV 緩存)這一事實(shí)中看出,這直接導(dǎo)致了 Transformer 緩慢的線性時(shí)間推理和二次時(shí)間訓(xùn)練。
遞歸模型的效率很高,因?yàn)樗鼈兊臓顟B(tài)是有限的,這意味著恒定時(shí)間推理和線性時(shí)間訓(xùn)練。然而,它們的高效性受限于這種狀態(tài)對上下文的壓縮程度。
為了理解這一原理,下圖展示了兩個(gè)合成任務(wù)的運(yùn)行示例:
研究者設(shè)計(jì)了一種簡單的選擇機(jī)制,根據(jù)輸入對 SSM 參數(shù)進(jìn)行參數(shù)化。這樣,模型就能過濾掉無關(guān)信息,并無限期地記住相關(guān)信息。
將選擇機(jī)制納入模型的一種方法是讓影響序列交互的參數(shù)(如 RNN 的遞歸動力學(xué)或 CNN 的卷積核)與輸入相關(guān)。算法 1 和 2 展示了本文使用的主要選擇機(jī)制。其主要區(qū)別在于,該方法只需將幾個(gè)參數(shù) ?,B,C 設(shè)置為輸入函數(shù),并在整個(gè)過程中改變張量形狀。這些參數(shù)現(xiàn)在都有一個(gè)長度維度 L ,意味著模型已經(jīng)從時(shí)間不變變?yōu)闀r(shí)間可變。
硬件感知算法
上述變化對模型的計(jì)算提出了技術(shù)挑戰(zhàn)。所有先前的 SSM 模型都必須是時(shí)間和輸入不變的,這樣才能提高計(jì)算效率。為此,研究者采用了一種硬件感知算法,通過掃描而不是卷積來計(jì)算模型,但不會將擴(kuò)展?fàn)顟B(tài)具體化,以避免在 GPU 存儲器層次結(jié)構(gòu)的不同級別之間進(jìn)行 IO 訪問。由此產(chǎn)生的實(shí)現(xiàn)方法在理論上(與所有基于卷積的 SSM 的偽線性相比,在序列長度上呈線性縮放)和現(xiàn)有硬件上都比以前的方法更快(在 A100 GPU 上可快達(dá) 3 倍)。
架構(gòu)
研究者將先前的 SSM 架構(gòu)設(shè)計(jì)與 Transformer 的 MLP 塊合并為一個(gè)塊,從而簡化了深度序列模型架構(gòu),形成了一種包含選擇性狀態(tài)空間的簡單、同質(zhì)的架構(gòu)設(shè)計(jì)(Mamba)。
與結(jié)構(gòu)化 SSM 一樣,選擇性 SSM 也是一種獨(dú)立的序列變換,可以靈活地融入神經(jīng)網(wǎng)絡(luò)。H3 架構(gòu)是著名的同質(zhì)化架構(gòu)設(shè)計(jì)的基礎(chǔ),通常由線性注意力啟發(fā)的塊和 MLP(多層感知器)塊交錯(cuò)組成。
研究者簡化了這一架構(gòu),將這兩個(gè)部分合二為一,均勻堆疊,如圖 3。他們受到門控注意力單元(GAU)的啟發(fā),該單元也對注意力做了類似的處理。
選擇性 SSM 以及 Mamba 架構(gòu)的擴(kuò)展是完全遞歸模型,幾個(gè)關(guān)鍵特性使其適合作為在序列上運(yùn)行的通用基礎(chǔ)模型的骨干:
高質(zhì)量:選擇性為語言和基因組學(xué)等密集模型帶來了強(qiáng)大的性能。
快速訓(xùn)練和推理:在訓(xùn)練過程中,計(jì)算量和內(nèi)存與序列長度成線性關(guān)系,而在推理過程中,由于不需要緩存以前的元素,自回歸展開模型每一步只需要恒定的時(shí)間。
長上下文:質(zhì)量和效率共同提高了實(shí)際數(shù)據(jù)的性能,序列長度可達(dá) 100 萬。
實(shí)驗(yàn)評估
實(shí)證驗(yàn)證了 Mamba 作為通用序列基礎(chǔ)模型骨干的潛力,無論是在預(yù)訓(xùn)練質(zhì)量還是特定領(lǐng)域的任務(wù)性能方面,Mamba 都能在多種類型的模態(tài)和環(huán)境中發(fā)揮作用:
合成任務(wù)。在復(fù)制和感應(yīng)頭等重要的語言模型合成任務(wù)上,Mamba 不僅能輕松解決,而且能推斷出無限長的解決方案(>100 萬 token)。
音頻和基因組學(xué)。在音頻波形和 DNA 序列建模方面,Mamba 在預(yù)訓(xùn)練質(zhì)量和下游指標(biāo)方面都優(yōu)于 SaShiMi、Hyena、Transformer 等先前的 SOTA 模型(例如,在具有挑戰(zhàn)性的語音生成數(shù)據(jù)集上將 FID 降低了一半以上)。在這兩種情況下,它的性能隨著上下文長度的增加而提高,最高可達(dá)百萬長度的序列。
語言建模。Mamba 是首個(gè)線性時(shí)間序列模型,在預(yù)訓(xùn)練復(fù)雜度和下游評估方面都真正達(dá)到了 Transformer 質(zhì)量的性能。通過多達(dá) 1B 參數(shù)的縮放規(guī)律,研究者發(fā)現(xiàn) Mamba 的性能超過了大量基線模型,包括 LLaMa 這種非常強(qiáng)大的現(xiàn)代 Transformer 訓(xùn)練配方。
與類似規(guī)模的 Transformer 相比,Mamba 具有 5 倍的生成吞吐量,而且 Mamba-3B 的質(zhì)量與兩倍于其規(guī)模的 Transformer 相當(dāng)(例如,與 Pythia-3B 相比,常識推理的平均值高出 4 分,甚至超過 Pythia-7B)。
審核編輯:黃飛
評論