如何从 tf.tensor 中获取字符串值,其中 dtype 是字符串

Posted

技术标签:

【中文标题】如何从 tf.tensor 中获取字符串值,其中 dtype 是字符串【英文标题】:how to get string value out of tf.tensor which dtype is string 【发布时间】:2019-10-01 00:33:45 【问题描述】:

我想使用 tf.data.Dataset.list_files 函数来提供我的数据集。 但是因为文件不是图片,所以需要手动加载。 问题是 tf.data.Dataset.list_files 将变量作为 tf.tensor 传递,而我的 python 代码无法处理张量。

如何从 tf.tensor 获取字符串值。 dtype 是字符串。

train_dataset = tf.data.Dataset.list_files(PATH+'clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: load_audio_file(x))

def load_audio_file(file_path):
  print("file_path: ", file_path)
  # i want do something like string_path = convert_tensor_to_string(file_path)

file_path 是Tensor("arg0:0", shape=(), dtype=string)

我使用 tensorflow 1.13.1 和 Eager 模式。

提前致谢

【问题讨论】:

【参考方案1】:

您可以使用tf.py_func 包裹load_audio_file()

import tensorflow as tf

tf.enable_eager_execution()

def load_audio_file(file_path):
    # you should decode bytes type to string type
    print("file_path: ",bytes.decode(file_path),type(bytes.decode(file_path)))
    return file_path

train_dataset = tf.data.Dataset.list_files('clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: tf.py_func(load_audio_file, [x], [tf.string]))

for one_element in train_dataset:
    print(one_element)

file_path:  clean_4s_val/1.wav <class 'str'>
(<tf.Tensor: id=32, shape=(), dtype=string, numpy=b'clean_4s_val/1.wav'>,)
file_path:  clean_4s_val/3.wav <class 'str'>
(<tf.Tensor: id=34, shape=(), dtype=string, numpy=b'clean_4s_val/3.wav'>,)
file_path:  clean_4s_val/2.wav <class 'str'>
(<tf.Tensor: id=36, shape=(), dtype=string, numpy=b'clean_4s_val/2.wav'>,)

更新 TF 2

上述解决方案不适用于 TF 2(使用 2.2.0 测试),即使将 tf.py_func 替换为 tf.py_function,也会给出

InvalidArgumentError: TypeError: descriptor 'decode' requires a 'bytes' object but received a 'tensorflow.python.framework.ops.EagerTensor'

要使其在 TF 2 中工作,请进行以下更改:

删除tf.enable_eager_execution()(在TF 2中eager是enabled by default,您可以通过tf.executing_eagerly()返回True进行验证) 将tf.py_func 替换为tf.py_functionfile_path 的所有函数内引用替换为file_path.numpy()

【讨论】:

对于 TF V2.x.x 使用 tf.py_function 或 tf.numpy_function 而不是 tf.py_func。 嗯,这是一个救生员,谢谢 (+1)!我冒昧地更新了您对 TF 2 的答案(添加另一个答案感觉不对),但是如果您反对,我很抱歉 - 只是回滚到以前的版本。再次感谢... @desertnaut 干得好!感谢您改进答案。 @ElegantCode 解决了我的问题,天才!!!不知道为什么tensorflow函数会变变变。【参考方案2】:

如果您想做一些完全自定义的事情,那么将您的代码包装在tf.py_function 中是您应该做的。请记住,这将导致性能不佳。在此处查看文档和示例:

https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map

另一方面,如果您正在做一些通用的事情,那么您不需要将代码包装在py_function 中,而是使用tf.strings 模块中提供的任何方法。这些方法适用于字符串张量,并提供许多常用方法,如 split、join、len 等。这些方法不会对性能产生负面影响,它们将直接作用于张量并返回修改后的张量。

在此处查看tf.strings 的文档:https://www.tensorflow.org/api_docs/python/tf/strings

例如,假设您想从文件名中提取标签名称,您可以编写如下代码:

ds.map(lambda x: tf.strings.split(x, sep='$')[1])

以上假设标签由$分隔。

【讨论】:

以上是关于如何从 tf.tensor 中获取字符串值,其中 dtype 是字符串的主要内容,如果未能解决你的问题,请参考以下文章

仅评估 tf.Tensor 的非零值

在 tf.data 中切片导致“在图形执行中不允许迭代 `tf.Tensor`”错误

Tensorflow 类型错误:不允许使用 `tf.Tensor` 作为 Python `bool`。

如何使用提供的需要 tf.Tensor 的 preprocess_input 函数预处理 tf.data.Dataset?

在 Tensorflow.js 中获取张量中项目的值

Tensorflow:TypeError:预期的二进制或 unicode 字符串,得到 <tf.Tensor 'Placeholder:0' shape=<unknown> dtyp