From 4878616e438672301536f74584ec7ea1623365fa Mon Sep 17 00:00:00 2001 From: chenzomi Date: Tue, 16 Jun 2020 15:01:54 +0800 Subject: [PATCH] change combined to nn --- mindspore/nn/layer/combined.py | 182 ------------------ mindspore/nn/layer/quant.py | 171 +++++++++++++++- mindspore/train/quant/quant.py | 7 +- .../train/quant/mobilenetv2_combined.py | 55 +++--- tests/ut/python/train/quant/test_quant.py | 13 +- 5 files changed, 205 insertions(+), 223 deletions(-) delete mode 100644 mindspore/nn/layer/combined.py diff --git a/mindspore/nn/layer/combined.py b/mindspore/nn/layer/combined.py deleted file mode 100644 index e5b3d0d03..000000000 --- a/mindspore/nn/layer/combined.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Use combination of Conv, Dense, Relu, Batchnorm.""" - -from .normalization import BatchNorm2d -from .activation import get_activation -from ..cell import Cell -from . import conv, basic -from ..._checkparam import ParamValidator as validator - - -__all__ = ['Conv2d', 'Dense'] - -class Conv2d(Cell): - r""" - A combination of convolution, Batchnorm, activation layer. - - For a more Detailed overview of Conv2d op. - - Args: - in_channels (int): The number of input channel :math:`C_{in}`. - out_channels (int): The number of output channel :math:`C_{out}`. - kernel_size (Union[int, tuple]): The data type is int or tuple with 2 integers. Specifies the height - and width of the 2D convolution window. Single int means the value if for both height and width of - the kernel. A tuple of 2 ints means the first value is for the height and the other is for the - width of the kernel. - stride (int): Specifies stride for all spatial dimensions with the same value. Value of stride should be - greater or equal to 1 but bounded by the height and width of the input. Default: 1. - pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". - padding (int): Implicit paddings on both sides of the input. Default: 0. - dilation (int): Specifying the dilation rate to use for dilated convolution. If set to be :math:`k > 1`, - there will be :math:`k - 1` pixels skipped for each sampling location. Its value should be greater - or equal to 1 and bounded by the height and width of the input. Default: 1. - group (int): Split filter into groups, `in_ channels` and `out_channels` should be - divisible by the number of groups. Default: 1. - has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. - weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. - It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified, - values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well - as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' - and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of - Initializer for more details. Default: 'normal'. - bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible - Initializer and string are the same as 'weight_init'. Refer to the values of - Initializer for more details. Default: 'zeros'. - batchnorm (bool): Specifies to used batchnorm or not. Default: None. - activation (string): Specifies activation type. The optional values are as following: - 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', - 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. - - Inputs: - - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - - Outputs: - Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. - - Examples: - >>> net = combined.Conv2d(120, 240, 4, batchnorm=True, activation='ReLU') - >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) - >>> net(input).shape - (1, 240, 1024, 640) - """ - - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - pad_mode='same', - padding=0, - dilation=1, - group=1, - has_bias=False, - weight_init='normal', - bias_init='zeros', - batchnorm=None, - activation=None): - super(Conv2d, self).__init__() - self.conv = conv.Conv2d( - in_channels, - out_channels, - kernel_size, - stride, - pad_mode, - padding, - dilation, - group, - has_bias, - weight_init, - bias_init) - self.has_bn = batchnorm is not None - self.has_act = activation is not None - self.batchnorm = batchnorm - if batchnorm is True: - self.batchnorm = BatchNorm2d(out_channels) - elif batchnorm is not None: - validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) - self.activation = get_activation(activation) - - def construct(self, x): - x = self.conv(x) - if self.has_bn: - x = self.batchnorm(x) - if self.has_act: - x = self.activation(x) - return x - - -class Dense(Cell): - r""" - A combination of Dense, Batchnorm, activation layer. - - For a more Detailed overview of Dense op. - - Args: - in_channels (int): The number of channels in the input space. - out_channels (int): The number of channels in the output space. - weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype - is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. - bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is - same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. - has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. - activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. - batchnorm (bool): Specifies to used batchnorm or not. Default: None. - activation (string): Specifies activation type. The optional values are as following: - 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', - 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. - - Inputs: - - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. - - Outputs: - Tensor of shape :math:`(N, out\_channels)`. - - Examples: - >>> net = nn.Dense(3, 4) - >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) - >>> net(input) - """ - - def __init__(self, - in_channels, - out_channels, - weight_init='normal', - bias_init='zeros', - has_bias=True, - batchnorm=None, - activation=None): - super(Dense, self).__init__() - self.dense = basic.Dense( - in_channels, - out_channels, - weight_init, - bias_init, - has_bias) - self.has_bn = batchnorm is not None - self.has_act = activation is not None - if batchnorm is True: - self.batchnorm = BatchNorm2d(out_channels) - elif batchnorm is not None: - validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) - self.activation = get_activation(activation) - - def construct(self, x): - x = self.dense(x) - if self.has_bn: - x = self.batchnorm(x) - if self.has_act: - x = self.activation(x) - return x diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 46f6cd5a4..ca4134105 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -27,8 +27,16 @@ from mindspore._checkparam import Validator as validator, Rel from mindspore.nn.cell import Cell from mindspore.nn.layer.activation import get_activation import mindspore.context as context +from .normalization import BatchNorm2d +from .activation import get_activation +from ..cell import Cell +from . import conv, basic +from ..._checkparam import ParamValidator as validator + __all__ = [ + 'Conv2dBnAct', + 'DenseBnAct', 'FakeQuantWithMinMax', 'Conv2dBatchNormQuant', 'Conv2dQuant', @@ -42,6 +50,165 @@ __all__ = [ ] +class Conv2dBnAct(Cell): + r""" + A combination of convolution, Batchnorm, activation layer. + + For a more Detailed overview of Conv2d op. + + Args: + in_channels (int): The number of input channel :math:`C_{in}`. + out_channels (int): The number of output channel :math:`C_{out}`. + kernel_size (Union[int, tuple]): The data type is int or tuple with 2 integers. Specifies the height + and width of the 2D convolution window. Single int means the value if for both height and width of + the kernel. A tuple of 2 ints means the first value is for the height and the other is for the + width of the kernel. + stride (int): Specifies stride for all spatial dimensions with the same value. Value of stride should be + greater or equal to 1 but bounded by the height and width of the input. Default: 1. + pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". + padding (int): Implicit paddings on both sides of the input. Default: 0. + dilation (int): Specifying the dilation rate to use for dilated convolution. If set to be :math:`k > 1`, + there will be :math:`k - 1` pixels skipped for each sampling location. Its value should be greater + or equal to 1 and bounded by the height and width of the input. Default: 1. + group (int): Split filter into groups, `in_ channels` and `out_channels` should be + divisible by the number of groups. Default: 1. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. + It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified, + values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well + as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' + and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of + Initializer for more details. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible + Initializer and string are the same as 'weight_init'. Refer to the values of + Initializer for more details. Default: 'zeros'. + batchnorm (bool): Specifies to used batchnorm or not. Default: None. + activation (string): Specifies activation type. The optional values are as following: + 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', + 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> net = Conv2dBnAct(120, 240, 4, batchnorm=True, activation='ReLU') + >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) + >>> net(input).shape + (1, 240, 1024, 640) + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + has_bias=False, + weight_init='normal', + bias_init='zeros', + batchnorm=None, + activation=None): + super(Conv2dBnAct, self).__init__() + self.conv = conv.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + pad_mode, + padding, + dilation, + group, + has_bias, + weight_init, + bias_init) + self.has_bn = batchnorm is not None + self.has_act = activation is not None + self.batchnorm = batchnorm + if batchnorm is True: + self.batchnorm = BatchNorm2d(out_channels) + elif batchnorm is not None: + validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) + self.activation = get_activation(activation) + + def construct(self, x): + x = self.conv(x) + if self.has_bn: + x = self.batchnorm(x) + if self.has_act: + x = self.activation(x) + return x + + +class DenseBnAct(Cell): + r""" + A combination of Dense, Batchnorm, activation layer. + + For a more Detailed overview of Dense op. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype + is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is + same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. + activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. + batchnorm (bool): Specifies to used batchnorm or not. Default: None. + activation (string): Specifies activation type. The optional values are as following: + 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', + 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. + + Outputs: + Tensor of shape :math:`(N, out\_channels)`. + + Examples: + >>> net = nn.Dense(3, 4) + >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) + >>> net(input) + """ + + def __init__(self, + in_channels, + out_channels, + weight_init='normal', + bias_init='zeros', + has_bias=True, + batchnorm=None, + activation=None): + super(DenseBnAct, self).__init__() + self.dense = basic.Dense( + in_channels, + out_channels, + weight_init, + bias_init, + has_bias) + self.has_bn = batchnorm is not None + self.has_act = activation is not None + if batchnorm is True: + self.batchnorm = BatchNorm2d(out_channels) + elif batchnorm is not None: + validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) + self.activation = get_activation(activation) + + def construct(self, x): + x = self.dense(x) + if self.has_bn: + x = self.batchnorm(x) + if self.has_act: + x = self.activation(x) + return x + + class BatchNormFoldCell(Cell): """ Batch normalization folded. @@ -302,8 +469,8 @@ class Conv2dBatchNormQuant(Cell): # initialize convolution op and Parameter if context.get_context('device_target') == "Ascend" and group > 1: - validator.check_integer('group', group, in_channels, Rel.EQ, 'Conv2dBatchNormQuant') - validator.check_integer('group', group, out_channels, Rel.EQ, 'Conv2dBatchNormQuant') + validator.check_integer('group', group, in_channels, Rel.EQ) + validator.check_integer('group', group, out_channels, Rel.EQ) self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=self.kernel_size, pad_mode=pad_mode, diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index ff4042693..5caa5a444 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -19,7 +19,6 @@ from ... import nn from ... import ops from ..._checkparam import ParamValidator as validator from ..._checkparam import Rel -from ...nn.layer import combined from ...nn.layer import quant _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, @@ -123,13 +122,13 @@ class ConvertToQuantNetwork: subcell = cells[name] if subcell == network: continue - elif isinstance(subcell, combined.Conv2d): + elif isinstance(subcell, quant.Conv2dBnAct): prefix = subcell.param_prefix new_subcell = self._convert_conv(subcell) new_subcell.update_parameters_name(prefix + '.') network.insert_child_to_cell(name, new_subcell) change = True - elif isinstance(subcell, combined.Dense): + elif isinstance(subcell, quant.DenseBnAct): prefix = subcell.param_prefix new_subcell = self._convert_dense(subcell) new_subcell.update_parameters_name(prefix + '.') @@ -159,7 +158,7 @@ class ConvertToQuantNetwork: def _convert_conv(self, subcell): """ - convet conv cell to combine cell + convet conv cell to quant cell """ conv_inner = subcell.conv bn_inner = subcell.batchnorm diff --git a/tests/ut/python/train/quant/mobilenetv2_combined.py b/tests/ut/python/train/quant/mobilenetv2_combined.py index 5ae241c0f..7ed1498fb 100644 --- a/tests/ut/python/train/quant/mobilenetv2_combined.py +++ b/tests/ut/python/train/quant/mobilenetv2_combined.py @@ -1,6 +1,5 @@ """mobile net v2""" from mindspore import nn -from mindspore.nn.layer import combined from mindspore.ops import operations as P @@ -14,11 +13,11 @@ def _conv_bn(in_channel, stride=1): """Get a conv2d batchnorm and relu layer.""" return nn.SequentialCell( - [combined.Conv2d(in_channel, - out_channel, - kernel_size=ksize, - stride=stride, - batchnorm=True)]) + [nn.Conv2dBnAct(in_channel, + out_channel, + kernel_size=ksize, + stride=stride, + batchnorm=True)]) class InvertedResidual(nn.Cell): @@ -31,30 +30,30 @@ class InvertedResidual(nn.Cell): self.use_res_connect = self.stride == 1 and inp == oup if expend_ratio == 1: self.conv = nn.SequentialCell([ - combined.Conv2d(hidden_dim, - hidden_dim, - 3, - stride, - group=hidden_dim, - batchnorm=True, - activation='relu6'), - combined.Conv2d(hidden_dim, oup, 1, 1, - batchnorm=True) + nn.Conv2dBnAct(hidden_dim, + hidden_dim, + 3, + stride, + group=hidden_dim, + batchnorm=True, + activation='relu6'), + nn.Conv2dBnAct(hidden_dim, oup, 1, 1, + batchnorm=True) ]) else: self.conv = nn.SequentialCell([ - combined.Conv2d(inp, hidden_dim, 1, 1, - batchnorm=True, - activation='relu6'), - combined.Conv2d(hidden_dim, - hidden_dim, - 3, - stride, - group=hidden_dim, - batchnorm=True, - activation='relu6'), - combined.Conv2d(hidden_dim, oup, 1, 1, - batchnorm=True) + nn.Conv2dBnAct(inp, hidden_dim, 1, 1, + batchnorm=True, + activation='relu6'), + nn.Conv2dBnAct(hidden_dim, + hidden_dim, + 3, + stride, + group=hidden_dim, + batchnorm=True, + activation='relu6'), + nn.Conv2dBnAct(hidden_dim, oup, 1, 1, + batchnorm=True) ]) self.add = P.TensorAdd() @@ -99,7 +98,7 @@ class MobileNetV2(nn.Cell): self.features = nn.SequentialCell(features) self.mean = P.ReduceMean(keep_dims=False) - self.classifier = combined.Dense(self.last_channel, num_class) + self.classifier = nn.DenseBnAct(self.last_channel, num_class) def construct(self, input_x): out = input_x diff --git a/tests/ut/python/train/quant/test_quant.py b/tests/ut/python/train/quant/test_quant.py index d11f169e2..e299c7b9f 100644 --- a/tests/ut/python/train/quant/test_quant.py +++ b/tests/ut/python/train/quant/test_quant.py @@ -15,7 +15,7 @@ """ tests for quant """ import mindspore.context as context from mindspore import nn -from mindspore.nn.layer import combined + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") @@ -37,12 +37,11 @@ class LeNet5(nn.Cell): def __init__(self, num_class=10): super(LeNet5, self).__init__() self.num_class = num_class - self.conv1 = combined.Conv2d( - 1, 6, kernel_size=5, batchnorm=True, activation='relu6') - self.conv2 = combined.Conv2d(6, 16, kernel_size=5, activation='relu') - self.fc1 = combined.Dense(16 * 5 * 5, 120, activation='relu') - self.fc2 = combined.Dense(120, 84, activation='relu') - self.fc3 = combined.Dense(84, self.num_class) + self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6') + self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu') + self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') + self.fc2 = nn.DenseBnAct(120, 84, activation='relu') + self.fc3 = nn.DenseBnAct(84, self.num_class) self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flattern = nn.Flatten() -- GitLab