diff --git a/configs/yolov3_reader.yml b/configs/yolov3_reader.yml index d167c70bc1f97db6dc997c9d535aeca92d4e3473..5e11364b0ad831e68513f60ea68ca511e1018a8b 100644 --- a/configs/yolov3_reader.yml +++ b/configs/yolov3_reader.yml @@ -52,7 +52,7 @@ TrainReader: drop_last: true worker_num: 4 bufsize: 4 - use_process: false #true + use_process: true EvalReader: diff --git a/tools/train.py b/tools/train.py index 5eba5f30c58d9027941f450ce83e1c3507571c62..c37d45f20ecc037ac4bf71200a361041799b1894 100755 --- a/tools/train.py +++ b/tools/train.py @@ -87,14 +87,7 @@ def parse_args(): return args -def run(): - FLAGS = parse_args() - - cfg = load_config(FLAGS.config) - merge_config(FLAGS.opt) - check_config(cfg) - check_gpu(cfg.use_gpu) - check_version() +def run(FLAGS, cfg): env = os.environ FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env @@ -108,6 +101,9 @@ def run(): random.seed(0) np.random.seed(0) + if dist.ParallelEnv().nranks > 1: + paddle.distributed.init_parallel_env() + # Model main_arch = cfg.architecture model = create(cfg.architecture) @@ -126,8 +122,7 @@ def run(): # Parallel Model if dist.ParallelEnv().nranks > 1: - strategy = paddle.distributed.init_parallel_env() - model = paddle.DataParallel(model, strategy) + model = paddle.DataParallel(model) # Data Reader start_iter = 0 @@ -137,7 +132,9 @@ def run(): devices_num = int(os.environ.get('CPU_NUM', 1)) train_reader = create_reader( - cfg.TrainReader, (cfg.max_iters - start_iter), cfg, devices_num=1) + cfg.TrainReader, (cfg.max_iters - start_iter), + cfg, + devices_num=devices_num) time_stat = deque(maxlen=cfg.log_iter) start_time = time.time() @@ -193,7 +190,15 @@ def run(): def main(): - dist.spawn(run) + FLAGS = parse_args() + + cfg = load_config(FLAGS.config) + merge_config(FLAGS.opt) + check_config(cfg) + check_gpu(cfg.use_gpu) + check_version() + + run(FLAGS, cfg) if __name__ == "__main__":