提交 bb696e5b 编写于 作者: G gongweibao

fix bug

上级 4a0d8b13
...@@ -14,6 +14,7 @@ from model import transformer, position_encoding_init ...@@ -14,6 +14,7 @@ from model import transformer, position_encoding_init
import logging import logging
import sys import sys
import copy
def parse_args(): def parse_args():
parser = argparse.ArgumentParser("Training for Transformer.") parser = argparse.ArgumentParser("Training for Transformer.")
...@@ -623,7 +624,7 @@ def train(args): ...@@ -623,7 +624,7 @@ def train(args):
is_test=False) is_test=False)
optimizer=None optimizer=None
if args.local: if args.sync:
lr_decay = fluid.layers.learning_rate_scheduler.noam_decay( lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(
ModelHyperParams.d_model, TrainTaskConfig.warmup_steps) ModelHyperParams.d_model, TrainTaskConfig.warmup_steps)
optimizer = fluid.optimizer.Adam( optimizer = fluid.optimizer.Adam(
...@@ -631,7 +632,7 @@ def train(args): ...@@ -631,7 +632,7 @@ def train(args):
beta1=TrainTaskConfig.beta1, beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2, beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps) epsilon=TrainTaskConfig.eps)
elif args.sync == False: else:
optimizer = fluid.optimizer.SGD(0.003) optimizer = fluid.optimizer.SGD(0.003)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
...@@ -681,7 +682,7 @@ def train(args): ...@@ -681,7 +682,7 @@ def train(args):
startup_program=startup_prog) startup_program=startup_prog)
if training_role == "PSERVER": if training_role == "PSERVER":
loggin.info("distributed: pserver started") logging.info("distributed: pserver started")
current_endpoint = os.getenv("POD_IP") + ":" + os.getenv( current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
"PADDLE_PORT") "PADDLE_PORT")
if not current_endpoint: if not current_endpoint:
...@@ -694,7 +695,7 @@ def train(args): ...@@ -694,7 +695,7 @@ def train(args):
exe.run(pserver_startup) exe.run(pserver_startup)
exe.run(pserver_prog) exe.run(pserver_prog)
elif training_role == "TRAINER": elif training_role == "TRAINER":
loggin.info("distributed: trainer started") logging.info("distributed: trainer started")
trainer_prog = t.get_trainer_program() trainer_prog = t.get_trainer_program()
train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
avg_cost, token_num, predict, pyreader) avg_cost, token_num, predict, pyreader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册