From 9c42672820a630d405c60e65e6952161ba8d7f99 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Tue, 27 Sep 2022 17:12:03 +0800 Subject: [PATCH] speedup ChannelClipAndQuantDequantKernelQuantAxis1 kernel (#46471) --- paddle/fluid/operators/fake_quantize_op.cu.h | 92 ++++++++++---------- 1 file changed, 47 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/operators/fake_quantize_op.cu.h b/paddle/fluid/operators/fake_quantize_op.cu.h index 22ba8254cd..9c71cce770 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu.h +++ b/paddle/fluid/operators/fake_quantize_op.cu.h @@ -590,20 +590,16 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis0(const T *in, const T *scale, const int bin_cnt, const int round_type, - const int n, - const int c, + const int wh_size, + const int num, + const int cout, T *out) { - int tid = threadIdx.x; - - int channel_size = n / c; - const T *in_c = in + blockIdx.x * channel_size; - T *out_c = out + blockIdx.x * channel_size; - - T s = scale[blockIdx.x]; - T inv_s = inverse(s); + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - for (int i = tid; i < channel_size; i += blockDim.x) { - T x = in_c[i]; + for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { + T s = scale[(i / wh_size) % cout]; + T inv_s = inverse(s); + T x = in[i]; if (round_type == 0) { x = bin_cnt * inv_s * x; x = roundWithTiesToEven(x); @@ -611,12 +607,12 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis0(const T *in, T min_bound = -bin_cnt - static_cast(1); x = x > max_bound ? max_bound : x; x = x < min_bound ? min_bound : x; - out_c[i] = (x * s) / bin_cnt; + out[i] = (x * s) / bin_cnt; } else { T v = x > s ? s : x; v = v < -s ? -s : v; v = bin_cnt * inv_s * v; - out_c[i] = round(v) * s / bin_cnt; + out[i] = round(v) * s / bin_cnt; } } } @@ -627,19 +623,16 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis1(const T *in, const T *scale, const int bin_cnt, const int round_type, - const int n, - const int cin, + const int wh_size, + const int num, const int cout, T *out) { - T s = scale[blockIdx.x % cout]; - T inv_s = inverse(s); - - int wh_size = n / (cin * cout); - const T *in_c = in + blockIdx.x * wh_size; - T *out_c = out + blockIdx.x * wh_size; + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - for (int i = threadIdx.x; i < wh_size; i += blockDim.x) { - T x = in_c[i]; + for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { + T s = scale[(i / wh_size) % cout]; + T inv_s = inverse(s); + T x = in[i]; if (round_type == 0) { x = bin_cnt * inv_s * x; x = roundWithTiesToEven(x); @@ -647,12 +640,12 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis1(const T *in, T min_bound = -bin_cnt - static_cast(1); x = x > max_bound ? max_bound : x; x = x < min_bound ? min_bound : x; - out_c[i] = (x * s) / bin_cnt; + out[i] = (x * s) / bin_cnt; } else { T v = x > s ? s : x; v = v < -s ? -s : v; v = bin_cnt * inv_s * v; - out_c[i] = round(v) * s / bin_cnt; + out[i] = round(v) * s / bin_cnt; } } } @@ -682,30 +675,39 @@ struct ChannelClipFakeQuantDequantFunctor { const T *scale_data = scale.data(); T *out_data = out->mutable_data(ctx.GetPlace()); + int64_t block_size = + std::min(static_cast(num), + static_cast(ctx.GetMaxThreadsPerBlock() / 4)); + + int64_t max_threads = ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + if (quant_axis == 0) { - int grid = in_dims[0]; - int block = 1024; + const int window_size = num / in_dims[0]; ChannelClipAndQuantDequantKernelQuantAxis0 - <<>>(in_data, - scale_data, - bin_cnt, - round_type, - num, - in_dims[0], - out_data); + <<>>(in_data, + scale_data, + bin_cnt, + round_type, + window_size, + num, + in_dims[0], + out_data); } else if (quant_axis == 1) { - int grid = in_dims[0] * in_dims[1]; - int block = 1024; + const int window_size = num / (in_dims[0] * in_dims[1]); ChannelClipAndQuantDequantKernelQuantAxis1 - <<>>(in_data, - scale_data, - bin_cnt, - round_type, - num, - in_dims[0], - in_dims[1], - out_data); + <<>>(in_data, + scale_data, + bin_cnt, + round_type, + window_size, + num, + in_dims[1], + out_data); } } }; -- GitLab