未验证 提交 5d422287 编写于 作者: L Leo Chen 提交者: GitHub

Add float16 to fake quantize/dequantize OP (#40664)

上级 7ce0ee69
...@@ -17,10 +17,13 @@ limitations under the License. */ ...@@ -17,10 +17,13 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext; using CUDA = paddle::platform::CUDADeviceContext;
using float16 = paddle::platform::float16;
REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs, REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CUDA, float>, ops::FakeDequantizeMaxAbsKernel<CUDA, float>,
ops::FakeDequantizeMaxAbsKernel<CUDA, double>); ops::FakeDequantizeMaxAbsKernel<CUDA, double>,
ops::FakeDequantizeMaxAbsKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
fake_channel_wise_dequantize_max_abs, fake_channel_wise_dequantize_max_abs,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, float>, ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, float>,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, double>); ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, double>,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, float16>);
...@@ -19,17 +19,22 @@ namespace ops = paddle::operators; ...@@ -19,17 +19,22 @@ namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext; using CUDA = paddle::platform::CUDADeviceContext;
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max, REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CUDA, float>); ops::FakeQuantizeAbsMaxKernel<CUDA, float>,
ops::FakeQuantizeAbsMaxKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_abs_max, REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_abs_max,
ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float>, ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float>,
ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float16>); ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max, REGISTER_OP_CUDA_KERNEL(
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>); fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max, REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>); ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>,
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
fake_quantize_moving_average_abs_max, fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>); ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale, REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
ops::MovingAverageAbsMaxScaleKernel<CUDA, float>, ops::MovingAverageAbsMaxScaleKernel<CUDA, float>,
ops::MovingAverageAbsMaxScaleKernel<CUDA, float16>); ops::MovingAverageAbsMaxScaleKernel<CUDA, float16>);
......
...@@ -24,6 +24,16 @@ limitations under the License. */ ...@@ -24,6 +24,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct QuantizeDataType {
using type = T;
};
template <>
struct QuantizeDataType<paddle::platform::float16> {
using type = float;
};
template <typename T> template <typename T>
__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x; int bid = threadIdx.x + blockIdx.x * blockDim.x;
...@@ -87,10 +97,12 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, ...@@ -87,10 +97,12 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n,
int tid = threadIdx.x; int tid = threadIdx.x;
int channel_size = n / c; int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size; const T* in_c = in + blockIdx.x * channel_size;
extern __shared__ T shared_max_data[]; extern __shared__ char* shared_max_data_tmp[];
auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp);
T local_max_data = T(0); T local_max_data = T(0);
for (int i = tid; i < channel_size; i += blockDim.x) { for (int i = tid; i < channel_size; i += blockDim.x) {
T tmp = fabs(in_c[i]); T tmp = static_cast<T>(
fabs(static_cast<typename QuantizeDataType<T>::type>(in_c[i])));
if (tmp > local_max_data) { if (tmp > local_max_data) {
local_max_data = tmp; local_max_data = tmp;
} }
...@@ -112,7 +124,8 @@ template <typename T> ...@@ -112,7 +124,8 @@ template <typename T>
__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
const int cin, const int cout, const int cin, const int cout,
T* out) { T* out) {
extern __shared__ T shared_max_data[]; extern __shared__ char* shared_max_data_tmp[];
auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp);
int cout_wh_size = n / cin; int cout_wh_size = n / cin;
int wh_size = n / (cin * cout); int wh_size = n / (cin * cout);
...@@ -121,7 +134,8 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, ...@@ -121,7 +134,8 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
const T* in_current = in + tid * cout_wh_size + bid * wh_size; const T* in_current = in + tid * cout_wh_size + bid * wh_size;
T local_max_data = T(0); T local_max_data = T(0);
for (int i = 0; i < wh_size; i++) { for (int i = 0; i < wh_size; i++) {
T tmp = fabs(in_current[i]); T tmp = static_cast<T>(
fabs(static_cast<typename QuantizeDataType<T>::type>(in_current[i])));
if (tmp > local_max_data) { if (tmp > local_max_data) {
local_max_data = tmp; local_max_data = tmp;
} }
...@@ -205,12 +219,14 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, ...@@ -205,12 +219,14 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
T s = scale[0]; T s = scale[0];
T inv_s = inverse(s); T inv_s = inverse(s);
T bin_cnt_t = static_cast<T>(bin_cnt);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i]; T x = in[i];
T v = x > s ? s : x; T v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt * inv_s * v; v = bin_cnt_t * inv_s * v;
out[i] = round(v); out[i] = static_cast<T>(
round(static_cast<typename QuantizeDataType<T>::type>(v)));
} }
} }
...@@ -230,7 +246,8 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, ...@@ -230,7 +246,8 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
x = x > s ? s : x; x = x > s ? s : x;
x = x < -s ? -s : x; x = x < -s ? -s : x;
x = bin_cnt_t * inv_s * x; x = bin_cnt_t * inv_s * x;
x = static_cast<T>(round(static_cast<float>(x))); x = static_cast<T>(
round(static_cast<typename QuantizeDataType<T>::type>(x)));
out[i] = (x * s) / bin_cnt_t; out[i] = (x * s) / bin_cnt_t;
} }
} }
...@@ -287,13 +304,15 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, ...@@ -287,13 +304,15 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
T s = scale[blockIdx.x]; T s = scale[blockIdx.x];
T inv_s = inverse(s); T inv_s = inverse(s);
T bin_cnt_t = static_cast<T>(bin_cnt);
for (int64_t i = tid; i < channel_size; i += blockDim.x) { for (int64_t i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i]; T x = in_c[i];
T v = x > s ? s : x; T v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt * inv_s * v; v = bin_cnt_t * inv_s * v;
out_c[i] = round(v); out_c[i] = static_cast<T>(
round(static_cast<typename QuantizeDataType<T>::type>(v)));
} }
} }
...@@ -303,14 +322,16 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN( ...@@ -303,14 +322,16 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN(
const T* in, const T* scale, const int bin_cnt, const int64_t n, const T* in, const T* scale, const int bin_cnt, const int64_t n,
const int nScale, const int quant_stride, T* out) { const int nScale, const int quant_stride, T* out) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
T bin_cnt_t = static_cast<T>(bin_cnt);
for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) { for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) {
T s = scale[(i / quant_stride) % nScale]; T s = scale[(i / quant_stride) % nScale];
T inv_s = inverse(s); T inv_s = inverse(s);
T x = in[i]; T x = in[i];
T v = x > s ? s : x; T v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt * inv_s * v; v = bin_cnt_t * inv_s * v;
out[i] = round(v); out[i] = static_cast<T>(
round(static_cast<typename QuantizeDataType<T>::type>(v)));
} }
} }
...@@ -376,7 +397,8 @@ __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, ...@@ -376,7 +397,8 @@ __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
scale_arr[idx] = cur; scale_arr[idx] = cur;
T max = last_scale[0]; T max = last_scale[0];
out_scale[0] = max < cur ? cur : max; out_scale[0] = max < cur ? cur : max;
if (fabs(removed - max) < 1e-6) { if (fabs(static_cast<typename QuantizeDataType<T>::type>(removed - max)) <
1e-6) {
need_find_max[0] = 1; need_find_max[0] = 1;
out_size[0] = it > window_size ? window_size : it; out_size[0] = it > window_size ? window_size : it;
} else { } else {
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import numpy as np import numpy as np
import math import math
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
def quantize_max_abs(x, max_range): def quantize_max_abs(x, max_range):
...@@ -76,22 +77,25 @@ def channel_wise_dequantize_max_abs(x, ...@@ -76,22 +77,25 @@ def channel_wise_dequantize_max_abs(x,
class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest): class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest):
def set_args(self): def set_args(self):
self.quant_bits = [8, 8] self.quant_bits = [8, 8]
self.data_type = "float32"
self.activation_scale = 0.7861 self.activation_scale = 0.7861
def set_dtype(self):
self.dtype = np.float32
def setUp(self): def setUp(self):
self.set_args() self.set_args()
self.set_dtype()
self.op_type = "fake_channel_wise_dequantize_max_abs" self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type) x = np.random.randn(4, 3, 64, 64).astype(self.dtype)
yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0], 1) yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0], 1)
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, 1, ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, 1,
self.activation_scale) self.activation_scale)
self.inputs = { self.inputs = {
'X': yq, 'X': yq,
'Scales': [("scales0", np.array(scales).astype(self.data_type)), 'Scales': [("scales0", np.array(scales).astype(self.dtype)),
("scales1", np.array( ("scales1",
[self.activation_scale]).astype(self.data_type))] np.array([self.activation_scale]).astype(self.dtype))]
} }
self.attrs = {'quant_bits': self.quant_bits} self.attrs = {'quant_bits': self.quant_bits}
self.outputs = {'Out': ydq} self.outputs = {'Out': ydq}
...@@ -100,16 +104,28 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest): ...@@ -100,16 +104,28 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest):
self.check_output() self.check_output()
class TestFakeChannelWiseDequantizeMaxAbsOpTwoScalesFloat16(
TestFakeChannelWiseDequantizeMaxAbsOpTwoScales):
def set_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-2)
class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest): class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest):
def set_args(self): def set_args(self):
self.quant_bits = [8] self.quant_bits = [8]
self.data_type = "float32"
self.quant_axis = 0 self.quant_axis = 0
def set_dtype(self):
self.dtype = np.float32
def setUp(self): def setUp(self):
self.set_args() self.set_args()
self.set_dtype()
self.op_type = "fake_channel_wise_dequantize_max_abs" self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type) x = np.random.randn(4, 3, 64, 64).astype(self.dtype)
yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0], yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0],
self.quant_axis) self.quant_axis)
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits,
...@@ -117,7 +133,7 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest): ...@@ -117,7 +133,7 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest):
self.inputs = { self.inputs = {
'X': yq, 'X': yq,
'Scales': [("scales0", np.array(scales).astype(self.data_type))] 'Scales': [("scales0", np.array(scales).astype(self.dtype))]
} }
self.attrs = { self.attrs = {
'quant_bits': self.quant_bits, 'quant_bits': self.quant_bits,
...@@ -133,24 +149,44 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1( ...@@ -133,24 +149,44 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1(
TestFakeChannelWiseDequantizeMaxAbsOpOneScale): TestFakeChannelWiseDequantizeMaxAbsOpOneScale):
def set_args(self): def set_args(self):
self.quant_bits = [8] self.quant_bits = [8]
self.data_type = "float32"
self.quant_axis = 1 self.quant_axis = 1
class TestFakeChannelWiseDequantizeMaxAbsOpOneScaleFloat16(
TestFakeChannelWiseDequantizeMaxAbsOpOneScale):
def set_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-2)
class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1Float16(
TestFakeChannelWiseDequantizeMaxAbsOpOneScale1):
def set_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-2)
class TestFakeDequantizeMaxAbsOp(OpTest): class TestFakeDequantizeMaxAbsOp(OpTest):
def set_args(self): def set_args(self):
self.num_bits = 8 self.num_bits = 8
self.max_range = math.pow(2, self.num_bits - 1) - 1 self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "float32"
def set_dtype(self):
self.dtype = np.float32
def setUp(self): def setUp(self):
self.set_args() self.set_args()
self.set_dtype()
self.op_type = "fake_dequantize_max_abs" self.op_type = "fake_dequantize_max_abs"
x = np.random.randn(31, 65).astype(self.data_type) x = np.random.randn(31, 65).astype(self.dtype)
yq, scale = quantize_max_abs(x, self.max_range) yq, scale = quantize_max_abs(x, self.max_range)
ydq = dequantize_max_abs(yq, scale, self.max_range) ydq = dequantize_max_abs(yq, scale, self.max_range)
self.inputs = {'X': yq, 'Scale': np.array(scale).astype(self.data_type)} self.inputs = {'X': yq, 'Scale': np.array(scale).astype(self.dtype)}
self.attrs = {'max_range': self.max_range} self.attrs = {'max_range': self.max_range}
self.outputs = {'Out': ydq} self.outputs = {'Out': ydq}
...@@ -159,17 +195,22 @@ class TestFakeDequantizeMaxAbsOp(OpTest): ...@@ -159,17 +195,22 @@ class TestFakeDequantizeMaxAbsOp(OpTest):
class TestFakeDequantizeMaxAbsOpDouble(TestFakeDequantizeMaxAbsOp): class TestFakeDequantizeMaxAbsOpDouble(TestFakeDequantizeMaxAbsOp):
def set_args(self): def set_dtype(self):
self.num_bits = 8 self.dtype = np.float64
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "float64"
class TestFakeDequantizeMaxAbsOp5Bits(TestFakeDequantizeMaxAbsOp): class TestFakeDequantizeMaxAbsOp5Bits(TestFakeDequantizeMaxAbsOp):
def set_args(self): def set_args(self):
self.num_bits = 5 self.num_bits = 5
self.max_range = math.pow(2, self.num_bits - 1) - 1 self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "float32"
class TestFakeDequantizeMaxAbsOpFloat16(TestFakeDequantizeMaxAbsOp):
def set_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-2)
class TestChannelWiseDequantizeOp(OpTest): class TestChannelWiseDequantizeOp(OpTest):
......
...@@ -15,28 +15,51 @@ ...@@ -15,28 +15,51 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import math
import numpy as np import numpy as np
import math import math
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core import paddle.fluid.core as core
# numpy.round has different behavior in comparision to c++ round function
# so we use round_c instead of numpy.round to align the output data
def round_c_single_element(x):
dtype = type(x)
if x >= 0:
return dtype(np.floor(x + 0.5))
else:
return dtype(np.ceil(x - 0.5))
round_c = np.vectorize(round_c_single_element)
class TestFakeQuantizeOp(OpTest): class TestFakeQuantizeOp(OpTest):
def setUp(self): def setUp(self):
self.set_dtype()
self.op_type = "fake_quantize_abs_max" self.op_type = "fake_quantize_abs_max"
self.attrs = {'bit_length': 8} self.attrs = {'bit_length': 8}
self.inputs = {'X': np.random.random((124, 240)).astype("float32"), } self.inputs = {'X': np.random.random((124, 240)).astype(self.dtype), }
scale = np.max(np.abs(self.inputs['X'])).astype("float32") scale = np.max(np.abs(self.inputs['X'])).astype(self.dtype)
self.outputs = { self.outputs = {
'Out': np.round(self.inputs['X'] / scale * ( 'Out': round_c(self.inputs['X'] / scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1)), (1 << (self.attrs['bit_length'] - 1)) - 1)),
'OutScale': np.array(scale).astype("float32"), 'OutScale': np.array(scale).astype(self.dtype),
} }
def set_dtype(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestFakeQuantizeOpFloat16(TestFakeQuantizeOp):
def set_dtype(self):
self.dtype = np.float16
class TestFakeQuantizeOp1(OpTest): class TestFakeQuantizeOp1(OpTest):
def setUp(self): def setUp(self):
self.op_type = "fake_quantize_abs_max" self.op_type = "fake_quantize_abs_max"
...@@ -73,6 +96,7 @@ class TestFakeQuantizeOp2(OpTest): ...@@ -73,6 +96,7 @@ class TestFakeQuantizeOp2(OpTest):
class TestFakeChannelWiseQuantizeOp(OpTest): class TestFakeChannelWiseQuantizeOp(OpTest):
def setUp(self): def setUp(self):
self.set_dtype()
self.set_arg() self.set_arg()
assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1." assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1."
...@@ -84,53 +108,70 @@ class TestFakeChannelWiseQuantizeOp(OpTest): ...@@ -84,53 +108,70 @@ class TestFakeChannelWiseQuantizeOp(OpTest):
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
if self.quant_axis == 0: if self.quant_axis == 0:
for i in range(self.inputs['X'].shape[0]): for i in range(self.inputs['X'].shape[0]):
scale_v = np.max(np.abs(self.inputs['X'][i])).astype("float32") scale_v = np.max(np.abs(self.inputs['X'][i])).astype(self.dtype)
scales.append(scale_v) scales.append(scale_v)
outputs[i] = np.round(outputs[i] / scale_v * bnt) outputs[i] = round_c(
self.dtype(bnt) * (self.dtype(1.0) / scale_v) * outputs[i])
elif self.quant_axis == 1: elif self.quant_axis == 1:
for i in range(self.inputs['X'].shape[1]): for i in range(self.inputs['X'].shape[1]):
scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype( scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype(
"float32") self.dtype)
scales.append(scale_v) scales.append(scale_v)
outputs[:, i] = np.round(outputs[:, i] / scale_v * bnt) outputs[:, i] = round_c(
self.dtype(bnt) * (self.dtype(1.0) / scale_v) *
outputs[:, i])
self.outputs = { self.outputs = {
'Out': outputs, 'Out': outputs,
'OutScale': np.array(scales).astype("float32"), 'OutScale': np.array(scales).astype(self.dtype),
} }
def set_arg(self): def set_arg(self):
self.quant_axis = 0 self.quant_axis = 0
self.inputs = { self.inputs = {
'X': np.random.random((20, 15, 6, 6)).astype("float32"), 'X': np.random.random((20, 15, 6, 6)).astype(self.dtype),
} }
def set_dtype(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestFakeChannelWiseQuantizeOpFloat16(TestFakeChannelWiseQuantizeOp):
def set_dtype(self):
self.dtype = np.float16
class TestFakeChannelWiseQuantizeOp1(TestFakeChannelWiseQuantizeOp): class TestFakeChannelWiseQuantizeOp1(TestFakeChannelWiseQuantizeOp):
def set_quant_axis(self): def set_quant_axis(self):
self.quant_axis = 1 self.quant_axis = 1
self.inputs = { self.inputs = {
'X': np.random.random((15, 20, 5, 5)).astype("float32"), 'X': np.random.random((15, 20, 5, 5)).astype(self.dtype),
} }
class TestFakeChannelWiseQuantizeOp1Float16(TestFakeChannelWiseQuantizeOp1):
def set_dtype(self):
self.dtype = np.float16
class TestFakeChannelWiseQuantizeOp2(TestFakeChannelWiseQuantizeOp): class TestFakeChannelWiseQuantizeOp2(TestFakeChannelWiseQuantizeOp):
def set_quant_axis(self): def set_quant_axis(self):
self.quant_axis = 0 self.quant_axis = 0
self.inputs = {'X': np.random.random((30, 15)).astype("float32"), } self.inputs = {'X': np.random.random((30, 15)).astype(self.dtype), }
class TestFakeChannelWiseQuantizeOp3(TestFakeChannelWiseQuantizeOp): class TestFakeChannelWiseQuantizeOp3(TestFakeChannelWiseQuantizeOp):
def set_quant_axis(self): def set_quant_axis(self):
self.quant_axis = 1 self.quant_axis = 1
self.inputs = {'X': np.random.random((30, 15)).astype("float32"), } self.inputs = {'X': np.random.random((30, 15)).astype(self.dtype), }
class TestFakeQuantizeRangeAbsMaxOp(OpTest): class TestFakeQuantizeRangeAbsMaxOp(OpTest):
def setUp(self): def setUp(self):
self.set_dtype()
self.op_type = "fake_quantize_range_abs_max" self.op_type = "fake_quantize_range_abs_max"
self.attrs = { self.attrs = {
'bit_length': int(5), 'bit_length': int(5),
...@@ -138,27 +179,36 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest): ...@@ -138,27 +179,36 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest):
'is_test': False 'is_test': False
} }
x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10 x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10
x = x.astype("float32") x = x.astype(self.dtype)
self.inputs = { self.inputs = {
'X': x, 'X': x,
'Iter': np.zeros(1).astype("int64"), 'Iter': np.zeros(1).astype("int64"),
'InScale': np.zeros(1).astype("float32") 'InScale': np.zeros(1).astype(self.dtype)
} }
scale = np.max(np.abs(self.inputs['X'])).astype("float32") scale = np.max(np.abs(self.inputs['X'])).astype(self.dtype)
out_scales = np.zeros(self.attrs['window_size']).astype("float32") out_scales = np.zeros(self.attrs['window_size']).astype(self.dtype)
out_scales[0] = scale out_scales[0] = scale
self.outputs = { self.outputs = {
'Out': np.round(self.inputs['X'] / scale * ( 'Out': round_c(
(1 << (self.attrs['bit_length'] - 1)) - 1)), self.dtype((1 << (self.attrs['bit_length'] - 1)) - 1) *
(self.dtype(1.0) / scale) * self.inputs['X']),
'OutScale': scale, 'OutScale': scale,
'OutScales': out_scales, 'OutScales': out_scales,
} }
def set_dtype(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestFakeQuantizeRangeAbsMaxOpFloat16(TestFakeQuantizeRangeAbsMaxOp):
def set_dtype(self):
self.dtype = np.float16
class TestMovingAverageAbsMaxScaleOp(OpTest): class TestMovingAverageAbsMaxScaleOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "moving_average_abs_max_scale" self.op_type = "moving_average_abs_max_scale"
...@@ -195,6 +245,7 @@ class TestMovingAverageAbsMaxScaleOp(OpTest): ...@@ -195,6 +245,7 @@ class TestMovingAverageAbsMaxScaleOp(OpTest):
class TestFakeQuantizeRangeAbsMaxOp2(OpTest): class TestFakeQuantizeRangeAbsMaxOp2(OpTest):
def setUp(self): def setUp(self):
self.set_dtype()
self.op_type = "fake_quantize_range_abs_max" self.op_type = "fake_quantize_range_abs_max"
self.attrs = { self.attrs = {
'bit_length': int(8), 'bit_length': int(8),
...@@ -202,55 +253,68 @@ class TestFakeQuantizeRangeAbsMaxOp2(OpTest): ...@@ -202,55 +253,68 @@ class TestFakeQuantizeRangeAbsMaxOp2(OpTest):
'is_test': True 'is_test': True
} }
x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10 x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10
x = x.astype("float32") x = x.astype(self.dtype)
scale = np.array([np.max(np.abs(x)).astype("float32") - 1.0]) scale = np.array([np.max(np.abs(x)).astype(self.dtype) - 1.0])
out_scales = np.zeros(self.attrs['window_size']).astype("float32") out_scales = np.zeros(self.attrs['window_size']).astype(self.dtype)
out_scales[0] = scale out_scales[0] = scale.astype(self.dtype)
self.inputs = { self.inputs = {
'X': x, 'X': x,
'Iter': np.zeros(1).astype("int64"), 'Iter': np.zeros(1).astype("int64"),
'InScale': scale.astype("float32") 'InScale': scale.astype(self.dtype)
} }
xs = np.clip(x, -scale, scale) xs = np.clip(x, -scale, scale).astype(self.dtype)
qs = np.round(xs / scale * ((1 << (self.attrs['bit_length'] - 1)) - 1)) qs = round_c(
self.dtype(
self.dtype((1 << (self.attrs['bit_length'] - 1)) - 1) * (
self.dtype(1.0) / scale) * xs))
self.outputs = { self.outputs = {
'Out': qs, 'Out': qs,
'OutScale': scale.astype("float32"), 'OutScale': scale.astype(self.dtype),
'OutScales': out_scales, 'OutScales': out_scales,
} }
def set_dtype(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output(no_check_set=set(['OutScale', 'OutScales'])) self.check_output(no_check_set=set(['OutScale', 'OutScales']))
class TestFakeQuantizeRangeAbsMaxOp2Float16(TestFakeQuantizeRangeAbsMaxOp2):
def set_dtype(self):
self.dtype = np.float16
class TestMovingOpBase(OpTest): class TestMovingOpBase(OpTest):
def setUp(self): def setUp(self):
self.set_dtype()
self.init_type() self.init_type()
self.attrs = { self.attrs = {
'bit_length': int(5), 'bit_length': int(5),
'moving_rate': float(0.9), 'moving_rate': float(0.9),
'is_test': False 'is_test': False
} }
accum = np.zeros(1).astype("float32") accum = np.zeros(1).astype(self.dtype)
accum[0] = 1 accum[0] = 1
state = np.zeros(1).astype("float32") state = np.zeros(1).astype(self.dtype)
state[0] = 1 state[0] = self.dtype(1.0)
scale = np.zeros(1).astype("float32") scale = np.zeros(1).astype(self.dtype)
scale[0] = 0.001 scale[0] = 0.001
self.inputs = { self.inputs = {
'X': np.random.random((8, 16, 7, 7)).astype("float32"), 'X': np.random.random((8, 16, 7, 7)).astype(self.dtype),
'InScale': scale, 'InScale': scale,
'InAccum': accum, 'InAccum': accum,
'InState': state, 'InState': state,
} }
out_accum = np.zeros(1).astype("float32") out_accum = np.zeros(1).astype(self.dtype)
out_state = np.zeros(1).astype("float32") out_state = np.zeros(1).astype(self.dtype)
out_scale = np.zeros(1).astype("float32") out_scale = np.zeros(1).astype(self.dtype)
out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max( out_accum[0] = self.dtype(self.attrs['moving_rate']) * self.dtype(accum[
np.abs(self.inputs['X'])).astype("float32") 0]) + np.max(np.abs(self.inputs['X'])).astype(self.dtype)
out_state[0] = self.attrs['moving_rate'] * state[0] + 1 out_state[0] = self.dtype(self.attrs['moving_rate']) * self.dtype(state[
out_scale = out_accum / out_state 0]) + self.dtype(1.0)
out_scale = self.dtype(self.dtype(out_accum) / self.dtype(out_state))
out_data = self.calc_output(out_scale) out_data = self.calc_output(out_scale)
self.outputs = { self.outputs = {
'Out': out_data, 'Out': out_data,
...@@ -259,17 +323,28 @@ class TestMovingOpBase(OpTest): ...@@ -259,17 +323,28 @@ class TestMovingOpBase(OpTest):
'OutScale': out_scale, 'OutScale': out_scale,
} }
def set_dtype(self):
self.dtype = np.float32
def init_type(self): def init_type(self):
self.op_type = "fake_quantize_moving_average_abs_max" self.op_type = "fake_quantize_moving_average_abs_max"
def calc_output(self, out_scale): def calc_output(self, out_scale):
return np.round(self.inputs['X'] / out_scale * ( return round_c(self.inputs['X'] / out_scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1)) (1 << (self.attrs['bit_length'] - 1)) - 1))
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestMovingOpBaseFloat16(TestMovingOpBase):
def set_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-2)
class TestFakeQuantDequantMovingOp(TestMovingOpBase): class TestFakeQuantDequantMovingOp(TestMovingOpBase):
def init_type(self): def init_type(self):
self.op_type = "fake_quantize_dequantize_moving_average_abs_max" self.op_type = "fake_quantize_dequantize_moving_average_abs_max"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册