WDL-训练模型

Posted 我家大宝最可爱

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了WDL-训练模型相关的知识,希望对你有一定的参考价值。

import warnings
warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning)

import tensorflow as tf
from FeatureGenerator import FeatureGenerator
from DataLoad import DataLoad
from WideDeep import WDL
import numpy as np
FLAGS = None

is_train = True

def main(unused_argv):

    print('....................start generat feature from json file...................')
    fg = FeatureGenerator('fg.json')


    # fg.cross_feature(["user_id","cate_id"], 1000000)
    print('....................end generat feature from json file...................')


    # print('....................wide & deep...................')
    print(fg.features_name)
    wide_columns = []
    for k in fg.features_name['sparse_indicator']:
        wide_columns.append(fg.add_feature('s',k))

    deep_columns = []
    for k in fg.features_name['numeric']:
        deep_columns.append(fg.add_feature('n',k))


    # print('....................read input tables...................')
    train_file,eval_file = FLAGS.tables.split(',')

    dl = DataLoad(
        "train_file":train_file
        ,"eval_file":eval_file
        ,"num_epochs" :100
        ,'batch_size' :128
        ,"_CSV_COLUMNS" :fg._CSV_COLUMNS
        ,"_CSV_COLUMN_DEFAULTS" :fg._CSV_COLUMN_DEFAULTS
        ,"worker_hosts" :FLAGS.worker_hosts
        ,"task_index" :FLAGS.task_index
        ,"remote_train": FLAGS.remote_train
    )


    # # 保存的checkpoint的序号要大于train_max_steps,否则evaluator无法结束
    max_steps = dl.get_cur_max_steps(FLAGS.checkpointDir)
    
    model = WDL(wide_columns,deep_columns, 
        'checkpointDir':FLAGS.checkpointDir
        ,'rtp_dfs_path':FLAGS.rtp_dfs_path
        ,'rtp_model_name':FLAGS.rtp_model_name
        ,'learning_rate':0.1
        ,'max_steps':max_steps
        ,'save_checkpoints_steps':100
        ,'save_summary_steps':10
        ,'protocol':FLAGS.protocol
    )
    

    # print('....................start train model...................')
    model.train(dl.train_input_fn, dl.eval_input_fn)

    # print('....................start save model...................')
    # model.save_model(fg.save_model_feature())




if __name__ == "__main__":

    tf.app.flags.DEFINE_string("tables", "train.csv,test.csv", "tables info")
    tf.app.flags.DEFINE_string("job_name", None, "job name: worker or ps")
    tf.app.flags.DEFINE_integer("task_index", None, "Worker or server index")
    tf.app.flags.DEFINE_string("worker_hosts", "", "worker hosts")
    tf.app.flags.DEFINE_string("checkpointDir", '../model', "oss checkpointDir")
    tf.app.flags.DEFINE_string("protocol", None, "oss checkpointDir")
    tf.app.flags.DEFINE_string("rtp_dfs_path", "../model_rtp", "rtp_dfs_path")
    tf.app.flags.DEFINE_string("rtp_model_name", "wdl", "rtp_model_name")
    tf.app.flags.DEFINE_string("remote_train", "localhost", "rtp_model_name")
    FLAGS = tf.app.flags.FLAGS
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.app.run()

以上是关于WDL-训练模型的主要内容,如果未能解决你的问题,请参考以下文章

WDL-特征生成

WDL数据加载

sh wdl设置别名/ ln

WDL数据加载

WDL-生成配置文件

WDL-生成配置文件