Tensorflow+kerasKeras 用Class类封装的模型如何调试call子函数的模型内部变量
Posted Better Bench
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow+kerasKeras 用Class类封装的模型如何调试call子函数的模型内部变量相关的知识,希望对你有一定的参考价值。
1 引言
keras搭建神经网络模型有三种方式,第一种是使用sequential,第二种函数API,第三种是Class。第二种在IDE直接家断点就可以调试。但是在Class封装的神经网络中,如下,添加断点后,运行是不会进入到调试的。
# 模型
class test_layer(keras.layers.Layer):
def __init__(self, **kwargs):
super(test_layer, self).__init__(**kwargs)
def build(self, input_shape):
self.w = K.variable(1.)
self._trainable_weights.append(self.w)
super(test_layer, self).build(input_shape)
def call(self, x, **kwargs):
m = x * x # 在这设置断点
n = self.w * K.sqrt(x)
return m + n
# 主函数
import tensorflow as tf
import keras
import keras.backend as K
input = keras.layers.Input((100,1))
y = test_layer()(input)
model = keras.Model(input,y)
model.predict(np.ones((100,1)))
2 实现
添加断点后,通过单独调用Class中的call类,并传入实参,就可以进入到call函数进行调试查看
# 主函数
import tensorflow as tf
import keras
import keras.backend as K
test_input = np.ones((100,1)
model = test_layer()
test = model.call(test_input)
以上是关于Tensorflow+kerasKeras 用Class类封装的模型如何调试call子函数的模型内部变量的主要内容,如果未能解决你的问题,请参考以下文章
Tensorflow+kerasKeras API两种训练GAN网络的方式
Tensorflow+kerasKeras API两种训练GAN网络的方式
Tensorflow+kerasKeras API三种搭建神经网络的方式及以mnist举例实现
Tensorflow+kerasKeras API三种搭建神经网络的方式及以mnist举例实现
Tensorflow+Keraskeras实现条件生成对抗网络DCGAN--以Minis和fashion_mnist数据集为例
Tensorflow+Keraskeras实现条件生成对抗网络DCGAN--以Minis和fashion_mnist数据集为例