提交 1c3e5796 编写于 作者: W wandongdong

fix bug in quant and correction_mul_grad

上级 75f791d8
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "fake_quant_impl.cuh" #include "fake_quant_impl.cuh"
__global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min, __global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min,
const float *nudge_max, const float *scale, bool symmetric) { const float *nudge_max, const float *scale) {
float input_x = 0.f; float input_x = 0.f;
int nudge_input = 0; int nudge_input = 0;
...@@ -35,7 +35,7 @@ __global__ void FakeQuantize(const float *input, float *output, const int size, ...@@ -35,7 +35,7 @@ __global__ void FakeQuantize(const float *input, float *output, const int size,
input_x = nudge_max[0]; input_x = nudge_max[0];
} }
// clamp shift // clamp shift
nudge_input = floor((input_x - nudge_min[0]) / scale[0] + 0.5f); nudge_input = round((input_x - nudge_min[0]) / scale[0]);
// quantize // quantize
output[i] = nudge_input * scale[0] + nudge_min[0]; output[i] = nudge_input * scale[0] + nudge_min[0];
...@@ -99,8 +99,7 @@ __global__ void UpdateInputMinMax(float *input_min, float *input_max, const floa ...@@ -99,8 +99,7 @@ __global__ void UpdateInputMinMax(float *input_min, float *input_max, const floa
void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max, void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max,
const float *scale, bool symmetric, cudaStream_t cuda_stream) { const float *scale, bool symmetric, cudaStream_t cuda_stream) {
FakeQuantize<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max, scale, FakeQuantize<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max, scale);
symmetric);
return; return;
} }
......
...@@ -22,7 +22,7 @@ from mindspore.common.parameter import Parameter ...@@ -22,7 +22,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import check_int_positive, check_bool, twice from mindspore._checkparam import check_int_positive, check_bool, twice
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator, Rel
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation from mindspore.nn.layer.activation import get_activation
import mindspore.context as context import mindspore.context as context
...@@ -207,7 +207,7 @@ class FakeQuantWithMinMaxD(Cell): ...@@ -207,7 +207,7 @@ class FakeQuantWithMinMaxD(Cell):
class FakeQuantWithMinMax(Cell): class FakeQuantWithMinMax(Cell):
r""" r"""
Aware Quantization training op. This OP provide Fake quantization observer function on data with min and max. Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
Args: Args:
min_init (int, list): The dimension of channel or 1(layer). Default: -6. min_init (int, list): The dimension of channel or 1(layer). Default: -6.
...@@ -243,8 +243,7 @@ class FakeQuantWithMinMax(Cell): ...@@ -243,8 +243,7 @@ class FakeQuantWithMinMax(Cell):
out_channels=1, out_channels=1,
quant_delay=0, quant_delay=0,
symmetric=False, symmetric=False,
narrow_range=False, narrow_range=False):
training=True):
"""init FakeQuantWithMinMax layer""" """init FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__() super(FakeQuantWithMinMax, self).__init__()
...@@ -258,7 +257,6 @@ class FakeQuantWithMinMax(Cell): ...@@ -258,7 +257,6 @@ class FakeQuantWithMinMax(Cell):
self.quant_delay = quant_delay self.quant_delay = quant_delay
self.symmetric = symmetric self.symmetric = symmetric
self.narrow_range = narrow_range self.narrow_range = narrow_range
self.training = training
if per_channel: if per_channel:
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32) min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32)
...@@ -422,11 +420,13 @@ class Conv2dBatchNormQuant(Cell): ...@@ -422,11 +420,13 @@ class Conv2dBatchNormQuant(Cell):
self.per_channel = per_channel self.per_channel = per_channel
self.symmetric = symmetric self.symmetric = symmetric
self.narrow_range = narrow_range self.narrow_range = narrow_range
self.channel_axis = int(group > 1)
self.is_gpu = context.get_context('device_target') == "GPU"
# initialize convolution op and Parameter # initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1: if context.get_context('device_target') == "Ascend" and group > 1:
validator.check_integer('group', group, 'in_channels', in_channels, 'Conv2dBatchNormQuant') validator.check_integer('group', group, in_channels, Rel.EQ, 'Conv2dBatchNormQuant')
validator.check_integer('group', group, 'in_channels', out_channels, 'Conv2dBatchNormQuant') validator.check_integer('group', group, out_channels, Rel.EQ, 'Conv2dBatchNormQuant')
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
pad_mode=pad_mode, pad_mode=pad_mode,
...@@ -472,7 +472,7 @@ class Conv2dBatchNormQuant(Cell): ...@@ -472,7 +472,7 @@ class Conv2dBatchNormQuant(Cell):
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range) narrow_range=narrow_range)
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
self.correct_mul = P.CorrectionMul() self.correct_mul = P.CorrectionMul(self.channel_axis)
if context.get_context('device_target') == "Ascend": if context.get_context('device_target') == "Ascend":
self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn) self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn)
self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0) self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0)
......
...@@ -93,8 +93,8 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe ...@@ -93,8 +93,8 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe
util.check_dtype_rule(inp_dtype_dout, ("float16", "float32")) util.check_dtype_rule(inp_dtype_dout, ("float16", "float32"))
util.check_dtype_rule(inp_dtype_x, ("float16", "float32")) util.check_dtype_rule(inp_dtype_x, ("float16", "float32"))
util.check_dtype_rule(inp_dtype_batch_std, ("float32",)) util.check_dtype_rule(inp_dtype_batch_std, ("float16", "float32"))
util.check_dtype_rule(inp_dtype_running_std, ("float32",)) util.check_dtype_rule(inp_dtype_running_std, ("float16", "float32"))
util.compare_tensor_dict_key(dout, x, "dtype") util.compare_tensor_dict_key(dout, x, "dtype")
util.compare_tensor_dict_key(dout, x, "shape") util.compare_tensor_dict_key(dout, x, "shape")
util.compare_tensor_dict_key(dx, x, "shape") util.compare_tensor_dict_key(dx, x, "shape")
......
...@@ -80,8 +80,7 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, ...@@ -80,8 +80,7 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
# FakeQuant # FakeQuant
input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x)) input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x))
nudge_input = te.lang.cce.floor(te.lang.cce.vadds(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale), nudge_input = te.lang.cce.round(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale))
0.5))
res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min) res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min)
return res return res
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册