generate_multi_language_configs.py 7.3 KB
Newer Older
只会git clone的程序员's avatar
只会git clone的程序员 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import yaml
from argparse import ArgumentParser, RawDescriptionHelpFormatter
只会git clone的程序员's avatar
只会git clone的程序员 已提交
17
import os.path
只会git clone的程序员's avatar
只会git clone的程序员 已提交
18 19
import logging
logging.basicConfig(level=logging.INFO)
20 21

support_list = {
T
tink2123 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
    'it': 'italian',
    'es': 'spanish',
    'pt': 'portuguese',
    'ru': 'russian',
    'ar': 'arabic',
    'ta': 'tamil',
    'ug': 'uyghur',
    'fa': 'persian',
    'ur': 'urdu',
    'rs': 'serbian latin',
    'oc': 'occitan',
    'rsc': 'serbian cyrillic',
    'bg': 'bulgarian',
    'uk': 'ukranian',
    'be': 'belarusian',
    'te': 'telugu',
    'ka': 'kannada',
    'chinese_cht': 'chinese tradition',
    'hi': 'hindi',
    'mr': 'marathi',
    'ne': 'nepali',
43
}
T
tink2123 已提交
44 45
assert (os.path.isfile("./rec_multi_language_lite_train.yml")
        ), "Loss basic configuration file rec_multi_language_lite_train.yml.\
只会git clone的程序员's avatar
只会git clone的程序员 已提交
46 47
You can download it from \
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/"
T
tink2123 已提交
48 49 50

global_config = yaml.load(
    open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader)
只会git clone的程序员's avatar
只会git clone的程序员 已提交
51
project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../"))
52

T
tink2123 已提交
53

54 55 56 57 58 59 60
class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")
        self.add_argument(
T
tink2123 已提交
61 62 63 64
            "-l",
            "--language",
            nargs='+',
            help="set language type, support {}".format(support_list))
65
        self.add_argument(
T
tink2123 已提交
66 67 68 69
            "--train",
            type=str,
            help="you can use this command to change the train dataset default path"
        )
70
        self.add_argument(
T
tink2123 已提交
71 72 73 74
            "--val",
            type=str,
            help="you can use this command to change the eval dataset default path"
        )
75
        self.add_argument(
T
tink2123 已提交
76 77 78 79
            "--dict",
            type=str,
            help="you can use this command to change the dictionary default path"
        )
只会git clone的程序员's avatar
只会git clone的程序员 已提交
80
        self.add_argument(
T
tink2123 已提交
81 82 83 84
            "--data_dir",
            type=str,
            help="you can use this command to change the dataset default root path"
        )
85 86 87 88

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        args.opt = self._parse_opt(args.opt)
只会git clone的程序员's avatar
只会git clone的程序员 已提交
89
        args.language = self._set_language(args.language)
90 91 92 93 94 95 96 97 98 99 100 101 102
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config

    def _set_language(self, type):
T
tink2123 已提交
103
        assert (type), "please use -l or --language to choose language type"
104 105
        assert(
                type[0] in support_list.keys()
只会git clone的程序员's avatar
只会git clone的程序员 已提交
106
               ),"the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, " \
107
                 "please check your running command".format(support_list, type)
T
tink2123 已提交
108 109 110 111 112 113 114 115 116
        global_config['Global'][
            'character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(type[
                0])
        global_config['Global'][
            'save_model_dir'] = './output/rec_{}_lite'.format(type[0])
        global_config['Train']['dataset'][
            'label_file_list'] = ["train_data/{}_train.txt".format(type[0])]
        global_config['Eval']['dataset'][
            'label_file_list'] = ["train_data/{}_val.txt".format(type[0])]
只会git clone的程序员's avatar
只会git clone的程序员 已提交
117
        global_config['Global']['character_type'] = type[0]
T
tink2123 已提交
118 119 120 121 122 123 124
        assert (
            os.path.isfile(
                os.path.join(project_path, global_config['Global'][
                    'character_dict_path']))
        ), "Loss default dictionary file {}_dict.txt.You can download it from \
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format(
            type[0])
只会git clone的程序员's avatar
只会git clone的程序员 已提交
125
        return type[0]
126 127


只会git clone的程序员's avatar
只会git clone的程序员 已提交
128
def merge_config(config):
129 130 131
    """
    Merge config into global config.
    Args:
只会git clone的程序员's avatar
只会git clone的程序员 已提交
132
        config (dict): Config to be merged.
133 134
    Returns: global config
    """
只会git clone的程序员's avatar
只会git clone的程序员 已提交
135
    for key, value in config.items():
136 137 138 139 140 141 142 143 144
        if "." not in key:
            if isinstance(value, dict) and key in global_config:
                global_config[key].update(value)
            else:
                global_config[key] = value
        else:
            sub_keys = key.split('.')
            assert (
                sub_keys[0] in global_config
只会git clone的程序员's avatar
只会git clone的程序员 已提交
145
            ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
146 147 148 149 150 151 152
                global_config.keys(), sub_keys[0])
            cur = global_config[sub_keys[0]]
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]
T
tink2123 已提交
153 154


只会git clone的程序员's avatar
只会git clone的程序员 已提交
155
def loss_file(path):
T
tink2123 已提交
156 157 158 159 160
    assert (
        os.path.exists(path)
    ), "There is no such file:{},Please do not forget to put in the specified file".format(
        path)

161 162 163

if __name__ == '__main__':
    FLAGS = ArgsParser().parse_args()
只会git clone的程序员's avatar
只会git clone的程序员 已提交
164
    merge_config(FLAGS.opt)
只会git clone的程序员's avatar
只会git clone的程序员 已提交
165 166 167
    save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language)
    if os.path.isfile(save_file_path):
        os.remove(save_file_path)
T
tink2123 已提交
168

169 170
    if FLAGS.train:
        global_config['Train']['dataset']['label_file_list'] = [FLAGS.train]
T
tink2123 已提交
171
        train_label_path = os.path.join(project_path, FLAGS.train)
只会git clone的程序员's avatar
只会git clone的程序员 已提交
172
        loss_file(train_label_path)
173 174
    if FLAGS.val:
        global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val]
T
tink2123 已提交
175
        eval_label_path = os.path.join(project_path, FLAGS.val)
只会git clone的程序员's avatar
只会git clone的程序员 已提交
176
        loss_file(eval_label_path)
177 178
    if FLAGS.dict:
        global_config['Global']['character_dict_path'] = FLAGS.dict
T
tink2123 已提交
179
        dict_path = os.path.join(project_path, FLAGS.dict)
只会git clone的程序员's avatar
只会git clone的程序员 已提交
180 181 182 183
        loss_file(dict_path)
    if FLAGS.data_dir:
        global_config['Eval']['dataset']['data_dir'] = FLAGS.data_dir
        global_config['Train']['dataset']['data_dir'] = FLAGS.data_dir
T
tink2123 已提交
184
        data_dir = os.path.join(project_path, FLAGS.data_dir)
只会git clone的程序员's avatar
只会git clone的程序员 已提交
185
        loss_file(data_dir)
T
tink2123 已提交
186

只会git clone的程序员's avatar
只会git clone的程序员 已提交
187
    with open(save_file_path, 'w') as f:
T
tink2123 已提交
188 189
        yaml.dump(
            dict(global_config), f, default_flow_style=False, sort_keys=False)
只会git clone的程序员's avatar
只会git clone的程序员 已提交
190
    logging.info("Project path is          :{}".format(project_path))
T
tink2123 已提交
191 192 193 194 195 196 197 198 199 200
    logging.info("Train list path set to   :{}".format(global_config['Train'][
        'dataset']['label_file_list'][0]))
    logging.info("Eval list path set to    :{}".format(global_config['Eval'][
        'dataset']['label_file_list'][0]))
    logging.info("Dataset root path set to :{}".format(global_config['Eval'][
        'dataset']['data_dir']))
    logging.info("Dict path set to         :{}".format(global_config['Global'][
        'character_dict_path']))
    logging.info("Config file set to       :configs/rec/multi_language/{}".
                 format(save_file_path))