From bb696e5bd1da9e8eb2381e88fb24947e9ad3b408 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 10 Oct 2018 02:41:25 +0000 Subject: [PATCH] fix bug --- fluid/neural_machine_translation/transformer/train.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 08a0488f..c4b6d6d9 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -14,6 +14,7 @@ from model import transformer, position_encoding_init import logging import sys +import copy def parse_args(): parser = argparse.ArgumentParser("Training for Transformer.") @@ -623,7 +624,7 @@ def train(args): is_test=False) optimizer=None - if args.local: + if args.sync: lr_decay = fluid.layers.learning_rate_scheduler.noam_decay( ModelHyperParams.d_model, TrainTaskConfig.warmup_steps) optimizer = fluid.optimizer.Adam( @@ -631,7 +632,7 @@ def train(args): beta1=TrainTaskConfig.beta1, beta2=TrainTaskConfig.beta2, epsilon=TrainTaskConfig.eps) - elif args.sync == False: + else: optimizer = fluid.optimizer.SGD(0.003) optimizer.minimize(avg_cost) @@ -681,7 +682,7 @@ def train(args): startup_program=startup_prog) if training_role == "PSERVER": - loggin.info("distributed: pserver started") + logging.info("distributed: pserver started") current_endpoint = os.getenv("POD_IP") + ":" + os.getenv( "PADDLE_PORT") if not current_endpoint: @@ -694,7 +695,7 @@ def train(args): exe.run(pserver_startup) exe.run(pserver_prog) elif training_role == "TRAINER": - loggin.info("distributed: trainer started") + logging.info("distributed: trainer started") trainer_prog = t.get_trainer_program() train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost, token_num, predict, pyreader) -- GitLab