未验证 提交 98a5af1a 编写于 作者: W whs 提交者: GitHub

Fix DequantizeTwoScale kernel (#45632)

上级 a6476418
......@@ -88,16 +88,14 @@ __global__ void DequantizeTwoScale(const T* in,
const T* scale_two,
T max_range,
int num,
int iter_size,
int channel,
int n_scales,
int quant_stride,
T* out) {
int tid = threadIdx.x;
int channel_size = num / (iter_size * channel);
int scale_index = blockIdx.x % channel;
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
for (int i = tid; i < channel_size; i += blockDim.x) {
out_c[i] = in_c[i] * scale_one[scale_index] * scale_two[0] / max_range;
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
int scale_index = (i / quant_stride) % n_scales;
T s = scale_one[scale_index] * scale_two[0];
out[i] = in[i] * s / max_range;
}
}
......@@ -115,6 +113,8 @@ struct ChannelDequantizeFunctor<phi::GPUContext, T> {
const T* in_data = in->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
if (scale_num == 1) {
// Dequantize inputs or weights before quantizable operators and after
// quantization operators. inputs --> quant -- > deqaunt --> conv2d -->
int64_t num = in->numel();
const T* scale_factor = scales[0]->data<T>();
int64_t block_size = std::min(
......@@ -140,25 +140,39 @@ struct ChannelDequantizeFunctor<phi::GPUContext, T> {
quant_stride,
out_data);
} else if (scale_num == 2) {
// Not need to consider quant_axis
int num = in->numel();
int iter_size = 1;
for (int i = 0; i < x_num_col_dims; i++) {
iter_size *= in->dims()[i];
}
int channel = in->dims()[x_num_col_dims];
// Dequantize activations after quantizable operators.
// inputs --> quant --> conv2d --> deqaunt -->
// Note 1: Not need to consider 'quant_axis'. Because 'quant_aixs' is the
// axis of weights to be quantized on while dequantization is applied on
// activations. Note 2: 'x_num_col_dims' is the axis of activations to be
// quantized on. `x_num_col_dims` is -1 for operator in ['matmul',
// 'matmul_v2', 'mul'] and is 1 for other operators.
int64_t num = in->numel();
int n_scales = in->dims()[x_num_col_dims];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
int block = 1024;
int grid = iter_size * channel;
DequantizeTwoScale<T><<<grid, block, 0, dev_ctx.stream()>>>(in_data,
scale_one,
scale_two,
max_range,
num,
iter_size,
channel,
out_data);
int64_t block_size = std::min(
num, static_cast<int64_t>(dev_ctx.GetMaxThreadsPerBlock() / 4));
int64_t max_threads =
dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM
const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1),
static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (num + block_size - 1) / block_size);
int quant_stride = 1;
for (int i = x_num_col_dims + 1; i < in_dims.size(); i++) {
quant_stride *= in_dims[i];
}
DequantizeTwoScale<T>
<<<grid_size, block_size, 0, dev_ctx.stream()>>>(in_data,
scale_one,
scale_two,
max_range,
num,
n_scales,
quant_stride,
out_data);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册