train.py 7.2 KB
Newer Older
W
wangguanzhong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

F
FDInSky 已提交
15 16 17
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Q
qingqing01 已提交
18

19 20 21 22 23 24
import os, sys
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
if parent_path not in sys.path:
    sys.path.append(parent_path)

F
FDInSky 已提交
25 26 27 28
# ignore numba warning
import warnings
warnings.filterwarnings('ignore')
import random
29
import datetime
Q
qingqing01 已提交
30
import time
F
FDInSky 已提交
31
import numpy as np
Q
qingqing01 已提交
32

W
wangxinxin08 已提交
33
import paddle
Q
qingqing01 已提交
34 35
from paddle.distributed import ParallelEnv

F
FDInSky 已提交
36
from ppdet.core.workspace import load_config, merge_config, create
W
wangguanzhong 已提交
37
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight, save_model
Q
qingqing01 已提交
38

Q
qingqing01 已提交
39 40 41 42 43
import ppdet.utils.cli as cli
import ppdet.utils.check as check
import ppdet.utils.stats as stats
from ppdet.utils.logger import setup_logger
logger = setup_logger('train')
F
FDInSky 已提交
44 45 46


def parse_args():
Q
qingqing01 已提交
47
    parser = cli.ArgsParser()
F
FDInSky 已提交
48
    parser.add_argument(
W
wangguanzhong 已提交
49
        "--weight_type",
F
FDInSky 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        default='pretrain',
        type=str,
        help="Loading Checkpoints only support 'pretrain', 'finetune', 'resume'."
    )
    parser.add_argument(
        "--fp16",
        action='store_true',
        default=False,
        help="Enable mixed precision training.")
    parser.add_argument(
        "--loss_scale",
        default=8.,
        type=float,
        help="Mixed precision training loss scale.")
    parser.add_argument(
        "--eval",
        action='store_true',
        default=False,
        help="Whether to perform evaluation in train")
    parser.add_argument(
        "--output_eval",
        default=None,
        type=str,
        help="Evaluation directory, default is current directory.")
    parser.add_argument(
        "--enable_ce",
        type=bool,
        default=False,
        help="If set True, enable continuous evaluation job."
        "This flag is only used for internal test.")
    parser.add_argument(
        "--use_gpu", action='store_true', default=False, help="data parallel")
    args = parser.parse_args()
    return args


G
Guanghua Yu 已提交
86
def run(FLAGS, cfg, place):
F
FDInSky 已提交
87 88 89 90 91 92 93 94
    env = os.environ
    FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
    if FLAGS.dist:
        trainer_id = int(env['PADDLE_TRAINER_ID'])
        local_seed = (99 + trainer_id)
        random.seed(local_seed)
        np.random.seed(local_seed)

95
    if FLAGS.enable_ce:
F
FDInSky 已提交
96 97 98
        random.seed(0)
        np.random.seed(0)

G
Guanghua Yu 已提交
99
    if ParallelEnv().nranks > 1:
100 101
        paddle.distributed.init_parallel_env()

G
Guanghua Yu 已提交
102
    # Data 
Q
qingqing01 已提交
103 104 105
    datasets = cfg.TrainDataset
    train_loader = create('TrainReader')(datasets, cfg['worker_num'])
    steps = len(train_loader)
G
Guanghua Yu 已提交
106

F
FDInSky 已提交
107
    # Model
108
    model = create(cfg.architecture)
F
FDInSky 已提交
109 110

    # Optimizer
Q
qingqing01 已提交
111
    lr = create('LearningRate')(steps)
F
FDInSky 已提交
112 113 114
    optimizer = create('OptimizerBuilder')(lr, model.parameters())

    # Init Model & Optimzer   
115
    start_epoch = 0
W
wangguanzhong 已提交
116
    if FLAGS.weight_type == 'resume':
117
        start_epoch = load_weight(model, cfg.pretrain_weights, optimizer)
W
wangguanzhong 已提交
118 119 120 121
    else:
        load_pretrain_weight(model, cfg.pretrain_weights,
                             cfg.get('load_static_weights', False),
                             FLAGS.weight_type)
F
FDInSky 已提交
122

123 124 125 126 127 128 129 130
    if getattr(model.backbone, 'norm_type', None) == 'sync_bn':
        assert cfg.use_gpu and ParallelEnv(
        ).nranks > 1, 'you should use bn rather than sync_bn while using a single gpu'
    # sync_bn = (getattr(model.backbone, 'norm_type', None) == 'sync_bn' and
    #            cfg.use_gpu and ParallelEnv().nranks > 1)
    # if sync_bn:
    #     model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)

W
wangguanzhong 已提交
131
    # Parallel Model 
G
Guanghua Yu 已提交
132
    if ParallelEnv().nranks > 1:
133
        model = paddle.DataParallel(model)
W
wangguanzhong 已提交
134

135 136
    cfg_name = os.path.basename(FLAGS.config).split('.')[0]
    save_dir = os.path.join(cfg.save_dir, cfg_name)
Q
qingqing01 已提交
137

G
Guanghua Yu 已提交
138
    # Run Train
Q
qingqing01 已提交
139 140 141 142 143 144 145 146 147
    end_epoch = int(cfg.epoch)
    batch_size = int(cfg['TrainReader']['batch_size'])
    total_steps = (end_epoch - start_epoch) * steps
    step_id = 0

    train_stats = stats.TrainingStats(cfg.log_iter)
    batch_time = stats.SmoothedValue(fmt='{avg:.4f}')
    data_time = stats.SmoothedValue(fmt='{avg:.4f}')

148
    end_time = time.time()
Q
qingqing01 已提交
149
    space_fmt = ':' + str(len(str(steps))) + 'd'
W
wangguanzhong 已提交
150
    # Run Train
Q
qingqing01 已提交
151 152
    for cur_eid in range(start_epoch, end_epoch):
        datasets.set_epoch(cur_eid)
G
Guanghua Yu 已提交
153
        for iter_id, data in enumerate(train_loader):
Q
qingqing01 已提交
154
            data_time.update(time.time() - end_time)
G
Guanghua Yu 已提交
155 156
            # Model Forward
            model.train()
K
Kaipeng Deng 已提交
157
            outputs = model(data, mode='train')
G
Guanghua Yu 已提交
158
            loss = outputs['loss']
Q
qingqing01 已提交
159
            # Model Backward
W
wangguanzhong 已提交
160
            loss.backward()
G
Guanghua Yu 已提交
161 162 163 164 165
            optimizer.step()
            curr_lr = optimizer.get_lr()
            lr.step()
            optimizer.clear_grad()

Q
qingqing01 已提交
166
            batch_time.update(time.time() - end_time)
G
Guanghua Yu 已提交
167 168 169 170
            if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
                train_stats.update(outputs)
                logs = train_stats.log()
                if iter_id % cfg.log_iter == 0:
Q
qingqing01 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
                    eta_sec = (total_steps - step_id) * batch_time.global_avg
                    eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
                    ips = float(batch_size) / batch_time.avg
                    fmt = ' '.join([
                        'Epoch: [{}]',
                        '[{' + space_fmt + '}/{}]',
                        '{meters}',
                        'eta: {eta}',
                        'batch_cost: {btime}',
                        'data_cost: {dtime}',
                        'ips: {ips:.4f} images/s',
                    ])
                    fmt = fmt.format(
                        cur_eid,
                        iter_id,
                        steps,
                        meters=logs,
                        eta=eta_str,
                        btime=str(batch_time),
                        dtime=str(data_time),
                        ips=ips)
                    logger.info(fmt)
            step_id += 1
            end_time = time.time()  # after copy outputs to CPU.
G
Guanghua Yu 已提交
195
        # Save Stage 
Q
qingqing01 已提交
196 197 198 199
        if (ParallelEnv().local_rank == 0 and \
            (cur_eid % cfg.snapshot_epoch) == 0) or (cur_eid + 1) == end_epoch:
            save_name = str(
                cur_eid) if cur_eid + 1 != end_epoch else "model_final"
200
            save_model(model, optimizer, save_dir, save_name, cur_eid + 1)
F
FDInSky 已提交
201 202 203


def main():
204 205 206 207
    FLAGS = parse_args()

    cfg = load_config(FLAGS.config)
    merge_config(FLAGS.opt)
Q
qingqing01 已提交
208 209 210
    check.check_config(cfg)
    check.check_gpu(cfg.use_gpu)
    check.check_version()
211

G
Guanghua Yu 已提交
212 213 214 215
    place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
    place = paddle.set_device(place)

    run(FLAGS, cfg, place)
F
FDInSky 已提交
216 217 218 219


if __name__ == "__main__":
    main()