在下面的解码器接口中,我们新增一个init_state
函数,用于将编码器的输出(enc_outputs
)转换为编码后的状态。注意,此步骤可能需要额外的输入,例如:输入序列的有效长度,这在机器翻译与数据集中进行了解释。为了逐个地生成长度可变的词元序列,解码器在每个时间步都会将输入(例如:在前一时间步生成的词元)和编码后的状态映射成当前时间步的输出词元。
#@save
class Decoder(nn.Module):
"""编码器-解码器架构的基本解码器接口"""
def __init__(self, **kwargs):
super(Decoder, self).__init__(**kwargs)
def init_state(self, enc_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError