提交 e32ea53d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2864 add mobilenetV2 export

Merge pull request !2864 from chenzhongming/master
......@@ -289,7 +289,8 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table;
auto filter = [](AnfNodePtr node) {
return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul));
return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) ||
IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative));
};
std::vector<AnfNodePtr> nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter);
auto is_quant_cnode = [](AnfNodePtr node) {
......
......@@ -530,6 +530,7 @@ _activation = {
'relu6': ReLU6,
'tanh': Tanh,
'gelu': GELU,
'elu': ELU,
'sigmoid': Sigmoid,
'prelu': PReLU,
'leakyrelu': LeakyReLU,
......
......@@ -17,6 +17,7 @@
from functools import partial
import numpy as np
from mindspore import nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
......@@ -41,8 +42,7 @@ __all__ = [
'Conv2dBatchNormQuant',
'Conv2dQuant',
'DenseQuant',
'ReLUQuant',
'ReLU6Quant',
'ActQuant',
'HSwishQuant',
'HSigmoidQuant',
'TensorAddQuant',
......@@ -375,9 +375,10 @@ class FakeQuantWithMinMax(Cell):
def extend_repr(self):
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
'quant_delay={}, min_init={}, max_init={}'.format(
self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel,
self.channel_axis, self.num_channels, self.quant_delay, self.min_init, self.max_init)
'quant_delay={}, min_init={}, max_init={}'.format(self.num_bits, self.symmetric, self.narrow_range,
self.ema, self.ema_decay, self.per_channel,
self.channel_axis, self.num_channels, self.quant_delay,
self.min_init, self.max_init)
return s
def construct(self, x):
......@@ -540,10 +541,12 @@ class Conv2dBatchNormQuant(Cell):
def extend_repr(self):
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(
self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group,
self.fake, self.freeze_bn, self.momentum, self.quant_delay)
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels,
self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation,
self.group,
self.fake, self.freeze_bn, self.momentum,
self.quant_delay)
return s
def construct(self, x):
......@@ -685,10 +688,9 @@ class Conv2dQuant(Cell):
def extend_repr(self):
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \
'has_bias={}, quant_delay={}'.format(
self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group,
self.has_bias, self.quant_delay)
'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group,
self.has_bias, self.quant_delay)
return s
......@@ -799,76 +801,23 @@ class DenseQuant(Cell):
class _QuantActivation(Cell):
r"""
Base class for Quant activation function. Add Fake Quant OP after activation OP.
Base class for quantization aware training activation function. Add Fake Quant OP after activation OP.
"""
def get_origin(self):
raise NotImplementedError
class ReLUQuant(_QuantActivation):
class ActQuant(_QuantActivation):
r"""
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
Quantization aware training activation function.
For a more Detailed overview of ReLU op.
Args:
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of ReLUQuant.
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> relu_quant = nn.ReLUQuant()
>>> input_x = Tensor(np.array([[1, 2, 0], [-1, -2, 1]]), mindspore.float32)
>>> result = relu_quant(input_x)
"""
def __init__(self,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
super(ReLUQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.relu = P.ReLU()
def construct(self, x):
x = self.relu(x)
x = self.fake_quant_act(x)
return x
def get_origin(self):
return self.relu
class ReLU6Quant(_QuantActivation):
r"""
ReLU6Quant activation function.
Add Fake Quant OP after Relu6. Not Recommand to used these cell for Fake Quant Op
Add Fake Quant OP after activation. Not Recommand to used these cell for Fake Quant Op
Will climp the max range of the activation and the relu6 do the same operation.
For a more Detailed overview of ReLU6 op.
Args:
activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
......@@ -883,19 +832,20 @@ class ReLU6Quant(_QuantActivation):
Tensor, with the same type and shape as the `x`.
Examples:
>>> relu6_quant = nn.ReLU6Quant(4, 1)
>>> act_quant = nn.ActQuant(4, 1)
>>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
>>> result = relu6_quant(input_x)
>>> result = act_quant(input_x)
"""
def __init__(self,
activation,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
super(ReLU6Quant, self).__init__()
super(ActQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
max_init=6,
ema=True,
......@@ -905,15 +855,15 @@ class ReLU6Quant(_QuantActivation):
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.relu6 = P.ReLU6()
self.act = activation
def construct(self, x):
x = self.relu6(x)
x = self.act(x)
x = self.fake_quant_act(x)
return x
def get_origin(self):
return self.relu6
return self.act
class HSwishQuant(_QuantActivation):
......@@ -923,6 +873,7 @@ class HSwishQuant(_QuantActivation):
For a more Detailed overview of HSwish op.
Args:
activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
......@@ -943,6 +894,7 @@ class HSwishQuant(_QuantActivation):
"""
def __init__(self,
activation,
ema_decay=0.999,
per_channel=False,
num_bits=8,
......@@ -968,7 +920,10 @@ class HSwishQuant(_QuantActivation):
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.act = P.HSwish()
if isinstance(activation, nn.HSwish):
self.act = activation
else:
raise ValueError("Activation should be `nn.HSwish`")
def construct(self, x):
x = self.fake_quant_act_before(x)
......@@ -987,6 +942,7 @@ class HSigmoidQuant(_QuantActivation):
For a more Detailed overview of HSigmoid op.
Args:
activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
......@@ -1007,6 +963,7 @@ class HSigmoidQuant(_QuantActivation):
"""
def __init__(self,
activation,
ema_decay=0.999,
per_channel=False,
num_bits=8,
......@@ -1032,7 +989,10 @@ class HSigmoidQuant(_QuantActivation):
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.act = P.HSigmoid()
if isinstance(activation, nn.HSwish):
self.act = activation
else:
raise ValueError("Activation should be `nn.HSigmoid`")
def construct(self, x):
x = self.fake_quant_act_before(x)
......
......@@ -1004,6 +1004,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, w_dtype):
args = {'x': x_dtype, 'w': w_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
if x_dtype.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32)
return x_dtype
......
......@@ -33,8 +33,10 @@ from ...ops.operations import _inner_ops as inner
from ...train import serialization
from . import quant_utils
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
nn.ReLU6: quant.ReLU6Quant,
_ACTIVATION_MAP = {nn.ReLU: quant.ActQuant,
nn.ReLU6: quant.ActQuant,
nn.LeakyReLU: quant.ActQuant,
nn.Sigmoid: quant.ActQuant,
nn.HSigmoid: quant.HSigmoidQuant,
nn.HSwish: quant.HSwishQuant}
......@@ -257,9 +259,9 @@ class ConvertToQuantNetwork:
def _convert_activation(self, activation):
act_class = activation.__class__
if act_class not in _ACTIVATION_MAP:
raise ValueError(
"Unsupported activation in auto quant: ", act_class)
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits,
raise ValueError("Unsupported activation in auto quant: ", act_class)
return _ACTIVATION_MAP[act_class](activation=act_class,
num_bits=self.act_bits,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
......@@ -317,7 +319,7 @@ class ExportToQuantInferNetwork:
minq = self.all_parameters[minq_name]
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type)
else:
logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}")
logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}")
return None
# Build the `Quant` `Dequant` op.
......@@ -325,7 +327,7 @@ class ExportToQuantInferNetwork:
quant_op = inner.AscendQuant(float(scale_a_in), float(zp_a_in))
sqrt_mode = False
scale_deq = scale_a_out * scale_w
if scale_deq < 2 ** -14:
if (scale_deq < 2 ** -14).all():
scale_deq = np.sqrt(scale_deq)
sqrt_mode = True
dequant_op = inner.AscendDequant(sqrt_mode)
......@@ -404,11 +406,17 @@ def export(network, *inputs, file_name, file_format='GEIR'):
file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model.
- GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model.
"""
supported_device = ["Ascend"]
supported_formats = ['GEIR']
if context.get_context('device_target') not in supported_device:
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
if file_format not in supported_formats:
raise ValueError('Illegal file format {}.'.format(file_format))
network.set_train(False)
if file_format == 'GEIR':
exporter = ExportToQuantInferNetwork(network, *inputs)
deploy_net = exporter.run()
......
......@@ -45,7 +45,7 @@ def cal_quantization_params(input_min,
raise ValueError("input min shape should equal to input max.")
if len(input_min.shape) > 1:
raise ValueError("input min and max shape should be one dim.")
if input_min > input_max:
if (input_min > input_max).all():
raise ValueError("input_min min should less than input max.")
if (input_max == input_min).all():
# scale = 1.0, zp = 0.0
......@@ -85,9 +85,7 @@ def cal_quantization_params(input_min,
return scale, zp
def weight2int(data,
scale,
zero_point):
def weight2int(data, scale, zero_point):
r"""
Calculate int8/uint8 weight from fp32. the formula is defined as:
......@@ -103,12 +101,24 @@ def weight2int(data,
weight (numpy.ndarray): The dimension of channel or 1.
"""
if scale.shape != zero_point.shape:
raise ValueError("scale and zero_point should have the same shape.")
if scale.shape[0] > 0:
scale = scale.reshape(1, -1)
zero_point = zero_point.reshape(1, -1)
raise ValueError("`scale` and `zero_point` should have the same shape.")
if scale.shape[0] < 0:
raise ValueError("`scale` and `zero_point` shape should greater than zero.")
if scale.shape[0] == data.shape[0]:
# `Conv2d` or `Dense` op weight
shape_list = [-1] + [1] * len(data.shape[1:])
scale = scale.reshape(shape_list)
zero_point = zero_point.reshape(shape_list)
elif scale.shape[0] == data.shape[1]:
# `DepthwiseConv2d` op weight
shape_list = [1, -1] + [1] * len(data.shape[2:])
scale = scale.reshape(shape_list)
zero_point = zero_point.reshape(shape_list)
else:
raise ValueError("Unsupported weight shape({})".format(data.shape))
return np.round((data/scale) + zero_point)
return np.round((data / scale) + zero_point)
def scale_zp_from_fack_quant_cell(cell, data_type):
......@@ -183,9 +193,20 @@ def fold_batchnorm(weight, cell_quant):
beta = cell_quant.beta.data.asnumpy()
epsilon = cell_quant.eps
sigma = np.sqrt(variance + epsilon)
gamma = gamma.reshape(-1, 1, 1, 1)
sigma = sigma.reshape(-1, 1, 1, 1)
mean = mean.reshape(-1, 1, 1, 1)
weight = weight * gamma / sigma
if gamma.shape[0] == weight.shape[0]:
# `Conv2d` or `Dense` op weight
shape_list = [-1] + [1] * len(weight.shape[1:])
_gamma = gamma.reshape(shape_list)
_sigma = sigma.reshape(shape_list)
elif gamma.shape[0] == weight.shape[1]:
# `DepthwiseConv2d` op weight
shape_list = [1, -1] + [1] * len(weight.shape[2:])
_gamma = gamma.reshape(shape_list)
_sigma = sigma.reshape(shape_list)
else:
raise ValueError("Unsupported weight shape({})".format(weight.shape))
weight = weight * _gamma / _sigma
bias = beta - gamma * mean / sigma
return weight, bias
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Export MobilenetV2 on ImageNet"""
import argparse
import numpy as np
import mindspore
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.quant import quant
from src.mobilenetV2 import mobilenetV2
from src.config import config_ascend
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--device_target', type=str, default=None, help='Run device target')
args_opt = parser.parse_args()
if __name__ == '__main__':
cfg = None
if args_opt.device_target == "Ascend":
cfg = config_ascend
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
else:
raise ValueError("Unsupported device target: {}.".format(args_opt.device_target))
# define fusion network
network = mobilenetV2(num_classes=cfg.num_classes)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(network, param_dict)
# export network
print("============== Starting export ==============")
inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32)
quant.export(network, inputs, file_name="mobilenet_quant", file_format='GEIR')
print("============== End export ==============")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册