未验证 提交 096fa39f 编写于 作者: Q qingqing01 提交者: GitHub

Fix logging in transformer dygraph (#4827)

上级 f9f0d30e
...@@ -29,6 +29,10 @@ from utils.check import check_gpu, check_version ...@@ -29,6 +29,10 @@ from utils.check import check_gpu, check_version
import reader import reader
from model import Transformer, CrossEntropyCriterion, NoamDecay from model import Transformer, CrossEntropyCriterion, NoamDecay
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def do_train(args): def do_train(args):
if args.use_cuda: if args.use_cuda:
...@@ -180,7 +184,7 @@ def do_train(args): ...@@ -180,7 +184,7 @@ def do_train(args):
total_avg_cost = avg_cost.numpy() * trainer_count total_avg_cost = avg_cost.numpy() * trainer_count
if step_idx == 0: if step_idx == 0:
logging.info( logger.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" % "normalized loss: %f, ppl: %f" %
(step_idx, pass_id, batch_id, total_avg_cost, (step_idx, pass_id, batch_id, total_avg_cost,
...@@ -189,7 +193,7 @@ def do_train(args): ...@@ -189,7 +193,7 @@ def do_train(args):
else: else:
train_avg_batch_cost = args.print_step / ( train_avg_batch_cost = args.print_step / (
time.time() - batch_start) time.time() - batch_start)
logging.info( logger.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, avg_speed: %.2f step/s" "normalized loss: %f, ppl: %f, avg_speed: %.2f step/s"
% (step_idx, pass_id, batch_id, total_avg_cost, % (step_idx, pass_id, batch_id, total_avg_cost,
...@@ -216,11 +220,11 @@ def do_train(args): ...@@ -216,11 +220,11 @@ def do_train(args):
total_sum_cost += sum_cost.numpy() total_sum_cost += sum_cost.numpy()
total_token_num += token_num.numpy() total_token_num += token_num.numpy()
total_avg_cost = total_sum_cost / total_token_num total_avg_cost = total_sum_cost / total_token_num
logging.info("validation, step_idx: %d, avg loss: %f, " logger.info("validation, step_idx: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" % "normalized loss: %f, ppl: %f" %
(step_idx, total_avg_cost, (step_idx, total_avg_cost,
total_avg_cost - loss_normalizer, total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]))) np.exp([min(total_avg_cost, 100)])))
transformer.train() transformer.train()
if args.save_model and ( if args.save_model and (
...@@ -242,8 +246,8 @@ def do_train(args): ...@@ -242,8 +246,8 @@ def do_train(args):
train_epoch_cost = time.time() - epoch_start train_epoch_cost = time.time() - epoch_start
ce_time.append(train_epoch_cost) ce_time.append(train_epoch_cost)
logging.info("train epoch: %d, epoch_cost: %.5f s" % logger.info("train epoch: %d, epoch_cost: %.5f s" %
(pass_id, train_epoch_cost)) (pass_id, train_epoch_cost))
if args.save_model: if args.save_model:
model_dir = os.path.join(args.save_model, "step_final") model_dir = os.path.join(args.save_model, "step_final")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册