1. 前言
最近,OpenAI推出的ChatGPT展現(xiàn)出了卓越的性能,引發(fā)了大規(guī)模語言模型(Large Language Model,LLM)的研究熱潮。大規(guī)模語言模型的“大”體現(xiàn)在兩個方面:模型參數(shù)規(guī)模大,訓(xùn)練數(shù)據(jù)規(guī)模大。以GPT3為例,GPT3的參數(shù)量為1750億,訓(xùn)練數(shù)據(jù)量達到了570GB。進而,訓(xùn)練大規(guī)模語言模型面臨兩個主要挑戰(zhàn):顯存效率和計算效率。
現(xiàn)在業(yè)界的大語言模型都是基于transformer模型的,模型結(jié)構(gòu)主要有兩大類:encoder-decoder(代表模型是T5)和decoder-only,具體的,decoder-only結(jié)構(gòu)又可以分為Causal LM(代表模型是GPT系列)和PrefixLM(代表模型是GLM)。歸因于GPT系列取得的巨大成功,大多數(shù)的主流大語言模型都采用Causal LM結(jié)構(gòu)。因此,針對decoder-only框架,為了更好地理解訓(xùn)練訓(xùn)練大語言模型的顯存效率和計算效率,本文分析采用decoder-only框架transformer模型的模型參數(shù)量、計算量、中間激活值、KV cache。
為了方便分析,先定義好一些數(shù)學(xué)符號。記transformer模型的層數(shù)為?,隱藏層維度為
?,注意力頭數(shù)為
?。詞表大小為
?,訓(xùn)練數(shù)據(jù)的批次大小為
?,序列長度為
?。
2. 模型參數(shù)量
transformer模型由個相同的層組成,每個層分為兩部分:self-attention塊和MLP塊。
self-attention塊的模型參數(shù)有?的權(quán)重矩陣
和偏置,輸出權(quán)重矩陣?
?和偏置,4個權(quán)重矩陣的形狀為
?,4個偏置的形狀為
?。self- attention塊的參數(shù)量為
?。
MLP塊由2個線性層組成,一般地,第一個線性層是先將維度從?映射到
,第二個線性層再將維度從
映射到
。第一個線性層的權(quán)重矩陣
?的形狀為
?,偏置的形狀為
?。第二個線性層權(quán)重矩陣
?的形狀為
?,偏置形狀為
?。MLP塊的參數(shù)量為
?。
self-attention塊和MLP塊各有一個layer normalization,包含了2個可訓(xùn)練模型參數(shù):縮放參數(shù)?和平移參數(shù)
?,形狀都是
?。2個layernormalization的參數(shù)量為?
?。
總的,每個transformer層的參數(shù)量為?。
除此之外,詞嵌入矩陣的參數(shù)量也較多,詞向量維度通常等于隱藏層維度,詞嵌入矩陣的參數(shù)量為?
。最后的輸出層的權(quán)重矩陣通常與詞嵌入矩陣是參數(shù)共享的。
關(guān)于位置編碼,如果采用可訓(xùn)練式的位置編碼,會有一些可訓(xùn)練模型參數(shù),數(shù)量比較少。如果采用相對位置編碼,例如RoPE和ALiBi,則不包含可訓(xùn)練的模型參數(shù)。我們忽略這部分參數(shù)。
綜上,層transformer模型的可訓(xùn)練模型參數(shù)量為
。當(dāng)隱藏維度?
?較大時,可以忽略一次項,?模型參數(shù)量近似為
?。
接下來,我們估計不同版本LLaMA模型的參數(shù)量。
實際參數(shù)量 | 隱藏維度h | 層數(shù)l | 12lh^2 |
---|---|---|---|
6.7B | 4096 | 32 | 6,442,450,944 |
13.0B | 5120 | 40 | 12,582,912,000 |
32.5B | 6656 | 60 | 31,897,681,920 |
65.2B | 8192 | 80 | 64,424,509,440 |
2.1 訓(xùn)練過程中的顯存占用分析
在訓(xùn)練神經(jīng)網(wǎng)絡(luò)的過程中,占用顯存的大頭主要分為四部分:模型參數(shù)、前向計算過程中產(chǎn)生的中間激活、后向傳遞計算得到的梯度、優(yōu)化器狀態(tài)。這里著重分析參數(shù)、梯度和優(yōu)化器狀態(tài)的顯存占用,中間激活的顯存占用后面會詳細介紹。訓(xùn)練大模型時通常會采用AdamW優(yōu)化器,并用混合精度訓(xùn)練來加速訓(xùn)練,基于這個前提分析顯存占用。
在一次訓(xùn)練迭代中,每個可訓(xùn)練模型參數(shù)都會對應(yīng)1個梯度,并對應(yīng)2個優(yōu)化器狀態(tài)(Adam優(yōu)化器梯度的一階動量和二階動量)。設(shè)模型參數(shù)量為?,那么梯度的元素數(shù)量為
?,AdamW優(yōu)化器的元素數(shù)量為
。float16數(shù)據(jù)類型的元素占2個bytes,float32數(shù)據(jù)類型的元素占4個bytes。在混合精度訓(xùn)練中,會使用float16的模型參數(shù)進行前向傳遞和后向傳遞,計算得到float16的梯度;在優(yōu)化器更新模型參數(shù)時,會使用float32的優(yōu)化器狀態(tài)、float32的梯度、float32的模型參數(shù)來更新模型參數(shù)。因此,對于每個可訓(xùn)練模型參數(shù),占用了
。使用AdamW優(yōu)化器和混合精度訓(xùn)練來訓(xùn)練參數(shù)量為?
的大模型,?模型參數(shù)、梯度和優(yōu)化器狀態(tài)占用的顯存大小為
?。
2.2 推理過程中的顯存占用分析
在神經(jīng)網(wǎng)絡(luò)的推理階段,沒有優(yōu)化器狀態(tài)和梯度,也不需要保存中間激活。少了梯度、優(yōu)化器狀態(tài)、中間激活,模型推理階段占用的顯存要遠小于訓(xùn)練階段。模型推理階段,占用顯存的大頭主要是模型參數(shù),如果使用float16來進行推理,推理階段模型參數(shù)占用的顯存大概是?。如果使用KVcache來加速推理過程,?KV cache也需要占用顯存,KVcache占用的顯存下文會詳細介紹。此外,輸入數(shù)據(jù)也需要放到GPU上,還有一些中間結(jié)果(推理過程中的中間結(jié)果用完會盡快釋放掉),不過這部分占用的顯存是很小的,可以忽略。
3. 計算量FLOPs估計
FLOPs,floating point operations,表示浮點數(shù)運算次數(shù),衡量了計算量的大小。
如何計算矩陣乘法的FLOPs呢?
對于?,計算?
?需要進行?
?次乘法運算和?
?次加法運算,共計?
?次浮點數(shù)運算,需要?
?的FLOPs。對于?
?,計算?
?需要的浮點數(shù)運算次數(shù)為?
?。
在一次訓(xùn)練迭代中,假設(shè)輸入數(shù)據(jù)的形狀為?。我們?先分析self-attention塊的計算,計算公式如下:
1. 計算?:矩陣乘法的輸入和輸出形狀為
。計算量為
。
2.?矩陣乘法的輸入和輸出形狀為
。計算量為?
?。
3. 計算在?上的加權(quán)?
?,矩陣乘法的輸入和輸出形狀為
。計算量為?
?。
4. attention后的線性映射,矩陣乘法的輸入和輸出形狀為。計算量為?
?。
接下來分析MLP塊的計算,計算公式如下:
1. 第一個線性層,矩陣乘法的輸入和輸出形狀為。計算量為?
?。
2. 第二個線性層,矩陣乘法的輸入和輸出形狀為。計算量為?
?。
將上述計算量相加,得到每個transformer層的計算量大約為?。
此外,另一個計算量的大頭是logits的計算,將隱藏向量映射為詞表大小。矩陣乘法的輸入和輸出形狀為,計算量為?
?。
因此,對于一個?層的transformer模型,輸入數(shù)據(jù)形狀為
?的情況下,一次訓(xùn)練迭代的計算量為
。
3.1 計算量與參數(shù)量的關(guān)聯(lián)
當(dāng)隱藏維度?比較大,且遠大于序列長度
?時,我們可以忽略一次項,計算量可以近似為
?。前面提到當(dāng)模型參數(shù)量為
?,輸入的tokens數(shù)為
?,存在等式
。我們可以近似認為:?在一次前向傳遞中,對于每個token,每個模型參數(shù),需要進行2次浮點數(shù)運算,即一次乘法法運算和一次加法運算。
一次訓(xùn)練迭代包含了前向傳遞和后向傳遞,后向傳遞的計算量是前向傳遞的2倍。因此,前向傳遞 + 后向傳遞的系數(shù)。一次訓(xùn)練迭代中,對于每個token,每個模型參數(shù),需要進行
?次浮點數(shù)運算。
接下來,我們可以估計訓(xùn)練GPT3-175B所需要的計算量。對于GPT3,每個token,每個參數(shù)進行了6次浮點數(shù)運算,再乘以參數(shù)量和總tokens數(shù)就得到了總的計算量。GPT3的模型參數(shù)量為?,訓(xùn)練數(shù)據(jù)量為?
?tokens。
3.2 訓(xùn)練時間估計
模型參數(shù)量和訓(xùn)練總tokens數(shù)決定了訓(xùn)練transformer模型需要的計算量。給定硬件GPU類型的情況下,可以估計所需要的訓(xùn)練時間。給定計算量,訓(xùn)練時間(也就是GPU算完這么多flops的計算時間)不僅跟GPU類型有關(guān),還與GPU利用率有關(guān)。計算端到端訓(xùn)練的GPU利用率時,不僅要考慮前向傳遞和后向傳遞的計算時間,還要**考慮CPU加載數(shù)據(jù)、優(yōu)化器更新、多卡通信和記錄日志的時間。一般來講,GPU利用率一般在之間。
上文講到一次前向傳遞中,對于每個token,每個模型參數(shù),進行2次浮點數(shù)計算。使用激活重計算技術(shù)來減少中間激活顯存(下文會詳細介紹)需要進行一次額外的前向傳遞,因此前向傳遞+ 后向傳遞 + 激活重計算的系數(shù)=1+2+1=4。使用激活重計算的一次訓(xùn)練迭代中,對于每個token,每個模型參數(shù),需要進行?次浮點數(shù)運算。在給定訓(xùn)練tokens數(shù)、硬件環(huán)境配置的情況下,訓(xùn)練transformer模型的計算時間為:
以GPT3-175B為例,在1024張40GB顯存的A100上,在300Btokens的數(shù)據(jù)上訓(xùn)練175B參數(shù)量的GPT3。40GB顯存A100的峰值性能為312TFLOPS,設(shè)GPU利用率為0.45,則所需要的訓(xùn)練時間為34天,這與[7]中的訓(xùn)練時間是對得上的。
以LLaMA-65B為例,在2048張80GB顯存的A100上,在1.4TBtokens的數(shù)據(jù)上訓(xùn)練了65B參數(shù)量的模型。80GB顯存A100的峰值性能為624TFLOPS,設(shè)GPU利用率為0.3,則所需要的訓(xùn)練時間為21天,這與[4]中的實際訓(xùn)練時間是對得上的。
4. 中間激活值分析
除了模型參數(shù)、梯度、優(yōu)化器狀態(tài)外,占用顯存的大頭就是前向傳遞過程中計算得到的中間激活值了,需要保存中間激活以便在后向傳遞計算梯度時使用。這里的激活(activations)指的是:前向傳遞過程中計算得到的,并在后向傳遞過程中需要用到的所有張量。這里的激活不包含模型參數(shù)和優(yōu)化器狀態(tài),但包含了dropout操作需要用到的mask矩陣。
在分析中間激活的顯存占用時,只考慮激活占用顯存的大頭,忽略掉一些小的buffers。比如,對于layernormalization,計算梯度時需要用到層的輸入、輸入的均值?和方差
?。輸入包含了
?個元素,而輸入的均值和方差分別包含了
?個元素。由于
?通常是比較大的(千數(shù)量級),有?
?。因此,對于layernormalization,中間激活近似估計為?
?,而不是
?。
大模型在訓(xùn)練過程中通常采用混合精度訓(xùn)練,中間激活值一般是float16或者bfloat16數(shù)據(jù)類型的。在分析中間激活的顯存占用時,假設(shè)中間激活值是以float16或bfloat16數(shù)據(jù)格式來保存的,每個元素占了2個bytes。唯一例外的是,dropout操作的mask矩陣,每個元素只占1個bytes。在下面的分析中,單位是bytes,而不是元素個數(shù)。
每個transformer層包含了一個self-attention塊和MLP塊,并分別對應(yīng)了一個layer normalization連接。
先分析self-attention塊的中間激活。self-attention塊的計算公式如下:
1. 對于?,需要保存它們共同的輸入
?,這就是中間激活。輸入
?的形狀為
?,元素個數(shù)為
?,占用顯存大小為
?。
2. 對于?矩陣乘法,需要保存中間激活
?,兩個張量的形狀都是
?,占用顯存大小合計為
?。
3. 對于函數(shù),需要保存函數(shù)的輸入?
?,占用顯存大小為
?,這里的
?表示注意力頭數(shù)。
?的形狀為:?
?的形狀為:
?的形狀為:
,元素個數(shù)為?
?,占用顯存大小為
?。
4. 計算完函數(shù)后,會進行dropout操作。需要保存一個mask矩陣,mask矩陣的形狀與
?相同,占用顯存大小為
?。
5. 計算在?上的attention,即?
?,需要保存
?,大小為
?;以及
?,大小為
?。二者占用顯存大小合計為
?。
6. 計算輸出映射以及一個dropout操作。輸入映射需要保存其輸入,大小為?;dropout需要保存mask矩陣,大小為
?。二者占用顯存大小合計為
?。
因此,將上述中間激活相加得到,self-attention塊的中間激活占用顯存大小為?。
接下來看MLP塊的中間激活。MLP塊的計算公式如下:
1. 第一個線性層需要保存其輸入,占用顯存大小為?。
2. 激活函數(shù)需要保存其輸入,占用顯存大小為?。
3. 第二個線性層需要保存其輸入,占用顯存大小為?。
4. 最后有一個dropout操作,需要保存mask矩陣,占用顯存大小為?。
對于MLP塊,需要保存的中間激活值為?。
另外,self-attention塊和MLP塊分別對應(yīng)了一個layer normalization。每個layer norm需要保存其輸入,大小為?。2個layer norm需要保存的中間激活為
?。
綜上,每個transformer層需要保存的中間激活占用顯存大小為?。對于
層transformer模型,還有embedding層、最后的輸出層。embedding層不需要中間激活??偟亩?,當(dāng)隱藏維度
?比較大,層數(shù)
?較深時,這部分的中間激活是很少的,可以忽略。因此,對于
?層transformer模型,中間激活占用的顯存大小可以近似為
。
4.1 對比中間激活與模型參數(shù)的顯存大小
在一次訓(xùn)練迭代中,模型參數(shù)(或梯度)占用的顯存大小只與模型參數(shù)量和參數(shù)數(shù)據(jù)類型有關(guān),與輸入數(shù)據(jù)的大小是沒有關(guān)系的。優(yōu)化器狀態(tài)占用的顯存大小也是一樣,與優(yōu)化器類型有關(guān),與模型參數(shù)量有關(guān),但與輸入數(shù)據(jù)的大小無關(guān)。而中間激活值與輸入數(shù)據(jù)的大?。ㄅ未笮?/strong>?和序列長度
?)是成正相關(guān)的,隨著批次大小
?和序列長度
的增大,中間激活占用的顯存會同步增大。當(dāng)我們訓(xùn)練神經(jīng)網(wǎng)絡(luò)遇到顯存不足OOM(Out OfMemory)問題時,通常會嘗試減小批次大小來避免顯存不足的問題,這種方式減少的其實是中間激活占用的顯存,而不是模型參數(shù)、梯度和優(yōu)化器的顯存。
以GPT3-175B為例,我們來直觀地對比下模型參數(shù)與中間激活的顯存大小。GPT3的模型配置如下。我們假設(shè)采用混合精度訓(xùn)練,模型參數(shù)和中間激活都采用float16數(shù)據(jù)類型,每個元素占2個bytes。
模型名 | 參數(shù)量 | 層數(shù) | 隱藏維度 | 注意力頭數(shù) |
---|---|---|---|---|
GPT3 | 175B | 96 | 12288 | 96 |
GPT3的模型參數(shù)量為175B,占用的顯存大小為。GPT3模型需要占用350GB的顯存。
GPT3的序列長度?為
?。對比不同的批次大小
?占用的中間激活:
當(dāng)?時,中間激活占用顯存為
,大約是模型參數(shù)顯存的0.79倍。
當(dāng)?時,中間激活占用顯存為
,大約是模型參數(shù)顯存的50倍。
當(dāng)?時,中間激活占用顯存為
,大約是模型參數(shù)顯存的101倍。
可以看到隨著批次大小的增大,中間激活占用的顯存遠遠超過了模型參數(shù)顯存。通常會采用?激活重計算技術(shù)來減少中間激活,理論上可以將中間激活顯存從
?減少到
,代價是增加了一次額外前向計算的時間,本質(zhì)上是“時間換空間”。
5. KV cache
在推斷階段,transformer模型加速推斷的一個常用策略就是使用 KV cache。一個典型的大模型生成式推斷包含了兩個階段:
1.預(yù)填充階段:輸入一個prompt序列,為每個transformer層生成 key cache和value cache(KV cache)。
2.解碼階段:使用并更新KV cache,一個接一個地生成詞,當(dāng)前生成的詞依賴于之前已經(jīng)生成的詞。
第?個transformer層的權(quán)重矩陣為
。其中,self-attention塊的4個權(quán)重矩陣?
,并且MLP塊的2個權(quán)重矩陣?
。
預(yù)填充階段
假設(shè)第?個transformer層的輸入為
?,self-attention塊的key、value、query和output表示為
,其中,?
。
key cache和value cache的計算過程為:
第?個transformer層剩余的計算過程為:
解碼階段
給定當(dāng)前生成詞在第?個transformer層的向量表示為
。推斷計算分兩部分:更新KV cache和計算第?
個transformer層的輸出。
更新key cache和value cache的計算過程如下:
第?個transformer層剩余的計算過程為:
5.1 KV cache的顯存占用分析
假設(shè)輸入序列的長度為?,輸出序列的長度為
?,以float16來保存KV cache,那么?KVcache的峰值顯存占用大小為
。這里第一個2表示K/V cache,第二個2表示float16占2個bytes。
以GPT3為例,對比KV cache與模型參數(shù)占用顯存的大小。GPT3模型占用顯存大小為350GB。假設(shè)批次大小?,輸入序列長度
?,輸出序列長度
?,則KV cache占用顯存為
,大約是模型參數(shù)顯存的0.5倍。
6. 總結(jié)
本文首先介紹了如何計算transformer模型的參數(shù)量,基于參數(shù)量可以進一步估計模型參數(shù)、梯度和優(yōu)化器狀態(tài)占用的顯存大小。接著,本文估計了訓(xùn)練迭代中,在給定訓(xùn)練tokens數(shù)的情況下transformer模型的計算量,給予計算量和顯卡性能可以進一步估計訓(xùn)練迭代的計算耗時。然后,本文分析了transformer模型前向計算過程中產(chǎn)生的中間激活值的顯存大小,中間激活的顯存大小與輸入數(shù)據(jù)大小正相關(guān),甚至?xí)h超過模型參數(shù)占用的顯存。最后,本文介紹了transformer模型推理過程常用的加速策略:使用KVcache。總的來說,分析transformer模型的參數(shù)量、計算量、中間激活和KV cache,有助于理解大模型訓(xùn)練和推斷過程中的顯存效率和計算效率。
-
模型
+關(guān)注
關(guān)注
1文章
3464瀏覽量
49817 -
Transformer
+關(guān)注
關(guān)注
0文章
148瀏覽量
6323 -
ChatGPT
+關(guān)注
關(guān)注
29文章
1584瀏覽量
8662
原文標題:分析transformer模型的參數(shù)量、計算量、中間激活、KV cache
文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
基于卷積的基礎(chǔ)模型InternImage網(wǎng)絡(luò)技術(shù)分析

基于Transformer做大模型預(yù)訓(xùn)練基本的并行范式

如何使用MATLAB構(gòu)建Transformer模型

你了解在單GPU上就可以運行的Transformer模型嗎
Google科學(xué)家設(shè)計簡化稀疏架構(gòu)Switch Transformer,語言模型的參數(shù)量可擴展至 1.6 萬億
一個GPU訓(xùn)練一個130億參數(shù)的模型

超大Transformer語言模型的分布式訓(xùn)練框架

Microsoft使用NVIDIA Triton加速AI Transformer模型應(yīng)用
在X3派上玩轉(zhuǎn)一億參數(shù)量超大Transformer,DIY專屬你的離線語音識別

基于Transformer的大型語言模型(LLM)的內(nèi)部機制

transformer模型詳解:Transformer 模型的壓縮方法

盤古大模型參數(shù)量有多少
基于Transformer模型的壓縮方法

評論