From 8991e9ae2664305bbf6b276a40724b03f0557609 Mon Sep 17 00:00:00 2001 From: whs Date: Wed, 23 Mar 2022 13:49:12 +0800 Subject: [PATCH] Fix quant and dequant cuda kernels when quant_axis==1 (#40772) --- paddle/fluid/operators/fake_dequantize_op.cu | 50 ++++++++++------- paddle/fluid/operators/fake_quantize_op.cu | 58 +++++++++++--------- 2 files changed, 62 insertions(+), 46 deletions(-) diff --git a/paddle/fluid/operators/fake_dequantize_op.cu b/paddle/fluid/operators/fake_dequantize_op.cu index c88a8fe196..c0ec44909a 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cu +++ b/paddle/fluid/operators/fake_dequantize_op.cu @@ -58,19 +58,15 @@ __global__ void DequantizeOneScaleQuantAxis0(const T* in, const T* scale, } template -__global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale, - T max_range, const int num, - const int cin, const int cout, - T* out) { - int bid = blockIdx.x; - T s = scale[bid % cout]; - - int wh_size = num / (cin * cout); - const T* in_current = in + bid * wh_size; - T* out_current = out + bid * wh_size; - - for (int i = threadIdx.x; i < wh_size; i += blockDim.x) { - out_current[i] = in_current[i] * s / max_range; +__global__ void DequantizeOneScaleQuantAxisN(const T* in, const T* scale, + const T max_range, + const int64_t num, + const int n_scales, + const int quant_stride, T* out) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { + T s = scale[(i / quant_stride) % n_scales]; + out[i] = in[i] * s / max_range; } } @@ -98,20 +94,32 @@ struct ChannelDequantizeFunctor { const T* in_data = in->data(); T* out_data = out->mutable_data(dev_ctx.GetPlace()); if (scale_num == 1) { - int num = in->numel(); + int64_t num = in->numel(); const T* scale_factor = scales[0]->data(); if (quant_axis == 0) { int grid = in_dims[0]; int block = 1024; DequantizeOneScaleQuantAxis0<<>>( in_data, scale_factor, max_range, num, in_dims[0], out_data); - } else if (quant_axis == 1) { - // Dequantize weight of Cin * Cout * W * H - int grid = in_dims[0] * in_dims[1]; - int block = 1024; - DequantizeOneScaleQuantAxis1<<>>( - in_data, scale_factor, max_range, num, in_dims[0], in_dims[1], - out_data); + } else { + int quant_stride = 1; + for (int i = quant_axis + 1; i < in_dims.size(); i++) { + quant_stride *= in_dims[i]; + } + + int64_t block_size = std::min( + num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = std::max( + ((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + + DequantizeOneScaleQuantAxisN< + T><<>>( + in_data, scale_factor, max_range, num, in_dims[quant_axis], + quant_stride, out_data); } } else if (scale_num == 2) { // Not need to consider quant_axis diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 70597be393..01384a6caf 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -273,18 +273,18 @@ struct ClipAndFakeQuantDequantFunctor { template __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, const int bin_cnt, - const int n, const int c, - T* out) { + const int64_t n, + const int c, T* out) { int tid = threadIdx.x; - int channel_size = n / c; + int64_t channel_size = n / c; const T* in_c = in + blockIdx.x * channel_size; T* out_c = out + blockIdx.x * channel_size; T s = scale[blockIdx.x]; T inv_s = inverse(s); - for (int i = tid; i < channel_size; i += blockDim.x) { + for (int64_t i = tid; i < channel_size; i += blockDim.x) { T x = in_c[i]; T v = x > s ? s : x; v = v < -s ? -s : v; @@ -293,25 +293,20 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, } } -// ChannelClipAndQuantKernel for quant_axis is 1 +// ChannelClipAndQuantKernel for quant_axis is N template -__global__ void ChannelClipAndQuantKernelQuantAxis1(const T* in, const T* scale, - const int bin_cnt, - const int n, const int cin, - const int cout, T* out) { - T s = scale[blockIdx.x % cout]; - T inv_s = inverse(s); - - int wh_size = n / (cin * cout); - const T* in_c = in + blockIdx.x * wh_size; - T* out_c = out + blockIdx.x * wh_size; - - for (int i = threadIdx.x; i < wh_size; i += blockDim.x) { - T x = in_c[i]; +__global__ void ChannelClipAndQuantKernelQuantAxisN( + const T* in, const T* scale, const int bin_cnt, const int64_t n, + const int nScale, const int quant_stride, T* out) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) { + T s = scale[(i / quant_stride) % nScale]; + T inv_s = 1.0 / s; + T x = in[i]; T v = x > s ? s : x; v = v < -s ? -s : v; v = bin_cnt * inv_s * v; - out_c[i] = round(v); + out[i] = round(v); } } @@ -327,7 +322,7 @@ struct ChannelClipAndFakeQuantFunctor { "the received is %d", quant_axis)); - int num = in.numel(); + int64_t num = in.numel(); auto in_dims = in.dims(); const T* in_data = in.data(); const T* scale_data = scale.data(); @@ -338,11 +333,24 @@ struct ChannelClipAndFakeQuantFunctor { int block = 1024; ChannelClipAndQuantKernelQuantAxis0<<>>( in_data, scale_data, bin_cnt, num, in_dims[0], out_data); - } else if (quant_axis == 1) { - int grid = in_dims[0] * in_dims[1]; - int block = 1024; - ChannelClipAndQuantKernelQuantAxis1<<>>( - in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data); + } else { + int quant_stride = 1; + for (int i = quant_axis + 1; i < in_dims.size(); i++) { + quant_stride *= in_dims[i]; + } + int64_t block_size = + std::min(num, static_cast(ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1), + static_cast(1)); + + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + + ChannelClipAndQuantKernelQuantAxisN<<>>( + in_data, scale_data, bin_cnt, num, in_dims[quant_axis], quant_stride, + out_data); } } }; -- GitLab