program.py 14.7 KB
Newer Older
1
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
W
WuHaobo 已提交
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
W
WuHaobo 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
WuHaobo 已提交
14 15 16 17 18 19 20

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
L
littletomatodonkey 已提交
21
import datetime
W
WuHaobo 已提交
22 23
from collections import OrderedDict

littletomatodonkey's avatar
littletomatodonkey 已提交
24
import paddle
littletomatodonkey's avatar
littletomatodonkey 已提交
25 26 27
from paddle import to_tensor
import paddle.nn as nn
import paddle.nn.functional as F
W
WuHaobo 已提交
28 29 30 31

from ppcls.optimizer import LearningRateBuilder
from ppcls.optimizer import OptimizerBuilder
from ppcls.modeling import architectures
Y
yaohai 已提交
32
from ppcls.modeling.loss import MultiLabelLoss
W
WuHaobo 已提交
33 34
from ppcls.modeling.loss import CELoss
from ppcls.modeling.loss import MixCELoss
littletomatodonkey's avatar
littletomatodonkey 已提交
35
from ppcls.modeling.loss import JSDivLoss
W
WuHaobo 已提交
36 37 38
from ppcls.modeling.loss import GoogLeNetLoss
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
Y
yaohai 已提交
39 40 41
from ppcls.utils import multi_hot_encode
from ppcls.utils import hamming_distance
from ppcls.utils import accuracy_score
W
WuHaobo 已提交
42 43


W
WuHaobo 已提交
44
def create_model(architecture, classes_num):
W
WuHaobo 已提交
45 46 47 48
    """
    Create a model

    Args:
49 50
        architecture(dict): architecture information,
            name(such as ResNet50) is needed
W
WuHaobo 已提交
51 52 53 54 55 56
        image(variable): model input variable
        classes_num(int): num of classes

    Returns:
        out(variable): model output variable
    """
littletomatodonkey's avatar
littletomatodonkey 已提交
57
    name = architecture["name"]
littletomatodonkey's avatar
littletomatodonkey 已提交
58
    params = architecture.get("params", {})
W
WuHaobo 已提交
59
    return architectures.__dict__[name](class_dim=classes_num, **params)
W
WuHaobo 已提交
60 61


62 63
def create_loss(feeds,
                out,
W
WuHaobo 已提交
64 65 66
                architecture,
                classes_num=1000,
                epsilon=None,
littletomatodonkey's avatar
littletomatodonkey 已提交
67
                use_mix=False,
Y
yaohai 已提交
68 69
                use_distillation=False,
                multilabel=False):
W
WuHaobo 已提交
70 71 72 73 74 75 76 77 78 79 80
    """
    Create a loss for optimization, such as:
        1. CrossEnotry loss
        2. CrossEnotry loss with label smoothing
        3. CrossEnotry loss with mix(mixup, cutmix, fmix)
        4. CrossEnotry loss with label smoothing and (mixup, cutmix, fmix)
        5. GoogLeNet loss

    Args:
        out(variable): model output variable
        feeds(dict): dict of model input variables
81 82
        architecture(dict): architecture information,
            name(such as ResNet50) is needed
W
WuHaobo 已提交
83 84
        classes_num(int): num of classes
        epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
littletomatodonkey's avatar
littletomatodonkey 已提交
85
        use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
W
WuHaobo 已提交
86 87 88 89

    Returns:
        loss(variable): loss variable
    """
littletomatodonkey's avatar
littletomatodonkey 已提交
90
    if architecture["name"] == "GoogLeNet":
W
WuHaobo 已提交
91 92
        assert len(out) == 3, "GoogLeNet should have 3 outputs"
        loss = GoogLeNetLoss(class_dim=classes_num, epsilon=epsilon)
93
        return loss(out[0], out[1], out[2], feeds["label"])
W
WuHaobo 已提交
94

littletomatodonkey's avatar
littletomatodonkey 已提交
95
    if use_distillation:
96 97
        assert len(out) == 2, ("distillation output length must be 2, "
                               "but got {}".format(len(out)))
littletomatodonkey's avatar
littletomatodonkey 已提交
98 99 100 101
        loss = JSDivLoss(class_dim=classes_num, epsilon=epsilon)
        return loss(out[1], out[0])

    if use_mix:
W
WuHaobo 已提交
102
        loss = MixCELoss(class_dim=classes_num, epsilon=epsilon)
103 104 105 106
        feed_y_a = feeds['y_a']
        feed_y_b = feeds['y_b']
        feed_lam = feeds['lam']
        return loss(out, feed_y_a, feed_y_b, feed_lam)
W
WuHaobo 已提交
107
    else:
Y
yaohai 已提交
108 109 110 111
        if not multilabel:
            loss = CELoss(class_dim=classes_num, epsilon=epsilon)
        else:
            loss = MultiLabelLoss(class_dim=classes_num, epsilon=epsilon)
112
        return loss(out, feeds["label"])
W
WuHaobo 已提交
113 114


W
WuHaobo 已提交
115
def create_metric(out,
W
WuHaobo 已提交
116
                  label,
W
WuHaobo 已提交
117 118 119
                  architecture,
                  topk=5,
                  classes_num=1000,
120
                  use_distillation=False,
Y
yaohai 已提交
121
                  multilabel=False,
122
                  mode="train"):
W
WuHaobo 已提交
123 124 125 126 127 128 129 130
    """
    Create measures of model accuracy, such as top1 and top5

    Args:
        out(variable): model output variable
        feeds(dict): dict of model input variables(included label)
        topk(int): usually top5
        classes_num(int): num of classes
131 132
        use_distillation(bool): whether to use distillation training
        mode(str): mode, train/valid
W
WuHaobo 已提交
133 134 135 136

    Returns:
        fetchs(dict): dict of measures
    """
W
WuHaobo 已提交
137 138
    if architecture["name"] == "GoogLeNet":
        assert len(out) == 3, "GoogLeNet should have 3 outputs"
139
        out = out[0]
W
WuHaobo 已提交
140 141 142 143
    else:
        # just need student label to get metrics
        if use_distillation:
            out = out[1]
144
    softmax_out = F.softmax(out)
W
WuHaobo 已提交
145

W
WuHaobo 已提交
146
    fetchs = OrderedDict()
Y
yaohai 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
    metric_names = set()
    if not multilabel:
        softmax_out = F.softmax(out)

        # set top1 to fetchs
        top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
        # set topk to fetchs
        k = min(topk, classes_num)
        topk = paddle.metric.accuracy(softmax_out, label=label, k=k)

        metric_names.add("top1")
        metric_names.add("top{}".format(k))

        fetchs['top1'] = top1
        topk_name = "top{}".format(k)
        fetchs[topk_name] = topk
    else:
        out = F.sigmoid(out)
        preds = multi_hot_encode(out.numpy())
        targets = label.numpy()
        ham_dist = to_tensor(hamming_distance(preds, targets))
        accuracy = to_tensor(accuracy_score(preds, targets, base="label"))

        ham_dist_name = "hamming_distance"
        accuracy_name = "multilabel_accuracy"
        metric_names.add(ham_dist_name)
        metric_names.add(accuracy_name)

        fetchs[accuracy_name] = accuracy
        fetchs[ham_dist_name] = ham_dist
177 178 179

    # multi cards' eval
    if mode != "train" and paddle.distributed.get_world_size() > 1:
Y
yaohai 已提交
180 181 182 183
        for metric_name in metric_names:
            fetchs[metric_name] = paddle.distributed.all_reduce(
                fetchs[metric_name], op=paddle.distributed.ReduceOp.
                SUM) / paddle.distributed.get_world_size()
W
WuHaobo 已提交
184 185 186 187

    return fetchs


littletomatodonkey's avatar
littletomatodonkey 已提交
188
def create_fetchs(feeds, net, config, mode="train"):
W
WuHaobo 已提交
189 190
    """
    Create fetchs as model outputs(included loss and measures),
littletomatodonkey's avatar
littletomatodonkey 已提交
191
    will call create_loss and create_metric(if use_mix).
W
WuHaobo 已提交
192 193 194

    Args:
        out(variable): model output variable
W
WuHaobo 已提交
195 196
        feeds(dict): dict of model input variables.
            If use mix_up, it will not include label.
197 198
        architecture(dict): architecture information,
            name(such as ResNet50) is needed
W
WuHaobo 已提交
199 200 201
        topk(int): usually top5
        classes_num(int): num of classes
        epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
littletomatodonkey's avatar
littletomatodonkey 已提交
202
        use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
W
WuHaobo 已提交
203 204 205 206

    Returns:
        fetchs(dict): dict of model outputs(included loss and measures)
    """
littletomatodonkey's avatar
littletomatodonkey 已提交
207 208 209 210 211 212
    architecture = config.ARCHITECTURE
    topk = config.topk
    classes_num = config.classes_num
    epsilon = config.get('ls_epsilon')
    use_mix = config.get('use_mix') and mode == 'train'
    use_distillation = config.get('use_distillation')
Y
yaohai 已提交
213
    multilabel = config.get('multilabel', False)
littletomatodonkey's avatar
littletomatodonkey 已提交
214 215 216

    out = net(feeds["image"])

W
WuHaobo 已提交
217
    fetchs = OrderedDict()
218
    fetchs['loss'] = create_loss(feeds, out, architecture, classes_num,
Y
yaohai 已提交
219 220
                                 epsilon, use_mix, use_distillation,
                                 multilabel)
littletomatodonkey's avatar
littletomatodonkey 已提交
221
    if not use_mix:
222 223 224 225 226 227 228
        metric = create_metric(
            out,
            feeds["label"],
            architecture,
            topk,
            classes_num,
            use_distillation,
Y
yaohai 已提交
229
            multilabel=multilabel,
230
            mode=mode)
W
WuHaobo 已提交
231 232 233 234 235
        fetchs.update(metric)

    return fetchs


W
WuHaobo 已提交
236
def create_optimizer(config, parameter_list=None):
W
WuHaobo 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
    """
    Create an optimizer using config, usually including
    learning rate and regularization.

    Args:
        config(dict):  such as
        {
            'LEARNING_RATE':
                {'function': 'Cosine',
                 'params': {'lr': 0.1}
                },
            'OPTIMIZER':
                {'function': 'Momentum',
                 'params':{'momentum': 0.9},
                 'regularizer':
                    {'function': 'L2', 'factor': 0.0001}
                }
        }

    Returns:
        an optimizer instance
    """
    # create learning_rate instance
    lr_config = config['LEARNING_RATE']
    lr_config['params'].update({
        'epochs': config['epochs'],
        'step_each_epoch':
        config['total_images'] // config['TRAIN']['batch_size'],
    })
    lr = LearningRateBuilder(**lr_config)()

    # create optimizer instance
    opt_config = config['OPTIMIZER']
    opt = OptimizerBuilder(**opt_config)
271
    return opt(lr, parameter_list), lr
W
WuHaobo 已提交
272 273


Y
yaohai 已提交
274
def create_feeds(batch, use_mix, num_classes, multilabel=False):
littletomatodonkey's avatar
littletomatodonkey 已提交
275
    image = batch[0]
276
    if use_mix:
littletomatodonkey's avatar
littletomatodonkey 已提交
277 278 279
        y_a = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1))
        y_b = to_tensor(batch[2].numpy().astype("int64").reshape(-1, 1))
        lam = to_tensor(batch[3].numpy().astype("float32").reshape(-1, 1))
280 281
        feeds = {"image": image, "y_a": y_a, "y_b": y_b, "lam": lam}
    else:
Y
yaohai 已提交
282 283 284 285
        if not multilabel:
            label = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1))
        else:
            label = to_tensor(batch[1].numpy().astype('float32').reshape(-1, num_classes))
286 287 288 289
        feeds = {"image": image, "label": label}
    return feeds


T
Tingquan Gao 已提交
290 291 292
total_step = 0


293 294 295 296 297 298
def run(dataloader,
        config,
        net,
        optimizer=None,
        lr_scheduler=None,
        epoch=0,
T
Tingquan Gao 已提交
299 300
        mode='train',
        vdl_writer=None):
W
WuHaobo 已提交
301 302 303 304
    """
    Feed data to the model and fetch the measures and loss

    Args:
littletomatodonkey's avatar
littletomatodonkey 已提交
305
        dataloader(paddle dataloader):
W
WuHaobo 已提交
306 307 308 309 310 311 312 313
        exe():
        program():
        fetchs(dict): dict of measures and the loss
        epoch(int): epoch of training or validation
        model(str): log only

    Returns:
    """
littletomatodonkey's avatar
littletomatodonkey 已提交
314
    print_interval = config.get("print_interval", 10)
littletomatodonkey's avatar
littletomatodonkey 已提交
315
    use_mix = config.get("use_mix", False) and mode == "train"
Y
yaohai 已提交
316 317
    multilabel = config.get("multilabel", False)
    classes_num = config.get("classes_num")
littletomatodonkey's avatar
littletomatodonkey 已提交
318 319

    metric_list = [
L
littletomatodonkey 已提交
320 321
        ("loss", AverageMeter(
            'loss', '7.5f', postfix=",")),
littletomatodonkey's avatar
littletomatodonkey 已提交
322
        ("lr", AverageMeter(
L
littletomatodonkey 已提交
323 324 325 326 327
            'lr', 'f', postfix=",", need_avg=False)),
        ("batch_time", AverageMeter(
            'batch_cost', '.5f', postfix=" s,")),
        ("reader_time", AverageMeter(
            'reader_cost', '.5f', postfix=" s,")),
littletomatodonkey's avatar
littletomatodonkey 已提交
328 329
    ]
    if not use_mix:
Y
yaohai 已提交
330 331 332 333 334 335 336 337 338 339 340 341 342
        if not multilabel:
            topk_name = 'top{}'.format(config.topk)
            metric_list.insert(
                0, (topk_name, AverageMeter(
                    topk_name, '.5f', postfix=",")))
            metric_list.insert(
                0, ("top1", AverageMeter(
                    "top1", '.5f', postfix=",")))
        else:
            metric_list.insert(0, ("multilabel_accuracy", AverageMeter(
                                   "multilabel_accuracy", '.5f', postfix=",")))
            metric_list.insert(0, ("hamming_distance", AverageMeter(
                                   "hamming_distance", '.5f', postfix=",")))
littletomatodonkey's avatar
littletomatodonkey 已提交
343 344

    metric_list = OrderedDict(metric_list)
W
WuHaobo 已提交
345

W
WuHaobo 已提交
346
    tic = time.time()
347
    for idx, batch in enumerate(dataloader()):
L
littletomatodonkey 已提交
348 349 350 351 352
        # avoid statistics from warmup time
        if idx == 10:
            metric_list["batch_time"].reset()
            metric_list["reader_time"].reset()

littletomatodonkey's avatar
littletomatodonkey 已提交
353
        metric_list['reader_time'].update(time.time() - tic)
littletomatodonkey's avatar
fix bs  
littletomatodonkey 已提交
354
        batch_size = len(batch[0])
Y
yaohai 已提交
355
        feeds = create_feeds(batch, use_mix, classes_num, multilabel)
littletomatodonkey's avatar
littletomatodonkey 已提交
356
        fetchs = create_fetchs(feeds, net, config, mode)
W
WuHaobo 已提交
357
        if mode == 'train':
L
littletomatodonkey 已提交
358 359 360 361 362
            avg_loss = fetchs['loss']
            avg_loss.backward()

            optimizer.step()
            optimizer.clear_grad()
T
Tingquan Gao 已提交
363 364
            lr_value = optimizer._global_learning_rate().numpy()[0]
            metric_list['lr'].update(lr_value, batch_size)
W
WuHaobo 已提交
365

366 367 368 369 370 371 372 373 374 375 376
            if lr_scheduler is not None:
                if lr_scheduler.update_specified:
                    curr_global_counter = lr_scheduler.step_each_epoch * epoch + idx
                    update = max(
                        0, curr_global_counter - lr_scheduler.update_start_step
                    ) % lr_scheduler.update_step_interval == 0
                    if update:
                        lr_scheduler.step()
                else:
                    lr_scheduler.step()

W
WuHaobo 已提交
377
        for name, fetch in fetchs.items():
littletomatodonkey's avatar
fix bs  
littletomatodonkey 已提交
378
            metric_list[name].update(fetch.numpy()[0], batch_size)
L
littletomatodonkey 已提交
379
        metric_list["batch_time"].update(time.time() - tic)
W
WuHaobo 已提交
380
        tic = time.time()
W
WuHaobo 已提交
381

T
Tingquan Gao 已提交
382 383 384 385 386 387 388 389 390 391 392 393
        if vdl_writer and mode == "train":
            global total_step
            logger.scaler(
                name="lr", value=lr_value, step=total_step, writer=vdl_writer)
            for name, fetch in fetchs.items():
                logger.scaler(
                    name="train_{}".format(name),
                    value=fetch.numpy()[0],
                    step=total_step,
                    writer=vdl_writer)
            total_step += 1

L
littletomatodonkey 已提交
394 395 396 397 398
        fetchs_str = ' '.join([
            str(metric_list[key].mean)
            if "time" in key else str(metric_list[key].value)
            for key in metric_list
        ])
littletomatodonkey's avatar
littletomatodonkey 已提交
399 400

        if idx % print_interval == 0:
L
littletomatodonkey 已提交
401
            ips_info = "ips: {:.5f} images/sec".format(
L
littletomatodonkey 已提交
402
                batch_size / metric_list["batch_time"].avg)
L
littletomatodonkey 已提交
403 404

            if mode == "train":
littletomatodonkey's avatar
littletomatodonkey 已提交
405 406
                epoch_str = "epoch:{:<3d}".format(epoch)
                step_str = "{:s} step:{:<4d}".format(mode, idx)
L
littletomatodonkey 已提交
407 408 409 410 411 412 413 414 415
                eta_sec = ((config["epochs"] - epoch) * len(dataloader) - idx
                           ) * metric_list["batch_time"].avg
                eta_str = "eta: {:s}".format(
                    str(datetime.timedelta(seconds=int(eta_sec))))
                logger.info("{:s}, {:s}, {:s} {:s}, {:s}".format(
                    epoch_str, step_str, fetchs_str, ips_info, eta_str))
            else:
                logger.info("{:s} step:{:<4d}, {:s} {:s}".format(
                    mode, idx, fetchs_str, ips_info))
S
refine  
shippingwang 已提交
416

417 418
    end_str = ' '.join([str(m.mean) for m in metric_list.values()] +
                       [metric_list['batch_time'].total])
L
littletomatodonkey 已提交
419 420 421 422
    ips_info = "ips: {:.5f} images/sec.".format(
        batch_size * metric_list["batch_time"].count /
        metric_list["batch_time"].sum)

W
WuHaobo 已提交
423
    if mode == 'eval':
L
littletomatodonkey 已提交
424
        logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
W
WuHaobo 已提交
425
    else:
S
shippingwang 已提交
426
        end_epoch_str = "END epoch:{:<3d}".format(epoch)
L
littletomatodonkey 已提交
427 428
        logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
                                                 ips_info))
littletomatodonkey's avatar
littletomatodonkey 已提交
429

W
WuHaobo 已提交
430
    # return top1_acc in order to save the best model
W
WuHaobo 已提交
431
    if mode == 'valid':
Y
yaohai 已提交
432 433 434 435
        if multilabel:
            return metric_list['multilabel_accuracy'].avg
        else:
            return metric_list['top1'].avg