diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py index 4d23f62656f5561dd93f40fa97a3f7874e4b2040..c824d4404ad5f1fbee0d672fbc06d365f7d4eade 100755 --- a/ppocr/utils/utility.py +++ b/ppocr/utils/utility.py @@ -1,131 +1,89 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import errno +import logging import os -import shutil -import tempfile - -import paddle -import paddle.fluid as fluid - -from .utility import initial_logger -import re -logger = initial_logger() - - -def _mkdir_if_not_exist(path): - """ - mkdir if not exists, ignore the exception when multiprocess mkdir together - """ - if not os.path.exists(path): - try: - os.makedirs(path) - except OSError as e: - if e.errno == errno.EEXIST and os.path.isdir(path): - logger.warning( - 'be happy if some process has already created {}'.format( - path)) - else: - raise OSError('Failed to mkdir {}'.format(path)) - - -def _load_state(path): - if os.path.exists(path + '.pdopt'): - # XXX another hack to ignore the optimizer state - tmp = tempfile.mkdtemp() - dst = os.path.join(tmp, os.path.basename(os.path.normpath(path))) - shutil.copy(path + '.pdparams', dst + '.pdparams') - state = fluid.io.load_program_state(dst) - shutil.rmtree(tmp) - else: - state = fluid.io.load_program_state(path) - return state - - -def load_params(exe, prog, path, ignore_params=[]): - """ - Load model from the given path. - Args: - exe (fluid.Executor): The fluid.Executor object. - prog (fluid.Program): load weight to which Program object. - path (string): URL string or loca model path. - ignore_params (list): ignore variable to load when finetuning. - It can be specified by finetune_exclude_pretrained_params - and the usage can refer to docs/advanced_tutorials/TRANSFER_LEARNING.md - """ - if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): - raise ValueError("Model pretrain path {} does not " - "exists.".format(path)) - - logger.info('Loading parameters from {}...'.format(path)) - - ignore_set = set() - state = _load_state(path) - - # ignore the parameter which mismatch the shape - # between the model and pretrain weight. - all_var_shape = {} - for block in prog.blocks: - for param in block.all_parameters(): - all_var_shape[param.name] = param.shape - ignore_set.update([ - name for name, shape in all_var_shape.items() - if name in state and shape != state[name].shape - ]) - - if ignore_params: - all_var_names = [var.name for var in prog.list_vars()] - ignore_list = filter( - lambda var: any([re.match(name, var) for name in ignore_params]), - all_var_names) - ignore_set.update(list(ignore_list)) - - if len(ignore_set) > 0: - for k in ignore_set: - if k in state: - logger.warning('variable {} not used'.format(k)) - del state[k] - fluid.io.set_program_state(prog, state) - - -def init_model(config, program, exe): - """ - load model from checkpoint or pretrained_model - """ - checkpoints = config['Global'].get('checkpoints') - if checkpoints: - path = checkpoints - fluid.load(program, path, exe) - logger.info("Finish initing model from {}".format(path)) - return - - pretrain_weights = config['Global'].get('pretrain_weights') - if pretrain_weights: - path = pretrain_weights - load_params(exe, program, path) - logger.info("Finish initing model from {}".format(path)) - return - - -def save_model(program, model_path): - """ - save model to the target path - """ - fluid.save(program, model_path) - logger.info("Already save model in {}".format(model_path)) + + +def initial_logger(): + FORMAT = '%(asctime)s-%(levelname)s: %(message)s' + logging.basicConfig(level=logging.INFO, format=FORMAT) + logger = logging.getLogger(__name__) + return logger + + +import importlib + + +def create_module(module_str): + tmpss = module_str.split(",") + assert len(tmpss) == 2, "Error formate\ + of the module path: {}".format(module_str) + module_name, function_name = tmpss[0], tmpss[1] + somemodule = importlib.import_module(module_name, __package__) + function = getattr(somemodule, function_name) + return function + + +def get_check_global_params(mode): + check_params = ['use_gpu', 'max_text_length', 'image_shape',\ + 'image_shape', 'character_type', 'loss_type'] + if mode == "train_eval": + check_params = check_params + [\ + 'train_batch_size_per_card', 'test_batch_size_per_card'] + elif mode == "test": + check_params = check_params + ['test_batch_size_per_card'] + return check_params + + +def get_check_reader_params(mode): + check_params = [] + if mode == "train_eval": + check_params = ['TrainReader', 'EvalReader'] + elif mode == "test": + check_params = ['TestReader'] + return check_params + + +def get_image_file_list(img_file): + imgs_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + + img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp'] + if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end: + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + if single_file.split('.')[-1] in img_end: + imgs_lists.append(os.path.join(img_file, single_file)) + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + return imgs_lists + + +from paddle import fluid + + +def create_multi_devices_program(program, loss_var_name): + build_strategy = fluid.BuildStrategy() + build_strategy.memory_optimize = False + build_strategy.enable_inplace = True + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_iteration_per_drop_scope = 1 + compile_program = fluid.CompiledProgram(program).with_data_parallel( + loss_name=loss_var_name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + return compile_program