提交 29b293d2 编写于 作者: u010070587's avatar u010070587 提交者: kolinwei

modify similarity_net ce (#4126)

上级 89316571
......@@ -92,11 +92,6 @@ def train(conf_dict, args):
"""
train processic
"""
if args.enable_ce:
SEED = 102
fluid.default_startup_program().random_seed = SEED
fluid.default_main_program().random_seed = SEED
# loading vocabulary
vocab = utils.load_vocab(args.vocab_path)
# get vocab size
......@@ -124,6 +119,12 @@ def train(conf_dict, args):
startup_prog = fluid.Program()
train_program = fluid.Program()
# used for continuous evaluation
if args.enable_ce:
SEED = 102
startup_prog.random_seed = SEED
train_program.random_seed = SEED
simnet_process = reader.SimNetProcessor(args, vocab)
if args.task_mode == "pairwise":
# Build network
......@@ -219,11 +220,15 @@ def train(conf_dict, args):
ce_info = []
train_exe = exe
#for epoch_id in range(args.epoch):
train_batch_data = fluid.io.batch(
fluid.io.shuffle(
get_train_examples, buf_size=10000),
args.batch_size,
drop_last=False)
# used for continuous evaluation
if args.enable_ce:
train_batch_data = fluid.io.batch(get_train_examples, args.batch_size, drop_last=False)
else:
train_batch_data = fluid.io.batch(
fluid.io.shuffle(
get_train_examples, buf_size=10000),
args.batch_size,
drop_last=False)
train_pyreader.decorate_paddle_reader(train_batch_data)
train_pyreader.start()
exe.run(startup_prog)
......@@ -295,14 +300,14 @@ def train(conf_dict, args):
target_vars, exe,
test_prog)
logging.info("saving infer model in %s" % model_path)
# used for continuous evaluation
if args.enable_ce:
card_num = get_cards()
ce_loss = 0
ce_time = 0
try:
ce_loss = ce_info[-2][0]
ce_time = ce_info[-2][1]
ce_loss = ce_info[-1][0]
ce_time = ce_info[-1][1]
except:
logging.info("ce info err!")
print("kpis\teach_step_duration_%s_card%s\t%s" %
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册