WDL数据加载

Posted 我家大宝最可爱

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了WDL数据加载相关的知识,希望对你有一定的参考价值。

import re
import tensorflow as tf


class DataLoad:
    def __init__(self, params:dict):
        self.train_file = params.get("train_file" ,None)
        self.eval_file = params.get("eval_file" ,None)

        if self.train_file is None or self.eval_file is None:
            raise ValueError("input data is None")
        self.remote_train = params.get("remote_train" , "")
        self.num_epochs = params.get("num_epochs" ,1)
        self.batch_size = params.get('batch_size' ,32)
        self._CSV_COLUMNS = params.get("_CSV_COLUMNS" ,None)
        if self._CSV_COLUMNS is None: raise Exception('_CSV_COLUMNS is None')
        self._CSV_COLUMN_DEFAULTS = params.get("_CSV_COLUMN_DEFAULTS" ,None)
        if self._CSV_COLUMN_DEFAULTS is None: raise Exception('_CSV_COLUMN_DEFAULTS is None')
        self.worker_hosts = params.get("worker_hosts" ,"")
        self.task_index = params.get("task_index" ,1)
        print('--------------------------- dataLoad params -------------------------------------')
        for k,v in params.items():
            print(' : '.format(k, v))

        self.set_test()

    def local_input_fn(self, data_file):
        def _parse_line(line):
            columns = tf.decode_csv(
                line, record_defaults=self._CSV_COLUMN_DEFAULTS)
            features = dict(zip(self._CSV_COLUMNS, columns))
            # pop函数提取label
            clk_labels = features.pop('clk_label')
            return features, tf.cast(tf.equal(clk_labels, '>50K'),tf.int32)

        dataset = tf.data.TextLineDataset(data_file).map(_parse_line, num_parallel_calls=5)

        dataset = dataset.repeat(self.num_epochs)
        dataset = dataset.batch(self.batch_size)
        return dataset

    def get_cur_max_steps(self, model_dir):
        last_ckp_step,total_records_num = 0,0
        if self.remote_train == "remote":
            import common_io
            with common_io.table.TableReader(self.train_file) as f:
                total_records_num = f.get_row_count()
        else:
            with open(self.train_file, "r") as f:
                total_records_num = sum(1 for _ in f)
        cur_train_steps = total_records_num*self.num_epochs//self.batch_size
        last_ckp_path = tf.train.latest_checkpoint(model_dir)
        if last_ckp_path is not None:
            last_ckp_step = int(last_ckp_path.split("-")[-1])
        print('last chceckpoint steps is '.format(last_ckp_step))
        print('current train steps is 0, total_records_num is 1'.format(cur_train_steps, total_records_num))
        return cur_train_steps + last_ckp_step

    def train_input_fn(self):
        if self.remote_train == "remote":
            dataset = self.local_input_fn(self.train_file)
        else:
            dataset = self.local_input_fn(self.train_file)
        return dataset
    
    def eval_input_fn(self):
        if self.remote_train == "remote":
            dataset = self.local_input_fn(self.eval_file)
        else:
            dataset = self.local_input_fn(self.eval_file)
        return dataset

    def sample(self,date_file,batch_size):
        def _parse_line(line):
            columns = tf.decode_csv(
                line, record_defaults=self._CSV_COLUMN_DEFAULTS)
            features = dict(zip(self._CSV_COLUMNS, columns))
            # pop函数提取label
            clk_labels = features.pop('clk_label')

            return features, tf.cast(tf.equal(clk_labels, '>50K'),tf.int32)

        dataset = tf.data.TextLineDataset(date_file).map(_parse_line, num_parallel_calls=5)

        dataset = dataset.repeat(1)
        dataset = dataset.batch(batch_size)
        iterator = dataset.make_one_shot_iterator()
        batch_features, clk_labels = iterator.get_next()
        with tf.Session() as sess:
            features_res = sess.run(batch_features)

        features_rtp = 

        for k in features_res.keys():
            v = features_res[k]

            vv = [[_] for _ in v]
            features_rtp[k] = vv

        return features_rtp
    
    def set_test(self):
        self.test_data = r"""
            
       """

    def date_from_people(self,delimiter='\\t',default_value='\\\\N'):

        def _add_type(t, x):
            if isinstance(t, float):
                return "Collections.singletonList(f)".format(x)
            elif isinstance(t, str):
                return "Collections.singletonList(\\"\\")".format(x)
            else:
                return "Collections.singletonList()".format(x)

        header,*rows = self.test_data.strip().split('\\n')
        header = header.split(delimiter)
        tables = [row.strip().split(delimiter) for row in rows] # 二维矩阵
        tables = list(zip(*tables)) # 行列转换
        feed_tables = []
        rtp_tables = []
        for i, d in enumerate(self._CSV_COLUMN_DEFAULTS) :
            tmp = list(map(lambda x: d[0] if x == default_value else x,tables[i]))
            feed = map(lambda x: [float(x)] if isinstance(d[0], (float)) else [x], tmp)
            rtp = map(lambda x: _add_type(d[0],x), tmp)
            feed_tables.append(list(feed))
            rtp_tables.append(list(rtp))

        feed_dict = dict(zip(header,feed_tables))
        feed_dict.pop("clk_label")
        feed_dict.pop("buy_label")

        rtp_dict = dict(zip(header,rtp_tables))
        rtp_dict.pop("clk_label")
        rtp_dict.pop("buy_label")


        rtp_res = []
        for k,v in rtp_dict.items():
            s = """
            List<List<Object>> 0 = Arrays.asList(1);
            builder.put("0", 0);
            """.format(k,','.join(v))
            rtp_res.append(s)

        with open("rtp_test.txt",'w') as f:
            f.writelines(rtp_res)

        return feed_dict

if __name__ == "__main__":

    with open("adult.data", "r") as csvFile:
        lines = csvFile.readlines()
        lines_num = len(lines)

        with open("train.csv",'w') as train:
            train.writelines(lines[:int(lines_num*0.7)])
        
        with open("test.csv",'w') as test:
            test.writelines(lines[int(lines_num*0.7):])




以上是关于WDL数据加载的主要内容,如果未能解决你的问题,请参考以下文章

WDL-生成配置文件

WDL-生成配置文件

重新加载时刷新片段

用于数据加载的 Android 活动/片段职责

如何在android中将json数据加载到片段中

WDL-训练模型