From 8ffb5bffef2915a2b218d5b57db9e27f1a89b50a Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Tue, 3 Nov 2020 11:20:02 +0800 Subject: [PATCH] fix load in metric learning (#4928) * fix load in metric learning * fix load in metric learning * fix typo * fix typo --- PaddleCV/metric_learning/train_elem.py | 4 +- PaddleCV/metric_learning/train_pair.py | 4 +- PaddleCV/metric_learning/utility.py | 59 ++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/PaddleCV/metric_learning/train_elem.py b/PaddleCV/metric_learning/train_elem.py index 3438667c..485b623c 100644 --- a/PaddleCV/metric_learning/train_elem.py +++ b/PaddleCV/metric_learning/train_elem.py @@ -33,6 +33,7 @@ from losses import SoftmaxLoss from losses import ArcMarginLoss from utility import add_arguments, print_arguments from utility import fmt_time, recall_topk, get_gpu_num, check_cuda +from utility import load_params parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) @@ -190,8 +191,7 @@ def train_async(args): fluid.load(program=train_prog, model_path=checkpoint, executor=exe) if pretrained_model: - fluid.load( - program=train_prog, model_path=pretrained_model, executor=exe) + load_params(exe, train_prog, pretrained_model) if args.use_gpu: devicenum = get_gpu_num() diff --git a/PaddleCV/metric_learning/train_pair.py b/PaddleCV/metric_learning/train_pair.py index 009ac9ae..2f3ef7a0 100644 --- a/PaddleCV/metric_learning/train_pair.py +++ b/PaddleCV/metric_learning/train_pair.py @@ -35,6 +35,7 @@ from losses import EmlLoss from losses import NpairsLoss from utility import add_arguments, print_arguments from utility import fmt_time, recall_topk, get_gpu_num, check_cuda +from utility import load_params parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) @@ -188,8 +189,7 @@ def train_async(args): fluid.load(program=train_prog, model_path=checkpoint, executor=exe) if pretrained_model: - fluid.load( - program=train_prog, model_path=pretrained_model, executor=exe) + load_params(exe, train_prog, pretrained_model) if args.use_gpu: devicenum = get_gpu_num() diff --git a/PaddleCV/metric_learning/utility.py b/PaddleCV/metric_learning/utility.py index de1f78cb..ada51e42 100644 --- a/PaddleCV/metric_learning/utility.py +++ b/PaddleCV/metric_learning/utility.py @@ -176,3 +176,62 @@ def check_cuda(use_cuda, err = \ sys.exit(1) except Exception as e: pass + + +def _load_state(path): + if os.path.exists(path + '.pdopt'): + # XXX another hack to ignore the optimizer state + tmp = tempfile.mkdtemp() + dst = os.path.join(tmp, os.path.basename(os.path.normpath(path))) + shutil.copy(path + '.pdparams', dst + '.pdparams') + state = fluid.io.load_program_state(dst) + shutil.rmtree(tmp) + else: + state = fluid.io.load_program_state(path) + return state + + +def load_params(exe, prog, path, ignore_params=None): + """ + Load model from the given path. + Args: + exe (fluid.Executor): The fluid.Executor object. + prog (fluid.Program): load weight to which Program object. + path (string): local model path. + ignore_params (list): ignore variable to load when finetuning. + """ + if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): + raise ValueError("Model pretrain path {} does not " + "exists.".format(path)) + + print('Loading parameters from {}...'.format(path)) + + ignore_set = set() + state = _load_state(path) + + # ignore the parameter which mismatch the shape + # between the model and pretrain weight. + all_var_shape = {} + for block in prog.blocks: + for param in block.all_parameters(): + all_var_shape[param.name] = param.shape + ignore_set.update([ + name for name, shape in all_var_shape.items() + if name in state and shape != state[name].shape + ]) + + if ignore_params: + all_var_names = [var.name for var in prog.list_vars()] + ignore_list = filter( + lambda var: any([re.match(name, var) for name in ignore_params]), + all_var_names) + ignore_set.update(list(ignore_list)) + + if len(ignore_set) > 0: + for k in ignore_set: + if k in state: + print('warning: variable {} is already excluded automatically'. + format(k)) + del state[k] + + fluid.io.set_program_state(prog, state) -- GitLab