谷歌團(tuán)隊(duì)(非官方發(fā)布)打造了一個(gè)名為JAX的系統(tǒng),今日在Reddit引發(fā)了熱議。網(wǎng)友紛紛為它叫好——“說不定能夠取代TensorFlow”。本文便帶領(lǐng)讀者一覽JAX的廬山真面目。
這個(gè)工具說不定比TensorFlow還好用!
它就是JAX,一款由谷歌團(tuán)隊(duì)打造(非官方發(fā)布),用于從純Python和Numpy機(jī)器學(xué)習(xí)程序中生成高性能加速器(accelerator)代碼,且特定于域的跟蹤JIT編譯器。
那么JAX到底有哪些威力呢?
JAX使用XLA編譯器基礎(chǔ)結(jié)構(gòu),來為子程序生成最有利于加速的優(yōu)化代碼,這些優(yōu)化子程序可以由任意Python調(diào)用和編排;
由于JAX與Autograd完全兼容,它允許Python函數(shù)的正、反向模式(forward- and reverse-mode)自動(dòng)區(qū)分為任意順序;
由于JAX支持結(jié)構(gòu)化控制流,所以它可以在保持高性能的同時(shí)為復(fù)雜的機(jī)器學(xué)習(xí)算法生成代碼;
通過將JAX與Autograd和Numpy相結(jié)合,可得到一個(gè)易于編程且高性能的ML系統(tǒng),該系統(tǒng)面向CPU,GPU和TPU,且能擴(kuò)展到多核Cloud TPU。
此“神器”在Reddit上引發(fā)了熱烈的討論,網(wǎng)友紛紛為它叫好:
我的天,“可微分的numpy”實(shí)在是太棒了!我對(duì)pytorch有一點(diǎn)不是很滿意,他們基本上重新做了numpy所做的一切,但存在一些愚蠢的差異,比如“dim”,而不是“axis”,等等。
JAX系統(tǒng)設(shè)計(jì)一覽
谷歌團(tuán)隊(duì)通過觀察發(fā)現(xiàn),JAX的ML工作負(fù)載通常由PSC子程序控制。
JAX的設(shè)計(jì)便因此利用了函數(shù)通??梢灾苯釉跈C(jī)器學(xué)習(xí)代碼中識(shí)別的特性,使機(jī)器學(xué)習(xí)研究人員可以使用JAX的jit_ps修飾符進(jìn)行注釋。
雖然手工注釋對(duì)非專業(yè)用戶和“零工作量知識(shí)”優(yōu)化提出了挑戰(zhàn),但它為專家提供了直接的好處,而且作為一個(gè)系統(tǒng)研究項(xiàng)目,它展示了PSC假設(shè)的威力。
JAX跟蹤緩存為跟蹤計(jì)算的參數(shù)創(chuàng)建了一個(gè)monomorphic signature,以便新遇到的數(shù)組元素類型、數(shù)組維度或元組成員觸發(fā)重新編譯。在跟蹤緩存丟失時(shí),JAX執(zhí)行相應(yīng)的Python函數(shù),并將其執(zhí)行跟蹤到具有靜態(tài)數(shù)據(jù)依賴關(guān)系的原始函數(shù)圖中。
現(xiàn)有的原語不僅包括數(shù)組級(jí)別的數(shù)字內(nèi)核,包括Numpy函數(shù)和其他函數(shù),它們?cè)试S用戶通過保留PSC屬性將控制流分段到編譯后的計(jì)算中。最后,JAX包含一些用于功能分布式編程的原語,如iterated_map_reduce。
為了生成代碼,JAX將跟蹤轉(zhuǎn)換為XLA HLO,這是一種中間語言,可以對(duì)高度可加速的數(shù)組級(jí)數(shù)值程序進(jìn)行建模。從廣義上講,JAX可以被看作是一個(gè)系統(tǒng),它將XLA編程模型提升到Python中,并支持使用可加速的子程序,同時(shí)仍然允許動(dòng)態(tài)編排。
defxla_add(xla_builder,xla_args,np_x,np_y):returnxla_builder.Add(xla_args[0],xla_args[1])defxla_sinh(xla_builder,xla_args,np_x):b,xla_x=xla_builder,xla_args[0]returnb.Div(b.Sub(b.Exp(xla_x),b.Exp(b.Neg(xla_x))),b.Const(2))defxla_while(xla_builder,xla_args,cond_fun,body_fun,init_val):xla_cond=trace_computation(cond_fun,args=(init_val,))xla_body=trace_computation(body_fun,args=(init_val,))returnxla_builder.While(xla_cond,xla_body,xla_args[-1])jax.register_translation_rule(numpy.add,xla_add)jax.register_translation_rule(numpy.sinh,xla_sinh)jax.register_translation_rule(while_loop,xla_while)
JAX從原語到XLA HLO的翻譯規(guī)則
另外,JAX和Autograd完全兼容。
importautograd.numpyasnpfromautogradimportgradfromjaximportjit_psdefpredict(params,inputs):forW,binparamsoutputs=np.dot(inputs,W)+binputs=np.tanh(outputs)returnoutputsdefloss(params,inputs,targets):preds=predict(params,inputs)returnnp.sum((preds-targets)**2)grad_fun=jit_ps(grad(loss))#Compiledgradient-of-lossfunction
一個(gè)與JAX完全連接的基本神經(jīng)網(wǎng)絡(luò)
實(shí)驗(yàn)、性能結(jié)果比較
為了演示JAX和XLA提供的數(shù)組級(jí)代碼優(yōu)化和操作融合,谷歌團(tuán)隊(duì)編譯了一個(gè)具有SeLU非線性的完全連接神經(jīng)網(wǎng)絡(luò)層,并在圖1中顯示JAX trace和XLA HLO圖形。
圖1:XLA HLO對(duì)具有SeLU非線性的層進(jìn)行融合?;疑虮硎舅械牟僮鞫既诤系紾EMM中。
使用一個(gè)線程和幾個(gè)小的示例優(yōu)化問題(包括凸二次型、隱馬爾科夫模型(HMM)邊緣似然性和邏輯回歸)將Python執(zhí)行時(shí)間與CPU上的JAX編譯運(yùn)行時(shí)進(jìn)行了比較。
對(duì)于某些CPU示例來說,XLA的編譯時(shí)間比較慢,但將來可能會(huì)有顯著的改進(jìn),對(duì)于經(jīng)過warmed-up代碼(表1),XLA的編譯速度非常快。
表1:在CPU上Truncated Newton-CG的計(jì)時(shí)(秒)
在GPU上訓(xùn)練卷積網(wǎng)絡(luò)。谷歌團(tuán)隊(duì)實(shí)現(xiàn)了一個(gè)all-conv CIFAR-10網(wǎng)絡(luò),只涉及卷積和ReLU激活。谷歌編寫了一個(gè)單獨(dú)的隨機(jī)梯度下降(SGD)更新步驟,并從一個(gè)純Python循環(huán)中調(diào)用它,結(jié)果如表2所示。
作為參考,谷歌在TensorFlow中實(shí)現(xiàn)了相同的算法,并在類似的Python循環(huán)中調(diào)用它。
表2:GPU上JAX convnet步驟的計(jì)時(shí)(msec)
云TPU可擴(kuò)展性。云TPU核心上的全局批處理的JAX并行化呈現(xiàn)線性加速(圖2,左)。在固定的minibatch / replica中,texec受復(fù)制計(jì)數(shù)的影響最?。ㄔ?ms內(nèi),右邊)
圖2:為ConvNet訓(xùn)練步驟在云TPU上進(jìn)行擴(kuò)展。
-
谷歌
+關(guān)注
關(guān)注
27文章
6231瀏覽量
107856 -
編譯器
+關(guān)注
關(guān)注
1文章
1659瀏覽量
50062 -
機(jī)器學(xué)習(xí)
+關(guān)注
關(guān)注
66文章
8499瀏覽量
134337
原文標(biāo)題:試試谷歌這個(gè)新工具:說不定比TensorFlow還好用!
文章出處:【微信號(hào):AI_era,微信公眾號(hào):新智元】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
評(píng)論