From e92cb0b93c8fa1ce78b908bc776a94050fd89c16 Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Thu, 3 Dec 2020 12:50:33 +0800 Subject: [PATCH] fix init model in static mode (#444) --- tools/static/program.py | 27 ------- tools/static/save_load.py | 145 ++++++++++++++++++++++++++++++++++++++ tools/static/train.py | 5 +- 3 files changed, 149 insertions(+), 28 deletions(-) create mode 100644 tools/static/save_load.py diff --git a/tools/static/program.py b/tools/static/program.py index c57e60ab..147aa4b7 100644 --- a/tools/static/program.py +++ b/tools/static/program.py @@ -39,33 +39,6 @@ from paddle.distributed import fleet from paddle.distributed.fleet import DistributedStrategy -def _mkdir_if_not_exist(path): - """ - mkdir if not exists, ignore the exception when multiprocess mkdir together - """ - if not os.path.exists(path): - try: - os.makedirs(path) - except OSError as e: - if e.errno == errno.EEXIST and os.path.isdir(path): - logger.warning( - 'be happy if some process has already created {}'.format( - path)) - else: - raise OSError('Failed to mkdir {}'.format(path)) - - -def save_model(program, model_path, epoch_id, prefix='ppcls'): - """ - save model to the target path - """ - model_path = os.path.join(model_path, str(epoch_id)) - _mkdir_if_not_exist(model_path) - model_prefix = os.path.join(model_path, prefix) - paddle.static.save(program, model_prefix) - logger.info("Already save model in {}".format(model_path)) - - def create_feeds(image_shape, use_mix=None, use_dali=None): """ Create feeds as model input diff --git a/tools/static/save_load.py b/tools/static/save_load.py new file mode 100644 index 00000000..7f20b292 --- /dev/null +++ b/tools/static/save_load.py @@ -0,0 +1,145 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import errno +import os +import re +import shutil +import tempfile + +import paddle + +from ppcls.utils import logger + +__all__ = ['init_model', 'save_model'] + + +def _mkdir_if_not_exist(path): + """ + mkdir if not exists, ignore the exception when multiprocess mkdir together + """ + if not os.path.exists(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno == errno.EEXIST and os.path.isdir(path): + logger.warning( + 'be happy if some process has already created {}'.format( + path)) + else: + raise OSError('Failed to mkdir {}'.format(path)) + + +def _load_state(path): + if os.path.exists(path + '.pdopt'): + # XXX another hack to ignore the optimizer state + tmp = tempfile.mkdtemp() + dst = os.path.join(tmp, os.path.basename(os.path.normpath(path))) + shutil.copy(path + '.pdparams', dst + '.pdparams') + state = paddle.static.load_program_state(dst) + shutil.rmtree(tmp) + else: + state = paddle.static.load_program_state(path) + return state + + +def load_params(exe, prog, path, ignore_params=None): + """ + Load model from the given path. + Args: + exe (fluid.Executor): The fluid.Executor object. + prog (fluid.Program): load weight to which Program object. + path (string): URL string or loca model path. + ignore_params (list): ignore variable to load when finetuning. + It can be specified by finetune_exclude_pretrained_params + and the usage can refer to the document + docs/advanced_tutorials/TRANSFER_LEARNING.md + """ + if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): + raise ValueError("Model pretrain path {} does not " + "exists.".format(path)) + + logger.info( + logger.coloring('Loading parameters from {}...'.format(path), + 'HEADER')) + + ignore_set = set() + state = _load_state(path) + + # ignore the parameter which mismatch the shape + # between the model and pretrain weight. + all_var_shape = {} + for block in prog.blocks: + for param in block.all_parameters(): + all_var_shape[param.name] = param.shape + ignore_set.update([ + name for name, shape in all_var_shape.items() + if name in state and shape != state[name].shape + ]) + + if ignore_params: + all_var_names = [var.name for var in prog.list_vars()] + ignore_list = filter( + lambda var: any([re.match(name, var) for name in ignore_params]), + all_var_names) + ignore_set.update(list(ignore_list)) + + if len(ignore_set) > 0: + for k in ignore_set: + if k in state: + logger.warning( + 'variable {} is already excluded automatically'.format(k)) + del state[k] + + paddle.static.set_program_state(prog, state) + + +def init_model(config, program, exe): + """ + load model from checkpoint or pretrained_model + """ + checkpoints = config.get('checkpoints') + if checkpoints: + paddle.static.load(program, checkpoints, exe) + logger.info( + logger.coloring("Finish initing model from {}".format(checkpoints), + "HEADER")) + return + + pretrained_model = config.get('pretrained_model') + if pretrained_model: + if not isinstance(pretrained_model, list): + pretrained_model = [pretrained_model] + for pretrain in pretrained_model: + load_params(exe, program, pretrain) + logger.info( + logger.coloring("Finish initing model from {}".format( + pretrained_model), "HEADER")) + + +def save_model(program, model_path, epoch_id, prefix='ppcls'): + """ + save model to the target path + """ + model_path = os.path.join(model_path, str(epoch_id)) + _mkdir_if_not_exist(model_path) + model_prefix = os.path.join(model_path, prefix) + paddle.static.save(program, model_prefix) + logger.info( + logger.coloring("Already save model in {}".format(model_path), + "HEADER")) diff --git a/tools/static/train.py b/tools/static/train.py index 36427d87..2b44befa 100644 --- a/tools/static/train.py +++ b/tools/static/train.py @@ -32,7 +32,7 @@ from ppcls.data import Reader from ppcls.utils.config import get_config from ppcls.utils import logger from tools.static import program -from program import save_model +from save_load import init_model, save_model def parse_args(): @@ -103,6 +103,9 @@ def main(args): # Parameter initialization exe.run(startup_prog) + # load pretrained models or checkpoints + init_model(config, train_prog, exe) + if not config.get('use_dali', False): train_dataloader = Reader(config, 'train', places=place)() if config.validate and paddle.distributed.get_rank() == 0: -- GitLab