未验证 提交 87e87ae7 编写于 作者: C chengjuntao 提交者: GitHub

modify save and load to 1.7 api for rrpn (#4310)

* modify save and load to 1.7 api

* add func to load parm
上级 97808bc9
......@@ -28,6 +28,19 @@ import logging
logger = logging.getLogger(__name__)
def _load_state(path):
if os.path.exists(path + '.pdopt'):
# XXX another hack to ignore the optimizer state
tmp = tempfile.mkdtemp()
dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
shutil.copy(path + '.pdparams', dst + '.pdparams')
state = fluid.io.load_program_state(dst)
shutil.rmtree(tmp)
else:
state = fluid.io.load_program_state(path)
return state
def load_params(exe, prog, path):
"""
Load model from the given path.
......@@ -64,7 +77,7 @@ def save(exe, prog, path):
if os.path.isdir(path):
shutil.rmtree(path)
logger.info('Save model to {}.'.format(path))
fluid.io.save_persistables(exe, path, prog)
fluid.save(prog, path)
def load_and_fusebn(exe, prog, path):
......@@ -81,15 +94,6 @@ def load_and_fusebn(exe, prog, path):
if not os.path.exists(path):
raise ValueError("Model path {} does not exists.".format(path))
def _if_exist(var):
b = os.path.exists(os.path.join(path, var.name))
if b:
logger.debug('load weight {}'.format(var.name))
return b
all_vars = list(filter(_if_exist, prog.list_vars()))
# Since the program uses affine-channel, there is no running mean and var
# in the program, here append running mean and var.
# NOTE, the params of batch norm should be like:
......@@ -101,15 +105,25 @@ def load_and_fusebn(exe, prog, path):
mean_variances = set()
bn_vars = []
bn_in_path = True
state = None
if os.path.exists(path + '.pdparams'):
state = _load_state(path)
inner_prog = fluid.Program()
inner_start_prog = fluid.Program()
inner_block = inner_prog.global_block()
with fluid.program_guard(inner_prog, inner_start_prog):
def check_mean_and_bias(prefix):
m = prefix + 'mean'
v = prefix + 'variance'
if state:
return v in state and m in state
else:
return (os.path.exists(os.path.join(path, m)) and
os.path.exists(os.path.join(path, v)))
has_mean_bias = True
with fluid.program_guard(prog, fluid.Program()):
for block in prog.blocks:
ops = list(block.ops)
if not bn_in_path:
if not has_mean_bias:
break
for op in ops:
if op.type == 'affine_channel':
......@@ -119,28 +133,22 @@ def load_and_fusebn(exe, prog, path):
prefix = scale_name[:-5]
mean_name = prefix + 'mean'
variance_name = prefix + 'variance'
if not os.path.exists(os.path.join(path, mean_name)):
bn_in_path = False
break
if not os.path.exists(os.path.join(path, variance_name)):
bn_in_path = False
if not check_mean_and_bias(prefix):
has_mean_bias = False
break
bias = block.var(bias_name)
mean_vb = inner_block.create_var(
mean_vb = block.create_var(
name=mean_name,
type=bias.type,
shape=bias.shape,
dtype=bias.dtype,
persistable=True)
variance_vb = inner_block.create_var(
dtype=bias.dtype)
variance_vb = block.create_var(
name=variance_name,
type=bias.type,
shape=bias.shape,
dtype=bias.dtype,
persistable=True)
dtype=bias.dtype)
mean_variances.add(mean_vb)
mean_variances.add(variance_vb)
......@@ -148,21 +156,16 @@ def load_and_fusebn(exe, prog, path):
bn_vars.append(
[scale_name, bias_name, mean_name, variance_name])
if not bn_in_path:
fluid.io.load_vars(exe, path, prog, vars=all_vars)
if state:
fluid.io.set_program_state(prog, state)
else:
load_params(exe, prog, path)
if not has_mean_bias:
logger.warning(
"There is no paramters of batch norm in model {}. "
"Skip to fuse batch norm. And load paramters done.".format(path))
return
# load running mean and running variance on cpu place into global scope.
place = fluid.CPUPlace()
exe_cpu = fluid.Executor(place)
fluid.io.load_vars(exe_cpu, path, vars=[v for v in mean_variances])
# load params on real place into global scope.
fluid.io.load_vars(exe, path, prog, vars=all_vars)
eps = 1e-5
for names in bn_vars:
scale_name, bias_name, mean_name, var_name = names
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册