未验证 提交 8af1c07f 编写于 作者: W wangguanzhong 提交者: GitHub

fix save load in dygraph (#1092)

上级 13759f95
...@@ -41,8 +41,7 @@ def run(FLAGS, cfg): ...@@ -41,8 +41,7 @@ def run(FLAGS, cfg):
model = create(cfg.architecture, mode='infer', open_debug=cfg.open_debug) model = create(cfg.architecture, mode='infer', open_debug=cfg.open_debug)
# Init Model # Init Model
if os.path.isfile(cfg.weights): param_state_dict = fluid.dygraph.load_dygraph(cfg.weights)[0]
param_state_dict, opti_state_dict = fluid.load_dygraph(cfg.weights)
model.set_dict(param_state_dict) model.set_dict(param_state_dict)
# Data Reader # Data Reader
......
...@@ -89,26 +89,13 @@ def run(FLAGS, cfg): ...@@ -89,26 +89,13 @@ def run(FLAGS, cfg):
np.random.seed(local_seed) np.random.seed(local_seed)
if FLAGS.enable_ce or cfg.open_debug: if FLAGS.enable_ce or cfg.open_debug:
fluid.default_startup_program().random_seed = 1000
fluid.default_main_program().random_seed = 1000
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
if FLAGS.use_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
parallel_log = "Note: use parallel "
# Model # Model
main_arch = cfg.architecture main_arch = cfg.architecture
model = create(cfg.architecture, mode='train', open_debug=cfg.open_debug) model = create(cfg.architecture, mode='train', open_debug=cfg.open_debug)
# Parallel Model
if FLAGS.use_parallel:
#strategy = fluid.dygraph.parallel.prepare_context()
model = fluid.dygraph.parallel.DataParallel(model, strategy)
parallel_log += "with data parallel!"
print(parallel_log)
# Optimizer # Optimizer
lr = create('LearningRate')() lr = create('LearningRate')()
optimizer = create('OptimizerBuilder')(lr, model.parameters()) optimizer = create('OptimizerBuilder')(lr, model.parameters())
...@@ -122,11 +109,15 @@ def run(FLAGS, cfg): ...@@ -122,11 +109,15 @@ def run(FLAGS, cfg):
FLAGS.ckpt_type, FLAGS.ckpt_type,
open_debug=cfg.open_debug) open_debug=cfg.open_debug)
# Parallel Model
if FLAGS.use_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
model = fluid.dygraph.parallel.DataParallel(model, strategy)
# Data Reader # Data Reader
start_iter = 0 start_iter = 0
if cfg.use_gpu: if cfg.use_gpu:
devices_num = fluid.core.get_cuda_device_count( devices_num = fluid.core.get_cuda_device_count()
) if FLAGS.use_parallel else 1
else: else:
devices_num = int(os.environ.get('CPU_NUM', 1)) devices_num = int(os.environ.get('CPU_NUM', 1))
...@@ -169,7 +160,9 @@ def run(FLAGS, cfg): ...@@ -169,7 +160,9 @@ def run(FLAGS, cfg):
break break
# Save Stage # Save Stage
if iter_id > 0 and iter_id % int(cfg.snapshot_iter) == 0: if iter_id > 0 and iter_id % int(
cfg.snapshot_iter) == 0 and fluid.dygraph.parallel.Env(
).local_rank == 0:
cfg_name = os.path.basename(FLAGS.config).split('.')[0] cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_name = str( save_name = str(
iter_id) if iter_id != cfg.max_iters - 1 else "model_final" iter_id) if iter_id != cfg.max_iters - 1 else "model_final"
...@@ -187,7 +180,6 @@ def main(): ...@@ -187,7 +180,6 @@ def main():
check_version() check_version()
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if FLAGS.use_parallel else fluid.CUDAPlace(0) \
if cfg.use_gpu else fluid.CPUPlace() if cfg.use_gpu else fluid.CPUPlace()
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册