提交 b7db3e9a 编写于 作者: C chenzomi

add fake quant per channel and bug fix

上级 bd3e8da6
......@@ -171,6 +171,6 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMax, FakeQuantGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantGpuKernel)
} // namespace kernel
} // namespace mindspore
......@@ -153,6 +153,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMaxGrad, FakeQuantGradGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantGradGpuKernel)
} // namespace kernel
} // namespace mindspore
......@@ -175,6 +175,6 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannel, FakeQuantPerChannelGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerChannel, FakeQuantPerChannelGpuKernel)
} // namespace kernel
} // namespace mindspore
......@@ -143,6 +143,6 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannelGrad, FakeQuantPerChannelGradGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerChannelGrad, FakeQuantPerChannelGradGpuKernel)
} // namespace kernel
} // namespace mindspore
......@@ -14,6 +14,7 @@
# ============================================================================
"""Aware quantization."""
from functools import partial
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
......@@ -101,10 +102,9 @@ class BatchNormFoldCell(Cell):
return batch_mean, batch_std, running_mean, running_std
class FakeQuantWithMinMaxD(Cell):
class FakeQuantWithMinMaxAscend(Cell):
r"""
Aware Quantization training op of ascend. This OP provide Fake quantization observer
function on data with min and max.
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
Args:
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
......@@ -125,7 +125,7 @@ class FakeQuantWithMinMaxD(Cell):
Tensor, with the same type and shape as the `x`.
Examples:
>>> fake_quant = nn.FakeQuantWithMinMaxD()
>>> fake_quant = FakeQuantWithMinMax()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x)
"""
......@@ -137,75 +137,77 @@ class FakeQuantWithMinMaxD(Cell):
ema=False,
ema_decay=0.999,
per_channel=False,
channel_size=1,
channel_axis=1,
out_channels=1,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True):
"""init FakeQuantWithMinMax ascend layer"""
super(FakeQuantWithMinMaxD, self).__init__()
"""init FakeQuantWithMinMaxAscend layer"""
super(FakeQuantWithMinMaxAscend, self).__init__()
self.min_init = min_init
self.num_bits = num_bits
self.max_init = max_init
self.num_bits = num_bits
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.channel_size = channel_size
self.channel_axis = channel_axis
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.training = training
if not per_channel:
self.fake_quant = P.FakeQuantWithMinMax(num_bits=self.num_bits,
# init tensor min and max for fake quant op
if isinstance(min_init, int):
min_array = np.array([min_init]).reshape(1).astype(np.float32)
max_array = np.array([max_init]).reshape(1).astype(np.float32)
elif isinstance(min_init, list):
min_array = np.array([self.min_init for i in range(
0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(
0, self.out_channels)]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
if per_channel:
quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis)
ema_fun = partial(P.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis)
else:
quant_fun = P.FakeQuantPerLayer
ema_fun = P.FakeQuantMinMaxPerLayerUpdate
self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=training)
self.ema_update = P.FakeQuantWithMinMaxUpdate(num_bits=self.num_bits,
training=self.training)
self.ema_update = ema_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=training)
else:
raise RuntimeError("not support per channel")
if isinstance(min_init, Parameter):
self.minq = min_init
self.maxq = max_init
else:
self.minq = Parameter(Tensor(np.array([min_init]).astype(np.float32)),
name='quant_min',
requires_grad=False)
self.maxq = Parameter(Tensor(np.array([max_init]).astype(np.float32)),
name='quant_max',
requires_grad=False)
self.reduce_min = P.ReduceMin()
self.reduce_max = P.ReduceMax()
training=self.training)
def extend_repr(self):
s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size,
self.quant_delay)
s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay,
self.per_channel, self.quant_delay, self.channel_axis)
return s
def construct(self, x, minq, maxq):
if self.training:
min_up, max_up = self.ema_update(x, minq, maxq)
def construct(self, x):
if self.update:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up)
else:
out = self.fake_quant(x, minq, maxq)
out = self.fake_quant(x, self.minq, self.maxq)
return out
class FakeQuantWithMinMax(Cell):
class FakeQuantWithMinMaxGPU(Cell):
r"""
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
......@@ -240,98 +242,69 @@ class FakeQuantWithMinMax(Cell):
ema=False,
ema_decay=0.999,
per_channel=False,
channel_axis=1,
out_channels=1,
quant_delay=0,
symmetric=False,
narrow_range=False):
"""init FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__()
narrow_range=False,
training=True):
super(FakeQuantWithMinMaxGPU, self).__init__()
self.min_init = min_init
self.num_bits = num_bits
self.max_init = max_init
self.num_bits = num_bits
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.out_channels = out_channels
self.channel_axis = channel_axis
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.training = training
if per_channel:
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(0, self.channel_size)]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True)
self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False)
else:
# init tensor min and max for fake quant op
if isinstance(min_init, int):
min_array = np.array([min_init]).reshape(1).astype(np.float32)
max_array = np.array([max_init]).reshape(1).astype(np.float32)
elif isinstance(min_init, list):
min_array = np.array([self.min_init for i in range(
0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(
0, self.out_channels)]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
if context.get_context('device_target') == "Ascend":
self.fake_quant_train = FakeQuantWithMinMaxD(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True,
min_init=self.minq,
max_init=self.maxq)
self.fake_quant_infer = FakeQuantWithMinMaxD(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False,
min_init=self.minq,
max_init=self.maxq)
elif context.get_context('device_target') == "GPU":
self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True)
self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits,
if per_channel:
quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis)
else:
quant_fun = P.FakeQuantPerLayer
self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False)
else:
raise ValueError("Not support platform.")
training=self.training)
def extend_repr(self):
s = 'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.quant_delay)
s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay,
self.per_channel, self.quant_delay, self.channel_axis)
return s
def construct(self, x):
if self.training:
out = self.fake_quant_train(x, self.minq, self.maxq)
else:
out = self.fake_quant_infer(x, self.minq, self.maxq)
out = self.fake_quant(x, self.minq, self.maxq)
return out
def FakeQuantWithMinMax(**kwargs):
if context.get_context('device_target') == "Ascend":
out = FakeQuantWithMinMaxAscend(**kwargs)
if context.get_context('device_target') == "GPU":
out = FakeQuantWithMinMaxGPU(**kwargs)
else:
raise ValueError("Not support platform or channel mode.")
return out
class Conv2dBatchNormQuant(Cell):
r"""
2D convolution with BatchNormal op folded layer.
......@@ -420,7 +393,6 @@ class Conv2dBatchNormQuant(Cell):
self.per_channel = per_channel
self.symmetric = symmetric
self.narrow_range = narrow_range
self.channel_axis = int(group > 1)
self.is_gpu = context.get_context('device_target') == "GPU"
# initialize convolution op and Parameter
......@@ -435,6 +407,7 @@ class Conv2dBatchNormQuant(Cell):
dilation=self.dilation)
if weight_init is None:
weight_init = initializer('normal', [1, in_channels, *self.kernel_size])
channel_axis = 1
else:
self.conv = P.Conv2D(out_channel=out_channels,
kernel_size=self.kernel_size,
......@@ -445,6 +418,7 @@ class Conv2dBatchNormQuant(Cell):
group=group)
if weight_init is None:
weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size])
channel_axis = 0
self.weight = Parameter(weight_init, name='weight')
# initialize batchnorm Parameter
......@@ -472,7 +446,7 @@ class Conv2dBatchNormQuant(Cell):
symmetric=symmetric,
narrow_range=narrow_range)
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
self.correct_mul = P.CorrectionMul(self.channel_axis)
self.correct_mul = P.CorrectionMul(channel_axis)
if context.get_context('device_target') == "Ascend":
self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn)
self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0)
......@@ -520,7 +494,7 @@ class Conv2dBatchNormQuant(Cell):
out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
F.control_depend(out, self.assignadd(self.step, self.one))
else:
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, running_std, running_mean, running_std)
return out
......
......@@ -20,10 +20,11 @@ from .grad_base import bprop_getters
from ..composite.multitype_ops.zeros_like_impl import zeros_like
@bprop_getters.register(P.FakeQuantWithMinMax)
@bprop_getters.register(P.FakeQuantPerLayer)
def get_bprop_fakequant_with_minmax(self):
"""Generate bprop for FakeQuantWithMinMax for GPU and Ascend"""
op = P.FakeQuantWithMinMaxGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
"""Generate bprop for FakeQuantPerLayer for GPU and Ascend"""
op = P.FakeQuantPerLayerGrad(
num_bits=self.num_bits, quant_delay=self.quant_delay)
def bprop(x, x_min, x_max, out, dout):
dx = op(dout, x, x_min, x_max)
......@@ -32,10 +33,14 @@ def get_bprop_fakequant_with_minmax(self):
return bprop
@bprop_getters.register(P.FakeQuantWithMinMaxPerChannel)
@bprop_getters.register(P.FakeQuantPerChannel)
def get_bprop_fakequant_with_minmax_perchannel(self):
"""Generate bprop for FakeQuantWithMinMaxPerChannel for GPU"""
op = P.FakeQuantWithMinMaxPerChannelGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
"""Generate bprop for FakeQuantPerChannel"""
op = P.FakeQuantPerChannelGrad(num_bits=self.num_bits,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.symmetric,
channel_axis=self.channel_axis)
def bprop(x, x_min, x_max, out, dout):
dx = op(dout, x, x_min, x_max)
......@@ -117,9 +122,19 @@ def get_bprop_batchnorm_fold2_(self):
return bprop
@bprop_getters.register(P.FakeQuantWithMinMaxUpdate)
def get_bprop_fakequant_with_minmax_update(self):
"""Generate bprop for FakeQuantWithMinMaxUpdate for Ascend"""
@bprop_getters.register(P.FakeQuantMinMaxPerLayerUpdate)
def get_bprop_fakequant_with_minmax_per_layer_update(self):
"""Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend"""
def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
return bprop
@bprop_getters.register(P.FakeQuantMinMaxPerChannelUpdate)
def get_bprop_fakequant_with_minmax_per_channel_update(self):
"""Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend"""
def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantMinMaxPerChannelUpdate op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_min_max_per_channel_update.so") \
.compute_cost(10) \
.kernel_name("fake_quant_min_max_per_channel_update") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
.output(0, "min_up", True, "required", "all") \
.output(1, "max_up", True, "required", "all") \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.get_op_info()
@op_info_register(fake_quant_min_max_per_channel_update_op_info)
def _fake_quant_min_max_per_channel_update_tbe():
"""FakeQuantPerChannelUpdate TBE register"""
return
@fusion_manager.register("fake_quant_min_max_per_channel_update")
def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val,
ema, ema_decay, quant_min, quant_max, training, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerChannelUpdate compute"""
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
if not ema:
ema_decay = 0.0
if training:
# CalMinMax
axis = [0, 2, 3]
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerLayer op"""
x_shape = x.get("ori_shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
check_list = ["float32", "float16"]
x_dtype = x_dtype.lower()
min_dtype = min_dtype.lower()
max_dtype = max_dtype.lower()
util.check_dtype_rule(x_dtype, check_list)
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
tensor_list = [input_data, min_data, max_data] + list(res_list)
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)
......@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""FakeQuantWithMinMaxUpdate op"""
"""FakeQuantMinMaxPerLayerUpdate op"""
from functools import reduce as functools_reduce
import te.lang.cce
from te import tvm
......@@ -23,12 +23,12 @@ from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_with_min_max_update.so") \
.binfile_name("fake_quant_minmax_update.so") \
.compute_cost(10) \
.kernel_name("fake_quant_with_min_max_update") \
.kernel_name("fake_quant_minmax_update") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
......@@ -36,7 +36,6 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
......@@ -47,16 +46,16 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
.get_op_info()
@op_info_register(fake_quant_update_op_info)
def _fake_quant_update_tbe():
"""_FakeQuantWithMinMaxUpdate TBE register"""
@op_info_register(fake_quant_minmax_update_op_info)
def _fake_quant_minmax_update_tbe():
"""FakeQuantMinMaxPerLayerUpdate TBE register"""
return
@fusion_manager.register("fake_quant_with_min_max_update")
def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
kernel_name="fake_quant_update"):
"""FakeQuantWithMinMaxUpdate compute"""
@fusion_manager.register("fake_quant_minmax_update")
def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
kernel_name="fake_quant_minmax_update"):
"""FakeQuantMinMaxPerLayerUpdate compute"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
......@@ -70,19 +69,21 @@ def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay,
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay,
kernel_name="fake_quant_update"):
"""FakeQuantWithMinMax op"""
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str)
def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits,
kernel_name="fake_quant_minmax_update"):
"""FakeQuantPerLayer op"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
......@@ -123,7 +124,7 @@ def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up,
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res_list = fake_quant_with_min_max_update_compute(input_data, min_data, max_data,
res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, kernel_name)
with tvm.target.cce():
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantPerChannel op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_perchannel_op_info = TBERegOp("FakeQuantPerChannel") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("fake_quant_perchannel.so") \
.compute_cost(10) \
.kernel_name("fake_quant_perchannel") \
.partial_flag(True) \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(fake_quant_perchannel_op_info)
def _fake_quant_perchannel_tbe():
"""FakeQuantPerChannel TBE register"""
return
@fusion_manager.register("fake_quant_perchannel")
def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max,
kernel_name="fake_quant_perchannel"):
"""FakeQuantPerChannel"""
x_shape = te.lang.cce.util.shape_to_list(x.shape)
minmax_shape = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = tvm.const(quant_min, x.dtype)
quant_max = tvm.const(quant_max, x.dtype)
quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype)
quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
# FakeQuant
nudge_min_b = te.lang.cce.broadcast(nudge_min, x_shape)
nudge_max_b = te.lang.cce.broadcast(nudge_max, x_shape)
scale_b = te.lang.cce.broadcast(scale, x_shape)
input_x = te.lang.cce.vmin(nudge_max_b, te.lang.cce.vmax(nudge_min_b, x))
nudge_input_ = te.lang.cce.vdiv(
te.lang.cce.vsub(input_x, nudge_min_b), scale_b)
nudge_input = te.lang.cce.floor(te.lang.cce.vadds(nudge_input_, 0.5))
res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale_b), nudge_min_b)
return res
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel(x, min_val, max_val, y,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel"):
"""FakeQuantPerChannel"""
x_shape = x.get("shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
check_list = ["float32", "float16"]
x_dtype = x_dtype.lower()
min_dtype = min_dtype.lower()
max_dtype = max_dtype.lower()
util.check_dtype_rule(x_dtype, check_list)
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1:
shape_c = min_val.get("shape")
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res = fake_quant_perchannel_compute(input_data, min_data, max_data, y,
quant_min, quant_max, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)
tensor_list = [input_data, min_data, max_data, res]
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantPerChannelGrad op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
SHAPE_SIZE_LIMIT = 2147483648
D_TYPE = 'float32'
fake_quant_perchannel_grad_op_info = TBERegOp("FakeQuantPerChannelGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_perchannel_grad.so") \
.compute_cost(10) \
.kernel_name("fake_quant_perchannel_grad") \
.partial_flag(True) \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "dout", None, "required", None) \
.input(1, "x", None, "required", None) \
.input(2, "min", None, "required", None) \
.input(3, "max", None, "required", None) \
.output(0, "dx", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
def _less_compare_float32(data_x, data_y):
"""_less_compare_float32 compute"""
input_shape = te.lang.cce.util.shape_to_list(data_x.shape)
min_value = tvm.const(2 ** (-126), dtype=D_TYPE)
max_value = tvm.const(2 ** 62, dtype=D_TYPE)
factor_value = tvm.const(2 ** 2, dtype=D_TYPE)
data_zero = te.lang.cce.broadcast(
tvm.const(0, dtype=D_TYPE), input_shape, D_TYPE)
min_value_tensor = te.lang.cce.vadds(data_zero, min_value)
res_sub = te.lang.cce.vsub(data_y, data_x)
res_min = te.lang.cce.vmin(res_sub, min_value_tensor)
res_max = te.lang.cce.vmax(res_min, data_zero)
res_max_mul = te.lang.cce.vmuls(res_max, max_value)
res_max_mul_max = te.lang.cce.vmuls(res_max_mul, max_value)
res = te.lang.cce.vmuls(res_max_mul_max, factor_value)
return res
@op_info_register(fake_quant_perchannel_grad_op_info)
def _fake_quant_perchannel_grad_tbe():
"""FakeQuantPerChannelGrad TBE register"""
return
@fusion_manager.register("fake_quant_perchannel_grad")
def fake_quant_perchannel_grad_compute(dout, x, min_val, max_val, quant_min, quant_max,
kernel_name="fake_quant_perchannel_grad"):
"""FakeQuantPerChannelGrad"""
x_shape = te.lang.cce.util.shape_to_list(x.shape)
minmax_shape = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = tvm.const(quant_min, x.dtype)
quant_max = tvm.const(quant_max, x.dtype)
quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype)
quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
# FakeQuant Grad
nudge_min_b = te.lang.cce.broadcast(nudge_min, x_shape)
nudge_max_b = te.lang.cce.broadcast(nudge_max, x_shape)
bool_over_min = _less_compare_float32(nudge_min_b, x)
bool_less_max = _less_compare_float32(x, nudge_max_b)
bool_between = te.lang.cce.vmul(bool_over_min, bool_less_max)
res = te.lang.cce.vmul(dout, bool_between)
return res
@util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel_grad"):
"""FakeQuantPerChannelGrad"""
x_shape = x.get("shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
check_list = ["float32", "float16"]
x_dtype = x_dtype.lower()
min_dtype = min_dtype.lower()
max_dtype = max_dtype.lower()
util.check_dtype_rule(x_dtype, check_list)
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1:
shape_c = min_val.get("shape")
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res = fake_quant_perchannel_grad_compute(dout_data, input_data, min_data, max_data,
quant_min, quant_max, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)
tensor_list = [dout_data, input_data, min_data, max_data, res]
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)
......@@ -13,8 +13,7 @@
# limitations under the License.
# ============================================================================
"""FakeQuantWithMinMax op"""
"""FakeQuantPerLayer op"""
from functools import reduce as functools_reduce
import te.lang.cce
from te import tvm
......@@ -23,20 +22,16 @@ from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \
fake_quant_per_layer_op_info = TBERegOp("FakeQuantPerLayer") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("fake_quant_with_min_max_vars_ema.so") \
.binfile_name("fake_quant_per_layer.so") \
.compute_cost(10) \
.kernel_name("fake_quant_with_min_max_vars_ema") \
.kernel_name("fake_quant_per_layer") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
......@@ -49,15 +44,15 @@ fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \
@op_info_register(fake_quant_op_info)
def _fake_quant_tbe():
"""FakeQuantWithMinMax TBE register"""
def _fake_quant_per_layer_tbe():
"""FakeQuantPerLayer TBE register"""
return
@fusion_manager.register("fake_quant_with_min_max_vars_ema")
def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, quant_max,
kernel_name="correction_mul"):
"""FakeQuantWithMinMax"""
@fusion_manager.register("fake_quant_per_layer")
def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max,
kernel_name="fake_quant_per_layer"):
"""FakeQuantPerLayer"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype)
......@@ -66,10 +61,13 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min)))
nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
......@@ -80,17 +78,19 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
# FakeQuant
input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x))
nudge_input = te.lang.cce.round(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale))
nudge_input_ = te.lang.cce.vdiv(
te.lang.cce.vsub(input_x, nudge_min), scale)
nudge_input = te.lang.cce.floor(te.lang.cce.vadds(nudge_input_, 0.5))
res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min)
return res
@util.check_input_type(dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y,
ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay,
kernel_name="fake_quant"):
"""FakeQuantWithMinMax"""
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, str)
def fake_quant_per_layer(x, min_val, max_val, y,
symmetric, narrow_range, num_bits,
kernel_name="fake_quant_per_layer"):
"""FakeQuantPerLayer"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
......@@ -131,7 +131,7 @@ def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y,
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res = fake_quant_with_min_max_vars_ema_compute(input_data, min_data, max_data, y,
res = fake_quant_per_layer_compute(input_data, min_data, max_data, y,
quant_min, quant_max, kernel_name)
with tvm.target.cce():
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""FakeQuantWithMinMaxGrad op"""
"""FakeQuantPerLayerGrad op"""
from functools import reduce as functools_reduce
import te.lang.cce
......@@ -26,15 +26,14 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
SHAPE_SIZE_LIMIT = 2147483648
D_TYPE = 'float32'
fake_quant_grad_op_info = TBERegOp("FakeQuantWithMinMaxGrad") \
fake_quant_per_layer_grad_op_info = TBERegOp("FakeQuantPerLayerGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_with_min_max_grad.so") \
.binfile_name("fake_quant_per_layer_grad.so") \
.compute_cost(10) \
.kernel_name("fake_quant_with_min_max_grad") \
.kernel_name("fake_quant_per_layer_grad") \
.partial_flag(True) \
.attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.input(0, "dout", None, "required", None) \
......@@ -57,7 +56,8 @@ def _less_compare_float32(data_x, data_y):
min_value = tvm.const(2 ** (-126), dtype=D_TYPE)
max_value = tvm.const(2 ** 62, dtype=D_TYPE)
factor_value = tvm.const(2 ** 2, dtype=D_TYPE)
data_zero = te.lang.cce.broadcast(tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE)
data_zero = te.lang.cce.broadcast(
tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE)
min_value_tensor = te.lang.cce.vadds(data_zero, min_value)
res_sub = te.lang.cce.vsub(data_y, data_x)
......@@ -71,16 +71,16 @@ def _less_compare_float32(data_x, data_y):
return res
@op_info_register(fake_quant_grad_op_info)
def _fake_quant_grad_tbe():
"""FakeQuantWithMinMaxGrad TBE register"""
@op_info_register(fake_quant_per_layer_grad_op_info)
def _fake_quant_per_layer_grad_tbe():
"""FakeQuantPerLayerGrad TBE register"""
return
@fusion_manager.register("fake_quant_with_min_max_grad")
def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, quant_max,
kernel_name="fake_quant_with_min_max_grad"):
"""FakeQuantWithMinMaxGrad"""
@fusion_manager.register("fake_quant_per_layer_grad")
def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max,
kernel_name="fake_quant_per_layer_grad"):
"""FakeQuantPerLayerGrad"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = tvm.const(quant_min, x.dtype)
......@@ -89,10 +89,13 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q
quant_max = te.lang.cce.broadcast(quant_max, shape_min)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min)))
nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
nudge_min = te.lang.cce.broadcast(nudge_min, shape)
......@@ -106,11 +109,11 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q
return res
@util.check_input_type(dict, dict, dict, dict, dict, int, int, bool, bool, str)
def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx,
num_bits, quant_delay, symmetric, narrow_range,
kernel_name="fake_quant_with_min_max_grad"):
"""FakeQuantWithMinMaxGrad"""
@util.check_input_type(dict, dict, dict, dict, dict, int, bool, bool, str)
def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx,
num_bits, symmetric, narrow_range,
kernel_name="fake_quant_per_layer_grad"):
"""FakeQuantPerLayerGrad"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
......@@ -152,7 +155,7 @@ def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx,
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res = fake_quant_with_min_max_grad_compute(dout_data, input_data, min_data, max_data, quant_min,
res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data, quant_min,
quant_max, kernel_name)
with tvm.target.cce():
......
......@@ -20,10 +20,12 @@ from ..._checkparam import Rel
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ...common import dtype as mstype
__all__ = ["FakeQuantWithMinMax",
"FakeQuantWithMinMaxGrad",
"FakeQuantWithMinMaxPerChannel",
"FakeQuantWithMinMaxPerChannelGrad",
__all__ = ["FakeQuantPerLayer",
"FakeQuantPerLayerGrad",
"FakeQuantPerChannel",
"FakeQuantPerChannelGrad",
"FakeQuantMinMaxPerLayerUpdate",
"FakeQuantMinMaxPerChannelUpdate",
"BatchNormFold",
"BatchNormFoldGrad",
"CorrectionMul",
......@@ -36,11 +38,10 @@ __all__ = ["FakeQuantWithMinMax",
"BatchNormFold2_D",
"BatchNormFold2GradD",
"BatchNormFold2GradReduce",
"FakeQuantWithMinMaxUpdate",
]
class FakeQuantWithMinMax(PrimitiveWithInfer):
class FakeQuantPerLayer(PrimitiveWithInfer):
r"""
Simulate the quantize and dequantize operations in training time.
......@@ -67,49 +68,67 @@ class FakeQuantWithMinMax(PrimitiveWithInfer):
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
>>> output_tensor = P.FakeQuantPerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
def __init__(self,
num_bits=8,
ema=False,
ema_decay=0.999,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True):
"""init FakeQuantWithMinMax OP"""
"""init FakeQuantPerLayer OP"""
if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type('training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type(
'quant_delay', quant_delay, (int,), self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['out'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
min_shape), 1, Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return x_type
class FakeQuantWithMinMaxGrad(PrimitiveWithInfer):
class FakeQuantPerLayerGrad(PrimitiveWithInfer):
r"""
Performs grad of FakeQuantWithMinMax operation.
Performs grad of FakeQuantPerLayerGrad operation.
Examples:
>>> fake_min_max_grad = P.FakeQuantWithMinMaxGrad()
>>> fake_min_max_grad = P.FakeQuantPerLayerGrad()
>>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
>>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
>>> _min = Tensor(np.array([-4]), mindspore.float32)
......@@ -119,32 +138,48 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer):
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False):
def __init__(self,
num_bits=8,
quant_delay=0,
symmetric=False,
narrow_range=False):
if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type(
'quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.init_prim_io_names(
inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name)
validator.check("dout shape", dout_shape, "x shape",
x_shape, Rel.EQ, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
min_shape), 1, Rel.EQ, self.name)
return dout_shape
def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return dout_type
class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
class FakeQuantPerChannel(PrimitiveWithInfer):
r"""
Simulate the quantize and dequantize operations in training time base on per channel.
......@@ -168,53 +203,73 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
- Tensor, has the same type as input.
Examples:
>>> fake_quant = P.FakeQuantWithMinMaxPerChannel()
>>> fake_quant = P.FakeQuantPerChannel()
>>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32)
>>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32)
>>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32)
>>> result = fake_quant(input_x, _min, _max)
"""
support_quant_bit = [4, 7, 8]
channel_axis = 0
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
training=True):
"""init FakeQuantWithMinMaxPerChannel OP"""
def __init__(self,
num_bits=8,
ema=False,
ema_decay=0.999,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True,
channel_axis=1):
"""init FakeQuantPerChannel OP"""
if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' Attr \'num_bits\' is not support.")
raise ValueError(
f"For '{self.name}' Attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type('training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type(
'quant_delay', quant_delay, (int,), self.name)
self.channel_axis = validator.check_integer(
'channel_axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer(
"min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer(
"max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return x_type
class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
class FakeQuantPerChannelGrad(PrimitiveWithInfer):
r"""
Performs grad of FakeQuantWithMinMaxPerChannel operation.
Performs grad of FakeQuantPerChannelGrad operation.
Examples:
>>> fqmmpc_grad = P.FakeQuantWithMinMaxPerChannelGrad()
>>> fqmmpc_grad = P.FakeQuantPerChannelGrad()
>>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32)
>>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32)
>>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32)
......@@ -224,16 +279,29 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False):
"""init FakeQuantWithMinMaxPerChannel Fill"""
def __init__(self,
num_bits=8,
quant_delay=0,
symmetric=False,
narrow_range=False,
channel_axis=1):
"""init FakeQuantPerChannelGrad Fill"""
if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type(
'quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
validator.check("dout shape", dout_shape, "x shape", x_shape)
......@@ -242,10 +310,13 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return dout_type
......@@ -744,17 +815,14 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
return dout_type, dout_type
class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer):
class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
r"""
Simulate the quantize and dequantize operations in training time.
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
simulate aware quantize funcion. After delay step in training time begin simulate the aware
quantize funcion. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
......@@ -776,36 +844,121 @@ class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer):
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True):
"""init FakeQuantWithMinMax OP"""
"""init FakeQuantMinMaxPerLayerUpdate OP"""
from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update
if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type('training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册