提交 2c52abfa 编写于 作者: G gongweibao

fix

上级 1d63dafd
...@@ -13,6 +13,8 @@ import reader ...@@ -13,6 +13,8 @@ import reader
from config import * from config import *
from model import transformer, position_encoding_init from model import transformer, position_encoding_init
from paddle.fluid.transpiler.details import program_to_code
import logging import logging
import sys import sys
import copy import copy
...@@ -642,8 +644,13 @@ def train(args): ...@@ -642,8 +644,13 @@ def train(args):
if args.sync: if args.sync:
lr_decay = fluid.layers.learning_rate_scheduler.noam_decay( lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(
ModelHyperParams.d_model, TrainTaskConfig.warmup_steps) ModelHyperParams.d_model, TrainTaskConfig.warmup_steps)
print("before adam")
with fluid.default_main_program()._lr_schedule_guard():
learning_rate = lr_decay * TrainTaskConfig.learning_rate
optimizer = fluid.optimizer.Adam( optimizer = fluid.optimizer.Adam(
learning_rate=lr_decay * TrainTaskConfig.learning_rate, learning_rate=learning_rate,
beta1=TrainTaskConfig.beta1, beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2, beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps) epsilon=TrainTaskConfig.eps)
...@@ -688,6 +695,13 @@ def train(args): ...@@ -688,6 +695,13 @@ def train(args):
trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0")) trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
current_endpoint = os.getenv("POD_IP") + ":" + port current_endpoint = os.getenv("POD_IP") + ":" + port
trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
print("pserver_endpoints", pserver_endpoints)
print("current_endpoint", current_endpoint)
print("trainer_id", trainer_id)
print("pserver_ips", pserver_ips)
print("port", port)
t = fluid.DistributeTranspiler() t = fluid.DistributeTranspiler()
t.transpile( t.transpile(
trainer_id, trainer_id,
...@@ -707,11 +721,26 @@ def train(args): ...@@ -707,11 +721,26 @@ def train(args):
pserver_startup = t.get_startup_program(current_endpoint, pserver_startup = t.get_startup_program(current_endpoint,
pserver_prog) pserver_prog)
print("pserver start:")
program_to_code(pserver_startup)
print("pserver train:")
program_to_code(pserver_prog)
#sys.exit(0)
exe.run(pserver_startup) exe.run(pserver_startup)
exe.run(pserver_prog) exe.run(pserver_prog)
elif training_role == "TRAINER": elif training_role == "TRAINER":
logging.info("distributed: trainer started") logging.info("distributed: trainer started")
trainer_prog = t.get_trainer_program() trainer_prog = t.get_trainer_program()
'''
print("trainer start:")
program_to_code(pserver_startup)
print("trainer train:")
program_to_code(trainer_prog)
sys.exit(0)
'''
train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
avg_cost, token_num, predict, pyreader) avg_cost, token_num, predict, pyreader)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册