inceptionv3.py 17.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import paddle
import paddle.nn as nn
22
from paddle.nn import Linear, Dropout
23 24 25 26 27
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform
from paddle.fluid.param_attr import ParamAttr

from paddle.utils.download import get_weights_path_from_url
28
from ..ops import ConvNormActivation
29 30 31 32 33

__all__ = []

model_urls = {
    "inception_v3":
34 35
    ("https://paddle-hapi.bj.bcebos.com/models/inception_v3.pdparams",
     "649a4547c3243e8b59c656f41fe330b8")
36 37 38 39 40 41
}


class InceptionStem(nn.Layer):
    def __init__(self):
        super().__init__()
42 43 44 45 46 47 48 49 50 51 52
        self.conv_1a_3x3 = ConvNormActivation(
            in_channels=3,
            out_channels=32,
            kernel_size=3,
            stride=2,
            padding=0,
            activation_layer=nn.ReLU)
        self.conv_2a_3x3 = ConvNormActivation(
            in_channels=32,
            out_channels=32,
            kernel_size=3,
53
            stride=1,
54 55 56 57 58 59
            padding=0,
            activation_layer=nn.ReLU)
        self.conv_2b_3x3 = ConvNormActivation(
            in_channels=32,
            out_channels=64,
            kernel_size=3,
60
            padding=1,
61
            activation_layer=nn.ReLU)
62 63

        self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=0)
64 65 66 67 68 69 70 71 72 73 74 75
        self.conv_3b_1x1 = ConvNormActivation(
            in_channels=64,
            out_channels=80,
            kernel_size=1,
            padding=0,
            activation_layer=nn.ReLU)
        self.conv_4a_3x3 = ConvNormActivation(
            in_channels=80,
            out_channels=192,
            kernel_size=3,
            padding=0,
            activation_layer=nn.ReLU)
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90

    def forward(self, x):
        x = self.conv_1a_3x3(x)
        x = self.conv_2a_3x3(x)
        x = self.conv_2b_3x3(x)
        x = self.max_pool(x)
        x = self.conv_3b_1x1(x)
        x = self.conv_4a_3x3(x)
        x = self.max_pool(x)
        return x


class InceptionA(nn.Layer):
    def __init__(self, num_channels, pool_features):
        super().__init__()
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
        self.branch1x1 = ConvNormActivation(
            in_channels=num_channels,
            out_channels=64,
            kernel_size=1,
            padding=0,
            activation_layer=nn.ReLU)

        self.branch5x5_1 = ConvNormActivation(
            in_channels=num_channels,
            out_channels=48,
            kernel_size=1,
            padding=0,
            activation_layer=nn.ReLU)
        self.branch5x5_2 = ConvNormActivation(
            in_channels=48,
            out_channels=64,
            kernel_size=5,
108
            padding=2,
109 110 111 112 113 114 115 116 117 118 119 120
            activation_layer=nn.ReLU)

        self.branch3x3dbl_1 = ConvNormActivation(
            in_channels=num_channels,
            out_channels=64,
            kernel_size=1,
            padding=0,
            activation_layer=nn.ReLU)
        self.branch3x3dbl_2 = ConvNormActivation(
            in_channels=64,
            out_channels=96,
            kernel_size=3,
121
            padding=1,
122 123 124 125 126
            activation_layer=nn.ReLU)
        self.branch3x3dbl_3 = ConvNormActivation(
            in_channels=96,
            out_channels=96,
            kernel_size=3,
127
            padding=1,
128 129
            activation_layer=nn.ReLU)

130 131
        self.branch_pool = AvgPool2D(
            kernel_size=3, stride=1, padding=1, exclusive=False)
132 133 134 135 136 137
        self.branch_pool_conv = ConvNormActivation(
            in_channels=num_channels,
            out_channels=pool_features,
            kernel_size=1,
            padding=0,
            activation_layer=nn.ReLU)
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = self.branch_pool(x)
        branch_pool = self.branch_pool_conv(branch_pool)
        x = paddle.concat(
            [branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=1)
        return x


class InceptionB(nn.Layer):
    def __init__(self, num_channels):
        super().__init__()
158 159 160 161
        self.branch3x3 = ConvNormActivation(
            in_channels=num_channels,
            out_channels=384,
            kernel_size=3,
162
            stride=2,
163 164 165 166 167 168 169 170 171 172 173 174 175
            padding=0,
            activation_layer=nn.ReLU)

        self.branch3x3dbl_1 = ConvNormActivation(
            in_channels=num_channels,
            out_channels=64,
            kernel_size=1,
            padding=0,
            activation_layer=nn.ReLU)
        self.branch3x3dbl_2 = ConvNormActivation(
            in_channels=64,
            out_channels=96,
            kernel_size=3,
176
            padding=1,
177 178 179 180 181
            activation_layer=nn.ReLU)
        self.branch3x3dbl_3 = ConvNormActivation(
            in_channels=96,
            out_channels=96,
            kernel_size=3,
182
            stride=2,
183 184 185
            padding=0,
            activation_layer=nn.ReLU)

186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
        self.branch_pool = MaxPool2D(kernel_size=3, stride=2)

    def forward(self, x):
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = self.branch_pool(x)

        x = paddle.concat([branch3x3, branch3x3dbl, branch_pool], axis=1)

        return x


class InceptionC(nn.Layer):
    def __init__(self, num_channels, channels_7x7):
        super().__init__()
205 206 207 208 209 210 211 212 213 214 215
        self.branch1x1 = ConvNormActivation(
            in_channels=num_channels,
            out_channels=192,
            kernel_size=1,
            padding=0,
            activation_layer=nn.ReLU)

        self.branch7x7_1 = ConvNormActivation(
            in_channels=num_channels,
            out_channels=channels_7x7,
            kernel_size=1,
216
            stride=1,
217 218 219 220 221 222
            padding=0,
            activation_layer=nn.ReLU)
        self.branch7x7_2 = ConvNormActivation(
            in_channels=channels_7x7,
            out_channels=channels_7x7,
            kernel_size=(1, 7),
223 224
            stride=1,
            padding=(0, 3),
225 226 227 228 229
            activation_layer=nn.ReLU)
        self.branch7x7_3 = ConvNormActivation(
            in_channels=channels_7x7,
            out_channels=192,
            kernel_size=(7, 1),
230 231
            stride=1,
            padding=(3, 0),
232 233 234 235 236 237 238 239 240 241 242 243
            activation_layer=nn.ReLU)

        self.branch7x7dbl_1 = ConvNormActivation(
            in_channels=num_channels,
            out_channels=channels_7x7,
            kernel_size=1,
            padding=0,
            activation_layer=nn.ReLU)
        self.branch7x7dbl_2 = ConvNormActivation(
            in_channels=channels_7x7,
            out_channels=channels_7x7,
            kernel_size=(7, 1),
244
            padding=(3, 0),
245 246 247 248 249
            activation_layer=nn.ReLU)
        self.branch7x7dbl_3 = ConvNormActivation(
            in_channels=channels_7x7,
            out_channels=channels_7x7,
            kernel_size=(1, 7),
250
            padding=(0, 3),
251 252 253 254 255
            activation_layer=nn.ReLU)
        self.branch7x7dbl_4 = ConvNormActivation(
            in_channels=channels_7x7,
            out_channels=channels_7x7,
            kernel_size=(7, 1),
256
            padding=(3, 0),
257 258 259 260 261
            activation_layer=nn.ReLU)
        self.branch7x7dbl_5 = ConvNormActivation(
            in_channels=channels_7x7,
            out_channels=192,
            kernel_size=(1, 7),
262
            padding=(0, 3),
263
            activation_layer=nn.ReLU)
264 265 266

        self.branch_pool = AvgPool2D(
            kernel_size=3, stride=1, padding=1, exclusive=False)
267 268 269 270 271 272
        self.branch_pool_conv = ConvNormActivation(
            in_channels=num_channels,
            out_channels=192,
            kernel_size=1,
            padding=0,
            activation_layer=nn.ReLU)
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = self.branch_pool(x)
        branch_pool = self.branch_pool_conv(branch_pool)

        x = paddle.concat(
            [branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=1)

        return x


class InceptionD(nn.Layer):
    def __init__(self, num_channels):
        super().__init__()
299 300 301 302 303 304 305 306 307 308
        self.branch3x3_1 = ConvNormActivation(
            in_channels=num_channels,
            out_channels=192,
            kernel_size=1,
            padding=0,
            activation_layer=nn.ReLU)
        self.branch3x3_2 = ConvNormActivation(
            in_channels=192,
            out_channels=320,
            kernel_size=3,
309
            stride=2,
310 311 312 313 314 315 316 317 318 319 320 321 322
            padding=0,
            activation_layer=nn.ReLU)

        self.branch7x7x3_1 = ConvNormActivation(
            in_channels=num_channels,
            out_channels=192,
            kernel_size=1,
            padding=0,
            activation_layer=nn.ReLU)
        self.branch7x7x3_2 = ConvNormActivation(
            in_channels=192,
            out_channels=192,
            kernel_size=(1, 7),
323
            padding=(0, 3),
324 325 326 327 328
            activation_layer=nn.ReLU)
        self.branch7x7x3_3 = ConvNormActivation(
            in_channels=192,
            out_channels=192,
            kernel_size=(7, 1),
329
            padding=(3, 0),
330 331 332 333 334
            activation_layer=nn.ReLU)
        self.branch7x7x3_4 = ConvNormActivation(
            in_channels=192,
            out_channels=192,
            kernel_size=3,
335
            stride=2,
336 337 338
            padding=0,
            activation_layer=nn.ReLU)

339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
        self.branch_pool = MaxPool2D(kernel_size=3, stride=2)

    def forward(self, x):
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = self.branch_pool(x)

        x = paddle.concat([branch3x3, branch7x7x3, branch_pool], axis=1)
        return x


class InceptionE(nn.Layer):
    def __init__(self, num_channels):
        super().__init__()
359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
        self.branch1x1 = ConvNormActivation(
            in_channels=num_channels,
            out_channels=320,
            kernel_size=1,
            padding=0,
            activation_layer=nn.ReLU)
        self.branch3x3_1 = ConvNormActivation(
            in_channels=num_channels,
            out_channels=384,
            kernel_size=1,
            padding=0,
            activation_layer=nn.ReLU)
        self.branch3x3_2a = ConvNormActivation(
            in_channels=384,
            out_channels=384,
            kernel_size=(1, 3),
375
            padding=(0, 1),
376 377 378 379 380
            activation_layer=nn.ReLU)
        self.branch3x3_2b = ConvNormActivation(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 1),
381
            padding=(1, 0),
382 383 384 385 386 387 388 389 390 391 392 393
            activation_layer=nn.ReLU)

        self.branch3x3dbl_1 = ConvNormActivation(
            in_channels=num_channels,
            out_channels=448,
            kernel_size=1,
            padding=0,
            activation_layer=nn.ReLU)
        self.branch3x3dbl_2 = ConvNormActivation(
            in_channels=448,
            out_channels=384,
            kernel_size=3,
394
            padding=1,
395 396 397 398 399
            activation_layer=nn.ReLU)
        self.branch3x3dbl_3a = ConvNormActivation(
            in_channels=384,
            out_channels=384,
            kernel_size=(1, 3),
400
            padding=(0, 1),
401 402 403 404 405
            activation_layer=nn.ReLU)
        self.branch3x3dbl_3b = ConvNormActivation(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 1),
406
            padding=(1, 0),
407 408
            activation_layer=nn.ReLU)

409 410
        self.branch_pool = AvgPool2D(
            kernel_size=3, stride=1, padding=1, exclusive=False)
411 412 413 414 415 416
        self.branch_pool_conv = ConvNormActivation(
            in_channels=num_channels,
            out_channels=192,
            kernel_size=1,
            padding=0,
            activation_layer=nn.ReLU)
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = paddle.concat(branch3x3, axis=1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = paddle.concat(branch3x3dbl, axis=1)

        branch_pool = self.branch_pool(x)
        branch_pool = self.branch_pool_conv(branch_pool)

        x = paddle.concat(
            [branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1)
        return x


class InceptionV3(nn.Layer):
    """
    InceptionV3
    Args:
        num_classes (int, optional): output dim of last fc layer. If num_classes <=0, last fc layer 
                            will not be defined. Default: 1000. 
        with_pool (bool, optional): use pool before the last fc layer or not. Default: True.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.vision.models import InceptionV3

            inception_v3 = InceptionV3()

            x = paddle.rand([1, 3, 299, 299])
            out = inception_v3(x)

            print(out.shape)
    """

    def __init__(self, num_classes=1000, with_pool=True):
        super().__init__()
        self.num_classes = num_classes
        self.with_pool = with_pool
        self.layers_config = {
            "inception_a": [[192, 256, 288], [32, 64, 64]],
            "inception_b": [288],
            "inception_c": [[768, 768, 768, 768], [128, 160, 160, 192]],
            "inception_d": [768],
            "inception_e": [1280, 2048]
        }

        inception_a_list = self.layers_config["inception_a"]
        inception_c_list = self.layers_config["inception_c"]
        inception_b_list = self.layers_config["inception_b"]
        inception_d_list = self.layers_config["inception_d"]
        inception_e_list = self.layers_config["inception_e"]

        self.inception_stem = InceptionStem()

        self.inception_block_list = nn.LayerList()
        for i in range(len(inception_a_list[0])):
            inception_a = InceptionA(inception_a_list[0][i],
                                     inception_a_list[1][i])
            self.inception_block_list.append(inception_a)

        for i in range(len(inception_b_list)):
            inception_b = InceptionB(inception_b_list[i])
            self.inception_block_list.append(inception_b)

        for i in range(len(inception_c_list[0])):
            inception_c = InceptionC(inception_c_list[0][i],
                                     inception_c_list[1][i])
            self.inception_block_list.append(inception_c)

        for i in range(len(inception_d_list)):
            inception_d = InceptionD(inception_d_list[i])
            self.inception_block_list.append(inception_d)

        for i in range(len(inception_e_list)):
            inception_e = InceptionE(inception_e_list[i])
            self.inception_block_list.append(inception_e)

        if with_pool:
            self.avg_pool = AdaptiveAvgPool2D(1)

        if num_classes > 0:
            self.dropout = Dropout(p=0.2, mode="downscale_in_infer")
            stdv = 1.0 / math.sqrt(2048 * 1.0)
            self.fc = Linear(
                2048,
                num_classes,
                weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)),
                bias_attr=ParamAttr())

    def forward(self, x):
        x = self.inception_stem(x)
        for inception_block in self.inception_block_list:
            x = inception_block(x)

        if self.with_pool:
            x = self.avg_pool(x)

        if self.num_classes > 0:
            x = paddle.reshape(x, shape=[-1, 2048])
            x = self.dropout(x)
            x = self.fc(x)
        return x


def inception_v3(pretrained=False, **kwargs):
    """
    InceptionV3 model from
    `"Rethinking the Inception Architecture for Computer Vision" <https://arxiv.org/pdf/1512.00567.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    
    Examples:
        .. code-block:: python

            import paddle
            from paddle.vision.models import inception_v3

            # build model
            model = inception_v3()

            # build model and load imagenet pretrained weight
            # model = inception_v3(pretrained=True)

            x = paddle.rand([1, 3, 299, 299])
            out = model(x)

            print(out.shape)
    """
    model = InceptionV3(**kwargs)
    arch = "inception_v3"
    if pretrained:
        assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
            arch)
        weight_path = get_weights_path_from_url(model_urls[arch][0],
                                                model_urls[arch][1])

        param = paddle.load(weight_path)
        model.set_dict(param)
    return model