train.py 7.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

C
chenguowei01 已提交
17
import paddle
18 19 20
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
C
chenguowei01 已提交
21 22
# from paddle.incubate.hapi.distributed import DistributedBatchSampler
from paddle.io import DistributedBatchSampler
C
chenguowei01 已提交
23
import paddle.nn.functional as F
24

C
chenguowei01 已提交
25
import dygraph.utils.logger as logger
C
chenguowei01 已提交
26 27 28
from dygraph.utils import load_pretrained_model
from dygraph.utils import resume
from dygraph.utils import Timer, calculate_eta
29 30 31
from .val import evaluate


C
chenguowei01 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
def check_logits_losses(logits, losses):
    len_logits = len(logits)
    len_losses = len(losses['types'])
    if len_logits != len_losses:
        raise RuntimeError(
            'The length of logits should equal to the types of loss config: {} != {}.'
            .format(len_logits, len_losses))


def loss_computation(logits, label, losses):
    check_logits_losses(logits, losses)
    loss = 0
    for i in range(len(logits)):
        logit = logits[i]
        if logit.shape[-2:] != label.shape[-2:]:
            logit = F.resize_bilinear(logit, label.shape[-2:])
        loss_i = losses['types'][i](logit, label)
        loss += losses['coef'][i] * loss_i
    return loss


53 54 55 56 57 58
def train(model,
          train_dataset,
          places=None,
          eval_dataset=None,
          optimizer=None,
          save_dir='output',
C
chenguowei01 已提交
59
          iters=10000,
60 61
          batch_size=2,
          resume_model=None,
C
chenguowei01 已提交
62 63
          save_interval_iters=1000,
          log_iters=10,
64 65
          num_classes=None,
          num_workers=8,
C
chenguowei01 已提交
66 67
          use_vdl=False,
          losses=None):
68 69 70
    ignore_index = model.ignore_index
    nranks = ParallelEnv().nranks

C
chenguowei01 已提交
71
    start_iter = 0
72
    if resume_model is not None:
C
chenguowei01 已提交
73
        start_iter = resume(model, optimizer, resume_model)
74 75 76 77 78 79 80 81

    if not os.path.isdir(save_dir):
        if os.path.exists(save_dir):
            os.remove(save_dir)
        os.makedirs(save_dir)

    if nranks > 1:
        strategy = fluid.dygraph.prepare_context()
C
chenguowei01 已提交
82
        ddp_model = fluid.dygraph.DataParallel(model, strategy)
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99

    batch_sampler = DistributedBatchSampler(
        train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    loader = DataLoader(
        train_dataset,
        batch_sampler=batch_sampler,
        places=places,
        num_workers=num_workers,
        return_list=True,
    )

    if use_vdl:
        from visualdl import LogWriter
        log_writer = LogWriter(save_dir)

    timer = Timer()
    avg_loss = 0.0
C
chenguowei01 已提交
100
    iters_per_epoch = len(batch_sampler)
101
    best_mean_iou = -1.0
C
chenguowei01 已提交
102
    best_model_iter = -1
103 104
    train_reader_cost = 0.0
    train_batch_cost = 0.0
C
chenguowei01 已提交
105 106
    timer.start()

107
    iter = start_iter
C
chenguowei01 已提交
108 109 110 111 112
    while iter < iters:
        for data in loader:
            iter += 1
            if iter > iters:
                break
113
            train_reader_cost += timer.elapsed_time()
114 115 116
            images = data[0]
            labels = data[1].astype('int64')
            if nranks > 1:
C
chenguowei01 已提交
117 118 119
                logits = ddp_model(images)
                loss = loss_computation(logits, labels, losses)
                # loss = ddp_model(images, labels)
C
chenguowei01 已提交
120
                # apply_collective_grads sum grads over multiple gpus.
C
chenguowei01 已提交
121
                loss = ddp_model.scale_loss(loss)
122
                loss.backward()
C
chenguowei01 已提交
123
                ddp_model.apply_collective_grads()
124
            else:
C
chenguowei01 已提交
125 126 127
                logits = model(images)
                loss = loss_computation(logits, labels, losses)
                # loss = model(images, labels)
128 129 130 131 132
                loss.backward()
            optimizer.minimize(loss)
            model.clear_gradients()
            avg_loss += loss.numpy()[0]
            lr = optimizer.current_step_lr()
133
            train_batch_cost += timer.elapsed_time()
C
chenguowei01 已提交
134 135 136 137
            if (iter) % log_iters == 0 and ParallelEnv().local_rank == 0:
                avg_loss /= log_iters
                avg_train_reader_cost = train_reader_cost / log_iters
                avg_train_batch_cost = train_batch_cost / log_iters
138 139
                train_reader_cost = 0.0
                train_batch_cost = 0.0
C
chenguowei01 已提交
140 141
                remain_iters = iters - iter
                eta = calculate_eta(remain_iters, avg_train_batch_cost)
C
chenguowei01 已提交
142
                logger.info(
C
chenguowei01 已提交
143 144
                    "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}"
                    .format((iter - 1) // iters_per_epoch + 1, iter, iters,
145 146
                            avg_loss * nranks, lr, avg_train_batch_cost,
                            avg_train_reader_cost, eta))
147
                if use_vdl:
C
chenguowei01 已提交
148 149
                    log_writer.add_scalar('Train/loss', avg_loss * nranks, iter)
                    log_writer.add_scalar('Train/lr', lr, iter)
150
                    log_writer.add_scalar('Train/batch_cost',
C
chenguowei01 已提交
151
                                          avg_train_batch_cost, iter)
152
                    log_writer.add_scalar('Train/reader_cost',
C
chenguowei01 已提交
153
                                          avg_train_reader_cost, iter)
154 155
                avg_loss = 0.0

C
chenguowei01 已提交
156 157 158 159 160 161 162 163 164 165
            if (iter % save_interval_iters == 0
                    or iter == iters) and ParallelEnv().local_rank == 0:
                current_save_dir = os.path.join(save_dir,
                                                "iter_{}".format(iter))
                if not os.path.isdir(current_save_dir):
                    os.makedirs(current_save_dir)
                fluid.save_dygraph(model.state_dict(),
                                   os.path.join(current_save_dir, 'model'))
                fluid.save_dygraph(optimizer.state_dict(),
                                   os.path.join(current_save_dir, 'model'))
166

C
chenguowei01 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
                if eval_dataset is not None:
                    mean_iou, avg_acc = evaluate(
                        model,
                        eval_dataset,
                        model_dir=current_save_dir,
                        num_classes=num_classes,
                        ignore_index=ignore_index,
                        iter_id=iter)
                    if mean_iou > best_mean_iou:
                        best_mean_iou = mean_iou
                        best_model_iter = iter
                        best_model_dir = os.path.join(save_dir, "best_model")
                        fluid.save_dygraph(
                            model.state_dict(),
                            os.path.join(best_model_dir, 'model'))
                    logger.info(
                        'Current evaluated best model in eval_dataset is iter_{}, miou={:4f}'
                        .format(best_model_iter, best_mean_iou))
185

C
chenguowei01 已提交
186 187 188 189
                    if use_vdl:
                        log_writer.add_scalar('Evaluate/mIoU', mean_iou, iter)
                        log_writer.add_scalar('Evaluate/aAcc', avg_acc, iter)
                    model.train()
190
            timer.restart()
191 192
    if use_vdl:
        log_writer.close()