如何将 NumPy 数组图像转换为 TensorFlow 图像?

Posted

技术标签:

【中文标题】如何将 NumPy 数组图像转换为 TensorFlow 图像?【英文标题】:How to convert NumPy array image to TensorFlow image? 【发布时间】:2018-07-21 11:44:14 【问题描述】:

使用TensorFlow的retrain.py后

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py

我已成功生成“retrained_labels.txt”和“retrained_graph.pb”文件。对于不熟悉此过程的任何人,我基本上都在遵循本教程:

https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0

这实际上与这个热门视频的步骤相同:

https://www.youtube.com/watch?v=QfNvhPx5Px8

在重新训练过程之后,我正在尝试编写一个 Python 脚本来打开测试图像目录中的所有图像,并依次在 OpenCV 窗口中显示每个图像,并运行 TensorFlow 对图像进行分类。

问题是,我似乎不知道如何将图像打开为 NumPy 数组(这是 Python OpenCV 包装器使用的格式),然后将其转换为我可以传递给 TensorFlow 的 sess.run 的格式()。

目前我用 cv2.imread() 打开图像,然后用 tf.gfile.FastGFile() 再次打开它。这是一个非常糟糕的做法;我宁愿打开图像一次然后转换它。

这是我卡住的代码的相关部分:

# open the image with OpenCV
openCVImage = cv2.imread(imageFileWithPath)

# show the OpenCV image
cv2.imshow(fileName, openCVImage)

# get the final tensor from the graph
finalTensor = sess.graph.get_tensor_by_name('final_result:0')

# open the image in TensorFlow
tfImage = tf.gfile.FastGFile(imageFileWithPath, 'rb').read()

# run the network to get the predictions
predictions = sess.run(finalTensor, 'DecodeJpeg/contents:0': tfImage)

阅读这些帖子后:

How to convert numpy arrays to standard TensorFlow format?

Feeding image data in tensorflow for transfer learning

我尝试了以下方法:

# show the OpenCV image
cv2.imshow(fileName, openCVImage)

# get the final tensor from the graph
finalTensor = sess.graph.get_tensor_by_name('final_result:0')

# convert the NumPy array / OpenCV image to a TensorFlow image
openCVImageAsArray = np.asarray(openCVImage, np.float32)
tfImage = tf.convert_to_tensor(openCVImageAsArray, np.float32)

# run the network to get the predictions
predictions = sess.run(finalTensor, 'DecodeJpeg/contents:0': tfImage)

这会导致 sess.run() 行出现此错误:

TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.

我也试过这个:

# show the OpenCV image
cv2.imshow(fileName, openCVImage)

# get the final tensor from the graph
finalTensor = sess.graph.get_tensor_by_name('final_result:0')

# convert the NumPy array / OpenCV image to a TensorFlow image
tfImage = np.array(openCVImage)[:, :, 0:3]

# run the network to get the predictions
predictions = sess.run(finalTensor, 'DecodeJpeg/contents:0': tfImage)

导致此错误:

ValueError: Cannot feed value of shape (257, 320, 3) for Tensor 'DecodeJpeg/contents:0', which has shape '()'

--- 编辑 ---

我也试过这个:

# show the OpenCV image
cv2.imshow(fileName, openCVImage)

# get the final tensor from the graph
finalTensor = sess.graph.get_tensor_by_name('final_result:0')

# convert the NumPy array / OpenCV image to a TensorFlow image
tfImage = np.expand_dims(openCVImage, axis=0)

# run the network to get the predictions
predictions = sess.run(finalTensor, feed_dict=finalTensor: tfImage)

导致此错误:

ValueError: Cannot feed value of shape (1, 669, 1157, 3) for Tensor 'final_result:0', which has shape '(?, 2)'

我也试过这个:

# show the OpenCV image
cv2.imshow(fileName, openCVImage)

# get the final tensor from the graph
finalTensor = sess.graph.get_tensor_by_name('final_result:0')

# convert the NumPy array / OpenCV image to a TensorFlow image
tfImage = np.expand_dims(openCVImage, axis=0)

# run the network to get the predictions
predictions = sess.run(finalTensor, feed_dict='DecodeJpeg/contents:0': tfImage)

导致此错误:

ValueError: Cannot feed value of shape (1, 669, 1157, 3) for Tensor 'DecodeJpeg/contents:0', which has shape '()'

我不确定这是否有必要,但如果有人好奇,这里就是整个脚本。请注意,这很好用,除了必须打开图像两次:

# test.py

import os
import tensorflow as tf
import numpy as np
import cv2

# module-level variables ##############################################################################################
RETRAINED_LABELS_TXT_FILE_LOC = os.getcwd() + "/" + "retrained_labels.txt"
RETRAINED_GRAPH_PB_FILE_LOC = os.getcwd() + "/" + "retrained_graph.pb"

TEST_IMAGES_DIR = os.getcwd() + "/test_images"

#######################################################################################################################
def main():
    # get a list of classifications from the labels file
    classifications = []
    # for each line in the label file . . .
    for currentLine in tf.gfile.GFile(RETRAINED_LABELS_TXT_FILE_LOC):
        # remove the carriage return
        classification = currentLine.rstrip()
        # and append to the list
        classifications.append(classification)
    # end for

    # show the classifications to prove out that we were able to read the label file successfully
    print("classifications = " + str(classifications))

    # load the graph from file
    with tf.gfile.FastGFile(RETRAINED_GRAPH_PB_FILE_LOC, 'rb') as retrainedGraphFile:
        # instantiate a GraphDef object
        graphDef = tf.GraphDef()
        # read in retrained graph into the GraphDef object
        graphDef.ParseFromString(retrainedGraphFile.read())
        # import the graph into the current default Graph, note that we don't need to be concerned with the return value
        _ = tf.import_graph_def(graphDef, name='')
    # end with

    # if the test image directory listed above is not valid, show an error message and bail
    if not os.path.isdir(TEST_IMAGES_DIR):
        print("the test image directory does not seem to be a valid directory, check file / directory paths")
        return
    # end if

    with tf.Session() as sess:
        # for each file in the test images directory . . .
        for fileName in os.listdir(TEST_IMAGES_DIR):
            # if the file does not end in .jpg or .jpeg (case-insensitive), continue with the next iteration of the for loop
            if not (fileName.lower().endswith(".jpg") or fileName.lower().endswith(".jpeg")):
                continue
            # end if

            # show the file name on std out
            print(fileName)

            # get the file name and full path of the current image file
            imageFileWithPath = os.path.join(TEST_IMAGES_DIR, fileName)
            # attempt to open the image with OpenCV
            openCVImage = cv2.imread(imageFileWithPath)

            # if we were not able to successfully open the image, continue with the next iteration of the for loop
            if openCVImage is None:
                print("unable to open " + fileName + " as an OpenCV image")
                continue
            # end if

            # show the OpenCV image
            cv2.imshow(fileName, openCVImage)

            # get the final tensor from the graph
            finalTensor = sess.graph.get_tensor_by_name('final_result:0')

            # ToDo: find a way to convert from a NumPy array / OpenCV image to a TensorFlow image
            # instead of opening the file twice, these attempts don't work
            # attempt 1:
            # openCVImageAsArray = np.asarray(openCVImage, np.float32)
            # tfImage = tf.convert_to_tensor(openCVImageAsArray, np.float32)
            # attempt 2:
            # tfImage = np.array(openCVImage)[:, :, 0:3]

            # open the image in TensorFlow
            tfImage = tf.gfile.FastGFile(imageFileWithPath, 'rb').read()

            # run the network to get the predictions
            predictions = sess.run(finalTensor, 'DecodeJpeg/contents:0': tfImage)

            # sort predictions from most confidence to least confidence
            sortedPredictions = predictions[0].argsort()[-len(predictions[0]):][::-1]

            print("---------------------------------------")

            # keep track of if we're going through the next for loop for the first time so we can show more info about
            # the first prediction, which is the most likely prediction (they were sorted descending above)
            onMostLikelyPrediction = True
            # for each prediction . . .
            for prediction in sortedPredictions:
                strClassification = classifications[prediction]

                # if the classification (obtained from the directory name) ends with the letter "s", remove the "s" to change from plural to singular
                if strClassification.endswith("s"):
                    strClassification = strClassification[:-1]
                # end if

                # get confidence, then get confidence rounded to 2 places after the decimal
                confidence = predictions[0][prediction]

                # if we're on the first (most likely) prediction, state what the object appears to be and show a % confidence to two decimal places
                if onMostLikelyPrediction:
                    scoreAsAPercent = confidence * 100.0
                    print("the object appears to be a " + strClassification + ", " + "0:.2f".format(scoreAsAPercent) + "% confidence")
                    onMostLikelyPrediction = False
                # end if

                # for any prediction, show the confidence as a ratio to five decimal places
                print(strClassification + " (" +  "0:.5f".format(confidence) + ")")
            # end for

            # pause until a key is pressed so the user can see the current image (shown above) and the prediction info
            cv2.waitKey()
            # after a key is pressed, close the current window to prep for the next time around
            cv2.destroyAllWindows()
        # end for
    # end with

    # write the graph to file so we can view with TensorBoard
    tfFileWriter = tf.summary.FileWriter(os.getcwd())
    tfFileWriter.add_graph(sess.graph)
    tfFileWriter.close()

# end main

#######################################################################################################################
if __name__ == "__main__":
    main()

【问题讨论】:

【参考方案1】:

你已经很接近了:

'DecodeJpeg/contents:0': tfImage 解码二进制 jpeg 图像。

如果图像已经解码,您需要使用'DecodeJpeg:0': tfImage。 Read more here

所以你的代码应该是这样的:

tfImage = np.array(openCVImage)[:, :, 0:3]
# run the network to get the predictions
predictions = sess.run(finalTensor, 'DecodeJpeg:0': tfImage)

【讨论】:

以上是关于如何将 NumPy 数组图像转换为 TensorFlow 图像?的主要内容,如果未能解决你的问题,请参考以下文章

如何将numpy数组图像转换为字节?

如何将 NumPy 数组转换为应用 matplotlib 颜色图的 PIL 图像

如何使用 matplotlib/numpy 将数组保存为灰度图像?

将CV2 numpy数组转换为QImage时如何配置颜色?

将numpy数组转换为rgb图像

将 NumPy 数组转换为 PIL 图像