vgg.py 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
M
ms_yan 已提交
15 16 17 18
"""
Image classifiation.
"""
import math
19 20
import mindspore.nn as nn
import mindspore.common.dtype as mstype
M
ms_yan 已提交
21 22 23
from mindspore.common import initializer as init
from mindspore.common.initializer import initializer
from .utils.var_init import default_recurisive_init, KaimingNormal
24

M
ms_yan 已提交
25 26

def _make_layer(base, args, batch_norm):
27 28 29 30 31 32 33 34 35
    """Make stage network of VGG."""
    layers = []
    in_channels = 3
    for v in base:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            weight_shape = (v, in_channels, 3, 3)
            weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor()
M
ms_yan 已提交
36 37
            if args.dataset == "imagenet2012":
                weight = 'normal'
38 39 40
            conv2d = nn.Conv2d(in_channels=in_channels,
                               out_channels=v,
                               kernel_size=3,
M
ms_yan 已提交
41 42 43
                               padding=args.padding,
                               pad_mode=args.pad_mode,
                               has_bias=args.has_bias,
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
                               weight_init=weight)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()]
            else:
                layers += [conv2d, nn.ReLU()]
            in_channels = v
    return nn.SequentialCell(layers)


class Vgg(nn.Cell):
    """
    VGG network definition.

    Args:
        base (list): Configuration for different layers, mainly the channel number of Conv layer.
        num_classes (int): Class numbers. Default: 1000.
        batch_norm (bool): Whether to do the batchnorm. Default: False.
        batch_size (int): Batch size. Default: 1.

    Returns:
        Tensor, infer output tensor.

    Examples:
        >>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
        >>>     num_classes=1000, batch_norm=False, batch_size=1)
    """

M
ms_yan 已提交
71
    def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train"):
72 73
        super(Vgg, self).__init__()
        _ = batch_size
M
ms_yan 已提交
74
        self.layers = _make_layer(base, args, batch_norm=batch_norm)
75
        self.flatten = nn.Flatten()
M
ms_yan 已提交
76 77 78
        dropout_ratio = 0.5
        if args.dataset == "cifar10" or phase == "test":
            dropout_ratio = 1.0
79 80 81
        self.classifier = nn.SequentialCell([
            nn.Dense(512 * 7 * 7, 4096),
            nn.ReLU(),
M
ms_yan 已提交
82
            nn.Dropout(dropout_ratio),
83 84
            nn.Dense(4096, 4096),
            nn.ReLU(),
M
ms_yan 已提交
85
            nn.Dropout(dropout_ratio),
86
            nn.Dense(4096, num_classes)])
M
ms_yan 已提交
87 88 89
        if args.dataset == "imagenet2012":
            default_recurisive_init(self)
            self.custom_init_weight()
90 91 92 93 94 95 96

    def construct(self, x):
        x = self.layers(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x

M
ms_yan 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
    def custom_init_weight(self):
        """
        Init the weight of Conv2d and Dense in the net.
        """
        for _, cell in self.cells_and_names():
            if isinstance(cell, nn.Conv2d):
                cell.weight.default_input = init.initializer(
                    KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'),
                    cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor()
                if cell.bias is not None:
                    cell.bias.default_input = init.initializer(
                        'zeros', cell.bias.default_input.shape, cell.bias.default_input.dtype).to_tensor()
            elif isinstance(cell, nn.Dense):
                cell.weight.default_input = init.initializer(
                    init.Normal(0.01), cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor()
                if cell.bias is not None:
                    cell.bias.default_input = init.initializer(
                        'zeros', cell.bias.default_input.shape, cell.bias.default_input.dtype).to_tensor()

116 117 118 119 120 121 122 123 124

cfg = {
    '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


M
ms_yan 已提交
125
def vgg16(num_classes=1000, args=None, phase="train"):
126 127 128 129 130
    """
    Get Vgg16 neural network with batch normalization.

    Args:
        num_classes (int): Class numbers. Default: 1000.
M
ms_yan 已提交
131 132
        args(dict): param for net init.
        phase(str): train or test mode.
133 134 135 136 137 138 139 140

    Returns:
        Cell, cell instance of Vgg16 neural network with batch normalization.

    Examples:
        >>> vgg16(num_classes=1000)
    """

M
ms_yan 已提交
141
    net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=True, phase=phase)
142
    return net