未验证 提交 9783e887 编写于 作者: G Guanghua Yu 提交者: GitHub

[cherry pick #43088 #40664] Add float16 to fake quantize/dequantize OP (#43689)

* cherry pick #43088 #40664

* fix clang format
上级 a363e5ab
......@@ -17,10 +17,13 @@ limitations under the License. */
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
using float16 = paddle::platform::float16;
REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CUDA, float>,
ops::FakeDequantizeMaxAbsKernel<CUDA, double>);
ops::FakeDequantizeMaxAbsKernel<CUDA, double>,
ops::FakeDequantizeMaxAbsKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(
fake_channel_wise_dequantize_max_abs,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, float>,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, double>);
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, double>,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, float16>);
......@@ -19,17 +19,22 @@ namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
using float16 = paddle::platform::float16;
REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CUDA, float>);
ops::FakeQuantizeAbsMaxKernel<CUDA, float>,
ops::FakeQuantizeAbsMaxKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_abs_max,
ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float>,
ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(
fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>);
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>,
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(
fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>);
ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
ops::MovingAverageAbsMaxScaleKernel<CUDA, float>,
ops::MovingAverageAbsMaxScaleKernel<CUDA, float16>);
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#endif // PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_
#include <string>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
......@@ -24,6 +25,16 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
struct QuantizeDataType {
using type = T;
};
template <>
struct QuantizeDataType<paddle::platform::float16> {
using type = float;
};
template <typename T>
__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x;
......@@ -87,10 +98,12 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n,
int tid = threadIdx.x;
int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size;
extern __shared__ T shared_max_data[];
extern __shared__ char* shared_max_data_tmp[];
auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp);
T local_max_data = T(0);
for (int i = tid; i < channel_size; i += blockDim.x) {
T tmp = fabs(in_c[i]);
T tmp = static_cast<T>(
fabs(static_cast<typename QuantizeDataType<T>::type>(in_c[i])));
if (tmp > local_max_data) {
local_max_data = tmp;
}
......@@ -112,7 +125,8 @@ template <typename T>
__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
const int cin, const int cout,
T* out) {
extern __shared__ T shared_max_data[];
extern __shared__ char* shared_max_data_tmp[];
auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp);
int cout_wh_size = n / cin;
int wh_size = n / (cin * cout);
......@@ -121,7 +135,8 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
const T* in_current = in + tid * cout_wh_size + bid * wh_size;
T local_max_data = T(0);
for (int i = 0; i < wh_size; i++) {
T tmp = fabs(in_current[i]);
T tmp = static_cast<T>(
fabs(static_cast<typename QuantizeDataType<T>::type>(in_current[i])));
if (tmp > local_max_data) {
local_max_data = tmp;
}
......@@ -203,14 +218,18 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x;
T s = scale[0];
T inv_s = inverse(s);
using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType s = static_cast<ComputeDataType>(scale[0]);
ComputeDataType inv_s = inverse(s);
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i];
T v = x > s ? s : x;
ComputeDataType x = static_cast<ComputeDataType>(in[i]);
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out[i] = round(v);
v = bin_cnt_t * inv_s * v;
out[i] = static_cast<T>(round(v));
}
}
......@@ -221,17 +240,19 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x;
T s = scale[0];
T inv_s = inverse(s);
T bin_cnt_t = static_cast<T>(bin_cnt);
using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType s = static_cast<ComputeDataType>(scale[0]);
ComputeDataType inv_s = inverse(s);
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i];
ComputeDataType x = static_cast<ComputeDataType>(in[i]);
x = x > s ? s : x;
x = x < -s ? -s : x;
x = bin_cnt_t * inv_s * x;
x = static_cast<T>(round(static_cast<float>(x)));
out[i] = (x * s) / bin_cnt_t;
x = round(x);
out[i] = static_cast<T>((x * s) / bin_cnt_t);
}
}
......@@ -285,15 +306,18 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
T s = scale[blockIdx.x];
T inv_s = inverse(s);
using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType s = static_cast<ComputeDataType>(scale[blockIdx.x]);
ComputeDataType inv_s = inverse(s);
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int64_t i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x;
ComputeDataType x = static_cast<ComputeDataType>(in_c[i]);
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out_c[i] = round(v);
v = bin_cnt_t * inv_s * v;
out_c[i] = static_cast<T>(round(v));
}
}
......@@ -303,14 +327,17 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN(
const T* in, const T* scale, const int bin_cnt, const int64_t n,
const int nScale, const int quant_stride, T* out) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) {
T s = scale[(i / quant_stride) % nScale];
T inv_s = inverse(s);
T x = in[i];
T v = x > s ? s : x;
ComputeDataType s =
static_cast<ComputeDataType>(scale[(i / quant_stride) % nScale]);
ComputeDataType inv_s = inverse(s);
ComputeDataType x = static_cast<ComputeDataType>(in[i]);
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out[i] = round(v);
v = bin_cnt_t * inv_s * v;
out[i] = static_cast<T>(round(v));
}
}
......@@ -376,7 +403,8 @@ __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
scale_arr[idx] = cur;
T max = last_scale[0];
out_scale[0] = max < cur ? cur : max;
if (fabs(removed - max) < 1e-6) {
if (fabs(static_cast<typename QuantizeDataType<T>::type>(removed - max)) <
1e-6) {
need_find_max[0] = 1;
out_size[0] = it > window_size ? window_size : it;
} else {
......
......@@ -18,6 +18,7 @@ import unittest
import numpy as np
import math
from op_test import OpTest
import paddle.fluid.core as core
def quantize_max_abs(x, max_range):
......@@ -76,22 +77,25 @@ def channel_wise_dequantize_max_abs(x,
class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest):
def set_args(self):
self.quant_bits = [8, 8]
self.data_type = "float32"
self.activation_scale = 0.7861
def set_dtype(self):
self.dtype = np.float32
def setUp(self):
self.set_args()
self.set_dtype()
self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
x = np.random.randn(4, 3, 64, 64).astype(self.dtype)
yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0], 1)
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, 1,
self.activation_scale)
self.inputs = {
'X': yq,
'Scales': [("scales0", np.array(scales).astype(self.data_type)),
("scales1", np.array(
[self.activation_scale]).astype(self.data_type))]
'Scales':
[("scales0", np.array(scales).astype(self.dtype)),
("scales1", np.array([self.activation_scale]).astype(self.dtype))]
}
self.attrs = {'quant_bits': self.quant_bits}
self.outputs = {'Out': ydq}
......@@ -100,16 +104,28 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest):
self.check_output()
class TestFakeChannelWiseDequantizeMaxAbsOpTwoScalesFloat16(
TestFakeChannelWiseDequantizeMaxAbsOpTwoScales):
def set_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-2)
class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest):
def set_args(self):
self.quant_bits = [8]
self.data_type = "float32"
self.quant_axis = 0
def set_dtype(self):
self.dtype = np.float32
def setUp(self):
self.set_args()
self.set_dtype()
self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
x = np.random.randn(4, 3, 64, 64).astype(self.dtype)
yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0],
self.quant_axis)
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits,
......@@ -117,7 +133,7 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest):
self.inputs = {
'X': yq,
'Scales': [("scales0", np.array(scales).astype(self.data_type))]
'Scales': [("scales0", np.array(scales).astype(self.dtype))]
}
self.attrs = {
'quant_bits': self.quant_bits,
......@@ -133,24 +149,44 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1(
TestFakeChannelWiseDequantizeMaxAbsOpOneScale):
def set_args(self):
self.quant_bits = [8]
self.data_type = "float32"
self.quant_axis = 1
class TestFakeChannelWiseDequantizeMaxAbsOpOneScaleFloat16(
TestFakeChannelWiseDequantizeMaxAbsOpOneScale):
def set_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-2)
class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1Float16(
TestFakeChannelWiseDequantizeMaxAbsOpOneScale1):
def set_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-2)
class TestFakeDequantizeMaxAbsOp(OpTest):
def set_args(self):
self.num_bits = 8
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "float32"
def set_dtype(self):
self.dtype = np.float32
def setUp(self):
self.set_args()
self.set_dtype()
self.op_type = "fake_dequantize_max_abs"
x = np.random.randn(31, 65).astype(self.data_type)
x = np.random.randn(31, 65).astype(self.dtype)
yq, scale = quantize_max_abs(x, self.max_range)
ydq = dequantize_max_abs(yq, scale, self.max_range)
self.inputs = {'X': yq, 'Scale': np.array(scale).astype(self.data_type)}
self.inputs = {'X': yq, 'Scale': np.array(scale).astype(self.dtype)}
self.attrs = {'max_range': self.max_range}
self.outputs = {'Out': ydq}
......@@ -159,17 +195,22 @@ class TestFakeDequantizeMaxAbsOp(OpTest):
class TestFakeDequantizeMaxAbsOpDouble(TestFakeDequantizeMaxAbsOp):
def set_args(self):
self.num_bits = 8
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "float64"
def set_dtype(self):
self.dtype = np.float64
class TestFakeDequantizeMaxAbsOp5Bits(TestFakeDequantizeMaxAbsOp):
def set_args(self):
self.num_bits = 5
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "float32"
class TestFakeDequantizeMaxAbsOpFloat16(TestFakeDequantizeMaxAbsOp):
def set_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-2)
class TestChannelWiseDequantizeOp(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册