diff --git a/tools/eval.py b/tools/eval.py index bec9b8fc0fc2043689e05886ca940cf33ad3f11f..6300710f4a087fb06e29f2d41ce1754c9d3a3548 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -41,9 +41,8 @@ def run(FLAGS, cfg): model = create(cfg.architecture, mode='infer', open_debug=cfg.open_debug) # Init Model - if os.path.isfile(cfg.weights): - param_state_dict, opti_state_dict = fluid.load_dygraph(cfg.weights) - model.set_dict(param_state_dict) + param_state_dict = fluid.dygraph.load_dygraph(cfg.weights)[0] + model.set_dict(param_state_dict) # Data Reader if FLAGS.use_gpu: diff --git a/tools/train.py b/tools/train.py index 3c1865ede657b92a19ffa6968c2ce6dbfb5e9b98..5c4c87f4fd9bdd9350152ad78c2bf26684c86f4f 100755 --- a/tools/train.py +++ b/tools/train.py @@ -89,26 +89,13 @@ def run(FLAGS, cfg): np.random.seed(local_seed) if FLAGS.enable_ce or cfg.open_debug: - fluid.default_startup_program().random_seed = 1000 - fluid.default_main_program().random_seed = 1000 random.seed(0) np.random.seed(0) - if FLAGS.use_parallel: - strategy = fluid.dygraph.parallel.prepare_context() - parallel_log = "Note: use parallel " - # Model main_arch = cfg.architecture model = create(cfg.architecture, mode='train', open_debug=cfg.open_debug) - # Parallel Model - if FLAGS.use_parallel: - #strategy = fluid.dygraph.parallel.prepare_context() - model = fluid.dygraph.parallel.DataParallel(model, strategy) - parallel_log += "with data parallel!" - print(parallel_log) - # Optimizer lr = create('LearningRate')() optimizer = create('OptimizerBuilder')(lr, model.parameters()) @@ -122,11 +109,15 @@ def run(FLAGS, cfg): FLAGS.ckpt_type, open_debug=cfg.open_debug) + # Parallel Model + if FLAGS.use_parallel: + strategy = fluid.dygraph.parallel.prepare_context() + model = fluid.dygraph.parallel.DataParallel(model, strategy) + # Data Reader start_iter = 0 if cfg.use_gpu: - devices_num = fluid.core.get_cuda_device_count( - ) if FLAGS.use_parallel else 1 + devices_num = fluid.core.get_cuda_device_count() else: devices_num = int(os.environ.get('CPU_NUM', 1)) @@ -169,7 +160,9 @@ def run(FLAGS, cfg): break # Save Stage - if iter_id > 0 and iter_id % int(cfg.snapshot_iter) == 0: + if iter_id > 0 and iter_id % int( + cfg.snapshot_iter) == 0 and fluid.dygraph.parallel.Env( + ).local_rank == 0: cfg_name = os.path.basename(FLAGS.config).split('.')[0] save_name = str( iter_id) if iter_id != cfg.max_iters - 1 else "model_final" @@ -187,7 +180,6 @@ def main(): check_version() place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ - if FLAGS.use_parallel else fluid.CUDAPlace(0) \ if cfg.use_gpu else fluid.CPUPlace() with fluid.dygraph.guard(place):