diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 3f3f2ecdc045ac3936cbc80b48085e26583a65d1..89bb406a273110c311cce1203d7a631ef6bf0e7a 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -97,7 +97,7 @@ class Cell: After invoked, can get all the cell's children's name perfix by '_param_perfix'. """ - cells = self.cells_and_names + cells = self.cells_and_names() for cell_name, cell in cells: cell._param_perfix = cell_name diff --git a/mindspore/nn/layer/combined.py b/mindspore/nn/layer/combined.py new file mode 100644 index 0000000000000000000000000000000000000000..671365e393989c16e6142ddc4de7a526aac3ab6f --- /dev/null +++ b/mindspore/nn/layer/combined.py @@ -0,0 +1,182 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Use combination of Conv, Dense, Relu, Batchnorm.""" + +from .normalization import BatchNorm2d +from .activation import get_activation +from ..cell import Cell +from . import conv, basic +from ..._checkparam import ParamValidator as validator + + +__all__ = ['Conv2d', 'Dense'] + +class Conv2d(Cell): + r""" + A combination of convolution, Batchnorm, activation layer. + + For a more Detailed overview of Conv2d op. + + Args: + in_channels (int): The number of input channel :math:`C_{in}`. + out_channels (int): The number of output channel :math:`C_{out}`. + kernel_size (Union[int, tuple]): The data type is int or tuple with 2 integers. Specifies the height + and width of the 2D convolution window. Single int means the value if for both height and width of + the kernel. A tuple of 2 ints means the first value is for the height and the other is for the + width of the kernel. + stride (int): Specifies stride for all spatial dimensions with the same value. Value of stride should be + greater or equal to 1 but bounded by the height and width of the input. Default: 1. + pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". + padding (int): Implicit paddings on both sides of the input. Default: 0. + dilation (int): Specifying the dilation rate to use for dilated convolution. If set to be :math:`k > 1`, + there will be :math:`k - 1` pixels skipped for each sampling location. Its value should be greater + or equal to 1 and bounded by the height and width of the input. Default: 1. + group (int): Split filter into groups, `in_ channels` and `out_channels` should be + divisible by the number of groups. Default: 1. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. + It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified, + values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well + as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' + and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of + Initializer for more details. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible + Initializer and string are the same as 'weight_init'. Refer to the values of + Initializer for more details. Default: 'zeros'. + batchnorm (bool): Specifies to used batchnorm or not. Default: None. + activation (string): Specifies activation type. The optional values are as following: + 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', + 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> net = combined.Conv2d(120, 240, 4, batchnorm=True, activation='ReLU') + >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) + >>> net(input).shape() + (1, 240, 1024, 640) + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + has_bias=False, + weight_init='normal', + bias_init='zeros', + batchnorm=None, + activation=None): + super(Conv2d, self).__init__() + self.conv = conv.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + pad_mode, + padding, + dilation, + group, + has_bias, + weight_init, + bias_init) + self.has_bn = batchnorm is not None + self.has_act = activation is not None + self.batchnorm = batchnorm + if batchnorm is True: + self.batchnorm = BatchNorm2d(out_channels) + elif batchnorm is not None: + validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) + self.activation = get_activation(activation) + + def construct(self, x): + x = self.conv(x) + if self.has_bn: + x = self.batchnorm(x) + if self.has_act: + x = self.activation(x) + return x + + +class Dense(Cell): + r""" + A combination of Dense, Batchnorm, activation layer. + + For a more Detailed overview of Dense op. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype + is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is + same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. + activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. + batchnorm (bool): Specifies to used batchnorm or not. Default: None. + activation (string): Specifies activation type. The optional values are as following: + 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', + 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. + + Outputs: + Tensor of shape :math:`(N, out\_channels)`. + + Examples: + >>> net = nn.Dense(3, 4) + >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) + >>> net(input) + """ + + def __init__(self, + in_channels, + out_channels, + weight_init='normal', + bias_init='zeros', + has_bias=True, + batchnorm=None, + activation=None): + super(Dense, self).__init__() + self.dense = basic.Dense( + in_channels, + out_channels, + weight_init, + bias_init, + has_bias) + self.has_bn = batchnorm is not None + self.has_act = activation is not None + if batchnorm is True: + self.batchnorm = BatchNorm2d(out_channels) + elif batchnorm is not None: + validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) + self.activation = get_activation(activation) + + def construct(self, x): + x = self.dense(x) + if self.has_bn: + x = self.batchnorm(x) + if self.has_act: + x = self.activation(x) + return x diff --git a/mindspore/train/quant/__init__.py b/mindspore/train/quant/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..531db34b2b7176d2d863a4934925cef869525cc4 --- /dev/null +++ b/mindspore/train/quant/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +quantization. + +User can use aware quantization to train a model. Mindspore supports quantization aware training, +which models quantization errors in both the forward and backward passes using fake-quantization +ops. Note that the entire computation is carried out in floating point. At the end of quantization +aware training, Mindspore provides conversion functions to convert the trained model into lower precision. +""" + +from .quant import convert_quant_network + +__all__ = ["convert_quant_network"] diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..e2a035bc77e1b89c5aef8a210f2601d3543648c1 --- /dev/null +++ b/mindspore/train/quant/quant.py @@ -0,0 +1,262 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""aware quantization.""" + +import re +from ... import nn +from ... import ops +from ..._checkparam import ParamValidator as validator +from ..._checkparam import Rel +from ...nn.layer import combined +from ...nn.layer import quant + +_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, + nn.ReLU6: quant.ReLU6Quant, + nn.HSigmoid: quant.HSigmoidQuant, + nn.HSwish: quant.HSwishQuant} + + +class _AddFakeQuantInputOutput(nn.Cell): + """ + Add FakeQuant at input and output of the Network. Only support one input and one output case. + """ + + def __init__(self, network, quant_delay=0): + super(_AddFakeQuantInputOutput, self).__init__(auto_prefix=False) + self.network = network + self.fake_quant_input = quant.FakeQuantWithMinMax( + min_init=-6, max_init=6, quant_delay=quant_delay, ema=True) + self.fake_quant_input.update_parameters_name('fake_quant_input') + self.fake_quant_output = quant.FakeQuantWithMinMax( + min_init=-6, max_init=6, quant_delay=quant_delay, ema=True) + self.fake_quant_output.update_parameters_name('fake_quant_output') + + def construct(self, data): + data = self.fake_quant_input(data) + output = self.network(data) + output = self.fake_quant_output(output) + return output + + +class _AddFakeQuantAfterSubCell(nn.Cell): + """ + Add FakeQuant after of the sub Cell. + """ + + def __init__(self, subcell, quant_delay=0, num_bits=8): + super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) + self.subcell = subcell + self.fake_quant_act = quant.FakeQuantWithMinMax(min_init=-6, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True) + + def construct(self, *data): + output = self.subcell(*data) + output = self.fake_quant_act(output) + return output + + +class ConvertToQuantNetwork: + """ + Convert network to quantization aware network + """ + __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] + + def __init__(self, + network, + quant_delay=0, + bn_fold=False, + freeze_bn=0, + weight_bits=8, + act_bits=8, + per_channel=False, + symmetric=False, + narrow_range=False): + self.network = validator.check_isinstance( + 'network', network, (nn.Cell,)) + self.quant_delay = validator.check_integer( + "quant delay", quant_delay, 0, Rel.GE) + self.freeze_bn = validator.check_integer( + "freeze bn", freeze_bn, 0, Rel.GE) + self.weight_bits = validator.check_integer( + "weights bit", weight_bits, 0, Rel.GE) + self.act_bits = validator.check_integer( + "activations bit", act_bits, 0, Rel.GE) + self.bn_fold = validator.check_bool("bn fold", bn_fold) + self.per_channel = validator.check_bool("per channel", per_channel) + self.symmetric = validator.check_bool("symmetric", symmetric) + self.narrow_range = validator.check_bool("narrow range", narrow_range) + + def _convert_op_name(self, name): + pattern = re.compile(r'([A-Z]{1})') + name_new = re.sub(pattern, r'_\1', name).lower() + if name_new[0] == '_': + name_new = name_new[1:] + return name_new + + def run(self): + self.network.update_cell_prefix() + network = self._convert_subcells2quant(self.network) + return network + + def _convert_subcells2quant(self, network): + """ + convet sub cell to quant cell + """ + cells = network.name_cells() + change = False + for name in cells: + subcell = cells[name] + if subcell == network: + continue + elif isinstance(subcell, combined.Conv2d): + prefix = subcell.param_prefix + new_subcell = self._convert_conv(subcell) + new_subcell.update_parameters_name(prefix + '.') + network.insert_child_to_cell(name, new_subcell) + change = True + elif isinstance(subcell, combined.Dense): + prefix = subcell.param_prefix + new_subcell = self._convert_dense(subcell) + new_subcell.update_parameters_name(prefix + '.') + network.insert_child_to_cell(name, new_subcell) + change = True + else: + self._convert_subcells2quant(subcell) + if isinstance(network, nn.SequentialCell) and change: + network.cell_list = list(network.cells()) + + # tensoradd to tensoradd quant + add_list = [] + for name in network.__dict__: + if name[0] == '_': + continue + attr = network.__dict__[name] + if isinstance(attr, ops.Primitive) and attr.name in ConvertToQuantNetwork.__quant_op_name__: + add_list.append((name, attr)) + for name, prim_op in add_list: + prefix = name + add_quant = _AddFakeQuantAfterSubCell(prim_op) # quant.TensorAddQuant() + prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) + add_quant.update_parameters_name(prefix + '.') + del network.__dict__[name] + network.insert_child_to_cell(name, add_quant) + return network + + def _convert_conv(self, subcell): + """ + convet conv cell to combine cell + """ + conv_inner = subcell.conv + bn_inner = subcell.batchnorm + if subcell.batchnorm is not None and self.bn_fold: + conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels, + conv_inner.out_channels, + kernel_size=conv_inner.kernel_size, + stride=conv_inner.stride, + pad_mode=conv_inner.pad_mode, + padding=conv_inner.padding, + dilation=conv_inner.dilation, + group=conv_inner.group, + eps=bn_inner.eps, + momentum=bn_inner.momentum, + quant_delay=self.quant_delay, + freeze_bn=self.freeze_bn, + per_channel=self.per_channel, + num_bits=self.weight_bits, + fake=True, + symmetric=self.symmetric, + narrow_range=self.narrow_range) + del subcell.batchnorm + subcell.batchnorm = None + subcell.has_bn = False + else: + conv_inner = quant.Conv2dQuant(conv_inner.in_channels, + conv_inner.out_channels, + kernel_size=conv_inner.kernel_size, + stride=conv_inner.stride, + pad_mode=conv_inner.pad_mode, + padding=conv_inner.padding, + dilation=conv_inner.dilation, + group=conv_inner.group, + has_bias=conv_inner.has_bias, + quant_delay=self.quant_delay, + per_channel=self.per_channel, + num_bits=self.weight_bits, + symmetric=self.symmetric, + narrow_range=self.narrow_range) + subcell.conv = conv_inner + if subcell.activation is not None: + subcell.activation = self._convert_activation(subcell.activation) + else: + subcell = _AddFakeQuantAfterSubCell(subcell) + return subcell + + def _convert_dense(self, subcell): + """ + convert dense cell to combine dense cell + """ + dense_inner = subcell.dense + dense_inner = quant.DenseQuant(dense_inner.in_channels, + dense_inner.out_channels, + has_bias=dense_inner.has_bias, + quant_delay=self.quant_delay, + per_channel=self.per_channel, + num_bits=self.weight_bits) + subcell.dense = dense_inner + if subcell.activation is not None: + subcell.activation = self._convert_activation(subcell.activation) + return subcell + + def _convert_activation(self, activation): + act_class = activation.__class__ + if act_class not in _ACTIVATION_MAP: + raise ValueError( + "Unsupported activation in auto Quant: ", act_class) + return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, quant_delay=self.quant_delay) + + +def convert_quant_network(network, + quant_delay=0, + bn_fold=False, + freeze_bn=0, + weight_bits=8, + act_bits=8, + per_channel=False, + symmetric=False, + narrow_range=False + ): + r""" + Create aware quantizaiton training network. + + Args: + network (Cell): Obtain a pipeline through network for saving graph summary. + quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0. + bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. + freeze_bn (bool): Number of steps after which BN parameters used total mean and variance. Default: 0. + weight_bits (int): Number of bits to use for quantizing weights. Default: 8. + act_bits (int): Number of bits to use for quantizing activations. Default: 8. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + + returns: + Cell, Network which has change to aware quantization training network. + """ + net = ConvertToQuantNetwork( + network, quant_delay, bn_fold, freeze_bn, weight_bits, act_bits, per_channel, symmetric, narrow_range) + return net.run() diff --git a/tests/ut/python/train/quant/mobilenetv2.py b/tests/ut/python/train/quant/mobilenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..38daaf314b378b0589852f610ccce8ac39f5f024 --- /dev/null +++ b/tests/ut/python/train/quant/mobilenetv2.py @@ -0,0 +1,100 @@ +"""MobileNetV2""" +from mindspore import nn +from mindspore.ops import operations as P + + +def make_divisible(input_x, div_by=8): + return int((input_x + div_by) // div_by) + + +def _conv_bn(in_channel, + out_channel, + ksize, + stride=1): + """Get a conv2d batchnorm and relu layer.""" + return nn.SequentialCell( + [nn.Conv2d(in_channel, + out_channel, + kernel_size=ksize, + stride=stride), + nn.BatchNorm2d(out_channel)]) + + +class InvertedResidual(nn.Cell): + def __init__(self, inp, oup, stride, expend_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(inp * expend_ratio) + self.use_res_connect = self.stride == 1 and inp == oup + if expend_ratio == 1: + self.conv = nn.SequentialCell([ + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(), + nn.Conv2d(hidden_dim, oup, 1, 1), + nn.BatchNorm2d(oup) + ]) + else: + self.conv = nn.SequentialCell([ + nn.Conv2d(inp, hidden_dim, 1, 1), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(), + + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(), + + nn.Conv2d(hidden_dim, oup, 1, 1), + nn.BatchNorm2d(oup) + ]) + + def construct(self, input_x): + out = self.conv(input_x) + if self.use_res_connect: + out = input_x + out + return out + + +class MobileNetV2(nn.Cell): + def __init__(self, num_class=1000, input_size=224, width_mul=1.): + super(MobileNetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + inverted_residual_setting = [ + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 230, 1, 1], + ] + if width_mul > 1.0: + last_channel = make_divisible(last_channel * width_mul) + self.last_channel = last_channel + features = [_conv_bn(3, input_channel, 3, 2)] + + for t, c, n, s in inverted_residual_setting: + out_channel = make_divisible(c * width_mul) if t > 1 else c + for i in range(n): + if i == 0: + features.append(block(input_channel, out_channel, s, t)) + else: + features.append(block(input_channel, out_channel, 1, t)) + input_channel = out_channel + + features.append(_conv_bn(input_channel, self.last_channel, 1)) + + self.features = nn.SequentialCell(features) + self.mean = P.ReduceMean(keep_dims=False) + self.classifier = nn.Dense(self.last_channel, num_class) + + def construct(self, input_x): + out = input_x + out = self.features(out) + out = self.mean(out, (2, 3)) + out = self.classifier(out) + return out diff --git a/tests/ut/python/train/quant/mobilenetv2_combined.py b/tests/ut/python/train/quant/mobilenetv2_combined.py new file mode 100644 index 0000000000000000000000000000000000000000..e8161dcd94976a6d51d3f853a6460dbfdca3012b --- /dev/null +++ b/tests/ut/python/train/quant/mobilenetv2_combined.py @@ -0,0 +1,108 @@ +"""mobile net v2""" +from mindspore import nn +from mindspore.ops import operations as P +from mindspore.nn.layer import combined + + +def make_divisible(input_x, div_by=8): + return int((input_x + div_by) // div_by) + + +def _conv_bn(in_channel, + out_channel, + ksize, + stride=1): + """Get a conv2d batchnorm and relu layer.""" + return nn.SequentialCell( + [combined.Conv2d(in_channel, + out_channel, + kernel_size=ksize, + stride=stride, + batchnorm=True)]) + + +class InvertedResidual(nn.Cell): + def __init__(self, inp, oup, stride, expend_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(inp * expend_ratio) + self.use_res_connect = self.stride == 1 and inp == oup + if expend_ratio == 1: + self.conv = nn.SequentialCell([ + combined.Conv2d(hidden_dim, + hidden_dim, + 3, + stride, + group=hidden_dim, + batchnorm=True, + activation='relu6'), + combined.Conv2d(hidden_dim, oup, 1, 1, + batchnorm=True) + ]) + else: + self.conv = nn.SequentialCell([ + combined.Conv2d(inp, hidden_dim, 1, 1, + batchnorm=True, + activation='relu6'), + combined.Conv2d(hidden_dim, + hidden_dim, + 3, + stride, + group=hidden_dim, + batchnorm=True, + activation='relu6'), + combined.Conv2d(hidden_dim, oup, 1, 1, + batchnorm=True) + ]) + self.add = P.TensorAdd() + + def construct(self, input_x): + out = self.conv(input_x) + if self.use_res_connect: + out = self.add(input_x, out) + return out + + +class MobileNetV2(nn.Cell): + def __init__(self, num_class=1000, input_size=224, width_mul=1.): + super(MobileNetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + inverted_residual_setting = [ + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 230, 1, 1], + ] + if width_mul > 1.0: + last_channel = make_divisible(last_channel * width_mul) + self.last_channel = last_channel + features = [_conv_bn(3, input_channel, 3, 2)] + + for t, c, n, s in inverted_residual_setting: + out_channel = make_divisible(c * width_mul) if t > 1 else c + for i in range(n): + if i == 0: + features.append(block(input_channel, out_channel, s, t)) + else: + features.append(block(input_channel, out_channel, 1, t)) + input_channel = out_channel + + features.append(_conv_bn(input_channel, self.last_channel, 1)) + + self.features = nn.SequentialCell(features) + self.mean = P.ReduceMean(keep_dims=False) + self.classifier = combined.Dense(self.last_channel, num_class) + + def construct(self, input_x): + out = input_x + out = self.features(out) + out = self.mean(out, (2, 3)) + out = self.classifier(out) + return out diff --git a/tests/ut/python/train/quant/test_quant.py b/tests/ut/python/train/quant/test_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..d640938519340f19528ed7a68fe15079e56e0ad8 --- /dev/null +++ b/tests/ut/python/train/quant/test_quant.py @@ -0,0 +1,94 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" tests for quant """ +import numpy as np +from mindspore import Tensor +from mindspore.train.quant import quant as qat +from mindspore import nn +import mindspore.ops.operations as P +from mindspore.nn.layer import combined +import mindspore.context as context +from mobilenetv2_combined import MobileNetV2 + +context.set_context(mode=context.GRAPH_MODE) + + +class LeNet5(nn.Cell): + """ + Lenet network + + Args: + num_class (int): Num classes. Default: 10. + + Returns: + Tensor, output tensor + Examples: + >>> LeNet(num_class=10) + + """ + + def __init__(self, num_class=10): + super(LeNet5, self).__init__() + self.num_class = num_class + self.conv1 = combined.Conv2d( + 1, 6, kernel_size=5, batchnorm=True, activation='relu6') + self.conv2 = combined.Conv2d(6, 16, kernel_size=5, activation='relu') + self.fc1 = combined.Dense(16 * 5 * 5, 120, activation='relu') + self.fc2 = combined.Dense(120, 84, activation='relu') + self.fc3 = combined.Dense(84, self.num_class) + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flattern = nn.Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.bn(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.max_pool2d(x) + x = self.flattern(x) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + return x + + +def test_qat_lenet(): + net = LeNet5() + net = qat.convert_quant_network( + net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8) + + +def test_qat_mobile(): + net = MobileNetV2() + img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) + net = qat.convert_quant_network( + net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8) + net(img) + + +def test_qat_mobile_train(): + net = MobileNetV2(num_class=10) + img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) + label = Tensor(np.ones((1, 10)).astype(np.float32)) + net = qat.convert_quant_network( + net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8) + + loss = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') + optimizer = nn.Momentum(net.trainable_params(), + learning_rate=0.1, momentum=0.9) + net = nn.WithLossCell(net, loss) + net = nn.TrainOneStepCell(net, optimizer) + net(img, label)