未验证 提交 87e87ae7 编写于 作者: C chengjuntao 提交者: GitHub

modify save and load to 1.7 api for rrpn (#4310)

* modify save and load to 1.7 api

* add func to load parm
上级 97808bc9
...@@ -28,6 +28,19 @@ import logging ...@@ -28,6 +28,19 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _load_state(path):
if os.path.exists(path + '.pdopt'):
# XXX another hack to ignore the optimizer state
tmp = tempfile.mkdtemp()
dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
shutil.copy(path + '.pdparams', dst + '.pdparams')
state = fluid.io.load_program_state(dst)
shutil.rmtree(tmp)
else:
state = fluid.io.load_program_state(path)
return state
def load_params(exe, prog, path): def load_params(exe, prog, path):
""" """
Load model from the given path. Load model from the given path.
...@@ -64,7 +77,7 @@ def save(exe, prog, path): ...@@ -64,7 +77,7 @@ def save(exe, prog, path):
if os.path.isdir(path): if os.path.isdir(path):
shutil.rmtree(path) shutil.rmtree(path)
logger.info('Save model to {}.'.format(path)) logger.info('Save model to {}.'.format(path))
fluid.io.save_persistables(exe, path, prog) fluid.save(prog, path)
def load_and_fusebn(exe, prog, path): def load_and_fusebn(exe, prog, path):
...@@ -81,15 +94,6 @@ def load_and_fusebn(exe, prog, path): ...@@ -81,15 +94,6 @@ def load_and_fusebn(exe, prog, path):
if not os.path.exists(path): if not os.path.exists(path):
raise ValueError("Model path {} does not exists.".format(path)) raise ValueError("Model path {} does not exists.".format(path))
def _if_exist(var):
b = os.path.exists(os.path.join(path, var.name))
if b:
logger.debug('load weight {}'.format(var.name))
return b
all_vars = list(filter(_if_exist, prog.list_vars()))
# Since the program uses affine-channel, there is no running mean and var # Since the program uses affine-channel, there is no running mean and var
# in the program, here append running mean and var. # in the program, here append running mean and var.
# NOTE, the params of batch norm should be like: # NOTE, the params of batch norm should be like:
...@@ -101,15 +105,25 @@ def load_and_fusebn(exe, prog, path): ...@@ -101,15 +105,25 @@ def load_and_fusebn(exe, prog, path):
mean_variances = set() mean_variances = set()
bn_vars = [] bn_vars = []
bn_in_path = True state = None
if os.path.exists(path + '.pdparams'):
state = _load_state(path)
inner_prog = fluid.Program() def check_mean_and_bias(prefix):
inner_start_prog = fluid.Program() m = prefix + 'mean'
inner_block = inner_prog.global_block() v = prefix + 'variance'
with fluid.program_guard(inner_prog, inner_start_prog): if state:
return v in state and m in state
else:
return (os.path.exists(os.path.join(path, m)) and
os.path.exists(os.path.join(path, v)))
has_mean_bias = True
with fluid.program_guard(prog, fluid.Program()):
for block in prog.blocks: for block in prog.blocks:
ops = list(block.ops) ops = list(block.ops)
if not bn_in_path: if not has_mean_bias:
break break
for op in ops: for op in ops:
if op.type == 'affine_channel': if op.type == 'affine_channel':
...@@ -119,28 +133,22 @@ def load_and_fusebn(exe, prog, path): ...@@ -119,28 +133,22 @@ def load_and_fusebn(exe, prog, path):
prefix = scale_name[:-5] prefix = scale_name[:-5]
mean_name = prefix + 'mean' mean_name = prefix + 'mean'
variance_name = prefix + 'variance' variance_name = prefix + 'variance'
if not check_mean_and_bias(prefix):
if not os.path.exists(os.path.join(path, mean_name)): has_mean_bias = False
bn_in_path = False
break
if not os.path.exists(os.path.join(path, variance_name)):
bn_in_path = False
break break
bias = block.var(bias_name) bias = block.var(bias_name)
mean_vb = inner_block.create_var( mean_vb = block.create_var(
name=mean_name, name=mean_name,
type=bias.type, type=bias.type,
shape=bias.shape, shape=bias.shape,
dtype=bias.dtype, dtype=bias.dtype)
persistable=True) variance_vb = block.create_var(
variance_vb = inner_block.create_var(
name=variance_name, name=variance_name,
type=bias.type, type=bias.type,
shape=bias.shape, shape=bias.shape,
dtype=bias.dtype, dtype=bias.dtype)
persistable=True)
mean_variances.add(mean_vb) mean_variances.add(mean_vb)
mean_variances.add(variance_vb) mean_variances.add(variance_vb)
...@@ -148,21 +156,16 @@ def load_and_fusebn(exe, prog, path): ...@@ -148,21 +156,16 @@ def load_and_fusebn(exe, prog, path):
bn_vars.append( bn_vars.append(
[scale_name, bias_name, mean_name, variance_name]) [scale_name, bias_name, mean_name, variance_name])
if not bn_in_path: if state:
fluid.io.load_vars(exe, path, prog, vars=all_vars) fluid.io.set_program_state(prog, state)
else:
load_params(exe, prog, path)
if not has_mean_bias:
logger.warning( logger.warning(
"There is no paramters of batch norm in model {}. " "There is no paramters of batch norm in model {}. "
"Skip to fuse batch norm. And load paramters done.".format(path)) "Skip to fuse batch norm. And load paramters done.".format(path))
return return
# load running mean and running variance on cpu place into global scope.
place = fluid.CPUPlace()
exe_cpu = fluid.Executor(place)
fluid.io.load_vars(exe_cpu, path, vars=[v for v in mean_variances])
# load params on real place into global scope.
fluid.io.load_vars(exe, path, prog, vars=all_vars)
eps = 1e-5 eps = 1e-5
for names in bn_vars: for names in bn_vars:
scale_name, bias_name, mean_name, var_name = names scale_name, bias_name, mean_name, var_name = names
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册