如何在 Keras/TensorFlow 中可视化 RNN/LSTM 权重?

How to visualize RNN/LSTM weights in Keras/TensorFlow?

我看过一些研究出版物和问答讨论了检查 RNN 权重的必要性;一些相关的答案是正确的,建议 get_weights() - 但我如何真正可视化权重 有意义 ?也就是说,LSTM 和 GRU 有 ,所有 RNN 都有 通道 作为独立的特征提取器 - 那么我如何 ( 1) 获取 per-gate 权重,并 (2) 以提供信息的方式绘制它们?

Keras/TF 以明确定义的顺序构建 RNN 权重,可以从源代码或直接通过 layer.__dict__ 检查 - 然后用于获取 per- kernelper-gate 权重; per-channel 然后可以在给定张量形状的情况下进行处理。下面的代码和解释涵盖 Keras/TF RNN 的所有可能情况,应该很容易扩展到任何未来的 API 变化。

另请参阅可视化 RNN 梯度,以及对 RNN regularization 的应用;与前者不同的是 post,我不会在这里包括一个简化的变体,因为根据权重提取和组织的性质,它仍然相当大和复杂;相反,只需查看存储库中的相关源代码(请参阅下一节)。


代码源See RNN(这个post包括更大的图像),我的存储库;包括:

  • 激活可视化
  • 权重可视化
  • 激活梯度可视化
  • 权重梯度可视化
  • 解释所有功能的文档字符串
  • 支持 Eager、Graph、TF1、TF2 和 from keras & from tf.keras
  • 比示例中显示的视觉可定制性更好

可视化方法:

  • 2D 热图:绘制每个门、每个内核、每个方向的权重分布; 清楚地显示内核与隐藏的关系
  • 直方图:绘制每个门、每个内核、每个方向的权重分布; 丢失上下文信息

EX 1:uni-LSTM,256 个单位,权重 -- batch_shape = (16, 100, 20)(输入)
rnn_histogram(model, 'lstm', equate_axes=False, show_bias=False)
rnn_histogram(model, 'lstm', equate_axes=True, show_bias=False)
rnn_heatmap(model, 'lstm')

  • 顶部图是直方图子图网格,显示每个内核的权重分布,以及每个内核内每个门的权重分布
  • 第二个绘图集 equate_axes=True 用于在内核和门之间进行均匀比较,提高比较质量,但可能会降低视觉吸引力
  • 最后一张图是相同权重的热图,门分隔由垂直线标记,还包括偏差权重
  • 与直方图不同,热图保留channel/context信息:可以清楚地区分输入到隐藏和隐藏到隐藏的转换矩阵
  • 请注意 Forget 门处的大量集中的最大值;作为琐事,在 Keras 中(通常),偏置门都被初始化为零,除了 Forget 偏置,它被初始化为 1


EX 2:bi-CuDNNLSTM,256 个单位,权重 -- batch_shape = (16, 100, 16)(输入)
rnn_histogram(model, 'bidir', equate_axes=2)
rnn_heatmap(model, 'bidir', norm=(-.8, .8))

  • 两者都支持双向;此示例中包含的直方图偏差
  • 再次注意偏差热图;它们似乎不再与 EX 1 位于同一位置。确实,CuDNNLSTM(和 CuDNNGRU)偏差的定义和初始化不同 - 无法从直方图推断的东西


EX 3:uni-CuDNNGRU,64 个单元,权重梯度 -- batch_shape = (16, 100, 16)(输入)
rnn_heatmap(model, 'gru', mode='grads', input_data=x, labels=y, cmap=None, absolute_value=True)

  • 我们可能希望可视化 梯度强度 ,这可以通过 absolute_value=True 和灰度色图
  • 来完成
  • 在此示例中,即使没有明确的分隔线,门分隔也很明显:
    • New是最活跃的内核门(输入到隐藏),建议在允许信息流
    • 上进行更多纠错
    • Reset 是最不活跃的循环门(隐藏到隐藏),建议在记忆保持方面进行最少的纠错


额外奖励:LSTM NaN 检测,512 个单位,权重 -- batch_shape = (16, 100, 16)(输入)

  • 热图和直方图都带有内置的 NaN 检测 - 内核、门和方向
  • 热图会将 NaN 打印到控制台,而直方图会直接在图上标记它们
  • 两者都会在绘图前将 NaN 值设置为零;在下面的示例中,所有相关的非 NaN 权重都已经为零