program.py 17.6 KB
Newer Older
1
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
W
WuHaobo 已提交
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
W
WuHaobo 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
WuHaobo 已提交
14 15 16 17 18 19 20

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
21
import numpy as np
W
WuHaobo 已提交
22 23 24 25 26 27 28 29 30 31

from collections import OrderedDict

import paddle.fluid as fluid

from ppcls.optimizer import LearningRateBuilder
from ppcls.optimizer import OptimizerBuilder
from ppcls.modeling import architectures
from ppcls.modeling.loss import CELoss
from ppcls.modeling.loss import MixCELoss
littletomatodonkey's avatar
littletomatodonkey 已提交
32
from ppcls.modeling.loss import JSDivLoss
W
WuHaobo 已提交
33 34 35 36 37 38
from ppcls.modeling.loss import GoogLeNetLoss
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger

from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.collective import DistributedStrategy
L
Leo Chen 已提交
39
import paddle.fluid as fluid
W
WuHaobo 已提交
40

S
shippingwang 已提交
41
from ema import ExponentialMovingAverage
R
fix  
root 已提交
42

W
WuHaobo 已提交
43

L
littletomatodonkey 已提交
44
def create_feeds(image_shape, use_mix=None, use_dali=None):
W
WuHaobo 已提交
45 46 47 48 49
    """
    Create feeds as model input

    Args:
        image_shape(list[int]): model input shape, such as [3, 224, 224]
littletomatodonkey's avatar
littletomatodonkey 已提交
50
        use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
W
WuHaobo 已提交
51 52 53 54 55 56 57

    Returns:
        feeds(dict): dict of model input variables
    """
    feeds = OrderedDict()
    feeds['image'] = fluid.data(
        name="feed_image", shape=[None] + image_shape, dtype="float32")
L
littletomatodonkey 已提交
58 59

    if use_mix and not use_dali:
W
WuHaobo 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
        feeds['feed_y_a'] = fluid.data(
            name="feed_y_a", shape=[None, 1], dtype="int64")
        feeds['feed_y_b'] = fluid.data(
            name="feed_y_b", shape=[None, 1], dtype="int64")
        feeds['feed_lam'] = fluid.data(
            name="feed_lam", shape=[None, 1], dtype="float32")
    else:
        feeds['label'] = fluid.data(
            name="feed_label", shape=[None, 1], dtype="int64")

    return feeds


def create_dataloader(feeds):
    """
    Create a dataloader with model input variables

    Args:
        feeds(dict): dict of model input variables

    Returns:
        dataloader(fluid dataloader):
    """
    trainer_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
    capacity = 64 if trainer_num <= 1 else 8
    dataloader = fluid.io.DataLoader.from_generator(
        feed_list=feeds,
        capacity=capacity,
        use_double_buffer=True,
        iterable=True)

    return dataloader


S
add ema  
shippingwang 已提交
94
def create_model(architecture, image, classes_num, is_train):
W
WuHaobo 已提交
95 96 97 98
    """
    Create a model

    Args:
99 100
        architecture(dict): architecture information,
            name(such as ResNet50) is needed
W
WuHaobo 已提交
101 102 103 104 105 106
        image(variable): model input variable
        classes_num(int): num of classes

    Returns:
        out(variable): model output variable
    """
littletomatodonkey's avatar
littletomatodonkey 已提交
107
    name = architecture["name"]
littletomatodonkey's avatar
littletomatodonkey 已提交
108
    params = architecture.get("params", {})
L
Leo Chen 已提交
109

littletomatodonkey's avatar
littletomatodonkey 已提交
110 111
    if "is_test" in params:
        params['is_test'] = not is_train
littletomatodonkey's avatar
littletomatodonkey 已提交
112
    model = architectures.__dict__[name](**params)
L
Leo Chen 已提交
113

L
littletomatodonkey 已提交
114
    if "data_format" in params and params["data_format"] == "NHWC":
L
Leo Chen 已提交
115 116
        image = fluid.layers.transpose(image, [0, 2, 3, 1])
        image.stop_gradient = True
W
WuHaobo 已提交
117 118 119 120 121 122 123 124 125
    out = model.net(input=image, class_dim=classes_num)
    return out


def create_loss(out,
                feeds,
                architecture,
                classes_num=1000,
                epsilon=None,
littletomatodonkey's avatar
littletomatodonkey 已提交
126 127
                use_mix=False,
                use_distillation=False):
W
WuHaobo 已提交
128 129 130 131 132 133 134 135 136 137 138
    """
    Create a loss for optimization, such as:
        1. CrossEnotry loss
        2. CrossEnotry loss with label smoothing
        3. CrossEnotry loss with mix(mixup, cutmix, fmix)
        4. CrossEnotry loss with label smoothing and (mixup, cutmix, fmix)
        5. GoogLeNet loss

    Args:
        out(variable): model output variable
        feeds(dict): dict of model input variables
139 140
        architecture(dict): architecture information,
            name(such as ResNet50) is needed
W
WuHaobo 已提交
141 142
        classes_num(int): num of classes
        epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
littletomatodonkey's avatar
littletomatodonkey 已提交
143
        use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
W
WuHaobo 已提交
144 145 146 147

    Returns:
        loss(variable): loss variable
    """
littletomatodonkey's avatar
littletomatodonkey 已提交
148
    if architecture["name"] == "GoogLeNet":
W
WuHaobo 已提交
149 150 151 152 153
        assert len(out) == 3, "GoogLeNet should have 3 outputs"
        loss = GoogLeNetLoss(class_dim=classes_num, epsilon=epsilon)
        target = feeds['label']
        return loss(out[0], out[1], out[2], target)

littletomatodonkey's avatar
littletomatodonkey 已提交
154
    if use_distillation:
155 156
        assert len(out) == 2, ("distillation output length must be 2, "
                               "but got {}".format(len(out)))
littletomatodonkey's avatar
littletomatodonkey 已提交
157 158 159 160
        loss = JSDivLoss(class_dim=classes_num, epsilon=epsilon)
        return loss(out[1], out[0])

    if use_mix:
W
WuHaobo 已提交
161 162 163 164 165 166 167 168 169 170 171
        loss = MixCELoss(class_dim=classes_num, epsilon=epsilon)
        feed_y_a = feeds['feed_y_a']
        feed_y_b = feeds['feed_y_b']
        feed_lam = feeds['feed_lam']
        return loss(out, feed_y_a, feed_y_b, feed_lam)
    else:
        loss = CELoss(class_dim=classes_num, epsilon=epsilon)
        target = feeds['label']
        return loss(out, target)


W
WuHaobo 已提交
172 173 174 175 176
def create_metric(out,
                  feeds,
                  architecture,
                  topk=5,
                  classes_num=1000,
littletomatodonkey's avatar
littletomatodonkey 已提交
177
                  use_distillation=False):
W
WuHaobo 已提交
178 179 180 181 182 183 184 185 186 187 188 189
    """
    Create measures of model accuracy, such as top1 and top5

    Args:
        out(variable): model output variable
        feeds(dict): dict of model input variables(included label)
        topk(int): usually top5
        classes_num(int): num of classes

    Returns:
        fetchs(dict): dict of measures
    """
W
WuHaobo 已提交
190 191 192 193 194 195 196 197 198
    if architecture["name"] == "GoogLeNet":
        assert len(out) == 3, "GoogLeNet should have 3 outputs"
        softmax_out = out[0]
    else:
        # just need student label to get metrics
        if use_distillation:
            out = out[1]
        softmax_out = fluid.layers.softmax(out, use_cudnn=False)

W
WuHaobo 已提交
199
    fetchs = OrderedDict()
W
WuHaobo 已提交
200 201
    # set top1 to fetchs
    top1 = fluid.layers.accuracy(softmax_out, label=feeds['label'], k=1)
202
    fetchs['top1'] = (top1, AverageMeter('top1', '.4f', need_avg=True))
W
WuHaobo 已提交
203
    # set topk to fetchs
W
WuHaobo 已提交
204
    k = min(topk, classes_num)
W
WuHaobo 已提交
205
    topk = fluid.layers.accuracy(softmax_out, label=feeds['label'], k=k)
W
WuHaobo 已提交
206
    topk_name = 'top{}'.format(k)
207
    fetchs[topk_name] = (topk, AverageMeter(topk_name, '.4f', need_avg=True))
W
WuHaobo 已提交
208 209 210 211 212 213 214 215 216 217

    return fetchs


def create_fetchs(out,
                  feeds,
                  architecture,
                  topk=5,
                  classes_num=1000,
                  epsilon=None,
littletomatodonkey's avatar
littletomatodonkey 已提交
218 219
                  use_mix=False,
                  use_distillation=False):
W
WuHaobo 已提交
220 221
    """
    Create fetchs as model outputs(included loss and measures),
littletomatodonkey's avatar
littletomatodonkey 已提交
222
    will call create_loss and create_metric(if use_mix).
W
WuHaobo 已提交
223 224 225

    Args:
        out(variable): model output variable
W
WuHaobo 已提交
226 227
        feeds(dict): dict of model input variables.
            If use mix_up, it will not include label.
228 229
        architecture(dict): architecture information,
            name(such as ResNet50) is needed
W
WuHaobo 已提交
230 231 232
        topk(int): usually top5
        classes_num(int): num of classes
        epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
littletomatodonkey's avatar
littletomatodonkey 已提交
233
        use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
W
WuHaobo 已提交
234 235 236 237 238

    Returns:
        fetchs(dict): dict of model outputs(included loss and measures)
    """
    fetchs = OrderedDict()
littletomatodonkey's avatar
littletomatodonkey 已提交
239 240
    loss = create_loss(out, feeds, architecture, classes_num, epsilon, use_mix,
                       use_distillation)
241
    fetchs['loss'] = (loss, AverageMeter('loss', '7.4f', need_avg=True))
littletomatodonkey's avatar
littletomatodonkey 已提交
242
    if not use_mix:
W
WuHaobo 已提交
243 244
        metric = create_metric(out, feeds, architecture, topk, classes_num,
                               use_distillation)
W
WuHaobo 已提交
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
        fetchs.update(metric)

    return fetchs


def create_optimizer(config):
    """
    Create an optimizer using config, usually including
    learning rate and regularization.

    Args:
        config(dict):  such as
        {
            'LEARNING_RATE':
                {'function': 'Cosine',
                 'params': {'lr': 0.1}
                },
            'OPTIMIZER':
                {'function': 'Momentum',
                 'params':{'momentum': 0.9},
                 'regularizer':
                    {'function': 'L2', 'factor': 0.0001}
                }
        }

    Returns:
        an optimizer instance
    """
    # create learning_rate instance
    lr_config = config['LEARNING_RATE']
    lr_config['params'].update({
        'epochs': config['epochs'],
        'step_each_epoch':
        config['total_images'] // config['TRAIN']['batch_size'],
    })
    lr = LearningRateBuilder(**lr_config)()

    # create optimizer instance
    opt_config = config['OPTIMIZER']
    opt = OptimizerBuilder(**opt_config)
    return opt(lr)


def dist_optimizer(config, optimizer):
    """
    Create a distributed optimizer based on a normal optimizer

    Args:
        config(dict):
        optimizer(): a normal optimizer

    Returns:
        optimizer: a distributed optimizer
    """
    exec_strategy = fluid.ExecutionStrategy()
    exec_strategy.num_threads = 3
    exec_strategy.num_iteration_per_drop_scope = 10

    dist_strategy = DistributedStrategy()
    dist_strategy.nccl_comm_num = 1
    dist_strategy.fuse_all_reduce_ops = True
    dist_strategy.exec_strategy = exec_strategy
    optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)

    return optimizer


312 313 314 315 316 317 318 319 320 321 322 323 324
def mixed_precision_optimizer(config, optimizer):
    use_fp16 = config.get('use_fp16', False)
    amp_scale_loss = config.get('amp_scale_loss', 1.0)
    use_dynamic_loss_scaling = config.get('use_dynamic_loss_scaling', False)
    if use_fp16:
        optimizer = fluid.contrib.mixed_precision.decorate(
            optimizer,
            init_loss_scaling=amp_scale_loss,
            use_dynamic_loss_scaling=use_dynamic_loss_scaling)

    return optimizer


325
def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
W
WuHaobo 已提交
326 327 328 329 330 331 332 333 334 335 336 337 338
    """
    Build a program using a model and an optimizer
        1. create feeds
        2. create a dataloader
        3. create a model
        4. create fetchs
        5. create an optimizer

    Args:
        config(dict): config
        main_prog(): main program
        startup_prog(): startup program
        is_train(bool): train or valid
339
        is_distributed(bool): whether to use distributed training method
W
WuHaobo 已提交
340 341 342 343 344 345 346 347

    Returns:
        dataloader(): a bridge between the model and the data
        fetchs(dict): dict of model outputs(included loss and measures)
    """
    with fluid.program_guard(main_prog, startup_prog):
        with fluid.unique_name.guard():
            use_mix = config.get('use_mix') and is_train
L
littletomatodonkey 已提交
348
            use_dali = config.get('use_dali')
littletomatodonkey's avatar
littletomatodonkey 已提交
349
            use_distillation = config.get('use_distillation')
L
littletomatodonkey 已提交
350 351 352 353 354 355 356 357
            feeds = create_feeds(config.image_shape, use_mix, use_dali)

            if use_dali and use_mix:
                import dali
                feeds = dali.mix(feeds, config, is_train)

            dataloader = create_dataloader(feeds.values()) if not config.get(
                'use_dali') else None
littletomatodonkey's avatar
littletomatodonkey 已提交
358
            out = create_model(config.ARCHITECTURE, feeds['image'],
S
add ema  
shippingwang 已提交
359
                               config.classes_num, is_train)
W
WuHaobo 已提交
360 361 362
            fetchs = create_fetchs(
                out,
                feeds,
littletomatodonkey's avatar
littletomatodonkey 已提交
363
                config.ARCHITECTURE,
W
WuHaobo 已提交
364 365 366
                config.topk,
                config.classes_num,
                epsilon=config.get('ls_epsilon'),
littletomatodonkey's avatar
littletomatodonkey 已提交
367 368
                use_mix=use_mix,
                use_distillation=use_distillation)
W
WuHaobo 已提交
369 370 371
            if is_train:
                optimizer = create_optimizer(config)
                lr = optimizer._global_learning_rate()
372
                fetchs['lr'] = (lr, AverageMeter('lr', 'f', need_avg=False))
373 374

                optimizer = mixed_precision_optimizer(config, optimizer)
375 376
                if is_distributed:
                    optimizer = dist_optimizer(config, optimizer)
W
WuHaobo 已提交
377
                optimizer.minimize(fetchs['loss'][0])
L
Leo Chen 已提交
378

S
add ema  
shippingwang 已提交
379 380
                if config.get('use_ema'):

S
shippingwang 已提交
381 382 383 384
                    global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter(
                    )
                    ema = ExponentialMovingAverage(
                        config.get('ema_decay'), thres_steps=global_steps)
S
add ema  
shippingwang 已提交
385
                    ema.update()
S
shippingwang 已提交
386
                    return dataloader, fetchs, ema
W
WuHaobo 已提交
387 388 389 390

    return dataloader, fetchs


littletomatodonkey's avatar
littletomatodonkey 已提交
391
def compile(config, program, loss_name=None, share_prog=None):
W
WuHaobo 已提交
392 393 394 395 396 397 398
    """
    Compile the program

    Args:
        config(dict): config
        program(): the program which is wrapped by
        loss_name(str): loss name
littletomatodonkey's avatar
littletomatodonkey 已提交
399
        share_prog(): the shared program, used for evaluation during training
W
WuHaobo 已提交
400 401 402 403 404 405 406 407 408 409

    Returns:
        compiled_program(): a compiled program
    """
    build_strategy = fluid.compiler.BuildStrategy()
    exec_strategy = fluid.ExecutionStrategy()

    exec_strategy.num_threads = 1
    exec_strategy.num_iteration_per_drop_scope = 10

L
Leo Chen 已提交
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427
    use_fp16 = config.get('use_fp16', False)
    fuse_bn_act_ops = config.get('fuse_bn_act_ops', True)
    fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', True)
    fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', True)
    enable_addto = config.get('enable_addto', True)

    if use_fp16:
        try:
            build_strategy.fuse_bn_act_ops = fuse_bn_act_ops
        except Exception as e:
            logger.info(
                "PaddlePaddle version 1.7.0 or higher is "
                "required when you want to fuse batch_norm and activation_op.")
        try:
            build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
        except Exception as e:
            logger.info(
                "PaddlePaddle version 1.7.0 or higher is "
L
littletomatodonkey 已提交
428 429 430
                "required when you want to fuse elewise_add_act and activation_op."
            )

L
Leo Chen 已提交
431 432 433 434 435
        try:
            build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
        except Exception as e:
            logger.info(
                "PaddlePaddle 2.0-rc or higher is "
L
littletomatodonkey 已提交
436 437
                "required when you want to enable fuse_bn_add_act_ops strategy."
            )
L
Leo Chen 已提交
438
        try:
L
littletomatodonkey 已提交
439

L
Leo Chen 已提交
440 441
            build_strategy.enable_addto = enable_addto
        except Exception as e:
L
littletomatodonkey 已提交
442 443
            logger.info("PaddlePaddle 2.0-rc or higher is "
                        "required when you want to enable addto strategy.")
L
Leo Chen 已提交
444

W
WuHaobo 已提交
445
    compiled_program = fluid.CompiledProgram(program).with_data_parallel(
littletomatodonkey's avatar
littletomatodonkey 已提交
446
        share_vars_from=share_prog,
W
WuHaobo 已提交
447 448 449 450 451 452 453
        loss_name=loss_name,
        build_strategy=build_strategy,
        exec_strategy=exec_strategy)

    return compiled_program


S
shippingwang 已提交
454 455 456
total_step = 0


S
shippingwang 已提交
457 458 459 460 461 462
def run(dataloader,
        exe,
        program,
        fetchs,
        epoch=0,
        mode='train',
463
        config=None,
S
shippingwang 已提交
464
        vdl_writer=None):
W
WuHaobo 已提交
465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
    """
    Feed data to the model and fetch the measures and loss

    Args:
        dataloader(fluid dataloader):
        exe():
        program():
        fetchs(dict): dict of measures and the loss
        epoch(int): epoch of training or validation
        model(str): log only

    Returns:
    """
    fetch_list = [f[0] for f in fetchs.values()]
    metric_list = [f[1] for f in fetchs.values()]
W
WuHaobo 已提交
480 481
    for m in metric_list:
        m.reset()
L
littletomatodonkey 已提交
482
    batch_time = AverageMeter('elapse', '.5f', need_avg=True)
W
WuHaobo 已提交
483
    tic = time.time()
L
littletomatodonkey 已提交
484 485
    dataloader = dataloader if config.get('use_dali') else dataloader()()
    for idx, batch in enumerate(dataloader):
L
littletomatodonkey 已提交
486 487 488 489 490
        if idx == 10:
            for m in metric_list:
                m.reset()
            batch_time.reset()
        batch_size = batch[0]["feed_image"].shape()[0]
W
WuHaobo 已提交
491 492 493
        metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list)
        batch_time.update(time.time() - tic)
        for i, m in enumerate(metrics):
L
littletomatodonkey 已提交
494
            metric_list[i].update(np.mean(m), batch_size)
littletomatodonkey's avatar
littletomatodonkey 已提交
495
        fetchs_str = ''.join([str(m.value) + ' '
L
littletomatodonkey 已提交
496
                              for m in metric_list] + [batch_time.mean]) + 's'
L
littletomatodonkey 已提交
497
        ips_info = " ips: {:.5f} images/sec.".format(batch_size /
L
littletomatodonkey 已提交
498
                                                     batch_time.avg)
L
littletomatodonkey 已提交
499
        fetchs_str += ips_info
S
fixed  
shippingwang 已提交
500
        if vdl_writer:
S
shippingwang 已提交
501
            global total_step
S
fixed  
shippingwang 已提交
502
            logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
S
shippingwang 已提交
503
            total_step += 1
W
WuHaobo 已提交
504
        if mode == 'eval':
S
fix  
shippingwang 已提交
505 506 507
            if idx % config.get('print_interval', 10) == 0:
                logger.info("{:s} step:{:<4d} {:s}".format(mode, idx,
                                                           fetchs_str))
W
WuHaobo 已提交
508
        else:
S
shippingwang 已提交
509 510
            epoch_str = "epoch:{:<3d}".format(epoch)
            step_str = "{:s} step:{:<4d}".format(mode, idx)
L
littletomatodonkey 已提交
511 512 513 514 515
            if idx % config.get('print_interval', 10) == 0:
                logger.info("{:s} {:s} {:s}".format(epoch_str
                                                    if idx == 0 else epoch_str,
                                                    step_str, fetchs_str))
        tic = time.time()
L
littletomatodonkey 已提交
516 517 518

    if config.get('use_dali'):
        dataloader.reset()
S
refine  
shippingwang 已提交
519

littletomatodonkey's avatar
littletomatodonkey 已提交
520
    end_str = ''.join([str(m.mean) + ' '
521
                       for m in metric_list] + [batch_time.total]) + 's'
L
littletomatodonkey 已提交
522 523 524
    ips_info = "ips: {:.5f} images/sec.".format(batch_size * batch_time.count /
                                                batch_time.sum)

W
WuHaobo 已提交
525
    if mode == 'eval':
L
littletomatodonkey 已提交
526
        logger.info("END {:s} {:s}s {:s}".format(mode, end_str, ips_info))
W
WuHaobo 已提交
527
    else:
S
shippingwang 已提交
528
        end_epoch_str = "END epoch:{:<3d}".format(epoch)
L
littletomatodonkey 已提交
529 530
        logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
                                                 ips_info))
littletomatodonkey's avatar
littletomatodonkey 已提交
531

W
WuHaobo 已提交
532
    # return top1_acc in order to save the best model
W
WuHaobo 已提交
533
    if mode == 'valid':
W
WuHaobo 已提交
534
        return fetchs["top1"][1].avg