program.py 20.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import numpy as np

from collections import OrderedDict
H
huangxu96 已提交
24
from optimizer import OptimizerBuilder
25 26 27

import paddle
import paddle.nn.functional as F
H
huangxu96 已提交
28 29
from paddle import fluid
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16
30 31 32 33 34 35 36 37 38 39 40 41 42 43

from ppcls.optimizer.learning_rate import LearningRateBuilder
from ppcls.modeling import architectures
from ppcls.modeling.loss import CELoss
from ppcls.modeling.loss import MixCELoss
from ppcls.modeling.loss import JSDivLoss
from ppcls.modeling.loss import GoogLeNetLoss
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger

from paddle.distributed import fleet
from paddle.distributed.fleet import DistributedStrategy


H
huangxu96 已提交
44
def create_feeds(image_shape, use_mix=None, use_dali=None, dtype="float32"):
45 46 47 48 49 50 51 52 53 54 55 56
    """
    Create feeds as model input

    Args:
        image_shape(list[int]): model input shape, such as [3, 224, 224]
        use_mix(bool): whether to use mix(include mixup, cutmix, fmix)

    Returns:
        feeds(dict): dict of model input variables
    """
    feeds = OrderedDict()
    feeds['image'] = paddle.static.data(
H
huangxu96 已提交
57
        name="feed_image", shape=[None] + image_shape, dtype=dtype)
T
Tingquan Gao 已提交
58
    if use_mix and not use_dali:
59 60 61 62 63
        feeds['feed_y_a'] = paddle.static.data(
            name="feed_y_a", shape=[None, 1], dtype="int64")
        feeds['feed_y_b'] = paddle.static.data(
            name="feed_y_b", shape=[None, 1], dtype="int64")
        feeds['feed_lam'] = paddle.static.data(
H
huangxu96 已提交
64
            name="feed_lam", shape=[None, 1], dtype=dtype)
65 66 67 68 69 70 71
    else:
        feeds['label'] = paddle.static.data(
            name="feed_label", shape=[None, 1], dtype="int64")

    return feeds


H
huangxu96 已提交
72
def create_model(architecture, image, classes_num, config, is_train):
73 74 75 76 77 78 79 80
    """
    Create a model

    Args:
        architecture(dict): architecture information,
            name(such as ResNet50) is needed
        image(variable): model input variable
        classes_num(int): num of classes
H
huangxu96 已提交
81
        config(dict): model config
82 83 84 85

    Returns:
        out(variable): model output variable
    """
H
huangxu96 已提交
86
    use_pure_fp16 = config.get("use_pure_fp16", False)
87 88
    name = architecture["name"]
    params = architecture.get("params", {})
L
littletomatodonkey 已提交
89 90 91 92 93

    data_format = "NCHW"
    if "data_format" in config:
        params["data_format"] = config["data_format"]
        data_format = config["data_format"]
H
huangxu96 已提交
94
    input_image_channel = config.get('image_shape', [3, 224, 224])[0]
L
littletomatodonkey 已提交
95 96 97 98 99
    if input_image_channel != 3:
        logger.warning(
            "Input image channel is changed to {}, maybe for better speed-up".
            format(input_image_channel))
        params["input_image_channel"] = input_image_channel
100 101
    if "is_test" in params:
        params['is_test'] = not is_train
L
littletomatodonkey 已提交
102 103
    model = architectures.__dict__[name](class_dim=classes_num, **params)

H
huangxu96 已提交
104 105 106 107 108
    if use_pure_fp16 and not config.get("use_dali", False):
        image = image.astype('float16')
    if data_format == "NHWC":
        image = paddle.tensor.transpose(image, [0, 2, 3, 1])
        image.stop_gradient = True
109
    out = model(image)
H
huangxu96 已提交
110 111 112
    if config.get("use_pure_fp16", False):
        cast_model_to_fp16(paddle.static.default_main_program())
        out = out.astype('float32')
113 114 115 116 117 118 119 120 121
    return out


def create_loss(out,
                feeds,
                architecture,
                classes_num=1000,
                epsilon=None,
                use_mix=False,
H
huangxu96 已提交
122 123
                use_distillation=False,
                use_pure_fp16=False):
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
    """
    Create a loss for optimization, such as:
        1. CrossEnotry loss
        2. CrossEnotry loss with label smoothing
        3. CrossEnotry loss with mix(mixup, cutmix, fmix)
        4. CrossEnotry loss with label smoothing and (mixup, cutmix, fmix)
        5. GoogLeNet loss

    Args:
        out(variable): model output variable
        feeds(dict): dict of model input variables
        architecture(dict): architecture information,
            name(such as ResNet50) is needed
        classes_num(int): num of classes
        epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
        use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
H
huangxu96 已提交
140
        use_pure_fp16(bool): whether to use pure fp16 data as training parameter
141 142 143 144 145 146 147 148 149 150

    Returns:
        loss(variable): loss variable
    """
    if use_mix:
        feed_y_a = paddle.reshape(feeds['feed_y_a'], [-1, 1])
        feed_y_b = paddle.reshape(feeds['feed_y_b'], [-1, 1])
        feed_lam = paddle.reshape(feeds['feed_lam'], [-1, 1])
    else:
        target = paddle.reshape(feeds['label'], [-1, 1])
151

152 153 154 155 156 157 158 159 160 161 162 163 164
    if architecture["name"] == "GoogLeNet":
        assert len(out) == 3, "GoogLeNet should have 3 outputs"
        loss = GoogLeNetLoss(class_dim=classes_num, epsilon=epsilon)
        return loss(out[0], out[1], out[2], target)

    if use_distillation:
        assert len(out) == 2, ("distillation output length must be 2, "
                               "but got {}".format(len(out)))
        loss = JSDivLoss(class_dim=classes_num, epsilon=epsilon)
        return loss(out[1], out[0])

    if use_mix:
        loss = MixCELoss(class_dim=classes_num, epsilon=epsilon)
H
huangxu96 已提交
165
        return loss(out, feed_y_a, feed_y_b, feed_lam, use_pure_fp16)
166 167
    else:
        loss = CELoss(class_dim=classes_num, epsilon=epsilon)
H
huangxu96 已提交
168
        return loss(out, target, use_pure_fp16)
169 170 171 172 173 174 175


def create_metric(out,
                  feeds,
                  architecture,
                  topk=5,
                  classes_num=1000,
H
huangxu96 已提交
176
                  config=None,
177 178 179 180 181 182 183 184 185
                  use_distillation=False):
    """
    Create measures of model accuracy, such as top1 and top5

    Args:
        out(variable): model output variable
        feeds(dict): dict of model input variables(included label)
        topk(int): usually top5
        classes_num(int): num of classes
H
huangxu96 已提交
186
        config(dict) : model config
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219

    Returns:
        fetchs(dict): dict of measures
    """
    label = paddle.reshape(feeds['label'], [-1, 1])
    if architecture["name"] == "GoogLeNet":
        assert len(out) == 3, "GoogLeNet should have 3 outputs"
        out = out[0]
    else:
        # just need student label to get metrics
        if use_distillation:
            out = out[1]
    softmax_out = F.softmax(out)

    fetchs = OrderedDict()
    # set top1 to fetchs
    top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
    fetchs['top1'] = (top1, AverageMeter('top1', '.4f', need_avg=True))
    # set topk to fetchs
    k = min(topk, classes_num)
    topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
    topk_name = 'top{}'.format(k)
    fetchs[topk_name] = (topk, AverageMeter(topk_name, '.4f', need_avg=True))
    return fetchs


def create_fetchs(out,
                  feeds,
                  architecture,
                  topk=5,
                  classes_num=1000,
                  epsilon=None,
                  use_mix=False,
H
huangxu96 已提交
220
                  config=None,
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
                  use_distillation=False):
    """
    Create fetchs as model outputs(included loss and measures),
    will call create_loss and create_metric(if use_mix).

    Args:
        out(variable): model output variable
        feeds(dict): dict of model input variables.
            If use mix_up, it will not include label.
        architecture(dict): architecture information,
            name(such as ResNet50) is needed
        topk(int): usually top5
        classes_num(int): num of classes
        epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
        use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
H
huangxu96 已提交
236
        config(dict): model config
237 238 239 240 241

    Returns:
        fetchs(dict): dict of model outputs(included loss and measures)
    """
    fetchs = OrderedDict()
H
huangxu96 已提交
242
    use_pure_fp16 = config.get("use_pure_fp16", False)
243
    loss = create_loss(out, feeds, architecture, classes_num, epsilon, use_mix,
H
huangxu96 已提交
244
                       use_distillation, use_pure_fp16)
245 246 247
    fetchs['loss'] = (loss, AverageMeter('loss', '7.4f', need_avg=True))
    if not use_mix:
        metric = create_metric(out, feeds, architecture, topk, classes_num,
H
huangxu96 已提交
248
                               config, use_distillation)
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
        fetchs.update(metric)

    return fetchs


def create_optimizer(config):
    """
    Create an optimizer using config, usually including
    learning rate and regularization.

    Args:
        config(dict):  such as
        {
            'LEARNING_RATE':
                {'function': 'Cosine',
                 'params': {'lr': 0.1}
                },
            'OPTIMIZER':
                {'function': 'Momentum',
                 'params':{'momentum': 0.9},
                 'regularizer':
                    {'function': 'L2', 'factor': 0.0001}
                }
        }

    Returns:
        an optimizer instance
    """
    # create learning_rate instance
    lr_config = config['LEARNING_RATE']
    lr_config['params'].update({
        'epochs': config['epochs'],
        'step_each_epoch':
        config['total_images'] // config['TRAIN']['batch_size'],
    })
    lr = LearningRateBuilder(**lr_config)()

    # create optimizer instance
    opt_config = config['OPTIMIZER']
H
huangxu96 已提交
288
    opt = OptimizerBuilder(config, **opt_config)
289 290
    return opt(lr), lr

291

292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
def create_strategy(config):
    """
    Create build strategy and exec strategy.

    Args:
        config(dict): config

    Returns:
        build_strategy: build strategy
        exec_strategy: exec strategy
    """
    build_strategy = paddle.static.BuildStrategy()
    exec_strategy = paddle.static.ExecutionStrategy()

    exec_strategy.num_threads = 1
    exec_strategy.num_iteration_per_drop_scope = 10000 if config.get(
        'use_pure_fp16', False) else 10

    fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16',
                                                         False)
    fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
    fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
    fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
    enable_addto = config.get('enable_addto', fuse_op)

    try:
        build_strategy.fuse_bn_act_ops = fuse_bn_act_ops
    except Exception as e:
        logger.info(
            "PaddlePaddle version 1.7.0 or higher is "
            "required when you want to fuse batch_norm and activation_op.")

    try:
        build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
    except Exception as e:
        logger.info(
            "PaddlePaddle version 1.7.0 or higher is "
            "required when you want to fuse elewise_add_act and activation_op.")

    try:
        build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
    except Exception as e:
        logger.info(
            "PaddlePaddle 2.0-rc or higher is "
            "required when you want to enable fuse_bn_add_act_ops strategy.")

    try:
        build_strategy.enable_addto = enable_addto
    except Exception as e:
        logger.info("PaddlePaddle 2.0-rc or higher is "
                    "required when you want to enable addto strategy.")
    return build_strategy, exec_strategy


346 347 348 349 350 351 352 353 354 355 356
def dist_optimizer(config, optimizer):
    """
    Create a distributed optimizer based on a normal optimizer

    Args:
        config(dict):
        optimizer(): a normal optimizer

    Returns:
        optimizer: a distributed optimizer
    """
357
    build_strategy, exec_strategy = create_strategy(config)
358 359

    dist_strategy = DistributedStrategy()
360 361 362
    dist_strategy.execution_strategy = exec_strategy
    dist_strategy.build_strategy = build_strategy

363 364
    dist_strategy.nccl_comm_num = 1
    dist_strategy.fuse_all_reduce_ops = True
365
    dist_strategy.fuse_grad_size_in_MB = 16
366 367 368 369 370 371
    optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)

    return optimizer


def mixed_precision_optimizer(config, optimizer):
H
huangxu96 已提交
372 373
    use_amp = config.get('use_amp', False)
    scale_loss = config.get('scale_loss', 1.0)
374
    use_dynamic_loss_scaling = config.get('use_dynamic_loss_scaling', False)
H
huangxu96 已提交
375
    if use_amp:
376 377
        optimizer = fluid.contrib.mixed_precision.decorate(
            optimizer,
H
huangxu96 已提交
378
            init_loss_scaling=scale_loss,
379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
            use_dynamic_loss_scaling=use_dynamic_loss_scaling)

    return optimizer


def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
    """
    Build a program using a model and an optimizer
        1. create feeds
        2. create a dataloader
        3. create a model
        4. create fetchs
        5. create an optimizer

    Args:
        config(dict): config
        main_prog(): main program
        startup_prog(): startup program
        is_train(bool): train or valid
        is_distributed(bool): whether to use distributed training method

    Returns:
        dataloader(): a bridge between the model and the data
        fetchs(dict): dict of model outputs(included loss and measures)
    """
    with paddle.static.program_guard(main_prog, startup_prog):
        with paddle.utils.unique_name.guard():
            use_mix = config.get('use_mix') and is_train
T
Tingquan Gao 已提交
407
            use_dali = config.get('use_dali', False)
408
            use_distillation = config.get('use_distillation')
H
huangxu96 已提交
409 410 411 412 413

            image_dtype = "float32"
            if config["ARCHITECTURE"]["name"] == "ResNet50" and config.get("use_pure_fp16", False) \
                and config.get("use_dali", False):
                image_dtype = "float16"
T
Tingquan Gao 已提交
414
            feeds = create_feeds(
L
littletomatodonkey 已提交
415 416 417 418
                config.image_shape,
                use_mix=use_mix,
                use_dali=use_dali,
                dtype=image_dtype)
T
Tingquan Gao 已提交
419 420 421
            if use_dali and use_mix:
                import dali
                feeds = dali.mix(feeds, config, is_train)
422
            out = create_model(config.ARCHITECTURE, feeds['image'],
H
huangxu96 已提交
423
                               config.classes_num, config, is_train)
424 425 426 427 428 429 430 431
            fetchs = create_fetchs(
                out,
                feeds,
                config.ARCHITECTURE,
                config.topk,
                config.classes_num,
                epsilon=config.get('ls_epsilon'),
                use_mix=use_mix,
H
huangxu96 已提交
432
                config=config,
433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456
                use_distillation=use_distillation)
            lr_scheduler = None
            if is_train:
                optimizer, lr_scheduler = create_optimizer(config)
                optimizer = mixed_precision_optimizer(config, optimizer)
                if is_distributed:
                    optimizer = dist_optimizer(config, optimizer)
                optimizer.minimize(fetchs['loss'][0])
    return fetchs, lr_scheduler, feeds


def compile(config, program, loss_name=None, share_prog=None):
    """
    Compile the program

    Args:
        config(dict): config
        program(): the program which is wrapped by
        loss_name(str): loss name
        share_prog(): the shared program, used for evaluation during training

    Returns:
        compiled_program(): a compiled program
    """
457
    build_strategy, exec_strategy = create_strategy(config)
458 459 460 461 462 463 464 465 466 467 468 469 470

    compiled_program = paddle.static.CompiledProgram(
        program).with_data_parallel(
            share_vars_from=share_prog,
            loss_name=loss_name,
            build_strategy=build_strategy,
            exec_strategy=exec_strategy)

    return compiled_program


total_step = 0

471

472 473 474 475 476 477 478 479 480
def run(dataloader,
        exe,
        program,
        feeds,
        fetchs,
        epoch=0,
        mode='train',
        config=None,
        vdl_writer=None,
481
        lr_scheduler=None):
482 483 484 485 486 487 488 489 490 491 492 493 494 495
    """
    Feed data to the model and fetch the measures and loss

    Args:
        dataloader(paddle io dataloader):
        exe():
        program():
        fetchs(dict): dict of measures and the loss
        epoch(int): epoch of training or validation
        model(str): log only

    Returns:
    """
    fetch_list = [f[0] for f in fetchs.values()]
496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511
    metric_list = [
        ("lr", AverageMeter(
            'lr', 'f', postfix=",", need_avg=False)),
        ("batch_time", AverageMeter(
            'batch_cost', '.5f', postfix=" s,")),
        ("reader_time", AverageMeter(
            'reader_cost', '.5f', postfix=" s,")),
    ]
    topk_name = 'top{}'.format(config.topk)
    metric_list.insert(0, ("loss", fetchs["loss"][1]))
    metric_list.insert(0, (topk_name, fetchs[topk_name][1]))
    metric_list.insert(0, ("top1", fetchs["top1"][1]))

    metric_list = OrderedDict(metric_list)

    for m in metric_list.values():
512
        m.reset()
513

T
Tingquan Gao 已提交
514 515
    use_dali = config.get('use_dali', False)
    dataloader = dataloader if use_dali else dataloader()
516
    tic = time.time()
T
Tingquan Gao 已提交
517
    for idx, batch in enumerate(dataloader):
L
littletomatodonkey 已提交
518 519
        # ignore the warmup iters
        if idx == 5:
520 521 522 523 524
            metric_list["batch_time"].reset()
            metric_list["reader_time"].reset()

        metric_list['reader_time'].update(time.time() - tic)

L
littletomatodonkey 已提交
525 526 527 528 529 530 531 532 533
        if use_dali:
            batch_size = batch[0]["feed_image"].shape()[0]
            feed_dict = batch[0]
        else:
            batch_size = batch[0].shape()[0]
            feed_dict = {
                key.name: batch[idx]
                for idx, key in enumerate(feeds.values())
            }
534 535 536
        metrics = exe.run(program=program,
                          feed=feed_dict,
                          fetch_list=fetch_list)
L
littletomatodonkey 已提交
537

538 539 540
        for name, m in zip(fetchs.keys(), metrics):
            metric_list[name].update(np.mean(m), batch_size)
        metric_list["batch_time"].update(time.time() - tic)
541
        if mode == "train":
542
            metric_list['lr'].update(lr_scheduler.get_lr())
L
littletomatodonkey 已提交
543

544 545 546 547 548
        fetchs_str = ' '.join([
            str(metric_list[key].mean)
            if "time" in key else str(metric_list[key].value)
            for key in metric_list
        ])
549 550
        ips_info = " ips: {:.5f} images/sec.".format(
            batch_size / metric_list["batch_time"].avg)
L
littletomatodonkey 已提交
551
        fetchs_str += ips_info
552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567

        if lr_scheduler is not None:
            if lr_scheduler.update_specified:
                curr_global_counter = lr_scheduler.step_each_epoch * epoch + idx
                update = max(
                    0, curr_global_counter - lr_scheduler.
                    update_start_step) % lr_scheduler.update_step_interval == 0
                if update:
                    lr_scheduler.step()
            else:
                lr_scheduler.step()

        if vdl_writer:
            global total_step
            logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
            total_step += 1
L
littletomatodonkey 已提交
568
        if mode == 'valid':
569 570 571 572 573 574 575 576 577 578 579 580 581 582
            if idx % config.get('print_interval', 10) == 0:
                logger.info("{:s} step:{:<4d} {:s}".format(mode, idx,
                                                           fetchs_str))
        else:
            epoch_str = "epoch:{:<3d}".format(epoch)
            step_str = "{:s} step:{:<4d}".format(mode, idx)

            if idx % config.get('print_interval', 10) == 0:
                logger.info("{:s} {:s} {:s}".format(
                    logger.coloring(epoch_str, "HEADER")
                    if idx == 0 else epoch_str,
                    logger.coloring(step_str, "PURPLE"),
                    logger.coloring(fetchs_str, 'OKGREEN')))

L
littletomatodonkey 已提交
583 584
        tic = time.time()

585 586 587 588 589
    end_str = ' '.join([str(m.mean) for m in metric_list.values()] +
                       [metric_list["batch_time"].total])
    ips_info = "ips: {:.5f} images/sec.".format(
        batch_size * metric_list["batch_time"].count /
        metric_list["batch_time"].sum)
L
littletomatodonkey 已提交
590
    if mode == 'valid':
591
        logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
592 593
    else:
        end_epoch_str = "END epoch:{:<3d}".format(epoch)
L
littletomatodonkey 已提交
594 595
        logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
                                                 ips_info))
T
Tingquan Gao 已提交
596 597
    if use_dali:
        dataloader.reset()
598 599 600 601

    # return top1_acc in order to save the best model
    if mode == 'valid':
        return fetchs["top1"][1].avg