前言
MMPose是一款基于PyTorch的姿態(tài)分析開(kāi)源工具箱,是OpenMMLab項(xiàng)目成員之一,主要特性:
支持多種人體姿態(tài)分析相關(guān)任務(wù):2D多人姿態(tài)估計(jì)、2D手部姿態(tài)估計(jì)、動(dòng)物關(guān)鍵點(diǎn)檢測(cè)等等
更高的精度和更快的速度:包括“自頂向下”和“自底向上”兩大類算法
支持多樣的數(shù)據(jù)集:支持了很多主流數(shù)據(jù)集的準(zhǔn)備和構(gòu)建,如 COCO、 MPII等
模塊化設(shè)計(jì):將統(tǒng)一的人體姿態(tài)分析框架解耦成不同的模塊組件,通過(guò)組合不同的模塊組件,可以便捷地構(gòu)建自定義人體姿態(tài)分析模型
本文主要對(duì)動(dòng)物關(guān)鍵點(diǎn)檢測(cè)模型進(jìn)行微調(diào)與測(cè)試,從數(shù)據(jù)集構(gòu)造開(kāi)始,詳細(xì)解釋各模塊作用。對(duì)一些新手可能會(huì)犯的錯(cuò)誤做一些說(shuō)明
環(huán)境配置
mmcv的安裝方式在我前面的mmdetection和mmsegmentation教程中都有寫(xiě)到。這里不再提
MMPose安裝方法最好是使用git,如果沒(méi)有git工具,可以使用mim install mmpose
最后在項(xiàng)目文件夾下新建checkpoint、outputs、data文件夾,分別用來(lái)存放模型預(yù)訓(xùn)練權(quán)重、模型輸出結(jié)果、訓(xùn)練數(shù)據(jù)
from IPython import display !pip install openmim !pip install -q /kaggle/input/frozen-packages-mmdetection/mmcv-2.0.1-cp310-cp310-linux_x86_64.whl !git clone https://github.com/open-mmlab/mmdetection.git %cd mmdetection !pip install -e . %cd .. !git clone https://github.com/open-mmlab/mmpose.git %cd mmpose !pip install -e . !mkdir checkpoint !mkdir outputs !mkdir data display.clear_output()
在上面的安裝工作完成后,我們檢查一下環(huán)境,以及核對(duì)一下安裝版本
from IPython import display import mmcv from mmcv.ops import get_compiling_cuda_version, get_compiler_version print('MMCV版本', mmcv.__version__) %cd /kaggle/working/mmdetection import mmdet print('mmdetection版本', mmdet.__version__) %cd /kaggle/working/mmpose import mmpose print('mmpose版本', mmpose.__version__) print('CUDA版本', get_compiling_cuda_version()) print('編譯器版本', get_compiler_version())
輸出:
MMCV版本 2.0.1 /kaggle/working/mmdetection mmdetection版本 3.1.0 /kaggle/working/mmpose mmpose版本 1.1.0 CUDA版本 11.8 編譯器版本 GCC 11.3
?為方便后續(xù)進(jìn)行文件操作,導(dǎo)入一些常用庫(kù)
import os import io import json import shutil import random import numpy as np from pathlib import Path from PIL import Image from tqdm import tqdm from mmengine import Config from pycocotools.coco import COCO
預(yù)訓(xùn)練模型推理
在進(jìn)行姿態(tài)估計(jì)前需要目標(biāo)檢測(cè)將不同的目標(biāo)檢測(cè)出來(lái),然后再分別對(duì)不同的目標(biāo)進(jìn)行姿態(tài)估計(jì)。所以我們要選擇一個(gè)目標(biāo)檢測(cè)模型。
這里選擇的是mmdetection工具箱中的RTMDet模型,型號(hào)選擇RTMDet-l。配置文件位于mmdetection/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py,我們復(fù)制模型權(quán)重地址并進(jìn)行下載。
姿態(tài)估計(jì)模型選擇RTMPose模型,打開(kāi)mmpose項(xiàng)目文件夾projects/rtmpose/README.md文檔,發(fā)現(xiàn)RTMPose模型動(dòng)物姿態(tài)估計(jì)(Animal 2d (17 Keypoints))僅提供了一個(gè)預(yù)訓(xùn)練模型。
配置文件位于projects/rtmpose/rtmpose/animal_2d_keypoint/rtmpose-m_8xb64-210e_ap10k-256x256.py,我們復(fù)制模型權(quán)重地址并進(jìn)行下載。
將預(yù)訓(xùn)練權(quán)重模型全部放入mmpose項(xiàng)目文件夾的checkpoint文件夾下。
# 下載RTMDet-L模型,用于目標(biāo)檢測(cè) !wget https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth -P checkpoint # 下載RTMPose模型,用于姿態(tài)估計(jì) !wget https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-ap10k_pt-aic-coco_210e-256x256-7a041aa1_20230206.pth -P checkpoint display.clear_output()
MMPose提供了一個(gè)被稱為MMPoseInferencer的、全面的推理API。這個(gè)API使得用戶得以使用所有MMPose支持的模型來(lái)對(duì)圖像和視頻進(jìn)行模型推理。此外,該API可以完成推理結(jié)果自動(dòng)化,并方便用戶保存預(yù)測(cè)結(jié)果。
我們使用Cat Dataset數(shù)據(jù)集中的一張圖片作為示例,進(jìn)行模型推理。推理參數(shù)說(shuō)明:
det_model:mmdetection工具箱中目標(biāo)檢測(cè)模型配置文件
det_weights:mmdetection工具箱中目標(biāo)檢測(cè)模型對(duì)應(yīng)預(yù)訓(xùn)練權(quán)重文件
pose2d:mmpose工具箱中姿態(tài)估計(jì)模型配置文件
pose2d_weights:mmpose工具箱中姿態(tài)估計(jì)對(duì)應(yīng)預(yù)訓(xùn)練權(quán)重文件
out_dir:圖片生成的文件夾
from mmpose.apis import MMPoseInferencer img_path = '/kaggle/input/cat-dataset/CAT_00/00000001_012.jpg' # 使用模型別名創(chuàng)建推斷器 inferencer = MMPoseInferencer(det_model = '/kaggle/working/mmdetection/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py', ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?det_weights = 'checkpoint/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth', ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?pose2d = 'projects/rtmpose/rtmpose/animal_2d_keypoint/rtmpose-m_8xb64-210e_ap10k-256x256.py', ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?pose2d_weights = 'checkpoint/rtmpose-m_simcc-ap10k_pt-aic-coco_210e-256x256-7a041aa1_20230206.pth') # MMPoseInferencer采用了惰性推斷方法,在給定輸入時(shí)創(chuàng)建一個(gè)預(yù)測(cè)生成器 result_generator = inferencer(img_path, out_dir = 'outputs', show=False) result = next(result_generator) display.clear_output()
可視化推理結(jié)果
import matplotlib.pyplot as plt img_og = mmcv.imread(img_path) img_fuse = mmcv.imread('outputs/visualizations/00000001_012.jpg') fig, axes = plt.subplots(1, 2, figsize=(15, 10)) axes[0].imshow(mmcv.bgr2rgb(img_og)) axes[0].set_title('Original Image') axes[0].axis('off') axes[1].imshow(mmcv.bgr2rgb(img_fuse)) axes[1].set_title('Keypoint Image') axes[1].axis('off') plt.show()
數(shù)據(jù)處理
數(shù)據(jù)內(nèi)容詳解
Cat Dataset包含9000多張貓圖像。對(duì)于每張圖像,都有貓頭部的注釋,有9個(gè)點(diǎn),2個(gè)用于眼睛,1個(gè)用于嘴巴,6個(gè)用于耳朵。
注釋數(shù)據(jù)存儲(chǔ)在1個(gè)文件中,文件名是相應(yīng)的圖像名稱,末尾加上“cat”。每張貓圖像都有1個(gè)注釋文件。對(duì)于每個(gè)注釋文件,注釋數(shù)據(jù)按以下順序存儲(chǔ):
?○Number of points (關(guān)鍵點(diǎn)數(shù)目)
?○Left Eye(左眼)
?○Right Eye(右眼)
?○Mouth(嘴)
?○Left Ear-1(左耳-1)
?○Left Ear-2(左耳-2)
?○Left Ear-3(左耳-3)
?○Right Ear-1(右耳-1)
?○Right Ear-2(右耳-2)
?○Right Ear-3(左耳-3)
數(shù)據(jù)集最初在互聯(lián)網(wǎng)檔案館中找到,網(wǎng)站(https://archive.org/details/CAT_DATASET)
數(shù)據(jù)層級(jí)目錄如下所示:
- CAT_00 ? ? - 00000001_000.jpg ? ? - 00000001_000.jpg.cat ? ? - 00000001_005.jpg ? ? - 00000001_005.jpg.cat ? ? - ... - CAT_01 ? ? - 00000100_002.jpg ? ? - 00000100_002.jpg.cat ? ? - 00000100_003.jpg ? ? - 00000100_003.jpg.cat - CAT_02 - CAT_03 - CAT_04 - CAT_05 - CAT_06
總的來(lái)說(shuō),一共有7個(gè)文件夾,每個(gè)文件夾里面有若干.jpg格式的圖片文件,且對(duì)應(yīng)有.cat格式的注釋文件,.cat文件可以看做是文本文件,內(nèi)容示例:
9 435 322 593 315 524 446 318 285 283 118 430 195 568 186 701 81 703 267?
除第1個(gè)數(shù)字9表示有9個(gè)關(guān)鍵點(diǎn),后面每2個(gè)點(diǎn)表示1個(gè)部位的坐標(biāo)(x,y),所以一共有1 + 2 * 9 = 19個(gè)點(diǎn)
文件夾規(guī)整
我們將數(shù)據(jù)集中的7個(gè)文件夾中的圖片與注釋文件分開(kāi),分別存儲(chǔ)在mmpose項(xiàng)目文件夾data文件夾中,并分別命名為images、ann
def separate_files(og_folder, trans_folder): ? ?image_folder = os.path.join(trans_folder, 'images') ? ?ann_folder = os.path.join(trans_folder, 'ann') ? ?os.makedirs(image_folder, exist_ok=True) ? ?os.makedirs(ann_folder, exist_ok=True) ? ?for folder in os.listdir(data_folder): ? ? ? ?folder_path = os.path.join(data_folder, folder) ? ? ? ?if os.path.isdir(folder_path): ? ? ? ? ? ?for file in os.listdir(folder_path): ? ? ? ? ? ? ? ?if file.endswith('.jpg'): ? ? ? ? ? ? ? ? ? ?source_path = os.path.join(folder_path, file) ? ? ? ? ? ? ? ? ? ?target_path = os.path.join(image_folder, file) ? ? ? ? ? ? ? ? ? ?shutil.copy(source_path, target_path) ? ? ? ? ? ? ? ?elif file.endswith('.cat'): ? ? ? ? ? ? ? ? ? ?source_path = os.path.join(folder_path, file) ? ? ? ? ? ? ? ? ? ?target_path = os.path.join(ann_folder, file) ? ? ? ? ? ? ? ? ? ?shutil.copy(source_path, target_path) data_folder = '/kaggle/input/cat-dataset' trans_folder = './data' separate_files(data_folder, trans_folder)
構(gòu)造COCO注釋文件
本質(zhì)上來(lái)說(shuō)COCO就是1個(gè)字典文件,第1級(jí)鍵包含images、annotations、categories。
?○其中images包含id(圖片的唯一標(biāo)識(shí),必須要是數(shù)值型,不能有字符) 、file_name(圖片名字)、?height(圖片高度),?width(圖片寬度)這些信息
?○其中annotations包含category_id(圖片所屬種類)、segmentation(實(shí)例分割掩碼)、iscrowd(決定是RLE格式還是polygon格式)、image_id(圖片id,對(duì)應(yīng)images鍵中的id)、id(注釋信息id)、bbox(目標(biāo)檢測(cè)框,[x, y, width, height])、?area(目標(biāo)檢測(cè)框面積)、num_keypoints(關(guān)鍵點(diǎn)數(shù)量),?keypoints(關(guān)鍵點(diǎn)坐標(biāo))
?○其中categories包含supercategory、id(類別id)、name(類別名)、keypoints(各部位名稱)、skeleton(部位連接信息)
?○更詳細(xì)的COCO(https://zhuanlan.zhihu.com/p/29393415)注釋文件解析推薦博客COCO數(shù)據(jù)集的標(biāo)注格式、如何將VOC XML文件轉(zhuǎn)化成COCO數(shù)據(jù)格式(https://www.cnblogs.com/marsggbo/p/11152462.html)
?○構(gòu)造read_file_as_list函數(shù),將注釋文件中的坐標(biāo)變成[x,y,v],v為0時(shí)表示這個(gè)關(guān)鍵點(diǎn)沒(méi)有標(biāo)注,v為1時(shí)表示這個(gè)關(guān)鍵點(diǎn)標(biāo)注了但是不可見(jiàn)(被遮擋了),v為2時(shí)表示這個(gè)關(guān)鍵點(diǎn)標(biāo)注了同時(shí)可見(jiàn)。因?yàn)閿?shù)據(jù)集中部位坐標(biāo)均標(biāo)注且可見(jiàn),所以在x,y坐標(biāo)后均插入2。
def read_file_as_list(file_path): ? ?with open(file_path, 'r') as file: ? ? ? ?content = file.read() ? ? ? ?key_point = [int(num) for num in content.split()] ? ? ? ?key_num = key_point[0] ? ? ? ?key_point.pop(0) ? ? ? ?for i in range(2, len(key_point) + len(key_point)//2, 2 + 1): ? ? ? ? ? ?key_point.insert(i, 2) ? ?return key_num,key_point
構(gòu)造get_image_size函數(shù),用于獲取圖片寬度和高度。
def get_image_size(image_path): ? ?with Image.open(image_path) as img: ? ? ? ?width, height = img.size ? ?return width, height
因?yàn)閿?shù)據(jù)集沒(méi)有提供目標(biāo)檢測(cè)框信息,且圖片中基本無(wú)干擾元素,所以將目標(biāo)檢測(cè)框信息置為[0, 0, width, height]即整張圖片。相應(yīng)的目標(biāo)檢測(cè)框面積area = width * height。
# 轉(zhuǎn)換為coco數(shù)據(jù)格式 def coco_structure(ann_dir,image_dir): ? ?coco = dict() ? ?coco['images'] = [] ? ?coco['annotations'] = [] ? ?coco['categories'] = [] ? ?coco['categories'].append(dict(supercategory = 'cat',id = 1,name = 'cat', ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? keypoints = ['Left Eye','Right Eye','Mouth','Left Ear-1','Left Ear-2','Left Ear-3','Right Ear-1','Right Ear-2','Right Ear-3'], ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? skeleton = [[0,1],[0,2],[1,2],[3,4],[4,5],[5,6],[6,7],[7,8],[3,8]])) ? ?ann_list = os.listdir(ann_dir) ? ?id = 0 ? ?for file_name in tqdm(ann_list): ? ? ? ?key_num,key_point = read_file_as_list(os.path.join(ann_dir, file_name)) ? ? ? ?if key_num == 9: ? ? ? ? ? ?image_name = os.path.splitext(file_name)[0] ? ? ? ? ? ?image_id = os.path.splitext(image_name)[0] ? ? ? ? ? ?height, width = get_image_size(os.path.join(image_dir, image_name)) ? ? ? ? ? ?image = {"id": id, "file_name": image_name, "height": height, "width": width} ? ? ? ? ? ?coco['images'].append(image) ? ? ? ? ? ?key_dict = dict(category_id = 1, segmentation = [], iscrowd = 0, image_id = id, ? ? ? ? ? ? ? ? ? ?id = id, bbox = [0, 0, width, height], area = width * height, num_keypoints = key_num, keypoints = key_point) ? ? ? ? ? ?coco['annotations'].append(key_dict) ? ? ? ? ? ?id = id + 1 ? ?return coco
寫(xiě)入注釋信息,并將其保存為mmpose項(xiàng)目文件夾data/annotations_all.json文件
ann_file = coco_structure('./data/ann','./data/images') output_file_path = ?'./data/annotations_all.json' with open(output_file_path, "w", encoding="utf-8") as output_file: ? ?json.dump(ann_file, output_file, ensure_ascii=True, indent=4)
拆分訓(xùn)練、測(cè)試數(shù)據(jù)
按0.85、0.15的比例將注釋文件拆分為訓(xùn)練、測(cè)試文件
def split_coco_dataset(coco_json_path: str, save_dir: str, ratios: list, ? ? ? ? ? ? ? ? ? ? ? shuffle: bool, seed: int): ? ?if not Path(coco_json_path).exists(): ? ? ? ?raise FileNotFoundError(f'Can not not found {coco_json_path}') ? ?if not Path(save_dir).exists(): ? ? ? ?Path(save_dir).mkdir(parents=True) ? ?ratios = np.array(ratios) / np.array(ratios).sum() ? ?if len(ratios) == 2: ? ? ? ?ratio_train, ratio_test = ratios ? ? ? ?ratio_val = 0 ? ? ? ?train_type = 'trainval' ? ?elif len(ratios) == 3: ? ? ? ?ratio_train, ratio_val, ratio_test = ratios ? ? ? ?train_type = 'train' ? ?else: ? ? ? ?raise ValueError('ratios must set 2 or 3 group!') ? ?coco = COCO(coco_json_path) ? ?coco_image_ids = coco.getImgIds() ? ?val_image_num = int(len(coco_image_ids) * ratio_val) ? ?test_image_num = int(len(coco_image_ids) * ratio_test) ? ?train_image_num = len(coco_image_ids) - val_image_num - test_image_num ? ?print('Split info: ====== ' ? ? ? ? ?f'Train ratio = {ratio_train}, number = {train_image_num} ' ? ? ? ? ?f'Val ratio = {ratio_val}, number = {val_image_num} ' ? ? ? ? ?f'Test ratio = {ratio_test}, number = {test_image_num}') ? ?seed = int(seed) ? ?if seed != -1: ? ? ? ?print(f'Set the global seed: {seed}') ? ? ? ?np.random.seed(seed) ? ?if shuffle: ? ? ? ?print('shuffle dataset.') ? ? ? ?random.shuffle(coco_image_ids) ? ?train_image_ids = coco_image_ids[:train_image_num] ? ?if val_image_num != 0: ? ? ? ?val_image_ids = coco_image_ids[train_image_num:train_image_num + ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? val_image_num] ? ?else: ? ? ? ?val_image_ids = None ? ?test_image_ids = coco_image_ids[train_image_num + val_image_num:] ? ?categories = coco.loadCats(coco.getCatIds()) ? ?for img_id_list in [train_image_ids, val_image_ids, test_image_ids]: ? ? ? ?if img_id_list is None: ? ? ? ? ? ?continue ? ? ? ?img_dict = { ? ? ? ? ? ?'images': coco.loadImgs(ids=img_id_list), ? ? ? ? ? ?'categories': categories, ? ? ? ? ? ?'annotations': coco.loadAnns(coco.getAnnIds(imgIds=img_id_list)) ? ? ? ?} ? ? ? ?if img_id_list == train_image_ids: ? ? ? ? ? ?json_file_path = Path(save_dir, f'{train_type}.json') ? ? ? ?elif img_id_list == val_image_ids: ? ? ? ? ? ?json_file_path = Path(save_dir, 'val.json') ? ? ? ?elif img_id_list == test_image_ids: ? ? ? ? ? ?json_file_path = Path(save_dir, 'test.json') ? ? ? ?else: ? ? ? ? ? ?raise ValueError('img_id_list ERROR!') ? ? ? ?print(f'Saving json to {json_file_path}') ? ? ? ?with open(json_file_path, 'w') as f_json: ? ? ? ? ? ?json.dump(img_dict, f_json, ensure_ascii=False, indent=2) ? ?print('All done!')
輸出:
loading annotations into memory... Done (t=0.13s) creating index... index created! Split info: ====== Train ratio = 0.85, number = 8495 Val ratio = 0, number = 0 Test ratio = 0.15, number = 1498 Set the global seed: 2023 shuffle dataset. Saving json to data/trainval.json Saving json to data/test.json All done!
可以看到訓(xùn)練集有8495張圖片,測(cè)試集有1498張圖片
模型配置文件
打開(kāi)項(xiàng)目文件夾下的projects/rtmpose/rtmpose/animal_2d_keypoint/rtmpose-m_8xb64-210e_ap10k-256x256.py文件,發(fā)現(xiàn)模型配置文件僅繼承_base_/default_runtime.py文件
需要修改主要有dataset_type、data_mode、dataset_info、codec、train_dataloader 、test_dataloader 、val_evaluator、base_lr、max_epochs、default_hooks。還有一些細(xì)節(jié)我在代碼中有標(biāo)注,可以參照著修改
修改完成后將文件寫(xiě)入./configs/animal_2d_keypoint/cat_keypoint.py中
custom_config = """ _base_ = ['mmpose::_base_/default_runtime.py'] # 數(shù)據(jù)集類型及路徑 dataset_type = 'CocoDataset' data_mode = 'topdown' data_root = './data/' work_dir = './work_dir' # cat dataset關(guān)鍵點(diǎn)檢測(cè)數(shù)據(jù)集-元數(shù)據(jù) dataset_info = { ? ?'dataset_name':'Keypoint_cat', ? ?'classes':'cat', ? ?'paper_info':{ ? ? ? ?'author':'Luck', ? ? ? ?'title':'Cat Keypoints Detection', ? ?}, ? ?'keypoint_info':{ ? ? ? ?0:{'name':'Left Eye','id':0,'color':[255,0,0],'type': '','swap': ''}, ? ? ? ?1:{'name':'Right Eye','id':1,'color':[255,127,0],'type': '','swap': ''}, ? ? ? ?2:{'name':'Mouth','id':2,'color':[255,255,0],'type': '','swap': ''}, ? ? ? ?3:{'name':'Left Ear-1','id':3,'color':[0,255,0],'type': '','swap': ''}, ? ? ? ?4:{'name':'Left Ear-2','id':4,'color':[0,255,255],'type': '','swap': ''}, ? ? ? ?5:{'name':'Left Ear-3','id':5,'color':[0,0,255],'type': '','swap': ''}, ? ? ? ?6:{'name':'Right Ear-1','id':6,'color':[139,0,255],'type': '','swap': ''}, ? ? ? ?7:{'name':'Right Ear-2','id':7,'color':[255,0,255],'type': '','swap': ''}, ? ? ? ?8:{'name':'Right Ear-3','id':8,'color':[160,82,45],'type': '','swap': ''} ? ?}, ? ?'skeleton_info': { ? ? ? ?0: {'link':('Left Eye','Right Eye'),'id': 0,'color': [255,0,0]}, ? ? ? ?1: {'link':('Left Eye','Mouth'),'id': 1,'color': [255,0,0]}, ? ? ? ?2: {'link':('Right Eye','Mouth'),'id': 2,'color': [255,0,0]}, ? ? ? ?3: {'link':('Left Ear-1','Left Ear-2'),'id': 3,'color': [255,0,0]}, ? ? ? ?4: {'link':('Left Ear-2','Left Ear-3'),'id': 4,'color': [255,0,0]}, ? ? ? ?5: {'link':('Left Ear-3','Right Ear-1'),'id': 5,'color': [255,0,0]}, ? ? ? ?6: {'link':('Right Ear-1','Right Ear-2'),'id': 6,'color': [255,0,0]}, ? ? ? ?7: {'link':('Right Ear-2','Right Ear-3'),'id': 7,'color': [255,0,0]}, ? ? ? ?8: {'link':('Left Ear-1','Right Ear-3'),'id': 8,'color': [255,0,0]}, ? ?} } # 獲取關(guān)鍵點(diǎn)個(gè)數(shù) NUM_KEYPOINTS = len(dataset_info['keypoint_info']) dataset_info['joint_weights'] = [1.0] * NUM_KEYPOINTS dataset_info['sigmas'] = [0.025] * NUM_KEYPOINTS # 訓(xùn)練超參數(shù) max_epochs = 100 val_interval = 5 train_cfg = {'max_epochs': max_epochs, 'val_begin':20, 'val_interval': val_interval} train_batch_size = 32 val_batch_size = 32 stage2_num_epochs = 10 base_lr = 4e-3 / 16 randomness = dict(seed=2023) # 優(yōu)化器 optim_wrapper = dict( ? ?type='OptimWrapper', ? ?optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), ? ?paramwise_cfg=dict( ? ? ? ?norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) # 學(xué)習(xí)率 param_scheduler = [ ? ?dict(type='LinearLR', start_factor=1.0e-5, by_epoch=False, begin=0, end=600), ? ?dict( ? ? ? ?type='CosineAnnealingLR', ? ? ? ?eta_min=base_lr * 0.05, ? ? ? ?begin=max_epochs // 2, ? ? ? ?end=max_epochs, ? ? ? ?T_max=max_epochs // 2, ? ? ? ?by_epoch=True, ? ? ? ?convert_to_iter_based=True), ] # automatically scaling LR based on the actual training batch size auto_scale_lr = dict(base_batch_size=1024) # codec settings # input_size可以換成128的倍數(shù) # sigma高斯分布標(biāo)準(zhǔn)差,越大越易學(xué),但進(jìn)度低。高精度場(chǎng)景,可以調(diào)小,RTMPose 原始論文中為 5.66 # input_size、sigma和下面model中的in_featuremap_size參數(shù)需要成比例縮放 codec = dict( ? ?type='SimCCLabel', ? ?input_size=(512, 512), ? ?sigma=(24, 24), ? ?simcc_split_ratio=2.0, ? ?normalize=False, ? ?use_dark=False) # 模型:RTMPose-M model = dict( ? ?type='TopdownPoseEstimator', ? ?data_preprocessor=dict( ? ? ? ?type='PoseDataPreprocessor', ? ? ? ?mean=[123.675, 116.28, 103.53], ? ? ? ?std=[58.395, 57.12, 57.375], ? ? ? ?bgr_to_rgb=True), ? ?backbone=dict( ? ? ? ?_scope_='mmdet', ? ? ? ?type='CSPNeXt', ? ? ? ?arch='P5', ? ? ? ?expand_ratio=0.5, ? ? ? ?deepen_factor=0.67, ? ? ? ?widen_factor=0.75, ? ? ? ?out_indices=(4, ), ? ? ? ?channel_attention=True, ? ? ? ?norm_cfg=dict(type='SyncBN'), ? ? ? ?act_cfg=dict(type='SiLU'), ? ? ? ?init_cfg=dict( ? ? ? ? ? ?type='Pretrained', ? ? ? ? ? ?prefix='backbone.', ? ? ? ? ? ?checkpoint='https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-m_8xb256-rsb-a1-600e_in1k-ecb3bbd9.pth' ? ? ? ?)), ? ?head=dict( ? ? ? ?type='RTMCCHead', ? ? ? ?in_channels=768, ? ? ? ?out_channels=NUM_KEYPOINTS, ? ? ? ?input_size=codec['input_size'], ? ? ? ?in_featuremap_size=(16, 16), ? ? ? ?simcc_split_ratio=codec['simcc_split_ratio'], ? ? ? ?final_layer_kernel_size=7, ? ? ? ?gau_cfg=dict( ? ? ? ? ? ?hidden_dims=256, ? ? ? ? ? ?s=128, ? ? ? ? ? ?expansion_factor=2, ? ? ? ? ? ?dropout_rate=0., ? ? ? ? ? ?drop_path=0., ? ? ? ? ? ?act_fn='SiLU', ? ? ? ? ? ?use_rel_bias=False, ? ? ? ? ? ?pos_enc=False), ? ? ? ?loss=dict( ? ? ? ? ? ?type='KLDiscretLoss', ? ? ? ? ? ?use_target_weight=True, ? ? ? ? ? ?beta=10., ? ? ? ? ? ?label_softmax=True), ? ? ? ?decoder=codec), ? ?test_cfg=dict(flip_test=True)) backend_args = dict(backend='local') # pipelines train_pipeline = [ ? ?dict(type='LoadImage', backend_args=backend_args), ? ?dict(type='GetBBoxCenterScale'), ? ?dict(type='RandomFlip', direction='horizontal'), ? ?# dict(type='RandomHalfBody'), ? ?dict( ? ? ? ?type='RandomBBoxTransform', scale_factor=[0.8, 1.2], rotate_factor=30), ? ?dict(type='TopdownAffine', input_size=codec['input_size']), ? ?dict(type='mmdet.YOLOXHSVRandomAug'), ? ?dict( ? ? ? ?type='Albumentation', ? ? ? ?transforms=[ ? ? ? ? ? ?dict(type='ChannelShuffle', p=0.5), ? ? ? ? ? ?dict(type='CLAHE', p=0.5), ? ? ? ? ? ?# dict(type='Downscale', scale_min=0.7, scale_max=0.9, p=0.2), ? ? ? ? ? ?dict(type='ColorJitter', p=0.5), ? ? ? ? ? ?dict( ? ? ? ? ? ? ? ?type='CoarseDropout', ? ? ? ? ? ? ? ?max_holes=4, ? ? ? ? ? ? ? ?max_height=0.3, ? ? ? ? ? ? ? ?max_width=0.3, ? ? ? ? ? ? ? ?min_holes=1, ? ? ? ? ? ? ? ?min_height=0.2, ? ? ? ? ? ? ? ?min_width=0.2, ? ? ? ? ? ? ? ?p=0.5), ? ? ? ?]), ? ?dict(type='GenerateTarget', encoder=codec), ? ?dict(type='PackPoseInputs') ] val_pipeline = [ ? ?dict(type='LoadImage', backend_args=backend_args), ? ?dict(type='GetBBoxCenterScale'), ? ?dict(type='TopdownAffine', input_size=codec['input_size']), ? ?dict(type='PackPoseInputs') ] train_pipeline_stage2 = [ ? ?dict(type='LoadImage', backend_args=backend_args), ? ?dict(type='GetBBoxCenterScale'), ? ?dict(type='RandomFlip', direction='horizontal'), ? ?dict(type='RandomHalfBody'), ? ?dict( ? ? ? ?type='RandomBBoxTransform', ? ? ? ?shift_factor=0., ? ? ? ?scale_factor=[0.75, 1.25], ? ? ? ?rotate_factor=60), ? ?dict(type='TopdownAffine', input_size=codec['input_size']), ? ?dict(type='mmdet.YOLOXHSVRandomAug'), ? ?dict( ? ? ? ?type='Albumentation', ? ? ? ?transforms=[ ? ? ? ? ? ?dict(type='Blur', p=0.1), ? ? ? ? ? ?dict(type='MedianBlur', p=0.1), ? ? ? ? ? ?dict( ? ? ? ? ? ? ? ?type='CoarseDropout', ? ? ? ? ? ? ? ?max_holes=1, ? ? ? ? ? ? ? ?max_height=0.4, ? ? ? ? ? ? ? ?max_width=0.4, ? ? ? ? ? ? ? ?min_holes=1, ? ? ? ? ? ? ? ?min_height=0.2, ? ? ? ? ? ? ? ?min_width=0.2, ? ? ? ? ? ? ? ?p=0.5), ? ? ? ?]), ? ?dict(type='GenerateTarget', encoder=codec), ? ?dict(type='PackPoseInputs') ] # data loaders train_dataloader = dict( ? ?batch_size=train_batch_size, ? ?num_workers=2, ? ?persistent_workers=True, ? ?sampler=dict(type='DefaultSampler', shuffle=True), ? ?dataset=dict( ? ? ? ?type=dataset_type, ? ? ? ?data_root=data_root, ? ? ? ?metainfo=dataset_info, ? ? ? ?data_mode=data_mode, ? ? ? ?ann_file='trainval.json', ? ? ? ?data_prefix=dict(img='images/'), ? ? ? ?pipeline=train_pipeline, ? ?)) val_dataloader = dict( ? ?batch_size=val_batch_size, ? ?num_workers=2, ? ?persistent_workers=True, ? ?drop_last=False, ? ?sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), ? ?dataset=dict( ? ? ? ?type=dataset_type, ? ? ? ?data_root=data_root, ? ? ? ?metainfo=dataset_info, ? ? ? ?data_mode=data_mode, ? ? ? ?ann_file='test.json', ? ? ? ?data_prefix=dict(img='images/'), ? ? ? ?pipeline=val_pipeline, ? ?)) test_dataloader = val_dataloader default_hooks = { ? ?'checkpoint': {'save_best': 'PCK','rule': 'greater','max_keep_ckpts': 2}, ? ?'logger': {'interval': 50} } custom_hooks = [ ? ?dict( ? ? ? ?type='EMAHook', ? ? ? ?ema_type='ExpMomentumEMA', ? ? ? ?momentum=0.0002, ? ? ? ?update_buffers=True, ? ? ? ?priority=49), ? ?dict( ? ? ? ?type='mmdet.PipelineSwitchHook', ? ? ? ?switch_epoch=max_epochs - stage2_num_epochs, ? ? ? ?switch_pipeline=train_pipeline_stage2) ] # evaluators val_evaluator = [ ? ?dict(type='CocoMetric', ann_file=data_root + 'test.json'), ? ?dict(type='PCKAccuracy'), ? ?dict(type='AUC'), ? ?dict(type='NME', norm_mode='keypoint_distance', keypoint_indices=[0, 1]) ] test_evaluator = val_evaluator """ config = './configs/animal_2d_keypoint/cat_keypoint.py' with io.open(config, 'w', encoding='utf-8') as f: ? ?f.write(custom_config)
模型訓(xùn)練
使用訓(xùn)練腳本啟動(dòng)訓(xùn)練
!python tools/train.py {config}
因?yàn)橛?xùn)練輸出太長(zhǎng),這里截取一段模型在測(cè)試集上最佳精度:
08/06 19:15:56 - mmengine - INFO - Evaluating CocoMetric... Loading and preparing results... DONE (t=0.07s) creating index... index created! Running per image evaluation... Evaluate annotation type *keypoints* DONE (t=0.57s). Accumulating evaluation results... DONE (t=0.03s). Average Precision ?(AP) @[ IoU=0.50:0.95 | area= ? all | maxDets= 20 ] = ?0.943 Average Precision ?(AP) @[ IoU=0.50 ? ? ?| area= ? all | maxDets= 20 ] = ?0.979 Average Precision ?(AP) @[ IoU=0.75 ? ? ?| area= ? all | maxDets= 20 ] = ?0.969 Average Precision ?(AP) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = -1.000 Average Precision ?(AP) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = ?0.944 Average Recall ? ? (AR) @[ IoU=0.50:0.95 | area= ? all | maxDets= 20 ] = ?0.953 Average Recall ? ? (AR) @[ IoU=0.50 ? ? ?| area= ? all | maxDets= 20 ] = ?0.987 Average Recall ? ? (AR) @[ IoU=0.75 ? ? ?| area= ? all | maxDets= 20 ] = ?0.977 Average Recall ? ? (AR) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = -1.000 Average Recall ? ? (AR) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = ?0.953 08/06 19:15:56 - mmengine - INFO - Evaluating PCKAccuracy (normalized by ``"bbox_size"``)... 08/06 19:15:56 - mmengine - INFO - Evaluating AUC... 08/06 19:15:56 - mmengine - INFO - Evaluating NME... 08/06 19:15:57 - mmengine - INFO - Epoch(val) [60][47/47] ? ?coco/AP: 0.943453 ?coco/AP .5: 0.979424 ?coco/AP .75: 0.969202 ?coco/AP (M): -1.000000 ?coco/AP (L): 0.944082 ?coco/AR: 0.953471 ?coco/AR .5: 0.987316 ?coco/AR .75: 0.977303 ?coco/AR (M): -1.000000 ?coco/AR (L): 0.953471 ?PCK: 0.978045 ?AUC: 0.801710 ?NME: 0.121770 ?data_time: 0.101005 ?time: 0.435133 08/06 19:15:57 - mmengine - INFO - The previous best checkpoint /kaggle/working/mmpose/work_dir/best_PCK_epoch_55.pth is removed 08/06 19:16:01 - mmengine - INFO - The best checkpoint with 0.9780 PCK at 60 epoch is saved to best_PCK_epoch_60.pth.
可以看到模型PCK達(dá)到了0.978,AUC達(dá)到了0.8017,mAP也都挺高,說(shuō)明模型效果非常不錯(cuò)!
模型精簡(jiǎn)
mmpose提供模型精簡(jiǎn)腳本,模型訓(xùn)練權(quán)重文件大小減少一半,但不影響精度和推理
將在驗(yàn)證集上表現(xiàn)最好的模型權(quán)重進(jìn)行精簡(jiǎn)
import glob ckpt_path = glob.glob('./work_dir/best_PCK_*.pth')[0] ckpt_sim = './work_dir/cat_pose_sim.pth' # 模型精簡(jiǎn) !python tools/misc/publish_model.py ? ? ? ?{ckpt_path} ? ? ? ?{ckpt_sim}
模型推理
這里和上面的模型推理使用相同的思路,使用RTMDet模型進(jìn)行目標(biāo)檢測(cè),使用我們自己訓(xùn)練的RTMPose模型進(jìn)行姿態(tài)估計(jì)。
不過(guò)pose2d參數(shù)是我們上面保存的配置文件./configs/animal_2d_keypoint/cat_keypoint.py,pose2d_weights為最佳精度模型精簡(jiǎn)后的權(quán)重文件glob.glob('./work_dir/cat_pose_sim*.pth')[0]。
img_path = '/kaggle/input/cat-dataset/CAT_00/00000001_012.jpg' inferencer = MMPoseInferencer(det_model = '/kaggle/working/mmdetection/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py', ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?det_weights = 'checkpoint/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth', ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?pose2d = './configs/animal_2d_keypoint/cat_keypoint.py', ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?pose2d_weights = glob.glob('./work_dir/cat_pose_sim*.pth')[0]) result_generator = inferencer(img_path, out_dir = 'outputs', show=False) result = next(result_generator) display.clear_output()
可視化訓(xùn)練結(jié)果
img_og = mmcv.imread(img_path) img_fuse = mmcv.imread('outputs/visualizations/00000001_012.jpg') fig, axes = plt.subplots(1, 2, figsize=(15, 10)) axes[0].imshow(mmcv.bgr2rgb(img_og)) axes[0].set_title('Original Image') axes[0].axis('off') axes[1].imshow(mmcv.bgr2rgb(img_fuse)) axes[1].set_title('Keypoint Image') axes[1].axis('off') plt.show()
編輯:黃飛
?
評(píng)論