program.py 12.8 KB
Newer Older
1
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
W
WuHaobo 已提交
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
W
WuHaobo 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
WuHaobo 已提交
14 15 16 17 18 19 20

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
L
littletomatodonkey 已提交
21
import datetime
W
WuHaobo 已提交
22 23
from collections import OrderedDict

littletomatodonkey's avatar
littletomatodonkey 已提交
24
import paddle
littletomatodonkey's avatar
littletomatodonkey 已提交
25 26 27
from paddle import to_tensor
import paddle.nn as nn
import paddle.nn.functional as F
W
WuHaobo 已提交
28 29 30 31 32 33

from ppcls.optimizer import LearningRateBuilder
from ppcls.optimizer import OptimizerBuilder
from ppcls.modeling import architectures
from ppcls.modeling.loss import CELoss
from ppcls.modeling.loss import MixCELoss
littletomatodonkey's avatar
littletomatodonkey 已提交
34
from ppcls.modeling.loss import JSDivLoss
W
WuHaobo 已提交
35 36 37 38 39
from ppcls.modeling.loss import GoogLeNetLoss
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger


W
WuHaobo 已提交
40
def create_model(architecture, classes_num):
W
WuHaobo 已提交
41 42 43 44
    """
    Create a model

    Args:
45 46
        architecture(dict): architecture information,
            name(such as ResNet50) is needed
W
WuHaobo 已提交
47 48 49 50 51 52
        image(variable): model input variable
        classes_num(int): num of classes

    Returns:
        out(variable): model output variable
    """
littletomatodonkey's avatar
littletomatodonkey 已提交
53
    name = architecture["name"]
littletomatodonkey's avatar
littletomatodonkey 已提交
54
    params = architecture.get("params", {})
W
WuHaobo 已提交
55
    return architectures.__dict__[name](class_dim=classes_num, **params)
W
WuHaobo 已提交
56 57


58 59
def create_loss(feeds,
                out,
W
WuHaobo 已提交
60 61 62
                architecture,
                classes_num=1000,
                epsilon=None,
littletomatodonkey's avatar
littletomatodonkey 已提交
63 64
                use_mix=False,
                use_distillation=False):
W
WuHaobo 已提交
65 66 67 68 69 70 71 72 73 74 75
    """
    Create a loss for optimization, such as:
        1. CrossEnotry loss
        2. CrossEnotry loss with label smoothing
        3. CrossEnotry loss with mix(mixup, cutmix, fmix)
        4. CrossEnotry loss with label smoothing and (mixup, cutmix, fmix)
        5. GoogLeNet loss

    Args:
        out(variable): model output variable
        feeds(dict): dict of model input variables
76 77
        architecture(dict): architecture information,
            name(such as ResNet50) is needed
W
WuHaobo 已提交
78 79
        classes_num(int): num of classes
        epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
littletomatodonkey's avatar
littletomatodonkey 已提交
80
        use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
W
WuHaobo 已提交
81 82 83 84

    Returns:
        loss(variable): loss variable
    """
littletomatodonkey's avatar
littletomatodonkey 已提交
85
    if architecture["name"] == "GoogLeNet":
W
WuHaobo 已提交
86 87
        assert len(out) == 3, "GoogLeNet should have 3 outputs"
        loss = GoogLeNetLoss(class_dim=classes_num, epsilon=epsilon)
88
        return loss(out[0], out[1], out[2], feeds["label"])
W
WuHaobo 已提交
89

littletomatodonkey's avatar
littletomatodonkey 已提交
90
    if use_distillation:
91 92
        assert len(out) == 2, ("distillation output length must be 2, "
                               "but got {}".format(len(out)))
littletomatodonkey's avatar
littletomatodonkey 已提交
93 94 95 96
        loss = JSDivLoss(class_dim=classes_num, epsilon=epsilon)
        return loss(out[1], out[0])

    if use_mix:
W
WuHaobo 已提交
97
        loss = MixCELoss(class_dim=classes_num, epsilon=epsilon)
98 99 100 101
        feed_y_a = feeds['y_a']
        feed_y_b = feeds['y_b']
        feed_lam = feeds['lam']
        return loss(out, feed_y_a, feed_y_b, feed_lam)
W
WuHaobo 已提交
102 103
    else:
        loss = CELoss(class_dim=classes_num, epsilon=epsilon)
104
        return loss(out, feeds["label"])
W
WuHaobo 已提交
105 106


W
WuHaobo 已提交
107
def create_metric(out,
W
WuHaobo 已提交
108
                  label,
W
WuHaobo 已提交
109 110 111
                  architecture,
                  topk=5,
                  classes_num=1000,
112 113
                  use_distillation=False,
                  mode="train"):
W
WuHaobo 已提交
114 115 116 117 118 119 120 121
    """
    Create measures of model accuracy, such as top1 and top5

    Args:
        out(variable): model output variable
        feeds(dict): dict of model input variables(included label)
        topk(int): usually top5
        classes_num(int): num of classes
122 123
        use_distillation(bool): whether to use distillation training
        mode(str): mode, train/valid
W
WuHaobo 已提交
124 125 126 127

    Returns:
        fetchs(dict): dict of measures
    """
W
WuHaobo 已提交
128 129
    if architecture["name"] == "GoogLeNet":
        assert len(out) == 3, "GoogLeNet should have 3 outputs"
130
        out = out[0]
W
WuHaobo 已提交
131 132 133 134
    else:
        # just need student label to get metrics
        if use_distillation:
            out = out[1]
135
    softmax_out = F.softmax(out)
W
WuHaobo 已提交
136

W
WuHaobo 已提交
137
    fetchs = OrderedDict()
W
WuHaobo 已提交
138
    # set top1 to fetchs
littletomatodonkey's avatar
littletomatodonkey 已提交
139
    top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
W
WuHaobo 已提交
140
    # set topk to fetchs
W
WuHaobo 已提交
141
    k = min(topk, classes_num)
littletomatodonkey's avatar
littletomatodonkey 已提交
142
    topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
143 144 145 146 147 148 149 150 151 152 153

    # multi cards' eval
    if mode != "train" and paddle.distributed.get_world_size() > 1:
        top1 = paddle.distributed.all_reduce(
            top1, op=paddle.distributed.ReduceOp.
            SUM) / paddle.distributed.get_world_size()
        topk = paddle.distributed.all_reduce(
            topk, op=paddle.distributed.ReduceOp.
            SUM) / paddle.distributed.get_world_size()

    fetchs['top1'] = top1
W
WuHaobo 已提交
154
    topk_name = 'top{}'.format(k)
W
WuHaobo 已提交
155
    fetchs[topk_name] = topk
W
WuHaobo 已提交
156 157 158 159

    return fetchs


littletomatodonkey's avatar
littletomatodonkey 已提交
160
def create_fetchs(feeds, net, config, mode="train"):
W
WuHaobo 已提交
161 162
    """
    Create fetchs as model outputs(included loss and measures),
littletomatodonkey's avatar
littletomatodonkey 已提交
163
    will call create_loss and create_metric(if use_mix).
W
WuHaobo 已提交
164 165 166

    Args:
        out(variable): model output variable
W
WuHaobo 已提交
167 168
        feeds(dict): dict of model input variables.
            If use mix_up, it will not include label.
169 170
        architecture(dict): architecture information,
            name(such as ResNet50) is needed
W
WuHaobo 已提交
171 172 173
        topk(int): usually top5
        classes_num(int): num of classes
        epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
littletomatodonkey's avatar
littletomatodonkey 已提交
174
        use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
W
WuHaobo 已提交
175 176 177 178

    Returns:
        fetchs(dict): dict of model outputs(included loss and measures)
    """
littletomatodonkey's avatar
littletomatodonkey 已提交
179 180 181 182 183 184 185 186 187
    architecture = config.ARCHITECTURE
    topk = config.topk
    classes_num = config.classes_num
    epsilon = config.get('ls_epsilon')
    use_mix = config.get('use_mix') and mode == 'train'
    use_distillation = config.get('use_distillation')

    out = net(feeds["image"])

W
WuHaobo 已提交
188
    fetchs = OrderedDict()
189 190
    fetchs['loss'] = create_loss(feeds, out, architecture, classes_num,
                                 epsilon, use_mix, use_distillation)
littletomatodonkey's avatar
littletomatodonkey 已提交
191
    if not use_mix:
192 193 194 195 196 197 198 199
        metric = create_metric(
            out,
            feeds["label"],
            architecture,
            topk,
            classes_num,
            use_distillation,
            mode=mode)
W
WuHaobo 已提交
200 201 202 203 204
        fetchs.update(metric)

    return fetchs


W
WuHaobo 已提交
205
def create_optimizer(config, parameter_list=None):
W
WuHaobo 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
    """
    Create an optimizer using config, usually including
    learning rate and regularization.

    Args:
        config(dict):  such as
        {
            'LEARNING_RATE':
                {'function': 'Cosine',
                 'params': {'lr': 0.1}
                },
            'OPTIMIZER':
                {'function': 'Momentum',
                 'params':{'momentum': 0.9},
                 'regularizer':
                    {'function': 'L2', 'factor': 0.0001}
                }
        }

    Returns:
        an optimizer instance
    """
    # create learning_rate instance
    lr_config = config['LEARNING_RATE']
    lr_config['params'].update({
        'epochs': config['epochs'],
        'step_each_epoch':
        config['total_images'] // config['TRAIN']['batch_size'],
    })
    lr = LearningRateBuilder(**lr_config)()

    # create optimizer instance
    opt_config = config['OPTIMIZER']
    opt = OptimizerBuilder(**opt_config)
240
    return opt(lr, parameter_list), lr
W
WuHaobo 已提交
241 242


243
def create_feeds(batch, use_mix):
littletomatodonkey's avatar
littletomatodonkey 已提交
244
    image = batch[0]
245
    if use_mix:
littletomatodonkey's avatar
littletomatodonkey 已提交
246 247 248
        y_a = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1))
        y_b = to_tensor(batch[2].numpy().astype("int64").reshape(-1, 1))
        lam = to_tensor(batch[3].numpy().astype("float32").reshape(-1, 1))
249 250
        feeds = {"image": image, "y_a": y_a, "y_b": y_b, "lam": lam}
    else:
littletomatodonkey's avatar
littletomatodonkey 已提交
251
        label = to_tensor(batch[1].numpy().astype('int64').reshape(-1, 1))
252 253 254 255
        feeds = {"image": image, "label": label}
    return feeds


T
Tingquan Gao 已提交
256 257 258
total_step = 0


259 260 261 262 263 264
def run(dataloader,
        config,
        net,
        optimizer=None,
        lr_scheduler=None,
        epoch=0,
T
Tingquan Gao 已提交
265 266
        mode='train',
        vdl_writer=None):
W
WuHaobo 已提交
267 268 269 270
    """
    Feed data to the model and fetch the measures and loss

    Args:
littletomatodonkey's avatar
littletomatodonkey 已提交
271
        dataloader(paddle dataloader):
W
WuHaobo 已提交
272 273 274 275 276 277 278 279
        exe():
        program():
        fetchs(dict): dict of measures and the loss
        epoch(int): epoch of training or validation
        model(str): log only

    Returns:
    """
littletomatodonkey's avatar
littletomatodonkey 已提交
280
    print_interval = config.get("print_interval", 10)
littletomatodonkey's avatar
littletomatodonkey 已提交
281
    use_mix = config.get("use_mix", False) and mode == "train"
littletomatodonkey's avatar
littletomatodonkey 已提交
282 283

    metric_list = [
L
littletomatodonkey 已提交
284 285
        ("loss", AverageMeter(
            'loss', '7.5f', postfix=",")),
littletomatodonkey's avatar
littletomatodonkey 已提交
286
        ("lr", AverageMeter(
L
littletomatodonkey 已提交
287 288 289 290 291
            'lr', 'f', postfix=",", need_avg=False)),
        ("batch_time", AverageMeter(
            'batch_cost', '.5f', postfix=" s,")),
        ("reader_time", AverageMeter(
            'reader_cost', '.5f', postfix=" s,")),
littletomatodonkey's avatar
littletomatodonkey 已提交
292 293
    ]
    if not use_mix:
littletomatodonkey's avatar
littletomatodonkey 已提交
294
        topk_name = 'top{}'.format(config.topk)
L
littletomatodonkey 已提交
295
        metric_list.insert(
296
            0, (topk_name, AverageMeter(
L
littletomatodonkey 已提交
297 298
                topk_name, '.5f', postfix=",")))
        metric_list.insert(
299
            0, ("top1", AverageMeter(
L
littletomatodonkey 已提交
300
                "top1", '.5f', postfix=",")))
littletomatodonkey's avatar
littletomatodonkey 已提交
301 302

    metric_list = OrderedDict(metric_list)
W
WuHaobo 已提交
303

W
WuHaobo 已提交
304
    tic = time.time()
305
    for idx, batch in enumerate(dataloader()):
L
littletomatodonkey 已提交
306 307 308 309 310
        # avoid statistics from warmup time
        if idx == 10:
            metric_list["batch_time"].reset()
            metric_list["reader_time"].reset()

littletomatodonkey's avatar
littletomatodonkey 已提交
311
        metric_list['reader_time'].update(time.time() - tic)
littletomatodonkey's avatar
fix bs  
littletomatodonkey 已提交
312
        batch_size = len(batch[0])
littletomatodonkey's avatar
littletomatodonkey 已提交
313
        feeds = create_feeds(batch, use_mix)
littletomatodonkey's avatar
littletomatodonkey 已提交
314
        fetchs = create_fetchs(feeds, net, config, mode)
W
WuHaobo 已提交
315
        if mode == 'train':
L
littletomatodonkey 已提交
316 317 318 319 320
            avg_loss = fetchs['loss']
            avg_loss.backward()

            optimizer.step()
            optimizer.clear_grad()
T
Tingquan Gao 已提交
321 322
            lr_value = optimizer._global_learning_rate().numpy()[0]
            metric_list['lr'].update(lr_value, batch_size)
W
WuHaobo 已提交
323

324 325 326 327 328 329 330 331 332 333 334
            if lr_scheduler is not None:
                if lr_scheduler.update_specified:
                    curr_global_counter = lr_scheduler.step_each_epoch * epoch + idx
                    update = max(
                        0, curr_global_counter - lr_scheduler.update_start_step
                    ) % lr_scheduler.update_step_interval == 0
                    if update:
                        lr_scheduler.step()
                else:
                    lr_scheduler.step()

W
WuHaobo 已提交
335
        for name, fetch in fetchs.items():
littletomatodonkey's avatar
fix bs  
littletomatodonkey 已提交
336
            metric_list[name].update(fetch.numpy()[0], batch_size)
L
littletomatodonkey 已提交
337
        metric_list["batch_time"].update(time.time() - tic)
W
WuHaobo 已提交
338
        tic = time.time()
W
WuHaobo 已提交
339

T
Tingquan Gao 已提交
340 341 342 343 344 345 346 347 348 349 350 351
        if vdl_writer and mode == "train":
            global total_step
            logger.scaler(
                name="lr", value=lr_value, step=total_step, writer=vdl_writer)
            for name, fetch in fetchs.items():
                logger.scaler(
                    name="train_{}".format(name),
                    value=fetch.numpy()[0],
                    step=total_step,
                    writer=vdl_writer)
            total_step += 1

L
littletomatodonkey 已提交
352 353 354 355 356
        fetchs_str = ' '.join([
            str(metric_list[key].mean)
            if "time" in key else str(metric_list[key].value)
            for key in metric_list
        ])
littletomatodonkey's avatar
littletomatodonkey 已提交
357 358

        if idx % print_interval == 0:
L
littletomatodonkey 已提交
359
            ips_info = "ips: {:.5f} images/sec".format(
L
littletomatodonkey 已提交
360
                batch_size / metric_list["batch_time"].avg)
L
littletomatodonkey 已提交
361 362

            if mode == "train":
littletomatodonkey's avatar
littletomatodonkey 已提交
363 364
                epoch_str = "epoch:{:<3d}".format(epoch)
                step_str = "{:s} step:{:<4d}".format(mode, idx)
L
littletomatodonkey 已提交
365 366 367 368 369 370 371 372 373
                eta_sec = ((config["epochs"] - epoch) * len(dataloader) - idx
                           ) * metric_list["batch_time"].avg
                eta_str = "eta: {:s}".format(
                    str(datetime.timedelta(seconds=int(eta_sec))))
                logger.info("{:s}, {:s}, {:s} {:s}, {:s}".format(
                    epoch_str, step_str, fetchs_str, ips_info, eta_str))
            else:
                logger.info("{:s} step:{:<4d}, {:s} {:s}".format(
                    mode, idx, fetchs_str, ips_info))
S
refine  
shippingwang 已提交
374

375 376
    end_str = ' '.join([str(m.mean) for m in metric_list.values()] +
                       [metric_list['batch_time'].total])
L
littletomatodonkey 已提交
377 378 379 380
    ips_info = "ips: {:.5f} images/sec.".format(
        batch_size * metric_list["batch_time"].count /
        metric_list["batch_time"].sum)

W
WuHaobo 已提交
381
    if mode == 'eval':
L
littletomatodonkey 已提交
382
        logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
W
WuHaobo 已提交
383
    else:
S
shippingwang 已提交
384
        end_epoch_str = "END epoch:{:<3d}".format(epoch)
L
littletomatodonkey 已提交
385 386
        logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
                                                 ips_info))
littletomatodonkey's avatar
littletomatodonkey 已提交
387

W
WuHaobo 已提交
388
    # return top1_acc in order to save the best model
W
WuHaobo 已提交
389
    if mode == 'valid':
390
        return metric_list['top1'].avg