From c59bdcd6803de17f784f46cc19f201867ce17a72 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Mon, 17 Feb 2020 14:58:17 +0800 Subject: [PATCH] Update checkpoint to use paddle 1.7 API (#229) * Update checkpoint to use paddle 1.7 API * Add new checkpoint support to `load_and_fusebn` * Make yapf happy --- ppdet/utils/checkpoint.py | 138 +++++++++++++++++++++----------------- 1 file changed, 75 insertions(+), 63 deletions(-) diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index 54c364812..7b5541a2c 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -20,6 +20,7 @@ from __future__ import unicode_literals import errno import os import shutil +import tempfile import time import numpy as np import re @@ -78,6 +79,19 @@ def _get_weight_path(path): return path +def _load_state(path): + if os.path.exists(path + '.pdopt'): + # XXX another hack to ignore the optimizer state + tmp = tempfile.mkdtemp() + dst = os.path.join(tmp, os.path.basename(os.path.normpath(path))) + shutil.copy(path + '.pdparams', dst + '.pdparams') + state = fluid.io.load_program_state(dst) + shutil.rmtree(tmp) + else: + state = fluid.io.load_program_state(path) + return state + + def load_params(exe, prog, path, ignore_params=[]): """ Load model from the given path. @@ -92,30 +106,40 @@ def load_params(exe, prog, path, ignore_params=[]): if is_url(path): path = _get_weight_path(path) - - if not os.path.exists(path): + if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): raise ValueError("Model pretrain path {} does not " "exists.".format(path)) logger.info('Loading parameters from {}...'.format(path)) - def _if_exist(var): - do_ignore = False - param_exist = os.path.exists(os.path.join(path, var.name)) - if len(ignore_params) > 0: - # Parameter related to num_classes will be ignored in finetuning - do_ignore_list = [ - bool(re.match(name, var.name)) for name in ignore_params - ] - do_ignore = any(do_ignore_list) - if do_ignore and param_exist: - logger.info('In load_params, ignore {}'.format(var.name)) - do_load = param_exist and not do_ignore - if do_load: - logger.debug('load weight {}'.format(var.name)) - return do_load - - fluid.io.load_vars(exe, path, prog, predicate=_if_exist) + ignore_list = None + if ignore_params: + all_var_names = [var.name for var in prog.list_vars()] + ignore_list = filter( + lambda var: any([re.match(name, var) for name in ignore_params]), + all_var_names) + ignore_list = list(ignore_list) + + if os.path.isdir(path): + if not ignore_list: + fluid.load(prog, path, executor=exe) + return + + # XXX this is hackish, but seems to be the least contrived way... + tmp = tempfile.mkdtemp() + dst = os.path.join(tmp, os.path.basename(os.path.normpath(path))) + shutil.copytree(path, dst, ignore=shutil.ignore_patterns(*ignore_list)) + fluid.load(prog, dst, executor=exe) + shutil.rmtree(tmp) + return + + state = _load_state(path) + + if ignore_list: + for k in ignore_list: + if k in state: + del state[k] + fluid.io.set_program_state(prog, state) def load_checkpoint(exe, prog, path): @@ -128,13 +152,10 @@ def load_checkpoint(exe, prog, path): """ if is_url(path): path = _get_weight_path(path) - - if not os.path.exists(path): - raise ValueError("Model checkpoint path {} does not " + if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): + raise ValueError("Model pretrain path {} does not " "exists.".format(path)) - - logger.info('Loading checkpoint from {}...'.format(path)) - fluid.io.load_persistables(exe, path, prog) + fluid.load(prog, path, executor=exe) def global_step(scope=None): @@ -165,7 +186,7 @@ def save(exe, prog, path): if os.path.isdir(path): shutil.rmtree(path) logger.info('Save model to {}.'.format(path)) - fluid.io.save_persistables(exe, path, prog) + fluid.save(prog, path) def load_and_fusebn(exe, prog, path): @@ -186,15 +207,6 @@ def load_and_fusebn(exe, prog, path): if not os.path.exists(path): raise ValueError("Model path {} does not exists.".format(path)) - def _if_exist(var): - b = os.path.exists(os.path.join(path, var.name)) - - if b: - logger.debug('load weight {}'.format(var.name)) - return b - - all_vars = list(filter(_if_exist, prog.list_vars())) - # Since the program uses affine-channel, there is no running mean and var # in the program, here append running mean and var. # NOTE, the params of batch norm should be like: @@ -206,15 +218,25 @@ def load_and_fusebn(exe, prog, path): mean_variances = set() bn_vars = [] - bn_in_path = True + state = None + if os.path.exists(path + '.pdparams'): + state = _load_state(path) + + def check_mean_and_bias(prefix): + m = prefix + 'mean' + v = prefix + 'variance' + if state: + return v in state and m in state + else: + return (os.path.exists(os.path.join(path, m)) and + os.path.exists(os.path.join(path, v))) + + has_mean_bias = True - inner_prog = fluid.Program() - inner_start_prog = fluid.Program() - inner_block = inner_prog.global_block() - with fluid.program_guard(inner_prog, inner_start_prog): + with fluid.program_guard(prog, fluid.Program()): for block in prog.blocks: ops = list(block.ops) - if not bn_in_path: + if not has_mean_bias: break for op in ops: if op.type == 'affine_channel': @@ -224,28 +246,22 @@ def load_and_fusebn(exe, prog, path): prefix = scale_name[:-5] mean_name = prefix + 'mean' variance_name = prefix + 'variance' - - if not os.path.exists(os.path.join(path, mean_name)): - bn_in_path = False - break - if not os.path.exists(os.path.join(path, variance_name)): - bn_in_path = False + if not check_mean_and_bias(prefix): + has_mean_bias = False break bias = block.var(bias_name) - mean_vb = inner_block.create_var( + mean_vb = block.create_var( name=mean_name, type=bias.type, shape=bias.shape, - dtype=bias.dtype, - persistable=True) - variance_vb = inner_block.create_var( + dtype=bias.dtype) + variance_vb = block.create_var( name=variance_name, type=bias.type, shape=bias.shape, - dtype=bias.dtype, - persistable=True) + dtype=bias.dtype) mean_variances.add(mean_vb) mean_variances.add(variance_vb) @@ -253,21 +269,17 @@ def load_and_fusebn(exe, prog, path): bn_vars.append( [scale_name, bias_name, mean_name, variance_name]) - if not bn_in_path: - fluid.io.load_vars(exe, path, prog, vars=all_vars) + if state: + fluid.io.set_program_state(prog, state) + else: + load_params(exe, prog, path) + + if not has_mean_bias: logger.warning( "There is no paramters of batch norm in model {}. " "Skip to fuse batch norm. And load paramters done.".format(path)) return - # load running mean and running variance on cpu place into global scope. - place = fluid.CPUPlace() - exe_cpu = fluid.Executor(place) - fluid.io.load_vars(exe_cpu, path, vars=[v for v in mean_variances]) - - # load params on real place into global scope. - fluid.io.load_vars(exe, path, prog, vars=all_vars) - eps = 1e-5 for names in bn_vars: scale_name, bias_name, mean_name, var_name = names -- GitLab