未验证 提交 8af1c07f 编写于 作者: W wangguanzhong 提交者: GitHub

fix save load in dygraph (#1092)

上级 13759f95
......@@ -41,8 +41,7 @@ def run(FLAGS, cfg):
model = create(cfg.architecture, mode='infer', open_debug=cfg.open_debug)
# Init Model
if os.path.isfile(cfg.weights):
param_state_dict, opti_state_dict = fluid.load_dygraph(cfg.weights)
param_state_dict = fluid.dygraph.load_dygraph(cfg.weights)[0]
model.set_dict(param_state_dict)
# Data Reader
......
......@@ -89,26 +89,13 @@ def run(FLAGS, cfg):
np.random.seed(local_seed)
if FLAGS.enable_ce or cfg.open_debug:
fluid.default_startup_program().random_seed = 1000
fluid.default_main_program().random_seed = 1000
random.seed(0)
np.random.seed(0)
if FLAGS.use_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
parallel_log = "Note: use parallel "
# Model
main_arch = cfg.architecture
model = create(cfg.architecture, mode='train', open_debug=cfg.open_debug)
# Parallel Model
if FLAGS.use_parallel:
#strategy = fluid.dygraph.parallel.prepare_context()
model = fluid.dygraph.parallel.DataParallel(model, strategy)
parallel_log += "with data parallel!"
print(parallel_log)
# Optimizer
lr = create('LearningRate')()
optimizer = create('OptimizerBuilder')(lr, model.parameters())
......@@ -122,11 +109,15 @@ def run(FLAGS, cfg):
FLAGS.ckpt_type,
open_debug=cfg.open_debug)
# Parallel Model
if FLAGS.use_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
model = fluid.dygraph.parallel.DataParallel(model, strategy)
# Data Reader
start_iter = 0
if cfg.use_gpu:
devices_num = fluid.core.get_cuda_device_count(
) if FLAGS.use_parallel else 1
devices_num = fluid.core.get_cuda_device_count()
else:
devices_num = int(os.environ.get('CPU_NUM', 1))
......@@ -169,7 +160,9 @@ def run(FLAGS, cfg):
break
# Save Stage
if iter_id > 0 and iter_id % int(cfg.snapshot_iter) == 0:
if iter_id > 0 and iter_id % int(
cfg.snapshot_iter) == 0 and fluid.dygraph.parallel.Env(
).local_rank == 0:
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_name = str(
iter_id) if iter_id != cfg.max_iters - 1 else "model_final"
......@@ -187,7 +180,6 @@ def main():
check_version()
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if FLAGS.use_parallel else fluid.CUDAPlace(0) \
if cfg.use_gpu else fluid.CPUPlace()
with fluid.dygraph.guard(place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册