未验证 提交 5c9deb5e 编写于 作者: W Wu Yi 提交者: GitHub

Merge pull request #1505 from typhoonzero/dist_train_fixes

refine dist train
...@@ -26,6 +26,7 @@ import six ...@@ -26,6 +26,7 @@ import six
import sys import sys
sys.path.append("..") sys.path.append("..")
import models import models
import utils
from reader import train, val from reader import train, val
def parse_args(): def parse_args():
...@@ -149,13 +150,15 @@ def get_model(args, is_train, main_prog, startup_prog): ...@@ -149,13 +150,15 @@ def get_model(args, is_train, main_prog, startup_prog):
lr = [] lr = []
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
# NOTE: we put weight decay in layers config, and remove
# weight decay on bn layers, so don't add weight decay in
# optimizer config.
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=models.learning_rate.lr_warmup( learning_rate=utils.learning_rate.lr_warmup(
fluid.layers.piecewise_decay( fluid.layers.piecewise_decay(
boundaries=bd, values=lr), boundaries=bd, values=lr),
warmup_steps, start_lr, end_lr), warmup_steps, start_lr, end_lr),
momentum=0.9, momentum=0.9)
regularization=fluid.regularizer.L2Decay(1e-4))
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
batched_reader = None batched_reader = None
...@@ -175,6 +178,7 @@ def append_nccl2_prepare(trainer_id, startup_prog): ...@@ -175,6 +178,7 @@ def append_nccl2_prepare(trainer_id, startup_prog):
for ip in worker_ips.split(","): for ip in worker_ips.split(","):
worker_endpoints.append(':'.join([ip, port])) worker_endpoints.append(':'.join([ip, port]))
current_endpoint = os.getenv("PADDLE_CURRENT_IP") + ":" + port current_endpoint = os.getenv("PADDLE_CURRENT_IP") + ":" + port
num_trainers = len(worker_endpoints)
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2" config.mode = "nccl2"
...@@ -182,6 +186,7 @@ def append_nccl2_prepare(trainer_id, startup_prog): ...@@ -182,6 +186,7 @@ def append_nccl2_prepare(trainer_id, startup_prog):
t.transpile(trainer_id, trainers=','.join(worker_endpoints), t.transpile(trainer_id, trainers=','.join(worker_endpoints),
current_endpoint=current_endpoint, current_endpoint=current_endpoint,
startup_program=startup_prog) startup_program=startup_prog)
return num_trainers, trainer_id
def dist_transpile(trainer_id, args, train_prog, startup_prog): def dist_transpile(trainer_id, args, train_prog, startup_prog):
...@@ -281,12 +286,12 @@ def test_single(exe, test_args, args, test_prog): ...@@ -281,12 +286,12 @@ def test_single(exe, test_args, args, test_prog):
def train_parallel(train_args, test_args, args, train_prog, test_prog, def train_parallel(train_args, test_args, args, train_prog, test_prog,
startup_prog, nccl_id_var, num_trainers, trainer_id): startup_prog, num_trainers, trainer_id):
over_all_start = time.time() over_all_start = time.time()
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0) place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)
if nccl_id_var and trainer_id == 0: if args.update_method == "nccl2" and trainer_id == 0:
#FIXME(wuyi): wait other trainer to start listening #FIXME(typhoonzero): wait other trainer to start listening
time.sleep(30) time.sleep(30)
startup_exe = fluid.Executor(place) startup_exe = fluid.Executor(place)
...@@ -398,8 +403,8 @@ def main(): ...@@ -398,8 +403,8 @@ def main():
# the unique trainer id, starting from 0, needed by trainer # the unique trainer id, starting from 0, needed by trainer
# only # only
nccl_id_var, num_trainers, trainer_id = ( num_trainers, trainer_id = (
None, 1, int(os.getenv("PADDLE_TRAINER_ID", "0"))) 1, int(os.getenv("PADDLE_TRAINER_ID", "0")))
train_prog = fluid.Program() train_prog = fluid.Program()
test_prog = fluid.Program() test_prog = fluid.Program()
...@@ -418,7 +423,7 @@ def main(): ...@@ -418,7 +423,7 @@ def main():
"Must configure correct environments to run dist train.") "Must configure correct environments to run dist train.")
all_args.extend([train_prog, test_prog, startup_prog]) all_args.extend([train_prog, test_prog, startup_prog])
if os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER": if os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER":
all_args.extend([nccl_id_var, num_trainers, trainer_id]) all_args.extend([num_trainers, trainer_id])
train_parallel(*all_args) train_parallel(*all_args)
elif os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER": elif os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER":
# start pserver with Executor # start pserver with Executor
...@@ -431,10 +436,10 @@ def main(): ...@@ -431,10 +436,10 @@ def main():
all_args.extend([train_prog, test_prog, startup_prog]) all_args.extend([train_prog, test_prog, startup_prog])
if args.update_method == "nccl2": if args.update_method == "nccl2":
nccl_id_var, num_trainers, trainer_id = append_nccl2_prepare( num_trainers, trainer_id = append_nccl2_prepare(
trainer_id, startup_prog) trainer_id, startup_prog)
all_args.extend([nccl_id_var, num_trainers, trainer_id]) all_args.extend([num_trainers, trainer_id])
train_parallel(*all_args) train_parallel(*all_args)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册