将张量作为参数传递给函数

Posted

技术标签:

【中文标题】将张量作为参数传递给函数【英文标题】:Passing tensors as an argument to a function 【发布时间】:2021-08-14 22:55:00 【问题描述】:

我正在尝试标准化 tf.data.Dataset,如下所示:

def normalization(image):
    print(image['label'])
    
    return 1
    

z = val.map(normalization) 

val 数据集是这样的:

<TakeDataset shapes:  id: (), image: (32, 32, 3), label: (), types: id: tf.string, image: tf.uint8, label: tf.int64>

如果我打印一个元素,我可以看到:

   'id': <tf.Tensor: shape=(), dtype=string, numpy=b'train_31598'>, 'image': <tf.Tensor: shape=(32, 32, 3), dtype=uint8, 
 numpy=    array([[[151, 130, 106],
            .....,
            [104,  95,  77]]], dtype=uint8)>, 'label': <tf.Tensor: shape=(), dtype=int64, numpy=50>

但是在我的函数输出中打印这个:

 'id': <tf.Tensor 'args_1:0' shape=() dtype=string>, 'image': <tf.Tensor 'args_2:0' shape=(32, 32, 3) dtype=uint8>, 'label': <tf.Tensor 'args_3:0' shape=() dtype=int64>

所以我不能对我的图像数组执行任何转换,因为我有'args_2:0'而不是张量数组

如何将每个元素正确传递给我的规范化函数?

【问题讨论】:

【参考方案1】:

我在标准数据集上尝试了您的代码,但它不起作用。 image['label'] 不正确,因为您应该给出一个整数。这是我对您的代码的修改:

def normalization(image,label):
print(image[0])

return tf.cast(image, tf.float32) / 255., label


z = ds_train.map(normalization) 

【讨论】:

以上是关于将张量作为参数传递给函数的主要内容,如果未能解决你的问题,请参考以下文章

如何将带有 args 的成员函数作为参数传递给另一个成员函数?

C++ 将列表作为参数传递给函数

当我们将对象作为参数传递给方法时,为啥会调用复制构造函数?

如何将子类作为期望基类的函数的参数传递,然后将该对象传递给指向这些抽象类对象的指针向量?

将取消引用的指针作为函数参数传递给结构

如何将关键字参数作为参数传递给函数?