提交 b7db3e9a 编写于 作者: C chenzomi

add fake quant per channel and bug fix

上级 bd3e8da6
...@@ -171,6 +171,6 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std ...@@ -171,6 +171,6 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std
return true; return true;
} }
MS_REG_GPU_KERNEL(FakeQuantWithMinMax, FakeQuantGpuKernel) MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantGpuKernel)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
...@@ -153,6 +153,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const ...@@ -153,6 +153,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const
return true; return true;
} }
MS_REG_GPU_KERNEL(FakeQuantWithMinMaxGrad, FakeQuantGradGpuKernel) MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantGradGpuKernel)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
...@@ -175,6 +175,6 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs, ...@@ -175,6 +175,6 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
return true; return true;
} }
MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannel, FakeQuantPerChannelGpuKernel) MS_REG_GPU_KERNEL(FakeQuantPerChannel, FakeQuantPerChannelGpuKernel)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
...@@ -143,6 +143,6 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp ...@@ -143,6 +143,6 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
return true; return true;
} }
MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannelGrad, FakeQuantPerChannelGradGpuKernel) MS_REG_GPU_KERNEL(FakeQuantPerChannelGrad, FakeQuantPerChannelGradGpuKernel)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Aware quantization.""" """Aware quantization."""
from functools import partial
import numpy as np import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
...@@ -101,10 +102,9 @@ class BatchNormFoldCell(Cell): ...@@ -101,10 +102,9 @@ class BatchNormFoldCell(Cell):
return batch_mean, batch_std, running_mean, running_std return batch_mean, batch_std, running_mean, running_std
class FakeQuantWithMinMaxD(Cell): class FakeQuantWithMinMaxAscend(Cell):
r""" r"""
Aware Quantization training op of ascend. This OP provide Fake quantization observer Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
function on data with min and max.
Args: Args:
min_init (int, list): The dimension of channel or 1(layer). Default: -6. min_init (int, list): The dimension of channel or 1(layer). Default: -6.
...@@ -125,7 +125,7 @@ class FakeQuantWithMinMaxD(Cell): ...@@ -125,7 +125,7 @@ class FakeQuantWithMinMaxD(Cell):
Tensor, with the same type and shape as the `x`. Tensor, with the same type and shape as the `x`.
Examples: Examples:
>>> fake_quant = nn.FakeQuantWithMinMaxD() >>> fake_quant = FakeQuantWithMinMax()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x) >>> result = fake_quant(input_x)
""" """
...@@ -137,75 +137,77 @@ class FakeQuantWithMinMaxD(Cell): ...@@ -137,75 +137,77 @@ class FakeQuantWithMinMaxD(Cell):
ema=False, ema=False,
ema_decay=0.999, ema_decay=0.999,
per_channel=False, per_channel=False,
channel_size=1, channel_axis=1,
out_channels=1,
quant_delay=0, quant_delay=0,
symmetric=False, symmetric=False,
narrow_range=False, narrow_range=False,
training=True): training=True):
"""init FakeQuantWithMinMax ascend layer""" """init FakeQuantWithMinMaxAscend layer"""
super(FakeQuantWithMinMaxD, self).__init__() super(FakeQuantWithMinMaxAscend, self).__init__()
self.min_init = min_init self.min_init = min_init
self.num_bits = num_bits
self.max_init = max_init self.max_init = max_init
self.num_bits = num_bits
self.ema = ema self.ema = ema
self.ema_decay = ema_decay self.ema_decay = ema_decay
self.per_channel = per_channel self.per_channel = per_channel
self.channel_size = channel_size self.channel_axis = channel_axis
self.quant_delay = quant_delay self.quant_delay = quant_delay
self.symmetric = symmetric self.symmetric = symmetric
self.narrow_range = narrow_range self.narrow_range = narrow_range
self.training = training self.training = training
if not per_channel: # init tensor min and max for fake quant op
self.fake_quant = P.FakeQuantWithMinMax(num_bits=self.num_bits, if isinstance(min_init, int):
ema=self.ema, min_array = np.array([min_init]).reshape(1).astype(np.float32)
ema_decay=self.ema_decay, max_array = np.array([max_init]).reshape(1).astype(np.float32)
quant_delay=self.quant_delay, elif isinstance(min_init, list):
symmetric=self.symmetric, min_array = np.array([self.min_init for i in range(
narrow_range=self.narrow_range, 0, self.out_channels)]).astype(np.float32)
training=training) max_array = np.array([self.max_init for i in range(
self.ema_update = P.FakeQuantWithMinMaxUpdate(num_bits=self.num_bits, 0, self.out_channels)]).astype(np.float32)
ema=self.ema, self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
ema_decay=self.ema_decay, self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=training)
else:
raise RuntimeError("not support per channel")
if isinstance(min_init, Parameter): if per_channel:
self.minq = min_init quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis)
self.maxq = max_init ema_fun = partial(P.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis)
else: else:
self.minq = Parameter(Tensor(np.array([min_init]).astype(np.float32)), quant_fun = P.FakeQuantPerLayer
name='quant_min', ema_fun = P.FakeQuantMinMaxPerLayerUpdate
requires_grad=False)
self.maxq = Parameter(Tensor(np.array([max_init]).astype(np.float32)), self.fake_quant = quant_fun(num_bits=self.num_bits,
name='quant_max', ema=self.ema,
requires_grad=False) ema_decay=self.ema_decay,
self.reduce_min = P.ReduceMin() quant_delay=self.quant_delay,
self.reduce_max = P.ReduceMax() symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
self.ema_update = ema_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
def extend_repr(self): def extend_repr(self):
s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format( s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size, self.min_init, self.max_init, self.ema, self.ema_decay,
self.quant_delay) self.per_channel, self.quant_delay, self.channel_axis)
return s return s
def construct(self, x, minq, maxq): def construct(self, x):
if self.training: if self.update:
min_up, max_up = self.ema_update(x, minq, maxq) min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant(x, min_up, max_up) out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up) P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up) P.Assign()(self.maxq, max_up)
else: else:
out = self.fake_quant(x, minq, maxq) out = self.fake_quant(x, self.minq, self.maxq)
return out return out
class FakeQuantWithMinMax(Cell): class FakeQuantWithMinMaxGPU(Cell):
r""" r"""
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
...@@ -240,98 +242,69 @@ class FakeQuantWithMinMax(Cell): ...@@ -240,98 +242,69 @@ class FakeQuantWithMinMax(Cell):
ema=False, ema=False,
ema_decay=0.999, ema_decay=0.999,
per_channel=False, per_channel=False,
channel_axis=1,
out_channels=1, out_channels=1,
quant_delay=0, quant_delay=0,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False,
"""init FakeQuantWithMinMax layer""" training=True):
super(FakeQuantWithMinMax, self).__init__() super(FakeQuantWithMinMaxGPU, self).__init__()
self.min_init = min_init self.min_init = min_init
self.num_bits = num_bits
self.max_init = max_init self.max_init = max_init
self.num_bits = num_bits
self.ema = ema self.ema = ema
self.ema_decay = ema_decay self.ema_decay = ema_decay
self.per_channel = per_channel self.per_channel = per_channel
self.out_channels = out_channels self.channel_axis = channel_axis
self.quant_delay = quant_delay self.quant_delay = quant_delay
self.symmetric = symmetric self.symmetric = symmetric
self.narrow_range = narrow_range self.narrow_range = narrow_range
self.training = training
if per_channel: # init tensor min and max for fake quant op
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32) if isinstance(min_init, int):
max_array = np.array([self.max_init for i in range(0, self.channel_size)]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True)
self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False)
else:
min_array = np.array([min_init]).reshape(1).astype(np.float32) min_array = np.array([min_init]).reshape(1).astype(np.float32)
max_array = np.array([max_init]).reshape(1).astype(np.float32) max_array = np.array([max_init]).reshape(1).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) elif isinstance(min_init, list):
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) min_array = np.array([self.min_init for i in range(
if context.get_context('device_target') == "Ascend": 0, self.out_channels)]).astype(np.float32)
self.fake_quant_train = FakeQuantWithMinMaxD(num_bits=self.num_bits, max_array = np.array([self.max_init for i in range(
ema=self.ema, 0, self.out_channels)]).astype(np.float32)
ema_decay=self.ema_decay, self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
quant_delay=self.quant_delay, self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
symmetric=self.symmetric,
narrow_range=self.narrow_range, if per_channel:
training=True, quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis)
min_init=self.minq, else:
max_init=self.maxq) quant_fun = P.FakeQuantPerLayer
self.fake_quant_infer = FakeQuantWithMinMaxD(num_bits=self.num_bits, self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema, ema=self.ema,
ema_decay=self.ema_decay, ema_decay=ema_decay,
quant_delay=self.quant_delay, quant_delay=quant_delay,
symmetric=self.symmetric, symmetric=self.symmetric,
narrow_range=self.narrow_range, narrow_range=self.narrow_range,
training=False, training=self.training)
min_init=self.minq,
max_init=self.maxq)
elif context.get_context('device_target') == "GPU":
self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True)
self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False)
else:
raise ValueError("Not support platform.")
def extend_repr(self): def extend_repr(self):
s = 'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'.format( s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.quant_delay) self.min_init, self.max_init, self.ema, self.ema_decay,
self.per_channel, self.quant_delay, self.channel_axis)
return s return s
def construct(self, x): def construct(self, x):
if self.training: out = self.fake_quant(x, self.minq, self.maxq)
out = self.fake_quant_train(x, self.minq, self.maxq)
else:
out = self.fake_quant_infer(x, self.minq, self.maxq)
return out return out
def FakeQuantWithMinMax(**kwargs):
if context.get_context('device_target') == "Ascend":
out = FakeQuantWithMinMaxAscend(**kwargs)
if context.get_context('device_target') == "GPU":
out = FakeQuantWithMinMaxGPU(**kwargs)
else:
raise ValueError("Not support platform or channel mode.")
return out
class Conv2dBatchNormQuant(Cell): class Conv2dBatchNormQuant(Cell):
r""" r"""
2D convolution with BatchNormal op folded layer. 2D convolution with BatchNormal op folded layer.
...@@ -420,7 +393,6 @@ class Conv2dBatchNormQuant(Cell): ...@@ -420,7 +393,6 @@ class Conv2dBatchNormQuant(Cell):
self.per_channel = per_channel self.per_channel = per_channel
self.symmetric = symmetric self.symmetric = symmetric
self.narrow_range = narrow_range self.narrow_range = narrow_range
self.channel_axis = int(group > 1)
self.is_gpu = context.get_context('device_target') == "GPU" self.is_gpu = context.get_context('device_target') == "GPU"
# initialize convolution op and Parameter # initialize convolution op and Parameter
...@@ -435,6 +407,7 @@ class Conv2dBatchNormQuant(Cell): ...@@ -435,6 +407,7 @@ class Conv2dBatchNormQuant(Cell):
dilation=self.dilation) dilation=self.dilation)
if weight_init is None: if weight_init is None:
weight_init = initializer('normal', [1, in_channels, *self.kernel_size]) weight_init = initializer('normal', [1, in_channels, *self.kernel_size])
channel_axis = 1
else: else:
self.conv = P.Conv2D(out_channel=out_channels, self.conv = P.Conv2D(out_channel=out_channels,
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
...@@ -445,6 +418,7 @@ class Conv2dBatchNormQuant(Cell): ...@@ -445,6 +418,7 @@ class Conv2dBatchNormQuant(Cell):
group=group) group=group)
if weight_init is None: if weight_init is None:
weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size]) weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size])
channel_axis = 0
self.weight = Parameter(weight_init, name='weight') self.weight = Parameter(weight_init, name='weight')
# initialize batchnorm Parameter # initialize batchnorm Parameter
...@@ -472,7 +446,7 @@ class Conv2dBatchNormQuant(Cell): ...@@ -472,7 +446,7 @@ class Conv2dBatchNormQuant(Cell):
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range) narrow_range=narrow_range)
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
self.correct_mul = P.CorrectionMul(self.channel_axis) self.correct_mul = P.CorrectionMul(channel_axis)
if context.get_context('device_target') == "Ascend": if context.get_context('device_target') == "Ascend":
self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn) self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn)
self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0) self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0)
...@@ -520,7 +494,7 @@ class Conv2dBatchNormQuant(Cell): ...@@ -520,7 +494,7 @@ class Conv2dBatchNormQuant(Cell):
out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std) out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
F.control_depend(out, self.assignadd(self.step, self.one)) F.control_depend(out, self.assignadd(self.step, self.one))
else: else:
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std) out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, running_std, running_mean, running_std)
return out return out
......
...@@ -20,10 +20,11 @@ from .grad_base import bprop_getters ...@@ -20,10 +20,11 @@ from .grad_base import bprop_getters
from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..composite.multitype_ops.zeros_like_impl import zeros_like
@bprop_getters.register(P.FakeQuantWithMinMax) @bprop_getters.register(P.FakeQuantPerLayer)
def get_bprop_fakequant_with_minmax(self): def get_bprop_fakequant_with_minmax(self):
"""Generate bprop for FakeQuantWithMinMax for GPU and Ascend""" """Generate bprop for FakeQuantPerLayer for GPU and Ascend"""
op = P.FakeQuantWithMinMaxGrad(num_bits=self.num_bits, quant_delay=self.quant_delay) op = P.FakeQuantPerLayerGrad(
num_bits=self.num_bits, quant_delay=self.quant_delay)
def bprop(x, x_min, x_max, out, dout): def bprop(x, x_min, x_max, out, dout):
dx = op(dout, x, x_min, x_max) dx = op(dout, x, x_min, x_max)
...@@ -32,10 +33,14 @@ def get_bprop_fakequant_with_minmax(self): ...@@ -32,10 +33,14 @@ def get_bprop_fakequant_with_minmax(self):
return bprop return bprop
@bprop_getters.register(P.FakeQuantWithMinMaxPerChannel) @bprop_getters.register(P.FakeQuantPerChannel)
def get_bprop_fakequant_with_minmax_perchannel(self): def get_bprop_fakequant_with_minmax_perchannel(self):
"""Generate bprop for FakeQuantWithMinMaxPerChannel for GPU""" """Generate bprop for FakeQuantPerChannel"""
op = P.FakeQuantWithMinMaxPerChannelGrad(num_bits=self.num_bits, quant_delay=self.quant_delay) op = P.FakeQuantPerChannelGrad(num_bits=self.num_bits,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.symmetric,
channel_axis=self.channel_axis)
def bprop(x, x_min, x_max, out, dout): def bprop(x, x_min, x_max, out, dout):
dx = op(dout, x, x_min, x_max) dx = op(dout, x, x_min, x_max)
...@@ -77,7 +82,7 @@ def get_bprop_batchnorm_fold2(self): ...@@ -77,7 +82,7 @@ def get_bprop_batchnorm_fold2(self):
d_batch_std, d_batch_mean, d_beta, d_gamma, d_x = op_f(dout, x, gamma, batch_std, batch_mean, running_std, d_batch_std, d_batch_mean, d_beta, d_gamma, d_x = op_f(dout, x, gamma, batch_std, batch_mean, running_std,
running_mean, global_step) running_mean, global_step)
return d_x, d_beta, d_gamma, d_batch_std, d_batch_mean, zeros_like(running_std), zeros_like(running_mean), \ return d_x, d_beta, d_gamma, d_batch_std, d_batch_mean, zeros_like(running_std), zeros_like(running_mean), \
zeros_like(global_step) zeros_like(global_step)
return bprop return bprop
...@@ -117,9 +122,19 @@ def get_bprop_batchnorm_fold2_(self): ...@@ -117,9 +122,19 @@ def get_bprop_batchnorm_fold2_(self):
return bprop return bprop
@bprop_getters.register(P.FakeQuantWithMinMaxUpdate) @bprop_getters.register(P.FakeQuantMinMaxPerLayerUpdate)
def get_bprop_fakequant_with_minmax_update(self): def get_bprop_fakequant_with_minmax_per_layer_update(self):
"""Generate bprop for FakeQuantWithMinMaxUpdate for Ascend""" """Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend"""
def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
return bprop
@bprop_getters.register(P.FakeQuantMinMaxPerChannelUpdate)
def get_bprop_fakequant_with_minmax_per_channel_update(self):
"""Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend"""
def bprop(x, x_min, x_max, out, dout): def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max) return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantMinMaxPerChannelUpdate op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_min_max_per_channel_update.so") \
.compute_cost(10) \
.kernel_name("fake_quant_min_max_per_channel_update") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
.output(0, "min_up", True, "required", "all") \
.output(1, "max_up", True, "required", "all") \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.get_op_info()
@op_info_register(fake_quant_min_max_per_channel_update_op_info)
def _fake_quant_min_max_per_channel_update_tbe():
"""FakeQuantPerChannelUpdate TBE register"""
return
@fusion_manager.register("fake_quant_min_max_per_channel_update")
def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val,
ema, ema_decay, quant_min, quant_max, training, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerChannelUpdate compute"""
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
if not ema:
ema_decay = 0.0
if training:
# CalMinMax
axis = [0, 2, 3]
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerLayer op"""
x_shape = x.get("ori_shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
check_list = ["float32", "float16"]
x_dtype = x_dtype.lower()
min_dtype = min_dtype.lower()
max_dtype = max_dtype.lower()
util.check_dtype_rule(x_dtype, check_list)
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
tensor_list = [input_data, min_data, max_data] + list(res_list)
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""FakeQuantWithMinMaxUpdate op""" """FakeQuantMinMaxPerLayerUpdate op"""
from functools import reduce as functools_reduce from functools import reduce as functools_reduce
import te.lang.cce import te.lang.cce
from te import tvm from te import tvm
...@@ -23,12 +23,12 @@ from topi.cce import util ...@@ -23,12 +23,12 @@ from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.async_flag(False) \ .async_flag(False) \
.binfile_name("fake_quant_with_min_max_update.so") \ .binfile_name("fake_quant_minmax_update.so") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("fake_quant_with_min_max_update") \ .kernel_name("fake_quant_minmax_update") \
.partial_flag(True) \ .partial_flag(True) \
.attr("ema", "optional", "bool", "all") \ .attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \ .attr("ema_decay", "optional", "float", "all") \
...@@ -36,7 +36,6 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ ...@@ -36,7 +36,6 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
.attr("narrow_range", "optional", "bool", "all") \ .attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \ .attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \ .attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.input(0, "x", None, "required", None) \ .input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \ .input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \ .input(2, "max", None, "required", None) \
...@@ -47,16 +46,16 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ ...@@ -47,16 +46,16 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
.get_op_info() .get_op_info()
@op_info_register(fake_quant_update_op_info) @op_info_register(fake_quant_minmax_update_op_info)
def _fake_quant_update_tbe(): def _fake_quant_minmax_update_tbe():
"""_FakeQuantWithMinMaxUpdate TBE register""" """FakeQuantMinMaxPerLayerUpdate TBE register"""
return return
@fusion_manager.register("fake_quant_with_min_max_update") @fusion_manager.register("fake_quant_minmax_update")
def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training, def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
kernel_name="fake_quant_update"): kernel_name="fake_quant_minmax_update"):
"""FakeQuantWithMinMaxUpdate compute""" """FakeQuantMinMaxPerLayerUpdate compute"""
shape = te.lang.cce.util.shape_to_list(x.shape) shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
...@@ -70,19 +69,21 @@ def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay, ...@@ -70,19 +69,21 @@ def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay,
x_max = te.lang.cce.reduce_max(x, axis=axis) x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min) x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min) x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0) min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0) max_val = te.lang.cce.vmaxs(max_val, 0)
return [min_val, max_val] return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) @util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str)
def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up, def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay, ema, ema_decay, symmetric, narrow_range, training, num_bits,
kernel_name="fake_quant_update"): kernel_name="fake_quant_minmax_update"):
"""FakeQuantWithMinMax op""" """FakeQuantPerLayer op"""
input_shape = x.get("shape") input_shape = x.get("shape")
input_dtype = x.get("dtype") input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape") min_shape = min_val.get("ori_shape")
...@@ -123,8 +124,8 @@ def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up, ...@@ -123,8 +124,8 @@ def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up,
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res_list = fake_quant_with_min_max_update_compute(input_data, min_data, max_data, res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, kernel_name) ema, ema_decay, quant_min, quant_max, training, kernel_name)
with tvm.target.cce(): with tvm.target.cce():
sch = generic.auto_schedule(res_list) sch = generic.auto_schedule(res_list)
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantPerChannel op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_perchannel_op_info = TBERegOp("FakeQuantPerChannel") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("fake_quant_perchannel.so") \
.compute_cost(10) \
.kernel_name("fake_quant_perchannel") \
.partial_flag(True) \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(fake_quant_perchannel_op_info)
def _fake_quant_perchannel_tbe():
"""FakeQuantPerChannel TBE register"""
return
@fusion_manager.register("fake_quant_perchannel")
def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max,
kernel_name="fake_quant_perchannel"):
"""FakeQuantPerChannel"""
x_shape = te.lang.cce.util.shape_to_list(x.shape)
minmax_shape = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = tvm.const(quant_min, x.dtype)
quant_max = tvm.const(quant_max, x.dtype)
quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype)
quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
# FakeQuant
nudge_min_b = te.lang.cce.broadcast(nudge_min, x_shape)
nudge_max_b = te.lang.cce.broadcast(nudge_max, x_shape)
scale_b = te.lang.cce.broadcast(scale, x_shape)
input_x = te.lang.cce.vmin(nudge_max_b, te.lang.cce.vmax(nudge_min_b, x))
nudge_input_ = te.lang.cce.vdiv(
te.lang.cce.vsub(input_x, nudge_min_b), scale_b)
nudge_input = te.lang.cce.floor(te.lang.cce.vadds(nudge_input_, 0.5))
res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale_b), nudge_min_b)
return res
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel(x, min_val, max_val, y,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel"):
"""FakeQuantPerChannel"""
x_shape = x.get("shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
check_list = ["float32", "float16"]
x_dtype = x_dtype.lower()
min_dtype = min_dtype.lower()
max_dtype = max_dtype.lower()
util.check_dtype_rule(x_dtype, check_list)
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1:
shape_c = min_val.get("shape")
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res = fake_quant_perchannel_compute(input_data, min_data, max_data, y,
quant_min, quant_max, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)
tensor_list = [input_data, min_data, max_data, res]
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantPerChannelGrad op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
SHAPE_SIZE_LIMIT = 2147483648
D_TYPE = 'float32'
fake_quant_perchannel_grad_op_info = TBERegOp("FakeQuantPerChannelGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_perchannel_grad.so") \
.compute_cost(10) \
.kernel_name("fake_quant_perchannel_grad") \
.partial_flag(True) \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "dout", None, "required", None) \
.input(1, "x", None, "required", None) \
.input(2, "min", None, "required", None) \
.input(3, "max", None, "required", None) \
.output(0, "dx", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
def _less_compare_float32(data_x, data_y):
"""_less_compare_float32 compute"""
input_shape = te.lang.cce.util.shape_to_list(data_x.shape)
min_value = tvm.const(2 ** (-126), dtype=D_TYPE)
max_value = tvm.const(2 ** 62, dtype=D_TYPE)
factor_value = tvm.const(2 ** 2, dtype=D_TYPE)
data_zero = te.lang.cce.broadcast(
tvm.const(0, dtype=D_TYPE), input_shape, D_TYPE)
min_value_tensor = te.lang.cce.vadds(data_zero, min_value)
res_sub = te.lang.cce.vsub(data_y, data_x)
res_min = te.lang.cce.vmin(res_sub, min_value_tensor)
res_max = te.lang.cce.vmax(res_min, data_zero)
res_max_mul = te.lang.cce.vmuls(res_max, max_value)
res_max_mul_max = te.lang.cce.vmuls(res_max_mul, max_value)
res = te.lang.cce.vmuls(res_max_mul_max, factor_value)
return res
@op_info_register(fake_quant_perchannel_grad_op_info)
def _fake_quant_perchannel_grad_tbe():
"""FakeQuantPerChannelGrad TBE register"""
return
@fusion_manager.register("fake_quant_perchannel_grad")
def fake_quant_perchannel_grad_compute(dout, x, min_val, max_val, quant_min, quant_max,
kernel_name="fake_quant_perchannel_grad"):
"""FakeQuantPerChannelGrad"""
x_shape = te.lang.cce.util.shape_to_list(x.shape)
minmax_shape = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = tvm.const(quant_min, x.dtype)
quant_max = tvm.const(quant_max, x.dtype)
quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype)
quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
# FakeQuant Grad
nudge_min_b = te.lang.cce.broadcast(nudge_min, x_shape)
nudge_max_b = te.lang.cce.broadcast(nudge_max, x_shape)
bool_over_min = _less_compare_float32(nudge_min_b, x)
bool_less_max = _less_compare_float32(x, nudge_max_b)
bool_between = te.lang.cce.vmul(bool_over_min, bool_less_max)
res = te.lang.cce.vmul(dout, bool_between)
return res
@util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel_grad"):
"""FakeQuantPerChannelGrad"""
x_shape = x.get("shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
check_list = ["float32", "float16"]
x_dtype = x_dtype.lower()
min_dtype = min_dtype.lower()
max_dtype = max_dtype.lower()
util.check_dtype_rule(x_dtype, check_list)
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1:
shape_c = min_val.get("shape")
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res = fake_quant_perchannel_grad_compute(dout_data, input_data, min_data, max_data,
quant_min, quant_max, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)
tensor_list = [dout_data, input_data, min_data, max_data, res]
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""FakeQuantWithMinMax op""" """FakeQuantPerLayer op"""
from functools import reduce as functools_reduce from functools import reduce as functools_reduce
import te.lang.cce import te.lang.cce
from te import tvm from te import tvm
...@@ -23,20 +22,16 @@ from topi import generic ...@@ -23,20 +22,16 @@ from topi import generic
from topi.cce import util from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \ fake_quant_per_layer_op_info = TBERegOp("FakeQuantPerLayer") \
.fusion_type("ELEMWISE") \ .fusion_type("ELEMWISE") \
.async_flag(False) \ .async_flag(False) \
.binfile_name("fake_quant_with_min_max_vars_ema.so") \ .binfile_name("fake_quant_per_layer.so") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("fake_quant_with_min_max_vars_ema") \ .kernel_name("fake_quant_per_layer") \
.partial_flag(True) \ .partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \ .attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \ .attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \ .attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.input(0, "x", None, "required", None) \ .input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \ .input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \ .input(2, "max", None, "required", None) \
...@@ -49,15 +44,15 @@ fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \ ...@@ -49,15 +44,15 @@ fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \
@op_info_register(fake_quant_op_info) @op_info_register(fake_quant_op_info)
def _fake_quant_tbe(): def _fake_quant_per_layer_tbe():
"""FakeQuantWithMinMax TBE register""" """FakeQuantPerLayer TBE register"""
return return
@fusion_manager.register("fake_quant_with_min_max_vars_ema") @fusion_manager.register("fake_quant_per_layer")
def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, quant_max, def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max,
kernel_name="correction_mul"): kernel_name="fake_quant_per_layer"):
"""FakeQuantWithMinMax""" """FakeQuantPerLayer"""
shape = te.lang.cce.util.shape_to_list(x.shape) shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype) quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype)
...@@ -66,10 +61,13 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, ...@@ -66,10 +61,13 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
# CalNudge(NudgeMinMax) # CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale)) zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point # Nudge zero point
nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min))) nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale) nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale) nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
...@@ -80,17 +78,19 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, ...@@ -80,17 +78,19 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
# FakeQuant # FakeQuant
input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x)) input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x))
nudge_input = te.lang.cce.round(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale)) nudge_input_ = te.lang.cce.vdiv(
te.lang.cce.vsub(input_x, nudge_min), scale)
nudge_input = te.lang.cce.floor(te.lang.cce.vadds(nudge_input_, 0.5))
res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min) res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min)
return res return res
@util.check_input_type(dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) @util.check_input_type(dict, dict, dict, dict, bool, bool, int, str)
def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y, def fake_quant_per_layer(x, min_val, max_val, y,
ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay, symmetric, narrow_range, num_bits,
kernel_name="fake_quant"): kernel_name="fake_quant_per_layer"):
"""FakeQuantWithMinMax""" """FakeQuantPerLayer"""
input_shape = x.get("shape") input_shape = x.get("shape")
input_dtype = x.get("dtype") input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape") min_shape = min_val.get("ori_shape")
...@@ -131,8 +131,8 @@ def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y, ...@@ -131,8 +131,8 @@ def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y,
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res = fake_quant_with_min_max_vars_ema_compute(input_data, min_data, max_data, y, res = fake_quant_per_layer_compute(input_data, min_data, max_data, y,
quant_min, quant_max, kernel_name) quant_min, quant_max, kernel_name)
with tvm.target.cce(): with tvm.target.cce():
sch = generic.auto_schedule(res) sch = generic.auto_schedule(res)
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""FakeQuantWithMinMaxGrad op""" """FakeQuantPerLayerGrad op"""
from functools import reduce as functools_reduce from functools import reduce as functools_reduce
import te.lang.cce import te.lang.cce
...@@ -26,15 +26,14 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType ...@@ -26,15 +26,14 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
SHAPE_SIZE_LIMIT = 2147483648 SHAPE_SIZE_LIMIT = 2147483648
D_TYPE = 'float32' D_TYPE = 'float32'
fake_quant_grad_op_info = TBERegOp("FakeQuantWithMinMaxGrad") \ fake_quant_per_layer_grad_op_info = TBERegOp("FakeQuantPerLayerGrad") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.async_flag(False) \ .async_flag(False) \
.binfile_name("fake_quant_with_min_max_grad.so") \ .binfile_name("fake_quant_per_layer_grad.so") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("fake_quant_with_min_max_grad") \ .kernel_name("fake_quant_per_layer_grad") \
.partial_flag(True) \ .partial_flag(True) \
.attr("num_bits", "optional", "int", "all") \ .attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.attr("symmetric", "optional", "bool", "all") \ .attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \ .attr("narrow_range", "optional", "bool", "all") \
.input(0, "dout", None, "required", None) \ .input(0, "dout", None, "required", None) \
...@@ -57,7 +56,8 @@ def _less_compare_float32(data_x, data_y): ...@@ -57,7 +56,8 @@ def _less_compare_float32(data_x, data_y):
min_value = tvm.const(2 ** (-126), dtype=D_TYPE) min_value = tvm.const(2 ** (-126), dtype=D_TYPE)
max_value = tvm.const(2 ** 62, dtype=D_TYPE) max_value = tvm.const(2 ** 62, dtype=D_TYPE)
factor_value = tvm.const(2 ** 2, dtype=D_TYPE) factor_value = tvm.const(2 ** 2, dtype=D_TYPE)
data_zero = te.lang.cce.broadcast(tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE) data_zero = te.lang.cce.broadcast(
tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE)
min_value_tensor = te.lang.cce.vadds(data_zero, min_value) min_value_tensor = te.lang.cce.vadds(data_zero, min_value)
res_sub = te.lang.cce.vsub(data_y, data_x) res_sub = te.lang.cce.vsub(data_y, data_x)
...@@ -71,16 +71,16 @@ def _less_compare_float32(data_x, data_y): ...@@ -71,16 +71,16 @@ def _less_compare_float32(data_x, data_y):
return res return res
@op_info_register(fake_quant_grad_op_info) @op_info_register(fake_quant_per_layer_grad_op_info)
def _fake_quant_grad_tbe(): def _fake_quant_per_layer_grad_tbe():
"""FakeQuantWithMinMaxGrad TBE register""" """FakeQuantPerLayerGrad TBE register"""
return return
@fusion_manager.register("fake_quant_with_min_max_grad") @fusion_manager.register("fake_quant_per_layer_grad")
def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max,
kernel_name="fake_quant_with_min_max_grad"): kernel_name="fake_quant_per_layer_grad"):
"""FakeQuantWithMinMaxGrad""" """FakeQuantPerLayerGrad"""
shape = te.lang.cce.util.shape_to_list(x.shape) shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = tvm.const(quant_min, x.dtype) quant_min = tvm.const(quant_min, x.dtype)
...@@ -89,10 +89,13 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q ...@@ -89,10 +89,13 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q
quant_max = te.lang.cce.broadcast(quant_max, shape_min) quant_max = te.lang.cce.broadcast(quant_max, shape_min)
# CalNudge(NudgeMinMax) # CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale)) zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point # Nudge zero point
nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min))) nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale) nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale) nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
nudge_min = te.lang.cce.broadcast(nudge_min, shape) nudge_min = te.lang.cce.broadcast(nudge_min, shape)
...@@ -106,11 +109,11 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q ...@@ -106,11 +109,11 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q
return res return res
@util.check_input_type(dict, dict, dict, dict, dict, int, int, bool, bool, str) @util.check_input_type(dict, dict, dict, dict, dict, int, bool, bool, str)
def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx,
num_bits, quant_delay, symmetric, narrow_range, num_bits, symmetric, narrow_range,
kernel_name="fake_quant_with_min_max_grad"): kernel_name="fake_quant_per_layer_grad"):
"""FakeQuantWithMinMaxGrad""" """FakeQuantPerLayerGrad"""
input_shape = x.get("shape") input_shape = x.get("shape")
input_dtype = x.get("dtype") input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape") min_shape = min_val.get("ori_shape")
...@@ -152,8 +155,8 @@ def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, ...@@ -152,8 +155,8 @@ def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx,
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res = fake_quant_with_min_max_grad_compute(dout_data, input_data, min_data, max_data, quant_min, res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data, quant_min,
quant_max, kernel_name) quant_max, kernel_name)
with tvm.target.cce(): with tvm.target.cce():
sch = generic.auto_schedule(res) sch = generic.auto_schedule(res)
......
...@@ -20,10 +20,12 @@ from ..._checkparam import Rel ...@@ -20,10 +20,12 @@ from ..._checkparam import Rel
from ..primitive import PrimitiveWithInfer, prim_attr_register from ..primitive import PrimitiveWithInfer, prim_attr_register
from ...common import dtype as mstype from ...common import dtype as mstype
__all__ = ["FakeQuantWithMinMax", __all__ = ["FakeQuantPerLayer",
"FakeQuantWithMinMaxGrad", "FakeQuantPerLayerGrad",
"FakeQuantWithMinMaxPerChannel", "FakeQuantPerChannel",
"FakeQuantWithMinMaxPerChannelGrad", "FakeQuantPerChannelGrad",
"FakeQuantMinMaxPerLayerUpdate",
"FakeQuantMinMaxPerChannelUpdate",
"BatchNormFold", "BatchNormFold",
"BatchNormFoldGrad", "BatchNormFoldGrad",
"CorrectionMul", "CorrectionMul",
...@@ -36,11 +38,10 @@ __all__ = ["FakeQuantWithMinMax", ...@@ -36,11 +38,10 @@ __all__ = ["FakeQuantWithMinMax",
"BatchNormFold2_D", "BatchNormFold2_D",
"BatchNormFold2GradD", "BatchNormFold2GradD",
"BatchNormFold2GradReduce", "BatchNormFold2GradReduce",
"FakeQuantWithMinMaxUpdate",
] ]
class FakeQuantWithMinMax(PrimitiveWithInfer): class FakeQuantPerLayer(PrimitiveWithInfer):
r""" r"""
Simulate the quantize and dequantize operations in training time. Simulate the quantize and dequantize operations in training time.
...@@ -67,49 +68,67 @@ class FakeQuantWithMinMax(PrimitiveWithInfer): ...@@ -67,49 +68,67 @@ class FakeQuantWithMinMax(PrimitiveWithInfer):
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32) >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32) >>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor) >>> output_tensor = P.FakeQuantPerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
""" """
support_quant_bit = [4, 7, 8] support_quant_bit = [4, 7, 8]
@prim_attr_register @prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, def __init__(self,
num_bits=8,
ema=False,
ema_decay=0.999,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True): training=True):
"""init FakeQuantWithMinMax OP""" """init FakeQuantPerLayer OP"""
if num_bits not in self.support_quant_bit: if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay: if ema and not ema_decay:
raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name) self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) self.symmetric = validator.check_value_type(
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) 'symmetric', symmetric, (bool,), self.name)
self.training = validator.check_value_type('training', training, (bool,), self.name) self.narrow_range = validator.check_value_type(
self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) 'narrow_range', narrow_range, (bool,), self.name)
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) self.training = validator.check_value_type(
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) 'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type(
'quant_delay', quant_delay, (int,), self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['out']) outputs=['out'])
def infer_shape(self, x_shape, min_shape, max_shape): def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) validator.check("min shape", min_shape, "max shape",
validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
min_shape), 1, Rel.EQ, self.name)
return x_shape return x_shape
def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32) valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) validator.check_tensor_type_same(
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) {"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return x_type return x_type
class FakeQuantWithMinMaxGrad(PrimitiveWithInfer): class FakeQuantPerLayerGrad(PrimitiveWithInfer):
r""" r"""
Performs grad of FakeQuantWithMinMax operation. Performs grad of FakeQuantPerLayerGrad operation.
Examples: Examples:
>>> fake_min_max_grad = P.FakeQuantWithMinMaxGrad() >>> fake_min_max_grad = P.FakeQuantPerLayerGrad()
>>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32) >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
>>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32) >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
>>> _min = Tensor(np.array([-4]), mindspore.float32) >>> _min = Tensor(np.array([-4]), mindspore.float32)
...@@ -119,32 +138,48 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer): ...@@ -119,32 +138,48 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer):
support_quant_bit = [4, 7, 8] support_quant_bit = [4, 7, 8]
@prim_attr_register @prim_attr_register
def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False): def __init__(self,
num_bits=8,
quant_delay=0,
symmetric=False,
narrow_range=False):
if num_bits not in self.support_quant_bit: if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) self.num_bits = validator.check_integer(
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) 'num_bits', num_bits, 0, Rel.GT, self.name)
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) self.quant_delay = validator.check_value_type(
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) 'quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.init_prim_io_names(
inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) validator.check("dout shape", dout_shape, "x shape",
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) x_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
min_shape), 1, Rel.EQ, self.name)
return dout_shape return dout_shape
def infer_dtype(self, dout_type, x_type, min_type, max_type): def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32) valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name) validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) validator.check_tensor_type_same(
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) {"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return dout_type return dout_type
class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): class FakeQuantPerChannel(PrimitiveWithInfer):
r""" r"""
Simulate the quantize and dequantize operations in training time base on per channel. Simulate the quantize and dequantize operations in training time base on per channel.
...@@ -168,53 +203,73 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): ...@@ -168,53 +203,73 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
- Tensor, has the same type as input. - Tensor, has the same type as input.
Examples: Examples:
>>> fake_quant = P.FakeQuantWithMinMaxPerChannel() >>> fake_quant = P.FakeQuantPerChannel()
>>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32) >>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32)
>>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32) >>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32)
>>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32) >>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32)
>>> result = fake_quant(input_x, _min, _max) >>> result = fake_quant(input_x, _min, _max)
""" """
support_quant_bit = [4, 7, 8] support_quant_bit = [4, 7, 8]
channel_axis = 0
@prim_attr_register @prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, def __init__(self,
training=True): num_bits=8,
"""init FakeQuantWithMinMaxPerChannel OP""" ema=False,
ema_decay=0.999,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True,
channel_axis=1):
"""init FakeQuantPerChannel OP"""
if num_bits not in self.support_quant_bit: if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' Attr \'num_bits\' is not support.") raise ValueError(
f"For '{self.name}' Attr \'num_bits\' is not support.")
if ema and not ema_decay: if ema and not ema_decay:
raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name) self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) self.symmetric = validator.check_value_type(
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) 'symmetric', symmetric, (bool,), self.name)
self.training = validator.check_value_type('training', training, (bool,), self.name) self.narrow_range = validator.check_value_type(
self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) 'narrow_range', narrow_range, (bool,), self.name)
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) self.training = validator.check_value_type(
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) 'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type(
'quant_delay', quant_delay, (int,), self.name)
self.channel_axis = validator.check_integer(
'channel_axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
def infer_shape(self, x_shape, min_shape, max_shape): def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) validator.check_integer(
validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) "min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer(
"max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
return x_shape return x_shape
def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32) valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) validator.check_tensor_type_same(
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) {"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return x_type return x_type
class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): class FakeQuantPerChannelGrad(PrimitiveWithInfer):
r""" r"""
Performs grad of FakeQuantWithMinMaxPerChannel operation. Performs grad of FakeQuantPerChannelGrad operation.
Examples: Examples:
>>> fqmmpc_grad = P.FakeQuantWithMinMaxPerChannelGrad() >>> fqmmpc_grad = P.FakeQuantPerChannelGrad()
>>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32) >>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32)
>>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32) >>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32)
>>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32) >>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32)
...@@ -224,16 +279,29 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): ...@@ -224,16 +279,29 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
support_quant_bit = [4, 7, 8] support_quant_bit = [4, 7, 8]
@prim_attr_register @prim_attr_register
def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False): def __init__(self,
"""init FakeQuantWithMinMaxPerChannel Fill""" num_bits=8,
quant_delay=0,
symmetric=False,
narrow_range=False,
channel_axis=1):
"""init FakeQuantPerChannelGrad Fill"""
if num_bits not in self.support_quant_bit: if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) self.num_bits = validator.check_integer(
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) 'num_bits', num_bits, 0, Rel.GT, self.name)
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) self.quant_delay = validator.check_value_type(
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) 'quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
validator.check("dout shape", dout_shape, "x shape", x_shape) validator.check("dout shape", dout_shape, "x shape", x_shape)
...@@ -242,10 +310,13 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): ...@@ -242,10 +310,13 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
def infer_dtype(self, dout_type, x_type, min_type, max_type): def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32) valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name) validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) validator.check_tensor_type_same(
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) {"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return dout_type return dout_type
...@@ -744,17 +815,14 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer): ...@@ -744,17 +815,14 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
return dout_type, dout_type return dout_type, dout_type
class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer): class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
r""" r"""
Simulate the quantize and dequantize operations in training time. Update min and max value for fake quant per layer op.
Args: Args:
num_bits (int) : Number bits for aware quantilization. Default: 8. num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False. ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
simulate aware quantize funcion. After delay step in training time begin simulate the aware
quantize funcion. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True. training (bool): Training the network or not. Default: True.
...@@ -776,36 +844,121 @@ class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer): ...@@ -776,36 +844,121 @@ class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer):
support_quant_bit = [4, 7, 8] support_quant_bit = [4, 7, 8]
@prim_attr_register @prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True): training=True):
"""init FakeQuantWithMinMax OP""" """init FakeQuantMinMaxPerLayerUpdate OP"""
from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update
if num_bits not in self.support_quant_bit: if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay: if ema and not ema_decay:
raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name) self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) self.symmetric = validator.check_value_type(
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) 'symmetric', symmetric, (bool,), self.name)
self.training = validator.check_value_type('training', training, (bool,), self.name) self.narrow_range = validator.check_value_type(
self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) 'narrow_range', narrow_range, (bool,), self.name)
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) self.training = validator.check_value_type(
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) 'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up']) outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape): def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) validator.check("min shape", min_shape, "max shape",
validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32) valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) validator.check_tensor_type_same(
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) {"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type return min_type, max_type
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册