FlashAttention新升級!斯坦福博士一人重寫算法,第二代實(shí)現(xiàn)了最高9倍速提升。
繼超快且省內(nèi)存的注意力算法FlashAttention爆火后,升級版的2代來了。FlashAttention-2是一種從頭編寫的算法,可以加快注意力并減少其內(nèi)存占用,且沒有任何近似值。比起第一代,F(xiàn)lashAttention-2速度提升了2倍。甚至,相較于PyTorch的標(biāo)準(zhǔn)注意力,其運(yùn)行速度最高可達(dá)9倍。
一年前,StanfordAILab博士Tri Dao發(fā)布了FlashAttention,讓注意力快了2到4倍,如今,F(xiàn)lashAttention已經(jīng)被許多企業(yè)和研究室采用,廣泛應(yīng)用于大多數(shù)LLM庫。如今,隨著長文檔查詢、編寫故事等新用例的需要,大語言模型的上下文以前比過去變長了許多——GPT-4的上下文長度是32k,MosaicML的MPT上下文長度是65k,Anthropic的Claude上下文長度是100k。但是,擴(kuò)大Transformer的上下文長度是一項(xiàng)極大的挑戰(zhàn),因?yàn)樽鳛槠浜诵牡淖⒁饬拥倪\(yùn)行時(shí)間和內(nèi)存要求,是輸入序列長度的二次方。Tri Dao一直在研究FlashAttention-2,它比v1快2倍,比標(biāo)準(zhǔn)的注意力快5到9倍,在A100上已經(jīng)達(dá)到了225 TFLOP/s的訓(xùn)練速度!
項(xiàng)目鏈接:
https://github.com/Dao-AILab/flash-attention

FlashAttention-2:更好的算法、并行性和工作分區(qū)
端到端訓(xùn)練GPT模型,速度高達(dá)225 TFLOP/s
雖說FlashAttention在發(fā)布時(shí)就已經(jīng)比優(yōu)化的基線快了2-4倍,但還是有相當(dāng)大的進(jìn)步空間。比方說,F(xiàn)lashAttention仍然不如優(yōu)化矩陣乘法(GEMM)運(yùn)算快,僅能達(dá)到理論最大FLOPs/s的25-40%(例如,在A100 GPU上的速度可達(dá)124 TFLOPs/s)。
對注意力計(jì)算重新排序
我們知道,F(xiàn)lashAttention是一種對注意力計(jì)算進(jìn)行重新排序的算法,利用平鋪、重新計(jì)算來顯著加快計(jì)算速度,并將序列長度的內(nèi)存使用量從二次減少到線性。

然而,F(xiàn)lashAttention仍然存在一些低效率的問題,這是由于不同線程塊之間的工作劃分并不理想,以及GPU上的warp——導(dǎo)致低占用率或不必要的共享內(nèi)存讀寫。
更少的non-matmulFLOP(非矩陣乘法浮點(diǎn)計(jì)算數(shù))
研究人員通過調(diào)整FlashAttention的算法來減少non-matmul FLOP的次數(shù)。這非常重要,因?yàn)楝F(xiàn)代GPU有專門的計(jì)算單元(比如英偉達(dá)GPU上的張量核心),這就使得matmul的速度更快。例如,A100 GPU FP16/BF16 matmul的最大理論吞吐量為312 TFLOPs/s,但non-matmul FP32的理論吞吐量僅為 19.5 TFLOPs/s。另外,每個(gè)非matmul FLOP比matmul FLOP要貴16倍。所以為了保持高吞吐量,研究人員希望在matmul FLOP上花盡可能多的時(shí)間。研究人員還重新編寫了FlashAttention中使用的在線softmax技巧,以減少重新縮放操作的數(shù)量,以及邊界檢查和因果掩碼操作,而無需更改輸出。更好的并行性
FlashAttention v1在批大小和部數(shù)量上進(jìn)行并行化處理。研究人員使用1個(gè)線程塊來處理一個(gè)注意力頭,共有 (batch_size * head number) 個(gè)線程塊。
每個(gè)線程塊都在流式多處理器 (SM)運(yùn)行,例如,A100 GPU上有108個(gè)這樣的處理器。當(dāng)這個(gè)數(shù)字很大(比如 ≥80)時(shí),這種調(diào)度是有效的,因?yàn)樵谶@種情況下,可以有效地使用GPU上幾乎所有的計(jì)算資源。在長序列的情況下(通常意味著更小批或更少的頭),為了更好地利用GPU上的多處理器,研究人員在序列長度的維度上另外進(jìn)行了并行化,使得該機(jī)制獲得了顯著加速。
更好的工作分區(qū)
即使在每個(gè)線程塊內(nèi),研究人員也必須決定如何在不同的warp(線程束)之間劃分工作(一組32個(gè)線程一起工作)。研究人員通常在每個(gè)線程塊使用4或8個(gè)warp,分區(qū)方案如下圖所示。研究人員在FlashAttention-2中改進(jìn)了這種分區(qū),減少了不同warp之間的同步和通信量,從而減少共享內(nèi)存讀/寫。

FlashAttention僅支持最大128的頭的維度,雖說適用于大多數(shù)模型,但還是有一些模型被排除在外。FlashAttention-2現(xiàn)在支持256的頭的維度,這意味著GPT-J、CodeGen、CodeGen2以及Stable Diffusion 1.x等模型都可以使用FlashAttention-2來獲得加速和節(jié)省內(nèi)存。v2還支持多查詢注意力(MQA)以及分組查詢注意力(GQA)。
▲GQA為每組查詢頭共享單個(gè)key和value的頭,在多頭和多查詢注意之間進(jìn)行插值
這些都是注意力的變體,其中多個(gè)查詢頭會(huì)指向key和value的同一個(gè)頭,以減少推理過程中KV緩存的大小,并可以顯著提高推理的吞吐量。

注意力基準(zhǔn)
研究人員人員在A100 80GB SXM4 GPU 上測量不同設(shè)置(有無因果掩碼、頭的維度是64或128)下不同注意力方法的運(yùn)行時(shí)間。

▲A100 GPU上的前向+后向速度
只需在H100 GPU上運(yùn)行相同的實(shí)現(xiàn)(不需要使用特殊指令來利用TMA和第四代Tensor Core等新硬件功能),研究人員就可以獲得高達(dá)335 TFLOPs/s的速度。
▲H100 GPU上的前向+后向速度
當(dāng)用于端到端訓(xùn)練GPT類模型時(shí),F(xiàn)lashAttention-2能在A100 GPU上實(shí)現(xiàn)高達(dá)225TFLOPs/s的速度(模型FLOPs利用率為72%)。與已經(jīng)非常優(yōu)化的FlashAttention模型相比,端到端的加速進(jìn)一步提高了1.3倍。


未來的工作
速度上快2倍,意味著研究人員可以用與之前訓(xùn)練8k上下文模型相同的成本,來訓(xùn)練16k上下文長度的模型。這些模型可以理解長篇書籍和報(bào)告、高分辨率圖像、音頻和視頻。同時(shí),F(xiàn)lashAttention-2還將加速現(xiàn)有模型的訓(xùn)練、微調(diào)和推理。在不久的將來,研究人員還計(jì)劃擴(kuò)大合作,使FlashAttention廣泛適用于不同類型的設(shè)備(例如H100 GPU、AMD GPU)以及新的數(shù)據(jù)類型(例如fp8)。下一步,研究人員計(jì)劃針對H100 GPU進(jìn)一步優(yōu)化FlashAttention-2,以使用新的硬件功能(TMA、第四代Tensor Core、fp8等等)。將FlashAttention-2中的低級優(yōu)化與高級算法更改(例如局部、擴(kuò)張、塊稀疏注意力)相結(jié)合,可以讓研究人員用更長的上下文來訓(xùn)練AI模型。研究人員也很高興與編譯器研究人員合作,使這些優(yōu)化技術(shù)更好地應(yīng)用于編程。

Tri Dao曾在斯坦福大學(xué)獲得了計(jì)算機(jī)博士學(xué)位,導(dǎo)師是Christopher Ré和Stefano Ermon。根據(jù)主頁介紹,他將從2024年9月開始,任職普林斯頓大學(xué)計(jì)算機(jī)科學(xué)助理教授。
Tri Dao的研究興趣在于機(jī)器學(xué)習(xí)和系統(tǒng),重點(diǎn)關(guān)注高效訓(xùn)練和長期環(huán)境:- 高效Transformer訓(xùn)練和推理 - 遠(yuǎn)程記憶的序列模型 - 緊湊型深度學(xué)習(xí)模型的結(jié)構(gòu)化稀疏性。
值得一提的是,Tri Dao今天正式成為生成式AI初創(chuàng)公司Together AI的首席科學(xué)家。
原文標(biāo)題:讓Attention提速9倍!FlashAttention燃爆顯存,Transformer上下文長度史詩級提升
文章出處:【微信公眾號:智能感知與物聯(lián)網(wǎng)技術(shù)研究所】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
-
物聯(lián)網(wǎng)
+關(guān)注
關(guān)注
2923文章
45693瀏覽量
385759
原文標(biāo)題:讓Attention提速9倍!FlashAttention燃爆顯存,Transformer上下文長度史詩級提升
文章出處:【微信號:tyutcsplab,微信公眾號:智能感知與物聯(lián)網(wǎng)技術(shù)研究所】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
S32K在AUTOSAR中使用CAT1 ISR,是否需要執(zhí)行上下文切換?
摩爾線程Round Attention優(yōu)化AI對話

DeepSeek推出NSA機(jī)制,加速長上下文訓(xùn)練與推理
《具身智能機(jī)器人系統(tǒng)》第7-9章閱讀心得之具身智能機(jī)器人與大模型
阿里通義千問發(fā)布Qwen2.5-Turbo開源AI模型
SystemView上下文統(tǒng)計(jì)窗口識別阻塞原因
鴻蒙Ability Kit(程序框架服務(wù))【應(yīng)用上下文Context】

編寫一個(gè)任務(wù)調(diào)度程序,在上下文切換后遇到了一些問題求解
鴻蒙開發(fā)接口Ability框架:【ServiceExtensionContext】

MiniMax推出“海螺AI”,支持超長文本處理
鴻蒙開發(fā)接口Ability框架:【AbilityStageContext】

評論