如何在 Python 的 tensorflow.fit 中解决这个问题?
Posted
技术标签:
【中文标题】如何在 Python 的 tensorflow.fit 中解决这个问题?【英文标题】:How can I fix this problem in tensorflow.fit in Python? 【发布时间】:2022-01-15 17:07:46 【问题描述】:你能告诉我这段代码有什么问题吗? 代码最后一行的意思是
history = model.fit(partial_x_train, partial_y_train, epochs=20, batch_size=512, validation_data=(x_val, y_val))
有问题但我不明白问题出在哪里
from tensorflow.keras.datasets import imdb
from tensorflow.keras import models
from tensorflow.keras import layers
from keras import optimizers
from keras import losses
from keras import metrics
import matplotlib.pyplot as plt
import numpy as np
(train_data, train_labels), (test_data,test_labels) = imdb.load_data(num_words=10000)
def vectorsize_sequeces(sequences, dimension=10000):
results = np.zeros((len(sequences), dimension))
for i, sequences in enumerate(sequences):
results[i, sequences] = 1.
return results
x_train = vectorsize_sequeces(train_data)
x_test = vectorsize_sequeces(test_data)
y_train = np.asarray(train_labels).astype('float32')
y_test = np.asarray(test_labels).astype('float32')
model = models.Sequential()
model.add(layers.Dense(16,activation='relu',input_shape=(10000,)))
model.add(layers.Dense(16,activation='relu'))
model.add(layers.Dense(1,activation='sigmoid'))
model.compile(optimizer='rmsprop',loss='binary_crossentopy',metrics=['accuracy'])
x_val = x_train[:10000]
partial_x_train = x_train[10000:]
y_val = y_train[:10000]
partial_y_train = y_train[10000:]
history = model.fit(partial_x_train, partial_y_train, epochs=20, batch_size=512, validation_data=(x_val, y_val))
我们遇到的错误
Epoch 1/20
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-23-be6266211430> in <module>()
----> 1 history = model.fit(partial_x_train, partial_y_train, epochs=20, batch_size=512, validation_data=(x_val, y_val))
1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in autograph_handler(*args, **kwargs)
1127 except Exception as e: # pylint:disable=broad-except
1128 if hasattr(e, "ag_error_metadata"):
-> 1129 raise e.ag_error_metadata.to_exception(e)
1130 else:
1131 raise
ValueError: in user code:
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 878, in train_function *
return step_function(self, iterator)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 867, in step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 860, in run_step **
outputs = model.train_step(data)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 810, in train_step
y, y_pred, sample_weight, regularization_losses=self.losses)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py", line 184, in __call__
self.build(y_pred)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py", line 133, in build
self._losses = tf.nest.map_structure(self._get_loss_object, self._losses)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py", line 273, in _get_loss_object
loss = losses_mod.get(loss)
File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 2134, in get
return deserialize(identifier)
File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 2093, in deserialize
printable_module_name='loss function')
File "/usr/local/lib/python3.7/dist-packages/keras/utils/generic_utils.py", line 709, in deserialize_keras_object
f'Unknown printable_module_name: object_name. Please ensure '
ValueError: Unknown loss function: binary_crossentopy. Please ensure this object is passed to the `custom_objects` argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.
【问题讨论】:
【参考方案1】:拼写错误
binary_crossentropy
你写道:
binary_crossentopy
使用这个:
model.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['accuracy'])
tf.keras.metrics.binary_crossentropy
【讨论】:
以上是关于如何在 Python 的 tensorflow.fit 中解决这个问题?的主要内容,如果未能解决你的问题,请参考以下文章