如何在 Keras 中解释 model.predict() 的输出

Posted

技术标签:

【中文标题】如何在 Keras 中解释 model.predict() 的输出【英文标题】:How to interpret output of model.predict() in Keras 【发布时间】:2020-10-24 20:11:27 【问题描述】:

当我尝试执行预测图像时,我的代码有问题。使用 keras 等。

我正在寻找如何输出数组的方法

例如[1,0,0]然后输出rock

import numpy as np
from google.colab import files
from keras.preprocessing import image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.applications.vgg16 import VGG16

%matplotlib inline

uploaded = files.upload()

for fn in uploaded.keys():
 
  # predicting images
  path = fn
  img = image.load_img(path, target_size=(150,150))
  imgplot = plt.imshow(img)
  x = image.img_to_array(img)
  x = np.expand_dims(x, axis=0)
  x = preprocess_input(x)

  #images = np.vstack([x])
  classes = model.predict(x, batch_size=10)
  print(classes)

  print(fn)
  if classes==[[1,0,0]]:
    print('paper')
  else:
    print('rock')

然后是这样的输出

Saving 0a3UtNzl5Ll3sq8K.png to 0a3UtNzl5Ll3sq8K (4).png
[[1. 0. 0.]]
0a3UtNzl5Ll3sq8K.png
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-69-863494647f7a> in <module>()
     28 
     29   print(fn)
---> 30   if classes==[[1,0,0]]:
     31     print('paper')
     32   else:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

截图程序: enter image description here

【问题讨论】:

【参考方案1】:

始终检查您正在使用的对象的类型。

返回类型是张量数组,不是列表;它实际上是每个标签的概率数组。要将其转换为 numpy 数组,您需要使用 prediction.numpy()

在您的情况下,混淆来自这样一个事实,即第一个标签的概率确实为 100%,其余的概率为 0%。

除此之外,注意比较的方式:

[[1. 0. 0.]][[1,0,0]]

您需要使用argmax() 才能正确获取标签。

【讨论】:

以上是关于如何在 Keras 中解释 model.predict() 的输出的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Keras Regressor 中解释 MSE

如何在 Keras Regressor 中解释 MSE

如何在 Keras 中解释 LSTM 层中的权重 [关闭]

在 Keras 中实现模型。如何解释填充/步幅值?

model.evaluate() 和 model.predict() 的 F1 不同

你能用 BatchNormalization 解释神经网络中的 Keras get_weights() 函数吗?