未验证 提交 c59bdcd6 编写于 作者: Y Yang Zhang 提交者: GitHub

Update checkpoint to use paddle 1.7 API (#229)

* Update checkpoint to use paddle 1.7 API

* Add new checkpoint support to `load_and_fusebn`

* Make yapf happy
上级 f92927ef
...@@ -20,6 +20,7 @@ from __future__ import unicode_literals ...@@ -20,6 +20,7 @@ from __future__ import unicode_literals
import errno import errno
import os import os
import shutil import shutil
import tempfile
import time import time
import numpy as np import numpy as np
import re import re
...@@ -78,6 +79,19 @@ def _get_weight_path(path): ...@@ -78,6 +79,19 @@ def _get_weight_path(path):
return path return path
def _load_state(path):
if os.path.exists(path + '.pdopt'):
# XXX another hack to ignore the optimizer state
tmp = tempfile.mkdtemp()
dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
shutil.copy(path + '.pdparams', dst + '.pdparams')
state = fluid.io.load_program_state(dst)
shutil.rmtree(tmp)
else:
state = fluid.io.load_program_state(path)
return state
def load_params(exe, prog, path, ignore_params=[]): def load_params(exe, prog, path, ignore_params=[]):
""" """
Load model from the given path. Load model from the given path.
...@@ -92,30 +106,40 @@ def load_params(exe, prog, path, ignore_params=[]): ...@@ -92,30 +106,40 @@ def load_params(exe, prog, path, ignore_params=[]):
if is_url(path): if is_url(path):
path = _get_weight_path(path) path = _get_weight_path(path)
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
if not os.path.exists(path):
raise ValueError("Model pretrain path {} does not " raise ValueError("Model pretrain path {} does not "
"exists.".format(path)) "exists.".format(path))
logger.info('Loading parameters from {}...'.format(path)) logger.info('Loading parameters from {}...'.format(path))
def _if_exist(var): ignore_list = None
do_ignore = False if ignore_params:
param_exist = os.path.exists(os.path.join(path, var.name)) all_var_names = [var.name for var in prog.list_vars()]
if len(ignore_params) > 0: ignore_list = filter(
# Parameter related to num_classes will be ignored in finetuning lambda var: any([re.match(name, var) for name in ignore_params]),
do_ignore_list = [ all_var_names)
bool(re.match(name, var.name)) for name in ignore_params ignore_list = list(ignore_list)
]
do_ignore = any(do_ignore_list) if os.path.isdir(path):
if do_ignore and param_exist: if not ignore_list:
logger.info('In load_params, ignore {}'.format(var.name)) fluid.load(prog, path, executor=exe)
do_load = param_exist and not do_ignore return
if do_load:
logger.debug('load weight {}'.format(var.name)) # XXX this is hackish, but seems to be the least contrived way...
return do_load tmp = tempfile.mkdtemp()
dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
fluid.io.load_vars(exe, path, prog, predicate=_if_exist) shutil.copytree(path, dst, ignore=shutil.ignore_patterns(*ignore_list))
fluid.load(prog, dst, executor=exe)
shutil.rmtree(tmp)
return
state = _load_state(path)
if ignore_list:
for k in ignore_list:
if k in state:
del state[k]
fluid.io.set_program_state(prog, state)
def load_checkpoint(exe, prog, path): def load_checkpoint(exe, prog, path):
...@@ -128,13 +152,10 @@ def load_checkpoint(exe, prog, path): ...@@ -128,13 +152,10 @@ def load_checkpoint(exe, prog, path):
""" """
if is_url(path): if is_url(path):
path = _get_weight_path(path) path = _get_weight_path(path)
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
if not os.path.exists(path): raise ValueError("Model pretrain path {} does not "
raise ValueError("Model checkpoint path {} does not "
"exists.".format(path)) "exists.".format(path))
fluid.load(prog, path, executor=exe)
logger.info('Loading checkpoint from {}...'.format(path))
fluid.io.load_persistables(exe, path, prog)
def global_step(scope=None): def global_step(scope=None):
...@@ -165,7 +186,7 @@ def save(exe, prog, path): ...@@ -165,7 +186,7 @@ def save(exe, prog, path):
if os.path.isdir(path): if os.path.isdir(path):
shutil.rmtree(path) shutil.rmtree(path)
logger.info('Save model to {}.'.format(path)) logger.info('Save model to {}.'.format(path))
fluid.io.save_persistables(exe, path, prog) fluid.save(prog, path)
def load_and_fusebn(exe, prog, path): def load_and_fusebn(exe, prog, path):
...@@ -186,15 +207,6 @@ def load_and_fusebn(exe, prog, path): ...@@ -186,15 +207,6 @@ def load_and_fusebn(exe, prog, path):
if not os.path.exists(path): if not os.path.exists(path):
raise ValueError("Model path {} does not exists.".format(path)) raise ValueError("Model path {} does not exists.".format(path))
def _if_exist(var):
b = os.path.exists(os.path.join(path, var.name))
if b:
logger.debug('load weight {}'.format(var.name))
return b
all_vars = list(filter(_if_exist, prog.list_vars()))
# Since the program uses affine-channel, there is no running mean and var # Since the program uses affine-channel, there is no running mean and var
# in the program, here append running mean and var. # in the program, here append running mean and var.
# NOTE, the params of batch norm should be like: # NOTE, the params of batch norm should be like:
...@@ -206,15 +218,25 @@ def load_and_fusebn(exe, prog, path): ...@@ -206,15 +218,25 @@ def load_and_fusebn(exe, prog, path):
mean_variances = set() mean_variances = set()
bn_vars = [] bn_vars = []
bn_in_path = True state = None
if os.path.exists(path + '.pdparams'):
state = _load_state(path)
def check_mean_and_bias(prefix):
m = prefix + 'mean'
v = prefix + 'variance'
if state:
return v in state and m in state
else:
return (os.path.exists(os.path.join(path, m)) and
os.path.exists(os.path.join(path, v)))
has_mean_bias = True
inner_prog = fluid.Program() with fluid.program_guard(prog, fluid.Program()):
inner_start_prog = fluid.Program()
inner_block = inner_prog.global_block()
with fluid.program_guard(inner_prog, inner_start_prog):
for block in prog.blocks: for block in prog.blocks:
ops = list(block.ops) ops = list(block.ops)
if not bn_in_path: if not has_mean_bias:
break break
for op in ops: for op in ops:
if op.type == 'affine_channel': if op.type == 'affine_channel':
...@@ -224,28 +246,22 @@ def load_and_fusebn(exe, prog, path): ...@@ -224,28 +246,22 @@ def load_and_fusebn(exe, prog, path):
prefix = scale_name[:-5] prefix = scale_name[:-5]
mean_name = prefix + 'mean' mean_name = prefix + 'mean'
variance_name = prefix + 'variance' variance_name = prefix + 'variance'
if not check_mean_and_bias(prefix):
if not os.path.exists(os.path.join(path, mean_name)): has_mean_bias = False
bn_in_path = False
break
if not os.path.exists(os.path.join(path, variance_name)):
bn_in_path = False
break break
bias = block.var(bias_name) bias = block.var(bias_name)
mean_vb = inner_block.create_var( mean_vb = block.create_var(
name=mean_name, name=mean_name,
type=bias.type, type=bias.type,
shape=bias.shape, shape=bias.shape,
dtype=bias.dtype, dtype=bias.dtype)
persistable=True) variance_vb = block.create_var(
variance_vb = inner_block.create_var(
name=variance_name, name=variance_name,
type=bias.type, type=bias.type,
shape=bias.shape, shape=bias.shape,
dtype=bias.dtype, dtype=bias.dtype)
persistable=True)
mean_variances.add(mean_vb) mean_variances.add(mean_vb)
mean_variances.add(variance_vb) mean_variances.add(variance_vb)
...@@ -253,21 +269,17 @@ def load_and_fusebn(exe, prog, path): ...@@ -253,21 +269,17 @@ def load_and_fusebn(exe, prog, path):
bn_vars.append( bn_vars.append(
[scale_name, bias_name, mean_name, variance_name]) [scale_name, bias_name, mean_name, variance_name])
if not bn_in_path: if state:
fluid.io.load_vars(exe, path, prog, vars=all_vars) fluid.io.set_program_state(prog, state)
else:
load_params(exe, prog, path)
if not has_mean_bias:
logger.warning( logger.warning(
"There is no paramters of batch norm in model {}. " "There is no paramters of batch norm in model {}. "
"Skip to fuse batch norm. And load paramters done.".format(path)) "Skip to fuse batch norm. And load paramters done.".format(path))
return return
# load running mean and running variance on cpu place into global scope.
place = fluid.CPUPlace()
exe_cpu = fluid.Executor(place)
fluid.io.load_vars(exe_cpu, path, vars=[v for v in mean_variances])
# load params on real place into global scope.
fluid.io.load_vars(exe, path, prog, vars=all_vars)
eps = 1e-5 eps = 1e-5
for names in bn_vars: for names in bn_vars:
scale_name, bias_name, mean_name, var_name = names scale_name, bias_name, mean_name, var_name = names
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册