未验证 提交 c59bdcd6 编写于 作者: Y Yang Zhang 提交者: GitHub

Update checkpoint to use paddle 1.7 API (#229)

* Update checkpoint to use paddle 1.7 API

* Add new checkpoint support to `load_and_fusebn`

* Make yapf happy
上级 f92927ef
......@@ -20,6 +20,7 @@ from __future__ import unicode_literals
import errno
import os
import shutil
import tempfile
import time
import numpy as np
import re
......@@ -78,6 +79,19 @@ def _get_weight_path(path):
return path
def _load_state(path):
if os.path.exists(path + '.pdopt'):
# XXX another hack to ignore the optimizer state
tmp = tempfile.mkdtemp()
dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
shutil.copy(path + '.pdparams', dst + '.pdparams')
state = fluid.io.load_program_state(dst)
shutil.rmtree(tmp)
else:
state = fluid.io.load_program_state(path)
return state
def load_params(exe, prog, path, ignore_params=[]):
"""
Load model from the given path.
......@@ -92,30 +106,40 @@ def load_params(exe, prog, path, ignore_params=[]):
if is_url(path):
path = _get_weight_path(path)
if not os.path.exists(path):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
logger.info('Loading parameters from {}...'.format(path))
def _if_exist(var):
do_ignore = False
param_exist = os.path.exists(os.path.join(path, var.name))
if len(ignore_params) > 0:
# Parameter related to num_classes will be ignored in finetuning
do_ignore_list = [
bool(re.match(name, var.name)) for name in ignore_params
]
do_ignore = any(do_ignore_list)
if do_ignore and param_exist:
logger.info('In load_params, ignore {}'.format(var.name))
do_load = param_exist and not do_ignore
if do_load:
logger.debug('load weight {}'.format(var.name))
return do_load
fluid.io.load_vars(exe, path, prog, predicate=_if_exist)
ignore_list = None
if ignore_params:
all_var_names = [var.name for var in prog.list_vars()]
ignore_list = filter(
lambda var: any([re.match(name, var) for name in ignore_params]),
all_var_names)
ignore_list = list(ignore_list)
if os.path.isdir(path):
if not ignore_list:
fluid.load(prog, path, executor=exe)
return
# XXX this is hackish, but seems to be the least contrived way...
tmp = tempfile.mkdtemp()
dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
shutil.copytree(path, dst, ignore=shutil.ignore_patterns(*ignore_list))
fluid.load(prog, dst, executor=exe)
shutil.rmtree(tmp)
return
state = _load_state(path)
if ignore_list:
for k in ignore_list:
if k in state:
del state[k]
fluid.io.set_program_state(prog, state)
def load_checkpoint(exe, prog, path):
......@@ -128,13 +152,10 @@ def load_checkpoint(exe, prog, path):
"""
if is_url(path):
path = _get_weight_path(path)
if not os.path.exists(path):
raise ValueError("Model checkpoint path {} does not "
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
logger.info('Loading checkpoint from {}...'.format(path))
fluid.io.load_persistables(exe, path, prog)
fluid.load(prog, path, executor=exe)
def global_step(scope=None):
......@@ -165,7 +186,7 @@ def save(exe, prog, path):
if os.path.isdir(path):
shutil.rmtree(path)
logger.info('Save model to {}.'.format(path))
fluid.io.save_persistables(exe, path, prog)
fluid.save(prog, path)
def load_and_fusebn(exe, prog, path):
......@@ -186,15 +207,6 @@ def load_and_fusebn(exe, prog, path):
if not os.path.exists(path):
raise ValueError("Model path {} does not exists.".format(path))
def _if_exist(var):
b = os.path.exists(os.path.join(path, var.name))
if b:
logger.debug('load weight {}'.format(var.name))
return b
all_vars = list(filter(_if_exist, prog.list_vars()))
# Since the program uses affine-channel, there is no running mean and var
# in the program, here append running mean and var.
# NOTE, the params of batch norm should be like:
......@@ -206,15 +218,25 @@ def load_and_fusebn(exe, prog, path):
mean_variances = set()
bn_vars = []
bn_in_path = True
state = None
if os.path.exists(path + '.pdparams'):
state = _load_state(path)
def check_mean_and_bias(prefix):
m = prefix + 'mean'
v = prefix + 'variance'
if state:
return v in state and m in state
else:
return (os.path.exists(os.path.join(path, m)) and
os.path.exists(os.path.join(path, v)))
has_mean_bias = True
inner_prog = fluid.Program()
inner_start_prog = fluid.Program()
inner_block = inner_prog.global_block()
with fluid.program_guard(inner_prog, inner_start_prog):
with fluid.program_guard(prog, fluid.Program()):
for block in prog.blocks:
ops = list(block.ops)
if not bn_in_path:
if not has_mean_bias:
break
for op in ops:
if op.type == 'affine_channel':
......@@ -224,28 +246,22 @@ def load_and_fusebn(exe, prog, path):
prefix = scale_name[:-5]
mean_name = prefix + 'mean'
variance_name = prefix + 'variance'
if not os.path.exists(os.path.join(path, mean_name)):
bn_in_path = False
break
if not os.path.exists(os.path.join(path, variance_name)):
bn_in_path = False
if not check_mean_and_bias(prefix):
has_mean_bias = False
break
bias = block.var(bias_name)
mean_vb = inner_block.create_var(
mean_vb = block.create_var(
name=mean_name,
type=bias.type,
shape=bias.shape,
dtype=bias.dtype,
persistable=True)
variance_vb = inner_block.create_var(
dtype=bias.dtype)
variance_vb = block.create_var(
name=variance_name,
type=bias.type,
shape=bias.shape,
dtype=bias.dtype,
persistable=True)
dtype=bias.dtype)
mean_variances.add(mean_vb)
mean_variances.add(variance_vb)
......@@ -253,21 +269,17 @@ def load_and_fusebn(exe, prog, path):
bn_vars.append(
[scale_name, bias_name, mean_name, variance_name])
if not bn_in_path:
fluid.io.load_vars(exe, path, prog, vars=all_vars)
if state:
fluid.io.set_program_state(prog, state)
else:
load_params(exe, prog, path)
if not has_mean_bias:
logger.warning(
"There is no paramters of batch norm in model {}. "
"Skip to fuse batch norm. And load paramters done.".format(path))
return
# load running mean and running variance on cpu place into global scope.
place = fluid.CPUPlace()
exe_cpu = fluid.Executor(place)
fluid.io.load_vars(exe_cpu, path, vars=[v for v in mean_variances])
# load params on real place into global scope.
fluid.io.load_vars(exe, path, prog, vars=all_vars)
eps = 1e-5
for names in bn_vars:
scale_name, bias_name, mean_name, var_name = names
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册