提交 b7db3e9a 编写于 作者: C chenzomi

add fake quant per channel and bug fix

上级 bd3e8da6
......@@ -171,6 +171,6 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMax, FakeQuantGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantGpuKernel)
} // namespace kernel
} // namespace mindspore
......@@ -153,6 +153,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMaxGrad, FakeQuantGradGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantGradGpuKernel)
} // namespace kernel
} // namespace mindspore
......@@ -175,6 +175,6 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannel, FakeQuantPerChannelGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerChannel, FakeQuantPerChannelGpuKernel)
} // namespace kernel
} // namespace mindspore
......@@ -143,6 +143,6 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannelGrad, FakeQuantPerChannelGradGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerChannelGrad, FakeQuantPerChannelGradGpuKernel)
} // namespace kernel
} // namespace mindspore
......@@ -14,6 +14,7 @@
# ============================================================================
"""Aware quantization."""
from functools import partial
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
......@@ -101,10 +102,9 @@ class BatchNormFoldCell(Cell):
return batch_mean, batch_std, running_mean, running_std
class FakeQuantWithMinMaxD(Cell):
class FakeQuantWithMinMaxAscend(Cell):
r"""
Aware Quantization training op of ascend. This OP provide Fake quantization observer
function on data with min and max.
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
Args:
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
......@@ -125,7 +125,7 @@ class FakeQuantWithMinMaxD(Cell):
Tensor, with the same type and shape as the `x`.
Examples:
>>> fake_quant = nn.FakeQuantWithMinMaxD()
>>> fake_quant = FakeQuantWithMinMax()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x)
"""
......@@ -137,75 +137,77 @@ class FakeQuantWithMinMaxD(Cell):
ema=False,
ema_decay=0.999,
per_channel=False,
channel_size=1,
channel_axis=1,
out_channels=1,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True):
"""init FakeQuantWithMinMax ascend layer"""
super(FakeQuantWithMinMaxD, self).__init__()
"""init FakeQuantWithMinMaxAscend layer"""
super(FakeQuantWithMinMaxAscend, self).__init__()
self.min_init = min_init
self.num_bits = num_bits
self.max_init = max_init
self.num_bits = num_bits
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.channel_size = channel_size
self.channel_axis = channel_axis
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.training = training
if not per_channel:
self.fake_quant = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=training)
self.ema_update = P.FakeQuantWithMinMaxUpdate(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=training)
else:
raise RuntimeError("not support per channel")
# init tensor min and max for fake quant op
if isinstance(min_init, int):
min_array = np.array([min_init]).reshape(1).astype(np.float32)
max_array = np.array([max_init]).reshape(1).astype(np.float32)
elif isinstance(min_init, list):
min_array = np.array([self.min_init for i in range(
0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(
0, self.out_channels)]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
if isinstance(min_init, Parameter):
self.minq = min_init
self.maxq = max_init
if per_channel:
quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis)
ema_fun = partial(P.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis)
else:
self.minq = Parameter(Tensor(np.array([min_init]).astype(np.float32)),
name='quant_min',
requires_grad=False)
self.maxq = Parameter(Tensor(np.array([max_init]).astype(np.float32)),
name='quant_max',
requires_grad=False)
self.reduce_min = P.ReduceMin()
self.reduce_max = P.ReduceMax()
quant_fun = P.FakeQuantPerLayer
ema_fun = P.FakeQuantMinMaxPerLayerUpdate
self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
self.ema_update = ema_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
def extend_repr(self):
s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size,
self.quant_delay)
s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay,
self.per_channel, self.quant_delay, self.channel_axis)
return s
def construct(self, x, minq, maxq):
if self.training:
min_up, max_up = self.ema_update(x, minq, maxq)
def construct(self, x):
if self.update:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up)
else:
out = self.fake_quant(x, minq, maxq)
out = self.fake_quant(x, self.minq, self.maxq)
return out
class FakeQuantWithMinMax(Cell):
class FakeQuantWithMinMaxGPU(Cell):
r"""
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
......@@ -240,98 +242,69 @@ class FakeQuantWithMinMax(Cell):
ema=False,
ema_decay=0.999,
per_channel=False,
channel_axis=1,
out_channels=1,
quant_delay=0,
symmetric=False,
narrow_range=False):
"""init FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__()
narrow_range=False,
training=True):
super(FakeQuantWithMinMaxGPU, self).__init__()
self.min_init = min_init
self.num_bits = num_bits
self.max_init = max_init
self.num_bits = num_bits
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.out_channels = out_channels
self.channel_axis = channel_axis
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.training = training
if per_channel:
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(0, self.channel_size)]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True)
self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False)
else:
# init tensor min and max for fake quant op
if isinstance(min_init, int):
min_array = np.array([min_init]).reshape(1).astype(np.float32)
max_array = np.array([max_init]).reshape(1).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
if context.get_context('device_target') == "Ascend":
self.fake_quant_train = FakeQuantWithMinMaxD(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True,
min_init=self.minq,
max_init=self.maxq)
self.fake_quant_infer = FakeQuantWithMinMaxD(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False,
min_init=self.minq,
max_init=self.maxq)
elif context.get_context('device_target') == "GPU":
self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True)
self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False)
else:
raise ValueError("Not support platform.")
elif isinstance(min_init, list):
min_array = np.array([self.min_init for i in range(
0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(
0, self.out_channels)]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
if per_channel:
quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis)
else:
quant_fun = P.FakeQuantPerLayer
self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
def extend_repr(self):
s = 'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.quant_delay)
s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay,
self.per_channel, self.quant_delay, self.channel_axis)
return s
def construct(self, x):
if self.training:
out = self.fake_quant_train(x, self.minq, self.maxq)
else:
out = self.fake_quant_infer(x, self.minq, self.maxq)
out = self.fake_quant(x, self.minq, self.maxq)
return out
def FakeQuantWithMinMax(**kwargs):
if context.get_context('device_target') == "Ascend":
out = FakeQuantWithMinMaxAscend(**kwargs)
if context.get_context('device_target') == "GPU":
out = FakeQuantWithMinMaxGPU(**kwargs)
else:
raise ValueError("Not support platform or channel mode.")
return out
class Conv2dBatchNormQuant(Cell):
r"""
2D convolution with BatchNormal op folded layer.
......@@ -420,7 +393,6 @@ class Conv2dBatchNormQuant(Cell):
self.per_channel = per_channel
self.symmetric = symmetric
self.narrow_range = narrow_range
self.channel_axis = int(group > 1)
self.is_gpu = context.get_context('device_target') == "GPU"
# initialize convolution op and Parameter
......@@ -435,6 +407,7 @@ class Conv2dBatchNormQuant(Cell):
dilation=self.dilation)
if weight_init is None:
weight_init = initializer('normal', [1, in_channels, *self.kernel_size])
channel_axis = 1
else:
self.conv = P.Conv2D(out_channel=out_channels,
kernel_size=self.kernel_size,
......@@ -445,6 +418,7 @@ class Conv2dBatchNormQuant(Cell):
group=group)
if weight_init is None:
weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size])
channel_axis = 0
self.weight = Parameter(weight_init, name='weight')
# initialize batchnorm Parameter
......@@ -472,7 +446,7 @@ class Conv2dBatchNormQuant(Cell):
symmetric=symmetric,
narrow_range=narrow_range)
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
self.correct_mul = P.CorrectionMul(self.channel_axis)
self.correct_mul = P.CorrectionMul(channel_axis)
if context.get_context('device_target') == "Ascend":
self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn)
self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0)
......@@ -520,7 +494,7 @@ class Conv2dBatchNormQuant(Cell):
out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
F.control_depend(out, self.assignadd(self.step, self.one))
else:
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, running_std, running_mean, running_std)
return out
......
......@@ -20,10 +20,11 @@ from .grad_base import bprop_getters
from ..composite.multitype_ops.zeros_like_impl import zeros_like
@bprop_getters.register(P.FakeQuantWithMinMax)
@bprop_getters.register(P.FakeQuantPerLayer)
def get_bprop_fakequant_with_minmax(self):
"""Generate bprop for FakeQuantWithMinMax for GPU and Ascend"""
op = P.FakeQuantWithMinMaxGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
"""Generate bprop for FakeQuantPerLayer for GPU and Ascend"""
op = P.FakeQuantPerLayerGrad(
num_bits=self.num_bits, quant_delay=self.quant_delay)
def bprop(x, x_min, x_max, out, dout):
dx = op(dout, x, x_min, x_max)
......@@ -32,10 +33,14 @@ def get_bprop_fakequant_with_minmax(self):
return bprop
@bprop_getters.register(P.FakeQuantWithMinMaxPerChannel)
@bprop_getters.register(P.FakeQuantPerChannel)
def get_bprop_fakequant_with_minmax_perchannel(self):
"""Generate bprop for FakeQuantWithMinMaxPerChannel for GPU"""
op = P.FakeQuantWithMinMaxPerChannelGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
"""Generate bprop for FakeQuantPerChannel"""
op = P.FakeQuantPerChannelGrad(num_bits=self.num_bits,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.symmetric,
channel_axis=self.channel_axis)
def bprop(x, x_min, x_max, out, dout):
dx = op(dout, x, x_min, x_max)
......@@ -77,7 +82,7 @@ def get_bprop_batchnorm_fold2(self):
d_batch_std, d_batch_mean, d_beta, d_gamma, d_x = op_f(dout, x, gamma, batch_std, batch_mean, running_std,
running_mean, global_step)
return d_x, d_beta, d_gamma, d_batch_std, d_batch_mean, zeros_like(running_std), zeros_like(running_mean), \
zeros_like(global_step)
zeros_like(global_step)
return bprop
......@@ -117,9 +122,19 @@ def get_bprop_batchnorm_fold2_(self):
return bprop
@bprop_getters.register(P.FakeQuantWithMinMaxUpdate)
def get_bprop_fakequant_with_minmax_update(self):
"""Generate bprop for FakeQuantWithMinMaxUpdate for Ascend"""
@bprop_getters.register(P.FakeQuantMinMaxPerLayerUpdate)
def get_bprop_fakequant_with_minmax_per_layer_update(self):
"""Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend"""
def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
return bprop
@bprop_getters.register(P.FakeQuantMinMaxPerChannelUpdate)
def get_bprop_fakequant_with_minmax_per_channel_update(self):
"""Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend"""
def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantMinMaxPerChannelUpdate op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_min_max_per_channel_update.so") \
.compute_cost(10) \
.kernel_name("fake_quant_min_max_per_channel_update") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
.output(0, "min_up", True, "required", "all") \
.output(1, "max_up", True, "required", "all") \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.get_op_info()
@op_info_register(fake_quant_min_max_per_channel_update_op_info)
def _fake_quant_min_max_per_channel_update_tbe():
"""FakeQuantPerChannelUpdate TBE register"""
return
@fusion_manager.register("fake_quant_min_max_per_channel_update")
def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val,
ema, ema_decay, quant_min, quant_max, training, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerChannelUpdate compute"""
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
if not ema:
ema_decay = 0.0
if training:
# CalMinMax
axis = [0, 2, 3]
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerLayer op"""
x_shape = x.get("ori_shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
check_list = ["float32", "float16"]
x_dtype = x_dtype.lower()
min_dtype = min_dtype.lower()
max_dtype = max_dtype.lower()
util.check_dtype_rule(x_dtype, check_list)
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
tensor_list = [input_data, min_data, max_data] + list(res_list)
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)
......@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""FakeQuantWithMinMaxUpdate op"""
"""FakeQuantMinMaxPerLayerUpdate op"""
from functools import reduce as functools_reduce
import te.lang.cce
from te import tvm
......@@ -23,12 +23,12 @@ from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_with_min_max_update.so") \
.binfile_name("fake_quant_minmax_update.so") \
.compute_cost(10) \
.kernel_name("fake_quant_with_min_max_update") \
.kernel_name("fake_quant_minmax_update") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
......@@ -36,7 +36,6 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
......@@ -47,16 +46,16 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
.get_op_info()
@op_info_register(fake_quant_update_op_info)
def _fake_quant_update_tbe():
"""_FakeQuantWithMinMaxUpdate TBE register"""
@op_info_register(fake_quant_minmax_update_op_info)
def _fake_quant_minmax_update_tbe():
"""FakeQuantMinMaxPerLayerUpdate TBE register"""
return
@fusion_manager.register("fake_quant_with_min_max_update")
def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
kernel_name="fake_quant_update"):
"""FakeQuantWithMinMaxUpdate compute"""
@fusion_manager.register("fake_quant_minmax_update")
def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
kernel_name="fake_quant_minmax_update"):
"""FakeQuantMinMaxPerLayerUpdate compute"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
......@@ -70,19 +69,21 @@ def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay,
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay,
kernel_name="fake_quant_update"):
"""FakeQuantWithMinMax op"""
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str)
def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits,
kernel_name="fake_quant_minmax_update"):
"""FakeQuantPerLayer op"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
......@@ -123,8 +124,8 @@ def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up,
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res_list = fake_quant_with_min_max_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, kernel_name)
res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantPerChannel op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_perchannel_op_info = TBERegOp("FakeQuantPerChannel") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("fake_quant_perchannel.so") \
.compute_cost(10) \
.kernel_name("fake_quant_perchannel") \
.partial_flag(True) \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(fake_quant_perchannel_op_info)
def _fake_quant_perchannel_tbe():
"""FakeQuantPerChannel TBE register"""
return
@fusion_manager.register("fake_quant_perchannel")
def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max,
kernel_name="fake_quant_perchannel"):
"""FakeQuantPerChannel"""
x_shape = te.lang.cce.util.shape_to_list(x.shape)
minmax_shape = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = tvm.const(quant_min, x.dtype)
quant_max = tvm.const(quant_max, x.dtype)
quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype)
quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
# FakeQuant
nudge_min_b = te.lang.cce.broadcast(nudge_min, x_shape)
nudge_max_b = te.lang.cce.broadcast(nudge_max, x_shape)
scale_b = te.lang.cce.broadcast(scale, x_shape)
input_x = te.lang.cce.vmin(nudge_max_b, te.lang.cce.vmax(nudge_min_b, x))
nudge_input_ = te.lang.cce.vdiv(
te.lang.cce.vsub(input_x, nudge_min_b), scale_b)
nudge_input = te.lang.cce.floor(te.lang.cce.vadds(nudge_input_, 0.5))
res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale_b), nudge_min_b)
return res
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel(x, min_val, max_val, y,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel"):
"""FakeQuantPerChannel"""
x_shape = x.get("shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
check_list = ["float32", "float16"]
x_dtype = x_dtype.lower()
min_dtype = min_dtype.lower()
max_dtype = max_dtype.lower()
util.check_dtype_rule(x_dtype, check_list)
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1:
shape_c = min_val.get("shape")
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res = fake_quant_perchannel_compute(input_data, min_data, max_data, y,
quant_min, quant_max, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)
tensor_list = [input_data, min_data, max_data, res]
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantPerChannelGrad op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
SHAPE_SIZE_LIMIT = 2147483648
D_TYPE = 'float32'
fake_quant_perchannel_grad_op_info = TBERegOp("FakeQuantPerChannelGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_perchannel_grad.so") \
.compute_cost(10) \
.kernel_name("fake_quant_perchannel_grad") \
.partial_flag(True) \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "dout", None, "required", None) \
.input(1, "x", None, "required", None) \
.input(2, "min", None, "required", None) \
.input(3, "max", None, "required", None) \
.output(0, "dx", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
def _less_compare_float32(data_x, data_y):
"""_less_compare_float32 compute"""
input_shape = te.lang.cce.util.shape_to_list(data_x.shape)
min_value = tvm.const(2 ** (-126), dtype=D_TYPE)
max_value = tvm.const(2 ** 62, dtype=D_TYPE)
factor_value = tvm.const(2 ** 2, dtype=D_TYPE)
data_zero = te.lang.cce.broadcast(
tvm.const(0, dtype=D_TYPE), input_shape, D_TYPE)
min_value_tensor = te.lang.cce.vadds(data_zero, min_value)
res_sub = te.lang.cce.vsub(data_y, data_x)
res_min = te.lang.cce.vmin(res_sub, min_value_tensor)
res_max = te.lang.cce.vmax(res_min, data_zero)
res_max_mul = te.lang.cce.vmuls(res_max, max_value)
res_max_mul_max = te.lang.cce.vmuls(res_max_mul, max_value)
res = te.lang.cce.vmuls(res_max_mul_max, factor_value)
return res
@op_info_register(fake_quant_perchannel_grad_op_info)
def _fake_quant_perchannel_grad_tbe():
"""FakeQuantPerChannelGrad TBE register"""
return
@fusion_manager.register("fake_quant_perchannel_grad")
def fake_quant_perchannel_grad_compute(dout, x, min_val, max_val, quant_min, quant_max,
kernel_name="fake_quant_perchannel_grad"):
"""FakeQuantPerChannelGrad"""
x_shape = te.lang.cce.util.shape_to_list(x.shape)
minmax_shape = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = tvm.const(quant_min, x.dtype)
quant_max = tvm.const(quant_max, x.dtype)
quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype)
quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
# FakeQuant Grad
nudge_min_b = te.lang.cce.broadcast(nudge_min, x_shape)
nudge_max_b = te.lang.cce.broadcast(nudge_max, x_shape)
bool_over_min = _less_compare_float32(nudge_min_b, x)
bool_less_max = _less_compare_float32(x, nudge_max_b)
bool_between = te.lang.cce.vmul(bool_over_min, bool_less_max)
res = te.lang.cce.vmul(dout, bool_between)
return res
@util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel_grad"):
"""FakeQuantPerChannelGrad"""
x_shape = x.get("shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
check_list = ["float32", "float16"]
x_dtype = x_dtype.lower()
min_dtype = min_dtype.lower()
max_dtype = max_dtype.lower()
util.check_dtype_rule(x_dtype, check_list)
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1:
shape_c = min_val.get("shape")
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res = fake_quant_perchannel_grad_compute(dout_data, input_data, min_data, max_data,
quant_min, quant_max, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)
tensor_list = [dout_data, input_data, min_data, max_data, res]
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)
......@@ -13,8 +13,7 @@
# limitations under the License.
# ============================================================================
"""FakeQuantWithMinMax op"""
"""FakeQuantPerLayer op"""
from functools import reduce as functools_reduce
import te.lang.cce
from te import tvm
......@@ -23,20 +22,16 @@ from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \
fake_quant_per_layer_op_info = TBERegOp("FakeQuantPerLayer") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("fake_quant_with_min_max_vars_ema.so") \
.binfile_name("fake_quant_per_layer.so") \
.compute_cost(10) \
.kernel_name("fake_quant_with_min_max_vars_ema") \
.kernel_name("fake_quant_per_layer") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
......@@ -49,15 +44,15 @@ fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \
@op_info_register(fake_quant_op_info)
def _fake_quant_tbe():
"""FakeQuantWithMinMax TBE register"""
def _fake_quant_per_layer_tbe():
"""FakeQuantPerLayer TBE register"""
return
@fusion_manager.register("fake_quant_with_min_max_vars_ema")
def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, quant_max,
kernel_name="correction_mul"):
"""FakeQuantWithMinMax"""
@fusion_manager.register("fake_quant_per_layer")
def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max,
kernel_name="fake_quant_per_layer"):
"""FakeQuantPerLayer"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype)
......@@ -66,10 +61,13 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min)))
nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
......@@ -80,17 +78,19 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
# FakeQuant
input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x))
nudge_input = te.lang.cce.round(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale))
nudge_input_ = te.lang.cce.vdiv(
te.lang.cce.vsub(input_x, nudge_min), scale)
nudge_input = te.lang.cce.floor(te.lang.cce.vadds(nudge_input_, 0.5))
res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min)
return res
@util.check_input_type(dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y,
ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay,
kernel_name="fake_quant"):
"""FakeQuantWithMinMax"""
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, str)
def fake_quant_per_layer(x, min_val, max_val, y,
symmetric, narrow_range, num_bits,
kernel_name="fake_quant_per_layer"):
"""FakeQuantPerLayer"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
......@@ -131,8 +131,8 @@ def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y,
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res = fake_quant_with_min_max_vars_ema_compute(input_data, min_data, max_data, y,
quant_min, quant_max, kernel_name)
res = fake_quant_per_layer_compute(input_data, min_data, max_data, y,
quant_min, quant_max, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""FakeQuantWithMinMaxGrad op"""
"""FakeQuantPerLayerGrad op"""
from functools import reduce as functools_reduce
import te.lang.cce
......@@ -26,15 +26,14 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
SHAPE_SIZE_LIMIT = 2147483648
D_TYPE = 'float32'
fake_quant_grad_op_info = TBERegOp("FakeQuantWithMinMaxGrad") \
fake_quant_per_layer_grad_op_info = TBERegOp("FakeQuantPerLayerGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_with_min_max_grad.so") \
.binfile_name("fake_quant_per_layer_grad.so") \
.compute_cost(10) \
.kernel_name("fake_quant_with_min_max_grad") \
.kernel_name("fake_quant_per_layer_grad") \
.partial_flag(True) \
.attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.input(0, "dout", None, "required", None) \
......@@ -57,7 +56,8 @@ def _less_compare_float32(data_x, data_y):
min_value = tvm.const(2 ** (-126), dtype=D_TYPE)
max_value = tvm.const(2 ** 62, dtype=D_TYPE)
factor_value = tvm.const(2 ** 2, dtype=D_TYPE)
data_zero = te.lang.cce.broadcast(tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE)
data_zero = te.lang.cce.broadcast(
tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE)
min_value_tensor = te.lang.cce.vadds(data_zero, min_value)
res_sub = te.lang.cce.vsub(data_y, data_x)
......@@ -71,16 +71,16 @@ def _less_compare_float32(data_x, data_y):
return res
@op_info_register(fake_quant_grad_op_info)
def _fake_quant_grad_tbe():
"""FakeQuantWithMinMaxGrad TBE register"""
@op_info_register(fake_quant_per_layer_grad_op_info)
def _fake_quant_per_layer_grad_tbe():
"""FakeQuantPerLayerGrad TBE register"""
return
@fusion_manager.register("fake_quant_with_min_max_grad")
def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, quant_max,
kernel_name="fake_quant_with_min_max_grad"):
"""FakeQuantWithMinMaxGrad"""
@fusion_manager.register("fake_quant_per_layer_grad")
def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max,
kernel_name="fake_quant_per_layer_grad"):
"""FakeQuantPerLayerGrad"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = tvm.const(quant_min, x.dtype)
......@@ -89,10 +89,13 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q
quant_max = te.lang.cce.broadcast(quant_max, shape_min)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min)))
nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
nudge_min = te.lang.cce.broadcast(nudge_min, shape)
......@@ -106,11 +109,11 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q
return res
@util.check_input_type(dict, dict, dict, dict, dict, int, int, bool, bool, str)
def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx,
num_bits, quant_delay, symmetric, narrow_range,
kernel_name="fake_quant_with_min_max_grad"):
"""FakeQuantWithMinMaxGrad"""
@util.check_input_type(dict, dict, dict, dict, dict, int, bool, bool, str)
def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx,
num_bits, symmetric, narrow_range,
kernel_name="fake_quant_per_layer_grad"):
"""FakeQuantPerLayerGrad"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
......@@ -152,8 +155,8 @@ def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx,
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res = fake_quant_with_min_max_grad_compute(dout_data, input_data, min_data, max_data, quant_min,
quant_max, kernel_name)
res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data, quant_min,
quant_max, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册