提交 e124898f 编写于 作者: Z zhengya01 提交者: JesseyXujin

Ce elmo (#3258)

* update .run_ce.sh

* chmod +x .run_ce.sh

* add elmo ce
上级 9105cc02
......@@ -118,6 +118,7 @@ def parse_args():
parser.add_argument('--update_method', type=str, default='nccl2')
parser.add_argument('--random_seed', type=int, default=0)
parser.add_argument('--n_negative_samples_batch', type=int, default=8000)
parser.add_argument('--enable_ce', action='store_true', help='whether print log for ce')
args = parser.parse_args()
return args
......@@ -264,10 +264,17 @@ def train():
vocab_size = vocab.size
logger.info("finished load vocab")
if args.enable_ce:
random.seed(args.random_seed)
np.random.seed(args.random_seed)
logger.info('build the model...')
# build model
train_prog = fluid.Program()
train_startup_prog = fluid.Program()
if args.enable_ce:
train_prog.random_seed = args.random_seed
train_startup_prog.random_seed = args.random_seed
# build infer model
infer_prog = fluid.Program()
infer_startup_prog = fluid.Program()
......@@ -559,6 +566,19 @@ def train_loop(args,
os.makedirs(model_path)
fluid.io.save_persistables(
executor=exe, dirname=model_path, main_program=train_prog)
if args.enable_ce:
card_num = get_cards()
ce_loss = 0
ce_time = 0
try:
ce_loss = ce_info[-2][0]
ce_time = ce_info[-2][1]
except:
print("ce info error")
print("kpis\ttrain_duration_card%s\t%s" % (card_num, ce_time))
print("kpis\ttrain_loss_card%s\t%f" % (card_num, ce_loss))
end_time = time.time()
total_time += end_time - start_time
epoch_id = int(final_batch_id / n_batches_per_epoch)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册