# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import yaml from argparse import ArgumentParser, RawDescriptionHelpFormatter import os.path import logging logging.basicConfig(level=logging.INFO) support_list = { 'it': 'italian', 'xi': 'spanish', 'pu': 'portuguese', 'ru': 'russian', 'ar': 'arabic', 'ta': 'tamil', 'ug': 'uyghur', 'fa': 'persian', 'ur': 'urdu', 'rs': 'serbian latin', 'oc': 'occitan', 'rsc': 'serbian cyrillic', 'bg': 'bulgarian', 'uk': 'ukranian', 'be': 'belarusian', 'te': 'telugu', 'ka': 'kannada', 'chinese_cht': 'chinese tradition', 'hi': 'hindi', 'mr': 'marathi', 'ne': 'nepali', } latin_lang = [ 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl', 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv', 'sw', 'tl', 'tr', 'uz', 'vi', 'latin' ] arabic_lang = ['ar', 'fa', 'ug', 'ur'] cyrillic_lang = [ 'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava', 'dar', 'inh', 'che', 'lbe', 'lez', 'tab', 'cyrillic' ] devanagari_lang = [ 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', 'sa', 'bgc', 'devanagari' ] multi_lang = latin_lang + arabic_lang + cyrillic_lang + devanagari_lang assert (os.path.isfile("./rec_multi_language_lite_train.yml") ), "Loss basic configuration file rec_multi_language_lite_train.yml.\ You can download it from \ https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/" global_config = yaml.load( open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader) project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../")) class ArgsParser(ArgumentParser): def __init__(self): super(ArgsParser, self).__init__( formatter_class=RawDescriptionHelpFormatter) self.add_argument( "-o", "--opt", nargs='+', help="set configuration options") self.add_argument( "-l", "--language", nargs='+', help="set language type, support {}".format(support_list)) self.add_argument( "--train", type=str, help="you can use this command to change the train dataset default path" ) self.add_argument( "--val", type=str, help="you can use this command to change the eval dataset default path" ) self.add_argument( "--dict", type=str, help="you can use this command to change the dictionary default path" ) self.add_argument( "--data_dir", type=str, help="you can use this command to change the dataset default root path" ) def parse_args(self, argv=None): args = super(ArgsParser, self).parse_args(argv) args.opt = self._parse_opt(args.opt) args.language = self._set_language(args.language) return args def _parse_opt(self, opts): config = {} if not opts: return config for s in opts: s = s.strip() k, v = s.split('=') config[k] = yaml.load(v, Loader=yaml.Loader) return config def _set_language(self, type): print("type:", type) lang = type[0] assert (type), "please use -l or --language to choose language type" assert( lang in support_list.keys() or lang in multi_lang ),"the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, " \ "please check your running command".format(multi_lang, type) if lang in latin_lang: lang = "latin" elif lang in arabic_lang: lang = "arabic" elif lang in cyrillic_lang: lang = "cyrillic" elif lang in devanagari_lang: lang = "devanagari" global_config['Global'][ 'character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(lang) global_config['Global'][ 'save_model_dir'] = './output/rec_{}_lite'.format(lang) global_config['Train']['dataset'][ 'label_file_list'] = ["train_data/{}_train.txt".format(lang)] global_config['Eval']['dataset'][ 'label_file_list'] = ["train_data/{}_val.txt".format(lang)] global_config['Global']['character_type'] = lang assert ( os.path.isfile( os.path.join(project_path, global_config['Global'][ 'character_dict_path'])) ), "Loss default dictionary file {}_dict.txt.You can download it from \ https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format( lang) return lang def merge_config(config): """ Merge config into global config. Args: config (dict): Config to be merged. Returns: global config """ for key, value in config.items(): if "." not in key: if isinstance(value, dict) and key in global_config: global_config[key].update(value) else: global_config[key] = value else: sub_keys = key.split('.') assert ( sub_keys[0] in global_config ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format( global_config.keys(), sub_keys[0]) cur = global_config[sub_keys[0]] for idx, sub_key in enumerate(sub_keys[1:]): if idx == len(sub_keys) - 2: cur[sub_key] = value else: cur = cur[sub_key] def loss_file(path): assert ( os.path.exists(path) ), "There is no such file:{},Please do not forget to put in the specified file".format( path) if __name__ == '__main__': FLAGS = ArgsParser().parse_args() merge_config(FLAGS.opt) save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language) if os.path.isfile(save_file_path): os.remove(save_file_path) if FLAGS.train: global_config['Train']['dataset']['label_file_list'] = [FLAGS.train] train_label_path = os.path.join(project_path, FLAGS.train) loss_file(train_label_path) if FLAGS.val: global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val] eval_label_path = os.path.join(project_path, FLAGS.val) loss_file(eval_label_path) if FLAGS.dict: global_config['Global']['character_dict_path'] = FLAGS.dict dict_path = os.path.join(project_path, FLAGS.dict) loss_file(dict_path) if FLAGS.data_dir: global_config['Eval']['dataset']['data_dir'] = FLAGS.data_dir global_config['Train']['dataset']['data_dir'] = FLAGS.data_dir data_dir = os.path.join(project_path, FLAGS.data_dir) loss_file(data_dir) with open(save_file_path, 'w') as f: yaml.dump( dict(global_config), f, default_flow_style=False, sort_keys=False) logging.info("Project path is :{}".format(project_path)) logging.info("Train list path set to :{}".format(global_config['Train'][ 'dataset']['label_file_list'][0])) logging.info("Eval list path set to :{}".format(global_config['Eval'][ 'dataset']['label_file_list'][0])) logging.info("Dataset root path set to :{}".format(global_config['Eval'][ 'dataset']['data_dir'])) logging.info("Dict path set to :{}".format(global_config['Global'][ 'character_dict_path'])) logging.info("Config file set to :configs/rec/multi_language/{}". format(save_file_path))