generate_multi_language_configs.py 6.2 KB
Newer Older
1 2
import yaml
from argparse import ArgumentParser, RawDescriptionHelpFormatter
只会git clone的程序员's avatar
只会git clone的程序员 已提交
3
import os.path
只会git clone的程序员's avatar
只会git clone的程序员 已提交
4 5
import logging
logging.basicConfig(level=logging.INFO)
6 7 8 9 10 11 12 13

support_list = {
    'it':'italian', 'xi':'spanish', 'pu':'portuguese', 'ru':'russian', 'ar':'arabic',
    'ta':'tamil', 'ug':'uyghur', 'fa':'persian', 'ur':'urdu', 'rs':'serbian latin',
    'oc':'occitan', 'rsc':'serbian cyrillic', 'bg':'bulgarian', 'uk':'ukranian', 'be':'belarusian',
    'te':'telugu', 'ka':'kannada', 'chinese_cht':'chinese tradition','hi':'hindi','mr':'marathi',
    'ne':'nepali',
}
只会git clone的程序员's avatar
只会git clone的程序员 已提交
14 15 16 17 18 19
assert(
    os.path.isfile("./rec_multi_language_lite_train.yml")
    ),"Loss basic configuration file rec_multi_language_lite_train.yml.\
You can download it from \
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/"
 
20
global_config = yaml.load(open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader)
只会git clone的程序员's avatar
只会git clone的程序员 已提交
21
project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../"))
22 23 24 25 26 27 28 29

class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")
        self.add_argument(
只会git clone的程序员's avatar
只会git clone的程序员 已提交
30
            "-l", "--language", nargs='+', help="set language type, support {}".format(support_list))
31
        self.add_argument(
只会git clone的程序员's avatar
只会git clone的程序员 已提交
32
            "--train",type=str,help="you can use this command to change the train dataset default path")
33
        self.add_argument(
只会git clone的程序员's avatar
只会git clone的程序员 已提交
34
            "--val",type=str,help="you can use this command to change the eval dataset default path")
35
        self.add_argument(
只会git clone的程序员's avatar
只会git clone的程序员 已提交
36 37
            "--dict",type=str,help="you can use this command to change the dictionary default path")
        self.add_argument(
只会git clone的程序员's avatar
只会git clone的程序员 已提交
38
            "--data_dir",type=str,help="you can use this command to change the dataset default root path")
39 40 41 42

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        args.opt = self._parse_opt(args.opt)
只会git clone的程序员's avatar
只会git clone的程序员 已提交
43
        args.language = self._set_language(args.language)
44 45 46 47 48 49 50 51 52 53 54 55 56
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config

    def _set_language(self, type):
只会git clone的程序员's avatar
只会git clone的程序员 已提交
57
        assert(type),"please use -l or --language to choose language type"
58 59
        assert(
                type[0] in support_list.keys()
只会git clone的程序员's avatar
只会git clone的程序员 已提交
60
               ),"the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, " \
61
                 "please check your running command".format(support_list, type)
只会git clone的程序员's avatar
只会git clone的程序员 已提交
62 63 64 65
        global_config['Global']['character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(type[0])
        global_config['Global']['save_model_dir'] = './output/rec_{}_lite'.format(type[0])
        global_config['Train']['dataset']['label_file_list'] = ["train_data/{}_train.txt".format(type[0])]
        global_config['Eval']['dataset']['label_file_list'] = ["train_data/{}_val.txt".format(type[0])]
只会git clone的程序员's avatar
只会git clone的程序员 已提交
66 67 68 69
        assert(
                os.path.isfile(os.path.join(project_path,global_config['Global']['character_dict_path']))
              ),"Loss default dictionary file {}_dict.txt.You can download it from \
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format(type[0])
只会git clone的程序员's avatar
只会git clone的程序员 已提交
70
        return type[0]
71 72


只会git clone的程序员's avatar
只会git clone的程序员 已提交
73
def merge_config(config):
74 75 76
    """
    Merge config into global config.
    Args:
只会git clone的程序员's avatar
只会git clone的程序员 已提交
77
        config (dict): Config to be merged.
78 79
    Returns: global config
    """
只会git clone的程序员's avatar
只会git clone的程序员 已提交
80
    for key, value in config.items():
81 82 83 84 85 86 87 88 89
        if "." not in key:
            if isinstance(value, dict) and key in global_config:
                global_config[key].update(value)
            else:
                global_config[key] = value
        else:
            sub_keys = key.split('.')
            assert (
                sub_keys[0] in global_config
只会git clone的程序员's avatar
只会git clone的程序员 已提交
90
            ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
91 92 93 94 95 96 97
                global_config.keys(), sub_keys[0])
            cur = global_config[sub_keys[0]]
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]
只会git clone的程序员's avatar
只会git clone的程序员 已提交
98 99 100 101
                    
def loss_file(path):
    if not os.path.exists(path):
        logging.warning('There is no such file:{},Please do not forget to put in the specified file'.format(path))
102

只会git clone的程序员's avatar
只会git clone的程序员 已提交
103
        
104 105
if __name__ == '__main__':
    FLAGS = ArgsParser().parse_args()
只会git clone的程序员's avatar
只会git clone的程序员 已提交
106
    merge_config(FLAGS.opt)
107 108
    if FLAGS.train:
        global_config['Train']['dataset']['label_file_list'] = [FLAGS.train]
只会git clone的程序员's avatar
只会git clone的程序员 已提交
109 110
        train_label_path = os.path.join(project_path,FLAGS.train)
        loss_file(train_label_path)
111 112
    if FLAGS.val:
        global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val]
只会git clone的程序员's avatar
只会git clone的程序员 已提交
113 114
        eval_label_path = os.path.join(project_path,FLAGS.val)
        loss_file(Eval_label_path)
115 116
    if FLAGS.dict:
        global_config['Global']['character_dict_path'] = FLAGS.dict
只会git clone的程序员's avatar
只会git clone的程序员 已提交
117 118 119 120 121 122 123 124 125
        dict_path = os.path.join(project_path,FLAGS.dict)
        loss_file(dict_path)
    if FLAGS.data_dir:
        global_config['Eval']['dataset']['data_dir'] = FLAGS.data_dir
        global_config['Train']['dataset']['data_dir'] = FLAGS.data_dir
        data_dir = os.path.join(project_path,FLAGS.data_dir)
        loss_file(data_dir)
        
    
只会git clone的程序员's avatar
只会git clone的程序员 已提交
126 127 128 129
    save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language)
    if os.path.isfile(save_file_path):
        os.remove(save_file_path)
    with open(save_file_path, 'w') as f:
130
        yaml.dump(dict(global_config), f, default_flow_style=False, sort_keys=False)
只会git clone的程序员's avatar
只会git clone的程序员 已提交
131 132 133 134 135 136
    logging.info("Project path is          :{}".format(project_path))
    logging.info("Train list path set to   :{}".format(global_config['Train']['dataset']['label_file_list'][0]))
    logging.info("Eval list path set to    :{}".format(global_config['Eval']['dataset']['label_file_list'][0]))
    logging.info("Dataset root path set to :{}".format(global_config['Eval']['dataset']['data_dir']))
    logging.info("Dict path set to         :{}".format(global_config['Global']['character_dict_path']))
    logging.info("Config file set to       :configs/rec/multi_language/{}".format(save_file_path))