vgg.py 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
M
ms_yan 已提交
15 16 17 18
"""
Image classifiation.
"""
import math
19 20
import mindspore.nn as nn
import mindspore.common.dtype as mstype
M
ms_yan 已提交
21 22 23
from mindspore.common import initializer as init
from mindspore.common.initializer import initializer
from .utils.var_init import default_recurisive_init, KaimingNormal
24

M
ms_yan 已提交
25 26

def _make_layer(base, args, batch_norm):
27 28 29 30 31 32 33
    """Make stage network of VGG."""
    layers = []
    in_channels = 3
    for v in base:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
M
ms_yan 已提交
34 35 36 37 38
            weight = 'ones'
            if args.initialize_mode == "XavierUniform":
                weight_shape = (v, in_channels, 3, 3)
                weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor()

39 40 41
            conv2d = nn.Conv2d(in_channels=in_channels,
                               out_channels=v,
                               kernel_size=3,
M
ms_yan 已提交
42 43 44
                               padding=args.padding,
                               pad_mode=args.pad_mode,
                               has_bias=args.has_bias,
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
                               weight_init=weight)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()]
            else:
                layers += [conv2d, nn.ReLU()]
            in_channels = v
    return nn.SequentialCell(layers)


class Vgg(nn.Cell):
    """
    VGG network definition.

    Args:
        base (list): Configuration for different layers, mainly the channel number of Conv layer.
        num_classes (int): Class numbers. Default: 1000.
        batch_norm (bool): Whether to do the batchnorm. Default: False.
        batch_size (int): Batch size. Default: 1.

    Returns:
        Tensor, infer output tensor.

    Examples:
        >>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
        >>>     num_classes=1000, batch_norm=False, batch_size=1)
    """

M
ms_yan 已提交
72
    def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train"):
73 74
        super(Vgg, self).__init__()
        _ = batch_size
M
ms_yan 已提交
75
        self.layers = _make_layer(base, args, batch_norm=batch_norm)
76
        self.flatten = nn.Flatten()
M
ms_yan 已提交
77
        dropout_ratio = 0.5
M
ms_yan 已提交
78
        if not args.has_dropout or phase == "test":
M
ms_yan 已提交
79
            dropout_ratio = 1.0
80 81 82
        self.classifier = nn.SequentialCell([
            nn.Dense(512 * 7 * 7, 4096),
            nn.ReLU(),
M
ms_yan 已提交
83
            nn.Dropout(dropout_ratio),
84 85
            nn.Dense(4096, 4096),
            nn.ReLU(),
M
ms_yan 已提交
86
            nn.Dropout(dropout_ratio),
87
            nn.Dense(4096, num_classes)])
M
ms_yan 已提交
88
        if args.initialize_mode == "KaimingNormal":
M
ms_yan 已提交
89 90
            default_recurisive_init(self)
            self.custom_init_weight()
91 92 93 94 95 96 97

    def construct(self, x):
        x = self.layers(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x

M
ms_yan 已提交
98 99 100 101 102 103 104 105
    def custom_init_weight(self):
        """
        Init the weight of Conv2d and Dense in the net.
        """
        for _, cell in self.cells_and_names():
            if isinstance(cell, nn.Conv2d):
                cell.weight.default_input = init.initializer(
                    KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'),
W
z  
Wei Luning 已提交
106
                    cell.weight.shape, cell.weight.dtype)
M
ms_yan 已提交
107 108
                if cell.bias is not None:
                    cell.bias.default_input = init.initializer(
W
z  
Wei Luning 已提交
109
                        'zeros', cell.bias.shape, cell.bias.dtype)
M
ms_yan 已提交
110 111
            elif isinstance(cell, nn.Dense):
                cell.weight.default_input = init.initializer(
W
z  
Wei Luning 已提交
112
                    init.Normal(0.01), cell.weight.shape, cell.weight.dtype)
M
ms_yan 已提交
113 114
                if cell.bias is not None:
                    cell.bias.default_input = init.initializer(
W
z  
Wei Luning 已提交
115
                        'zeros', cell.bias.shape, cell.bias.dtype)
M
ms_yan 已提交
116

117 118 119 120 121 122 123 124 125

cfg = {
    '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


M
ms_yan 已提交
126
def vgg16(num_classes=1000, args=None, phase="train"):
127 128 129 130 131
    """
    Get Vgg16 neural network with batch normalization.

    Args:
        num_classes (int): Class numbers. Default: 1000.
M
ms_yan 已提交
132
        args(namespace): param for net init.
M
ms_yan 已提交
133
        phase(str): train or test mode.
134 135 136 137 138

    Returns:
        Cell, cell instance of Vgg16 neural network with batch normalization.

    Examples:
M
ms_yan 已提交
139
        >>> vgg16(num_classes=1000, args=args)
140 141
    """

M
ms_yan 已提交
142
    net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase)
143
    return net