Tensorflow 中带有 model.fit 的 InvalidArgumentError
Posted
技术标签:
【中文标题】Tensorflow 中带有 model.fit 的 InvalidArgumentError【英文标题】:InvalidArgumentError with model.fit in Tensorflow 【发布时间】:2021-04-07 14:08:40 【问题描述】:使用 CNN 进行图像分类。当model.fit()
被调用时,它开始训练模型一段时间,在执行过程中被中断并返回错误信息。
错误信息如下
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: Input size should match (header_size + row_size * abs_height) but they differ by 2
[[node decode_image/DecodeImage]]
[[IteratorGetNext]]
[[IteratorGetNext/_4]]
(1) Invalid argument: Input size should match (header_size + row_size * abs_height) but they differ by 2
[[node decode_image/DecodeImage]]
[[IteratorGetNext]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_8873]
Function call stack:
train_function -> train_function
更新:我的建议是检查数据集的元数据。它帮助解决了我的问题。
【问题讨论】:
代码存在一些问题,但我注意到的主要问题是您正在为训练数据集和测试数据集加载相同的目录。 @yudhiesh 你的意思是训练集和验证集?是的,它们是使用image_dataset_from_directory()
和不同子集从同一目录加载的。测试集在另一个文件夹中分离。由于它与问题关系不大,所以我没有包括它。
很抱歉,这实际上是正确的。我将添加一个包含更改的答案。
@yudhiesh 没关系。稍后我会尝试分享访问数据集的链接。
你没有具体说明你是如何修复它的?你提到检查元数据,但要寻找什么?你发现了什么?你究竟做了什么来修复它?
【参考方案1】:
您不必指定参数 label_mode
。为了使用SparseCategoricalCrossentropy
作为损失函数,您需要将其设置为int
。
如果您不指定它,则将其设置为None
为per the documentation。
您还需要根据从中读取图像的目录结构将参数labels
指定为inferred
。
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
labels="inferred",
label_mode="int",
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
labels="inferred",
label_mode="int",
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
【讨论】:
感谢您的提醒。我只是尝试运行它,但仍然返回相同的错误... 您能打印出img_height, img_width, IMG_SHAPE
的值并将其添加到问题中吗?
就在那儿。 img_height
和 img_width
分别为 180。 IMG_SHAPE
是 (180, 180, 3)
我无法检查数据集中的图像,因为它太大了,但我猜测图像中的输入形状与您在创建模型时指定的大小不同。
我想这可能不是我在导入数据集时指定输入形状inputs = tf.keras.Input(shape=(180, 180, 3))
的原因,这与img_height
和img_width
相同。而且它也没有解释为什么当其中一个类被删除时它工作得很好......以上是关于Tensorflow 中带有 model.fit 的 InvalidArgumentError的主要内容,如果未能解决你的问题,请参考以下文章
Python Tensorflow - 多次运行 model.fit 而不重新实例化模型
如何重构/重新格式化包含要输入到 Tensorflow 的 model.fit() 中的图像的 Pandas 数据帧?
如何在 Python 的 tensorflow.fit 中解决这个问题?
IndexError:在model.fit()中列出超出范围的索引