diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.cc index ade7c32da05ac2fa8367ec7b90b522b0a2d5282c..31f37bd733355f928c7b65f41f44796d87a1509e 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.cc @@ -171,6 +171,6 @@ bool FakeQuantGpuKernel::Launch(const std::vector &inputs, const std return true; } -MS_REG_GPU_KERNEL(FakeQuantWithMinMax, FakeQuantGpuKernel) +MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantGpuKernel) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc index d92696d1bd6e43f2ee7f3a5b891977b9e5bd180c..db025945018b414e0cdcb701a6f7c8cdb7815d88 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc @@ -153,6 +153,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector &inputs, const return true; } -MS_REG_GPU_KERNEL(FakeQuantWithMinMaxGrad, FakeQuantGradGpuKernel) +MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantGradGpuKernel) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc index 083bf7f011e8b580f0ab23e5679afd4e8b4a3bfc..ea1fea33227a136f1da3b1fb9206355e77bbe4ba 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc @@ -175,6 +175,6 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector &inputs, return true; } -MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannel, FakeQuantPerChannelGpuKernel) +MS_REG_GPU_KERNEL(FakeQuantPerChannel, FakeQuantPerChannelGpuKernel) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.cc index 88c976285cc8ab4b54a19e7ec74ae6d56f51ed0f..b43e178eb1d4c949d53279a067fc16294cbbb421 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.cc @@ -143,6 +143,6 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector &inp return true; } -MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannelGrad, FakeQuantPerChannelGradGpuKernel) +MS_REG_GPU_KERNEL(FakeQuantPerChannelGrad, FakeQuantPerChannelGradGpuKernel) } // namespace kernel } // namespace mindspore diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 13421ce908c1f3fa929cd141b411c1f2245a98ca..77fda2162e8e67c1e474ec8a573109c657b790cf 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -14,6 +14,7 @@ # ============================================================================ """Aware quantization.""" +from functools import partial import numpy as np import mindspore.common.dtype as mstype from mindspore.ops import operations as P @@ -101,10 +102,9 @@ class BatchNormFoldCell(Cell): return batch_mean, batch_std, running_mean, running_std -class FakeQuantWithMinMaxD(Cell): +class FakeQuantWithMinMaxAscend(Cell): r""" - Aware Quantization training op of ascend. This OP provide Fake quantization observer - function on data with min and max. + Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. Args: min_init (int, list): The dimension of channel or 1(layer). Default: -6. @@ -125,7 +125,7 @@ class FakeQuantWithMinMaxD(Cell): Tensor, with the same type and shape as the `x`. Examples: - >>> fake_quant = nn.FakeQuantWithMinMaxD() + >>> fake_quant = FakeQuantWithMinMax() >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) >>> result = fake_quant(input_x) """ @@ -137,75 +137,77 @@ class FakeQuantWithMinMaxD(Cell): ema=False, ema_decay=0.999, per_channel=False, - channel_size=1, + channel_axis=1, + out_channels=1, quant_delay=0, symmetric=False, narrow_range=False, training=True): - """init FakeQuantWithMinMax ascend layer""" - super(FakeQuantWithMinMaxD, self).__init__() - + """init FakeQuantWithMinMaxAscend layer""" + super(FakeQuantWithMinMaxAscend, self).__init__() self.min_init = min_init - self.num_bits = num_bits self.max_init = max_init + self.num_bits = num_bits self.ema = ema self.ema_decay = ema_decay self.per_channel = per_channel - self.channel_size = channel_size + self.channel_axis = channel_axis self.quant_delay = quant_delay self.symmetric = symmetric self.narrow_range = narrow_range self.training = training - if not per_channel: - self.fake_quant = P.FakeQuantWithMinMax(num_bits=self.num_bits, - ema=self.ema, - ema_decay=self.ema_decay, - quant_delay=self.quant_delay, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=training) - self.ema_update = P.FakeQuantWithMinMaxUpdate(num_bits=self.num_bits, - ema=self.ema, - ema_decay=self.ema_decay, - quant_delay=self.quant_delay, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=training) - else: - raise RuntimeError("not support per channel") + # init tensor min and max for fake quant op + if isinstance(min_init, int): + min_array = np.array([min_init]).reshape(1).astype(np.float32) + max_array = np.array([max_init]).reshape(1).astype(np.float32) + elif isinstance(min_init, list): + min_array = np.array([self.min_init for i in range( + 0, self.out_channels)]).astype(np.float32) + max_array = np.array([self.max_init for i in range( + 0, self.out_channels)]).astype(np.float32) + self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) + self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) - if isinstance(min_init, Parameter): - self.minq = min_init - self.maxq = max_init + if per_channel: + quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis) + ema_fun = partial(P.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis) else: - self.minq = Parameter(Tensor(np.array([min_init]).astype(np.float32)), - name='quant_min', - requires_grad=False) - self.maxq = Parameter(Tensor(np.array([max_init]).astype(np.float32)), - name='quant_max', - requires_grad=False) - self.reduce_min = P.ReduceMin() - self.reduce_max = P.ReduceMax() + quant_fun = P.FakeQuantPerLayer + ema_fun = P.FakeQuantMinMaxPerLayerUpdate + + self.fake_quant = quant_fun(num_bits=self.num_bits, + ema=self.ema, + ema_decay=self.ema_decay, + quant_delay=self.quant_delay, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + training=self.training) + self.ema_update = ema_fun(num_bits=self.num_bits, + ema=self.ema, + ema_decay=self.ema_decay, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + training=self.training) def extend_repr(self): - s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format( - self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size, - self.quant_delay) + s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format( + self.min_init, self.max_init, self.ema, self.ema_decay, + self.per_channel, self.quant_delay, self.channel_axis) return s - def construct(self, x, minq, maxq): - if self.training: - min_up, max_up = self.ema_update(x, minq, maxq) + def construct(self, x): + if self.update: + min_up, max_up = self.ema_update(x, self.minq, self.maxq) out = self.fake_quant(x, min_up, max_up) P.Assign()(self.minq, min_up) P.Assign()(self.maxq, max_up) else: - out = self.fake_quant(x, minq, maxq) + out = self.fake_quant(x, self.minq, self.maxq) return out -class FakeQuantWithMinMax(Cell): +class FakeQuantWithMinMaxGPU(Cell): r""" Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. @@ -240,98 +242,69 @@ class FakeQuantWithMinMax(Cell): ema=False, ema_decay=0.999, per_channel=False, + channel_axis=1, out_channels=1, quant_delay=0, symmetric=False, - narrow_range=False): - """init FakeQuantWithMinMax layer""" - super(FakeQuantWithMinMax, self).__init__() - + narrow_range=False, + training=True): + super(FakeQuantWithMinMaxGPU, self).__init__() self.min_init = min_init - self.num_bits = num_bits self.max_init = max_init + self.num_bits = num_bits self.ema = ema self.ema_decay = ema_decay self.per_channel = per_channel - self.out_channels = out_channels + self.channel_axis = channel_axis self.quant_delay = quant_delay self.symmetric = symmetric self.narrow_range = narrow_range + self.training = training - if per_channel: - min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32) - max_array = np.array([self.max_init for i in range(0, self.channel_size)]).astype(np.float32) - self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) - self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) - self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, - ema=self.ema, - ema_decay=self.ema_decay, - quant_delay=self.quant_delay, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=True) - self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, - ema=self.ema, - ema_decay=self.ema_decay, - quant_delay=self.quant_delay, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=False) - else: + # init tensor min and max for fake quant op + if isinstance(min_init, int): min_array = np.array([min_init]).reshape(1).astype(np.float32) max_array = np.array([max_init]).reshape(1).astype(np.float32) - self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) - self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) - if context.get_context('device_target') == "Ascend": - self.fake_quant_train = FakeQuantWithMinMaxD(num_bits=self.num_bits, - ema=self.ema, - ema_decay=self.ema_decay, - quant_delay=self.quant_delay, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=True, - min_init=self.minq, - max_init=self.maxq) - self.fake_quant_infer = FakeQuantWithMinMaxD(num_bits=self.num_bits, - ema=self.ema, - ema_decay=self.ema_decay, - quant_delay=self.quant_delay, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=False, - min_init=self.minq, - max_init=self.maxq) - elif context.get_context('device_target') == "GPU": - self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits, - ema=self.ema, - ema_decay=self.ema_decay, - quant_delay=self.quant_delay, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=True) - self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits, - ema=self.ema, - ema_decay=ema_decay, - quant_delay=quant_delay, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=False) - else: - raise ValueError("Not support platform.") + elif isinstance(min_init, list): + min_array = np.array([self.min_init for i in range( + 0, self.out_channels)]).astype(np.float32) + max_array = np.array([self.max_init for i in range( + 0, self.out_channels)]).astype(np.float32) + self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) + self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) + + if per_channel: + quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis) + else: + quant_fun = P.FakeQuantPerLayer + self.fake_quant = quant_fun(num_bits=self.num_bits, + ema=self.ema, + ema_decay=ema_decay, + quant_delay=quant_delay, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + training=self.training) def extend_repr(self): - s = 'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'.format( - self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.quant_delay) + s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format( + self.min_init, self.max_init, self.ema, self.ema_decay, + self.per_channel, self.quant_delay, self.channel_axis) return s def construct(self, x): - if self.training: - out = self.fake_quant_train(x, self.minq, self.maxq) - else: - out = self.fake_quant_infer(x, self.minq, self.maxq) + out = self.fake_quant(x, self.minq, self.maxq) return out +def FakeQuantWithMinMax(**kwargs): + if context.get_context('device_target') == "Ascend": + out = FakeQuantWithMinMaxAscend(**kwargs) + if context.get_context('device_target') == "GPU": + out = FakeQuantWithMinMaxGPU(**kwargs) + else: + raise ValueError("Not support platform or channel mode.") + return out + class Conv2dBatchNormQuant(Cell): r""" 2D convolution with BatchNormal op folded layer. @@ -420,7 +393,6 @@ class Conv2dBatchNormQuant(Cell): self.per_channel = per_channel self.symmetric = symmetric self.narrow_range = narrow_range - self.channel_axis = int(group > 1) self.is_gpu = context.get_context('device_target') == "GPU" # initialize convolution op and Parameter @@ -435,6 +407,7 @@ class Conv2dBatchNormQuant(Cell): dilation=self.dilation) if weight_init is None: weight_init = initializer('normal', [1, in_channels, *self.kernel_size]) + channel_axis = 1 else: self.conv = P.Conv2D(out_channel=out_channels, kernel_size=self.kernel_size, @@ -445,6 +418,7 @@ class Conv2dBatchNormQuant(Cell): group=group) if weight_init is None: weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size]) + channel_axis = 0 self.weight = Parameter(weight_init, name='weight') # initialize batchnorm Parameter @@ -472,7 +446,7 @@ class Conv2dBatchNormQuant(Cell): symmetric=symmetric, narrow_range=narrow_range) self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) - self.correct_mul = P.CorrectionMul(self.channel_axis) + self.correct_mul = P.CorrectionMul(channel_axis) if context.get_context('device_target') == "Ascend": self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn) self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0) @@ -520,7 +494,7 @@ class Conv2dBatchNormQuant(Cell): out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std) F.control_depend(out, self.assignadd(self.step, self.one)) else: - out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std) + out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, running_std, running_mean, running_std) return out diff --git a/mindspore/ops/_grad/grad_quant_ops.py b/mindspore/ops/_grad/grad_quant_ops.py index 0a5cd5430691cfb4ee5f328b2ef56654af0e4206..7ab9192040977cc172b2306364f2731abe64b23d 100644 --- a/mindspore/ops/_grad/grad_quant_ops.py +++ b/mindspore/ops/_grad/grad_quant_ops.py @@ -20,10 +20,11 @@ from .grad_base import bprop_getters from ..composite.multitype_ops.zeros_like_impl import zeros_like -@bprop_getters.register(P.FakeQuantWithMinMax) +@bprop_getters.register(P.FakeQuantPerLayer) def get_bprop_fakequant_with_minmax(self): - """Generate bprop for FakeQuantWithMinMax for GPU and Ascend""" - op = P.FakeQuantWithMinMaxGrad(num_bits=self.num_bits, quant_delay=self.quant_delay) + """Generate bprop for FakeQuantPerLayer for GPU and Ascend""" + op = P.FakeQuantPerLayerGrad( + num_bits=self.num_bits, quant_delay=self.quant_delay) def bprop(x, x_min, x_max, out, dout): dx = op(dout, x, x_min, x_max) @@ -32,10 +33,14 @@ def get_bprop_fakequant_with_minmax(self): return bprop -@bprop_getters.register(P.FakeQuantWithMinMaxPerChannel) +@bprop_getters.register(P.FakeQuantPerChannel) def get_bprop_fakequant_with_minmax_perchannel(self): - """Generate bprop for FakeQuantWithMinMaxPerChannel for GPU""" - op = P.FakeQuantWithMinMaxPerChannelGrad(num_bits=self.num_bits, quant_delay=self.quant_delay) + """Generate bprop for FakeQuantPerChannel""" + op = P.FakeQuantPerChannelGrad(num_bits=self.num_bits, + quant_delay=self.quant_delay, + symmetric=self.symmetric, + narrow_range=self.symmetric, + channel_axis=self.channel_axis) def bprop(x, x_min, x_max, out, dout): dx = op(dout, x, x_min, x_max) @@ -77,7 +82,7 @@ def get_bprop_batchnorm_fold2(self): d_batch_std, d_batch_mean, d_beta, d_gamma, d_x = op_f(dout, x, gamma, batch_std, batch_mean, running_std, running_mean, global_step) return d_x, d_beta, d_gamma, d_batch_std, d_batch_mean, zeros_like(running_std), zeros_like(running_mean), \ - zeros_like(global_step) + zeros_like(global_step) return bprop @@ -117,9 +122,19 @@ def get_bprop_batchnorm_fold2_(self): return bprop -@bprop_getters.register(P.FakeQuantWithMinMaxUpdate) -def get_bprop_fakequant_with_minmax_update(self): - """Generate bprop for FakeQuantWithMinMaxUpdate for Ascend""" +@bprop_getters.register(P.FakeQuantMinMaxPerLayerUpdate) +def get_bprop_fakequant_with_minmax_per_layer_update(self): + """Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend""" + + def bprop(x, x_min, x_max, out, dout): + return zeros_like(x), zeros_like(x_min), zeros_like(x_max) + + return bprop + + +@bprop_getters.register(P.FakeQuantMinMaxPerChannelUpdate) +def get_bprop_fakequant_with_minmax_per_channel_update(self): + """Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend""" def bprop(x, x_min, x_max, out, dout): return zeros_like(x), zeros_like(x_min), zeros_like(x_max) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py new file mode 100644 index 0000000000000000000000000000000000000000..7694753d8f574c572a3a8f3be3782767c7449fc1 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py @@ -0,0 +1,135 @@ + +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FakeQuantMinMaxPerChannelUpdate op""" +import te.lang.cce +from te import tvm +from te.platform.fusion_manager import fusion_manager +from topi import generic +from topi.cce import util +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + + +fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("fake_quant_min_max_per_channel_update.so") \ + .compute_cost(10) \ + .kernel_name("fake_quant_min_max_per_channel_update") \ + .partial_flag(True) \ + .attr("ema", "optional", "bool", "all") \ + .attr("ema_decay", "optional", "float", "all") \ + .attr("symmetric", "optional", "bool", "all") \ + .attr("narrow_range", "optional", "bool", "all") \ + .attr("training", "optional", "bool", "all") \ + .attr("num_bits", "optional", "int", "all") \ + .attr("channel_axis", "optional", "int", "all") \ + .input(0, "x", None, "required", None) \ + .input(1, "min", None, "required", None) \ + .input(2, "max", None, "required", None) \ + .output(0, "min_up", True, "required", "all") \ + .output(1, "max_up", True, "required", "all") \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(fake_quant_min_max_per_channel_update_op_info) +def _fake_quant_min_max_per_channel_update_tbe(): + """FakeQuantPerChannelUpdate TBE register""" + return + + +@fusion_manager.register("fake_quant_min_max_per_channel_update") +def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val, + ema, ema_decay, quant_min, quant_max, training, channel_axis, + kernel_name="fake_quant_min_max_per_channel_update"): + """FakeQuantPerChannelUpdate compute""" + shape_min = te.lang.cce.util.shape_to_list(min_val.shape) + + if not ema: + ema_decay = 0.0 + if training: + # CalMinMax + axis = [0, 2, 3] + x_min = te.lang.cce.reduce_min(x, axis=axis) + x_max = te.lang.cce.reduce_max(x, axis=axis) + x_min = te.lang.cce.broadcast(x_min, shape_min) + x_max = te.lang.cce.broadcast(x_max, shape_min) + min_val = te.lang.cce.vadd(te.lang.cce.vmuls( + min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) + max_val = te.lang.cce.vadd(te.lang.cce.vmuls( + max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) + min_val = te.lang.cce.vmins(min_val, 0) + max_val = te.lang.cce.vmaxs(max_val, 0) + + return [min_val, max_val] + + +@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) +def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, + ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis, + kernel_name="fake_quant_min_max_per_channel_update"): + """FakeQuantPerLayer op""" + x_shape = x.get("ori_shape") + x_format = x.get("format") + x_dtype = x.get("dtype") + min_shape = min_val.get("ori_shape") + min_dtype = min_val.get("dtype") + max_shape = max_val.get("ori_shape") + max_dtype = max_val.get("dtype") + + util.check_kernel_name(kernel_name) + util.check_shape_rule(x_shape) + util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis]) + util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis]) + util.check_tensor_shape_size(x_shape) + util.check_tensor_shape_size(min_shape) + util.check_tensor_shape_size(max_shape) + + check_list = ["float32", "float16"] + x_dtype = x_dtype.lower() + min_dtype = min_dtype.lower() + max_dtype = max_dtype.lower() + util.check_dtype_rule(x_dtype, check_list) + util.check_dtype_rule(min_dtype, check_list) + util.check_dtype_rule(max_dtype, check_list) + + if symmetric: + quant_min = 0 - 2 ** (num_bits - 1) + quant_max = 2 ** (num_bits - 1) - 1 + else: + quant_min = 0 + quant_max = 2 ** num_bits - 1 + if narrow_range: + quant_min = quant_min + 1 + + shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]] + input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype) + min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) + max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) + res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data, + ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name) + + with tvm.target.cce(): + sch = generic.auto_schedule(res_list) + + tensor_list = [input_data, min_data, max_data] + list(res_list) + config = {"print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list} + + te.lang.cce.cce_build_code(sch, config) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_update.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py similarity index 75% rename from mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_update.py rename to mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py index 58eeeda9fbb20a0bb1ff3f5a82a1fd0ff0cc2bc2..0ad2315bb3f65e61ec21b8c8340ce5fd798aaf35 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_update.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -"""FakeQuantWithMinMaxUpdate op""" +"""FakeQuantMinMaxPerLayerUpdate op""" from functools import reduce as functools_reduce import te.lang.cce from te import tvm @@ -23,12 +23,12 @@ from topi.cce import util from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType -fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ +fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("fake_quant_with_min_max_update.so") \ + .binfile_name("fake_quant_minmax_update.so") \ .compute_cost(10) \ - .kernel_name("fake_quant_with_min_max_update") \ + .kernel_name("fake_quant_minmax_update") \ .partial_flag(True) \ .attr("ema", "optional", "bool", "all") \ .attr("ema_decay", "optional", "float", "all") \ @@ -36,7 +36,6 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ .attr("narrow_range", "optional", "bool", "all") \ .attr("training", "optional", "bool", "all") \ .attr("num_bits", "optional", "int", "all") \ - .attr("quant_delay", "optional", "int", "all") \ .input(0, "x", None, "required", None) \ .input(1, "min", None, "required", None) \ .input(2, "max", None, "required", None) \ @@ -47,16 +46,16 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ .get_op_info() -@op_info_register(fake_quant_update_op_info) -def _fake_quant_update_tbe(): - """_FakeQuantWithMinMaxUpdate TBE register""" +@op_info_register(fake_quant_minmax_update_op_info) +def _fake_quant_minmax_update_tbe(): + """FakeQuantMinMaxPerLayerUpdate TBE register""" return -@fusion_manager.register("fake_quant_with_min_max_update") -def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training, - kernel_name="fake_quant_update"): - """FakeQuantWithMinMaxUpdate compute""" +@fusion_manager.register("fake_quant_minmax_update") +def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training, + kernel_name="fake_quant_minmax_update"): + """FakeQuantMinMaxPerLayerUpdate compute""" shape = te.lang.cce.util.shape_to_list(x.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape) min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) @@ -70,19 +69,21 @@ def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay, x_max = te.lang.cce.reduce_max(x, axis=axis) x_min = te.lang.cce.broadcast(x_min, shape_min) x_max = te.lang.cce.broadcast(x_max, shape_min) - min_val = te.lang.cce.vadd(te.lang.cce.vmuls(min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) - max_val = te.lang.cce.vadd(te.lang.cce.vmuls(max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) + min_val = te.lang.cce.vadd(te.lang.cce.vmuls( + min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) + max_val = te.lang.cce.vadd(te.lang.cce.vmuls( + max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) min_val = te.lang.cce.vmins(min_val, 0) max_val = te.lang.cce.vmaxs(max_val, 0) return [min_val, max_val] -@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) -def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up, - ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay, - kernel_name="fake_quant_update"): - """FakeQuantWithMinMax op""" +@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str) +def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up, + ema, ema_decay, symmetric, narrow_range, training, num_bits, + kernel_name="fake_quant_minmax_update"): + """FakeQuantPerLayer op""" input_shape = x.get("shape") input_dtype = x.get("dtype") min_shape = min_val.get("ori_shape") @@ -123,8 +124,8 @@ def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up, input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) - res_list = fake_quant_with_min_max_update_compute(input_data, min_data, max_data, - ema, ema_decay, quant_min, quant_max, training, kernel_name) + res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data, + ema, ema_decay, quant_min, quant_max, training, kernel_name) with tvm.target.cce(): sch = generic.auto_schedule(res_list) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py new file mode 100644 index 0000000000000000000000000000000000000000..827d7a433c5cf41adaf288afa59cf1dc8899ae54 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py @@ -0,0 +1,145 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FakeQuantPerChannel op""" +import te.lang.cce +from te import tvm +from te.platform.fusion_manager import fusion_manager +from topi import generic +from topi.cce import util +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +fake_quant_perchannel_op_info = TBERegOp("FakeQuantPerChannel") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("fake_quant_perchannel.so") \ + .compute_cost(10) \ + .kernel_name("fake_quant_perchannel") \ + .partial_flag(True) \ + .attr("symmetric", "optional", "bool", "all") \ + .attr("narrow_range", "optional", "bool", "all") \ + .attr("num_bits", "optional", "int", "all") \ + .attr("channel_axis", "optional", "int", "all") \ + .input(0, "x", None, "required", None) \ + .input(1, "min", None, "required", None) \ + .input(2, "max", None, "required", None) \ + .output(0, "y", True, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(fake_quant_perchannel_op_info) +def _fake_quant_perchannel_tbe(): + """FakeQuantPerChannel TBE register""" + return + + +@fusion_manager.register("fake_quant_perchannel") +def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max, + kernel_name="fake_quant_perchannel"): + """FakeQuantPerChannel""" + x_shape = te.lang.cce.util.shape_to_list(x.shape) + minmax_shape = te.lang.cce.util.shape_to_list(min_val.shape) + quant_min = tvm.const(quant_min, x.dtype) + quant_max = tvm.const(quant_max, x.dtype) + quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype) + quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype) + + # CalNudge(NudgeMinMax) + scale = te.lang.cce.vdiv(te.lang.cce.vsub( + max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) + zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale)) + + # Nudge zero point + nudge_zp_ = te.lang.cce.vmin( + quant_max, te.lang.cce.vmax(quant_min, zp_from_min)) + nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5)) + nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale) + nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale) + + # FakeQuant + nudge_min_b = te.lang.cce.broadcast(nudge_min, x_shape) + nudge_max_b = te.lang.cce.broadcast(nudge_max, x_shape) + scale_b = te.lang.cce.broadcast(scale, x_shape) + + input_x = te.lang.cce.vmin(nudge_max_b, te.lang.cce.vmax(nudge_min_b, x)) + nudge_input_ = te.lang.cce.vdiv( + te.lang.cce.vsub(input_x, nudge_min_b), scale_b) + nudge_input = te.lang.cce.floor(te.lang.cce.vadds(nudge_input_, 0.5)) + res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale_b), nudge_min_b) + + return res + + +@util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str) +def fake_quant_perchannel(x, min_val, max_val, y, + symmetric, narrow_range, num_bits, channel_axis, + kernel_name="fake_quant_perchannel"): + """FakeQuantPerChannel""" + x_shape = x.get("shape") + x_format = x.get("format") + x_dtype = x.get("dtype") + min_shape = min_val.get("ori_shape") + min_dtype = min_val.get("dtype") + max_shape = max_val.get("ori_shape") + max_dtype = max_val.get("dtype") + + util.check_kernel_name(kernel_name) + util.check_shape_rule(x_shape) + util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis]) + util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis]) + util.check_tensor_shape_size(x_shape) + util.check_tensor_shape_size(min_shape) + util.check_tensor_shape_size(max_shape) + + check_list = ["float32", "float16"] + x_dtype = x_dtype.lower() + min_dtype = min_dtype.lower() + max_dtype = max_dtype.lower() + util.check_dtype_rule(x_dtype, check_list) + util.check_dtype_rule(min_dtype, check_list) + util.check_dtype_rule(max_dtype, check_list) + + if symmetric: + quant_min = 0 - 2 ** (num_bits - 1) + quant_max = 2 ** (num_bits - 1) - 1 + else: + quant_min = 0 + quant_max = 2 ** num_bits - 1 + if narrow_range: + quant_min = quant_min + 1 + + shape_c = [1] * len(x_shape) + shape_c[channel_axis] = min_val.get("ori_shape")[0] + if x_format == "NC1HWC0" and channel_axis == 1: + shape_c = min_val.get("shape") + input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype) + min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) + max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) + res = fake_quant_perchannel_compute(input_data, min_data, max_data, y, + quant_min, quant_max, kernel_name) + + with tvm.target.cce(): + sch = generic.auto_schedule(res) + + tensor_list = [input_data, min_data, max_data, res] + config = {"print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list} + + te.lang.cce.cce_build_code(sch, config) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..91fb694154f79acc7c8dceab4a7c6753a0484ad9 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py @@ -0,0 +1,171 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FakeQuantPerChannelGrad op""" +import te.lang.cce +from te import tvm +from te.platform.fusion_manager import fusion_manager +from topi import generic +from topi.cce import util +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +SHAPE_SIZE_LIMIT = 2147483648 +D_TYPE = 'float32' + +fake_quant_perchannel_grad_op_info = TBERegOp("FakeQuantPerChannelGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("fake_quant_perchannel_grad.so") \ + .compute_cost(10) \ + .kernel_name("fake_quant_perchannel_grad") \ + .partial_flag(True) \ + .attr("symmetric", "optional", "bool", "all") \ + .attr("narrow_range", "optional", "bool", "all") \ + .attr("num_bits", "optional", "int", "all") \ + .attr("channel_axis", "optional", "int", "all") \ + .input(0, "dout", None, "required", None) \ + .input(1, "x", None, "required", None) \ + .input(2, "min", None, "required", None) \ + .input(3, "max", None, "required", None) \ + .output(0, "dx", True, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +def _less_compare_float32(data_x, data_y): + """_less_compare_float32 compute""" + input_shape = te.lang.cce.util.shape_to_list(data_x.shape) + min_value = tvm.const(2 ** (-126), dtype=D_TYPE) + max_value = tvm.const(2 ** 62, dtype=D_TYPE) + factor_value = tvm.const(2 ** 2, dtype=D_TYPE) + data_zero = te.lang.cce.broadcast( + tvm.const(0, dtype=D_TYPE), input_shape, D_TYPE) + min_value_tensor = te.lang.cce.vadds(data_zero, min_value) + + res_sub = te.lang.cce.vsub(data_y, data_x) + res_min = te.lang.cce.vmin(res_sub, min_value_tensor) + res_max = te.lang.cce.vmax(res_min, data_zero) + + res_max_mul = te.lang.cce.vmuls(res_max, max_value) + res_max_mul_max = te.lang.cce.vmuls(res_max_mul, max_value) + res = te.lang.cce.vmuls(res_max_mul_max, factor_value) + + return res + + +@op_info_register(fake_quant_perchannel_grad_op_info) +def _fake_quant_perchannel_grad_tbe(): + """FakeQuantPerChannelGrad TBE register""" + return + + +@fusion_manager.register("fake_quant_perchannel_grad") +def fake_quant_perchannel_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, + kernel_name="fake_quant_perchannel_grad"): + """FakeQuantPerChannelGrad""" + x_shape = te.lang.cce.util.shape_to_list(x.shape) + minmax_shape = te.lang.cce.util.shape_to_list(min_val.shape) + quant_min = tvm.const(quant_min, x.dtype) + quant_max = tvm.const(quant_max, x.dtype) + quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype) + quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype) + + # CalNudge(NudgeMinMax) + scale = te.lang.cce.vdiv(te.lang.cce.vsub( + max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) + zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale)) + + # Nudge zero point + nudge_zp_ = te.lang.cce.vmin( + quant_max, te.lang.cce.vmax(quant_min, zp_from_min)) + nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5)) + nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale) + nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale) + + # FakeQuant Grad + nudge_min_b = te.lang.cce.broadcast(nudge_min, x_shape) + nudge_max_b = te.lang.cce.broadcast(nudge_max, x_shape) + + bool_over_min = _less_compare_float32(nudge_min_b, x) + bool_less_max = _less_compare_float32(x, nudge_max_b) + bool_between = te.lang.cce.vmul(bool_over_min, bool_less_max) + res = te.lang.cce.vmul(dout, bool_between) + + return res + + +@util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str) +def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx, + symmetric, narrow_range, num_bits, channel_axis, + kernel_name="fake_quant_perchannel_grad"): + """FakeQuantPerChannelGrad""" + x_shape = x.get("shape") + x_format = x.get("format") + x_dtype = x.get("dtype") + min_shape = min_val.get("ori_shape") + min_dtype = min_val.get("dtype") + max_shape = max_val.get("ori_shape") + max_dtype = max_val.get("dtype") + + util.check_kernel_name(kernel_name) + util.check_shape_rule(x_shape) + util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis]) + util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis]) + util.check_tensor_shape_size(x_shape) + util.check_tensor_shape_size(min_shape) + util.check_tensor_shape_size(max_shape) + + check_list = ["float32", "float16"] + x_dtype = x_dtype.lower() + min_dtype = min_dtype.lower() + max_dtype = max_dtype.lower() + util.check_dtype_rule(x_dtype, check_list) + util.check_dtype_rule(min_dtype, check_list) + util.check_dtype_rule(max_dtype, check_list) + + if symmetric: + quant_min = 0 - 2 ** (num_bits - 1) + quant_max = 2 ** (num_bits - 1) - 1 + else: + quant_min = 0 + quant_max = 2 ** num_bits - 1 + if narrow_range: + quant_min = quant_min + 1 + + shape_c = [1] * len(x_shape) + shape_c[channel_axis] = min_val.get("ori_shape")[0] + if x_format == "NC1HWC0" and channel_axis == 1: + shape_c = min_val.get("shape") + dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype) + input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype) + min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) + max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) + res = fake_quant_perchannel_grad_compute(dout_data, input_data, min_data, max_data, + quant_min, quant_max, kernel_name) + + with tvm.target.cce(): + sch = generic.auto_schedule(res) + + tensor_list = [dout_data, input_data, min_data, max_data, res] + config = {"print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list} + + te.lang.cce.cce_build_code(sch, config) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py similarity index 75% rename from mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max.py rename to mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py index f35dfae39bbe9222be08795a751b60ba868a01c7..81322acccf1a27dd7c8b831a7f7b6a0a99716855 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py @@ -13,8 +13,7 @@ # limitations under the License. # ============================================================================ -"""FakeQuantWithMinMax op""" - +"""FakeQuantPerLayer op""" from functools import reduce as functools_reduce import te.lang.cce from te import tvm @@ -23,20 +22,16 @@ from topi import generic from topi.cce import util from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType -fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \ +fake_quant_per_layer_op_info = TBERegOp("FakeQuantPerLayer") \ .fusion_type("ELEMWISE") \ .async_flag(False) \ - .binfile_name("fake_quant_with_min_max_vars_ema.so") \ + .binfile_name("fake_quant_per_layer.so") \ .compute_cost(10) \ - .kernel_name("fake_quant_with_min_max_vars_ema") \ + .kernel_name("fake_quant_per_layer") \ .partial_flag(True) \ - .attr("ema", "optional", "bool", "all") \ - .attr("ema_decay", "optional", "float", "all") \ .attr("symmetric", "optional", "bool", "all") \ .attr("narrow_range", "optional", "bool", "all") \ - .attr("training", "optional", "bool", "all") \ .attr("num_bits", "optional", "int", "all") \ - .attr("quant_delay", "optional", "int", "all") \ .input(0, "x", None, "required", None) \ .input(1, "min", None, "required", None) \ .input(2, "max", None, "required", None) \ @@ -49,15 +44,15 @@ fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \ @op_info_register(fake_quant_op_info) -def _fake_quant_tbe(): - """FakeQuantWithMinMax TBE register""" +def _fake_quant_per_layer_tbe(): + """FakeQuantPerLayer TBE register""" return -@fusion_manager.register("fake_quant_with_min_max_vars_ema") -def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, quant_max, - kernel_name="correction_mul"): - """FakeQuantWithMinMax""" +@fusion_manager.register("fake_quant_per_layer") +def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max, + kernel_name="fake_quant_per_layer"): + """FakeQuantPerLayer""" shape = te.lang.cce.util.shape_to_list(x.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape) quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype) @@ -66,10 +61,13 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) # CalNudge(NudgeMinMax) - scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) + scale = te.lang.cce.vdiv(te.lang.cce.vsub( + max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale)) # Nudge zero point - nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min))) + nudge_zp_ = te.lang.cce.vmin( + quant_max, te.lang.cce.vmax(quant_min, zp_from_min)) + nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5)) nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale) nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale) @@ -80,17 +78,19 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, # FakeQuant input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x)) - nudge_input = te.lang.cce.round(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale)) + nudge_input_ = te.lang.cce.vdiv( + te.lang.cce.vsub(input_x, nudge_min), scale) + nudge_input = te.lang.cce.floor(te.lang.cce.vadds(nudge_input_, 0.5)) res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min) return res -@util.check_input_type(dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) -def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y, - ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay, - kernel_name="fake_quant"): - """FakeQuantWithMinMax""" +@util.check_input_type(dict, dict, dict, dict, bool, bool, int, str) +def fake_quant_per_layer(x, min_val, max_val, y, + symmetric, narrow_range, num_bits, + kernel_name="fake_quant_per_layer"): + """FakeQuantPerLayer""" input_shape = x.get("shape") input_dtype = x.get("dtype") min_shape = min_val.get("ori_shape") @@ -131,8 +131,8 @@ def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y, input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) - res = fake_quant_with_min_max_vars_ema_compute(input_data, min_data, max_data, y, - quant_min, quant_max, kernel_name) + res = fake_quant_per_layer_compute(input_data, min_data, max_data, y, + quant_min, quant_max, kernel_name) with tvm.target.cce(): sch = generic.auto_schedule(res) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_grad.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py similarity index 79% rename from mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_grad.py rename to mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py index 5137f7c42b8d94c838b77846a9ca854675ba7cff..9a5b8bc7d037878cf55dce8b71eba9d16b22ae2e 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_grad.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -"""FakeQuantWithMinMaxGrad op""" +"""FakeQuantPerLayerGrad op""" from functools import reduce as functools_reduce import te.lang.cce @@ -26,15 +26,14 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType SHAPE_SIZE_LIMIT = 2147483648 D_TYPE = 'float32' -fake_quant_grad_op_info = TBERegOp("FakeQuantWithMinMaxGrad") \ +fake_quant_per_layer_grad_op_info = TBERegOp("FakeQuantPerLayerGrad") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("fake_quant_with_min_max_grad.so") \ + .binfile_name("fake_quant_per_layer_grad.so") \ .compute_cost(10) \ - .kernel_name("fake_quant_with_min_max_grad") \ + .kernel_name("fake_quant_per_layer_grad") \ .partial_flag(True) \ .attr("num_bits", "optional", "int", "all") \ - .attr("quant_delay", "optional", "int", "all") \ .attr("symmetric", "optional", "bool", "all") \ .attr("narrow_range", "optional", "bool", "all") \ .input(0, "dout", None, "required", None) \ @@ -57,7 +56,8 @@ def _less_compare_float32(data_x, data_y): min_value = tvm.const(2 ** (-126), dtype=D_TYPE) max_value = tvm.const(2 ** 62, dtype=D_TYPE) factor_value = tvm.const(2 ** 2, dtype=D_TYPE) - data_zero = te.lang.cce.broadcast(tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE) + data_zero = te.lang.cce.broadcast( + tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE) min_value_tensor = te.lang.cce.vadds(data_zero, min_value) res_sub = te.lang.cce.vsub(data_y, data_x) @@ -71,16 +71,16 @@ def _less_compare_float32(data_x, data_y): return res -@op_info_register(fake_quant_grad_op_info) -def _fake_quant_grad_tbe(): - """FakeQuantWithMinMaxGrad TBE register""" +@op_info_register(fake_quant_per_layer_grad_op_info) +def _fake_quant_per_layer_grad_tbe(): + """FakeQuantPerLayerGrad TBE register""" return -@fusion_manager.register("fake_quant_with_min_max_grad") -def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, - kernel_name="fake_quant_with_min_max_grad"): - """FakeQuantWithMinMaxGrad""" +@fusion_manager.register("fake_quant_per_layer_grad") +def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, + kernel_name="fake_quant_per_layer_grad"): + """FakeQuantPerLayerGrad""" shape = te.lang.cce.util.shape_to_list(x.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape) quant_min = tvm.const(quant_min, x.dtype) @@ -89,10 +89,13 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q quant_max = te.lang.cce.broadcast(quant_max, shape_min) # CalNudge(NudgeMinMax) - scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) + scale = te.lang.cce.vdiv(te.lang.cce.vsub( + max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale)) # Nudge zero point - nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min))) + nudge_zp_ = te.lang.cce.vmin( + quant_max, te.lang.cce.vmax(quant_min, zp_from_min)) + nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5)) nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale) nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale) nudge_min = te.lang.cce.broadcast(nudge_min, shape) @@ -106,11 +109,11 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q return res -@util.check_input_type(dict, dict, dict, dict, dict, int, int, bool, bool, str) -def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, - num_bits, quant_delay, symmetric, narrow_range, - kernel_name="fake_quant_with_min_max_grad"): - """FakeQuantWithMinMaxGrad""" +@util.check_input_type(dict, dict, dict, dict, dict, int, bool, bool, str) +def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx, + num_bits, symmetric, narrow_range, + kernel_name="fake_quant_per_layer_grad"): + """FakeQuantPerLayerGrad""" input_shape = x.get("shape") input_dtype = x.get("dtype") min_shape = min_val.get("ori_shape") @@ -152,8 +155,8 @@ def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) - res = fake_quant_with_min_max_grad_compute(dout_data, input_data, min_data, max_data, quant_min, - quant_max, kernel_name) + res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data, quant_min, + quant_max, kernel_name) with tvm.target.cce(): sch = generic.auto_schedule(res) diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 347d72bf5fb8d97ded3a67eca3e7c29e7a6babb0..e1aa5630bafced6828ae6719988fcb3d16386fcf 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -20,10 +20,12 @@ from ..._checkparam import Rel from ..primitive import PrimitiveWithInfer, prim_attr_register from ...common import dtype as mstype -__all__ = ["FakeQuantWithMinMax", - "FakeQuantWithMinMaxGrad", - "FakeQuantWithMinMaxPerChannel", - "FakeQuantWithMinMaxPerChannelGrad", +__all__ = ["FakeQuantPerLayer", + "FakeQuantPerLayerGrad", + "FakeQuantPerChannel", + "FakeQuantPerChannelGrad", + "FakeQuantMinMaxPerLayerUpdate", + "FakeQuantMinMaxPerChannelUpdate", "BatchNormFold", "BatchNormFoldGrad", "CorrectionMul", @@ -36,11 +38,10 @@ __all__ = ["FakeQuantWithMinMax", "BatchNormFold2_D", "BatchNormFold2GradD", "BatchNormFold2GradReduce", - "FakeQuantWithMinMaxUpdate", ] -class FakeQuantWithMinMax(PrimitiveWithInfer): +class FakeQuantPerLayer(PrimitiveWithInfer): r""" Simulate the quantize and dequantize operations in training time. @@ -67,49 +68,67 @@ class FakeQuantWithMinMax(PrimitiveWithInfer): >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) >>> min_tensor = Tensor(np.array([-6]), mstype.float32) >>> max_tensor = Tensor(np.array([6]), mstype.float32) - >>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor) + >>> output_tensor = P.FakeQuantPerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor) """ support_quant_bit = [4, 7, 8] @prim_attr_register - def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, + def __init__(self, + num_bits=8, + ema=False, + ema_decay=0.999, + quant_delay=0, + symmetric=False, + narrow_range=False, training=True): - """init FakeQuantWithMinMax OP""" + """init FakeQuantPerLayer OP""" if num_bits not in self.support_quant_bit: - raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") + raise ValueError( + f"For '{self.name}' attr \'num_bits\' is not support.") if ema and not ema_decay: - raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") + raise ValueError( + f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") self.ema = validator.check_value_type('ema', ema, (bool,), self.name) - self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) - self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) - self.training = validator.check_value_type('training', training, (bool,), self.name) - self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) - self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) - self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) + self.symmetric = validator.check_value_type( + 'symmetric', symmetric, (bool,), self.name) + self.narrow_range = validator.check_value_type( + 'narrow_range', narrow_range, (bool,), self.name) + self.training = validator.check_value_type( + 'training', training, (bool,), self.name) + self.ema_decay = validator.check_number_range( + 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) + self.num_bits = validator.check_integer( + 'num_bits', num_bits, 0, Rel.GT, self.name) + self.quant_delay = validator.check_value_type( + 'quant_delay', quant_delay, (int,), self.name) self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) def infer_shape(self, x_shape, min_shape, max_shape): validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) - validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) - validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) + validator.check("min shape", min_shape, "max shape", + max_shape, Rel.EQ, self.name) + validator.check_integer("min rank", len( + min_shape), 1, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, min_type, max_type): valid_types = (mstype.float16, mstype.float32) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"max": max_type}, valid_types, self.name) return x_type -class FakeQuantWithMinMaxGrad(PrimitiveWithInfer): +class FakeQuantPerLayerGrad(PrimitiveWithInfer): r""" - Performs grad of FakeQuantWithMinMax operation. + Performs grad of FakeQuantPerLayerGrad operation. Examples: - >>> fake_min_max_grad = P.FakeQuantWithMinMaxGrad() + >>> fake_min_max_grad = P.FakeQuantPerLayerGrad() >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32) >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32) >>> _min = Tensor(np.array([-4]), mindspore.float32) @@ -119,32 +138,48 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer): support_quant_bit = [4, 7, 8] @prim_attr_register - def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False): + def __init__(self, + num_bits=8, + quant_delay=0, + symmetric=False, + narrow_range=False): if num_bits not in self.support_quant_bit: - raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") - - self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) - self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) - self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) - self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) - self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) + raise ValueError( + f"For '{self.name}' attr \'num_bits\' is not support.") + + self.num_bits = validator.check_integer( + 'num_bits', num_bits, 0, Rel.GT, self.name) + self.quant_delay = validator.check_value_type( + 'quant_delay', quant_delay, (int,), self.name) + self.symmetric = validator.check_value_type( + 'symmetric', symmetric, (bool,), self.name) + self.narrow_range = validator.check_value_type( + 'narrow_range', narrow_range, (bool,), self.name) + self.init_prim_io_names( + inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): - validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) - validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) - validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) + validator.check("dout shape", dout_shape, "x shape", + x_shape, Rel.EQ, self.name) + validator.check("min shape", min_shape, "max shape", + max_shape, Rel.EQ, self.name) + validator.check_integer("min rank", len( + min_shape), 1, Rel.EQ, self.name) return dout_shape def infer_dtype(self, dout_type, x_type, min_type, max_type): valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"dout": dout_type}, valid_types, self.name) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"max": max_type}, valid_types, self.name) return dout_type -class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): +class FakeQuantPerChannel(PrimitiveWithInfer): r""" Simulate the quantize and dequantize operations in training time base on per channel. @@ -168,53 +203,73 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): - Tensor, has the same type as input. Examples: - >>> fake_quant = P.FakeQuantWithMinMaxPerChannel() + >>> fake_quant = P.FakeQuantPerChannel() >>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32) >>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32) >>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32) >>> result = fake_quant(input_x, _min, _max) """ support_quant_bit = [4, 7, 8] - channel_axis = 0 @prim_attr_register - def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, - training=True): - """init FakeQuantWithMinMaxPerChannel OP""" + def __init__(self, + num_bits=8, + ema=False, + ema_decay=0.999, + quant_delay=0, + symmetric=False, + narrow_range=False, + training=True, + channel_axis=1): + """init FakeQuantPerChannel OP""" if num_bits not in self.support_quant_bit: - raise ValueError(f"For '{self.name}' Attr \'num_bits\' is not support.") + raise ValueError( + f"For '{self.name}' Attr \'num_bits\' is not support.") if ema and not ema_decay: - raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") + raise ValueError( + f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") self.ema = validator.check_value_type('ema', ema, (bool,), self.name) - self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) - self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) - self.training = validator.check_value_type('training', training, (bool,), self.name) - self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) - self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) - self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) + self.symmetric = validator.check_value_type( + 'symmetric', symmetric, (bool,), self.name) + self.narrow_range = validator.check_value_type( + 'narrow_range', narrow_range, (bool,), self.name) + self.training = validator.check_value_type( + 'training', training, (bool,), self.name) + self.ema_decay = validator.check_number_range( + 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) + self.num_bits = validator.check_integer( + 'num_bits', num_bits, 0, Rel.GT, self.name) + self.quant_delay = validator.check_value_type( + 'quant_delay', quant_delay, (int,), self.name) + self.channel_axis = validator.check_integer( + 'channel_axis', channel_axis, 0, Rel.GE, self.name) self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) def infer_shape(self, x_shape, min_shape, max_shape): validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) - validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) - validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) + validator.check_integer( + "min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) + validator.check_integer( + "max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, min_type, max_type): valid_types = (mstype.float16, mstype.float32) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"max": max_type}, valid_types, self.name) return x_type -class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): +class FakeQuantPerChannelGrad(PrimitiveWithInfer): r""" - Performs grad of FakeQuantWithMinMaxPerChannel operation. + Performs grad of FakeQuantPerChannelGrad operation. Examples: - >>> fqmmpc_grad = P.FakeQuantWithMinMaxPerChannelGrad() + >>> fqmmpc_grad = P.FakeQuantPerChannelGrad() >>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32) >>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32) >>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32) @@ -224,16 +279,29 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): support_quant_bit = [4, 7, 8] @prim_attr_register - def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False): - """init FakeQuantWithMinMaxPerChannel Fill""" + def __init__(self, + num_bits=8, + quant_delay=0, + symmetric=False, + narrow_range=False, + channel_axis=1): + """init FakeQuantPerChannelGrad Fill""" if num_bits not in self.support_quant_bit: - raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") - - self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) - self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) - self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) - self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) - self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) + raise ValueError( + f"For '{self.name}' attr \'num_bits\' is not support.") + + self.num_bits = validator.check_integer( + 'num_bits', num_bits, 0, Rel.GT, self.name) + self.quant_delay = validator.check_value_type( + 'quant_delay', quant_delay, (int,), self.name) + self.symmetric = validator.check_value_type( + 'symmetric', symmetric, (bool,), self.name) + self.narrow_range = validator.check_value_type( + 'narrow_range', narrow_range, (bool,), self.name) + self.channel_axis = validator.check_integer( + 'channel axis', channel_axis, 0, Rel.GE, self.name) + self.init_prim_io_names( + inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): validator.check("dout shape", dout_shape, "x shape", x_shape) @@ -242,10 +310,13 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): def infer_dtype(self, dout_type, x_type, min_type, max_type): valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"dout": dout_type}, valid_types, self.name) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"max": max_type}, valid_types, self.name) return dout_type @@ -744,17 +815,14 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer): return dout_type, dout_type -class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer): +class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer): r""" - Simulate the quantize and dequantize operations in training time. + Update min and max value for fake quant per layer op. Args: num_bits (int) : Number bits for aware quantilization. Default: 8. ema (bool): Use EMA algorithm update value min and max. Default: False. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. - quant_delay (int): Quantilization delay parameter. Before delay step in training time not update - simulate aware quantize funcion. After delay step in training time begin simulate the aware - quantize funcion. Default: 0. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. training (bool): Training the network or not. Default: True. @@ -776,36 +844,121 @@ class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer): support_quant_bit = [4, 7, 8] @prim_attr_register - def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, + def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, training=True): - """init FakeQuantWithMinMax OP""" + """init FakeQuantMinMaxPerLayerUpdate OP""" from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update if num_bits not in self.support_quant_bit: - raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") + raise ValueError( + f"For '{self.name}' attr \'num_bits\' is not support.") if ema and not ema_decay: - raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") + raise ValueError( + f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") self.ema = validator.check_value_type('ema', ema, (bool,), self.name) - self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) - self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) - self.training = validator.check_value_type('training', training, (bool,), self.name) - self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) - self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) - self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) + self.symmetric = validator.check_value_type( + 'symmetric', symmetric, (bool,), self.name) + self.narrow_range = validator.check_value_type( + 'narrow_range', narrow_range, (bool,), self.name) + self.training = validator.check_value_type( + 'training', training, (bool,), self.name) + self.ema_decay = validator.check_number_range( + 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) + self.num_bits = validator.check_integer( + 'num_bits', num_bits, 0, Rel.GT, self.name) self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) def infer_shape(self, x_shape, min_shape, max_shape): validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) - validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) - validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) + validator.check("min shape", min_shape, "max shape", + max_shape, Rel.EQ, self.name) + validator.check_integer("min rank", len( + min_shape), 1, Rel.EQ, self.name) return min_shape, max_shape def infer_dtype(self, x_type, min_type, max_type): valid_types = (mstype.float16, mstype.float32) validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"max": max_type}, valid_types, self.name) + return min_type, max_type + + +class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer): + r""" + Update min and max value for fake quant per layer op. + + Args: + num_bits (int) : Number bits for aware quantilization. Default: 8. + ema (bool): Use EMA algorithm update value min and max. Default: False. + ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + training (bool): Training the network or not. Default: True. + channel_axis (int): Channel asis for per channel compute. Default: 1. + + Inputs: + - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. + - **min** (Tensor) : Value of the min range of the input data x. + - **max** (Tensor) : Value of the max range of the input data x. + + Outputs: + - Tensor: Simulate quantize tensor of x. + + Examples: + >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) + >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) + >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) + >>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(x, min, max) + """ + support_quant_bit = [4, 7, 8] + + @prim_attr_register + def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, + training=True, channel_axis=1): + """init FakeQuantPerChannelUpdate OP for Ascend""" + if num_bits not in self.support_quant_bit: + raise ValueError( + f"For '{self.name}' attr \'num_bits\' is not support.") + if ema and not ema_decay: + raise ValueError( + f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") + + self.ema = validator.check_value_type('ema', ema, (bool,), self.name) + self.symmetric = validator.check_value_type( + 'symmetric', symmetric, (bool,), self.name) + self.narrow_range = validator.check_value_type( + 'narrow_range', narrow_range, (bool,), self.name) + self.training = validator.check_value_type( + 'training', training, (bool,), self.name) + self.ema_decay = validator.check_number_range( + 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) + self.num_bits = validator.check_integer( + 'num_bits', num_bits, 0, Rel.GT, self.name) + self.channel_axis = validator.check_integer( + 'channel axis', channel_axis, 0, Rel.GE, self.name) + self.init_prim_io_names( + inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) + + def infer_shape(self, x_shape, min_shape, max_shape): + validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) + validator.check("min shape", min_shape, "max shape", + max_shape, Rel.EQ, self.name) + validator.check_integer("min rank", len( + min_shape), 1, Rel.EQ, self.name) + return min_shape, max_shape + + def infer_dtype(self, x_type, min_type, max_type): + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same( + {"x": x_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"max": max_type}, valid_types, self.name) return min_type, max_type