提交 6ee6ca69 编写于 作者: Q Qiao Longfei

clean code

上级 4c03cf0f
import argparse import argparse
import time
import numpy as np import numpy as np
import paddle import paddle
...@@ -8,6 +9,11 @@ import reader ...@@ -8,6 +9,11 @@ import reader
from network_conf import ctr_dnn_model from network_conf import ctr_dnn_model
def print_log(log_str):
time_stamp = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
print(str(time_stamp) + " " + log_str)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example") parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example")
parser.add_argument( parser.add_argument(
...@@ -49,7 +55,6 @@ def infer(): ...@@ -49,7 +55,6 @@ def infer():
with fluid.scope_guard(inference_scope): with fluid.scope_guard(inference_scope):
[inference_program, _, fetch_targets] = fluid.io.load_inference_model(args.model_path, exe) [inference_program, _, fetch_targets] = fluid.io.load_inference_model(args.model_path, exe)
print(fetch_targets)
def set_zero(var_name): def set_zero(var_name):
param = inference_scope.var(var_name).get_tensor() param = inference_scope.var(var_name).get_tensor()
...@@ -60,14 +65,12 @@ def infer(): ...@@ -60,14 +65,12 @@ def infer():
for name in auc_states_names: for name in auc_states_names:
set_zero(name) set_zero(name)
batch_id = 0 for batch_id, data in enumerate(test_reader()):
for data in test_reader():
loss_val, auc_val = exe.run(inference_program, loss_val, auc_val = exe.run(inference_program,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=fetch_targets) fetch_list=fetch_targets)
if batch_id % 100 == 0: if batch_id % 100 == 0:
print("loss: " + str(loss_val) + " auc_val:" + str(auc_val)) print_log("TEST --> batch: {} loss: {} auc: {}".format(batch_id, loss_val, auc_val))
batch_id += 1
if __name__ == '__main__': if __name__ == '__main__':
......
import os
import logging
import argparse import argparse
import os
import time
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from network_conf import ctr_dnn_model
import reader import reader
import paddle from network_conf import ctr_dnn_model
logging.basicConfig() def print_log(log_str):
logger = logging.getLogger("paddle") time_stamp = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
logger.setLevel(logging.INFO) print(str(time_stamp) + " " + log_str)
def parse_args(): def parse_args():
...@@ -73,17 +74,14 @@ def train(): ...@@ -73,17 +74,14 @@ def train():
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for pass_id in range(args.num_passes): for pass_id in range(args.num_passes):
batch_id = 0 for batch_id, data in enumerate(train_reader()):
for data in train_reader():
loss_val, auc_val, batch_auc_val = exe.run( loss_val, auc_val, batch_auc_val = exe.run(
fluid.default_main_program(), fluid.default_main_program(),
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[loss, auc_var, batch_auc_var] fetch_list=[loss, auc_var, batch_auc_var]
) )
print('pass:' + str(pass_id) + ' batch:' + str(batch_id) + print_log("TRAIN --> pass: {} batch: {} loss: {} auc: {}, batch_auc: {}"
' loss: ' + str(loss_val) + " auc: " + str(auc_val) + .format(pass_id, batch_id, loss_val, auc_val, batch_auc_val))
" batch_auc: " + str(batch_auc_val))
batch_id += 1
if batch_id % 1000 == 0 and batch_id != 0: if batch_id % 1000 == 0 and batch_id != 0:
model_dir = args.model_output_dir + '/batch-' + str(batch_id) model_dir = args.model_output_dir + '/batch-' + str(batch_id)
fluid.io.save_inference_model(model_dir, data_name_list, [loss, auc_var], exe) fluid.io.save_inference_model(model_dir, data_name_list, [loss, auc_var], exe)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册