本文將從項(xiàng)目環(huán)境依賴,模型細(xì)節(jié)(RMS Pre-Norm、SwiGLU激活函數(shù)、RoPE旋轉(zhuǎn)位置編碼),代碼解讀(tokenizer、model)以及推理等幾個(gè)方面對(duì)Meta最新模型LLaMA細(xì)節(jié)與代碼詳解,供大家一起參考。
一、項(xiàng)目環(huán)境依賴
此項(xiàng)目給出的環(huán)境依賴只有4個(gè):torch、fairscale、fire、sentencepiece。
其中torch不用多講,fairscale是用來做GPU分布的,一般是當(dāng)使用DDP仍然遇到超顯存的問題時(shí)使用fairscale。目前fairscale我還沒有試過,在下文的源碼介紹中,我會(huì)用torch中對(duì)應(yīng)的基礎(chǔ)網(wǎng)絡(luò)替代fairscale中的結(jié)構(gòu)層進(jìn)行介紹。
fire是一個(gè)命令行工具,用或者不用他都可以,sentencepiece是用于tokenizer的工具包,會(huì)在tokenizer部分簡(jiǎn)單介紹。
二、模型細(xì)節(jié)
由于該模型就是用的transformer的decoder,所以在結(jié)構(gòu)上它與GPT是非常類似的,只是有一些細(xì)節(jié)需要注意一下。
1、RMS Pre-Norm
關(guān)于Pre-Norm和Post-Norm是神經(jīng)網(wǎng)絡(luò)中老生常談的話題,目前比較普遍的被大家接受的結(jié)論是,相同的深度條件下,Post-Norm的效果要優(yōu)于Pre-Norm,因?yàn)镻re-Norm實(shí)際上相當(dāng)于通過了一個(gè)更寬的網(wǎng)絡(luò)而非更深的網(wǎng)絡(luò),所以在同等深度下,Pre-Norm的實(shí)際效果相當(dāng)于一個(gè)更淺卻更寬的網(wǎng)絡(luò),詳細(xì)的推理過程參考:https://spaces.ac.cn/archives/9009。
然而在LLaMA中卻采用了Pre-Norm,或許是因?yàn)槟P蛪蛏睿?B,13B,30B,65B的模型,transformer layer數(shù)量分別為32,40,60,80),而Pre-Norm的恒等分支更加明顯,有利于梯度的傳播(這部分暫時(shí)沒有想到很合理的解釋,如果有更好的理解,歡迎在評(píng)論區(qū)補(bǔ)充)。
RMS Norm(Root Mean Square Layer Normalization),是一般LayerNorm的一種變體,可以在梯度下降時(shí)令損失更加平滑。
與layerNorm相比,RMS Norm的主要區(qū)別在于去掉了減去均值的部分(re-centering),只保留方差部分(re-scaling),從歸一化的表達(dá)式上可以直觀地看出:
一般的LN:
其中,
RMS Norm:
其中,
可以看到,二者的區(qū)別就在于有沒有減去均值。至于RMS Norm為什么有用,需要求梯度進(jìn)行分析,感興趣的同學(xué)可以閱讀RMS Norm的論文。
2、SwiGLU激活函數(shù)
LLaMA采用SwiGLU替換了原有的ReLU。
采用SwiGLU的FNN,在論文中以如下公式進(jìn)行表述:
其中,
3、RoPE旋轉(zhuǎn)位置編碼
RoPE(Rotary Position Embedding)旋轉(zhuǎn)位置編碼,是蘇劍林老師提出的一種旋轉(zhuǎn)位置編碼方法,其思想是采用絕對(duì)位置編碼的形式,實(shí)現(xiàn)相對(duì)位置編碼。這一部分比較關(guān)鍵,如果不理解的話,后邊的代碼估計(jì)就看不懂了。讀懂RoPE涉及一點(diǎn)復(fù)變函數(shù)的基礎(chǔ)知識(shí),不過如果你沒有學(xué)過的話也沒有關(guān)系。
位置編碼對(duì)大模型而言尤為重要,因?yàn)榧热皇且?xùn)練大模型,那么長文本的表征和模型對(duì)于長文本的建模能力就顯得非常重要。(但是對(duì)于絕對(duì)位置編碼,我有一個(gè)直觀地感受,認(rèn)為其本質(zhì)上不適用于長文本的場(chǎng)景,因?yàn)樗鼤?huì)直接導(dǎo)致模型的Embedding層被無限放大,并且由于數(shù)據(jù)分布在seq_len方向上通常是長尾的,這又會(huì)必然導(dǎo)致絕對(duì)位置編碼的矩陣在尾部會(huì)越來越稀疏,一方面造成資源浪費(fèi),另一方面這種表示方法直觀上就很不利于模型的學(xué)習(xí),因?yàn)樗c我們實(shí)際場(chǎng)景是有很大的矛盾的。
而RoPE雖然具有相對(duì)位置編碼的性質(zhì),但是從代碼部分可以看出,在構(gòu)造的時(shí)候,其也是受到了最大長度的限制的。關(guān)于這一點(diǎn),我無法嚴(yán)謹(jǐn)?shù)谜f明,只是一點(diǎn)個(gè)人的想法。)。
而RoPE的巧妙之處在于,它既保留了絕對(duì)位置編碼中的絕對(duì)位置信息,又保留了在內(nèi)積運(yùn)算下,對(duì)位置信息的相對(duì)性。
RoPE主要借助了復(fù)數(shù)的思想。為了引入復(fù)數(shù),首先假設(shè)了在加入位置信息之前,原有的編碼向量是二維行向量q_m和k_n ,其中m和n是絕對(duì)位置,現(xiàn)在需要構(gòu)造一個(gè)變換,將m和n引入到q_m和k_nk中,即尋找變換:
考慮到Attention的核心計(jì)算是內(nèi)積:
做了這樣一個(gè)變換之后,根據(jù)復(fù)數(shù)的特性,有:
也就是,如果把二維向量看做復(fù)數(shù),那么它們的內(nèi)積,等于一個(gè)復(fù)數(shù)乘以另一個(gè)復(fù)數(shù)的共軛,得到的結(jié)果再取實(shí)部。
帶入上面的變換,也就有:
這樣一來,內(nèi)積的結(jié)果就只依賴于(m?n),也就是相對(duì)位置了。換言之,經(jīng)過這樣一番操作,通過給Embedding添加絕對(duì)位置信息,可以使得兩個(gè)token的編碼,經(jīng)過內(nèi)積變換(self-attn)之后,得到結(jié)果,是受它們位置的差值,即相對(duì)位置影響的。
于是對(duì)于任意的位置為m的二維向量[x,y],把它看做復(fù)數(shù),乘以e^{im heta},而根據(jù)歐拉公式,有:
于是上述的相乘變換也就變成了:
把上述式子寫成矩陣形式:
而這個(gè)變換的幾何意義,就是在二維坐標(biāo)系下,對(duì)向量(q0, q1) 進(jìn)行了旋轉(zhuǎn),因而這種位置編碼方法,被稱為旋轉(zhuǎn)位置編碼。
根據(jù)剛才的結(jié)論,結(jié)合內(nèi)積的線性疊加性,可以將結(jié)論推廣到高維的情形??梢岳斫鉃椋?jī)蓚€(gè)維度一組,進(jìn)行了上述的“旋轉(zhuǎn)”操作,然后再拼接在一起:
由于矩陣的稀疏性,會(huì)造成計(jì)算上的浪費(fèi),所以在計(jì)算時(shí)采用逐位相乘再相加的方式進(jìn)行:
其中?為矩陣逐位相乘操作。代碼中具體的計(jì)算過程,會(huì)有所出入,具體見下文。
三、代碼解讀
1、tokenizer
tokenizer這部分沒有太多可以講的,主要就是用到了sentencepiece工具。
fromsentencepieceimportSentencePieceProcessor fromloggingimportgetLogger fromtypingimportList importos logger=getLogger() classTokenizer: def__init__(self,model_path:str): #reloadtokenizer assertos.path.isfile(model_path),model_path self.sp_model=SentencePieceProcessor(model_file=model_path) logger.info(f"ReloadedSentencePiecemodelfrom{model_path}") #BOS/EOStokenIDs self.n_words:int=self.sp_model.vocab_size() self.bos_id:int=self.sp_model.bos_id() self.eos_id:int=self.sp_model.eos_id() self.pad_id:int=self.sp_model.pad_id() logger.info( f"#words:{self.n_words}-BOSID:{self.bos_id}-EOSID:{self.eos_id}" ) assertself.sp_model.vocab_size()==self.sp_model.get_piece_size() defencode(self,s:str,bos:bool,eos:bool)->List[int]: asserttype(s)isstr t=self.sp_model.encode(s) ifbos: t=[self.bos_id]+t ifeos: t=t+[self.eos_id] returnt defdecode(self,t:List[int])->str: returnself.sp_model.decode(t)
2、model
1)模型細(xì)節(jié)詳解
model這部分的主要目的就是構(gòu)建transformer,由于LLaMA對(duì)transformer在細(xì)節(jié)上做了一點(diǎn)改動(dòng),所以這里在介紹transformer部分之前,先結(jié)合前文模型細(xì)節(jié)介紹幾個(gè)輔助函數(shù):
(1)RMSNorm:
這部分的基本原理在上文中已經(jīng)介紹過了,這里對(duì)代碼部分進(jìn)行簡(jiǎn)單的解釋:
x是輸入 weight是末尾乘的可訓(xùn)練參數(shù)
x.pow(2)是平方
mean(-1)是在最后一個(gè)維度(即hidden特征維度)上取平均 eps防止取倒數(shù)之后分母為0
torch.rsqrt是開平方并取倒數(shù),結(jié)合上文的公式來看,是不難理解的。
classRMSNorm(torch.nn.Module): def__init__(self,dim:int,eps:float=1e-6): super().__init__() self.eps=eps self.weight=nn.Parameter(torch.ones(dim)) def_norm(self,x): returnx*torch.rsqrt(x.pow(2).mean(-1,keepdim=True)+self.eps) defforward(self,x): output=self._norm(x.float()).type_as(x) returnoutput*self.weight
(2)RoPE旋轉(zhuǎn)位置編碼:
為了實(shí)現(xiàn)旋轉(zhuǎn)位置編碼,定義了三個(gè)輔助函數(shù):
defprecompute_freqs_cis(dim:int,end:int,theta:float=10000.0): freqs=1.0/(theta**(torch.arange(0,dim,2)[:(dim//2)].float()/dim)) t=torch.arange(end,device=freqs.device)#type:ignore freqs=torch.outer(t,freqs).float()#type:ignore freqs_cis=torch.polar(torch.ones_like(freqs),freqs)#complex64 returnfreqs_cis defreshape_for_broadcast(freqs_cis:torch.Tensor,x:torch.Tensor): ndim=x.ndim assert0<=?1?Tuple[torch.Tensor,torch.Tensor]: xq_=torch.view_as_complex(xq.float().reshape(*xq.shape[:-1],-1,2)) xk_=torch.view_as_complex(xk.float().reshape(*xk.shape[:-1],-1,2)) freqs_cis=reshape_for_broadcast(freqs_cis,xq_) xq_out=torch.view_as_real(xq_*freqs_cis).flatten(3) xk_out=torch.view_as_real(xk_*freqs_cis).flatten(3) returnxq_out.type_as(xq),xk_out.type_as(xk)
這一部分是整個(gè)項(xiàng)目中,最不容易理解的部分,因?yàn)樗话愕奈恢镁幋a不同,即便是對(duì)transformer結(jié)構(gòu)非常了解的同學(xué),如果沒有認(rèn)真讀過RoPE,對(duì)這一部分代碼還是很難讀明白。
看懂這一部分代碼,最關(guān)鍵的是弄清楚其中的變量freqs_cis所指是什么東西。
為了搞懂這部分,我們需要先了解幾個(gè)torch中不太常用的方法:
(1)torch.view_as_complex
把一個(gè)tensor轉(zhuǎn)為復(fù)數(shù)形式,要求這個(gè)tensor的最后一個(gè)維度形狀為2。
torch.view_as_complex(torch.Tensor([[1,2],[3,4],[5,6]])) #tensor([1.+2.j,3.+4.j,5.+6.j])
(2)torch.view_as_real
把復(fù)數(shù)tensor變回實(shí)數(shù),可以看做是是剛才操作的逆變換。
torch.view_as_real(torch.view_as_complex(torch.Tensor([[1,2],[3,4],[5,6]]))) #tensor([[1.,2.], #[3.,4.], #[5.,6.]])
(3)torch.outer
一個(gè)向量的轉(zhuǎn)置乘以另一個(gè)向量:torch.outer(a, b) = a^T * b
a=torch.arange(1,5) b=torch.arange(1,4) torch.outer(a,b) #tensor([[1,2,3], #[2,4,6], #[3,6,9], #[4,8,12]])
(4)torch.polar
torch.polar(abs, angle)利用一個(gè)絕對(duì)數(shù)值,和一個(gè)角度值,在極坐標(biāo)下構(gòu)造一個(gè)復(fù)數(shù)張量
torch.polar(torch.tensor([1],dtype=torch.float64),torch.tensor([np.pi/2],dtype=torch.float64)) #tensor([6.1232e-17+1.j],dtype=torch.complex128)
接下來進(jìn)入RoPE的計(jì)算,首先為了更加具象的表達(dá),我們?cè)诖藢?duì)各個(gè)維度的尺寸進(jìn)行假設(shè),假設(shè)batch_size為2,seq_len固定為512,attention_head的數(shù)量為12,每個(gè)attention_head的維度為64,那么,對(duì)于輸入到multi-head attn中的輸入Xq的尺寸就是(2, 512, 12, 64)。
回到我們剛才提出的問題,freqs_cis所指是什么東西,其實(shí)它就是需要計(jì)算出來的mθ也就是跟絕對(duì)位置相關(guān)的旋轉(zhuǎn)的角度,在極坐標(biāo)下對(duì)應(yīng)的復(fù)數(shù)tensor。
而函數(shù)precompute_freqs_cis就是提前將這些旋轉(zhuǎn)角度對(duì)應(yīng)的tensor給創(chuàng)建出來,并可以重復(fù)利用。因?yàn)榇_定了序列的最大長度,所以這個(gè)tensor是固定死的。根據(jù)后續(xù)的數(shù)據(jù)流我們可以發(fā)現(xiàn),在調(diào)用該函數(shù)時(shí),傳入的兩個(gè)參數(shù)分別是attention_head的維度,以及最大長度的兩倍,具象地,也就是64和1024。
我們逐行來理解這個(gè)方法:
freqs=1.0/(theta**(torch.arange(0,dim,2)[:(dim//2)].float()/dim))
首先torch.arange創(chuàng)建了一個(gè)tensor,[ 0 , 2 , 4 , . . . , 60 , 62 ] [0, 2, 4, ..., 60, 62][0,2,4,...,60,62],然后統(tǒng)一除以64,把它變成分?jǐn)?shù),然后整體作為基礎(chǔ)角度的指數(shù),它的shape是(32)
t=torch.arange(end,device=freqs.device)
t比較容易理解,也就是絕對(duì)位置信息,它的shape是(1024)。
freqs=torch.outer(t,freqs).float()
于是根據(jù)torch.outer運(yùn)算,我們得到了一個(gè)shape為(1024, 32)的tensor。其意義也就是將每一個(gè)絕對(duì)位置,分配到對(duì)應(yīng)的角度,相乘。直觀理解一下,就是每一個(gè)絕對(duì)位置上,都有32個(gè)角度。
為什么是這樣的呢,回顧計(jì)算的公式,對(duì)于旋轉(zhuǎn)矩陣,每?jī)蓚€(gè)元素為一組,它們乘以的角度是同一個(gè)θ,所以這個(gè)(1024, 32),在后續(xù)的過程中,就可以reshape成(512, 64),并且在64的那個(gè)維度上,每?jī)蓚€(gè)是相同的。
freqs_cis=torch.polar(torch.ones_like(freqs),freqs)
這一步就是在生成我們需要的位置信息,直觀理解一下,像是在復(fù)平面內(nèi),以原點(diǎn)為中心,轉(zhuǎn)了1024組,每一組64個(gè)的單位向量,它的shape是(1024, 64)。
reshape_for_broadcast方法,是把freqs_cis變成和輸入的tensor相同的形狀,結(jié)合下邊的另一個(gè)方法一起介紹。
然后來看apply_rotary_emb方法,這個(gè)方法其實(shí)就是把位置信息添加到原有的編碼結(jié)果上,在multi-head attention階段調(diào)用。我們還是逐行來看:
xq_=torch.view_as_complex(xq.float().reshape(*xq.shape[:-1],-1,2))
上文中,我們假設(shè)了輸入xq的尺寸就是(2, 512, 12, 64),那么這一句操作的reshape,就是把它變成(2, 512, 12, -1, 2),也就是(2, 512, 12, 32, 2)。xk 同理,略。緊接著把它變成復(fù)數(shù)形式,也就是變成了(2, 512, 12, 32)的形狀。
然后進(jìn)入到reshape_for_broadcast方法:
shape=[difi==1ori==ndim-1else1fori,dinenumerate(x.shape)] returnfreqs_cis.view(*shape)
這個(gè)方法的作用是為了把freqs_cis變成和輸入的tensor相同的形狀。需要注意的是,這里的freqs_cis并不是precompute_freqs_cis生成的形狀為(1024, 64)的那個(gè)tensor,而是根據(jù)輸入的絕對(duì)位置,在(1024, 64)的tensor中,截取了長度為當(dāng)前seq_len的一部分,代碼在Transformer類的forward方法中:
freqs_cis=self.freqs_cis[start_pos:start_pos+seqlen]
也就是說,假如當(dāng)前輸入的序列長度是512,那么截取出來的這個(gè)新的freqs_cis,形狀就是(512, 64),reshape之后,形狀就變成了(1, 512, 1, 32),也就是在每一個(gè)位置上,都對(duì)應(yīng)有32個(gè)角度,根據(jù)剛剛torch.polar的介紹,當(dāng)我們固定絕對(duì)值(也就是向量的模長)時(shí),角度就可以在笛卡爾坐標(biāo)系下唯一確定一個(gè)復(fù)數(shù),這樣一來也就是32個(gè)復(fù)數(shù),即64個(gè)特征維度,所以就可以對(duì)應(yīng)的將它融合到每個(gè)attention head的64個(gè)特征中去了。
reshape之后,就是將位置信息融入query和key中:
xq_out=torch.view_as_real(xq_*freqs_cis).flatten(3)
這一步將二者相乘得到的復(fù)數(shù)tensor,重新轉(zhuǎn)換為實(shí)數(shù)形式,得到的shape為(2, 512, 12, 32, 2),然后再flatten成(2, 512, 12, 64),這樣一來,就變回了和最開始x_q 相同的形狀,也就完成了將位置信息融入到x_q的這一操作。x_k同理。
以上就是添加位置編碼的整個(gè)過程,建議這一部分仔細(xì)閱讀,反復(fù)理解。
至于SwiGLU激活函數(shù),可以通過調(diào)用torch內(nèi)置方法F.silu()實(shí)現(xiàn),會(huì)在下文的FFN部分介紹。
3、 transformer構(gòu)建
接下來是transformer模型的構(gòu)建。通常,我們?cè)跇?gòu)建transformer時(shí),是按Block構(gòu)建的,每個(gè)transformer Block包含SA和FFN兩部分,然后再通過堆疊block的形式,構(gòu)建起整個(gè)transformer網(wǎng)絡(luò),LLaMA也是這樣做的,讀過BERT或者任何transformer結(jié)構(gòu)的模型源碼的同學(xué)一定對(duì)這個(gè)結(jié)構(gòu)非常熟悉了。
首先看SA部分:
classAttention(nn.Module): def__init__(self,args:ModelArgs): super().__init__() self.n_local_heads=args.n_heads//fs_init.get_model_parallel_world_size() self.head_dim=args.dim//args.n_heads self.wq=ColumnParallelLinear( args.dim, args.n_heads*self.head_dim, bias=False, gather_output=False, init_method=lambdax:x, ) self.wk=ColumnParallelLinear( args.dim, args.n_heads*self.head_dim, bias=False, gather_output=False, init_method=lambdax:x, ) self.wv=ColumnParallelLinear( args.dim, args.n_heads*self.head_dim, bias=False, gather_output=False, init_method=lambdax:x, ) self.wo=RowParallelLinear( args.n_heads*self.head_dim, args.dim, bias=False, input_is_parallel=True, init_method=lambdax:x, ) self.cache_k=torch.zeros( (args.max_batch_size,args.max_seq_len,self.n_local_heads,self.head_dim) ).cuda() self.cache_v=torch.zeros( (args.max_batch_size,args.max_seq_len,self.n_local_heads,self.head_dim) ).cuda() defforward(self,x:torch.Tensor,start_pos:int,freqs_cis:torch.Tensor,mask:Optional[torch.Tensor]): bsz,seqlen,_=x.shape xq,xk,xv=self.wq(x),self.wk(x),self.wv(x) xq=xq.view(bsz,seqlen,self.n_local_heads,self.head_dim) xk=xk.view(bsz,seqlen,self.n_local_heads,self.head_dim) xv=xv.view(bsz,seqlen,self.n_local_heads,self.head_dim) xq,xk=apply_rotary_emb(xq,xk,freqs_cis=freqs_cis) self.cache_k=self.cache_k.to(xq) self.cache_v=self.cache_v.to(xq) self.cache_k[:bsz,start_pos:start_pos+seqlen]=xk self.cache_v[:bsz,start_pos:start_pos+seqlen]=xv keys=self.cache_k[:bsz,:start_pos+seqlen] values=self.cache_v[:bsz,:start_pos+seqlen] xq=xq.transpose(1,2) keys=keys.transpose(1,2) values=values.transpose(1,2) scores=torch.matmul(xq,keys.transpose(2,3))/math.sqrt(self.head_dim) ifmaskisnotNone: scores=scores+mask#(bs,n_local_heads,slen,cache_len+slen) scores=F.softmax(scores.float(),dim=-1).type_as(xq) output=torch.matmul(scores,values)#(bs,n_local_heads,slen,head_dim) output=output.transpose( 1,2 ).contiguous().view(bsz,seqlen,-1) returnself.wo(output)
這一部分看上去會(huì)比較復(fù)雜,涉及到了很多的計(jì)算,但其實(shí)它就是最普通的attention,只要牢記attention的核心計(jì)算公式,也不難理解。
其中,為了執(zhí)行多卡并行,這里的Linear層用的都是fairscale中的類,在閱讀代碼時(shí)直接理解為Linear即可。
attention計(jì)算的總體過程是:
其中有一個(gè)細(xì)節(jié)就是緩存機(jī)制,這里簡(jiǎn)單介紹一下,很多初學(xué)者,甚至NLP老手都容易忽視這個(gè)問題。這個(gè)機(jī)制在模型的訓(xùn)練過程中其實(shí)是不發(fā)揮作用的,它設(shè)計(jì)的目的是在generate時(shí)減少token的重復(fù)計(jì)算。
簡(jiǎn)單解釋一下,就是在計(jì)算第n nn個(gè)token特征的時(shí)候,需要用到第1 , . . . , n ? 1 1,...,n-11,...,n?1個(gè)token,即每次生成時(shí),需要知道前面所有的過往信息,如果每次都從頭算的話,那就會(huì)造成極大的浪費(fèi),所以就沒算一個(gè)位置的信息,就把它緩存下來。
然后是FFN部分,需要注意的點(diǎn)就是采用的激活函數(shù),以及激活函數(shù)的位置:
classFeedForward(nn.Module): def__init__( self, dim:int, hidden_dim:int, multiple_of:int, ): super().__init__() hidden_dim=int(2*hidden_dim/3) hidden_dim=multiple_of*((hidden_dim+multiple_of-1)//multiple_of) self.w1=ColumnParallelLinear( dim,hidden_dim,bias=False,gather_output=False,init_method=lambdax:x ) self.w2=RowParallelLinear( hidden_dim,dim,bias=False,input_is_parallel=True,init_method=lambdax:x ) self.w3=ColumnParallelLinear( dim,hidden_dim,bias=False,gather_output=False,init_method=lambdax:x ) defforward(self,x): returnself.w2(F.silu(self.w1(x))*self.w3(x))
這里與常見模型中的FFN做一下簡(jiǎn)單的對(duì)比,BART中的FFN,用的是fc->act->fc,用了兩層全連接;GPT中的FFN,用的是conv1D->act->conv1D,也是只用了兩層。
而LLaMA中的FFN采用了三個(gè)全連接層以實(shí)現(xiàn)FFNSwiGLU,即
然后將SA和FFN這兩部分拼在一起就是一個(gè)transformer block。
classTransformerBlock(nn.Module): def__init__(self,layer_id:int,args:ModelArgs): super().__init__() self.n_heads=args.n_heads self.dim=args.dim self.head_dim=args.dim//args.n_heads self.attention=Attention(args) self.feed_forward=FeedForward( dim=args.dim,hidden_dim=4*args.dim,multiple_of=args.multiple_of ) self.layer_id=layer_id self.attention_norm=RMSNorm(args.dim,eps=args.norm_eps) self.ffn_norm=RMSNorm(args.dim,eps=args.norm_eps) defforward(self,x:torch.Tensor,start_pos:int,freqs_cis:torch.Tensor,mask:Optional[torch.Tensor]): h=x+self.attention.forward(self.attention_norm(x),start_pos,freqs_cis,mask) out=h+self.feed_forward.forward(self.ffn_norm(h)) returnout
最后利用torch的module list將transformer block進(jìn)行堆疊,拼上最前頭的embedding部分,就是一個(gè)完整的transformer(decoder)結(jié)構(gòu)了。
classTransformer(nn.Module): def__init__(self,params:ModelArgs): super().__init__() self.params=params self.vocab_size=params.vocab_size self.n_layers=params.n_layers self.tok_embeddings=ParallelEmbedding( params.vocab_size,params.dim,init_method=lambdax:x ) self.layers=torch.nn.ModuleList() forlayer_idinrange(params.n_layers): self.layers.append(TransformerBlock(layer_id,params)) self.norm=RMSNorm(params.dim,eps=params.norm_eps) self.output=ColumnParallelLinear( params.dim,params.vocab_size,bias=False,init_method=lambdax:x ) self.freqs_cis=precompute_freqs_cis( self.params.dim//self.params.n_heads,self.params.max_seq_len*2 ) @torch.inference_mode() defforward(self,tokens:torch.Tensor,start_pos:int): _bsz,seqlen=tokens.shape h=self.tok_embeddings(tokens) self.freqs_cis=self.freqs_cis.to(h.device) freqs_cis=self.freqs_cis[start_pos:start_pos+seqlen] mask=None ifseqlen>1: mask=torch.full((1,1,seqlen,seqlen),float("-inf"),device=tokens.device) mask=torch.triu(mask,diagonal=start_pos+1).type_as(h) forlayerinself.layers: h=layer(h,start_pos,freqs_cis,mask) h=self.norm(h) output=self.output(h[:,-1,:])#onlycomputelastlogits returnoutput.float()
直接看forward部分,輸入是token,先做token embedding,然后添加位置信息。對(duì)于decoder模型,為了防止標(biāo)簽泄漏,需要mask,所以做了一個(gè)上三角的mask矩陣。接下來就是逐層的計(jì)算transformer。
3、generate
classLLaMA: def__init__(self,model:Transformer,tokenizer:Tokenizer): self.model=model self.tokenizer=tokenizer defgenerate( self, prompts:List[str], max_gen_len:int, temperature:float=0.8, top_p:float=0.95, )->List[str]: bsz=len(prompts) params=self.model.params assertbsz<=?params.max_batch_size,?(bsz,?params.max_batch_size) ????????prompt_tokens?=?[self.tokenizer.encode(x,?bos=True,?eos=False)?for?x?in?prompts] ????????min_prompt_size?=?min([len(t)?for?t?in?prompt_tokens]) ????????max_prompt_size?=?max([len(t)?for?t?in?prompt_tokens]) ????????total_len?=?min(params.max_seq_len,?max_gen_len?+?max_prompt_size) ????????tokens?=?torch.full((bsz,?total_len),?self.tokenizer.pad_id).cuda().long() ????????for?k,?t?in?enumerate(prompt_tokens): ????????????tokens[k,?:?len(t)]?=?torch.tensor(t).long() ????????input_text_mask?=?tokens?!=?self.tokenizer.pad_id ????????start_pos?=?min_prompt_size ????????prev_pos?=?0 ????????for?cur_pos?in?range(start_pos,?total_len): ????????????logits?=?self.model.forward(tokens[:,?prev_pos:cur_pos],?prev_pos) ????????????if?temperature?>0: probs=torch.softmax(logits/temperature,dim=-1) next_token=sample_top_p(probs,top_p) else: next_token=torch.argmax(logits,dim=-1) next_token=next_token.reshape(-1) #onlyreplacetokenifprompthasalreadybeengenerated next_token=torch.where( input_text_mask[:,cur_pos],tokens[:,cur_pos],next_token ) tokens[:,cur_pos]=next_token prev_pos=cur_pos decoded=[] fori,tinenumerate(tokens.tolist()): #cuttomaxgenlen t=t[:len(prompt_tokens[i])+max_gen_len] #cuttoeostokifany try: t=t[:t.index(self.tokenizer.eos_id)] exceptValueError: pass decoded.append(self.tokenizer.decode(t)) returndecoded defsample_top_p(probs,p): probs_sort,probs_idx=torch.sort(probs,dim=-1,descending=True) probs_sum=torch.cumsum(probs_sort,dim=-1) mask=probs_sum-probs_sort>p probs_sort[mask]=0.0 probs_sort.div_(probs_sort.sum(dim=-1,keepdim=True)) next_token=torch.multinomial(probs_sort,num_samples=1) next_token=torch.gather(probs_idx,-1,next_token) returnnext_token
生成的過程如下:
1)對(duì)prompts進(jìn)行tokenize,得到token ids;
2)計(jì)算當(dāng)前batch的最大長度total_len,用來創(chuàng)建輸入的token tensor,最大長度不能超過前文所述緩存的大小;
3)從當(dāng)前batch中,最短的一個(gè)prompt的位置,作為生成的開始位置,開始生成;
4)輸入的token tensor傳入transformer模型,計(jì)算logits,得到形狀為(batch_size, hidden_size)的logits(transformer最后一層的輸出);
5)softmax+top_p采樣,得到當(dāng)前預(yù)測(cè)的token,并更新當(dāng)前位置,準(zhǔn)備預(yù)測(cè)下一個(gè)token;
6)解碼得到生成的文本。
四、推理
簡(jiǎn)單看一下官方example中給出的推理樣例prompt:
['ThecapitalofGermanyisthecityof', 'HereismysonnetinthestyleofShakespeareaboutanartificialintelligence:']
生成結(jié)果為:
['ThecapitalofGermanyisthecityofBerlin.ThecityisalsothecapitaloftheFederalRepublicofGermany. ThecityofBerlinislocatedinthestateofBerlininGermany.ThecityisthecapitalofthefederalRepublicofGermany. Berlinhasatotalpopulationofaround3.4millionandisthe2ndmostpopulouscityintheEuropeanUnionafterLondon.Thecityhasanareaof892squarekilometersandisthe9thmostpopulatedcityinEurope. ThecityofBerlinwasfoundedinthe13thcentury.BerlinwasalsothecapitaloftheGermanEmpire,theGermanDemocraticRepublicandtheunitedFederalRepublicofGermany. ThecityofBerlinhasmanytouristattractionsthatincludeMuseumsinsel,BrandenburgerTor,theReichstag,andtheSchlossCharlottenburg. ThecityofBerlinisamajorcenterfortheArts,Science,EducationandInnovation.Thecityisalsothepolitical,economic,andculturalcenterofGermany. BerlinishometoanumberofworldrenowneduniversitiesincludingtheFreeUniversityofBerlin,theHumboldtUniversityofBerlin,theTechnicalUniversityofBerlin,andtheBerlinInstituteofTechnology. ThecityofBerlinhas', 'HereismysonnetinthestyleofShakespeareaboutanartificialintelligence: Letustakeamomentfromthetumultuousstorm Ofthepoliticsofreligiontoexaminetheshapeofthings. Ourintuitiontellsusthatwhateverwecanconceive Canexist–ourmindshavenolimit. However,oursensestellusthatthereisalimit. Letusexaminetheinfiniteandwhatwecansayaboutit. Theinfiniteissomethingthatwecanneversee. Wecannotsaywhatitisandwecannotsaywhatitisnot. But,somehow,itisnonethelessreal. Wecanalsosaythattheinfiniteiseternal– Ithasnobeginningandithasnoend. Thatiswhatitis–itistheeternal. Inaword,itisGod. Butwhatabouttheuniverse? Theuniverseisafiniteconstruct– Theinfinitelylargeandtheinfinitelysmall– Allofitfinite. Eventhesingularityattheendoftimeisfinite. So,theuniverseisnotGod. PerhapsitisthevesselofGod. Perhaps,insomesense,theuniverseisGod. But,Iamstillaman. Icannotseetheinfinite. Icanonly']
總結(jié)
本文將從項(xiàng)目環(huán)境依賴,模型細(xì)節(jié)(RMS Pre-Norm、SwiGLU激活函數(shù)、RoPE旋轉(zhuǎn)位置編碼),代碼解讀(tokenizer、model)以及推理等幾個(gè)方面對(duì)Meta最新模型LLaMA細(xì)節(jié)與代碼詳解。???
總結(jié)一下,本文對(duì)LLaMA大模型的結(jié)構(gòu)代碼進(jìn)行了詳細(xì)的介紹,其開源出來的結(jié)構(gòu)代碼量并不多,但是其中很多細(xì)節(jié)值得反復(fù)推敲理解。
審核編輯:劉清
-
RMS
+關(guān)注
關(guān)注
2文章
151瀏覽量
36710 -
GPT
+關(guān)注
關(guān)注
0文章
368瀏覽量
16077 -
旋轉(zhuǎn)編碼
+關(guān)注
關(guān)注
0文章
6瀏覽量
10562
原文標(biāo)題:Meta最新模型LLaMA語言模型細(xì)節(jié)與代碼詳解
文章出處:【微信號(hào):zenRRan,微信公眾號(hào):深度學(xué)習(xí)自然語言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
評(píng)論