train.py 8.4 KB
Newer Older
C
chenguowei01 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
C
chenguowei01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import paddle.fluid as fluid
C
chenguowei01 已提交
19 20
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
C
chenguowei01 已提交
21
from paddle.incubate.hapi.distributed import DistributedBatchSampler
C
chenguowei01 已提交
22

C
chenguowei01 已提交
23
from datasets import OpticDiscSeg, Cityscapes
C
chenguowei01 已提交
24 25 26 27
import transforms as T
import models
import utils.logging as logging
from utils import get_environ_info
C
chenguowei01 已提交
28
from utils import load_pretrained_model
C
chenguowei01 已提交
29
from utils import resume
C
chenguowei01 已提交
30
from val import evaluate
C
chenguowei01 已提交
31 32 33 34 35 36


def parse_args():
    parser = argparse.ArgumentParser(description='Model training')

    # params of model
C
chenguowei01 已提交
37 38 39 40 41
    parser.add_argument(
        '--model_name',
        dest='model_name',
        help="Model type for traing, which is one of ('UNet')",
        type=str,
C
chenguowei01 已提交
42
        default='UNet')
C
chenguowei01 已提交
43 44

    # params of dataset
C
chenguowei01 已提交
45
    parser.add_argument(
C
chenguowei01 已提交
46 47
        '--dataset',
        dest='dataset',
C
chenguowei01 已提交
48 49
        help=
        "The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')",
C
chenguowei01 已提交
50
        type=str,
C
chenguowei01 已提交
51
        default='OpticDiscSeg')
C
chenguowei01 已提交
52 53

    # params of training
C
chenguowei01 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    parser.add_argument(
        "--input_size",
        dest="input_size",
        help="The image size for net inputs.",
        nargs=2,
        default=[512, 512],
        type=int)
    parser.add_argument(
        '--num_epochs',
        dest='num_epochs',
        help='Number epochs for training',
        type=int,
        default=100)
    parser.add_argument(
        '--batch_size',
        dest='batch_size',
C
chenguowei01 已提交
70
        help='Mini batch size of one gpu or cpu',
C
chenguowei01 已提交
71 72 73 74 75 76 77 78 79 80 81
        type=int,
        default=2)
    parser.add_argument(
        '--learning_rate',
        dest='learning_rate',
        help='Learning rate',
        type=float,
        default=0.01)
    parser.add_argument(
        '--pretrained_model',
        dest='pretrained_model',
C
chenguowei01 已提交
82
        help='The path of pretrained weight',
C
chenguowei01 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
        type=str,
        default=None)
    parser.add_argument(
        '--save_interval_epochs',
        dest='save_interval_epochs',
        help='The interval epochs for save a model snapshot',
        type=int,
        default=5)
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
        help='The directory for saving the model snapshot',
        type=str,
        default='./output')
    parser.add_argument(
        '--num_workers',
        dest='num_workers',
        help='Num workers for data loader',
        type=int,
        default=0)
C
chenguowei01 已提交
103 104 105 106 107
    parser.add_argument(
        '--do_eval',
        dest='do_eval',
        help='Eval while training',
        action='store_true')
C
chenguowei01 已提交
108 109 110 111 112 113

    return parser.parse_args()


def train(model,
          train_dataset,
C
chenguowei01 已提交
114
          places=None,
C
chenguowei01 已提交
115 116 117 118 119 120
          eval_dataset=None,
          optimizer=None,
          save_dir='output',
          num_epochs=100,
          batch_size=2,
          pretrained_model=None,
C
chenguowei01 已提交
121
          resume_model=None,
C
chenguowei01 已提交
122
          save_interval_epochs=1,
C
chenguowei01 已提交
123 124
          num_classes=None,
          num_workers=8):
C
chenguowei01 已提交
125 126 127
    ignore_index = model.ignore_index
    nranks = ParallelEnv().nranks

C
chenguowei01 已提交
128 129 130 131 132
    start_epoch = 0
    if resume_model is not None:
        start_epoch = resume(optimizer, resume_model)
    elif pretrained_model is not None:
        load_pretrained_model(model, pretrained_model)
C
chenguowei01 已提交
133

C
chenguowei01 已提交
134 135
    if not os.path.isdir(save_dir):
        if os.path.exists(save_dir):
C
chenguowei01 已提交
136 137 138
            os.remove(save_dir)
        os.makedirs(save_dir)

C
chenguowei01 已提交
139 140 141
    if nranks > 1:
        strategy = fluid.dygraph.prepare_context()
        model_parallel = fluid.dygraph.DataParallel(model, strategy)
C
chenguowei01 已提交
142

C
chenguowei01 已提交
143 144
    batch_sampler = DistributedBatchSampler(
        train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
C
chenguowei01 已提交
145 146 147 148 149 150 151 152
    loader = DataLoader(
        train_dataset,
        batch_sampler=batch_sampler,
        places=places,
        num_workers=num_workers,
        return_list=True,
    )

C
chenguowei01 已提交
153
    for epoch in range(start_epoch, num_epochs):
C
chenguowei01 已提交
154
        for step, data in enumerate(loader):
C
chenguowei01 已提交
155 156
            images = data[0]
            labels = data[1].astype('int64')
C
chenguowei01 已提交
157 158 159 160 161 162 163 164
            if nranks > 1:
                loss = model_parallel(images, labels, mode='train')
                loss = model_parallel.scale_loss(loss)
                loss.backward()
                model_parallel.apply_collective_grads()
            else:
                loss = model(images, labels, mode='train')
                loss.backward()
C
chenguowei01 已提交
165
            optimizer.minimize(loss)
C
chenguowei01 已提交
166
            model.clear_gradients()
C
chenguowei01 已提交
167 168 169 170 171
            lr = optimizer.current_step_lr()
            logging.info(
                "[TRAIN] Epoch={}/{}, Step={}/{}, loss={}, lr={}".format(
                    epoch + 1, num_epochs, step + 1, len(batch_sampler),
                    loss.numpy(), lr))
C
chenguowei01 已提交
172

C
chenguowei01 已提交
173
        if ((epoch + 1) % save_interval_epochs == 0
C
chenguowei01 已提交
174
                or epoch == num_epochs - 1) and ParallelEnv().local_rank == 0:
C
chenguowei01 已提交
175 176 177
            current_save_dir = os.path.join(save_dir,
                                            "epoch_{}".format(epoch + 1))
            if not os.path.isdir(current_save_dir):
C
chenguowei01 已提交
178
                os.makedirs(current_save_dir)
C
chenguowei01 已提交
179
            fluid.save_dygraph(model.state_dict(),
C
chenguowei01 已提交
180
                               os.path.join(current_save_dir, 'model'))
C
chenguowei01 已提交
181 182
            fluid.save_dygraph(optimizer.state_dict(),
                               os.path.join(current_save_dir, 'model'))
C
chenguowei01 已提交
183

C
chenguowei01 已提交
184
            if eval_dataset is not None:
C
chenguowei01 已提交
185 186 187
                evaluate(
                    model,
                    eval_dataset,
C
chenguowei01 已提交
188
                    places=places,
C
chenguowei01 已提交
189 190 191
                    model_dir=current_save_dir,
                    num_classes=num_classes,
                    batch_size=batch_size,
C
chenguowei01 已提交
192
                    ignore_index=ignore_index,
C
chenguowei01 已提交
193
                    epoch_id=epoch + 1)
C
chenguowei01 已提交
194
                model.train()
C
chenguowei01 已提交
195 196 197


def main(args):
C
chenguowei01 已提交
198 199
    env_info = get_environ_info()
    places = fluid.CUDAPlace(ParallelEnv().dev_id) \
C
chenguowei01 已提交
200
        if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
C
chenguowei01 已提交
201 202
        else fluid.CPUPlace()

C
chenguowei01 已提交
203 204
    if args.dataset.lower() == 'opticdiscseg':
        dataset = OpticDiscSeg
C
chenguowei01 已提交
205 206
    elif args.dataset.lower() == 'cityscapes':
        dataset = Cityscapes
C
chenguowei01 已提交
207 208
    else:
        raise Exception(
C
chenguowei01 已提交
209 210
            "The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
        )
C
chenguowei01 已提交
211

C
chenguowei01 已提交
212 213 214 215 216 217 218
    with fluid.dygraph.guard(places):
        # Creat dataset reader
        train_transforms = T.Compose([
            T.Resize(args.input_size),
            T.RandomHorizontalFlip(),
            T.Normalize()
        ])
C
chenguowei01 已提交
219
        train_dataset = dataset(transforms=train_transforms, mode='train')
C
chenguowei01 已提交
220

C
chenguowei01 已提交
221
        eval_dataset = None
C
chenguowei01 已提交
222
        if args.do_eval:
C
chenguowei01 已提交
223 224 225
            eval_transforms = T.Compose(
                [T.Resize(args.input_size),
                 T.Normalize()])
C
chenguowei01 已提交
226
            eval_dataset = dataset(transforms=eval_transforms, mode='eval')
C
chenguowei01 已提交
227 228

        if args.model_name == 'UNet':
C
chenguowei01 已提交
229 230
            model = models.UNet(
                num_classes=train_dataset.num_classes, ignore_index=255)
C
chenguowei01 已提交
231 232

        # Creat optimizer
C
chenguowei01 已提交
233 234 235 236
        # todo, may less one than len(loader)
        num_steps_each_epoch = len(train_dataset) // (
            args.batch_size * ParallelEnv().nranks)
        print(num_steps_each_epoch, 'num_steps_each_epoch')
C
chenguowei01 已提交
237
        decay_step = args.num_epochs * num_steps_each_epoch
C
chenguowei01 已提交
238 239
        lr_decay = fluid.layers.polynomial_decay(
            args.learning_rate, decay_step, end_learning_rate=0, power=0.9)
C
chenguowei01 已提交
240 241 242 243 244 245
        optimizer = fluid.optimizer.Momentum(
            lr_decay,
            momentum=0.9,
            parameter_list=model.parameters(),
            regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5))

C
chenguowei01 已提交
246 247 248 249 250 251 252 253 254 255 256
        train(
            model,
            train_dataset,
            places=places,
            eval_dataset=eval_dataset,
            optimizer=optimizer,
            save_dir=args.save_dir,
            num_epochs=args.num_epochs,
            batch_size=args.batch_size,
            pretrained_model=args.pretrained_model,
            save_interval_epochs=args.save_interval_epochs,
C
chenguowei01 已提交
257
            num_classes=train_dataset.num_classes,
C
chenguowei01 已提交
258
            num_workers=args.num_workers)
C
chenguowei01 已提交
259 260 261 262


if __name__ == '__main__':
    args = parse_args()
C
chenguowei01 已提交
263
    main(args)