提交 ba239be3 编写于 作者: Y Yibing Liu

Merge branch 'develop' of upstream into agg_nlp_models

...@@ -145,14 +145,15 @@ def parse_args(): ...@@ -145,14 +145,15 @@ def parse_args():
return args return args
def append_nccl2_prepare(trainer_id, worker_endpoints, current_endpoint): def append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
current_endpoint):
assert (trainer_id >= 0 and len(worker_endpoints) > 1 and assert (trainer_id >= 0 and len(worker_endpoints) > 1 and
current_endpoint in worker_endpoints) current_endpoint in worker_endpoints)
eps = copy.deepcopy(worker_endpoints) eps = copy.deepcopy(worker_endpoints)
eps.remove(current_endpoint) eps.remove(current_endpoint)
nccl_id_var = fluid.default_startup_program().global_block().create_var( nccl_id_var = startup_prog.global_block().create_var(
name="NCCLID", persistable=True, type=fluid.core.VarDesc.VarType.RAW) name="NCCLID", persistable=True, type=fluid.core.VarDesc.VarType.RAW)
fluid.default_startup_program().global_block().append_op( startup_prog.global_block().append_op(
type="gen_nccl_id", type="gen_nccl_id",
inputs={}, inputs={},
outputs={"NCCLID": nccl_id_var}, outputs={"NCCLID": nccl_id_var},
...@@ -403,7 +404,7 @@ def test_context(exe, train_exe, dev_count): ...@@ -403,7 +404,7 @@ def test_context(exe, train_exe, dev_count):
TrainTaskConfig.label_smooth_eps, TrainTaskConfig.label_smooth_eps,
use_py_reader=args.use_py_reader, use_py_reader=args.use_py_reader,
is_test=True) is_test=True)
test_prog = test_prog.clone(for_test=True)
test_data = prepare_data_generator( test_data = prepare_data_generator(
args, is_test=True, count=dev_count, pyreader=pyreader) args, is_test=True, count=dev_count, pyreader=pyreader)
...@@ -680,10 +681,11 @@ def train(args): ...@@ -680,10 +681,11 @@ def train(args):
logging.info("trainers_num:{}".format(trainers_num)) logging.info("trainers_num:{}".format(trainers_num))
logging.info("worker_endpoints:{}".format(worker_endpoints)) logging.info("worker_endpoints:{}".format(worker_endpoints))
logging.info("current_endpoint:{}".format(current_endpoint)) logging.info("current_endpoint:{}".format(current_endpoint))
append_nccl2_prepare(trainer_id, worker_endpoints, current_endpoint) append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
train_loop(exe, current_endpoint)
fluid.default_main_program(), dev_count, sum_cost, train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
avg_cost, token_num, predict, trainers_num, trainer_id) avg_cost, token_num, predict, pyreader, trainers_num,
trainer_id)
return return
port = os.getenv("PADDLE_PORT", "6174") port = os.getenv("PADDLE_PORT", "6174")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册