未验证 提交 8991e9ae 编写于 作者: W whs 提交者: GitHub

Fix quant and dequant cuda kernels when quant_axis==1 (#40772)

上级 319f95d0
...@@ -58,19 +58,15 @@ __global__ void DequantizeOneScaleQuantAxis0(const T* in, const T* scale, ...@@ -58,19 +58,15 @@ __global__ void DequantizeOneScaleQuantAxis0(const T* in, const T* scale,
} }
template <typename T> template <typename T>
__global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale, __global__ void DequantizeOneScaleQuantAxisN(const T* in, const T* scale,
T max_range, const int num, const T max_range,
const int cin, const int cout, const int64_t num,
T* out) { const int n_scales,
int bid = blockIdx.x; const int quant_stride, T* out) {
T s = scale[bid % cout]; int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
int wh_size = num / (cin * cout); T s = scale[(i / quant_stride) % n_scales];
const T* in_current = in + bid * wh_size; out[i] = in[i] * s / max_range;
T* out_current = out + bid * wh_size;
for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
out_current[i] = in_current[i] * s / max_range;
} }
} }
...@@ -98,20 +94,32 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> { ...@@ -98,20 +94,32 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
const T* in_data = in->data<T>(); const T* in_data = in->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace()); T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
if (scale_num == 1) { if (scale_num == 1) {
int num = in->numel(); int64_t num = in->numel();
const T* scale_factor = scales[0]->data<T>(); const T* scale_factor = scales[0]->data<T>();
if (quant_axis == 0) { if (quant_axis == 0) {
int grid = in_dims[0]; int grid = in_dims[0];
int block = 1024; int block = 1024;
DequantizeOneScaleQuantAxis0<T><<<grid, block, 0, dev_ctx.stream()>>>( DequantizeOneScaleQuantAxis0<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, in_dims[0], out_data); in_data, scale_factor, max_range, num, in_dims[0], out_data);
} else if (quant_axis == 1) { } else {
// Dequantize weight of Cin * Cout * W * H int quant_stride = 1;
int grid = in_dims[0] * in_dims[1]; for (int i = quant_axis + 1; i < in_dims.size(); i++) {
int block = 1024; quant_stride *= in_dims[i];
DequantizeOneScaleQuantAxis1<T><<<grid, block, 0, dev_ctx.stream()>>>( }
in_data, scale_factor, max_range, num, in_dims[0], in_dims[1],
out_data); int64_t block_size = std::min(
num, static_cast<int64_t>(dev_ctx.GetMaxThreadsPerBlock() / 4));
int64_t max_threads =
dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM
const int64_t max_blocks = std::max(
((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (num + block_size - 1) / block_size);
DequantizeOneScaleQuantAxisN<
T><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, in_dims[quant_axis],
quant_stride, out_data);
} }
} else if (scale_num == 2) { } else if (scale_num == 2) {
// Not need to consider quant_axis // Not need to consider quant_axis
......
...@@ -273,18 +273,18 @@ struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> { ...@@ -273,18 +273,18 @@ struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
template <typename T> template <typename T>
__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
const int bin_cnt, const int bin_cnt,
const int n, const int c, const int64_t n,
T* out) { const int c, T* out) {
int tid = threadIdx.x; int tid = threadIdx.x;
int channel_size = n / c; int64_t channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size; const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size; T* out_c = out + blockIdx.x * channel_size;
T s = scale[blockIdx.x]; T s = scale[blockIdx.x];
T inv_s = inverse(s); T inv_s = inverse(s);
for (int i = tid; i < channel_size; i += blockDim.x) { for (int64_t i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i]; T x = in_c[i];
T v = x > s ? s : x; T v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
...@@ -293,25 +293,20 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, ...@@ -293,25 +293,20 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
} }
} }
// ChannelClipAndQuantKernel for quant_axis is 1 // ChannelClipAndQuantKernel for quant_axis is N
template <typename T> template <typename T>
__global__ void ChannelClipAndQuantKernelQuantAxis1(const T* in, const T* scale, __global__ void ChannelClipAndQuantKernelQuantAxisN(
const int bin_cnt, const T* in, const T* scale, const int bin_cnt, const int64_t n,
const int n, const int cin, const int nScale, const int quant_stride, T* out) {
const int cout, T* out) { int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
T s = scale[blockIdx.x % cout]; for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) {
T inv_s = inverse(s); T s = scale[(i / quant_stride) % nScale];
T inv_s = 1.0 / s;
int wh_size = n / (cin * cout); T x = in[i];
const T* in_c = in + blockIdx.x * wh_size;
T* out_c = out + blockIdx.x * wh_size;
for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x; T v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt * inv_s * v; v = bin_cnt * inv_s * v;
out_c[i] = round(v); out[i] = round(v);
} }
} }
...@@ -327,7 +322,7 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> { ...@@ -327,7 +322,7 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
int num = in.numel(); int64_t num = in.numel();
auto in_dims = in.dims(); auto in_dims = in.dims();
const T* in_data = in.data<T>(); const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>(); const T* scale_data = scale.data<T>();
...@@ -338,11 +333,24 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> { ...@@ -338,11 +333,24 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
int block = 1024; int block = 1024;
ChannelClipAndQuantKernelQuantAxis0<T><<<grid, block, 0, ctx.stream()>>>( ChannelClipAndQuantKernelQuantAxis0<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, in_dims[0], out_data); in_data, scale_data, bin_cnt, num, in_dims[0], out_data);
} else if (quant_axis == 1) { } else {
int grid = in_dims[0] * in_dims[1]; int quant_stride = 1;
int block = 1024; for (int i = quant_axis + 1; i < in_dims.size(); i++) {
ChannelClipAndQuantKernelQuantAxis1<T><<<grid, block, 0, ctx.stream()>>>( quant_stride *= in_dims[i];
in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data); }
int64_t block_size =
std::min(num, static_cast<int64_t>(ctx.GetMaxThreadsPerBlock() / 4));
int64_t max_threads =
ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM
const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1),
static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (num + block_size - 1) / block_size);
ChannelClipAndQuantKernelQuantAxisN<T><<<grid_size, block_size>>>(
in_data, scale_data, bin_cnt, num, in_dims[quant_axis], quant_stride,
out_data);
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册