From 65aac81191c00dcbe79cd191f595be942a8bb749 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Tue, 17 Nov 2020 10:07:30 +0800 Subject: [PATCH] Fix fake_quant error when cout > 1024, test=develop (#28603) --- paddle/fluid/operators/fake_dequantize_op.cu | 16 ++++----- paddle/fluid/operators/fake_quantize_op.cu | 34 ++++++++++++++------ 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/fake_dequantize_op.cu b/paddle/fluid/operators/fake_dequantize_op.cu index 54a92b055a..a89c430c7a 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cu +++ b/paddle/fluid/operators/fake_dequantize_op.cu @@ -62,14 +62,14 @@ __global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale, T max_range, const int num, const int cin, const int cout, T* out) { - int cout_wh_size = num / cin; - int wh_size = cout_wh_size / cout; + int bid = blockIdx.x; + T s = scale[bid % cout]; - T s = scale[blockIdx.x]; - const T* in_current = in + threadIdx.x * cout_wh_size + blockIdx.x * wh_size; - T* out_current = out + threadIdx.x * cout_wh_size + blockIdx.x * wh_size; + int wh_size = num / (cin * cout); + const T* in_current = in + bid * wh_size; + T* out_current = out + bid * wh_size; - for (int i = 0; i < wh_size; i++) { + for (int i = threadIdx.x; i < wh_size; i += blockDim.x) { out_current[i] = in_current[i] * s / max_range; } } @@ -107,8 +107,8 @@ struct ChannelDequantizeFunctor { in_data, scale_factor, max_range, num, in_dims[0], out_data); } else if (quant_axis == 1) { // Dequantize weight of Cin * Cout * W * H - int grid = in_dims[1]; - int block = in_dims[0]; + int grid = in_dims[0] * in_dims[1]; + int block = 1024; DequantizeOneScaleQuantAxis1<<>>( in_data, scale_factor, max_range, num, in_dims[0], in_dims[1], out_data); diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 8bc14dde86..26dcf8bf39 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -131,7 +131,7 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, } __syncthreads(); } - if (tid == 0) { + if (tid == 0 && shared_max_data[0] > out[bid]) { out[bid] = shared_max_data[0]; } } @@ -148,20 +148,36 @@ struct FindChannelAbsMaxFunctor { quant_axis)); const int num = in_tensor.numel(); auto in_dims = in_tensor.dims(); - int channel = in_dims[quant_axis]; const T* in_data = in_tensor.data(); if (quant_axis == 0) { - int grid = channel; + int cout = in_dims[0]; + int grid = cout; int block = 1024; FindChannelAbsMaxKernelQuantAxis0< T><<>>( - in_data, num, channel, out_abs_max); + in_data, num, cout, out_abs_max); } else if (quant_axis == 1) { - int grid = in_dims[1]; - int block = in_dims[0]; - FindChannelAbsMaxKernelQuantAxis1< - T><<>>( - in_data, num, in_dims[0], in_dims[1], out_abs_max); + int cin = in_dims[0]; + int cout = in_dims[1]; + int grid = cout; + int max_threads = 1024; + + cudaMemset(out_abs_max, 0, sizeof(T) * cout); + + for (int i = 0; i < cin / max_threads; i++) { + int block = max_threads; + FindChannelAbsMaxKernelQuantAxis1< + T><<>>( + in_data, num, cin, cout, out_abs_max); + in_data += num / cin; + } + + int block = cin % max_threads; + if (block > 0) { + FindChannelAbsMaxKernelQuantAxis1< + T><<>>( + in_data, num, in_dims[0], in_dims[1], out_abs_max); + } } } }; -- GitLab