大家在訓(xùn)練深度學(xué)習(xí)模型的時(shí)候,有沒(méi)有遇到這樣的場(chǎng)景:分類任務(wù)的準(zhǔn)確率比較高,但是模型輸出的預(yù)測(cè)概率和實(shí)際預(yù)測(cè)準(zhǔn)確率存在比較大的差異?這就是現(xiàn)代深度學(xué)習(xí)模型面臨的校準(zhǔn)問(wèn)題。在很多場(chǎng)景中,我們不僅關(guān)注分類效果或者排序效果(auc),還希望模型預(yù)測(cè)的概率也是準(zhǔn)的。例如在自動(dòng)駕駛場(chǎng)景中,如果模型無(wú)法以置信度較高的水平檢測(cè)行人或障礙物,就應(yīng)該通過(guò)輸出概率反映出來(lái),并讓模型依賴其他信息進(jìn)行決策。再比如在廣告場(chǎng)景中,ctr預(yù)測(cè)除了給廣告排序外,還會(huì)用于確定最終的扣費(fèi)價(jià)格,如果ctr的概率預(yù)測(cè)的不準(zhǔn),會(huì)導(dǎo)致廣告主的扣費(fèi)偏高或偏低。
那么,為什么深度學(xué)習(xí)模型經(jīng)常出現(xiàn)預(yù)測(cè)概率和真實(shí)情況差異大的問(wèn)題?又該如何進(jìn)行校準(zhǔn)呢?這篇文章首先給大家介紹模型輸出預(yù)測(cè)概率不可信的原因,再為大家通過(guò)10篇頂會(huì)論文介紹經(jīng)典的校準(zhǔn)方法,可以適用于非常廣泛的場(chǎng)景。
1 為什么會(huì)出現(xiàn)校準(zhǔn)差的問(wèn)題
最早進(jìn)行系統(tǒng)性的分析深度學(xué)習(xí)輸出概率偏差問(wèn)題的是2017年在ICML發(fā)表的一篇文章On calibration of modern neural networks(ICML 2017)。文中發(fā)現(xiàn),相比早期的簡(jiǎn)單神經(jīng)網(wǎng)絡(luò)模型,現(xiàn)在的模型越來(lái)越大,效果越來(lái)越好,但同時(shí)模型的校準(zhǔn)性越來(lái)越差。文中對(duì)比了簡(jiǎn)單模型LeNet和現(xiàn)代模型ResNet的校準(zhǔn)情況,LeNet的輸出結(jié)果校準(zhǔn)性很好,而ResNet則出現(xiàn)了比較嚴(yán)重的過(guò)自信問(wèn)題(over-confidence),即模型輸出的置信度很高,但實(shí)際的準(zhǔn)確率并沒(méi)有那么高。
造成這個(gè)現(xiàn)象的最本質(zhì)原因,是模型對(duì)分類問(wèn)題通常使用的交叉熵?fù)p失過(guò)擬合。并且模型越復(fù)雜,擬合能力越強(qiáng),越容易過(guò)擬合交叉熵?fù)p失,帶來(lái)校準(zhǔn)效果變差。這也解釋了為什么隨著深度學(xué)習(xí)模型的發(fā)展,校準(zhǔn)問(wèn)題越來(lái)越凸顯出來(lái)。
那么為什么過(guò)擬合交叉熵?fù)p失,就會(huì)導(dǎo)致校準(zhǔn)問(wèn)題呢?因?yàn)楦鶕?jù)交叉熵?fù)p失的公式可以看出,即使模型已經(jīng)在正確類別上的輸出概率值最大(也就是分類已經(jīng)正確了),繼續(xù)增大對(duì)應(yīng)的概率值仍然能使交叉熵進(jìn)一步減小。因此模型會(huì)傾向于over-confident,即對(duì)于樣本盡可能的讓模型預(yù)測(cè)為正確的label對(duì)應(yīng)的概率接近1。模型過(guò)擬合交叉熵,帶來(lái)了分類準(zhǔn)確率的提升,但是犧牲的是模型輸出概率的可信度。
如何解決校準(zhǔn)性差的問(wèn)題,讓模型輸出可信的概率值呢?業(yè)內(nèi)的主要方法包括后處理和在模型中聯(lián)合優(yōu)化校準(zhǔn)損失兩個(gè)方向,下面給大家分別進(jìn)行介紹。
2 后處理校準(zhǔn)方法
后處理校準(zhǔn)方法指的是,先正常訓(xùn)練模型得到初始的預(yù)測(cè)結(jié)果,再對(duì)這些預(yù)測(cè)概率值進(jìn)行后處理,讓校準(zhǔn)后的預(yù)測(cè)概率更符合真實(shí)情況。典型的方法包括Histogram binning(2001)、Isotonic regression(2002)和Platt scaling(1999)。
Histogram binning是一種比較簡(jiǎn)單的校準(zhǔn)方法,根據(jù)初始預(yù)測(cè)結(jié)果進(jìn)行排序后分桶,每個(gè)桶內(nèi)求解一個(gè)校準(zhǔn)后的結(jié)果,落入這個(gè)桶內(nèi)的預(yù)測(cè)結(jié)果,都會(huì)被校準(zhǔn)成這個(gè)值。每個(gè)桶校準(zhǔn)值的求解方法是利用一個(gè)驗(yàn)證集進(jìn)行擬合,求解桶內(nèi)平均誤差最小的值,其實(shí)也就是落入該桶內(nèi)正樣本的比例。
Isotonic regression是Histogram binning一種擴(kuò)展,通過(guò)學(xué)習(xí)一個(gè)單調(diào)增函數(shù),輸入初始預(yù)測(cè)結(jié)果,輸出校準(zhǔn)后的預(yù)測(cè)結(jié)果,利用這個(gè)單調(diào)增函數(shù)最小化預(yù)測(cè)值和label之間的誤差。保序回歸就是在不改變預(yù)測(cè)結(jié)果的排序(即不影響模型的排序能力),通過(guò)修改每個(gè)元素的值讓整體的誤差最小,進(jìn)而實(shí)現(xiàn)模型糾偏。
Platt scaling則直接使用一個(gè)邏輯回歸模型學(xué)習(xí)基礎(chǔ)預(yù)測(cè)值到校準(zhǔn)預(yù)測(cè)值的函數(shù),利用這個(gè)函數(shù)實(shí)現(xiàn)預(yù)測(cè)結(jié)果校準(zhǔn)。在獲得基礎(chǔ)預(yù)估結(jié)果后,以此作為輸入,訓(xùn)練一個(gè)邏輯回歸模型,擬合校準(zhǔn)后的結(jié)果,也是在一個(gè)單獨(dú)的驗(yàn)證集上進(jìn)行訓(xùn)練。這個(gè)方法的問(wèn)題在于對(duì)校準(zhǔn)前的預(yù)測(cè)值和真實(shí)值之間的關(guān)系做了比較強(qiáng)分布假設(shè)。
3 在模型中進(jìn)行校準(zhǔn)
除了后處理的校準(zhǔn)方法外,一些在模型訓(xùn)練過(guò)程中實(shí)現(xiàn)校準(zhǔn)的方法獲得越來(lái)越多的關(guān)注。在模型中進(jìn)行校準(zhǔn)避免了后處理的兩階段方式,主要包括在損失函數(shù)中引入校準(zhǔn)項(xiàng)、label smoothing以及數(shù)據(jù)增強(qiáng)三種方式。
基于損失函數(shù)的校準(zhǔn)方法最基礎(chǔ)的是On calibration of modern neural networks(ICML 2017)這篇文章提出的temperature scaling方法。Temperature scaling的實(shí)現(xiàn)方式很簡(jiǎn)單,把模型最后一層輸出的logits(softmax的輸入)除以一個(gè)常數(shù)項(xiàng)。這里的temperature起到了對(duì)logits縮放的作用,讓輸出的概率分布熵更大(溫度系數(shù)越大越接近均勻分布)。同時(shí),這樣又不會(huì)改變?cè)瓉?lái)預(yù)測(cè)類別概率值的相對(duì)排序,因此理論上不會(huì)對(duì)模型準(zhǔn)確率產(chǎn)生負(fù)面影響。
在Trainable calibration measures for neural networks from kernel mean embeddings(2018)這篇文章中,作者直接定義了一個(gè)可導(dǎo)的校準(zhǔn)loss,作為一個(gè)輔助loss在模型中和交叉熵loss聯(lián)合學(xué)習(xí)。本文定義的MMCE原理來(lái)自評(píng)估模型校準(zhǔn)度的指標(biāo),即模型輸出類別概率值與模型正確預(yù)測(cè)該類別樣本占比的差異。
在Calibrating deep neural networks using focal loss(NIPS 2020)中,作者提出直接使用focal loss替代交叉熵?fù)p失,就可以起到校準(zhǔn)作用。Focal loss是表示學(xué)習(xí)中的常用函數(shù),對(duì)focal loss不了解的同學(xué)可以參考之前的文章:表示學(xué)習(xí)中的7大損失函數(shù)梳理。作者對(duì)focal loss進(jìn)行推倒,可以拆解為如下兩項(xiàng),分別是預(yù)測(cè)分布與真實(shí)分布的KL散度,以及預(yù)測(cè)分布的熵。KL散度和一般的交叉熵作用相同,而第二項(xiàng)在約束模型輸出的預(yù)測(cè)概率值熵盡可能大,其實(shí)和temperature scaling的原理類似,都是緩解模型在某個(gè)類別上打分太高而帶來(lái)的過(guò)自信問(wèn)題:
除了修改損失函數(shù)實(shí)現(xiàn)校準(zhǔn)的方法外,label smoothing也是一種常用的校準(zhǔn)方法,最早在Regularizing neural networks by penalizing confident output distributions(ICLR 2017)中提出了label smoothing在模型校準(zhǔn)上的應(yīng)用,后來(lái)又在When does label smoothing help? (NIPS 2019)進(jìn)行了更加深入的探討。Label smoothing通過(guò)如下公式對(duì)原始的label進(jìn)行平滑操作,其原理也是增大輸出概率分布的熵:
此外,一些研究也研究了數(shù)據(jù)增強(qiáng)手段對(duì)模型校準(zhǔn)的影響。On mixup training: Improved calibration and predictive uncertainty for deep neural networks(NIPS 2019)提出mixup方法可以有效提升模型校準(zhǔn)程度。Mixup是一種簡(jiǎn)單有效的數(shù)據(jù)增強(qiáng)策略,具體實(shí)現(xiàn)上,隨機(jī)從數(shù)據(jù)集中抽取兩個(gè)樣本,將它們的特征和label分別進(jìn)行加權(quán)融合,得到一個(gè)新的樣本用于訓(xùn)練:
文中作者提出,上面融合過(guò)程中對(duì)label的融合對(duì)取得校準(zhǔn)效果好的預(yù)測(cè)結(jié)果是非常重要的,這和上面提到的label smoothing思路比較接近,讓label不再是0或1的超低熵分布,來(lái)緩解模型過(guò)自信問(wèn)題。
類似的方法還包括CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features(ICCV 2019)提出的一種對(duì)Mixup方法的擴(kuò)展,隨機(jī)選擇兩個(gè)圖像和label后,對(duì)每個(gè)patch隨機(jī)選擇是否使用另一個(gè)圖像相應(yīng)的patch進(jìn)行替換,也起到了和Mixup類似的效果。文中也對(duì)比了Mixup和CutMix的效果,Mixup由于每個(gè)位置都進(jìn)行插值,容易造成區(qū)域信息的混淆,而CutMix直接進(jìn)行替換,不同區(qū)域的差異更加明確。
4 總結(jié)
本文梳理了深度學(xué)習(xí)模型的校準(zhǔn)方法,包含10篇經(jīng)典論文的工作。通過(guò)校準(zhǔn),可以讓模型輸出的預(yù)測(cè)概率更加可信,可以應(yīng)用于各種類型、各種場(chǎng)景的深度學(xué)習(xí)模型中,適用場(chǎng)景非常廣泛。
審核編輯:劉清
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4814瀏覽量
103531
原文標(biāo)題:不要相信模型輸出的概率打分......
文章出處:【微信號(hào):zenRRan,微信公眾號(hào):深度學(xué)習(xí)自然語(yǔ)言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
ADS1291測(cè)試中經(jīng)常會(huì)出現(xiàn)R波變小的情況,為什么?
使用ADUM4121ARIZ輸出15V電壓驅(qū)動(dòng)mos/IGBT時(shí),經(jīng)常出現(xiàn)10ohm電阻損壞的情況,為什么?
ADS1299利用信號(hào)發(fā)生器發(fā)出的正弦信號(hào)讀到的數(shù)據(jù)經(jīng)常出現(xiàn)毛刺,怎么解決?
深度學(xué)習(xí)模型的魯棒性優(yōu)化
用tas5630驅(qū)動(dòng)容性負(fù)載,經(jīng)常出現(xiàn)損壞芯片的現(xiàn)象,怎么解決?
GPU深度學(xué)習(xí)應(yīng)用案例
FPGA加速深度學(xué)習(xí)模型的案例
AI大模型與深度學(xué)習(xí)的關(guān)系
FPGA做深度學(xué)習(xí)能走多遠(yuǎn)?
tvp5150am1 RST腳經(jīng)常出現(xiàn)復(fù)位不正常,為什么?
使用OPA129構(gòu)建了一個(gè)電荷放大器,6腳輸出經(jīng)常出現(xiàn)尖峰的原因?
THS4500 RG和RF的選值對(duì)輸出波形的影響怎么解決?
【《大語(yǔ)言模型應(yīng)用指南》閱讀體驗(yàn)】+ 基礎(chǔ)知識(shí)學(xué)習(xí)
深度學(xué)習(xí)模型有哪些應(yīng)用場(chǎng)景
深度學(xué)習(xí)模型量化方法

評(píng)論