在一般的 seq2seq 問(wèn)題中,如機(jī)器翻譯(第 10.5 節(jié)),輸入和輸出的長(zhǎng)度不同且未對(duì)齊。處理這類數(shù)據(jù)的標(biāo)準(zhǔn)方法是設(shè)計(jì)一個(gè)編碼器-解碼器架構(gòu)(圖 10.6.1),它由兩個(gè)主要組件組成:一個(gè) 編碼器,它以可變長(zhǎng)度序列作為輸入,以及一個(gè) 解碼器,作為一個(gè)條件語(yǔ)言模型,接收編碼輸入和目標(biāo)序列的向左上下文,并預(yù)測(cè)目標(biāo)序列中的后續(xù)標(biāo)記。
圖 10.6.1編碼器-解碼器架構(gòu)。
讓我們以從英語(yǔ)到法語(yǔ)的機(jī)器翻譯為例。給定一個(gè)英文輸入序列:“They”、“are”、“watching”、“.”,這種編碼器-解碼器架構(gòu)首先將可變長(zhǎng)度輸入編碼為一個(gè)狀態(tài),然后對(duì)該狀態(tài)進(jìn)行解碼以生成翻譯后的序列,token通過(guò)標(biāo)記,作為輸出:“Ils”、“regardent”、“.”。由于編碼器-解碼器架構(gòu)構(gòu)成了后續(xù)章節(jié)中不同 seq2seq 模型的基礎(chǔ),因此本節(jié)將此架構(gòu)轉(zhuǎn)換為稍后將實(shí)現(xiàn)的接口。
from torch import nn from d2l import torch as d2l
from mxnet.gluon import nn from d2l import mxnet as d2l
from flax import linen as nn from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf from d2l import tensorflow as d2l
10.6.1。編碼器
在編碼器接口中,我們只是指定編碼器將可變長(zhǎng)度序列作為輸入X。實(shí)現(xiàn)將由繼承此基類的任何模型提供Encoder。
class Encoder(nn.Module): #@save """The base encoder interface for the encoder-decoder architecture.""" def __init__(self): super().__init__() # Later there can be additional arguments (e.g., length excluding padding) def forward(self, X, *args): raise NotImplementedError
class Encoder(nn.Block): #@save """The base encoder interface for the encoder-decoder architecture.""" def __init__(self): super().__init__() # Later there can be additional arguments (e.g., length excluding padding) def forward(self, X, *args): raise NotImplementedError
class Encoder(nn.Module): #@save """The base encoder interface for the encoder-decoder architecture.""" def setup(self): raise NotImplementedError # Later there can be additional arguments (e.g., length excluding padding) def __call__(self, X, *args): raise NotImplementedError
class Encoder(tf.keras.layers.Layer): #@save """The base encoder interface for the encoder-decoder architecture.""" def __init__(self): super().__init__() # Later there can be additional arguments (e.g., length excluding padding) def call(self, X, *args): raise NotImplementedError
10.6.2。解碼器
在下面的解碼器接口中,我們添加了一個(gè)額外的init_state 方法來(lái)將編碼器輸出 ( enc_all_outputs) 轉(zhuǎn)換為編碼狀態(tài)。請(qǐng)注意,此步驟可能需要額外的輸入,例如輸入的有效長(zhǎng)度,這在 第 10.5 節(jié)中有解釋。為了逐個(gè)令牌生成可變長(zhǎng)度序列令牌,每次解碼器都可以將輸入(例如,在先前時(shí)間步生成的令牌)和編碼狀態(tài)映射到當(dāng)前時(shí)間步的輸出令牌。
class Decoder(nn.Module): #@save """The base decoder interface for the encoder-decoder architecture.""" def __init__(self): super().__init__() # Later there can be additional arguments (e.g., length excluding padding) def init_state(self, enc_all_outputs, *args): raise NotImplementedError def forward(self, X, state): raise NotImplementedError
class Decoder(nn.Block): #@save """The base decoder interface for the encoder-decoder architecture.""" def __init__(self): super().__init__() # Later there can be additional arguments (e.g., length excluding padding) def init_state(self, enc_all_outputs, *args): raise NotImplementedError def forward(self, X, state): raise NotImplementedError
class Decoder(nn.Module): #@save """The base decoder interface for the encoder-decoder architecture.""" def setup(self): raise NotImplementedError # Later there can be additional arguments (e.g., length excluding padding) def init_state(self, enc_all_outputs, *args): raise NotImplementedError def __call__(self, X, state): raise NotImplementedError
class Decoder(tf.keras.layers.Layer): #@save """The base decoder interface for the encoder-decoder architecture.""" def __init__(self): super().__init__() # Later there can be additional arguments (e.g., length excluding padding) def init_state(self, enc_all_outputs, *args): raise NotImplementedError def call(self, X, state): raise NotImplementedError
10.6.3。將編碼器和解碼器放在一起
在前向傳播中,編碼器的輸出用于產(chǎn)生編碼狀態(tài),解碼器將進(jìn)一步使用該狀態(tài)作為其輸入之一。
class EncoderDecoder(d2l.Classifier): #@save """The base class for the encoder-decoder architecture.""" def __init__(self, encoder, decoder): super().__init__() self.encoder = encoder self.decoder = decoder def forward(self, enc_X, dec_X, *args): enc_all_outputs = self.encoder(enc_X, *args) dec_state = self.decoder.init_state(enc_all_outputs, *args) # Return decoder output only return self.decoder(dec_X, dec_state)[0]
class EncoderDecoder(d2l.Classifier): #@save """The base class for the encoder-decoder architecture.""" def __init__(self, encoder, decoder): super().__init__() self.encoder = encoder self.decoder = decoder def forward(self, enc_X, dec_X, *args): enc_all_outputs = self.encoder(enc_X, *args) dec_state = self.decoder.init_state(enc_all_outputs, *args) # Return decoder output only return self.decoder(dec_X, dec_state)[0]
class EncoderDecoder(d2l.Classifier): #@save """The base class for the encoder-decoder architecture.""" encoder: nn.Module decoder: nn.Module training: bool def __call__(self, enc_X, dec_X, *args): enc_all_outputs = self.encoder(enc_X, *args, training=self.training) dec_state = self.decoder.init_state(enc_all_outputs, *args) # Return decoder output only return self.decoder(dec_X, dec_state, training=self.training)[0]
class EncoderDecoder(d2l.Classifier): #@save """The base class for the encoder-decoder architecture.""" def __init__(self, encoder, decoder): super().__init__() self.encoder = encoder self.decoder = decoder def call(self, enc_X, dec_X, *args): enc_all_outputs = self.encoder(enc_X, *args, training=True) dec_state = self.decoder.init_state(enc_all_outputs, *args) # Return decoder output only return self.decoder(dec_X, dec_state, training=True)[0]
在下一節(jié)中,我們將看到如何應(yīng)用 RNN 來(lái)設(shè)計(jì)基于這種編碼器-解碼器架構(gòu)的 seq2seq 模型。
10.6.4。概括
編碼器-解碼器架構(gòu)可以處理由可變長(zhǎng)度序列組成的輸入和輸出,因此適用于機(jī)器翻譯等 seq2seq 問(wèn)題。編碼器將可變長(zhǎng)度序列作為輸入,并將其轉(zhuǎn)換為具有固定形狀的狀態(tài)。解碼器將固定形狀的編碼狀態(tài)映射到可變長(zhǎng)度序列。
10.6.5。練習(xí)
假設(shè)我們使用神經(jīng)網(wǎng)絡(luò)來(lái)實(shí)現(xiàn)編碼器-解碼器架構(gòu)。編碼器和解碼器必須是同一類型的神經(jīng)網(wǎng)絡(luò)嗎?
除了機(jī)器翻譯,你能想到另一個(gè)可以應(yīng)用編碼器-解碼器架構(gòu)的應(yīng)用程序嗎?
-
解碼器
+關(guān)注
關(guān)注
9文章
1165瀏覽量
41825 -
編碼器
+關(guān)注
關(guān)注
45文章
3786瀏覽量
137580 -
pytorch
+關(guān)注
關(guān)注
2文章
809瀏覽量
13860
發(fā)布評(píng)論請(qǐng)先 登錄
怎么理解真正的編碼器和解碼器?
編碼器和解碼器的區(qū)別是什么,編碼器用軟件還是硬件好
詳解編碼器和解碼器電路:定義/工作原理/應(yīng)用/真值表

PyTorch教程10.6之編碼器-解碼器架構(gòu)

PyTorch教程10.7之用于機(jī)器翻譯的編碼器-解碼器Seq2Seq

PyTorch教程-10.7. 用于機(jī)器翻譯的編碼器-解碼器 Seq2Seq
基于transformer的編碼器-解碼器模型的工作原理

基于 RNN 的解碼器架構(gòu)如何建模

基于 Transformers 的編碼器-解碼器模型

神經(jīng)編碼器-解碼器模型的歷史

詳解編碼器和解碼器電路

視頻編碼器與解碼器的應(yīng)用方案

YXC丨視頻編碼器與解碼器的應(yīng)用方案

視頻編碼器與解碼器的應(yīng)用方案

評(píng)論