未验证 提交 65aac811 编写于 作者: C cc 提交者: GitHub

Fix fake_quant error when cout > 1024, test=develop (#28603)

上级 2cd10fc4
......@@ -62,14 +62,14 @@ __global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale,
T max_range, const int num,
const int cin, const int cout,
T* out) {
int cout_wh_size = num / cin;
int wh_size = cout_wh_size / cout;
int bid = blockIdx.x;
T s = scale[bid % cout];
T s = scale[blockIdx.x];
const T* in_current = in + threadIdx.x * cout_wh_size + blockIdx.x * wh_size;
T* out_current = out + threadIdx.x * cout_wh_size + blockIdx.x * wh_size;
int wh_size = num / (cin * cout);
const T* in_current = in + bid * wh_size;
T* out_current = out + bid * wh_size;
for (int i = 0; i < wh_size; i++) {
for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
out_current[i] = in_current[i] * s / max_range;
}
}
......@@ -107,8 +107,8 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
in_data, scale_factor, max_range, num, in_dims[0], out_data);
} else if (quant_axis == 1) {
// Dequantize weight of Cin * Cout * W * H
int grid = in_dims[1];
int block = in_dims[0];
int grid = in_dims[0] * in_dims[1];
int block = 1024;
DequantizeOneScaleQuantAxis1<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, in_dims[0], in_dims[1],
out_data);
......
......@@ -131,7 +131,7 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
}
__syncthreads();
}
if (tid == 0) {
if (tid == 0 && shared_max_data[0] > out[bid]) {
out[bid] = shared_max_data[0];
}
}
......@@ -148,22 +148,38 @@ struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
quant_axis));
const int num = in_tensor.numel();
auto in_dims = in_tensor.dims();
int channel = in_dims[quant_axis];
const T* in_data = in_tensor.data<T>();
if (quant_axis == 0) {
int grid = channel;
int cout = in_dims[0];
int grid = cout;
int block = 1024;
FindChannelAbsMaxKernelQuantAxis0<
T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
in_data, num, channel, out_abs_max);
in_data, num, cout, out_abs_max);
} else if (quant_axis == 1) {
int grid = in_dims[1];
int block = in_dims[0];
int cin = in_dims[0];
int cout = in_dims[1];
int grid = cout;
int max_threads = 1024;
cudaMemset(out_abs_max, 0, sizeof(T) * cout);
for (int i = 0; i < cin / max_threads; i++) {
int block = max_threads;
FindChannelAbsMaxKernelQuantAxis1<
T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
in_data, num, cin, cout, out_abs_max);
in_data += num / cin;
}
int block = cin % max_threads;
if (block > 0) {
FindChannelAbsMaxKernelQuantAxis1<
T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
in_data, num, in_dims[0], in_dims[1], out_abs_max);
}
}
}
};
template struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, float>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册