Tensorflow+kerasKeras 用Class类封装的模型如何调试call子函数的模型内部变量

Posted Better Bench

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow+kerasKeras 用Class类封装的模型如何调试call子函数的模型内部变量相关的知识,希望对你有一定的参考价值。

1 引言

keras搭建神经网络模型有三种方式,第一种是使用sequential,第二种函数API,第三种是Class。第二种在IDE直接家断点就可以调试。但是在Class封装的神经网络中,如下,添加断点后,运行是不会进入到调试的。

# 模型
class test_layer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(test_layer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.w = K.variable(1.)
        self._trainable_weights.append(self.w)
        super(test_layer, self).build(input_shape)

    def call(self, x, **kwargs):
        m = x * x            # 在这设置断点
        n = self.w * K.sqrt(x)
        return m + n
# 主函数
import tensorflow as tf
import keras
import keras.backend as K

input = keras.layers.Input((100,1))
y = test_layer()(input)

model = keras.Model(input,y)
model.predict(np.ones((100,1)))

2 实现

添加断点后,通过单独调用Class中的call类,并传入实参,就可以进入到call函数进行调试查看

# 主函数
import tensorflow as tf
import keras
import keras.backend as K

test_input = np.ones((100,1)
model = test_layer()
test = model.call(test_input)

以上是关于Tensorflow+kerasKeras 用Class类封装的模型如何调试call子函数的模型内部变量的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow+kerasKeras API两种训练GAN网络的方式

Tensorflow+kerasKeras API两种训练GAN网络的方式

Tensorflow+kerasKeras API三种搭建神经网络的方式及以mnist举例实现

Tensorflow+kerasKeras API三种搭建神经网络的方式及以mnist举例实现

Tensorflow+Keraskeras实现条件生成对抗网络DCGAN--以Minis和fashion_mnist数据集为例

Tensorflow+Keraskeras实现条件生成对抗网络DCGAN--以Minis和fashion_mnist数据集为例