警告:tensorflow:sample_weight 模式被强制从 ... 到 ['...']
Posted
技术标签:
【中文标题】警告:tensorflow:sample_weight 模式被强制从 ... 到 [\'...\']【英文标题】:WARNING:tensorflow:sample_weight modes were coerced from ... to ['...']警告:tensorflow:sample_weight 模式被强制从 ... 到 ['...'] 【发布时间】:2020-04-06 15:09:55 【问题描述】:使用.fit_generator()
或.fit()
训练图像分类器,并将字典作为参数传递给class_weight=
。
我在 TF1.x 中从未出现错误,但在 2.1 中,我在开始训练时得到以下输出:
WARNING:tensorflow:sample_weight modes were coerced from
...
to
['...']
从...
强制转换为['...']
是什么意思?
tensorflow
的 repo 上此警告的来源是 here,放置的 cmets 是:
尝试将 sample_weight_modes 强制为目标结构。这隐含地取决于 Model 为内部表示扁平化输出这一事实。
【问题讨论】:
很高兴看到这样一个最近的问题也是我自己警告的唯一搜索结果。 @jorijnsmit 你能提供代码来复制问题/警告吗? 其实用%tensorflow_version 2.x
切换到TF2就足以让这个警告出现:colab.research.google.com/gist/jorijnsmit/…
@jorijnsmit,不,我收到相同的警告,但实际上已将 TF2.1 安装为 pip install tensorflow
(在 pyenv/virtualenv 环境中)
确实是@lurix66,产生此错误的代码在2.1.0rc0
中介绍。
【参考方案1】:
这似乎是一条虚假消息。升级到 TensorFlow 2.1 后,我收到相同的警告消息,但我根本不使用任何类权重或样本权重。我确实使用了一个返回这样的元组的生成器:
return inputs, targets
现在我只是将其更改为以下内容以使警告消失:
return inputs, targets, [None]
我不知道这是否相关,但我的模型使用 3 个输入,所以我的 inputs
变量实际上是 3 个 numpy 数组的列表。 targets
只是一个 numpy 数组。
无论如何,这只是一个警告。无论哪种方式,训练都很好。
为 TensorFlow 2.2 编辑:
这个错误似乎已经在 TensorFlow 2.2 中得到修复,非常棒。然而,上面的修复将在 TF 2.2 中失败,因为它会尝试获取样本权重的形状,这显然会因AttributeError: 'NoneType' object has no attribute 'shape'
而失败。所以升级到 2.2 时撤消上述修复。
【讨论】:
天哪,这AttributeError
杀了我……非常感谢!【参考方案2】:
我相信这是 tensorflow 的一个错误,当您使用默认参数 sample_weight_mode=None
调用 model.compile()
然后使用指定的 sample_weight
或 class_weight
调用 model.fit()
时会发生这种错误。
来自 tensorflow 存储库:
fit()
最终调用_process_training_inputs()
_process_training_inputs()
sets sample_weight_modes = [None]
基于model.sample_weight_mode = None
然后用sample_weight_modes = [None]
创建一个DataAdapter
DataAdapter
在initialization 期间用sample_weight_modes = [None]
调用broadcast_sample_weight_modes()
broadcast_sample_weight_modes()
seems to expect sample_weight_modes = None
但收到 [None]
它断言[None]
是与sample_weight
/ class_weight
不同的结构,通过拟合sample_weight
/ class_weight
的结构将其覆盖回None
并输出警告
除了警告之外,这对fit()
没有影响,因为DataAdapter
中的sample_weight_modes
被设置回None
。
请注意,tensorflow documentation 声明 sample_weight
必须是一个 numpy 数组。如果您使用sample_weight.tolist()
调用fit()
,则不会收到警告,但当_process_numpy_inputs()
在preprocessing 中调用并接收长度大于1 的输入时,sample_weight
会被静默覆盖为None
。
【讨论】:
非常详尽的解释,谢谢。我唯一不明白的是警告描述...
被强制转换为[...]
,而在你的情况下[None]
被强制转换为None
...【参考方案3】:
我采用了您的 Gist 并安装了 Tensorflow 2.0,而不是 TFA,并且它在没有任何此类警告的情况下工作。
这里是完整代码的Gist。安装Tensorflow的代码如下:
!pip install tensorflow==2.0
执行成功的截图如下:
更新:此错误已在 Tensorflow Version 2.2.
【讨论】:
感谢您的回复。你是对的,直到版本2.1.0rc0
才引入警告消息。但是,恐怕我的问题仍然存在:“将 ...
强制转换为 ['...']
是什么意思?”
我注意到,当sample_weight_mode=None
和target_structure
的类型为dict
,sample_weight_modes
然后[None]
和broadcast_sample_weight_modes
中的异常是由于dict
。这可以被认为是一个错误吗?
不。问题不断收集意见和支持,但没有答案。
@gkennos:如果你觉得这是一个错误,你能在 Github Tensorflow 存储库中提交一个错误。
这绝对是一个错误,但它现在已在 TensorFlow 2.2 中修复【参考方案4】:
而不是提供字典
weights = '0': 42.0, '1': 1.0
我尝试了一个列表
weights = [42.0, 1.0]
警告消失了。
【讨论】:
谢谢你!我正在尝试(不成功)使用字典。通过使用列表,错误得到修复! 虽然这确实消除了错误,但对我来说,这打破了每个类的权重会产生更糟糕的结果。在切换到列表之前,我会检查一致性。以上是关于警告:tensorflow:sample_weight 模式被强制从 ... 到 ['...']的主要内容,如果未能解决你的问题,请参考以下文章