未验证 提交 5c9deb5e 编写于 作者: W Wu Yi 提交者: GitHub

Merge pull request #1505 from typhoonzero/dist_train_fixes

refine dist train
......@@ -26,6 +26,7 @@ import six
import sys
sys.path.append("..")
import models
import utils
from reader import train, val
def parse_args():
......@@ -149,13 +150,15 @@ def get_model(args, is_train, main_prog, startup_prog):
lr = []
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
# NOTE: we put weight decay in layers config, and remove
# weight decay on bn layers, so don't add weight decay in
# optimizer config.
optimizer = fluid.optimizer.Momentum(
learning_rate=models.learning_rate.lr_warmup(
learning_rate=utils.learning_rate.lr_warmup(
fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
warmup_steps, start_lr, end_lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
momentum=0.9)
optimizer.minimize(avg_cost)
batched_reader = None
......@@ -175,6 +178,7 @@ def append_nccl2_prepare(trainer_id, startup_prog):
for ip in worker_ips.split(","):
worker_endpoints.append(':'.join([ip, port]))
current_endpoint = os.getenv("PADDLE_CURRENT_IP") + ":" + port
num_trainers = len(worker_endpoints)
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
......@@ -182,6 +186,7 @@ def append_nccl2_prepare(trainer_id, startup_prog):
t.transpile(trainer_id, trainers=','.join(worker_endpoints),
current_endpoint=current_endpoint,
startup_program=startup_prog)
return num_trainers, trainer_id
def dist_transpile(trainer_id, args, train_prog, startup_prog):
......@@ -281,12 +286,12 @@ def test_single(exe, test_args, args, test_prog):
def train_parallel(train_args, test_args, args, train_prog, test_prog,
startup_prog, nccl_id_var, num_trainers, trainer_id):
startup_prog, num_trainers, trainer_id):
over_all_start = time.time()
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)
if nccl_id_var and trainer_id == 0:
#FIXME(wuyi): wait other trainer to start listening
if args.update_method == "nccl2" and trainer_id == 0:
#FIXME(typhoonzero): wait other trainer to start listening
time.sleep(30)
startup_exe = fluid.Executor(place)
......@@ -398,8 +403,8 @@ def main():
# the unique trainer id, starting from 0, needed by trainer
# only
nccl_id_var, num_trainers, trainer_id = (
None, 1, int(os.getenv("PADDLE_TRAINER_ID", "0")))
num_trainers, trainer_id = (
1, int(os.getenv("PADDLE_TRAINER_ID", "0")))
train_prog = fluid.Program()
test_prog = fluid.Program()
......@@ -418,7 +423,7 @@ def main():
"Must configure correct environments to run dist train.")
all_args.extend([train_prog, test_prog, startup_prog])
if os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER":
all_args.extend([nccl_id_var, num_trainers, trainer_id])
all_args.extend([num_trainers, trainer_id])
train_parallel(*all_args)
elif os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER":
# start pserver with Executor
......@@ -431,10 +436,10 @@ def main():
all_args.extend([train_prog, test_prog, startup_prog])
if args.update_method == "nccl2":
nccl_id_var, num_trainers, trainer_id = append_nccl2_prepare(
num_trainers, trainer_id = append_nccl2_prepare(
trainer_id, startup_prog)
all_args.extend([nccl_id_var, num_trainers, trainer_id])
all_args.extend([num_trainers, trainer_id])
train_parallel(*all_args)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册