提交 bb696e5b 编写于 作者: G gongweibao

fix bug

上级 4a0d8b13
......@@ -14,6 +14,7 @@ from model import transformer, position_encoding_init
import logging
import sys
import copy
def parse_args():
parser = argparse.ArgumentParser("Training for Transformer.")
......@@ -623,7 +624,7 @@ def train(args):
is_test=False)
optimizer=None
if args.local:
if args.sync:
lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(
ModelHyperParams.d_model, TrainTaskConfig.warmup_steps)
optimizer = fluid.optimizer.Adam(
......@@ -631,7 +632,7 @@ def train(args):
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
elif args.sync == False:
else:
optimizer = fluid.optimizer.SGD(0.003)
optimizer.minimize(avg_cost)
......@@ -681,7 +682,7 @@ def train(args):
startup_program=startup_prog)
if training_role == "PSERVER":
loggin.info("distributed: pserver started")
logging.info("distributed: pserver started")
current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
"PADDLE_PORT")
if not current_endpoint:
......@@ -694,7 +695,7 @@ def train(args):
exe.run(pserver_startup)
exe.run(pserver_prog)
elif training_role == "TRAINER":
loggin.info("distributed: trainer started")
logging.info("distributed: trainer started")
trainer_prog = t.get_trainer_program()
train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
avg_cost, token_num, predict, pyreader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册