使用Tensorflow Dataset.from_generator生成的数据在调用iterator.get_next()时会导致错误
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用Tensorflow Dataset.from_generator生成的数据在调用iterator.get_next()时会导致错误相关的知识,希望对你有一定的参考价值。
我是Tensorflow的新手。我关注了一些在线帖子并编写了代码来从生成器获取数据。代码如下所示:
def gen(my_list_of_files):
for fl in my_list_of_files:
with open(fl) as f:
for line in f.readlines():
json_line = json.loads(line)
features = json_line['features']
labels = json_line['labels']
yield features, labels
def get_dataset():
generator = lambda: gen()
return tf.data.Dataset.from_generator(generator, (tf.float32, tf.float32))
def get_input():
dataset = get_dataset()
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.repeat().unbatch(tf.contrib.data.unbatch())
dataset = dataset.batch(batch_size, drop_remainder=False)
# This is where the problem is
features, labels = dataset.make_one_shot_iterator().get_next()
return features, labels
当我运行它时,我收到错误:
InvalidArgumentError (see above for traceback): Input element must have a non-scalar value in each component.
[[node IteratorGetNext (defined at /blah/blah/blah) ]]
我正在屈服的价值如下:
[1, 2, 3, 4, 5, 6] # features
7 # label
我对错误的理解是它不能迭代数据集,因为它不是矢量。我的理解是否正确?我该如何解决?
答案
{
"features": ["1","2"],
"labels": "2"
}
执行此代码时,我没有看到您的错误。
def gen():
with open('jsondataset') as f:
data = json.load(f)
features = data['features']
labels = data['labels']
print( features)
yield features, labels
def get_dataset():
generator = lambda: gen()
return tf.data.Dataset.from_generator(generator, (tf.float32, tf.float32))
def get_input():
dataset = get_dataset()
dataset = dataset.shuffle(buffer_size=5)
dataset = dataset.batch(5, drop_remainder=False)
# This is where the problem is
iter = dataset.make_one_shot_iterator()
features, labels = iter.get_next()
with tf.Session() as sess:
print(sess.run([features,labels]))
def main():
get_input()
if __name__ == "__main__":
main()
[array([[1。,2。]],dtype = float32),array([2。],dtype = float32)]
以上是关于使用Tensorflow Dataset.from_generator生成的数据在调用iterator.get_next()时会导致错误的主要内容,如果未能解决你的问题,请参考以下文章