在每个时间步计算可变长度输出的成本

Calculate cost of variable length output at each timestep

我正在尝试使用 LSTM 单元和 Tensorflow 创建文本生成神经网络。我正在用时间主要格式 [time_steps、batch_size、input_size] 的句子训练网络,我希望每个时间步都能预测序列中的下一个单词。该序列在时间步长之前用空值填充,并且一个单独的占位符包含批处理中每个序列的长度。

有很多关于时间反向传播概念的信息,但是我找不到任何关于在 tensorflow 中实际实现可变长度序列成本计算的信息。由于序列的末尾被填充,我假设我不想计算填充部分的成本。所以我需要一种方法来剪辑从第一个输出到序列末尾的输出。

这是我目前拥有的代码:

    outputs = []
    states = []
    cost = 0
    for i in range(time_steps+1):
        output, state = cell(X[i], state)
        z1 = tf.matmul(output, dec_W1) + dec_b1
        a1 = tf.nn.sigmoid(z1)
        z2 = tf.matmul(a1, dec_W2) + dec_b2
        a2 = tf.nn.softmax(z2)
        outputs.append(a2)
        states.append(state)
        #== calculate cost
        cost = cost + tf.nn.softmax_cross_entropy_with_logits(logits=z2, labels=y[i])
    optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)

此代码在没有可变长度序列的情况下工作。但是,如果我在末尾添加了填充值,那么它也会计算填充部分的成本,这没有多大意义。

如何只计算序列长度上限之前的输出成本?

解决了!

在研究了很多示例之后(大多数都在更高级别的框架中,例如 Keras,这很痛苦)我发现您必须创建一个掩码!回想起来似乎很简单。

这是创建 1 和 0 掩码的代码,然后按元素将其与矩阵相乘(这将是成本值)

x = tf.placeholder(tf.float32)
seq = tf.placeholder(tf.int32)

def mask_by_length(input_matrix, length):
    '''
        Input matrix is a 2d tensor [batch_size, time_steps]
        length is a 1d tensor
        length refers to the length of input matrix axis 1
    '''
    length_transposed = tf.expand_dims(length, 1)

    # Create range in order to compare length to
    range = tf.range(tf.shape(input_matrix)[1])
    range_row = tf.expand_dims(range, 0)

    # Use the logical operations to create a mask
    mask = tf.less(range_row, length_transposed)

    # cast boolean to int to finalize mask
    mask_result = tf.cast(mask, dtype=tf.float32)

    # Element-wise multiplication to cancel out values in the mask
    result = tf.multiply(mask_result, input_matrix)

    return result

mask_values = mask_by_length(x, seq)

输入值(时间主要)[time_steps、batch_size]

[[ 0.71, 0.22, 1.42, -0.28, 0.99] [ 0.41, 2.24, 0.09, 0.74, 0.65]]

序列值 [batch_size]

[2, 3]

输出(时间主要)[time_steps、batch_size]

[[ 0.71, 0.22, 0, 0, 0, ] [ 0.41, 2.24, 0.09, 0, 0, ]]