提交 29b293d2 编写于 作者: u010070587's avatar u010070587 提交者: kolinwei

modify similarity_net ce (#4126)

上级 89316571
...@@ -92,11 +92,6 @@ def train(conf_dict, args): ...@@ -92,11 +92,6 @@ def train(conf_dict, args):
""" """
train processic train processic
""" """
if args.enable_ce:
SEED = 102
fluid.default_startup_program().random_seed = SEED
fluid.default_main_program().random_seed = SEED
# loading vocabulary # loading vocabulary
vocab = utils.load_vocab(args.vocab_path) vocab = utils.load_vocab(args.vocab_path)
# get vocab size # get vocab size
...@@ -124,6 +119,12 @@ def train(conf_dict, args): ...@@ -124,6 +119,12 @@ def train(conf_dict, args):
startup_prog = fluid.Program() startup_prog = fluid.Program()
train_program = fluid.Program() train_program = fluid.Program()
# used for continuous evaluation
if args.enable_ce:
SEED = 102
startup_prog.random_seed = SEED
train_program.random_seed = SEED
simnet_process = reader.SimNetProcessor(args, vocab) simnet_process = reader.SimNetProcessor(args, vocab)
if args.task_mode == "pairwise": if args.task_mode == "pairwise":
# Build network # Build network
...@@ -219,11 +220,15 @@ def train(conf_dict, args): ...@@ -219,11 +220,15 @@ def train(conf_dict, args):
ce_info = [] ce_info = []
train_exe = exe train_exe = exe
#for epoch_id in range(args.epoch): #for epoch_id in range(args.epoch):
train_batch_data = fluid.io.batch( # used for continuous evaluation
fluid.io.shuffle( if args.enable_ce:
get_train_examples, buf_size=10000), train_batch_data = fluid.io.batch(get_train_examples, args.batch_size, drop_last=False)
args.batch_size, else:
drop_last=False) train_batch_data = fluid.io.batch(
fluid.io.shuffle(
get_train_examples, buf_size=10000),
args.batch_size,
drop_last=False)
train_pyreader.decorate_paddle_reader(train_batch_data) train_pyreader.decorate_paddle_reader(train_batch_data)
train_pyreader.start() train_pyreader.start()
exe.run(startup_prog) exe.run(startup_prog)
...@@ -295,14 +300,14 @@ def train(conf_dict, args): ...@@ -295,14 +300,14 @@ def train(conf_dict, args):
target_vars, exe, target_vars, exe,
test_prog) test_prog)
logging.info("saving infer model in %s" % model_path) logging.info("saving infer model in %s" % model_path)
# used for continuous evaluation
if args.enable_ce: if args.enable_ce:
card_num = get_cards() card_num = get_cards()
ce_loss = 0 ce_loss = 0
ce_time = 0 ce_time = 0
try: try:
ce_loss = ce_info[-2][0] ce_loss = ce_info[-1][0]
ce_time = ce_info[-2][1] ce_time = ce_info[-1][1]
except: except:
logging.info("ce info err!") logging.info("ce info err!")
print("kpis\teach_step_duration_%s_card%s\t%s" % print("kpis\teach_step_duration_%s_card%s\t%s" %
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册