01 ? ? ? 序
很久沒(méi)有寫(xiě)過(guò)博客了,最近忙于做項(xiàng)目,閉了后簡(jiǎn)單寫(xiě)一下心得體會(huì)。近期主要是在做服飾場(chǎng)景相關(guān)的項(xiàng)目,今天簡(jiǎn)單寫(xiě)寫(xiě)其中做的一個(gè)服飾分割。
初版代碼已經(jīng)提交,歡迎大家提issue和pr
https://github.com/FlyEgle/segmentationlight
02 ? ? ? 背景
主要場(chǎng)景就是對(duì)模特進(jìn)行服飾摳圖,要求邊緣處理相對(duì)平滑,扣取召回和準(zhǔn)確率比較高,能夠覆蓋95%以上的場(chǎng)景case。同時(shí)需要考慮模型FLOPs以及結(jié)構(gòu)便宜性,便于后期有壓縮的需求。
03 ? ? ? 模型選擇
有考慮過(guò)如下三種模型:
DeepLabV3
U2Net
HRNet-seg
這里DeepLabV3有空洞卷積存在,對(duì)于細(xì)致的扣圖,效果不是很好,更加適用于連通性比較強(qiáng)的物體分割以及多類(lèi)別分割。
HRNet-seg存在一個(gè)問(wèn)題,最后輸出的featuremap分別是[1/4, 1/8, 1/16, 1/32],雖然是有不斷的高低分辨率的交互,但是1/4還是有點(diǎn)捉襟見(jiàn)肘,會(huì)影響一些小的pixel,空洞以及邊緣效果。所以做了簡(jiǎn)單的修該如下:
FPN+upsmaple形式
FPN+upsample
upsmaple+cat
upsmaple+cat 相對(duì)來(lái)說(shuō)FPN的收斂速度會(huì)更快一些,計(jì)算量更小,性能略高,相比原始HRseg的輸出來(lái)說(shuō),細(xì)致化了很多。 3. U2net的計(jì)算量要比HRnet-fpn更小,同時(shí),U2net更加注重刻畫(huà)細(xì)節(jié)。由于場(chǎng)景只有一個(gè)類(lèi)別,所以U2Net不太需要考慮類(lèi)別的關(guān)系,對(duì)于模型本身來(lái)說(shuō)更加適配。 4. 有嘗試過(guò)修改U2net,包括增加attention,增加refine Module,多監(jiān)督約束以及修改結(jié)構(gòu)等,不過(guò)最終都比較雞肋了,寫(xiě)paper還是可以的,從實(shí)際case效果上看幾乎無(wú)差。也嘗試過(guò)x2,x4channel,性能上也沒(méi)明顯提升。 不過(guò)對(duì)于專(zhuān)一場(chǎng)景來(lái)說(shuō),模型本身不是重點(diǎn)。
u2net模型結(jié)構(gòu)
04 ? ? ? 損失設(shè)計(jì)
任務(wù)只要求區(qū)分前景和背景,自然可以理解為二分類(lèi)或者是1分類(lèi)問(wèn)題,所以基礎(chǔ)loss的選擇就可以是softmax+CE(二分類(lèi)),sigmoid+bce(前景)。sigmoid相比softmax對(duì)于邊緣效果更佳友好(可以調(diào)節(jié)閾值),為了保證連通區(qū)域,采用了bce+3*dice作為baseline損失。 這里在320x320尺寸下,做了一些對(duì)比實(shí)驗(yàn),可以看到bce+iou指標(biāo)最高,不過(guò)case by case的話(huà)視覺(jué)效果沒(méi)有bce+dice好。降低dice的系數(shù),也是因?yàn)橛袝r(shí)候dice過(guò)強(qiáng)忽略了bce判別正負(fù)樣本的情況。
損失函數(shù) 也嘗試過(guò)一些其他的loss,如focalloss,tv, L1等損失組合,意義不是很大, 代碼如下:
# ----------------- DICE Loss--------------------class DiceLoss(nn.Module): def __init__(self): super(DiceLoss, self).__init__() def forward(self, logits, targets, mask=False): num = targets.size(0) smooth = 1. probs = torch.sigmoid(logits) m1 = probs.view(num, -1) m2 = targets.view(num, -1) intersection = (m1 * m2) score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth) score = 1 - score.sum() / num return score# -------------------- BCELoss -----------------------class BCELoss(nn.Module): """binary bceloss with sigmoid""" def __init__(self): super(BCELoss, self).__init__() def forward(self, inputs, targets, weights=None, mask=False): assert len(inputs.shape) == 4, "inputs shape must be NCHW" if len(targets.shape) != 4: targets = targets.unsqueeze(1).float() else: targets = targets.float() if mask: inputs = inputs * targets losses = F.binary_cross_entropy_with_logits(inputs, targets, weights) return losses# ----------------- DICE+BCE Loss--------------------class DiceWithBCELoss(nn.Module): def __init__(self, weights, mining=False): super(DiceWithBCELoss, self).__init__() self.dice_loss = DiceLoss() if mining: self.bce_loss = BalanceCrossEntropyLoss() else: self.bce_loss = BCELoss() self.weights = weights def forward(self, preds, targets): bceloss = self.bce_loss(preds, targets) diceloss = self.dice_loss(preds, targets) ????????return?self.weights['bce']?*?bceloss?+?self.weights['dice']*diceloss
05 ? ? ? 訓(xùn)練優(yōu)化
1. 分辨率
baseline模型的訓(xùn)練尺寸為320x320,隨之提升到了640x640,這里采用兩種方法,一個(gè)是from strach訓(xùn)練一個(gè)是load 320的pretrain 進(jìn)行訓(xùn)練。相比于strach,pretrain的效果會(huì)更好,隨著數(shù)據(jù)的迭代和累積,不斷的采用上一個(gè)最好效果的weights來(lái)做下一次訓(xùn)練模型的pretrain,最終訓(xùn)練尺寸為800x800。 嘗試過(guò)采用更大的分辨率960和1024來(lái)進(jìn)行訓(xùn)練,在個(gè)人的場(chǎng)景上基本沒(méi)有顯著提升。(ps: 1024尺寸下的bs太小了,加了accumulate grad后性能下降的明顯)
2. 數(shù)據(jù)增強(qiáng)
數(shù)據(jù)增強(qiáng)采用基本都是常規(guī)的,隨機(jī)crop,隨機(jī)翻轉(zhuǎn),隨機(jī)旋轉(zhuǎn),隨機(jī)blur,這里colorjitter會(huì)影響性能就沒(méi)有用了。
def build_transformers(crop_size=(320, 320)): if isinstance(crop_size, int): crop_size = (crop_size, crop_size) data_aug = [ # RandomCropScale(scale_size=crop_size, scale=(0.4, 1.0)), RandomCropScale2(scale_size=crop_size, scale=(0.3, 1.2), prob=0.5), RandomHorizionFlip(p=0.5), RandomRotate(degree=15, mode=0), RandomGaussianBlur(p=0.2), ] to_tensor = [ Normalize(normalize=True, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), ToTensor(channel_first=True) ] final_aug = data_aug + to_tensor return Compose(final_aug)
比較重要的一點(diǎn)的是randomcrop,調(diào)整了crop的區(qū)域?yàn)閇0.3,1.2],一方面是因?yàn)閏rop區(qū)域太小,容易忽視整體性,另一方面是crop大一些可以相應(yīng)的對(duì)應(yīng)大分辨率。要注意的是,crop的區(qū)域是需要包含前景,可以通過(guò)設(shè)定前景占比來(lái)進(jìn)行調(diào)整,也可以理解為手動(dòng)balance數(shù)據(jù)。
class RandomCropScale2: """RandomCrop with Scale the images & targets, if not crop fit size, need to switch the prob to do reisze to keep the over figure scale_size : (list) a sequence of scale scale : default is (0.08, 1.0), crop region areas ratio : default is (3. / 4., 4. / 3.), ratio for width / height Returns: scale_image : (ndarray) crop and scale image scale_target: (ndarray) crop and scale target, shape is same with image """ def __init__(self, scale_size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), prob=0.5): self.scale_size = scale_size self.scale = scale self.ratio = ratio # self.prob = np.random.uniform(0, 1) > prob self.prob = prob self.scale_func = Scale(self.scale_size) # center crop # self.centercrop = CenterCrop(self.scale_size) if (self.scale[0] > self.scale[1]) or (self.ratio[0] > self.ratio[1]): warnings.warn("Scale and ratio should be of kind (min, max)") def _isBG(self, tgts): """If the targets all is 0, 0 is background """ if np.sum(tgts) == 0: return True else: return False # TODO: fix empty bug def _crop_imgs(self, imgs, tgts): height, width, _ = imgs.shape area = height * width for _ in range(10): target_area = area * np.random.uniform(self.scale[0], self.scale[1]) aspect_ratio = np.random.uniform(self.ratio[0], self.ratio[1]) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if 0 < w < width and 0 < h < height: random_y = np.random.randint(0, height - h + 1) random_x = np.random.randint(0, width - w + 1) crop_image = imgs[random_y:random_y+h, random_x:random_x+w] crop_target = tgts[random_y:random_y+h, random_x:random_x+w] if not self._isBG(crop_target): crop_image, crop_target = self.scale_func(crop_image, crop_target) return crop_image, crop_target # switch prob or center crop if np.random.uniform(0, 1) > self.prob: # center crop in_ratio = float(width) / float(height) if in_ratio < min(self.ratio): w = width h = int(round(w / min(self.ratio))) elif in_ratio > max(self.ratio): h = height w = int(round(h * max(self.ratio))) else: w = width h = height # navie center crop crop_x = max((width - w) // 2, 0) crop_y = max((height - h) // 2, 0) imgs = imgs[crop_y:crop_y+height, crop_x:crop_x+width] tgts = tgts[crop_y:crop_y+height, crop_x:crop_x+width] # scale crop_image, crop_target = self.scale_func(imgs, tgts) return crop_image, crop_target def __call__(self, imgs, tgts): crop_image, crop_target = self._crop_imgs(imgs, tgts) ????????return?crop_image,?crop_target
3. 數(shù)據(jù)
這個(gè)就仁者見(jiàn)仁智者見(jiàn)智了,查缺補(bǔ)漏就好,一般新數(shù)據(jù),我會(huì)用模型過(guò)濾一遍,卡個(gè)0.98或者0.99的miou,小于這個(gè)閾值的用于訓(xùn)練,大于閾值的采樣訓(xùn)練。 訓(xùn)練這里采用的是ADAMW優(yōu)化器,1e-2的weights decay,5e-4到1e-4調(diào)整學(xué)習(xí)率,視情況而定。(ADAMW偶爾會(huì)出現(xiàn)nan的問(wèn)題,要查找是否數(shù)據(jù)有nan,如果沒(méi)有大概率是因?yàn)橛衎n導(dǎo)致的數(shù)值溢出,可以調(diào)小LR或者更換優(yōu)化器)采用了CircleLR進(jìn)行衰減,效果還算ok,跑相同300個(gè)epoch,比CosineLR要好一點(diǎn)點(diǎn)。最終場(chǎng)景驗(yàn)證數(shù)據(jù)可以到達(dá)99%+的miou。
06 ? ? ? 邊緣優(yōu)化
Sigmoid訓(xùn)練后,可以簡(jiǎn)單的卡個(gè)閾值來(lái)進(jìn)行邊緣平滑處理,可以二值也可以過(guò)渡。
output[output >= thre] = 1 or None output[output < thre] = 0
?
邊緣
粗看邊緣還算可以,但是細(xì)看就發(fā)現(xiàn)鋸齒很明顯了,還需要進(jìn)一步處理,這里簡(jiǎn)單做了一個(gè)算法,縮放現(xiàn)有的mask(這里縮放可以用contour,也可以用腐蝕,也可以用shapely),把原始圖像做blur,把外圈的blur貼回來(lái)。
def edgePostProcess(mask, image): """Edge post Process Args: mask: a ndarray map, value is [0,255], shape is (h, w, 3) image: a ndarray map, value is 0-255, shape is(h, w, 3) Returns: outputs: edge blur image """ mask[mask==255] = 1 mask = getShrink(mask) image = image * mask image[image==0] = 255 blur_image = cv2.GaussianBlur(image, (5, 5), 0) new_mask = np.zeros(image.shape, np.uint8) contours, hierachy = cv2.findContours( mask[:,:,0], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) cv2.drawContours(new_mask, contours, -1, (255, 255, 255), 5) output = np.where(new_mask==np.array([255, 255, 255]), blur_image, image) return output
?
邊緣 其實(shí)可以看到,邊緣只是blur了,但是貼到白背景上可以發(fā)現(xiàn)視覺(jué)效果會(huì)好很多,這也是欺騙人眼的一個(gè)方法。 放一張高圓圓的照片吧,看一下分割后的結(jié)果
07 ? ? ? 代碼
這套代碼框架寫(xiě)了個(gè)把個(gè)月,包括了FCNs,SegNets,DeepLab,UNet,U2Net,HRNet等一些常用模型的實(shí)現(xiàn),loss,aug,lrshedule等,以及VOC上的一些pretrain。整體代碼簡(jiǎn)單明了,模塊分明,如果有需要后面可以考慮開(kāi)源。
model zoo 最后 ,本人不是主要做分割的,只是項(xiàng)目需要了就寫(xiě)了一套代碼框架,做了一些相關(guān)的實(shí)驗(yàn)探索,有一定的場(chǎng)景調(diào)優(yōu),不一定具備共性,歡迎大家討論~
編輯:黃飛
?
評(píng)論