WDL-训练模型
Posted 我家大宝最可爱
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了WDL-训练模型相关的知识,希望对你有一定的参考价值。
import warnings
warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning)
import tensorflow as tf
from FeatureGenerator import FeatureGenerator
from DataLoad import DataLoad
from WideDeep import WDL
import numpy as np
FLAGS = None
is_train = True
def main(unused_argv):
print('....................start generat feature from json file...................')
fg = FeatureGenerator('fg.json')
# fg.cross_feature(["user_id","cate_id"], 1000000)
print('....................end generat feature from json file...................')
# print('....................wide & deep...................')
print(fg.features_name)
wide_columns = []
for k in fg.features_name['sparse_indicator']:
wide_columns.append(fg.add_feature('s',k))
deep_columns = []
for k in fg.features_name['numeric']:
deep_columns.append(fg.add_feature('n',k))
# print('....................read input tables...................')
train_file,eval_file = FLAGS.tables.split(',')
dl = DataLoad(
"train_file":train_file
,"eval_file":eval_file
,"num_epochs" :100
,'batch_size' :128
,"_CSV_COLUMNS" :fg._CSV_COLUMNS
,"_CSV_COLUMN_DEFAULTS" :fg._CSV_COLUMN_DEFAULTS
,"worker_hosts" :FLAGS.worker_hosts
,"task_index" :FLAGS.task_index
,"remote_train": FLAGS.remote_train
)
# # 保存的checkpoint的序号要大于train_max_steps,否则evaluator无法结束
max_steps = dl.get_cur_max_steps(FLAGS.checkpointDir)
model = WDL(wide_columns,deep_columns,
'checkpointDir':FLAGS.checkpointDir
,'rtp_dfs_path':FLAGS.rtp_dfs_path
,'rtp_model_name':FLAGS.rtp_model_name
,'learning_rate':0.1
,'max_steps':max_steps
,'save_checkpoints_steps':100
,'save_summary_steps':10
,'protocol':FLAGS.protocol
)
# print('....................start train model...................')
model.train(dl.train_input_fn, dl.eval_input_fn)
# print('....................start save model...................')
# model.save_model(fg.save_model_feature())
if __name__ == "__main__":
tf.app.flags.DEFINE_string("tables", "train.csv,test.csv", "tables info")
tf.app.flags.DEFINE_string("job_name", None, "job name: worker or ps")
tf.app.flags.DEFINE_integer("task_index", None, "Worker or server index")
tf.app.flags.DEFINE_string("worker_hosts", "", "worker hosts")
tf.app.flags.DEFINE_string("checkpointDir", '../model', "oss checkpointDir")
tf.app.flags.DEFINE_string("protocol", None, "oss checkpointDir")
tf.app.flags.DEFINE_string("rtp_dfs_path", "../model_rtp", "rtp_dfs_path")
tf.app.flags.DEFINE_string("rtp_model_name", "wdl", "rtp_model_name")
tf.app.flags.DEFINE_string("remote_train", "localhost", "rtp_model_name")
FLAGS = tf.app.flags.FLAGS
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
以上是关于WDL-训练模型的主要内容,如果未能解决你的问题,请参考以下文章