为啥 Keras (Tensorflow 2.0) 模型在绘制时不包含矩阵乘法的变量?

Posted

技术标签:

【中文标题】为啥 Keras (Tensorflow 2.0) 模型在绘制时不包含矩阵乘法的变量?【英文标题】:Why does the Keras (Tensorflow 2.0) model does not include the variables of matrix multiplication, when plotted?为什么 Keras (Tensorflow 2.0) 模型在绘制时不包含矩阵乘法的变量? 【发布时间】:2022-01-09 06:02:13 【问题描述】:

我正在使用以下代码在 Keras (Tensorflow 2.0) 中构建深度学习模型。

import tensorflow as tf
keras = tf.keras
from keras.layers import Input, Dense
from keras.models import Model

a = Input(shape=(138,7), name='inputP')
b = Input(shape=(138,7), name='inputQ')
c = tf.transpose(b, [0,2,1])
d = tf.matmul(c,a)
e = Dense(15,activation = 'relu')(d)
model = Model([a,b],e)
keras.utils.plot_model(model)

生成以下输出:

这里,输入 P 不包含在绘制的图形中。可能是什么原因?

【问题讨论】:

【参考方案1】:

要使plot_model 正常工作,请将所有这些 tensorflow 操作(如 tf.transposetf.matmul)替换为 lambda 层,这样函数式 API 中的每个节点都是一个 keras 层,即,

import tensorflow as tf

a = tf.keras.Input((138,7), name='inputP')
b = tf.keras.Input((138,7), name='inputQ')
c = tf.keras.layers.Lambda(lambda x: tf.transpose(x, [0,2,1]),name='transpose')(b)
d = tf.keras.layers.Lambda(lambda x: tf.matmul(x[0],x[1]),name='matmul')([c,a])
e = tf.keras.layers.Dense(15,activation = 'relu')(d)
model = tf.keras.Model([a,b],e)
tf.keras.utils.plot_model(model)

【讨论】:

这似乎适用于一般情况。但是当我使用keras的MultiHeadAttention层时,它就不起作用了。 代码可以在这里找到:colab.research.google.com/drive/… 看来tensorflow不支持MHA这个。解决方法是创建一个包含 MHA 的自定义 keras 层并调用它。代码:pastebin.com/qmmURT3P 一般来说,如果图和写的代码不匹配,会不会有问题?训练仍然顺利进行(没有错误),可训练参数的数量也没有变化。两种情况下的指标似乎也相似,当图表连接时,当它有悬挂节点时。 到目前为止您提到的应该是plot_model 函数中的问题,并且计算应该是正确的。但是一般情况下,请把tensorflow操作(tf.transpose,tf.matmul,etc...)封装在keras层里面,避免奇怪的事情发生。

以上是关于为啥 Keras (Tensorflow 2.0) 模型在绘制时不包含矩阵乘法的变量?的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow 2.0 Keras 的训练速度比 2.0 Estimator 慢 4 倍

翻译: Keras 标准化:TensorFlow 2.0 中高级 API 指南

Keras 中的像素加权损失函数 - TensorFlow 2.0

Keras TensorFlow 2.0 精华资源

TensorFlow 2.0 Keras:如何为 TensorBoard 编写图像摘要

在没有 Keras 的情况下使用 Tensorflow 2.0 和急切执行