diff --git a/PaddleNLP/language_representations_kit/ELMo/args.py b/PaddleNLP/language_representations_kit/ELMo/args.py index 67537befa9335c6a410203d874f03388891043a2..dbcac53ef50ad2a34ebb4b8b88b6c92a1272020c 100755 --- a/PaddleNLP/language_representations_kit/ELMo/args.py +++ b/PaddleNLP/language_representations_kit/ELMo/args.py @@ -118,6 +118,7 @@ def parse_args(): parser.add_argument('--update_method', type=str, default='nccl2') parser.add_argument('--random_seed', type=int, default=0) parser.add_argument('--n_negative_samples_batch', type=int, default=8000) + parser.add_argument('--enable_ce', action='store_true', help='whether print log for ce') args = parser.parse_args() return args diff --git a/PaddleNLP/language_representations_kit/ELMo/train.py b/PaddleNLP/language_representations_kit/ELMo/train.py index 25beb4b1aba1a05109375f4e88ec79887d389218..62f96eb1a5e4b9311cfce17c2ae6f2b637b7b2d5 100755 --- a/PaddleNLP/language_representations_kit/ELMo/train.py +++ b/PaddleNLP/language_representations_kit/ELMo/train.py @@ -264,10 +264,17 @@ def train(): vocab_size = vocab.size logger.info("finished load vocab") + if args.enable_ce: + random.seed(args.random_seed) + np.random.seed(args.random_seed) + logger.info('build the model...') # build model train_prog = fluid.Program() train_startup_prog = fluid.Program() + if args.enable_ce: + train_prog.random_seed = args.random_seed + train_startup_prog.random_seed = args.random_seed # build infer model infer_prog = fluid.Program() infer_startup_prog = fluid.Program() @@ -559,6 +566,19 @@ def train_loop(args, os.makedirs(model_path) fluid.io.save_persistables( executor=exe, dirname=model_path, main_program=train_prog) + + if args.enable_ce: + card_num = get_cards() + ce_loss = 0 + ce_time = 0 + try: + ce_loss = ce_info[-2][0] + ce_time = ce_info[-2][1] + except: + print("ce info error") + print("kpis\ttrain_duration_card%s\t%s" % (card_num, ce_time)) + print("kpis\ttrain_loss_card%s\t%f" % (card_num, ce_loss)) + end_time = time.time() total_time += end_time - start_time epoch_id = int(final_batch_id / n_batches_per_epoch)