在H100發(fā)布之際,英偉達(dá)還帶來一個(gè)“重磅產(chǎn)品”——Transformer Engine。在Transformer大火之際推出這么一個(gè)產(chǎn)品,無疑是煉丹師福音。
當(dāng)時(shí)我還在猜測它會以怎么樣的一種形式呈現(xiàn)給用戶,直到最近公開了倉庫 NVIDIA/TransformerEngine
這其實(shí)就是PyTorch的一個(gè)拓展,為了利用FP8的特性,針對Transformer里面的Kernel進(jìn)行了重寫,包含了一系列LayerNorm, GeLU, ScaledSoftmax等。
使用方式也是比較簡單,使用該拓展額外包的一層Module來搭建網(wǎng)絡(luò),即可,最后再包一層混合精度訓(xùn)練作用域:
importtorch importtransformer_engine.pytorchaste fromtransformer_engine.commonimportrecipe #Setdimensions. in_features=768 out_features=3072 hidden_size=2048 #Initializemodelandinputs. model=te.Linear(in_features,out_features,use_bias=True) inp=torch.randn(hidden_size,in_features,device="cuda") #創(chuàng)建FP8訓(xùn)練的配置 fp8_recipe=recipe.DelayedScaling(margin=0,interval=1,fp8_format=recipe.Format.E4M3) #FP8的autocast withte.fp8_autocast(enabled=True,fp8_recipe=fp8_recipe): out=model(inp) loss=out.sum() loss.backward()
本篇博客就簡單介紹下Transformer Engine及其對應(yīng)實(shí)現(xiàn)原理
官方文檔:https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Transfromer Engine 是干啥的?
在各種以Transformer為基礎(chǔ)的語言模型如GPT3大火后,語言模型的參數(shù)量還在以指數(shù)形式增長:
那么優(yōu)化Transformer性能就顯得格外重要了,其中混合精度訓(xùn)練是一個(gè)很實(shí)用的技巧
在FP16下,其數(shù)據(jù)范圍還是足夠大的,因此在AMP下,我們只在最后的Loss做了一個(gè)scaling,這個(gè)步驟足以保證在整個(gè)模型運(yùn)算過程中不會產(chǎn)生溢出
而FP8相比FP16減少了更多有效位,因此不能簡單地復(fù)用FP16下的策略,需要給每個(gè)FP8 Tensor單獨(dú)設(shè)置一個(gè)合適的scale factor。Transformer Engine 需要?jiǎng)討B(tài)地對輸入范圍進(jìn)行調(diào)整,如圖所示:
上圖來自H100白皮書內(nèi)(當(dāng)時(shí)我還天真的以為有一個(gè)專門的硬件做這個(gè)處理。。。)
下面我們簡單看下其代碼和實(shí)現(xiàn)原理
Kernel實(shí)現(xiàn)
具體到每一個(gè)算子實(shí)現(xiàn)動態(tài)范圍調(diào)整的原理其實(shí)很簡單,通過記錄歷史的abs max值,來去調(diào)整最終縮放的范圍。
其主要的Kernel實(shí)現(xiàn)都放在了 common 目錄下,我們以gelu這個(gè)kernel為例,最終它會調(diào)用到 vectorized_pointwise.h這個(gè)文件,我們主要看 unary_kernel
unary_kernel
這個(gè)核函數(shù)模板跟常規(guī)的elementwise向量化模板是類似的。
首先會讓每個(gè)線程獲取到scale值
ComputeTypes=0; ifconstexpr(is_fp8::value){ //獲取scale值 if(scale!=nullptr)s=*scale; //將scale取倒數(shù)寫回scale_inv if(blockIdx.x==0&&threadIdx.x==0&&scale_inv!=nullptr){ reciprocal (scale_inv,s); } }
其中在循環(huán)里,線程會不斷更新他運(yùn)算結(jié)果的最大值,并且最終運(yùn)算結(jié)果要乘上scale值:
//實(shí)際運(yùn)算發(fā)生 ComputeTypetemp=OP(val,p); ifconstexpr(is_fp8::value){ __builtin_assume(max>=0); max=fmaxf(fabsf(temp),max); //縮放 temp=temp*s; }
當(dāng)Kernel主體運(yùn)算完畢后,再也warp為單位做一個(gè)reduce_max,獲取到線程束內(nèi)的最大值,再通過atomicMax原子操作,不斷更新全局最大值:
ifconstexpr(is_fp8::value){ /*warptileamaxreduce*/ max=reduce_max (max,warp_id); if(threadIdx.x==0&&amax!=nullptr){ static_assert(std::is_same ::value); //更新全局最大值 atomicMaxFloat(amax,max); } }
其他layernorm等Kernel也是諸如類似的邏輯,這里就不再展開了
(1) DelayedScaling
從前面的示例代碼我們可以看到一個(gè)比較重要的API是DelayedScaling,我們可以根據(jù)官方文檔查看各個(gè)參數(shù)含義:
margin 計(jì)算scale的偏移量
interval 控制計(jì)算scale factor的頻率
fp8_format 使用FP8的格式,F(xiàn)P8有E4M3和E5M2,但是現(xiàn)在不支持純E5M2的格式訓(xùn)練
amax_history_len 記錄abs maxval的歷史窗口大小
amax_compute_algo 在窗口里選擇absmax的算法,'max'則是選擇歷史窗口里最大值,'most_recent'則是選擇近期的值,當(dāng)然你也可以傳一個(gè)自定義的函數(shù)
相關(guān)代碼為:
@torch.jit.script def_default_get_amax( amax_history:torch.Tensor, amax_compute_algo:str, )->Tuple[torch.Tensor,torch.Tensor]: """Defaultfunctiontoobtainamaxfromhistory.""" ifamax_compute_algo=="max": amax=torch.max(amax_history,dim=0).values else:#amax_compute_algo=="most_recent" amax=amax_history[0] amax_history=update_amax_history(amax_history) returnamax_history,amax
scaling_factor_compute_algo 計(jì)算scale factor的算法
@torch.jit.script def_default_sf_compute( amax:torch.Tensor, scale:torch.Tensor, fp8_max:float, margin:int, )->torch.Tensor: """Defaultfunctiontoconvertamaxtoscalingfactor.""" exp=torch.floor(torch.log2(fp8_max/amax))-margin sf=torch.round(torch.pow(2,torch.abs(exp))) sf=torch.where(amax>0.0,sf,scale) sf=torch.where(torch.isfinite(amax),sf,scale) sf=torch.where(exp0,?1?/?sf,?sf) ????return?sf
override_linear_precision 由3個(gè)bool值,分別控制fprop前向,dgrad,wgrad三個(gè)矩陣乘是否用更高的精度來計(jì)算,默認(rèn)都為False
(2) TransformerEngineBaseModule
相關(guān)的Kernel除了要完成自己的計(jì)算任務(wù),也得實(shí)時(shí)維護(hù)amax這些值,因此也需要對應(yīng)修改nn.Module,這里TransformerEngine繼承了nn.Module,并且增加了一些buffer維護(hù)的機(jī)制,這些buffer用于存儲動態(tài)scale的信息:
classTransformerEngineBaseModule(torch.nn.Module,ABC): def__init__(self)->None: ... self.fp8=False self.fp8_meta={} self.fp8_meta["fp8_group"]=None self.fp8_meta["recipe"]=get_default_fp8_recipe() deffp8_init(self,num_gemms:int=1)->None: """Initializefp8relatedmetadataandtensorsduringfprop.""" #Iffp8isn'tenabled,turnoffandreturn. ifnotis_fp8_enabled(): self.fp8=False return #FP8isalreadyenabledandrecipeisthesame,don'tdoanything. ifself.fp8andget_fp8_recipe()==self.fp8_meta["recipe"]: return #SetFP8,recipe,andotherFP8metadata self.fp8=True self.fp8_meta["recipe"]=get_fp8_recipe() self.fp8_meta["num_gemms"]=num_gemms self.fp8_meta["fp8_group"]=get_fp8_group() #SetFP8_MAXpertensoraccordingtorecipe self.fp8_meta["fp8_max_fwd"]=self.fp8_meta["recipe"].fp8_format.value.max_fwd self.fp8_meta["fp8_max_bwd"]=self.fp8_meta["recipe"].fp8_format.value.max_bwd #Allocatescalesandamaxes self.init_fp8_meta_tensors()
而相關(guān)Module如LayerNormMLP繼承該Module,并且傳入fp8_meta信息更新:
classLayerNormMLP(TransformerEngineBaseModule): defforward(...): out=_LayerNormMLP.apply( ..., self.fp8, self.fp8_meta, )
總結(jié)
大致瀏覽完其實(shí)思路不復(fù)雜,但感覺還是FP8技術(shù)的不穩(wěn)定,整個(gè)項(xiàng)目還是加入了很多限制。得益于PyTorch靈活的外部擴(kuò)展形式,只要不去觸碰框架底層運(yùn)行機(jī)制,僅僅在算子層面上的修改還是相當(dāng)簡單。雖然不具備通用性,但是運(yùn)算主體就這幾個(gè)算子,為了性能也是可以接受的
審核編輯:湯梓紅
-
NVIDIA
+關(guān)注
關(guān)注
14文章
5309瀏覽量
106452 -
英偉達(dá)
+關(guān)注
關(guān)注
22文章
3953瀏覽量
93830 -
Transformer
+關(guān)注
關(guān)注
0文章
151瀏覽量
6524 -
pytorch
+關(guān)注
關(guān)注
2文章
809瀏覽量
13971 -
H100
+關(guān)注
關(guān)注
0文章
33瀏覽量
425
原文標(biāo)題:詳解 NVIDIA H100 TransformerEngine
文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
英偉達(dá)a100和h100哪個(gè)強(qiáng)?英偉達(dá)A100和H100的區(qū)別
NVIDIA發(fā)布新一代產(chǎn)品—NVIDIA H100

GTC2022大會黃仁勛:NVIDIA H100的5項(xiàng)突破性創(chuàng)新

GTC2022大會亮點(diǎn):NVIDIA發(fā)布全新AI計(jì)算系統(tǒng)—DGX H100

NVIDIA發(fā)布DGX H100系統(tǒng) 羅德與施瓦茨提供O-RAN無線電單元方案
NVIDIA發(fā)布最新Hopper架構(gòu)的H100系列GPU和Grace CPU超級芯片
藍(lán)海大腦服務(wù)器全力支持NVIDIA H100 GPU
用NVIDIA H100 CNX構(gòu)建人工智能系統(tǒng)

利用NVIDIA HGX H100加速計(jì)算數(shù)據(jù)中心平臺應(yīng)用

評論