tf.data.Dataset.from_tensor_slices中的shuffle()repeat()batch()用法

Posted gengyi

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tf.data.Dataset.from_tensor_slices中的shuffle()repeat()batch()用法相关的知识,希望对你有一定的参考价值。

引用库文件

from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import feature_column
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split

加载数据集,生成数据帧资源句柄

# 将heart.csv数据集下载并加载到数据帧中
path_data = "E:/pre_data/heart.csv"
dataframe = pd.read_csv(path_data)

将pandas dataframe 数据格式转变为 tf.data 格式的数据集形式

# 拷贝数据帧,id(dataframe)!=id(dataframe_new)
 dataframe_new = dataframe.copy()
 # 从dataframe_new数据中获取target属性
labels = dataframe_new.pop(target)
# 要构建Dataset内存中的数据
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe_new), labels))
# 将数据打乱的混乱程度
ds = ds.shuffle(buffer_size=len(dataframe_new))
# 从数据集中取出数据集的个数
ds = ds.batch(100)
# 指定数据集重复的次数
ds = ds.repeat(2)

ds 中有shuffle、batch、repeat三个方法;具体区别如下

shuffle:

tensorflow中的数据集类Dataset有一个shuffle方法,用来打乱数据集中数据顺序,训练时非常常用。其中shuffle方法有一个参数buffer_size,非常令人费解,文档的解释如下:

 

以上是关于tf.data.Dataset.from_tensor_slices中的shuffle()repeat()batch()用法的主要内容,如果未能解决你的问题,请参考以下文章