未验证 提交 8ffb5bff 编写于 作者: L littletomatodonkey 提交者: GitHub

fix load in metric learning (#4928)

* fix load in metric learning

* fix load in metric learning

* fix typo

* fix typo
上级 2f52d598
......@@ -33,6 +33,7 @@ from losses import SoftmaxLoss
from losses import ArcMarginLoss
from utility import add_arguments, print_arguments
from utility import fmt_time, recall_topk, get_gpu_num, check_cuda
from utility import load_params
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
......@@ -190,8 +191,7 @@ def train_async(args):
fluid.load(program=train_prog, model_path=checkpoint, executor=exe)
if pretrained_model:
fluid.load(
program=train_prog, model_path=pretrained_model, executor=exe)
load_params(exe, train_prog, pretrained_model)
if args.use_gpu:
devicenum = get_gpu_num()
......
......@@ -35,6 +35,7 @@ from losses import EmlLoss
from losses import NpairsLoss
from utility import add_arguments, print_arguments
from utility import fmt_time, recall_topk, get_gpu_num, check_cuda
from utility import load_params
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
......@@ -188,8 +189,7 @@ def train_async(args):
fluid.load(program=train_prog, model_path=checkpoint, executor=exe)
if pretrained_model:
fluid.load(
program=train_prog, model_path=pretrained_model, executor=exe)
load_params(exe, train_prog, pretrained_model)
if args.use_gpu:
devicenum = get_gpu_num()
......
......@@ -176,3 +176,62 @@ def check_cuda(use_cuda, err = \
sys.exit(1)
except Exception as e:
pass
def _load_state(path):
if os.path.exists(path + '.pdopt'):
# XXX another hack to ignore the optimizer state
tmp = tempfile.mkdtemp()
dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
shutil.copy(path + '.pdparams', dst + '.pdparams')
state = fluid.io.load_program_state(dst)
shutil.rmtree(tmp)
else:
state = fluid.io.load_program_state(path)
return state
def load_params(exe, prog, path, ignore_params=None):
"""
Load model from the given path.
Args:
exe (fluid.Executor): The fluid.Executor object.
prog (fluid.Program): load weight to which Program object.
path (string): local model path.
ignore_params (list): ignore variable to load when finetuning.
"""
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
print('Loading parameters from {}...'.format(path))
ignore_set = set()
state = _load_state(path)
# ignore the parameter which mismatch the shape
# between the model and pretrain weight.
all_var_shape = {}
for block in prog.blocks:
for param in block.all_parameters():
all_var_shape[param.name] = param.shape
ignore_set.update([
name for name, shape in all_var_shape.items()
if name in state and shape != state[name].shape
])
if ignore_params:
all_var_names = [var.name for var in prog.list_vars()]
ignore_list = filter(
lambda var: any([re.match(name, var) for name in ignore_params]),
all_var_names)
ignore_set.update(list(ignore_list))
if len(ignore_set) > 0:
for k in ignore_set:
if k in state:
print('warning: variable {} is already excluded automatically'.
format(k))
del state[k]
fluid.io.set_program_state(prog, state)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册