未验证 提交 3ab53043 编写于 作者: W wangxinxin08 提交者: GitHub

[Dygraph] modify code to be suited for launch in 2.0rc (#1587)

* modify code to be suited for launch in 2.0rc

* modify code according to review
上级 a375f671
...@@ -52,7 +52,7 @@ TrainReader: ...@@ -52,7 +52,7 @@ TrainReader:
drop_last: true drop_last: true
worker_num: 4 worker_num: 4
bufsize: 4 bufsize: 4
use_process: false #true use_process: true
EvalReader: EvalReader:
......
...@@ -87,14 +87,7 @@ def parse_args(): ...@@ -87,14 +87,7 @@ def parse_args():
return args return args
def run(): def run(FLAGS, cfg):
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
env = os.environ env = os.environ
FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
...@@ -108,6 +101,9 @@ def run(): ...@@ -108,6 +101,9 @@ def run():
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
if dist.ParallelEnv().nranks > 1:
paddle.distributed.init_parallel_env()
# Model # Model
main_arch = cfg.architecture main_arch = cfg.architecture
model = create(cfg.architecture) model = create(cfg.architecture)
...@@ -126,8 +122,7 @@ def run(): ...@@ -126,8 +122,7 @@ def run():
# Parallel Model # Parallel Model
if dist.ParallelEnv().nranks > 1: if dist.ParallelEnv().nranks > 1:
strategy = paddle.distributed.init_parallel_env() model = paddle.DataParallel(model)
model = paddle.DataParallel(model, strategy)
# Data Reader # Data Reader
start_iter = 0 start_iter = 0
...@@ -137,7 +132,9 @@ def run(): ...@@ -137,7 +132,9 @@ def run():
devices_num = int(os.environ.get('CPU_NUM', 1)) devices_num = int(os.environ.get('CPU_NUM', 1))
train_reader = create_reader( train_reader = create_reader(
cfg.TrainReader, (cfg.max_iters - start_iter), cfg, devices_num=1) cfg.TrainReader, (cfg.max_iters - start_iter),
cfg,
devices_num=devices_num)
time_stat = deque(maxlen=cfg.log_iter) time_stat = deque(maxlen=cfg.log_iter)
start_time = time.time() start_time = time.time()
...@@ -193,7 +190,15 @@ def run(): ...@@ -193,7 +190,15 @@ def run():
def main(): def main():
dist.spawn(run) FLAGS = parse_args()
cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
run(FLAGS, cfg)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册