train.py 8.0 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
W
WenmuZhou 已提交
21

22
__dir__ = os.path.dirname(os.path.abspath(__file__))
L
LDOUBLEV 已提交
23
sys.path.append(__dir__)
littletomatodonkey's avatar
littletomatodonkey 已提交
24
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
L
LDOUBLEV 已提交
25

W
WenmuZhou 已提交
26 27 28
import yaml
import paddle
import paddle.distributed as dist
L
LDOUBLEV 已提交
29

W
WenmuZhou 已提交
30
from ppocr.data import build_dataloader
D
dyning 已提交
31 32
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
W
WenmuZhou 已提交
33 34 35
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
36
from ppocr.utils.save_load import load_model
文幕地方's avatar
文幕地方 已提交
37
from ppocr.utils.utility import set_seed
38
from ppocr.modeling.architectures import apply_to_static
W
WenmuZhou 已提交
39
import tools.program as program
L
LDOUBLEV 已提交
40

W
WenmuZhou 已提交
41
dist.get_world_size()
L
LDOUBLEV 已提交
42 43


W
WenmuZhou 已提交
44 45 46 47
def main(config, device, logger, vdl_writer):
    # init dist environment
    if config['Global']['distributed']:
        dist.init_parallel_env()
L
LDOUBLEV 已提交
48

W
WenmuZhou 已提交
49
    global_config = config['Global']
D
dyning 已提交
50

W
WenmuZhou 已提交
51
    # build dataloader
D
dyning 已提交
52
    train_dataloader = build_dataloader(config, 'Train', device, logger)
W
WenmuZhou 已提交
53 54
    if len(train_dataloader) == 0:
        logger.error(
55 56 57 58
            "No Images in train dataset, please ensure\n" +
            "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
            +
            "\t2. The annotation file and path in the configuration file are provided normally."
W
WenmuZhou 已提交
59
        )
W
WenmuZhou 已提交
60
        return
W
WenmuZhou 已提交
61

D
dyning 已提交
62
    if config['Eval']:
D
dyning 已提交
63
        valid_dataloader = build_dataloader(config, 'Eval', device, logger)
W
WenmuZhou 已提交
64
    else:
D
dyning 已提交
65 66
        valid_dataloader = None

W
WenmuZhou 已提交
67
    # build post process
D
dyning 已提交
68 69 70
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

W
WenmuZhou 已提交
71
    # build model
W
WenmuZhou 已提交
72
    # for rec algorithm
W
WenmuZhou 已提交
73
    if hasattr(post_process_class, 'character'):
D
dyning 已提交
74
        char_num = len(getattr(post_process_class, 'character'))
littletomatodonkey's avatar
littletomatodonkey 已提交
75 76 77
        if config['Architecture']["algorithm"] in ["Distillation",
                                                   ]:  # distillation model
            for key in config['Architecture']["Models"]:
A
andyjpaddle 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
                if config['Architecture']['Models'][key]['Head'][
                        'name'] == 'MultiHead':  # for multi head
                    if config['PostProcess'][
                            'name'] == 'DistillationSARLabelDecode':
                        char_num = char_num - 2
                    # update SARLoss params
                    assert list(config['Loss']['loss_config_list'][-1].keys())[
                        0] == 'DistillationSARLoss'
                    config['Loss']['loss_config_list'][-1][
                        'DistillationSARLoss']['ignore_index'] = char_num + 1
                    out_channels_list = {}
                    out_channels_list['CTCLabelDecode'] = char_num
                    out_channels_list['SARLabelDecode'] = char_num + 2
                    config['Architecture']['Models'][key]['Head'][
                        'out_channels_list'] = out_channels_list
                else:
                    config['Architecture']["Models"][key]["Head"][
                        'out_channels'] = char_num
        elif config['Architecture']['Head'][
                'name'] == 'MultiHead':  # for multi head
            if config['PostProcess']['name'] == 'SARLabelDecode':
                char_num = char_num - 2
            # update SARLoss params
            assert list(config['Loss']['loss_config_list'][1].keys())[
                0] == 'SARLoss'
            if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
                config['Loss']['loss_config_list'][1]['SARLoss'] = {
                    'ignore_index': char_num + 1
                }
            else:
                config['Loss']['loss_config_list'][1]['SARLoss'][
                    'ignore_index'] = char_num + 1
            out_channels_list = {}
            out_channels_list['CTCLabelDecode'] = char_num
            out_channels_list['SARLabelDecode'] = char_num + 2
            config['Architecture']['Head'][
                'out_channels_list'] = out_channels_list
littletomatodonkey's avatar
littletomatodonkey 已提交
115 116 117
        else:  # base rec model
            config['Architecture']["Head"]['out_channels'] = char_num

A
andyjpaddle 已提交
118 119 120
        if config['PostProcess']['name'] == 'SARLabelDecode':  # for SAR model
            config['Loss']['ignore_index'] = char_num - 1

W
WenmuZhou 已提交
121
    model = build_model(config['Architecture'])
H
huangjun12 已提交
122

文幕地方's avatar
文幕地方 已提交
123 124 125 126
    use_sync_bn = config["Global"].get("use_sync_bn", False)
    if use_sync_bn:
        model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        logger.info('convert_sync_batchnorm')
W
WenmuZhou 已提交
127

128 129
    model = apply_to_static(model, config, logger)

D
dyning 已提交
130 131
    # build loss
    loss_class = build_loss(config['Loss'])
D
dyning 已提交
132

W
WenmuZhou 已提交
133
    # build optim
D
dyning 已提交
134 135
    optimizer, lr_scheduler = build_optimizer(
        config['Optimizer'],
W
WenmuZhou 已提交
136
        epochs=config['Global']['epoch_num'],
D
dyning 已提交
137
        step_each_epoch=len(train_dataloader),
T
Topdu 已提交
138
        model=model)
W
WenmuZhou 已提交
139 140 141

    # build metric
    eval_class = build_metric(config['Metric'])
H
huangjun12 已提交
142

143 144 145 146
    logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
    if valid_dataloader is not None:
        logger.info('valid dataloader has {} iters'.format(
            len(valid_dataloader)))
S
stephon 已提交
147

S
stephon 已提交
148
    use_amp = config["Global"].get("use_amp", False)
文幕地方's avatar
文幕地方 已提交
149
    amp_level = config["Global"].get("amp_level", 'O2')
H
huangjun12 已提交
150
    amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
S
stephon 已提交
151 152 153 154 155 156
    if use_amp:
        AMP_RELATED_FLAGS_SETTING = {
            'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
            'FLAGS_max_inplace_grad_add': 8,
        }
        paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
S
stephon 已提交
157 158 159
        scale_loss = config["Global"].get("scale_loss", 1.0)
        use_dynamic_loss_scaling = config["Global"].get(
            "use_dynamic_loss_scaling", False)
S
stephon 已提交
160 161 162
        scaler = paddle.amp.GradScaler(
            init_loss_scaling=scale_loss,
            use_dynamic_loss_scaling=use_dynamic_loss_scaling)
文幕地方's avatar
文幕地方 已提交
163 164
        if amp_level == "O2":
            model, optimizer = paddle.amp.decorate(
H
huangjun12 已提交
165 166 167 168
                models=model,
                optimizers=optimizer,
                level=amp_level,
                master_weight=True)
S
stephon 已提交
169 170 171
    else:
        scaler = None

文幕地方's avatar
文幕地方 已提交
172 173 174
    # load pretrain model
    pre_best_model_dict = load_model(config, model, optimizer,
                                     config['Architecture']["model_type"])
H
huangjun12 已提交
175

文幕地方's avatar
文幕地方 已提交
176 177
    if config['Global']['distributed']:
        model = paddle.DataParallel(model)
W
WenmuZhou 已提交
178
    # start train
D
dyning 已提交
179 180
    program.train(config, train_dataloader, valid_dataloader, device, model,
                  loss_class, optimizer, lr_scheduler, post_process_class,
H
huangjun12 已提交
181 182
                  eval_class, pre_best_model_dict, logger, vdl_writer, scaler,
                  amp_level, amp_custom_black_list)
D
dyning 已提交
183 184 185


def test_reader(config, device, logger):
W
WenmuZhou 已提交
186
    loader = build_dataloader(config, 'Train', device, logger)
187 188 189 190
    import time
    starttime = time.time()
    count = 0
    try:
D
dyning 已提交
191
        for data in loader():
192 193 194 195
            count += 1
            if count % 1 == 0:
                batch_time = time.time() - starttime
                starttime = time.time()
W
WenmuZhou 已提交
196 197
                logger.info("reader: {}, {}, {}".format(
                    count, len(data[0]), batch_time))
198
    except Exception as e:
L
LDOUBLEV 已提交
199 200
        logger.info(e)
    logger.info("finish reader: {}, Success!".format(count))
201

D
dyning 已提交
202

L
LDOUBLEV 已提交
203
if __name__ == '__main__':
204
    config, device, logger, vdl_writer = program.preprocess(is_train=True)
文幕地方's avatar
文幕地方 已提交
205 206
    seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024
    set_seed(seed)
D
dyning 已提交
207
    main(config, device, logger, vdl_writer)
W
WenmuZhou 已提交
208
    # test_reader(config, device, logger)