From 98a5af1aef00679544284e2beac2bd4ade6c1d0b Mon Sep 17 00:00:00 2001 From: whs Date: Tue, 6 Sep 2022 13:40:18 +0800 Subject: [PATCH] Fix DequantizeTwoScale kernel (#45632) --- .../fluid/operators/fake_dequantize_op.cu.h | 66 +++++++++++-------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/fake_dequantize_op.cu.h b/paddle/fluid/operators/fake_dequantize_op.cu.h index 161b87ea392..17b0d978716 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cu.h +++ b/paddle/fluid/operators/fake_dequantize_op.cu.h @@ -88,16 +88,14 @@ __global__ void DequantizeTwoScale(const T* in, const T* scale_two, T max_range, int num, - int iter_size, - int channel, + int n_scales, + int quant_stride, T* out) { - int tid = threadIdx.x; - int channel_size = num / (iter_size * channel); - int scale_index = blockIdx.x % channel; - const T* in_c = in + blockIdx.x * channel_size; - T* out_c = out + blockIdx.x * channel_size; - for (int i = tid; i < channel_size; i += blockDim.x) { - out_c[i] = in_c[i] * scale_one[scale_index] * scale_two[0] / max_range; + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { + int scale_index = (i / quant_stride) % n_scales; + T s = scale_one[scale_index] * scale_two[0]; + out[i] = in[i] * s / max_range; } } @@ -115,6 +113,8 @@ struct ChannelDequantizeFunctor { const T* in_data = in->data(); T* out_data = out->mutable_data(dev_ctx.GetPlace()); if (scale_num == 1) { + // Dequantize inputs or weights before quantizable operators and after + // quantization operators. inputs --> quant -- > deqaunt --> conv2d --> int64_t num = in->numel(); const T* scale_factor = scales[0]->data(); int64_t block_size = std::min( @@ -140,25 +140,39 @@ struct ChannelDequantizeFunctor { quant_stride, out_data); } else if (scale_num == 2) { - // Not need to consider quant_axis - int num = in->numel(); - int iter_size = 1; - for (int i = 0; i < x_num_col_dims; i++) { - iter_size *= in->dims()[i]; - } - int channel = in->dims()[x_num_col_dims]; + // Dequantize activations after quantizable operators. + // inputs --> quant --> conv2d --> deqaunt --> + // Note 1: Not need to consider 'quant_axis'. Because 'quant_aixs' is the + // axis of weights to be quantized on while dequantization is applied on + // activations. Note 2: 'x_num_col_dims' is the axis of activations to be + // quantized on. `x_num_col_dims` is -1 for operator in ['matmul', + // 'matmul_v2', 'mul'] and is 1 for other operators. + int64_t num = in->numel(); + int n_scales = in->dims()[x_num_col_dims]; const T* scale_one = scales[0]->data(); const T* scale_two = scales[1]->data(); - int block = 1024; - int grid = iter_size * channel; - DequantizeTwoScale<<>>(in_data, - scale_one, - scale_two, - max_range, - num, - iter_size, - channel, - out_data); + + int64_t block_size = std::min( + num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1), + static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + int quant_stride = 1; + for (int i = x_num_col_dims + 1; i < in_dims.size(); i++) { + quant_stride *= in_dims[i]; + } + DequantizeTwoScale + <<>>(in_data, + scale_one, + scale_two, + max_range, + num, + n_scales, + quant_stride, + out_data); } } }; -- GitLab