作為最為流行的深度學(xué)習(xí)資源庫,TensorFlow 是幫助深度學(xué)習(xí)新方法走向?qū)崿F(xiàn)的強(qiáng)大工具。它為大多數(shù)深度學(xué)習(xí)領(lǐng)域中使用的常用語言提供了大量應(yīng)用程序接口。對于開發(fā)者和研究人員來說,在開啟新的項(xiàng)目前首先面臨的問題是:如何構(gòu)建一個(gè)簡單明了的結(jié)構(gòu),本文或許可以為你帶來幫助。
TensorFlow 項(xiàng)目模板
簡潔而精密的結(jié)構(gòu)對于深度學(xué)習(xí)項(xiàng)目來說是必不可少的,在經(jīng)過多次練習(xí)和 TensorFlow 項(xiàng)目開發(fā)之后,本文作者提出了一個(gè)結(jié)合簡便性、優(yōu)化文件結(jié)構(gòu)和良好 OOP 設(shè)計(jì)的 TensorFlow 項(xiàng)目模板。該模板可以幫助你快速啟動(dòng)自己的 TensorFlow 項(xiàng)目,直接從實(shí)現(xiàn)自己的核心思想開始。
這個(gè)簡單的模板可以幫助你直接從構(gòu)建模型、訓(xùn)練等任務(wù)開始工作。
目錄
概述
詳述
項(xiàng)目架構(gòu)
文件夾結(jié)構(gòu)
主要組件
模型
訓(xùn)練器
數(shù)據(jù)加載器
記錄器
配置
Main
未來工作
概述
簡言之,本文介紹的是這一模板的使用方法,例如,如果你希望實(shí)現(xiàn) VGG 模型,那么你應(yīng)該:
在模型文件夾中創(chuàng)建一個(gè)名為 VGG 的類,由它繼承「base_model」類
classVGGModel(BaseModel):
def __init__(self, config):
super(VGGModel, self).__init__(config)
#call the build_model and init_saver functions.
self.build_model()
self.init_saver()
覆寫這兩個(gè)函數(shù) "build_model",在其中執(zhí)行你的 VGG 模型;以及定義 TensorFlow 保存的「init_saver」,隨后在 initalizer 中調(diào)用它們。
def build_model(self):
# here you build the tensorflow graph of any model you want and also define the loss.
pass
def init_saver(self):
#here you initalize the tensorflow saver that will be used in saving the checkpoints.
self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)
在 trainers 文件夾中創(chuàng)建 VGG 訓(xùn)練器,繼承「base_train」類。
classVGGTrainer(BaseTrain):
def __init__(self, sess, model, data, config, logger):
super(VGGTrainer, self).__init__(sess, model, data, config, logger)
覆寫這兩個(gè)函數(shù)「train_step」、「train_epoch」,在其中寫入訓(xùn)練過程的邏輯。
def train_epoch(self):
"""
implement the logic of epoch:
-loop ever the number of iteration in the config and call teh train step
-add any summaries you want using the sammary
"""
pass
def train_step(self):
"""
implement the logic of the train step
- run the tensorflow session
- return any metrics you need to summarize
"""
pass
在主文件中創(chuàng)建會(huì)話,創(chuàng)建以下對象:「Model」、「Logger」、「Data_Generator」、「Trainer」與配置:
sess = tf.Session()
# create instance of the model you want
model =VGGModel(config)
# create your data generator
data =DataGenerator(config)
# create tensorboard logger
logger =Logger(sess, config)
向所有這些對象傳遞訓(xùn)練器對象,通過調(diào)用「trainer.train()」開始訓(xùn)練。
trainer =VGGTrainer(sess, model, data, config, logger)
# here you train your model
trainer.train()
你會(huì)看到模板文件、一個(gè)示例模型和訓(xùn)練文件夾,向你展示如何快速開始你的第一個(gè)模型。
詳述
模型架構(gòu)
文件夾結(jié)構(gòu)
├── base
│ ├── base_model.py - this file contains the abstract class of the model.
│ └── ease_train.py - this file contains the abstract class of the trainer.
│
│
├── model -This folder contains any model of your project.
│ └── example_model.py
│
│
├── trainer -this folder contains trainers of your project.
│ └── example_trainer.py
│
├── mains - here's the main/s of your project (you may need more than one main.
│
│
├── data _loader
│ └── data_generator.py - here's the data_generator that responsible for all data handling.
│
└── utils
├── logger.py
└── any_other_utils_you_need
主要組件
模型
基礎(chǔ)模型
基礎(chǔ)模型是一個(gè)必須由你所創(chuàng)建的模型繼承的抽象類,其背后的思路是:絕大多數(shù)模型之間都有很多東西是可以共享的。基礎(chǔ)模型包含:
Save-此函數(shù)可保存 checkpoint 至桌面。
Load-此函數(shù)可加載桌面上的 checkpoint。
Cur-epoch、Global_step counters-這些變量會(huì)跟蹤訓(xùn)練 epoch 和全局步。
Init_Saver-一個(gè)抽象函數(shù),用于初始化保存和加載 checkpoint 的操作,注意:請?jiān)谝獙?shí)現(xiàn)的模型中覆蓋此函數(shù)。
Build_model-是一個(gè)定義模型的抽象函數(shù),注意:請?jiān)谝獙?shí)現(xiàn)的模型中覆蓋此函數(shù)。
你的模型
以下是你在模型中執(zhí)行的地方。因此,你應(yīng)該:
創(chuàng)建你的模型類并繼承 base_model 類。
覆寫 "build_model",在其中寫入你想要的 tensorflow 模型。
覆寫"init_save",在其中你創(chuàng)建 tensorflow 保存器,以用它保存和加載檢查點(diǎn)。
在 initalizer 中調(diào)用"build_model" 和 "init_saver"
訓(xùn)練器
基礎(chǔ)訓(xùn)練器
基礎(chǔ)訓(xùn)練器(Base trainer)是一個(gè)只包裝訓(xùn)練過程的抽象的類。
你的訓(xùn)練器
以下是你應(yīng)該在訓(xùn)練器中執(zhí)行的。
創(chuàng)建你的訓(xùn)練器類,并繼承 base_trainer 類。
覆寫這兩個(gè)函數(shù),在其中你執(zhí)行每一步和每一 epoch 的訓(xùn)練過程。
數(shù)據(jù)加載器
這些類負(fù)責(zé)所有的數(shù)據(jù)操作和處理,并提供一個(gè)可被訓(xùn)練器使用的易用接口。
記錄器(Logger)
這個(gè)類負(fù)責(zé) tensorboard 總結(jié)。在你的訓(xùn)練器中創(chuàng)建一個(gè)有關(guān)所有你想要的 tensorflow 變量的詞典,并將其傳遞給 logger.summarize()。
配置
我使用 Json 作為配置方法,接著解析它,因此寫入所有你想要的配置,然后用"utils/config/process_config"解析它,并把這個(gè)配置對象傳遞給所有其他對象。
Main
以下是你整合的所有之前的部分。
1. 解析配置文件。
2. 創(chuàng)建一個(gè) TensorFlow 會(huì)話。
3. 創(chuàng)建 "Model"、"Data_Generator" 和 "Logger"實(shí)例,并解析所有它們的配置。
4. 創(chuàng)建一個(gè)"Trainer"實(shí)例,并把之前所有的對象傳遞給它。
5. 現(xiàn)在你可通過調(diào)用"Trainer.train()"訓(xùn)練你的模型。
-
深度學(xué)習(xí)
+關(guān)注
關(guān)注
73文章
5547瀏覽量
122306 -
tensorflow
+關(guān)注
關(guān)注
13文章
330瀏覽量
60934
原文標(biāo)題:快速開啟你的第一個(gè)項(xiàng)目:TensorFlow項(xiàng)目架構(gòu)模板
文章出處:【微信號:CAAI-1981,微信公眾號:中國人工智能學(xué)會(huì)】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
如何使用tensorflow快速搭建起一個(gè)深度學(xué)習(xí)項(xiàng)目
【分享】制作專屬的LabVIEW 項(xiàng)目模板
使用 TensorFlow, 你必須明白 TensorFlow
TensorFlow是什么
TensorFlow的特點(diǎn)和基本的操作方式
《AI 概論》教師手冊(第一篇)——活用Excel模板
如何建立一個(gè)KEIL工程模板
TensorFlow是什么?如何啟動(dòng)并運(yùn)行TensorFlow?
一個(gè)實(shí)用的GitHub項(xiàng)目:TensorFlow-Cookbook
Makefile的項(xiàng)目模板免費(fèi)下載

TensorFlow Community Spotlight獲獎(jiǎng)項(xiàng)目
TensorFlow和PyTorch的實(shí)際應(yīng)用比較
搭建基于Vue3+Vite2+Arco+Typescript+Pinia后臺管理系統(tǒng)模板

評論