未验证 提交 98a5af1a 编写于 作者: W whs 提交者: GitHub

Fix DequantizeTwoScale kernel (#45632)

上级 a6476418
...@@ -88,16 +88,14 @@ __global__ void DequantizeTwoScale(const T* in, ...@@ -88,16 +88,14 @@ __global__ void DequantizeTwoScale(const T* in,
const T* scale_two, const T* scale_two,
T max_range, T max_range,
int num, int num,
int iter_size, int n_scales,
int channel, int quant_stride,
T* out) { T* out) {
int tid = threadIdx.x; int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
int channel_size = num / (iter_size * channel); for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
int scale_index = blockIdx.x % channel; int scale_index = (i / quant_stride) % n_scales;
const T* in_c = in + blockIdx.x * channel_size; T s = scale_one[scale_index] * scale_two[0];
T* out_c = out + blockIdx.x * channel_size; out[i] = in[i] * s / max_range;
for (int i = tid; i < channel_size; i += blockDim.x) {
out_c[i] = in_c[i] * scale_one[scale_index] * scale_two[0] / max_range;
} }
} }
...@@ -115,6 +113,8 @@ struct ChannelDequantizeFunctor<phi::GPUContext, T> { ...@@ -115,6 +113,8 @@ struct ChannelDequantizeFunctor<phi::GPUContext, T> {
const T* in_data = in->data<T>(); const T* in_data = in->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace()); T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
if (scale_num == 1) { if (scale_num == 1) {
// Dequantize inputs or weights before quantizable operators and after
// quantization operators. inputs --> quant -- > deqaunt --> conv2d -->
int64_t num = in->numel(); int64_t num = in->numel();
const T* scale_factor = scales[0]->data<T>(); const T* scale_factor = scales[0]->data<T>();
int64_t block_size = std::min( int64_t block_size = std::min(
...@@ -140,25 +140,39 @@ struct ChannelDequantizeFunctor<phi::GPUContext, T> { ...@@ -140,25 +140,39 @@ struct ChannelDequantizeFunctor<phi::GPUContext, T> {
quant_stride, quant_stride,
out_data); out_data);
} else if (scale_num == 2) { } else if (scale_num == 2) {
// Not need to consider quant_axis // Dequantize activations after quantizable operators.
int num = in->numel(); // inputs --> quant --> conv2d --> deqaunt -->
int iter_size = 1; // Note 1: Not need to consider 'quant_axis'. Because 'quant_aixs' is the
for (int i = 0; i < x_num_col_dims; i++) { // axis of weights to be quantized on while dequantization is applied on
iter_size *= in->dims()[i]; // activations. Note 2: 'x_num_col_dims' is the axis of activations to be
} // quantized on. `x_num_col_dims` is -1 for operator in ['matmul',
int channel = in->dims()[x_num_col_dims]; // 'matmul_v2', 'mul'] and is 1 for other operators.
int64_t num = in->numel();
int n_scales = in->dims()[x_num_col_dims];
const T* scale_one = scales[0]->data<T>(); const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>(); const T* scale_two = scales[1]->data<T>();
int block = 1024;
int grid = iter_size * channel; int64_t block_size = std::min(
DequantizeTwoScale<T><<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, static_cast<int64_t>(dev_ctx.GetMaxThreadsPerBlock() / 4));
scale_one, int64_t max_threads =
scale_two, dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM
max_range, const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1),
num, static_cast<int64_t>(1));
iter_size, const int64_t grid_size =
channel, std::min(max_blocks, (num + block_size - 1) / block_size);
out_data); int quant_stride = 1;
for (int i = x_num_col_dims + 1; i < in_dims.size(); i++) {
quant_stride *= in_dims[i];
}
DequantizeTwoScale<T>
<<<grid_size, block_size, 0, dev_ctx.stream()>>>(in_data,
scale_one,
scale_two,
max_range,
num,
n_scales,
quant_stride,
out_data);
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册