TensorFlow Seq2seq库(contrib)

2019-01-31 18:12 更新

用于构建 seq2seq 模型和动态解码的模块,建立在 tf.contrib.rnn 库的顶部。

该库由两个主要组件组成:

  • tf.contrib.rnn.RNNCell 对象的新 attention 包装。
  • 一种新的面向对象的动态解码框架。

注意

attention 包装是 RNNCell 包装其他 RNNCell 对象并实现 attention 的对象。attention 的形式由一个子类 tf.contrib.seq2seq.AttentionMechanism 决定。这些子类描述了在创建包装时要使用的 attention 形式(例如,加法与乘法)。AttentionMechanism 的一个实例是由一个 memory 张量构成,从中创建查询键和值。

attention 机制

两个基本的 attention 机制是:tf.contrib.seq2seq.BahdanauAttention (附加的 attention,参考)和 tf.contrib.seq2seq.LuongAttention(增加的 attention,参考)

该 memory 张量传递的 attention 机制的构造,预计将被塑造 [batch_size, memory_max_time, memory_depth];并且通常一个额外的 memory_sequence_length 向量被接受。如果提供的话,memory 张量的行被零掩蔽,超过其真正的序列长度。

attention 机制也具有深度概念,通常被确定为构造参数 num_units。对于某些类型的 attention(如BahdanauAttention),查询和内存都将投射到深度 num_units 的张量。对于其他类型(如LuongAttention),num_units 应该匹配查询的深度;memory 张量将被投射到这个深度。

attention 包装器

基本的 attention 包装是 tf.contrib.seq2seq.DynamicAttentionWrapper。这个包装器接受一个 RNNCell 实例,一个实例 AttentionMechanism 和一个 attention 深度参数(attention_size);以及允许自定义中间计算的几个可选参数。

在每个时间步骤,这个包装器执行的基本计算是:

cell_inputs = concat([inputs, prev_state.attention], -1)
cell_output, next_cell_state = cell(cell_inputs, prev_state.cell_state)
score = attention_mechanism(cell_output)
alignments = softmax(score)
context = matmul(alignments, attention_mechanism.values)
attention = tf.layers.Dense(attention_size)(concat([cell_output, context], 1))
next_state = AttentionWrapperState(
  cell_state=next_cell_state,
  attention=attention)
output = attention
return output, next_state

在实践中,许多中间计算是可配置的。例如,初始连接 inputs 和 prev_state.attention 可以用另一种混合功能代替。在从分数计算对齐时,可以用其他选项替换函数 softmax。最后,包装器返回的输出可以配置为值 cell_output 而不是 attention。

使用 DynamicAttentionWrapper 的好处是它能很好地与其他包装器和下面描述的动态解码器一起播放。例如,你可以写:

cell = tf.contrib.rnn.DeviceWrapper(LSTMCell(512), "/device:GPU:0")
attention_mechanism = tf.contrib.seq2seq.LuongAttention(512, encoder_outputs)
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
  cell, attention_mechanism, attention_size=256)
attn_cell = tf.contrib.rnn.DeviceWrapper(attn_cell, "/device:GPU:1")
top_cell = tf.contrib.rnn.DeviceWrapper(LSTMCell(512), "/device:GPU:1")
multi_cell = MultiRNNCell([attn_cell, top_cell])

所述 multi_rnn 单元将执行对 GPU 0 底层计算; attention 计算将在 GPU 1 上执行,并立即传递到也在 GPU 1 上计算的顶层。attention 也在时间上传递到下一个时间步,并在下一个时间步骤复制到 GPU 0单元。(注意:这只是一个使用的例子,而不是建议的设备分区策略。)

TensorFlow 动态解码

解码器基类和功能

  • tf.contrib.seq2seq.Decoder
  • tf.contrib.seq2seq.dynamic_decode

基本解码器

  • tf.contrib.seq2seq.BasicDecoderOutput
  • tf.contrib.seq2seq.BasicDecoder

解码器助手

  • tf.contrib.seq2seq.Helper
  • tf.contrib.seq2seq.CustomHelper
  • tf.contrib.seq2seq.GreedyEmbeddingHelper
  • tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper
  • tf.contrib.seq2seq.ScheduledOutputTrainingHelper
  • tf.contrib.seq2seq.TrainingHelper
以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号