train.py 10.9 KB
Newer Older
C
chenguowei01 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
C
chenguowei01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import paddle.fluid as fluid
C
chenguowei01 已提交
19 20
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
C
chenguowei01 已提交
21
from paddle.incubate.hapi.distributed import DistributedBatchSampler
C
chenguowei01 已提交
22

C
chenguowei01 已提交
23
from datasets import OpticDiscSeg, Cityscapes
C
chenguowei01 已提交
24 25 26 27
import transforms as T
import models
import utils.logging as logging
from utils import get_environ_info
C
chenguowei01 已提交
28
from utils import load_pretrained_model
C
chenguowei01 已提交
29
from utils import resume
C
chenguowei01 已提交
30
from utils import Timer, calculate_eta
C
chenguowei01 已提交
31
from val import evaluate
C
chenguowei01 已提交
32 33 34 35 36 37


def parse_args():
    parser = argparse.ArgumentParser(description='Model training')

    # params of model
C
chenguowei01 已提交
38 39 40 41 42
    parser.add_argument(
        '--model_name',
        dest='model_name',
        help="Model type for traing, which is one of ('UNet')",
        type=str,
C
chenguowei01 已提交
43
        default='UNet')
C
chenguowei01 已提交
44 45

    # params of dataset
C
chenguowei01 已提交
46
    parser.add_argument(
C
chenguowei01 已提交
47 48
        '--dataset',
        dest='dataset',
C
chenguowei01 已提交
49 50
        help=
        "The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')",
C
chenguowei01 已提交
51
        type=str,
C
chenguowei01 已提交
52
        default='OpticDiscSeg')
C
chenguowei01 已提交
53 54

    # params of training
C
chenguowei01 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
    parser.add_argument(
        "--input_size",
        dest="input_size",
        help="The image size for net inputs.",
        nargs=2,
        default=[512, 512],
        type=int)
    parser.add_argument(
        '--num_epochs',
        dest='num_epochs',
        help='Number epochs for training',
        type=int,
        default=100)
    parser.add_argument(
        '--batch_size',
        dest='batch_size',
C
chenguowei01 已提交
71
        help='Mini batch size of one gpu or cpu',
C
chenguowei01 已提交
72 73 74 75 76 77 78 79 80 81 82
        type=int,
        default=2)
    parser.add_argument(
        '--learning_rate',
        dest='learning_rate',
        help='Learning rate',
        type=float,
        default=0.01)
    parser.add_argument(
        '--pretrained_model',
        dest='pretrained_model',
C
chenguowei01 已提交
83 84 85 86 87 88 89
        help='The path of pretrained model',
        type=str,
        default=None)
    parser.add_argument(
        '--resume_model',
        dest='resume_model',
        help='The path of resume model',
C
chenguowei01 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
        type=str,
        default=None)
    parser.add_argument(
        '--save_interval_epochs',
        dest='save_interval_epochs',
        help='The interval epochs for save a model snapshot',
        type=int,
        default=5)
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
        help='The directory for saving the model snapshot',
        type=str,
        default='./output')
    parser.add_argument(
        '--num_workers',
        dest='num_workers',
        help='Num workers for data loader',
        type=int,
        default=0)
C
chenguowei01 已提交
110 111 112 113 114
    parser.add_argument(
        '--do_eval',
        dest='do_eval',
        help='Eval while training',
        action='store_true')
C
chenguowei01 已提交
115 116 117 118 119 120
    parser.add_argument(
        '--log_steps',
        dest='log_steps',
        help='Display logging information at every log_steps',
        default=10,
        type=int)
C
add vdl  
chenguowei01 已提交
121 122 123 124 125
    parser.add_argument(
        '--use_vdl',
        dest='use_vdl',
        help='Whether to record the data during training to VisualDL',
        action='store_true')
C
chenguowei01 已提交
126 127 128 129 130 131

    return parser.parse_args()


def train(model,
          train_dataset,
C
chenguowei01 已提交
132
          places=None,
C
chenguowei01 已提交
133 134 135 136 137 138
          eval_dataset=None,
          optimizer=None,
          save_dir='output',
          num_epochs=100,
          batch_size=2,
          pretrained_model=None,
C
chenguowei01 已提交
139
          resume_model=None,
C
chenguowei01 已提交
140
          save_interval_epochs=1,
C
chenguowei01 已提交
141
          log_steps=10,
C
chenguowei01 已提交
142
          num_classes=None,
C
add vdl  
chenguowei01 已提交
143 144
          num_workers=8,
          use_vdl=False):
C
chenguowei01 已提交
145 146 147
    ignore_index = model.ignore_index
    nranks = ParallelEnv().nranks

C
chenguowei01 已提交
148 149
    start_epoch = 0
    if resume_model is not None:
C
chenguowei01 已提交
150
        start_epoch = resume(model, optimizer, resume_model)
C
chenguowei01 已提交
151 152
    elif pretrained_model is not None:
        load_pretrained_model(model, pretrained_model)
C
chenguowei01 已提交
153

C
chenguowei01 已提交
154 155
    if not os.path.isdir(save_dir):
        if os.path.exists(save_dir):
C
chenguowei01 已提交
156 157 158
            os.remove(save_dir)
        os.makedirs(save_dir)

C
chenguowei01 已提交
159 160 161
    if nranks > 1:
        strategy = fluid.dygraph.prepare_context()
        model_parallel = fluid.dygraph.DataParallel(model, strategy)
C
chenguowei01 已提交
162

C
chenguowei01 已提交
163 164
    batch_sampler = DistributedBatchSampler(
        train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
C
chenguowei01 已提交
165 166 167 168 169 170 171 172
    loader = DataLoader(
        train_dataset,
        batch_sampler=batch_sampler,
        places=places,
        num_workers=num_workers,
        return_list=True,
    )

C
add vdl  
chenguowei01 已提交
173 174 175 176
    if use_vdl:
        from visualdl import LogWriter
        log_writer = LogWriter(save_dir)

C
chenguowei01 已提交
177 178 179
    timer = Timer()
    timer.start()
    avg_loss = 0.0
C
add vdl  
chenguowei01 已提交
180 181 182
    steps_per_epoch = len(batch_sampler)
    total_steps = steps_per_epoch * (num_epochs - start_epoch)
    num_steps = 0
C
chenguowei01 已提交
183 184
    best_mean_iou = -1.0
    best_model_epoch = 1
C
chenguowei01 已提交
185
    for epoch in range(start_epoch, num_epochs):
C
chenguowei01 已提交
186
        for step, data in enumerate(loader):
C
chenguowei01 已提交
187 188
            images = data[0]
            labels = data[1].astype('int64')
C
chenguowei01 已提交
189 190 191 192 193 194 195 196
            if nranks > 1:
                loss = model_parallel(images, labels, mode='train')
                loss = model_parallel.scale_loss(loss)
                loss.backward()
                model_parallel.apply_collective_grads()
            else:
                loss = model(images, labels, mode='train')
                loss.backward()
C
chenguowei01 已提交
197
            optimizer.minimize(loss)
C
chenguowei01 已提交
198
            model.clear_gradients()
C
chenguowei01 已提交
199
            avg_loss += loss.numpy()[0]
C
chenguowei01 已提交
200
            lr = optimizer.current_step_lr()
C
add vdl  
chenguowei01 已提交
201
            num_steps += 1
C
chenguowei01 已提交
202
            if num_steps % log_steps == 0 and ParallelEnv().local_rank == 0:
C
chenguowei01 已提交
203 204
                avg_loss /= log_steps
                time_step = timer.elapsed_time() / log_steps
C
add vdl  
chenguowei01 已提交
205
                remain_steps = total_steps - num_steps
C
chenguowei01 已提交
206 207 208 209
                logging.info(
                    "[TRAIN] Epoch={}/{}, Step={}/{}, loss={:.4f}, lr={:.6f}, sec/step={:.4f} | ETA {}"
                    .format(epoch + 1, num_epochs, step + 1, steps_per_epoch,
                            avg_loss, lr, time_step,
C
add vdl  
chenguowei01 已提交
210 211 212 213
                            calculate_eta(remain_steps, time_step)))
                if use_vdl:
                    log_writer.add_scalar('Train/loss', avg_loss, num_steps)
                    log_writer.add_scalar('Train/lr', lr, num_steps)
C
chenguowei01 已提交
214 215
                avg_loss = 0.0
                timer.restart()
C
chenguowei01 已提交
216

C
chenguowei01 已提交
217
        if ((epoch + 1) % save_interval_epochs == 0
C
chenguowei01 已提交
218
                or epoch + 1 == num_epochs) and ParallelEnv().local_rank == 0:
C
chenguowei01 已提交
219 220 221
            current_save_dir = os.path.join(save_dir,
                                            "epoch_{}".format(epoch + 1))
            if not os.path.isdir(current_save_dir):
C
chenguowei01 已提交
222
                os.makedirs(current_save_dir)
C
chenguowei01 已提交
223
            fluid.save_dygraph(model.state_dict(),
C
chenguowei01 已提交
224
                               os.path.join(current_save_dir, 'model'))
C
chenguowei01 已提交
225 226
            fluid.save_dygraph(optimizer.state_dict(),
                               os.path.join(current_save_dir, 'model'))
C
chenguowei01 已提交
227

C
chenguowei01 已提交
228
            if eval_dataset is not None:
C
add vdl  
chenguowei01 已提交
229
                mean_iou, mean_acc = evaluate(
C
chenguowei01 已提交
230 231
                    model,
                    eval_dataset,
C
chenguowei01 已提交
232
                    places=places,
C
chenguowei01 已提交
233 234 235
                    model_dir=current_save_dir,
                    num_classes=num_classes,
                    batch_size=batch_size,
C
chenguowei01 已提交
236
                    ignore_index=ignore_index,
C
chenguowei01 已提交
237
                    epoch_id=epoch + 1)
C
chenguowei01 已提交
238 239 240 241 242 243 244 245 246 247
                if mean_iou > best_mean_iou:
                    best_mean_iou = mean_iou
                    best_model_epoch = epoch + 1
                    best_model_dir = os.path.join(save_dir, "best_model")
                    fluid.save_dygraph(model.state_dict(),
                                       os.path.join(best_model_dir, 'model'))
                    logging.info(
                        'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}'
                        .format(best_model_epoch, best_mean_iou))

C
add vdl  
chenguowei01 已提交
248 249
                if use_vdl:
                    log_writer.add_scalar('Evaluate/mean_iou', mean_iou,
C
chenguowei01 已提交
250
                                          epoch + 1)
C
add vdl  
chenguowei01 已提交
251
                    log_writer.add_scalar('Evaluate/mean_acc', mean_acc,
C
chenguowei01 已提交
252
                                          epoch + 1)
C
chenguowei01 已提交
253
                model.train()
C
chenguowei01 已提交
254 255
    if use_vdl:
        log_writer.close()
C
chenguowei01 已提交
256 257 258


def main(args):
C
chenguowei01 已提交
259 260
    env_info = get_environ_info()
    places = fluid.CUDAPlace(ParallelEnv().dev_id) \
C
chenguowei01 已提交
261
        if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
C
chenguowei01 已提交
262 263
        else fluid.CPUPlace()

C
chenguowei01 已提交
264 265
    if args.dataset.lower() == 'opticdiscseg':
        dataset = OpticDiscSeg
C
chenguowei01 已提交
266 267
    elif args.dataset.lower() == 'cityscapes':
        dataset = Cityscapes
C
chenguowei01 已提交
268 269
    else:
        raise Exception(
C
chenguowei01 已提交
270 271
            "The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
        )
C
chenguowei01 已提交
272

C
chenguowei01 已提交
273 274 275 276 277 278 279
    with fluid.dygraph.guard(places):
        # Creat dataset reader
        train_transforms = T.Compose([
            T.Resize(args.input_size),
            T.RandomHorizontalFlip(),
            T.Normalize()
        ])
C
chenguowei01 已提交
280
        train_dataset = dataset(transforms=train_transforms, mode='train')
C
chenguowei01 已提交
281

C
chenguowei01 已提交
282
        eval_dataset = None
C
chenguowei01 已提交
283
        if args.do_eval:
C
chenguowei01 已提交
284 285 286
            eval_transforms = T.Compose(
                [T.Resize(args.input_size),
                 T.Normalize()])
C
chenguowei01 已提交
287
            eval_dataset = dataset(transforms=eval_transforms, mode='eval')
C
chenguowei01 已提交
288 289

        if args.model_name == 'UNet':
C
chenguowei01 已提交
290 291
            model = models.UNet(
                num_classes=train_dataset.num_classes, ignore_index=255)
C
chenguowei01 已提交
292 293

        # Creat optimizer
C
chenguowei01 已提交
294 295 296
        # todo, may less one than len(loader)
        num_steps_each_epoch = len(train_dataset) // (
            args.batch_size * ParallelEnv().nranks)
C
chenguowei01 已提交
297
        decay_step = args.num_epochs * num_steps_each_epoch
C
chenguowei01 已提交
298 299
        lr_decay = fluid.layers.polynomial_decay(
            args.learning_rate, decay_step, end_learning_rate=0, power=0.9)
C
chenguowei01 已提交
300 301 302 303 304 305
        optimizer = fluid.optimizer.Momentum(
            lr_decay,
            momentum=0.9,
            parameter_list=model.parameters(),
            regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5))

C
chenguowei01 已提交
306 307 308 309 310 311 312 313 314 315
        train(
            model,
            train_dataset,
            places=places,
            eval_dataset=eval_dataset,
            optimizer=optimizer,
            save_dir=args.save_dir,
            num_epochs=args.num_epochs,
            batch_size=args.batch_size,
            pretrained_model=args.pretrained_model,
C
chenguowei01 已提交
316
            resume_model=args.resume_model,
C
chenguowei01 已提交
317
            save_interval_epochs=args.save_interval_epochs,
C
chenguowei01 已提交
318
            log_steps=args.log_steps,
C
chenguowei01 已提交
319
            num_classes=train_dataset.num_classes,
C
add vdl  
chenguowei01 已提交
320 321
            num_workers=args.num_workers,
            use_vdl=args.use_vdl)
C
chenguowei01 已提交
322 323 324 325


if __name__ == '__main__':
    args = parse_args()
C
chenguowei01 已提交
326
    main(args)