program.py 14.3 KB
Newer Older
littletomatodonkey's avatar
littletomatodonkey 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import numpy as np

from collections import OrderedDict

import paddle
import paddle.nn.functional as F

from paddle.distributed import fleet
from paddle.distributed.fleet import DistributedStrategy

# from ppcls.optimizer import OptimizerBuilder
# from ppcls.optimizer.learning_rate import LearningRateBuilder

from ppcls.arch import build_model
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.optimizer import build_lr_scheduler

from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger


def create_feeds(image_shape, use_mix=None, dtype="float32"):
    """
    Create feeds as model input

    Args:
        image_shape(list[int]): model input shape, such as [3, 224, 224]
        use_mix(bool): whether to use mix(include mixup, cutmix, fmix)

    Returns:
        feeds(dict): dict of model input variables
    """
    feeds = OrderedDict()
    feeds['data'] = paddle.static.data(
        name="data", shape=[None] + image_shape, dtype=dtype)
    if use_mix:
        feeds['y_a'] = paddle.static.data(
            name="y_a", shape=[None, 1], dtype="int64")
        feeds['y_b'] = paddle.static.data(
            name="y_b", shape=[None, 1], dtype="int64")
        feeds['lam'] = paddle.static.data(
            name="lam", shape=[None, 1], dtype=dtype)
    else:
        feeds['label'] = paddle.static.data(
            name="label", shape=[None, 1], dtype="int64")

    return feeds


def create_fetchs(out,
                  feeds,
                  architecture,
                  topk=5,
                  epsilon=None,
                  use_mix=False,
                  config=None,
                  mode="Train"):
    """
    Create fetchs as model outputs(included loss and measures),
    will call create_loss and create_metric(if use_mix).
    Args:
        out(variable): model output variable
        feeds(dict): dict of model input variables.
            If use mix_up, it will not include label.
        architecture(dict): architecture information,
            name(such as ResNet50) is needed
        topk(int): usually top5
        epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
        use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
        config(dict): model config

    Returns:
        fetchs(dict): dict of model outputs(included loss and measures)
    """
    fetchs = OrderedDict()
    # build loss
    # TODO(littletomatodonkey): support mix training
    if use_mix:
        y_a = paddle.reshape(feeds['y_a'], [-1, 1])
        y_b = paddle.reshape(feeds['y_b'], [-1, 1])
        lam = paddle.reshape(feeds['lam'], [-1, 1])
    else:
        target = paddle.reshape(feeds['label'], [-1, 1])

    loss_func = build_loss(config["Loss"][mode])

    # TODO: support mix training
    loss_dict = loss_func(out, target)

    loss_out = loss_dict["loss"]
    # if "AMP" in config and config.AMP.get("use_pure_fp16", False):
    # loss_out = loss_out.astype("float16")

    # if use_mix:
    #     return loss_func(out, feed_y_a, feed_y_b, feed_lam)
    # else:
    #     return loss_func(out, target)

    fetchs['loss'] = (loss_out, AverageMeter('loss', '7.4f', need_avg=True))

    assert use_mix is False

    # build metric
    if not use_mix:
        metric_func = build_metrics(config["Metric"][mode])

        metric_dict = metric_func(out, target)

        for key in metric_dict:
            if mode != "Train" and paddle.distributed.get_world_size() > 1:
                paddle.distributed.all_reduce(
                    metric_dict[key], op=paddle.distributed.ReduceOp.SUM)
                metric_dict[key] = metric_dict[
                    key] / paddle.distributed.get_world_size()

            fetchs[key] = (metric_dict[key], AverageMeter(
                key, '7.4f', need_avg=True))

    return fetchs


def create_optimizer(config, step_each_epoch):
    # create learning_rate instance
    optimizer, lr_sch = build_optimizer(
        config["Optimizer"], config["Global"]["epochs"], step_each_epoch)
    return optimizer, lr_sch


def create_strategy(config):
    """
    Create build strategy and exec strategy.

    Args:
        config(dict): config

    Returns:
        build_strategy: build strategy
        exec_strategy: exec strategy
    """
    build_strategy = paddle.static.BuildStrategy()
    exec_strategy = paddle.static.ExecutionStrategy()

    exec_strategy.num_threads = 1
    exec_strategy.num_iteration_per_drop_scope = (
        10000
        if 'AMP' in config and config.AMP.get("use_pure_fp16", False) else 10)

    fuse_op = True if 'AMP' in config else False

    fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
    fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
    fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
    enable_addto = config.get('enable_addto', fuse_op)

    build_strategy.fuse_bn_act_ops = fuse_bn_act_ops
    build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
    build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
    build_strategy.enable_addto = enable_addto

    return build_strategy, exec_strategy


def dist_optimizer(config, optimizer):
    """
    Create a distributed optimizer based on a normal optimizer

    Args:
        config(dict):
        optimizer(): a normal optimizer

    Returns:
        optimizer: a distributed optimizer
    """
    build_strategy, exec_strategy = create_strategy(config)

    dist_strategy = DistributedStrategy()
    dist_strategy.execution_strategy = exec_strategy
    dist_strategy.build_strategy = build_strategy

    dist_strategy.nccl_comm_num = 1
    dist_strategy.fuse_all_reduce_ops = True
    dist_strategy.fuse_grad_size_in_MB = 16
    optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)

    return optimizer


def mixed_precision_optimizer(config, optimizer):
    if 'AMP' in config:
        amp_cfg = config.AMP if config.AMP else dict()
        scale_loss = amp_cfg.get('scale_loss', 1.0)
        use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling',
                                               False)
        use_pure_fp16 = amp_cfg.get('use_pure_fp16', False)
        optimizer = paddle.static.amp.decorate(
            optimizer,
            init_loss_scaling=scale_loss,
            use_dynamic_loss_scaling=use_dynamic_loss_scaling,
            use_pure_fp16=use_pure_fp16,
            use_fp16_guard=True)

    return optimizer


def build(config,
          main_prog,
          startup_prog,
          step_each_epoch=100,
          is_train=True,
          is_distributed=True):
    """
    Build a program using a model and an optimizer
        1. create feeds
        2. create a dataloader
        3. create a model
        4. create fetchs
        5. create an optimizer

    Args:
        config(dict): config
        main_prog(): main program
        startup_prog(): startup program
        is_train(bool): train or eval
        is_distributed(bool): whether to use distributed training method

    Returns:
        dataloader(): a bridge between the model and the data
        fetchs(dict): dict of model outputs(included loss and measures)
    """
    with paddle.static.program_guard(main_prog, startup_prog):
        with paddle.utils.unique_name.guard():
            mode = "Train" if is_train else "Eval"
            use_mix = "batch_transform_ops" in config["DataLoader"][mode][
                "dataset"]
            use_dali = config["Global"].get('use_dali', False)
            feeds = create_feeds(
                config["Global"]["image_shape"],
                use_mix=use_mix,
                dtype="float32")

            # build model
            # data_format should be assigned in arch-dict
            input_image_channel = config["Global"]["image_shape"][
                0]  # default as [3, 224, 224]
            model = build_model(config["Arch"])
            out = model(feeds["data"])
            # end of build model

            fetchs = create_fetchs(
                out,
                feeds,
                config["Arch"],
                epsilon=config.get('ls_epsilon'),
                use_mix=use_mix,
                config=config,
                mode=mode)
            lr_scheduler = None
            optimizer = None
            if is_train:
                optimizer, lr_scheduler = build_optimizer(
                    config["Optimizer"], config["Global"]["epochs"],
                    step_each_epoch)
                optimizer = mixed_precision_optimizer(config, optimizer)
                if is_distributed:
                    optimizer = dist_optimizer(config, optimizer)
                optimizer.minimize(fetchs['loss'][0])
    return fetchs, lr_scheduler, feeds, optimizer


def compile(config, program, loss_name=None, share_prog=None):
    """
    Compile the program

    Args:
        config(dict): config
        program(): the program which is wrapped by
        loss_name(str): loss name
        share_prog(): the shared program, used for evaluation during training

    Returns:
        compiled_program(): a compiled program
    """
    build_strategy, exec_strategy = create_strategy(config)

    compiled_program = paddle.static.CompiledProgram(
        program).with_data_parallel(
            share_vars_from=share_prog,
            loss_name=loss_name,
            build_strategy=build_strategy,
            exec_strategy=exec_strategy)

    return compiled_program


total_step = 0


def run(dataloader,
        exe,
        program,
        feeds,
        fetchs,
        epoch=0,
        mode='train',
        config=None,
        vdl_writer=None,
        lr_scheduler=None):
    """
    Feed data to the model and fetch the measures and loss

    Args:
        dataloader(paddle io dataloader):
        exe():
        program():
        fetchs(dict): dict of measures and the loss
        epoch(int): epoch of training or evaluation
        model(str): log only

    Returns:
    """
    fetch_list = [f[0] for f in fetchs.values()]
    metric_dict = OrderedDict([("lr", AverageMeter(
        'lr', 'f', postfix=",", need_avg=False))])

    for k in fetchs:
        metric_dict[k] = fetchs[k][1]

    metric_dict["batch_time"] = AverageMeter(
        'batch_cost', '.5f', postfix=" s,")
    metric_dict["reader_time"] = AverageMeter(
        'reader_cost', '.5f', postfix=" s,")

    for m in metric_dict.values():
        m.reset()

    use_dali = config["Global"].get('use_dali', False)
    tic = time.time()

    if not use_dali:
        dataloader = dataloader()

    idx = 0
    batch_size = None
    while True:
        # The DALI maybe raise RuntimeError for some particular images, such as ImageNet1k/n04418357_26036.JPEG
        try:
            batch = next(dataloader)
        except StopIteration:
            break
        except RuntimeError:
            logger.warning(
                "Except RuntimeError when reading data from dataloader, try to read once again..."
            )
            continue
        idx += 1
        # ignore the warmup iters
        if idx == 5:
            metric_dict["batch_time"].reset()
            metric_dict["reader_time"].reset()

        metric_dict['reader_time'].update(time.time() - tic)

        if use_dali:
            batch_size = batch[0]["data"].shape()[0]
            feed_dict = batch[0]
        else:
            batch_size = batch[0].shape()[0]
            feed_dict = {
                key.name: batch[idx]
                for idx, key in enumerate(feeds.values())
            }

        metrics = exe.run(program=program,
                          feed=feed_dict,
                          fetch_list=fetch_list)

        for name, m in zip(fetchs.keys(), metrics):
            metric_dict[name].update(np.mean(m), batch_size)
        metric_dict["batch_time"].update(time.time() - tic)
        if mode == "train":
            metric_dict['lr'].update(lr_scheduler.get_lr())

        fetchs_str = ' '.join([
            str(metric_dict[key].mean)
            if "time" in key else str(metric_dict[key].value)
            for key in metric_dict
        ])
        ips_info = " ips: {:.5f} images/sec.".format(
            batch_size / metric_dict["batch_time"].avg)
        fetchs_str += ips_info

        if lr_scheduler is not None:
            lr_scheduler.step()

        if vdl_writer:
            global total_step
            logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
            total_step += 1
        if mode == 'eval':
            if idx % config.get('print_interval', 10) == 0:
                logger.info("{:s} step:{:<4d} {:s}".format(mode, idx,
                                                           fetchs_str))
        else:
            epoch_str = "epoch:{:<3d}".format(epoch)
            step_str = "{:s} step:{:<4d}".format(mode, idx)

            if idx % config.get('print_interval', 10) == 0:
                logger.info("{:s} {:s} {:s}".format(epoch_str, step_str,
                                                    fetchs_str))

        tic = time.time()

    end_str = ' '.join([str(m.mean) for m in metric_dict.values()] +
                       [metric_dict["batch_time"].total])
    ips_info = "ips: {:.5f} images/sec.".format(
        batch_size * metric_dict["batch_time"].count /
        metric_dict["batch_time"].sum)
    if mode == 'eval':
        logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
    else:
        end_epoch_str = "END epoch:{:<3d}".format(epoch)
        logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
                                                 ips_info))
    if use_dali:
        dataloader.reset()

    # return top1_acc in order to save the best model
    if mode == 'eval':
        return fetchs["top1"][1].avg