提交 e32ea53d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2864 add mobilenetV2 export

Merge pull request !2864 from chenzhongming/master
...@@ -289,7 +289,8 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI ...@@ -289,7 +289,8 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!"; MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table; std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table;
auto filter = [](AnfNodePtr node) { auto filter = [](AnfNodePtr node) {
return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul)); return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) ||
IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative));
}; };
std::vector<AnfNodePtr> nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter); std::vector<AnfNodePtr> nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter);
auto is_quant_cnode = [](AnfNodePtr node) { auto is_quant_cnode = [](AnfNodePtr node) {
......
...@@ -530,6 +530,7 @@ _activation = { ...@@ -530,6 +530,7 @@ _activation = {
'relu6': ReLU6, 'relu6': ReLU6,
'tanh': Tanh, 'tanh': Tanh,
'gelu': GELU, 'gelu': GELU,
'elu': ELU,
'sigmoid': Sigmoid, 'sigmoid': Sigmoid,
'prelu': PReLU, 'prelu': PReLU,
'leakyrelu': LeakyReLU, 'leakyrelu': LeakyReLU,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
from functools import partial from functools import partial
import numpy as np import numpy as np
from mindspore import nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
...@@ -41,8 +42,7 @@ __all__ = [ ...@@ -41,8 +42,7 @@ __all__ = [
'Conv2dBatchNormQuant', 'Conv2dBatchNormQuant',
'Conv2dQuant', 'Conv2dQuant',
'DenseQuant', 'DenseQuant',
'ReLUQuant', 'ActQuant',
'ReLU6Quant',
'HSwishQuant', 'HSwishQuant',
'HSigmoidQuant', 'HSigmoidQuant',
'TensorAddQuant', 'TensorAddQuant',
...@@ -375,9 +375,10 @@ class FakeQuantWithMinMax(Cell): ...@@ -375,9 +375,10 @@ class FakeQuantWithMinMax(Cell):
def extend_repr(self): def extend_repr(self):
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \ s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
'quant_delay={}, min_init={}, max_init={}'.format( 'quant_delay={}, min_init={}, max_init={}'.format(self.num_bits, self.symmetric, self.narrow_range,
self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel, self.ema, self.ema_decay, self.per_channel,
self.channel_axis, self.num_channels, self.quant_delay, self.min_init, self.max_init) self.channel_axis, self.num_channels, self.quant_delay,
self.min_init, self.max_init)
return s return s
def construct(self, x): def construct(self, x):
...@@ -540,10 +541,12 @@ class Conv2dBatchNormQuant(Cell): ...@@ -540,10 +541,12 @@ class Conv2dBatchNormQuant(Cell):
def extend_repr(self): def extend_repr(self):
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \ 'pad_mode={}, padding={}, dilation={}, group={}, ' \
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format( 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels,
self.in_channels, self.out_channels, self.kernel_size, self.stride, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group, self.pad_mode, self.padding, self.dilation,
self.fake, self.freeze_bn, self.momentum, self.quant_delay) self.group,
self.fake, self.freeze_bn, self.momentum,
self.quant_delay)
return s return s
def construct(self, x): def construct(self, x):
...@@ -685,10 +688,9 @@ class Conv2dQuant(Cell): ...@@ -685,10 +688,9 @@ class Conv2dQuant(Cell):
def extend_repr(self): def extend_repr(self):
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \ 'pad_mode={}, padding={}, dilation={}, group={}, ' \
'has_bias={}, quant_delay={}'.format( 'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding, self.dilation, self.group,
self.pad_mode, self.padding, self.dilation, self.group, self.has_bias, self.quant_delay)
self.has_bias, self.quant_delay)
return s return s
...@@ -799,76 +801,23 @@ class DenseQuant(Cell): ...@@ -799,76 +801,23 @@ class DenseQuant(Cell):
class _QuantActivation(Cell): class _QuantActivation(Cell):
r""" r"""
Base class for Quant activation function. Add Fake Quant OP after activation OP. Base class for quantization aware training activation function. Add Fake Quant OP after activation OP.
""" """
def get_origin(self): def get_origin(self):
raise NotImplementedError raise NotImplementedError
class ReLUQuant(_QuantActivation): class ActQuant(_QuantActivation):
r""" r"""
ReLUQuant activation function. Add Fake Quant OP after Relu OP. Quantization aware training activation function.
For a more Detailed overview of ReLU op. Add Fake Quant OP after activation. Not Recommand to used these cell for Fake Quant Op
Args:
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of ReLUQuant.
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> relu_quant = nn.ReLUQuant()
>>> input_x = Tensor(np.array([[1, 2, 0], [-1, -2, 1]]), mindspore.float32)
>>> result = relu_quant(input_x)
"""
def __init__(self,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
super(ReLUQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.relu = P.ReLU()
def construct(self, x):
x = self.relu(x)
x = self.fake_quant_act(x)
return x
def get_origin(self):
return self.relu
class ReLU6Quant(_QuantActivation):
r"""
ReLU6Quant activation function.
Add Fake Quant OP after Relu6. Not Recommand to used these cell for Fake Quant Op
Will climp the max range of the activation and the relu6 do the same operation. Will climp the max range of the activation and the relu6 do the same operation.
For a more Detailed overview of ReLU6 op. For a more Detailed overview of ReLU6 op.
Args: Args:
activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
...@@ -883,19 +832,20 @@ class ReLU6Quant(_QuantActivation): ...@@ -883,19 +832,20 @@ class ReLU6Quant(_QuantActivation):
Tensor, with the same type and shape as the `x`. Tensor, with the same type and shape as the `x`.
Examples: Examples:
>>> relu6_quant = nn.ReLU6Quant(4, 1) >>> act_quant = nn.ActQuant(4, 1)
>>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32) >>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
>>> result = relu6_quant(input_x) >>> result = act_quant(input_x)
""" """
def __init__(self, def __init__(self,
activation,
ema_decay=0.999, ema_decay=0.999,
per_channel=False, per_channel=False,
num_bits=8, num_bits=8,
symmetric=False, symmetric=False,
narrow_range=False, narrow_range=False,
quant_delay=0): quant_delay=0):
super(ReLU6Quant, self).__init__() super(ActQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=0, self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
max_init=6, max_init=6,
ema=True, ema=True,
...@@ -905,15 +855,15 @@ class ReLU6Quant(_QuantActivation): ...@@ -905,15 +855,15 @@ class ReLU6Quant(_QuantActivation):
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range, narrow_range=narrow_range,
quant_delay=quant_delay) quant_delay=quant_delay)
self.relu6 = P.ReLU6() self.act = activation
def construct(self, x): def construct(self, x):
x = self.relu6(x) x = self.act(x)
x = self.fake_quant_act(x) x = self.fake_quant_act(x)
return x return x
def get_origin(self): def get_origin(self):
return self.relu6 return self.act
class HSwishQuant(_QuantActivation): class HSwishQuant(_QuantActivation):
...@@ -923,6 +873,7 @@ class HSwishQuant(_QuantActivation): ...@@ -923,6 +873,7 @@ class HSwishQuant(_QuantActivation):
For a more Detailed overview of HSwish op. For a more Detailed overview of HSwish op.
Args: Args:
activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
...@@ -943,6 +894,7 @@ class HSwishQuant(_QuantActivation): ...@@ -943,6 +894,7 @@ class HSwishQuant(_QuantActivation):
""" """
def __init__(self, def __init__(self,
activation,
ema_decay=0.999, ema_decay=0.999,
per_channel=False, per_channel=False,
num_bits=8, num_bits=8,
...@@ -968,7 +920,10 @@ class HSwishQuant(_QuantActivation): ...@@ -968,7 +920,10 @@ class HSwishQuant(_QuantActivation):
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range, narrow_range=narrow_range,
quant_delay=quant_delay) quant_delay=quant_delay)
self.act = P.HSwish() if isinstance(activation, nn.HSwish):
self.act = activation
else:
raise ValueError("Activation should be `nn.HSwish`")
def construct(self, x): def construct(self, x):
x = self.fake_quant_act_before(x) x = self.fake_quant_act_before(x)
...@@ -987,6 +942,7 @@ class HSigmoidQuant(_QuantActivation): ...@@ -987,6 +942,7 @@ class HSigmoidQuant(_QuantActivation):
For a more Detailed overview of HSigmoid op. For a more Detailed overview of HSigmoid op.
Args: Args:
activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
...@@ -1007,6 +963,7 @@ class HSigmoidQuant(_QuantActivation): ...@@ -1007,6 +963,7 @@ class HSigmoidQuant(_QuantActivation):
""" """
def __init__(self, def __init__(self,
activation,
ema_decay=0.999, ema_decay=0.999,
per_channel=False, per_channel=False,
num_bits=8, num_bits=8,
...@@ -1032,7 +989,10 @@ class HSigmoidQuant(_QuantActivation): ...@@ -1032,7 +989,10 @@ class HSigmoidQuant(_QuantActivation):
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range, narrow_range=narrow_range,
quant_delay=quant_delay) quant_delay=quant_delay)
self.act = P.HSigmoid() if isinstance(activation, nn.HSwish):
self.act = activation
else:
raise ValueError("Activation should be `nn.HSigmoid`")
def construct(self, x): def construct(self, x):
x = self.fake_quant_act_before(x) x = self.fake_quant_act_before(x)
......
...@@ -1004,6 +1004,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): ...@@ -1004,6 +1004,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, w_dtype): def infer_dtype(self, x_dtype, w_dtype):
args = {'x': x_dtype, 'w': w_dtype} args = {'x': x_dtype, 'w': w_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensor_type_same(args, mstype.number_type, self.name)
if x_dtype.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32)
return x_dtype return x_dtype
......
...@@ -33,8 +33,10 @@ from ...ops.operations import _inner_ops as inner ...@@ -33,8 +33,10 @@ from ...ops.operations import _inner_ops as inner
from ...train import serialization from ...train import serialization
from . import quant_utils from . import quant_utils
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, _ACTIVATION_MAP = {nn.ReLU: quant.ActQuant,
nn.ReLU6: quant.ReLU6Quant, nn.ReLU6: quant.ActQuant,
nn.LeakyReLU: quant.ActQuant,
nn.Sigmoid: quant.ActQuant,
nn.HSigmoid: quant.HSigmoidQuant, nn.HSigmoid: quant.HSigmoidQuant,
nn.HSwish: quant.HSwishQuant} nn.HSwish: quant.HSwishQuant}
...@@ -257,9 +259,9 @@ class ConvertToQuantNetwork: ...@@ -257,9 +259,9 @@ class ConvertToQuantNetwork:
def _convert_activation(self, activation): def _convert_activation(self, activation):
act_class = activation.__class__ act_class = activation.__class__
if act_class not in _ACTIVATION_MAP: if act_class not in _ACTIVATION_MAP:
raise ValueError( raise ValueError("Unsupported activation in auto quant: ", act_class)
"Unsupported activation in auto quant: ", act_class) return _ACTIVATION_MAP[act_class](activation=act_class,
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, num_bits=self.act_bits,
quant_delay=self.act_qdelay, quant_delay=self.act_qdelay,
per_channel=self.act_channel, per_channel=self.act_channel,
symmetric=self.act_symmetric, symmetric=self.act_symmetric,
...@@ -317,7 +319,7 @@ class ExportToQuantInferNetwork: ...@@ -317,7 +319,7 @@ class ExportToQuantInferNetwork:
minq = self.all_parameters[minq_name] minq = self.all_parameters[minq_name]
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type)
else: else:
logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}") logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}")
return None return None
# Build the `Quant` `Dequant` op. # Build the `Quant` `Dequant` op.
...@@ -325,7 +327,7 @@ class ExportToQuantInferNetwork: ...@@ -325,7 +327,7 @@ class ExportToQuantInferNetwork:
quant_op = inner.AscendQuant(float(scale_a_in), float(zp_a_in)) quant_op = inner.AscendQuant(float(scale_a_in), float(zp_a_in))
sqrt_mode = False sqrt_mode = False
scale_deq = scale_a_out * scale_w scale_deq = scale_a_out * scale_w
if scale_deq < 2 ** -14: if (scale_deq < 2 ** -14).all():
scale_deq = np.sqrt(scale_deq) scale_deq = np.sqrt(scale_deq)
sqrt_mode = True sqrt_mode = True
dequant_op = inner.AscendDequant(sqrt_mode) dequant_op = inner.AscendDequant(sqrt_mode)
...@@ -404,11 +406,17 @@ def export(network, *inputs, file_name, file_format='GEIR'): ...@@ -404,11 +406,17 @@ def export(network, *inputs, file_name, file_format='GEIR'):
file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model. file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model.
- GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model. - GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model.
""" """
supported_device = ["Ascend"]
supported_formats = ['GEIR'] supported_formats = ['GEIR']
if context.get_context('device_target') not in supported_device:
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
if file_format not in supported_formats: if file_format not in supported_formats:
raise ValueError('Illegal file format {}.'.format(file_format)) raise ValueError('Illegal file format {}.'.format(file_format))
network.set_train(False)
if file_format == 'GEIR': if file_format == 'GEIR':
exporter = ExportToQuantInferNetwork(network, *inputs) exporter = ExportToQuantInferNetwork(network, *inputs)
deploy_net = exporter.run() deploy_net = exporter.run()
......
...@@ -45,7 +45,7 @@ def cal_quantization_params(input_min, ...@@ -45,7 +45,7 @@ def cal_quantization_params(input_min,
raise ValueError("input min shape should equal to input max.") raise ValueError("input min shape should equal to input max.")
if len(input_min.shape) > 1: if len(input_min.shape) > 1:
raise ValueError("input min and max shape should be one dim.") raise ValueError("input min and max shape should be one dim.")
if input_min > input_max: if (input_min > input_max).all():
raise ValueError("input_min min should less than input max.") raise ValueError("input_min min should less than input max.")
if (input_max == input_min).all(): if (input_max == input_min).all():
# scale = 1.0, zp = 0.0 # scale = 1.0, zp = 0.0
...@@ -85,9 +85,7 @@ def cal_quantization_params(input_min, ...@@ -85,9 +85,7 @@ def cal_quantization_params(input_min,
return scale, zp return scale, zp
def weight2int(data, def weight2int(data, scale, zero_point):
scale,
zero_point):
r""" r"""
Calculate int8/uint8 weight from fp32. the formula is defined as: Calculate int8/uint8 weight from fp32. the formula is defined as:
...@@ -103,12 +101,24 @@ def weight2int(data, ...@@ -103,12 +101,24 @@ def weight2int(data,
weight (numpy.ndarray): The dimension of channel or 1. weight (numpy.ndarray): The dimension of channel or 1.
""" """
if scale.shape != zero_point.shape: if scale.shape != zero_point.shape:
raise ValueError("scale and zero_point should have the same shape.") raise ValueError("`scale` and `zero_point` should have the same shape.")
if scale.shape[0] > 0: if scale.shape[0] < 0:
scale = scale.reshape(1, -1) raise ValueError("`scale` and `zero_point` shape should greater than zero.")
zero_point = zero_point.reshape(1, -1)
if scale.shape[0] == data.shape[0]:
# `Conv2d` or `Dense` op weight
shape_list = [-1] + [1] * len(data.shape[1:])
scale = scale.reshape(shape_list)
zero_point = zero_point.reshape(shape_list)
elif scale.shape[0] == data.shape[1]:
# `DepthwiseConv2d` op weight
shape_list = [1, -1] + [1] * len(data.shape[2:])
scale = scale.reshape(shape_list)
zero_point = zero_point.reshape(shape_list)
else:
raise ValueError("Unsupported weight shape({})".format(data.shape))
return np.round((data/scale) + zero_point) return np.round((data / scale) + zero_point)
def scale_zp_from_fack_quant_cell(cell, data_type): def scale_zp_from_fack_quant_cell(cell, data_type):
...@@ -183,9 +193,20 @@ def fold_batchnorm(weight, cell_quant): ...@@ -183,9 +193,20 @@ def fold_batchnorm(weight, cell_quant):
beta = cell_quant.beta.data.asnumpy() beta = cell_quant.beta.data.asnumpy()
epsilon = cell_quant.eps epsilon = cell_quant.eps
sigma = np.sqrt(variance + epsilon) sigma = np.sqrt(variance + epsilon)
gamma = gamma.reshape(-1, 1, 1, 1)
sigma = sigma.reshape(-1, 1, 1, 1) if gamma.shape[0] == weight.shape[0]:
mean = mean.reshape(-1, 1, 1, 1) # `Conv2d` or `Dense` op weight
weight = weight * gamma / sigma shape_list = [-1] + [1] * len(weight.shape[1:])
_gamma = gamma.reshape(shape_list)
_sigma = sigma.reshape(shape_list)
elif gamma.shape[0] == weight.shape[1]:
# `DepthwiseConv2d` op weight
shape_list = [1, -1] + [1] * len(weight.shape[2:])
_gamma = gamma.reshape(shape_list)
_sigma = sigma.reshape(shape_list)
else:
raise ValueError("Unsupported weight shape({})".format(weight.shape))
weight = weight * _gamma / _sigma
bias = beta - gamma * mean / sigma bias = beta - gamma * mean / sigma
return weight, bias return weight, bias
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Export MobilenetV2 on ImageNet"""
import argparse
import numpy as np
import mindspore
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.quant import quant
from src.mobilenetV2 import mobilenetV2
from src.config import config_ascend
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--device_target', type=str, default=None, help='Run device target')
args_opt = parser.parse_args()
if __name__ == '__main__':
cfg = None
if args_opt.device_target == "Ascend":
cfg = config_ascend
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
else:
raise ValueError("Unsupported device target: {}.".format(args_opt.device_target))
# define fusion network
network = mobilenetV2(num_classes=cfg.num_classes)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(network, param_dict)
# export network
print("============== Starting export ==============")
inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32)
quant.export(network, inputs, file_name="mobilenet_quant", file_format='GEIR')
print("============== End export ==============")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册