classifier.py 18.5 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
import numpy as np
import time
import math
import tqdm
import paddle.fluid as fluid
import paddlex.utils.logging as logging
from paddlex.utils import seconds_to_hms
import paddlex
from collections import OrderedDict
from .base import BaseAPI


class BaseClassifier(BaseAPI):
    """构建分类器,并实现其训练、评估、预测和模型导出。
    Args:
        model_name (str): 分类器的模型名字,取值范围为['ResNet18',
                          'ResNet34', 'ResNet50', 'ResNet101',
                          'ResNet50_vd', 'ResNet101_vd', 'DarkNet53',
                          'MobileNetV1', 'MobileNetV2', 'Xception41',
                          'Xception65', 'Xception71']。默认为'ResNet50'。
        num_classes (int): 类别数。默认为1000。
    """

39
    def __init__(self, model_name='ResNet50', num_classes=1000):
J
jiangjiajun 已提交
40 41 42
        self.init_params = locals()
        super(BaseClassifier, self).__init__('classifier')
        if not hasattr(paddlex.cv.nets, str.lower(model_name)):
43 44
            raise Exception("ERROR: There's no model named {}.".format(
                model_name))
J
jiangjiajun 已提交
45 46 47
        self.model_name = model_name
        self.labels = None
        self.num_classes = num_classes
48
        self.fixed_input_shape = None
J
jiangjiajun 已提交
49 50

    def build_net(self, mode='train'):
C
Channingss 已提交
51
        if self.fixed_input_shape is not None:
52 53 54
            input_shape = [
                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
            ]
J
jiangjiajun 已提交
55
            image = fluid.data(dtype='float32', shape=input_shape, name='image')
C
Channingss 已提交
56 57 58
        else:
            image = fluid.data(
                dtype='float32', shape=[None, 3, None, None], name='image')
J
jiangjiajun 已提交
59 60 61
        if mode != 'test':
            label = fluid.data(dtype='int64', shape=[None, 1], name='label')
        model = getattr(paddlex.cv.nets, str.lower(self.model_name))
S
sunyanfang01 已提交
62
        net_out = model(image, num_classes=self.num_classes)
S
sunyanfang01 已提交
63
        softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
J
jiangjiajun 已提交
64
        inputs = OrderedDict([('image', image)])
S
SunAhong1993 已提交
65 66
        outputs = OrderedDict([('predict', softmax_out)])
        if mode == 'test':
S
rename  
sunyanfang01 已提交
67
            self.interpretation_feats = OrderedDict([('logits', net_out)])
J
jiangjiajun 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        if mode != 'test':
            cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
            avg_cost = fluid.layers.mean(cost)
            acc1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)
            k = min(5, self.num_classes)
            acck = fluid.layers.accuracy(input=softmax_out, label=label, k=k)
            if mode == 'train':
                self.optimizer.minimize(avg_cost)
            inputs = OrderedDict([('image', image), ('label', label)])
            outputs = OrderedDict([('loss', avg_cost), ('acc1', acc1),
                                   ('acc{}'.format(k), acck)])
        if mode == 'eval':
            del outputs['loss']
        return inputs, outputs

J
jiangjiajun 已提交
83 84
    def default_optimizer(self, learning_rate, warmup_steps, warmup_start_lr,
                          lr_decay_epochs, lr_decay_gamma,
J
jiangjiajun 已提交
85 86 87 88 89 90 91 92
                          num_steps_each_epoch):
        boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
        values = [
            learning_rate * (lr_decay_gamma**i)
            for i in range(len(lr_decay_epochs) + 1)
        ]
        lr_decay = fluid.layers.piecewise_decay(
            boundaries=boundaries, values=values)
J
jiangjiajun 已提交
93 94 95 96 97 98
        if warmup_steps > 0:
            if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
                logging.error(
                    "In function train(), parameters should satisfy: warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
                    exit=False)
                logging.error(
J
jiangjiajun 已提交
99 100
                    "See this doc for more information: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice",
                    exit=False)
J
jiangjiajun 已提交
101
                logging.error(
J
jiangjiajun 已提交
102 103 104
                    "warmup_steps should less than {} or lr_decay_epochs[0] greater than {}, please modify 'lr_decay_epochs' or 'warmup_steps' in train function".
                    format(lr_decay_epochs[0] * num_steps_each_epoch,
                           warmup_steps // num_steps_each_epoch))
J
jiangjiajun 已提交
105 106 107 108 109 110

            lr_decay = fluid.layers.linear_lr_warmup(
                learning_rate=lr_decay,
                warmup_steps=warmup_steps,
                start_lr=warmup_start_lr,
                end_lr=learning_rate)
J
jiangjiajun 已提交
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
        optimizer = fluid.optimizer.Momentum(
            lr_decay,
            momentum=0.9,
            regularization=fluid.regularizer.L2Decay(1e-04))
        return optimizer

    def train(self,
              num_epochs,
              train_dataset,
              train_batch_size=64,
              eval_dataset=None,
              save_interval_epochs=1,
              log_interval_steps=2,
              save_dir='output',
              pretrain_weights='IMAGENET',
              optimizer=None,
              learning_rate=0.025,
J
jiangjiajun 已提交
128 129
              warmup_steps=0,
              warmup_start_lr=0.0,
J
jiangjiajun 已提交
130 131 132 133
              lr_decay_epochs=[30, 60, 90],
              lr_decay_gamma=0.1,
              use_vdl=False,
              sensitivities_file=None,
F
FlyingQianMM 已提交
134 135
              eval_metric_loss=0.05,
              early_stop=False,
136 137
              early_stop_patience=5,
              resume_checkpoint=None):
J
jiangjiajun 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151
        """训练。
        Args:
            num_epochs (int): 训练迭代轮数。
            train_dataset (paddlex.datasets): 训练数据读取器。
            train_batch_size (int): 训练数据batch大小。同时作为验证数据batch大小。默认值为64。
            eval_dataset (paddlex.datasets: 验证数据读取器。
            save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。
            log_interval_steps (int): 训练日志输出间隔(单位:迭代步数)。默认为2。
            save_dir (str): 模型保存路径。
            pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',
                则自动下载在ImageNet图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认为'IMAGENET'。
            optimizer (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:
                fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
            learning_rate (float): 默认优化器的初始学习率。默认为0.025。
J
jiangjiajun 已提交
152 153
            warmup_steps(int): 学习率从warmup_start_lr上升至设定的learning_rate,所需的步数,默认为0
            warmup_start_lr(float): 学习率在warmup阶段时的起始值,默认为0.0
J
jiangjiajun 已提交
154 155 156 157 158 159
            lr_decay_epochs (list): 默认优化器的学习率衰减轮数。默认为[30, 60, 90]。
            lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。
            use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。
            sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
                则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
            eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
F
FlyingQianMM 已提交
160 161 162
            early_stop (bool): 是否使用提前终止训练策略。默认值为False。
            early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                连续下降或持平,则终止训练。默认值为5。
163
            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
J
jiangjiajun 已提交
164 165 166 167
        Raises:
            ValueError: 模型从inference model进行加载。
        """
        if not self.trainable:
J
jiangjiajun 已提交
168
            raise ValueError("Model is not trainable from load_model method.")
J
jiangjiajun 已提交
169 170 171 172 173
        self.labels = train_dataset.labels
        if optimizer is None:
            num_steps_each_epoch = train_dataset.num_samples // train_batch_size
            optimizer = self.default_optimizer(
                learning_rate=learning_rate,
J
jiangjiajun 已提交
174 175
                warmup_steps=warmup_steps,
                warmup_start_lr=warmup_start_lr,
J
jiangjiajun 已提交
176 177 178 179 180 181 182
                lr_decay_epochs=lr_decay_epochs,
                lr_decay_gamma=lr_decay_gamma,
                num_steps_each_epoch=num_steps_each_epoch)
        self.optimizer = optimizer
        # 构建训练、验证、预测网络
        self.build_program()
        # 初始化网络权重
183 184 185 186 187 188 189
        self.net_initialize(
            startup_prog=fluid.default_startup_program(),
            pretrain_weights=pretrain_weights,
            save_dir=save_dir,
            sensitivities_file=sensitivities_file,
            eval_metric_loss=eval_metric_loss,
            resume_checkpoint=resume_checkpoint)
J
jiangjiajun 已提交
190 191 192 193 194 195 196 197 198
        # 训练
        self.train_loop(
            num_epochs=num_epochs,
            train_dataset=train_dataset,
            train_batch_size=train_batch_size,
            eval_dataset=eval_dataset,
            save_interval_epochs=save_interval_epochs,
            log_interval_steps=log_interval_steps,
            save_dir=save_dir,
F
FlyingQianMM 已提交
199 200 201
            use_vdl=use_vdl,
            early_stop=early_stop,
            early_stop_patience=early_stop_patience)
J
jiangjiajun 已提交
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219

    def evaluate(self,
                 eval_dataset,
                 batch_size=1,
                 epoch_id=None,
                 return_details=False):
        """评估。
        Args:
            eval_dataset (paddlex.datasets): 验证数据读取器。
            batch_size (int): 验证数据批大小。默认为1。
            epoch_id (int): 当前评估模型所在的训练轮数。
            return_details (bool): 是否返回详细信息。
        Returns:
          dict: 当return_details为False时,返回dict, 包含关键字:'acc1'、'acc5',
              分别表示最大值的accuracy、前5个最大值的accuracy。
          tuple (metrics, eval_details): 当return_details为True时,增加返回dict,
              包含关键字:'true_labels'、'pred_scores',分别代表真实类别id、每个类别的预测得分。
        """
J
jiangjiajun 已提交
220
        self.arrange_transforms(transforms=eval_dataset.transforms, mode='eval')
J
jiangjiajun 已提交
221 222 223 224 225 226 227 228 229 230 231
        data_generator = eval_dataset.generator(
            batch_size=batch_size, drop_last=False)
        k = min(5, self.num_classes)
        total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size)
        true_labels = list()
        pred_scores = list()
        if not hasattr(self, 'parallel_test_prog'):
            self.parallel_test_prog = fluid.CompiledProgram(
                self.test_prog).with_data_parallel(
                    share_vars_from=self.parallel_train_prog)
        batch_size_each_gpu = self._get_single_card_bs(batch_size)
J
jiangjiajun 已提交
232 233
        logging.info("Start to evaluating(total_samples={}, total_steps={})...".
                     format(eval_dataset.num_samples, total_steps))
J
jiangjiajun 已提交
234 235 236 237 238 239 240 241 242
        for step, data in tqdm.tqdm(
                enumerate(data_generator()), total=total_steps):
            images = np.array([d[0] for d in data]).astype('float32')
            labels = [d[1] for d in data]
            num_samples = images.shape[0]
            if num_samples < batch_size:
                num_pad_samples = batch_size - num_samples
                pad_images = np.tile(images[0:1], (num_pad_samples, 1, 1, 1))
                images = np.concatenate([images, pad_images])
243 244 245
            outputs = self.exe.run(self.parallel_test_prog,
                                   feed={'image': images},
                                   fetch_list=list(self.test_outputs.values()))
J
jiangjiajun 已提交
246 247 248
            outputs = [outputs[0][:num_samples]]
            true_labels.extend(labels)
            pred_scores.extend(outputs[0].tolist())
249 250
            logging.debug("[EVAL] Epoch={}, Step={}/{}".format(epoch_id, step +
                                                               1, total_steps))
J
jiangjiajun 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286

        pred_top1_label = np.argsort(pred_scores)[:, -1]
        pred_topk_label = np.argsort(pred_scores)[:, -k:]
        acc1 = sum(pred_top1_label == true_labels) / len(true_labels)
        acck = sum(
            [np.isin(x, y)
             for x, y in zip(true_labels, pred_topk_label)]) / len(true_labels)
        metrics = OrderedDict([('acc1', acc1), ('acc{}'.format(k), acck)])
        if return_details:
            eval_details = {
                'true_labels': true_labels,
                'pred_scores': pred_scores
            }
            return metrics, eval_details
        return metrics

    def predict(self, img_file, transforms=None, topk=1):
        """预测。
        Args:
            img_file (str): 预测图像路径。
            transforms (paddlex.cls.transforms): 数据预处理操作。
            topk (int): 预测时前k个最大值。
        Returns:
            list: 其中元素均为字典。字典的关键字为'category_id'、'category'、'score',
            分别对应预测类别id、预测类别标签、预测得分。
        """
        if transforms is None and not hasattr(self, 'test_transforms'):
            raise Exception("transforms need to be defined, now is None.")
        true_topk = min(self.num_classes, topk)
        if transforms is not None:
            self.arrange_transforms(transforms=transforms, mode='test')
            im = transforms(img_file)
        else:
            self.arrange_transforms(
                transforms=self.test_transforms, mode='test')
            im = self.test_transforms(img_file)
287 288
        result = self.exe.run(self.test_prog,
                              feed={'image': im},
J
jiangjiajun 已提交
289 290
                              fetch_list=list(self.test_outputs.values()),
                              use_program_cache=True)
J
jiangjiajun 已提交
291 292 293 294 295 296 297
        pred_label = np.argsort(result[0][0])[::-1][:true_topk]
        res = [{
            'category_id': l,
            'category': self.labels[l],
            'score': result[0][0][l]
        } for l in pred_label]
        return res
S
sunyanfang01 已提交
298

J
jiangjiajun 已提交
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333

class ResNet18(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(ResNet18, self).__init__(
            model_name='ResNet18', num_classes=num_classes)


class ResNet34(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(ResNet34, self).__init__(
            model_name='ResNet34', num_classes=num_classes)


class ResNet50(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(ResNet50, self).__init__(
            model_name='ResNet50', num_classes=num_classes)


class ResNet101(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(ResNet101, self).__init__(
            model_name='ResNet101', num_classes=num_classes)


class ResNet50_vd(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(ResNet50_vd, self).__init__(
            model_name='ResNet50_vd', num_classes=num_classes)


class ResNet101_vd(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(ResNet101_vd, self).__init__(
            model_name='ResNet101_vd', num_classes=num_classes)
J
jiangjiajun 已提交
334 335


S
sunyanfang01 已提交
336 337
class ResNet50_vd_ssld(BaseClassifier):
    def __init__(self, num_classes=1000):
J
jiangjiajun 已提交
338 339 340 341
        super(ResNet50_vd_ssld, self).__init__(
            model_name='ResNet50_vd_ssld', num_classes=num_classes)


S
sunyanfang01 已提交
342 343
class ResNet101_vd_ssld(BaseClassifier):
    def __init__(self, num_classes=1000):
J
jiangjiajun 已提交
344 345
        super(ResNet101_vd_ssld, self).__init__(
            model_name='ResNet101_vd_ssld', num_classes=num_classes)
J
jiangjiajun 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375


class DarkNet53(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(DarkNet53, self).__init__(
            model_name='DarkNet53', num_classes=num_classes)


class MobileNetV1(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(MobileNetV1, self).__init__(
            model_name='MobileNetV1', num_classes=num_classes)


class MobileNetV2(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(MobileNetV2, self).__init__(
            model_name='MobileNetV2', num_classes=num_classes)


class MobileNetV3_small(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(MobileNetV3_small, self).__init__(
            model_name='MobileNetV3_small', num_classes=num_classes)


class MobileNetV3_large(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(MobileNetV3_large, self).__init__(
            model_name='MobileNetV3_large', num_classes=num_classes)
J
jiangjiajun 已提交
376 377


S
sunyanfang01 已提交
378 379
class MobileNetV3_small_ssld(BaseClassifier):
    def __init__(self, num_classes=1000):
J
jiangjiajun 已提交
380 381
        super(MobileNetV3_small_ssld, self).__init__(
            model_name='MobileNetV3_small_ssld', num_classes=num_classes)
S
sunyanfang01 已提交
382 383 384 385


class MobileNetV3_large_ssld(BaseClassifier):
    def __init__(self, num_classes=1000):
J
jiangjiajun 已提交
386 387
        super(MobileNetV3_large_ssld, self).__init__(
            model_name='MobileNetV3_large_ssld', num_classes=num_classes)
J
jiangjiajun 已提交
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423


class Xception65(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(Xception65, self).__init__(
            model_name='Xception65', num_classes=num_classes)


class Xception41(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(Xception41, self).__init__(
            model_name='Xception41', num_classes=num_classes)


class DenseNet121(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(DenseNet121, self).__init__(
            model_name='DenseNet121', num_classes=num_classes)


class DenseNet161(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(DenseNet161, self).__init__(
            model_name='DenseNet161', num_classes=num_classes)


class DenseNet201(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(DenseNet201, self).__init__(
            model_name='DenseNet201', num_classes=num_classes)


class ShuffleNetV2(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(ShuffleNetV2, self).__init__(
            model_name='ShuffleNetV2', num_classes=num_classes)
424 425 426 427 428 429


class HRNet_W18(BaseClassifier):
    def __init__(self, num_classes=1000):
        super(HRNet_W18, self).__init__(
            model_name='HRNet_W18', num_classes=num_classes)