train.py 11.0 KB
Newer Older
C
chenguowei01 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
C
chenguowei01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import paddle.fluid as fluid
C
chenguowei01 已提交
19 20
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
C
chenguowei01 已提交
21
from paddle.incubate.hapi.distributed import DistributedBatchSampler
C
chenguowei01 已提交
22

C
chenguowei01 已提交
23
from datasets import OpticDiscSeg, Cityscapes
C
chenguowei01 已提交
24
import transforms as T
C
chenguowei01 已提交
25
from models import MODELS
C
chenguowei01 已提交
26 27
import utils.logging as logging
from utils import get_environ_info
C
chenguowei01 已提交
28
from utils import load_pretrained_model
C
chenguowei01 已提交
29
from utils import resume
C
chenguowei01 已提交
30
from utils import Timer, calculate_eta
C
chenguowei01 已提交
31
from val import evaluate
C
chenguowei01 已提交
32 33 34 35 36 37


def parse_args():
    parser = argparse.ArgumentParser(description='Model training')

    # params of model
C
chenguowei01 已提交
38 39 40
    parser.add_argument(
        '--model_name',
        dest='model_name',
R
update  
root 已提交
41 42
        help='Model type for training, which is one of {}'.format(
            str(list(MODELS.keys()))),
C
chenguowei01 已提交
43
        type=str,
C
chenguowei01 已提交
44
        default='UNet')
C
chenguowei01 已提交
45 46

    # params of dataset
C
chenguowei01 已提交
47
    parser.add_argument(
C
chenguowei01 已提交
48 49
        '--dataset',
        dest='dataset',
C
chenguowei01 已提交
50 51
        help=
        "The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')",
C
chenguowei01 已提交
52
        type=str,
C
chenguowei01 已提交
53
        default='OpticDiscSeg')
C
chenguowei01 已提交
54 55

    # params of training
C
chenguowei01 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    parser.add_argument(
        "--input_size",
        dest="input_size",
        help="The image size for net inputs.",
        nargs=2,
        default=[512, 512],
        type=int)
    parser.add_argument(
        '--num_epochs',
        dest='num_epochs',
        help='Number epochs for training',
        type=int,
        default=100)
    parser.add_argument(
        '--batch_size',
        dest='batch_size',
C
chenguowei01 已提交
72
        help='Mini batch size of one gpu or cpu',
C
chenguowei01 已提交
73 74 75 76 77 78 79 80 81 82 83
        type=int,
        default=2)
    parser.add_argument(
        '--learning_rate',
        dest='learning_rate',
        help='Learning rate',
        type=float,
        default=0.01)
    parser.add_argument(
        '--pretrained_model',
        dest='pretrained_model',
C
chenguowei01 已提交
84 85 86 87 88 89 90
        help='The path of pretrained model',
        type=str,
        default=None)
    parser.add_argument(
        '--resume_model',
        dest='resume_model',
        help='The path of resume model',
C
chenguowei01 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
        type=str,
        default=None)
    parser.add_argument(
        '--save_interval_epochs',
        dest='save_interval_epochs',
        help='The interval epochs for save a model snapshot',
        type=int,
        default=5)
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
        help='The directory for saving the model snapshot',
        type=str,
        default='./output')
    parser.add_argument(
        '--num_workers',
        dest='num_workers',
        help='Num workers for data loader',
        type=int,
        default=0)
C
chenguowei01 已提交
111 112 113 114 115
    parser.add_argument(
        '--do_eval',
        dest='do_eval',
        help='Eval while training',
        action='store_true')
C
chenguowei01 已提交
116 117 118 119 120 121
    parser.add_argument(
        '--log_steps',
        dest='log_steps',
        help='Display logging information at every log_steps',
        default=10,
        type=int)
C
add vdl  
chenguowei01 已提交
122 123 124
    parser.add_argument(
        '--use_vdl',
        dest='use_vdl',
C
chenguowei01 已提交
125
        help='Whether to record the data to VisualDL during training',
C
add vdl  
chenguowei01 已提交
126
        action='store_true')
C
chenguowei01 已提交
127 128 129 130 131 132

    return parser.parse_args()


def train(model,
          train_dataset,
C
chenguowei01 已提交
133
          places=None,
C
chenguowei01 已提交
134 135 136 137 138 139
          eval_dataset=None,
          optimizer=None,
          save_dir='output',
          num_epochs=100,
          batch_size=2,
          pretrained_model=None,
C
chenguowei01 已提交
140
          resume_model=None,
C
chenguowei01 已提交
141
          save_interval_epochs=1,
C
chenguowei01 已提交
142
          log_steps=10,
C
chenguowei01 已提交
143
          num_classes=None,
C
add vdl  
chenguowei01 已提交
144 145
          num_workers=8,
          use_vdl=False):
C
chenguowei01 已提交
146 147 148
    ignore_index = model.ignore_index
    nranks = ParallelEnv().nranks

C
chenguowei01 已提交
149 150
    start_epoch = 0
    if resume_model is not None:
C
chenguowei01 已提交
151
        start_epoch = resume(model, optimizer, resume_model)
C
chenguowei01 已提交
152 153
    elif pretrained_model is not None:
        load_pretrained_model(model, pretrained_model)
C
chenguowei01 已提交
154

C
chenguowei01 已提交
155 156
    if not os.path.isdir(save_dir):
        if os.path.exists(save_dir):
C
chenguowei01 已提交
157 158 159
            os.remove(save_dir)
        os.makedirs(save_dir)

C
chenguowei01 已提交
160 161 162
    if nranks > 1:
        strategy = fluid.dygraph.prepare_context()
        model_parallel = fluid.dygraph.DataParallel(model, strategy)
C
chenguowei01 已提交
163

C
chenguowei01 已提交
164 165
    batch_sampler = DistributedBatchSampler(
        train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
C
chenguowei01 已提交
166 167 168 169 170 171 172 173
    loader = DataLoader(
        train_dataset,
        batch_sampler=batch_sampler,
        places=places,
        num_workers=num_workers,
        return_list=True,
    )

C
add vdl  
chenguowei01 已提交
174 175 176 177
    if use_vdl:
        from visualdl import LogWriter
        log_writer = LogWriter(save_dir)

C
chenguowei01 已提交
178 179 180
    timer = Timer()
    timer.start()
    avg_loss = 0.0
C
add vdl  
chenguowei01 已提交
181 182 183
    steps_per_epoch = len(batch_sampler)
    total_steps = steps_per_epoch * (num_epochs - start_epoch)
    num_steps = 0
C
chenguowei01 已提交
184
    best_mean_iou = -1.0
R
update  
root 已提交
185
    best_model_epoch = -1
C
chenguowei01 已提交
186
    for epoch in range(start_epoch, num_epochs):
C
chenguowei01 已提交
187
        for step, data in enumerate(loader):
C
chenguowei01 已提交
188 189
            images = data[0]
            labels = data[1].astype('int64')
C
chenguowei01 已提交
190 191 192 193 194 195 196 197
            if nranks > 1:
                loss = model_parallel(images, labels, mode='train')
                loss = model_parallel.scale_loss(loss)
                loss.backward()
                model_parallel.apply_collective_grads()
            else:
                loss = model(images, labels, mode='train')
                loss.backward()
C
chenguowei01 已提交
198
            optimizer.minimize(loss)
C
chenguowei01 已提交
199
            model.clear_gradients()
C
chenguowei01 已提交
200
            avg_loss += loss.numpy()[0]
C
chenguowei01 已提交
201
            lr = optimizer.current_step_lr()
C
add vdl  
chenguowei01 已提交
202
            num_steps += 1
C
chenguowei01 已提交
203
            if num_steps % log_steps == 0 and ParallelEnv().local_rank == 0:
C
chenguowei01 已提交
204 205
                avg_loss /= log_steps
                time_step = timer.elapsed_time() / log_steps
C
add vdl  
chenguowei01 已提交
206
                remain_steps = total_steps - num_steps
C
chenguowei01 已提交
207 208 209 210
                logging.info(
                    "[TRAIN] Epoch={}/{}, Step={}/{}, loss={:.4f}, lr={:.6f}, sec/step={:.4f} | ETA {}"
                    .format(epoch + 1, num_epochs, step + 1, steps_per_epoch,
                            avg_loss, lr, time_step,
C
add vdl  
chenguowei01 已提交
211 212 213 214
                            calculate_eta(remain_steps, time_step)))
                if use_vdl:
                    log_writer.add_scalar('Train/loss', avg_loss, num_steps)
                    log_writer.add_scalar('Train/lr', lr, num_steps)
C
chenguowei01 已提交
215 216
                avg_loss = 0.0
                timer.restart()
C
chenguowei01 已提交
217

C
chenguowei01 已提交
218
        if ((epoch + 1) % save_interval_epochs == 0
C
chenguowei01 已提交
219
                or epoch + 1 == num_epochs) and ParallelEnv().local_rank == 0:
C
chenguowei01 已提交
220 221 222
            current_save_dir = os.path.join(save_dir,
                                            "epoch_{}".format(epoch + 1))
            if not os.path.isdir(current_save_dir):
C
chenguowei01 已提交
223
                os.makedirs(current_save_dir)
C
chenguowei01 已提交
224
            fluid.save_dygraph(model.state_dict(),
C
chenguowei01 已提交
225
                               os.path.join(current_save_dir, 'model'))
C
chenguowei01 已提交
226 227
            fluid.save_dygraph(optimizer.state_dict(),
                               os.path.join(current_save_dir, 'model'))
C
chenguowei01 已提交
228

C
chenguowei01 已提交
229
            if eval_dataset is not None:
C
add vdl  
chenguowei01 已提交
230
                mean_iou, mean_acc = evaluate(
C
chenguowei01 已提交
231 232 233 234
                    model,
                    eval_dataset,
                    model_dir=current_save_dir,
                    num_classes=num_classes,
C
chenguowei01 已提交
235
                    ignore_index=ignore_index,
C
chenguowei01 已提交
236
                    epoch_id=epoch + 1)
C
chenguowei01 已提交
237 238 239 240 241 242
                if mean_iou > best_mean_iou:
                    best_mean_iou = mean_iou
                    best_model_epoch = epoch + 1
                    best_model_dir = os.path.join(save_dir, "best_model")
                    fluid.save_dygraph(model.state_dict(),
                                       os.path.join(best_model_dir, 'model'))
R
update  
root 已提交
243 244 245
                logging.info(
                    'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}'
                    .format(best_model_epoch, best_mean_iou))
C
chenguowei01 已提交
246

C
add vdl  
chenguowei01 已提交
247 248
                if use_vdl:
                    log_writer.add_scalar('Evaluate/mean_iou', mean_iou,
C
chenguowei01 已提交
249
                                          epoch + 1)
C
add vdl  
chenguowei01 已提交
250
                    log_writer.add_scalar('Evaluate/mean_acc', mean_acc,
C
chenguowei01 已提交
251
                                          epoch + 1)
C
chenguowei01 已提交
252
                model.train()
C
chenguowei01 已提交
253 254
    if use_vdl:
        log_writer.close()
C
chenguowei01 已提交
255 256 257


def main(args):
C
chenguowei01 已提交
258 259
    env_info = get_environ_info()
    places = fluid.CUDAPlace(ParallelEnv().dev_id) \
C
chenguowei01 已提交
260
        if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
C
chenguowei01 已提交
261 262
        else fluid.CPUPlace()

C
chenguowei01 已提交
263 264
    if args.dataset.lower() == 'opticdiscseg':
        dataset = OpticDiscSeg
C
chenguowei01 已提交
265 266
    elif args.dataset.lower() == 'cityscapes':
        dataset = Cityscapes
C
chenguowei01 已提交
267 268
    else:
        raise Exception(
C
chenguowei01 已提交
269 270
            "The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
        )
C
chenguowei01 已提交
271

C
chenguowei01 已提交
272 273 274 275 276 277 278
    with fluid.dygraph.guard(places):
        # Creat dataset reader
        train_transforms = T.Compose([
            T.Resize(args.input_size),
            T.RandomHorizontalFlip(),
            T.Normalize()
        ])
C
chenguowei01 已提交
279
        train_dataset = dataset(transforms=train_transforms, mode='train')
C
chenguowei01 已提交
280

C
chenguowei01 已提交
281
        eval_dataset = None
C
chenguowei01 已提交
282
        if args.do_eval:
C
chenguowei01 已提交
283 284 285
            eval_transforms = T.Compose(
                [T.Resize(args.input_size),
                 T.Normalize()])
C
chenguowei01 已提交
286
            eval_dataset = dataset(transforms=eval_transforms, mode='eval')
C
chenguowei01 已提交
287

C
chenguowei01 已提交
288 289 290 291 292
        if args.model_name not in MODELS:
            raise Exception(
                '--model_name is invalid. it should be one of {}'.format(
                    str(list(MODELS.keys()))))
        model = MODELS[args.model_name](num_classes=train_dataset.num_classes)
C
chenguowei01 已提交
293 294

        # Creat optimizer
C
chenguowei01 已提交
295 296 297
        # todo, may less one than len(loader)
        num_steps_each_epoch = len(train_dataset) // (
            args.batch_size * ParallelEnv().nranks)
C
chenguowei01 已提交
298
        decay_step = args.num_epochs * num_steps_each_epoch
C
chenguowei01 已提交
299 300
        lr_decay = fluid.layers.polynomial_decay(
            args.learning_rate, decay_step, end_learning_rate=0, power=0.9)
C
chenguowei01 已提交
301 302 303 304 305 306
        optimizer = fluid.optimizer.Momentum(
            lr_decay,
            momentum=0.9,
            parameter_list=model.parameters(),
            regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5))

C
chenguowei01 已提交
307 308 309 310 311 312 313 314 315 316
        train(
            model,
            train_dataset,
            places=places,
            eval_dataset=eval_dataset,
            optimizer=optimizer,
            save_dir=args.save_dir,
            num_epochs=args.num_epochs,
            batch_size=args.batch_size,
            pretrained_model=args.pretrained_model,
C
chenguowei01 已提交
317
            resume_model=args.resume_model,
C
chenguowei01 已提交
318
            save_interval_epochs=args.save_interval_epochs,
C
chenguowei01 已提交
319
            log_steps=args.log_steps,
C
chenguowei01 已提交
320
            num_classes=train_dataset.num_classes,
C
add vdl  
chenguowei01 已提交
321 322
            num_workers=args.num_workers,
            use_vdl=args.use_vdl)
C
chenguowei01 已提交
323 324 325 326


if __name__ == '__main__':
    args = parse_args()
C
chenguowei01 已提交
327
    main(args)