WDL数据加载
Posted 我家大宝最可爱
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了WDL数据加载相关的知识,希望对你有一定的参考价值。
import re
import tensorflow as tf
class DataLoad:
def __init__(self, params:dict):
self.train_file = params.get("train_file" ,None)
self.eval_file = params.get("eval_file" ,None)
if self.train_file is None or self.eval_file is None:
raise ValueError("input data is None")
self.remote_train = params.get("remote_train" , "")
self.num_epochs = params.get("num_epochs" ,1)
self.batch_size = params.get('batch_size' ,32)
self._CSV_COLUMNS = params.get("_CSV_COLUMNS" ,None)
if self._CSV_COLUMNS is None: raise Exception('_CSV_COLUMNS is None')
self._CSV_COLUMN_DEFAULTS = params.get("_CSV_COLUMN_DEFAULTS" ,None)
if self._CSV_COLUMN_DEFAULTS is None: raise Exception('_CSV_COLUMN_DEFAULTS is None')
self.worker_hosts = params.get("worker_hosts" ,"")
self.task_index = params.get("task_index" ,1)
print('--------------------------- dataLoad params -------------------------------------')
for k,v in params.items():
print(' : '.format(k, v))
self.set_test()
def local_input_fn(self, data_file):
def _parse_line(line):
columns = tf.decode_csv(
line, record_defaults=self._CSV_COLUMN_DEFAULTS)
features = dict(zip(self._CSV_COLUMNS, columns))
# pop函数提取label
clk_labels = features.pop('clk_label')
return features, tf.cast(tf.equal(clk_labels, '>50K'),tf.int32)
dataset = tf.data.TextLineDataset(data_file).map(_parse_line, num_parallel_calls=5)
dataset = dataset.repeat(self.num_epochs)
dataset = dataset.batch(self.batch_size)
return dataset
def get_cur_max_steps(self, model_dir):
last_ckp_step,total_records_num = 0,0
if self.remote_train == "remote":
import common_io
with common_io.table.TableReader(self.train_file) as f:
total_records_num = f.get_row_count()
else:
with open(self.train_file, "r") as f:
total_records_num = sum(1 for _ in f)
cur_train_steps = total_records_num*self.num_epochs//self.batch_size
last_ckp_path = tf.train.latest_checkpoint(model_dir)
if last_ckp_path is not None:
last_ckp_step = int(last_ckp_path.split("-")[-1])
print('last chceckpoint steps is '.format(last_ckp_step))
print('current train steps is 0, total_records_num is 1'.format(cur_train_steps, total_records_num))
return cur_train_steps + last_ckp_step
def train_input_fn(self):
if self.remote_train == "remote":
dataset = self.local_input_fn(self.train_file)
else:
dataset = self.local_input_fn(self.train_file)
return dataset
def eval_input_fn(self):
if self.remote_train == "remote":
dataset = self.local_input_fn(self.eval_file)
else:
dataset = self.local_input_fn(self.eval_file)
return dataset
def sample(self,date_file,batch_size):
def _parse_line(line):
columns = tf.decode_csv(
line, record_defaults=self._CSV_COLUMN_DEFAULTS)
features = dict(zip(self._CSV_COLUMNS, columns))
# pop函数提取label
clk_labels = features.pop('clk_label')
return features, tf.cast(tf.equal(clk_labels, '>50K'),tf.int32)
dataset = tf.data.TextLineDataset(date_file).map(_parse_line, num_parallel_calls=5)
dataset = dataset.repeat(1)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
batch_features, clk_labels = iterator.get_next()
with tf.Session() as sess:
features_res = sess.run(batch_features)
features_rtp =
for k in features_res.keys():
v = features_res[k]
vv = [[_] for _ in v]
features_rtp[k] = vv
return features_rtp
def set_test(self):
self.test_data = r"""
"""
def date_from_people(self,delimiter='\\t',default_value='\\\\N'):
def _add_type(t, x):
if isinstance(t, float):
return "Collections.singletonList(f)".format(x)
elif isinstance(t, str):
return "Collections.singletonList(\\"\\")".format(x)
else:
return "Collections.singletonList()".format(x)
header,*rows = self.test_data.strip().split('\\n')
header = header.split(delimiter)
tables = [row.strip().split(delimiter) for row in rows] # 二维矩阵
tables = list(zip(*tables)) # 行列转换
feed_tables = []
rtp_tables = []
for i, d in enumerate(self._CSV_COLUMN_DEFAULTS) :
tmp = list(map(lambda x: d[0] if x == default_value else x,tables[i]))
feed = map(lambda x: [float(x)] if isinstance(d[0], (float)) else [x], tmp)
rtp = map(lambda x: _add_type(d[0],x), tmp)
feed_tables.append(list(feed))
rtp_tables.append(list(rtp))
feed_dict = dict(zip(header,feed_tables))
feed_dict.pop("clk_label")
feed_dict.pop("buy_label")
rtp_dict = dict(zip(header,rtp_tables))
rtp_dict.pop("clk_label")
rtp_dict.pop("buy_label")
rtp_res = []
for k,v in rtp_dict.items():
s = """
List<List<Object>> 0 = Arrays.asList(1);
builder.put("0", 0);
""".format(k,','.join(v))
rtp_res.append(s)
with open("rtp_test.txt",'w') as f:
f.writelines(rtp_res)
return feed_dict
if __name__ == "__main__":
with open("adult.data", "r") as csvFile:
lines = csvFile.readlines()
lines_num = len(lines)
with open("train.csv",'w') as train:
train.writelines(lines[:int(lines_num*0.7)])
with open("test.csv",'w') as test:
test.writelines(lines[int(lines_num*0.7):])
以上是关于WDL数据加载的主要内容,如果未能解决你的问题,请参考以下文章