访问在 TF 2.0 中未显式公开为层的 Keras 模型的中间张量

Posted

技术标签:

【中文标题】访问在 TF 2.0 中未显式公开为层的 Keras 模型的中间张量【英文标题】:Accessing intermediate tensors of a Keras Model that were not explicitly exposed as layers in TF 2.0 【发布时间】:2021-01-25 22:31:55 【问题描述】:

是否可以在 Keras 模型中访问预激活张量?例如,给定这个模型:

import tensorflow as tf
image_ = tf.keras.Input(shape=[224, 224, 3], batch_size=1)
vgg19 = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_tensor=image_, input_shape=image_.shape[1:], pooling=None)

访问图层的常用方法是:

intermediate_layer_model = tf.keras.models.Model(inputs=image_, outputs=[vgg19.get_layer('block1_conv2').output])
intermediate_layer_model.summary()

这给出了一层的 ReLU 输出,而我想要 ReLU 输入。我试过这样做:

graph = tf.function(vgg19, [tf.TensorSpec.from_tensor(image_)]).get_concrete_function().graph
outputs = [graph.get_tensor_by_name(tname) for tname in [
    'vgg19/block4_conv3/BiasAdd:0',
    'vgg19/block4_conv4/BiasAdd:0',
    'vgg19/block5_conv1/BiasAdd:0'
]]
intermediate_layer_model = tf.keras.models.Model(inputs=image_, outputs=outputs)
intermediate_layer_model.summary()

但我得到了错误

ValueError: Unknown graph. Aborting.

我发现的唯一解决方法是编辑模型文件以手动公开中间体,像这样转动每一层:

x = layers.Conv2D(256, (3, 3), activation="relu", padding="same", name="block3_conv1")(x)

分为 2 层,其中第一层可以在激活之前访问:

x = layers.Conv2D(256, (3, 3), activation=None, padding="same", name="block3_conv1")(x)
x = layers.ReLU(name="block3_conv1_relu")(x)

有没有一种方法可以访问模型中的预激活张量,而无需实质上编辑 Tensorflow 2 源代码,或者恢复到具有完全灵活性访问中间体的 Tensorflow 1?

【问题讨论】:

this 回答你的问题了吗? 该解决方案正是对我不起作用的解决方案。它对回复该帖子的人也不起作用,但他们会收到不同的错误。 【参考方案1】:

获取每一层的输出。您必须定义一个 keras 函数并为每一层评估它。

请参考如下代码

from tensorflow.keras import backend as K

inp = model.input                                           # input 
outputs = [layer.output for layer in model.layers]          # all layer outputs
functors = [K.function([inp], [out]) for out in outputs]    # evaluation functions

更多详情请参考SO Answer。

【讨论】:

以上是关于访问在 TF 2.0 中未显式公开为层的 Keras 模型的中间张量的主要内容,如果未能解决你的问题,请参考以下文章

如果在根中初始化对象时未显式声明,则组件 v 无法识别道具

即使未显式修改数组值也已修改[重复]

怎么证明未显式定义构造方法时,编译器会自动生成无参的构造方法?

浅析SQL查询语句未显式指定排序方式,无法保证同样的查询每次排序结果都一致的原因

bss 概念

Kera高层API