提交 7891b53f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2704 fix quantization aware training auto create graph bug

Merge pull request !2704 from chenzhongming/r0.5
......@@ -81,6 +81,7 @@ class Cell:
self.enable_hook = False
self._bprop_debug = False
self._is_run = False
self.cell_type = None
@property
def is_run(self):
......@@ -140,6 +141,14 @@ class Cell:
for cell_name, cell in cells_name:
cell._param_prefix = cell_name
def update_cell_type(self, cell_type):
"""
Update current cell type mainly identify if quantization aware training network.
After invoked, can set the cell type to 'cell_type'.
"""
self.cell_type = cell_type
@cell_init_args.setter
def cell_init_args(self, value):
if not isinstance(value, str):
......
......@@ -17,6 +17,7 @@ from mindspore import log as logger
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore._checkparam import ParamValidator as validator, Rel
from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative
from mindspore._extends import cell_attr_register
from ..cell import Cell
......@@ -397,3 +398,150 @@ class Conv2dTranspose(_Conv):
self.weight,
self.bias)
return s
class DepthwiseConv2d(Cell):
r"""
2D depthwise convolution layer.
Applies a 2D depthwise convolution over an input tensor which is typically of shape:
math:`(N, C_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number.
For each batch of shape:math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as:
.. math::
out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to :math:`i`-th channel of the :math:`j`-th
filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and
:math:`\text{ks_w}` are height and width of the convolution kernel. The full kernel has shape
:math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number
to split the input in the channel dimension.
If the 'pad_mode' is set to be "valid", the output height and width will be
:math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
(\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
<http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the height
and width of the 2D convolution window. Single int means the value if for both height and width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel.
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the height and width of movement are both strides, or a tuple of two int numbers that
represent height and width of movement respectively. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "same".
- same: Adopts the way of completion. Output height and width will be the same as the input.
Total number of padding will be calculated for horizontal and vertical
direction and evenly distributed to top and bottom, left and right if possible. Otherwise, the
last extra padding will be done from the bottom and the right side. If this mode is set, `padding`
must be 0.
- valid: Adopts the way of discarding. The possibly largest height and width of output will be return
without padding. Extra pixels will be discarded. If this mode is set, `padding`
must be 0.
- pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
Tensor borders. `padding` should be greater than or equal to 0.
padding (int): Implicit paddings on both sides of the input. Default: 0.
dilation (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the dilation rate
to use for dilated convolution. If set to be :math:`k > 1`, there will
be :math:`k - 1` pixels skipped for each sampling location. Its value should
be greater or equal to 1 and bounded by the height and width of the
input. Default: 1.
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> net = nn.DepthwiseConv2d(120, 240, 4, has_bias=False, weight_init='normal')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape
(1, 240, 1024, 640)
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros'):
super(DepthwiseConv2d, self).__init__()
self.kernel_size = twice(kernel_size)
self.stride = twice(stride)
self.dilation = twice(dilation)
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
validator.check_integer('group', group, in_channels, Rel.EQ)
validator.check_integer('group', group, out_channels, Rel.EQ)
validator.check_integer('group', group, 1, Rel.GE)
self.pad_mode = pad_mode
self.padding = padding
self.dilation = dilation
self.group = group
self.has_bias = has_bias
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation)
self.bias_add = P.BiasAdd()
weight_shape = [1, in_channels, *self.kernel_size]
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
if check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
else:
if bias_init != 'zeros':
logger.warning("value of `has_bias` is False, value of `bias_init` will be ignore.")
self.bias = None
def construct(self, x):
out = self.conv(x, self.weight)
if self.has_bias:
out = self.bias_add(out, self.bias)
return out
def extend_repr(self):
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={},' \
'has_bias={}, weight_init={}, bias_init={}'.format(
self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group,
self.has_bias, self.weight_init, self.bias_init)
if self.has_bias:
s += ', bias={}'.format(self.bias)
return s
此差异已折叠。
......@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""Generate bprop for aware quantization ops"""
"""Generate bprop for quantization aware ops"""
from .. import operations as P
from ..operations import _quant_ops as Q
......@@ -133,9 +133,9 @@ def get_bprop_batchnorm_fold2_(self):
return bprop
@bprop_getters.register(Q.FakeQuantMinMaxPerLayerUpdate)
@bprop_getters.register(Q.MinMaxUpdatePerLayer)
def get_bprop_fakequant_with_minmax_per_layer_update(self):
"""Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend"""
"""Generate bprop for MinMaxUpdatePerLayer for Ascend"""
def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
......@@ -143,9 +143,9 @@ def get_bprop_fakequant_with_minmax_per_layer_update(self):
return bprop
@bprop_getters.register(Q.FakeQuantMinMaxPerChannelUpdate)
@bprop_getters.register(Q.MinMaxUpdatePerChannel)
def get_bprop_fakequant_with_minmax_per_channel_update(self):
"""Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend"""
"""Generate bprop for MinMaxUpdatePerChannel for Ascend"""
def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
......
......@@ -30,7 +30,6 @@ batchnorm_fold2_op_info = TBERegOp("BatchNormFold2_D") \
.compute_cost(10) \
.kernel_name("batchnorm_fold2") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", None, "required", None) \
.input(1, "beta", None, "required", None) \
.input(2, "gamma", None, "required", None) \
......
......@@ -30,7 +30,6 @@ batchnorm_fold2_grad_op_info = TBERegOp("BatchNormFold2GradD") \
.compute_cost(10) \
.kernel_name("batchnorm_fold2_grad") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "dout", None, "required", None) \
.input(1, "dout_reduce", None, "required", None) \
.input(2, "dout_x_reduce", None, "required", None) \
......
......@@ -31,7 +31,6 @@ batchnorm_fold2_grad_reduce_op_info = TBERegOp("BatchNormFold2GradReduce") \
.compute_cost(10) \
.kernel_name("batchnorm_fold2_grad_reduce") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "dout", None, "required", None) \
.input(1, "x", None, "required", None) \
.output(0, "dout_reduce", True, "required", "all") \
......
......@@ -30,7 +30,6 @@ correction_mul_op_info = TBERegOp("CorrectionMul") \
.compute_cost(10) \
.kernel_name("correction_mul") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "batch_std", None, "required", None) \
......
......@@ -30,7 +30,6 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
.compute_cost(10) \
.kernel_name("correction_mul_grad") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "dout", None, "required", None) \
.input(1, "x", None, "required", None) \
......@@ -128,7 +127,6 @@ correction_mul_grad_reduce_op_info = TBERegOp("CorrectionMulGradReduce") \
.compute_cost(10) \
.kernel_name("correction_mul_grad_reduce") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "dout", None, "required", None) \
.output(0, "d_batch_std", True, "required", "all") \
......
......@@ -99,11 +99,15 @@ def fake_quant_perchannel(x, min_val, max_val, y,
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
channel_axis_ = 1
else:
channel_axis_ = channel_axis
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
......@@ -126,8 +130,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,
quant_min = quant_min + 1
shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1:
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape")
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
......
......@@ -124,11 +124,15 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
channel_axis_ = 1
else:
channel_axis_ = channel_axis
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
......@@ -151,8 +155,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
quant_min = quant_min + 1
shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1:
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape")
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -14,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""FakeQuantMinMaxPerChannelUpdate op"""
"""MinMaxUpdatePerChannel op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
......@@ -22,20 +21,15 @@ from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \
minmax_update_perchannel_op_info = TBERegOp("MinMaxUpdatePerChannel") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_min_max_per_channel_update.so") \
.binfile_name("minmax_update_perchannel.so") \
.compute_cost(10) \
.kernel_name("fake_quant_min_max_per_channel_update") \
.kernel_name("minmax_update_perchannel") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
......@@ -47,43 +41,46 @@ fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChan
.get_op_info()
@op_info_register(fake_quant_min_max_per_channel_update_op_info)
def _fake_quant_min_max_per_channel_update_tbe():
"""FakeQuantPerChannelUpdate TBE register"""
@op_info_register(minmax_update_perchannel_op_info)
def _minmax_update_perchannel_tbe():
"""MinMaxUpdatePerChannel TBE register"""
return
@fusion_manager.register("fake_quant_min_max_per_channel_update")
def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val,
ema, ema_decay, quant_min, quant_max, training, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerChannelUpdate compute"""
@fusion_manager.register("minmax_update_perchannel")
def minmax_update_perchannel_compute(x, min_val, max_val,
ema, ema_decay, channel_axis):
"""MinMaxUpdatePerChannel compute"""
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
if not ema:
ema_decay = 0.0
if training:
# CalMinMax
# CalMinMax
if channel_axis == 0:
axis = [1, 2, 3, 4]
else:
axis = [0, 2, 3]
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerLayer op"""
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, int, str)
def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
ema, ema_decay, channel_axis,
kernel_name="minmax_update_perchannel"):
"""MinMaxUpdatePerChannel op"""
x_shape = x.get("ori_shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
......@@ -91,11 +88,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if channel_axis == 0 and x_shape[0] != min_shape[0] and x_shape[1] == min_shape[0]:
channel_axis_ = 1
else:
channel_axis_ = channel_axis
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis_])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis_])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
......@@ -108,21 +109,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
if channel_axis_ == 0:
shape_c = min_val.get("ori_shape")
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name)
res_list = minmax_update_perchannel_compute(input_data, min_data, max_data,
ema, ema_decay, channel_axis_)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""FakeQuantMinMaxPerLayerUpdate op"""
"""MinMaxUpdatePerLayer op"""
from functools import reduce as functools_reduce
import te.lang.cce
from te import tvm
......@@ -22,20 +22,15 @@ from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
minmax_update_perlayer_op_info = TBERegOp("MinMaxUpdatePerLayer") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_minmax_update.so") \
.binfile_name("minmax_update_perlayer.so") \
.compute_cost(10) \
.kernel_name("fake_quant_minmax_update") \
.kernel_name("minmax_update_perlayer") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
......@@ -46,44 +41,42 @@ fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
.get_op_info()
@op_info_register(fake_quant_minmax_update_op_info)
def _fake_quant_minmax_update_tbe():
"""FakeQuantMinMaxPerLayerUpdate TBE register"""
@op_info_register(minmax_update_perlayer_op_info)
def _minmax_update_perlayer_tbe():
"""MinMaxUpdatePerLayer TBE register"""
return
@fusion_manager.register("fake_quant_minmax_update")
def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
kernel_name="fake_quant_minmax_update"):
"""FakeQuantMinMaxPerLayerUpdate compute"""
@fusion_manager.register("minmax_update_perlayer")
def minmax_update_perlayer_compute(x, min_val, max_val, ema, ema_decay):
"""MinMaxUpdatePerLayer compute"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
if not ema:
ema_decay = 0.0
if training:
# CalMinMax
axis = tuple(range(len(shape)))
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
# CalMinMax
axis = tuple(range(len(shape)))
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str)
def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits,
kernel_name="fake_quant_minmax_update"):
"""FakeQuantPerLayer op"""
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, str)
def minmax_update_perlayer(x, min_val, max_val, min_up, max_up,
ema, ema_decay, kernel_name="minmax_update_perlayer"):
"""MinMaxUpdatePerLayer op"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
......@@ -112,20 +105,10 @@ def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, kernel_name)
res_list = minmax_update_perlayer_compute(input_data, min_data, max_data, ema, ema_decay)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
......
......@@ -21,12 +21,12 @@ from ..._checkparam import Rel
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ...common import dtype as mstype
__all__ = ["FakeQuantPerLayer",
__all__ = ["MinMaxUpdatePerLayer",
"MinMaxUpdatePerChannel",
"FakeQuantPerLayer",
"FakeQuantPerLayerGrad",
"FakeQuantPerChannel",
"FakeQuantPerChannelGrad",
"FakeQuantMinMaxPerLayerUpdate",
"FakeQuantMinMaxPerChannelUpdate",
"BatchNormFold",
"BatchNormFoldGrad",
"CorrectionMul",
......@@ -38,20 +38,141 @@ __all__ = ["FakeQuantPerLayer",
"BatchNormFoldGradD",
"BatchNormFold2_D",
"BatchNormFold2GradD",
"BatchNormFold2GradReduce",
"BatchNormFold2GradReduce"
]
class MinMaxUpdatePerLayer(PrimitiveWithInfer):
r"""
Update min and max per layer.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, ema=False, ema_decay=0.999):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import minmax_update_perlayer
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class MinMaxUpdatePerChannel(PrimitiveWithInfer):
r"""
Update min and max per channel.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
channel_axis (int): Quantization by channel axis, support 0 and 1. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
support_x_rank = [2, 4]
@prim_attr_register
def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.channel_axis = validator.check_int_range(
'channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
if len(x_shape) not in self.support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.support_x_rank}'")
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class FakeQuantPerLayer(PrimitiveWithInfer):
r"""
Simulate the quantize and dequantize operations in training time.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
num_bits (int) : Number bits for quantization aware. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
simulate aware quantize funcion. After delay step in training time begin simulate the aware
simulate quantization aware funcion. After delay step in training time begin simulate the aware
quantize funcion. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
......@@ -103,8 +224,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type(
'quant_delay', quant_delay, (int,), self.name)
self.quant_delay = validator.check_integer(
'quant_delay', quant_delay, 0, Rel.GE, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['out'])
......@@ -196,6 +317,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
channel_axis (int): Quantization by channel axis, support 0 and 1. Default: 1.
Inputs:
- **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
......@@ -213,6 +335,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
>>> result = fake_quant(input_x, _min, _max)
"""
support_quant_bit = [4, 7, 8]
support_x_rank = [2, 4]
@prim_attr_register
def __init__(self,
......@@ -245,14 +368,15 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type(
'quant_delay', quant_delay, (int,), self.name)
self.channel_axis = validator.check_integer(
'channel_axis', channel_axis, 0, Rel.GE, self.name)
self.quant_delay = validator.check_integer(
'quant_delay', quant_delay, 0, Rel.GE, self.name)
self.channel_axis = validator.check_int_range(
'channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
if len(x_shape) not in self.support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.support_x_rank}'")
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer(
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
......@@ -832,153 +956,3 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
def infer_dtype(self, dout_type, x_type):
validator.check("dout type", dout_type, "x type", x_type)
return dout_type, dout_type
class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
......@@ -32,7 +32,6 @@ from ...ops.operations import _inner_ops as inner
from ...train import serialization
from . import quant_utils
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
nn.ReLU6: quant.ReLU6Quant,
nn.HSigmoid: quant.HSigmoidQuant,
......@@ -41,15 +40,14 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
class _AddFakeQuantInput(nn.Cell):
"""
Add FakeQuant at input and output of the Network. Only support one input and one output case.
Add FakeQuant OP at input of the network. Only support one input case.
"""
def __init__(self, network, quant_delay=0):
super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
self.network = network
self.fake_quant_input = quant.FakeQuantWithMinMax(
min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
self.fake_quant_input = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
self.fake_quant_input.update_parameters_name('fake_quant_input')
self.network = network
def construct(self, data):
data = self.fake_quant_input(data)
......@@ -59,7 +57,7 @@ class _AddFakeQuantInput(nn.Cell):
class _AddFakeQuantAfterSubCell(nn.Cell):
"""
Add FakeQuant after of the sub Cell.
Add FakeQuant OP after of the sub Cell.
"""
def __init__(self, subcell, **kwargs):
......@@ -114,11 +112,12 @@ class ConvertToQuantNetwork:
self.network.update_cell_prefix()
network = self._convert_subcells2quant(self.network)
network = _AddFakeQuantInput(network)
self.network.update_cell_type("quant")
return network
def _convert_subcells2quant(self, network):
"""
convet sub cell to quant cell
convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell
"""
cells = network.name_cells()
change = False
......@@ -137,19 +136,19 @@ class ConvertToQuantNetwork:
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())
# tensoradd to tensoradd quant
# add FakeQuant OP after OP in while list
add_list = []
for name in network.__dict__:
if name[0] == '_':
continue
attr = network.__dict__[name]
if isinstance(attr, ops.Primitive) and attr.name in ConvertToQuantNetwork.__quant_op_name__:
if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__:
add_list.append((name, attr))
for name, prim_op in add_list:
prefix = name
add_quant = _AddFakeQuantAfterSubCell(prim_op,
num_bits=self.act_bits,
quant_delay=self.act_delay,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
narrow_range=self.act_range)
......@@ -163,11 +162,11 @@ class ConvertToQuantNetwork:
def _convert_conv(self, subcell):
"""
convet conv cell to quant cell
convert Conv2d cell to quant cell
"""
conv_inner = subcell.conv
bn_inner = subcell.batchnorm
if subcell.batchnorm is not None and self.bn_fold:
if subcell.has_bn and self.bn_fold:
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
conv_inner.out_channels,
kernel_size=conv_inner.kernel_size,
......@@ -177,7 +176,7 @@ class ConvertToQuantNetwork:
dilation=conv_inner.dilation,
group=conv_inner.group,
eps=bn_inner.eps,
momentum=bn_inner.momentum,
momentum=1 - bn_inner.momentum,
quant_delay=self.weight_qdelay,
freeze_bn=self.freeze_bn,
per_channel=self.weight_channel,
......@@ -185,6 +184,11 @@ class ConvertToQuantNetwork:
fake=True,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
# change original network BatchNormal OP parameters to quant network
conv_inner.gamma = subcell.batchnorm.gamma
conv_inner.beta = subcell.batchnorm.beta
conv_inner.moving_mean = subcell.batchnorm.moving_mean
conv_inner.moving_variance = subcell.batchnorm.moving_variance
del subcell.batchnorm
subcell.batchnorm = None
subcell.has_bn = False
......@@ -203,6 +207,10 @@ class ConvertToQuantNetwork:
num_bits=self.weight_bits,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
# change original network Conv2D OP parameters to quant network
conv_inner.weight = subcell.conv.weight
if subcell.conv.has_bias:
conv_inner.bias = subcell.conv.bias
subcell.conv = conv_inner
if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation)
......@@ -229,6 +237,10 @@ class ConvertToQuantNetwork:
per_channel=self.weight_channel,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
# change original network Dense OP parameters to quant network
dense_inner.weight = subcell.dense.weight
if subcell.dense.has_bias:
dense_inner.bias = subcell.dense.bias
subcell.dense = dense_inner
if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation)
......@@ -236,7 +248,7 @@ class ConvertToQuantNetwork:
subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
num_bits=self.act_bits,
quant_delay=self.act_delay,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
narrow_range=self.act_range)
......@@ -246,12 +258,12 @@ class ConvertToQuantNetwork:
act_class = activation.__class__
if act_class not in _ACTIVATION_MAP:
raise ValueError(
"Unsupported activation in auto Quant: ", act_class)
"Unsupported activation in auto quant: ", act_class)
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
symmetric=self.act_symmetric,
narrow_range=self.act_range)
class ExportQuantNetworkDeploy:
......@@ -405,7 +417,7 @@ def convert_quant_network(network,
narrow_range=(False, False)
):
r"""
Create aware quantizaiton training network.
Create quantization aware training network.
Args:
network (Cell): Obtain a pipeline through network for saving graph summary.
......@@ -419,7 +431,7 @@ def convert_quant_network(network,
then base on per channel otherwise base on per layer. The first element represent weights
and second element represent data flow. Default: [False, False]
symmetric (list of bool): Quantization algorithm use symmetric or not. If `True` then base on
symmetric otherwise base on assymmetric. The first element represent weights and second
symmetric otherwise base on asymmetric. The first element represent weights and second
element represent data flow. Default: [False, False]
narrow_range (list of bool): Quantization algorithm use narrow range or not. If `True` then base
on narrow range otherwise base on off narrow range. The first element represent weights and
......@@ -428,6 +440,7 @@ def convert_quant_network(network,
Returns:
Cell, Network which has change to aware quantization training network cell.
"""
def convert2list(name, value):
if not isinstance(value, list) and not isinstance(value, tuple):
value = [value]
......
......@@ -2,13 +2,13 @@
## Description
Training LeNet with MNIST dataset in MindSpore with quantization aware trainging.
Training LeNet with MNIST dataset in MindSpore with quantization aware training.
This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware.
In this tutorial, you will:
1. Train a Mindspore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`.
1. Train a MindSpore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`.
2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file.
3. Use the quantization aware model to create an actually quantized model for the Ascend inference backend.
4. See the persistence of accuracy in inference backend and a 4x smaller model. To see the latency benefits on mobile, try out the Ascend inference backend examples.
......@@ -24,16 +24,16 @@ Install MindSpore base on the ascend device and GPU device from [MindSpore](http
```python
pip uninstall -y mindspore-ascend
pip uninstall -y mindspore-gpu
pip install mindspore-ascend-0.4.0.whl
pip install mindspore-ascend.whl
```
then you will get the following display
Then you will get the following display
```bash
>>> Found existing installation: mindspore-ascend
>>> Uninstalling mindspore-ascend:
>>> Successfully uninstalled mindspore-ascend.
>>> Successfully uninstalled mindspore-ascend.
```
### Prepare Dataset
......@@ -87,7 +87,7 @@ class LeNet5(nn.Cell):
return x
```
get the MNIST from scratch dataset.
Get the MNIST from scratch dataset.
```Python
ds_train = create_dataset(os.path.join(args.data_path, "train"),
......@@ -97,7 +97,7 @@ step_size = ds_train.get_dataset_size()
### Train model
Load teh Lenet fusion network, traing network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`.
Load the Lenet fusion network, training network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`.
```Python
# Define the network
......@@ -133,7 +133,7 @@ After all the following we will get the loss value of each step as following:
>>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
```
To save your time, just run this command.
Also, you can just run this command instead.
```python
python train.py --data_path MNIST_Data --device_target Ascend
......@@ -165,17 +165,17 @@ Note that the resulting model is quantization aware but not quantized (e.g. the
# define funsion network
network = LeNet5Fusion(cfg.num_classes)
# load aware quantizaiton network checkpoint
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
# convert funsion netwrok to aware quantizaiton network
# convert funsion netwrok to quantization aware network
network = quant.convert_quant_network(network)
```
### load checkpoint
after convert to quantization aware network, we can load the checkpoint file.
After convert to quantization aware network, we can load the checkpoint file.
```python
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
......@@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
### train quantization aware model
To save your time, just run this command.
Also, you can just run this command instead.
```python
python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
......@@ -210,7 +210,7 @@ Procedure of quantization aware model evaluation is different from normal. Becau
# define funsion network
network = LeNet5Fusion(cfg.num_classes)
# load aware quantizaiton network checkpoint
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
......@@ -218,10 +218,10 @@ load_param_into_net(network, param_dict)
network = quant.convert_quant_network(network)
```
To save your time, just run this command.
Also, you can just run this command insread.
```python
python eval.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
python eval_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
```
The top1 accuracy would display on shell.
......@@ -235,7 +235,7 @@ The top1 accuracy would display on shell.
Here are some optional parameters:
```bash
--device_target {Ascend,GPU,CPU}
--device_target {Ascend,GPU}
device where the code will be implemented (default: Ascend)
--data_path DATA_PATH
path where the dataset is saved
......
......@@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'],
choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
......
......@@ -32,7 +32,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'],
choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
......@@ -61,7 +61,7 @@ if __name__ == "__main__":
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path, model_type="quant")
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
print("============== Starting Testing ==============")
......
......@@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'],
choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
......@@ -56,8 +56,7 @@ if __name__ == "__main__":
# call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max,
model_type=network.type)
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
# define model
......
......@@ -33,7 +33,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'],
choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
......@@ -50,11 +50,13 @@ if __name__ == "__main__":
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path, network.type)
load_param_into_net(network, param_dict)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
......@@ -64,8 +66,7 @@ if __name__ == "__main__":
# call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max,
model_type="quant")
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
# define model
......
......@@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "eval" ];
if [ -d "../eval" ];
then
rm -rf ../eval
fi
......
......@@ -62,7 +62,7 @@ run_gpu()
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "train" ];
if [ -d "../train" ];
then
rm -rf ../train
fi
......
......@@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "eval" ];
if [ -d "../eval" ];
then
rm -rf ../eval
fi
......
......@@ -60,7 +60,7 @@ run_gpu()
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "train" ];
if [ -d "../train" ];
then
rm -rf ../train
fi
......
......@@ -17,7 +17,7 @@ def _conv_bn(in_channel,
out_channel,
kernel_size=ksize,
stride=stride,
batchnorm=True)])
has_bn=True)])
class InvertedResidual(nn.Cell):
......@@ -35,25 +35,25 @@ class InvertedResidual(nn.Cell):
3,
stride,
group=hidden_dim,
batchnorm=True,
has_bn=True,
activation='relu6'),
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
batchnorm=True)
has_bn=True)
])
else:
self.conv = nn.SequentialCell([
nn.Conv2dBnAct(inp, hidden_dim, 1, 1,
batchnorm=True,
has_bn=True,
activation='relu6'),
nn.Conv2dBnAct(hidden_dim,
hidden_dim,
3,
stride,
group=hidden_dim,
batchnorm=True,
has_bn=True,
activation='relu6'),
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
batchnorm=True)
has_bn=True)
])
self.add = P.TensorAdd()
......
......@@ -42,7 +42,7 @@ class LeNet5(nn.Cell):
def __init__(self, num_class=10):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6', pad_mode="valid")
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu6', pad_mode="valid")
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid")
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册