diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index c9d79a3a390a3d045db61289ee4c8ceebe6cd575..4ba02066e60c639d4075a4510ca29f597539eec9 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -289,7 +289,8 @@ std::map> ExecutorPy::FetchI MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!"; std::map> fake_quant_table; auto filter = [](AnfNodePtr node) { - return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul)); + return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) || + IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative)); }; std::vector nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter); auto is_quant_cnode = [](AnfNodePtr node) { diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index 14a1aa8554060ec7e3ba78c9ec68eda2fc88c04d..384f6251338abd8fdb0daee5c7acbbf7dcd934af 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -530,6 +530,7 @@ _activation = { 'relu6': ReLU6, 'tanh': Tanh, 'gelu': GELU, + 'elu': ELU, 'sigmoid': Sigmoid, 'prelu': PReLU, 'leakyrelu': LeakyReLU, diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index f0c82937c57efb189a83ba3d9056cee0674ee509..32f7fa4db10a4ffddac956b5dbc9d3729b4b7a2d 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -17,6 +17,7 @@ from functools import partial import numpy as np +from mindspore import nn import mindspore.common.dtype as mstype from mindspore.ops import operations as P from mindspore.ops import functional as F @@ -41,8 +42,7 @@ __all__ = [ 'Conv2dBatchNormQuant', 'Conv2dQuant', 'DenseQuant', - 'ReLUQuant', - 'ReLU6Quant', + 'ActQuant', 'HSwishQuant', 'HSigmoidQuant', 'TensorAddQuant', @@ -375,9 +375,10 @@ class FakeQuantWithMinMax(Cell): def extend_repr(self): s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \ - 'quant_delay={}, min_init={}, max_init={}'.format( - self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel, - self.channel_axis, self.num_channels, self.quant_delay, self.min_init, self.max_init) + 'quant_delay={}, min_init={}, max_init={}'.format(self.num_bits, self.symmetric, self.narrow_range, + self.ema, self.ema_decay, self.per_channel, + self.channel_axis, self.num_channels, self.quant_delay, + self.min_init, self.max_init) return s def construct(self, x): @@ -540,10 +541,12 @@ class Conv2dBatchNormQuant(Cell): def extend_repr(self): s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ 'pad_mode={}, padding={}, dilation={}, group={}, ' \ - 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format( - self.in_channels, self.out_channels, self.kernel_size, self.stride, - self.pad_mode, self.padding, self.dilation, self.group, - self.fake, self.freeze_bn, self.momentum, self.quant_delay) + 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels, + self.kernel_size, self.stride, + self.pad_mode, self.padding, self.dilation, + self.group, + self.fake, self.freeze_bn, self.momentum, + self.quant_delay) return s def construct(self, x): @@ -685,10 +688,9 @@ class Conv2dQuant(Cell): def extend_repr(self): s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ 'pad_mode={}, padding={}, dilation={}, group={}, ' \ - 'has_bias={}, quant_delay={}'.format( - self.in_channels, self.out_channels, self.kernel_size, self.stride, - self.pad_mode, self.padding, self.dilation, self.group, - self.has_bias, self.quant_delay) + 'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, + self.pad_mode, self.padding, self.dilation, self.group, + self.has_bias, self.quant_delay) return s @@ -799,76 +801,23 @@ class DenseQuant(Cell): class _QuantActivation(Cell): r""" - Base class for Quant activation function. Add Fake Quant OP after activation OP. + Base class for quantization aware training activation function. Add Fake Quant OP after activation OP. """ def get_origin(self): raise NotImplementedError -class ReLUQuant(_QuantActivation): +class ActQuant(_QuantActivation): r""" - ReLUQuant activation function. Add Fake Quant OP after Relu OP. + Quantization aware training activation function. - For a more Detailed overview of ReLU op. - - Args: - ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - symmetric (bool): Quantization algorithm use symmetric or not. Default: False. - narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. - - Inputs: - - **x** (Tensor) - The input of ReLUQuant. - - Outputs: - Tensor, with the same type and shape as the `x`. - - Examples: - >>> relu_quant = nn.ReLUQuant() - >>> input_x = Tensor(np.array([[1, 2, 0], [-1, -2, 1]]), mindspore.float32) - >>> result = relu_quant(input_x) - """ - - def __init__(self, - ema_decay=0.999, - per_channel=False, - num_bits=8, - symmetric=False, - narrow_range=False, - quant_delay=0): - super(ReLUQuant, self).__init__() - self.fake_quant_act = FakeQuantWithMinMax(min_init=0, - max_init=6, - ema=True, - ema_decay=ema_decay, - per_channel=per_channel, - num_bits=num_bits, - symmetric=symmetric, - narrow_range=narrow_range, - quant_delay=quant_delay) - self.relu = P.ReLU() - - def construct(self, x): - x = self.relu(x) - x = self.fake_quant_act(x) - return x - - def get_origin(self): - return self.relu - - -class ReLU6Quant(_QuantActivation): - r""" - ReLU6Quant activation function. - - Add Fake Quant OP after Relu6. Not Recommand to used these cell for Fake Quant Op + Add Fake Quant OP after activation. Not Recommand to used these cell for Fake Quant Op Will climp the max range of the activation and the relu6 do the same operation. For a more Detailed overview of ReLU6 op. Args: + activation (Cell): Activation cell class. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. @@ -883,19 +832,20 @@ class ReLU6Quant(_QuantActivation): Tensor, with the same type and shape as the `x`. Examples: - >>> relu6_quant = nn.ReLU6Quant(4, 1) + >>> act_quant = nn.ActQuant(4, 1) >>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32) - >>> result = relu6_quant(input_x) + >>> result = act_quant(input_x) """ def __init__(self, + activation, ema_decay=0.999, per_channel=False, num_bits=8, symmetric=False, narrow_range=False, quant_delay=0): - super(ReLU6Quant, self).__init__() + super(ActQuant, self).__init__() self.fake_quant_act = FakeQuantWithMinMax(min_init=0, max_init=6, ema=True, @@ -905,15 +855,15 @@ class ReLU6Quant(_QuantActivation): symmetric=symmetric, narrow_range=narrow_range, quant_delay=quant_delay) - self.relu6 = P.ReLU6() + self.act = activation def construct(self, x): - x = self.relu6(x) + x = self.act(x) x = self.fake_quant_act(x) return x def get_origin(self): - return self.relu6 + return self.act class HSwishQuant(_QuantActivation): @@ -923,6 +873,7 @@ class HSwishQuant(_QuantActivation): For a more Detailed overview of HSwish op. Args: + activation (Cell): Activation cell class. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. @@ -943,6 +894,7 @@ class HSwishQuant(_QuantActivation): """ def __init__(self, + activation, ema_decay=0.999, per_channel=False, num_bits=8, @@ -968,7 +920,10 @@ class HSwishQuant(_QuantActivation): symmetric=symmetric, narrow_range=narrow_range, quant_delay=quant_delay) - self.act = P.HSwish() + if isinstance(activation, nn.HSwish): + self.act = activation + else: + raise ValueError("Activation should be `nn.HSwish`") def construct(self, x): x = self.fake_quant_act_before(x) @@ -987,6 +942,7 @@ class HSigmoidQuant(_QuantActivation): For a more Detailed overview of HSigmoid op. Args: + activation (Cell): Activation cell class. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. @@ -1007,6 +963,7 @@ class HSigmoidQuant(_QuantActivation): """ def __init__(self, + activation, ema_decay=0.999, per_channel=False, num_bits=8, @@ -1032,7 +989,10 @@ class HSigmoidQuant(_QuantActivation): symmetric=symmetric, narrow_range=narrow_range, quant_delay=quant_delay) - self.act = P.HSigmoid() + if isinstance(activation, nn.HSwish): + self.act = activation + else: + raise ValueError("Activation should be `nn.HSigmoid`") def construct(self, x): x = self.fake_quant_act_before(x) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 2b85c752bfd64015cfe64cde277f05a4804177f2..7117e494e491f52b4d904a593ef24a86515ab825 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1004,6 +1004,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): def infer_dtype(self, x_dtype, w_dtype): args = {'x': x_dtype, 'w': w_dtype} validator.check_tensor_type_same(args, mstype.number_type, self.name) + if x_dtype.element_type() == mstype.int8: + return mstype.tensor_type(mstype.int32) return x_dtype diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index bc44ba22c27bd01afb37b2037ecb80b7e983a8f4..e769fa1cdd97e49a6097590ae57f6c200add0240 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -33,8 +33,10 @@ from ...ops.operations import _inner_ops as inner from ...train import serialization from . import quant_utils -_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, - nn.ReLU6: quant.ReLU6Quant, +_ACTIVATION_MAP = {nn.ReLU: quant.ActQuant, + nn.ReLU6: quant.ActQuant, + nn.LeakyReLU: quant.ActQuant, + nn.Sigmoid: quant.ActQuant, nn.HSigmoid: quant.HSigmoidQuant, nn.HSwish: quant.HSwishQuant} @@ -257,9 +259,9 @@ class ConvertToQuantNetwork: def _convert_activation(self, activation): act_class = activation.__class__ if act_class not in _ACTIVATION_MAP: - raise ValueError( - "Unsupported activation in auto quant: ", act_class) - return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, + raise ValueError("Unsupported activation in auto quant: ", act_class) + return _ACTIVATION_MAP[act_class](activation=act_class, + num_bits=self.act_bits, quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, @@ -317,7 +319,7 @@ class ExportToQuantInferNetwork: minq = self.all_parameters[minq_name] scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) else: - logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}") + logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}") return None # Build the `Quant` `Dequant` op. @@ -325,7 +327,7 @@ class ExportToQuantInferNetwork: quant_op = inner.AscendQuant(float(scale_a_in), float(zp_a_in)) sqrt_mode = False scale_deq = scale_a_out * scale_w - if scale_deq < 2 ** -14: + if (scale_deq < 2 ** -14).all(): scale_deq = np.sqrt(scale_deq) sqrt_mode = True dequant_op = inner.AscendDequant(sqrt_mode) @@ -404,11 +406,17 @@ def export(network, *inputs, file_name, file_format='GEIR'): file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model. - GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model. """ + supported_device = ["Ascend"] supported_formats = ['GEIR'] + if context.get_context('device_target') not in supported_device: + raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) + if file_format not in supported_formats: raise ValueError('Illegal file format {}.'.format(file_format)) + network.set_train(False) + if file_format == 'GEIR': exporter = ExportToQuantInferNetwork(network, *inputs) deploy_net = exporter.run() diff --git a/mindspore/train/quant/quant_utils.py b/mindspore/train/quant/quant_utils.py index c4a8004012a131eae39cf6bfa7cbeac5566dd05f..da6d4fc87282cb500f45c7293ec6704f24a92f52 100644 --- a/mindspore/train/quant/quant_utils.py +++ b/mindspore/train/quant/quant_utils.py @@ -45,7 +45,7 @@ def cal_quantization_params(input_min, raise ValueError("input min shape should equal to input max.") if len(input_min.shape) > 1: raise ValueError("input min and max shape should be one dim.") - if input_min > input_max: + if (input_min > input_max).all(): raise ValueError("input_min min should less than input max.") if (input_max == input_min).all(): # scale = 1.0, zp = 0.0 @@ -85,9 +85,7 @@ def cal_quantization_params(input_min, return scale, zp -def weight2int(data, - scale, - zero_point): +def weight2int(data, scale, zero_point): r""" Calculate int8/uint8 weight from fp32. the formula is defined as: @@ -103,12 +101,24 @@ def weight2int(data, weight (numpy.ndarray): The dimension of channel or 1. """ if scale.shape != zero_point.shape: - raise ValueError("scale and zero_point should have the same shape.") - if scale.shape[0] > 0: - scale = scale.reshape(1, -1) - zero_point = zero_point.reshape(1, -1) + raise ValueError("`scale` and `zero_point` should have the same shape.") + if scale.shape[0] < 0: + raise ValueError("`scale` and `zero_point` shape should greater than zero.") + + if scale.shape[0] == data.shape[0]: + # `Conv2d` or `Dense` op weight + shape_list = [-1] + [1] * len(data.shape[1:]) + scale = scale.reshape(shape_list) + zero_point = zero_point.reshape(shape_list) + elif scale.shape[0] == data.shape[1]: + # `DepthwiseConv2d` op weight + shape_list = [1, -1] + [1] * len(data.shape[2:]) + scale = scale.reshape(shape_list) + zero_point = zero_point.reshape(shape_list) + else: + raise ValueError("Unsupported weight shape({})".format(data.shape)) - return np.round((data/scale) + zero_point) + return np.round((data / scale) + zero_point) def scale_zp_from_fack_quant_cell(cell, data_type): @@ -183,9 +193,20 @@ def fold_batchnorm(weight, cell_quant): beta = cell_quant.beta.data.asnumpy() epsilon = cell_quant.eps sigma = np.sqrt(variance + epsilon) - gamma = gamma.reshape(-1, 1, 1, 1) - sigma = sigma.reshape(-1, 1, 1, 1) - mean = mean.reshape(-1, 1, 1, 1) - weight = weight * gamma / sigma + + if gamma.shape[0] == weight.shape[0]: + # `Conv2d` or `Dense` op weight + shape_list = [-1] + [1] * len(weight.shape[1:]) + _gamma = gamma.reshape(shape_list) + _sigma = sigma.reshape(shape_list) + elif gamma.shape[0] == weight.shape[1]: + # `DepthwiseConv2d` op weight + shape_list = [1, -1] + [1] * len(weight.shape[2:]) + _gamma = gamma.reshape(shape_list) + _sigma = sigma.reshape(shape_list) + else: + raise ValueError("Unsupported weight shape({})".format(weight.shape)) + + weight = weight * _gamma / _sigma bias = beta - gamma * mean / sigma return weight, bias diff --git a/model_zoo/mobilenetv2_quant/export.py b/model_zoo/mobilenetv2_quant/export.py new file mode 100644 index 0000000000000000000000000000000000000000..00e377cece25fa912a785e28908659668ef7eb75 --- /dev/null +++ b/model_zoo/mobilenetv2_quant/export.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Export MobilenetV2 on ImageNet""" + +import argparse +import numpy as np + +import mindspore +from mindspore import Tensor +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.quant import quant + +from src.mobilenetV2 import mobilenetV2 +from src.config import config_ascend + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +parser.add_argument('--device_target', type=str, default=None, help='Run device target') +args_opt = parser.parse_args() + +if __name__ == '__main__': + cfg = None + if args_opt.device_target == "Ascend": + cfg = config_ascend + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) + else: + raise ValueError("Unsupported device target: {}.".format(args_opt.device_target)) + + # define fusion network + network = mobilenetV2(num_classes=cfg.num_classes) + # convert fusion network to quantization aware network + network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + # load checkpoint + param_dict = load_checkpoint(args_opt.checkpoint_path) + load_param_into_net(network, param_dict) + + # export network + print("============== Starting export ==============") + inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32) + quant.export(network, inputs, file_name="mobilenet_quant", file_format='GEIR') + print("============== End export ==============")