diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py index c824d4404ad5f1fbee0d672fbc06d365f7d4eade..74a200f5e5661ecfe1409290871a931bdf18e99d 100755 --- a/ppocr/utils/utility.py +++ b/ppocr/utils/utility.py @@ -1,89 +1,128 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function -def initial_logger(): - FORMAT = '%(asctime)s-%(levelname)s: %(message)s' - logging.basicConfig(level=logging.INFO, format=FORMAT) - logger = logging.getLogger(__name__) - return logger - - -import importlib - - -def create_module(module_str): - tmpss = module_str.split(",") - assert len(tmpss) == 2, "Error formate\ - of the module path: {}".format(module_str) - module_name, function_name = tmpss[0], tmpss[1] - somemodule = importlib.import_module(module_name, __package__) - function = getattr(somemodule, function_name) - return function - - -def get_check_global_params(mode): - check_params = ['use_gpu', 'max_text_length', 'image_shape',\ - 'image_shape', 'character_type', 'loss_type'] - if mode == "train_eval": - check_params = check_params + [\ - 'train_batch_size_per_card', 'test_batch_size_per_card'] - elif mode == "test": - check_params = check_params + ['test_batch_size_per_card'] - return check_params - - -def get_check_reader_params(mode): - check_params = [] - if mode == "train_eval": - check_params = ['TrainReader', 'EvalReader'] - elif mode == "test": - check_params = ['TestReader'] - return check_params - - -def get_image_file_list(img_file): - imgs_lists = [] - if img_file is None or not os.path.exists(img_file): - raise Exception("not found any img file in {}".format(img_file)) - - img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp'] - if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end: - imgs_lists.append(img_file) - elif os.path.isdir(img_file): - for single_file in os.listdir(img_file): - if single_file.split('.')[-1] in img_end: - imgs_lists.append(os.path.join(img_file, single_file)) - if len(imgs_lists) == 0: - raise Exception("not found any img file in {}".format(img_file)) - return imgs_lists - - -from paddle import fluid - - -def create_multi_devices_program(program, loss_var_name): - build_strategy = fluid.BuildStrategy() - build_strategy.memory_optimize = False - build_strategy.enable_inplace = True - exec_strategy = fluid.ExecutionStrategy() - exec_strategy.num_iteration_per_drop_scope = 1 - compile_program = fluid.CompiledProgram(program).with_data_parallel( - loss_name=loss_var_name, - build_strategy=build_strategy, - exec_strategy=exec_strategy) - return compile_program +import errno +import os +import shutil +import tempfile + +import paddle.fluid as fluid + +from .utility import initial_logger +import re +logger = initial_logger() + + +def _mkdir_if_not_exist(path): + """ + mkdir if not exists, ignore the exception when multiprocess mkdir together + """ + if not os.path.exists(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno == errno.EEXIST and os.path.isdir(path): + logger.warning( + 'be happy if some process has already created {}'.format( + path)) + else: + raise OSError('Failed to mkdir {}'.format(path)) + + +def _load_state(path): + if os.path.exists(path + '.pdopt'): + # XXX another hack to ignore the optimizer state + tmp = tempfile.mkdtemp() + dst = os.path.join(tmp, os.path.basename(os.path.normpath(path))) + shutil.copy(path + '.pdparams', dst + '.pdparams') + state = fluid.io.load_program_state(dst) + shutil.rmtree(tmp) + else: + state = fluid.io.load_program_state(path) + return state + + +def load_params(exe, prog, path, ignore_params=[]): + """ + Load model from the given path. + Args: + exe (fluid.Executor): The fluid.Executor object. + prog (fluid.Program): load weight to which Program object. + path (string): URL string or loca model path. + ignore_params (list): ignore variable to load when finetuning. + It can be specified by finetune_exclude_pretrained_params + and the usage can refer to docs/advanced_tutorials/TRANSFER_LEARNING.md + """ + if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): + raise ValueError("Model pretrain path {} does not " + "exists.".format(path)) + + logger.info('Loading parameters from {}...'.format(path)) + + ignore_set = set() + state = _load_state(path) + + # ignore the parameter which mismatch the shape + # between the model and pretrain weight. + all_var_shape = {} + for block in prog.blocks: + for param in block.all_parameters(): + all_var_shape[param.name] = param.shape + ignore_set.update([ + name for name, shape in all_var_shape.items() + if name in state and shape != state[name].shape + ]) + + if ignore_params: + all_var_names = [var.name for var in prog.list_vars()] + ignore_list = filter( + lambda var: any([re.match(name, var) for name in ignore_params]), + all_var_names) + ignore_set.update(list(ignore_list)) + + if len(ignore_set) > 0: + for k in ignore_set: + if k in state: + logger.warning('variable {} not used'.format(k)) + del state[k] + fluid.io.set_program_state(prog, state) + + +def init_model(config, program, exe): + """ + load model from checkpoint or pretrained_model + """ + checkpoints = config['Global'].get('checkpoints') + if checkpoints: + path = checkpoints + fluid.load(program, path, exe) + logger.info("Finish initing model from {}".format(path)) + + pretrain_weights = config['Global'].get('pretrain_weights') + if pretrain_weights: + path = pretrain_weights + load_params(exe, program, path) + logger.info("Finish initing model from {}".format(path)) + + +def save_model(program, model_path): + """ + save model to the target path + """ + fluid.save(program, model_path) + logger.info("Already save model in {}".format(model_path))