Tensorflow 的 K-Means - 图形断开错误
Posted
技术标签:
【中文标题】Tensorflow 的 K-Means - 图形断开错误【英文标题】:K-Means of Tensorflow - Graph disconnected error 【发布时间】:2021-07-09 07:35:00 【问题描述】:我正在尝试编写一个在数据集上运行 KMeans
并输出集群质心的函数。我的目标是在自定义的keras
层中使用它,所以我使用TensorFlow
的 KMeans 实现,它将张量作为输入数据集。
但是,我的问题是,即使作为独立功能,我也无法使其工作。问题来自KMeans
接受一个 generator function 提供小批量而不是普通张量的事实,但是当我使用闭包来执行此操作时,我得到一个 graph disconnected
错误:
import tensorflow as tf # version: 2.4.1
from tensorflow.compat.v1.estimator.experimental import KMeans
@tf.function
def KMeansCentroids(inputs, num_clusters, steps, use_mini_batch=False):
# `inputs` is a 2D tensor
def input_fn():
# Each one of the lines below results in the same "Graph Disconnected" error. Tuples don't really needed but just to be consistent with the documentation
return (inputs, None)
return (tf.data.Dataset.from_tensor_slices(inputs), None)
return (tf.convert_to_tensor(inputs), None)
kmeans = KMeans(
num_clusters=num_clusters,
use_mini_batch=use_mini_batch)
kmeans.train(input_fn, steps=steps) # This is where the error happens
return kmeans.cluster_centers()
>>> x = tf.random.uniform((100, 2))
>>> c = KMeansCentroids(x, 5, 10)
确切的错误是:
如果我要使用值错误:
Tensor("strided_slice:0", shape=(), dtype=int32)
必须来自同一图表Tensor("Equal:0", shape=(), dtype=bool)
(图表为FuncGraph(name=KMeansCentroids, id=..)
和<tensorflow.python.framework.ops.Graph object at ...>
)。
numpy
数据集并在函数内转换为张量,代码就可以正常工作。
另外,使input_fn()
直接返回tf.random.uniform((100, 2))
(忽略输入参数)将再次起作用。这就是为什么我猜测 tensorflow 不支持闭包,因为它需要在一开始就构建计算图。
但我不知道如何解决这个问题。
由于 KMeans 是 compat.v1.experimental
模块,会不会是版本错误?
请注意,documentation of KMeans 状态为 input_fn()
:
该函数应构造并返回以下之一:
tf.data.Dataset 对象:Dataset 对象的输出必须是具有以下相同约束的元组(特征、标签)。 元组(特征、标签):其中 features 是 tf.Tensor 或字符串特征名称到 Tensor 的字典,标签是 Tensor 或字符串标签名称到 Tensor 的字典。特征和标签都由 model_fn 使用。它们应该满足输入对 model_fn 的期望。
【问题讨论】:
【参考方案1】:您面临的问题更多是在创建的图形之外调用张量。基本上,当您调用.train
函数时,将创建一个新图,即在input_fn
中定义的图和在model_fn
中定义的图。
kmeans.train(input_fn, steps=steps)
然后,所有超出这些函数的张量都将被视为局外人,不会成为这个新图表的一部分。这就是您尝试使用外部张量时收到graph disconnected
错误的原因。要解决此问题,您需要在这些图中创建必要的张量。
import tensorflow as tf
from tensorflow.compat.v1.estimator.experimental import KMeans
@tf.function
def KMeansCentroids(num_clusters, steps, use_mini_batch=False):
def input_fn(batch_size):
pinputs = tf.random.uniform((100, 2))
dataset = tf.data.Dataset.from_tensor_slices((pinputs))
dataset = dataset.shuffle(1000).repeat()
return dataset.batch(batch_size)
kmeans = KMeans(
num_clusters=num_clusters,
use_mini_batch=use_mini_batch)
kmeans.train(input_fn = lambda: input_fn(5),
steps=steps)
return kmeans.cluster_centers()
c = KMeansCentroids(5, 10)
这里有更多信息可供阅读,1。仅供参考,我用tf > 2
的几个版本测试了您的代码,我认为这与版本错误或其他问题无关。
在这里为未来的读者重新提及。在Keras
层中使用KMeans
的替代方法:
【讨论】:
感谢您的回答。您所说的是正确的,我提到在input_fn()
中生成数据会起作用。但不幸的是,我想将KMeansCentroids()
函数合并到自定义keras.Layer
中。所以我不能事先知道输入。它们在Layer.call()
中作为参数可用。所以我需要以某种方式组合图表。
您是否尝试过在Layer. build()
中进行延迟初始化?
是的。由于需要在input_fn()
中捕获“全局”变量,问题总是相同的。似乎应该有一种标准方法将张量/操作/函数“附加”到现有图形
我明白了。然后你能添加一些你尝试过的自定义层的最小代码吗?你不需要写任何“那是什么”的东西,其他人可以从这里找到信息。
你检查过这些tf_kmeans.py、ClusteringLayer - 有什么好处吗?以上是关于Tensorflow 的 K-Means - 图形断开错误的主要内容,如果未能解决你的问题,请参考以下文章