From c86f79bd99cee86c569c83e8822c76f6a4c93e21 Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Thu, 14 Jan 2021 11:27:19 +0800 Subject: [PATCH] Transformer-XL update train (#5200) * update train * update eval --- PaddleNLP/examples/language_model/transformer-xl/eval.py | 4 ++-- PaddleNLP/examples/language_model/transformer-xl/train.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/PaddleNLP/examples/language_model/transformer-xl/eval.py b/PaddleNLP/examples/language_model/transformer-xl/eval.py index a6c49cf8..cd50c2d7 100644 --- a/PaddleNLP/examples/language_model/transformer-xl/eval.py +++ b/PaddleNLP/examples/language_model/transformer-xl/eval.py @@ -118,9 +118,9 @@ def do_eval(args): logger_info = '' if valid_loss is not None: - logger_info = logger_info + _logger(valid_loss) + logger_info = logger_info + _logger(valid_loss) + " | " if test_loss is not None: - logger_info = logger_info + _logger(test_loss) + logger_info = logger_info + _logger(test_loss) + " | " logger.info(logger_info) diff --git a/PaddleNLP/examples/language_model/transformer-xl/train.py b/PaddleNLP/examples/language_model/transformer-xl/train.py index 116a9352..91297212 100644 --- a/PaddleNLP/examples/language_model/transformer-xl/train.py +++ b/PaddleNLP/examples/language_model/transformer-xl/train.py @@ -242,8 +242,9 @@ def do_train(args): logger.info(logger_info) if args.save_model and rank == 0: - model_dir = os.path.join(args.save_model, - "step_" + str(step_idx)) + model_dir = os.path.join( + args.save_model, + "step_" + str(step_idx) + "_" + str(eval_loss)) if not os.path.exists(model_dir): os.makedirs(model_dir) paddle.save( -- GitLab