未验证 提交 65aac811 编写于 作者: C cc 提交者: GitHub

Fix fake_quant error when cout > 1024, test=develop (#28603)

上级 2cd10fc4
...@@ -62,14 +62,14 @@ __global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale, ...@@ -62,14 +62,14 @@ __global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale,
T max_range, const int num, T max_range, const int num,
const int cin, const int cout, const int cin, const int cout,
T* out) { T* out) {
int cout_wh_size = num / cin; int bid = blockIdx.x;
int wh_size = cout_wh_size / cout; T s = scale[bid % cout];
T s = scale[blockIdx.x]; int wh_size = num / (cin * cout);
const T* in_current = in + threadIdx.x * cout_wh_size + blockIdx.x * wh_size; const T* in_current = in + bid * wh_size;
T* out_current = out + threadIdx.x * cout_wh_size + blockIdx.x * wh_size; T* out_current = out + bid * wh_size;
for (int i = 0; i < wh_size; i++) { for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
out_current[i] = in_current[i] * s / max_range; out_current[i] = in_current[i] * s / max_range;
} }
} }
...@@ -107,8 +107,8 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> { ...@@ -107,8 +107,8 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
in_data, scale_factor, max_range, num, in_dims[0], out_data); in_data, scale_factor, max_range, num, in_dims[0], out_data);
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
// Dequantize weight of Cin * Cout * W * H // Dequantize weight of Cin * Cout * W * H
int grid = in_dims[1]; int grid = in_dims[0] * in_dims[1];
int block = in_dims[0]; int block = 1024;
DequantizeOneScaleQuantAxis1<T><<<grid, block, 0, dev_ctx.stream()>>>( DequantizeOneScaleQuantAxis1<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, in_dims[0], in_dims[1], in_data, scale_factor, max_range, num, in_dims[0], in_dims[1],
out_data); out_data);
......
...@@ -131,7 +131,7 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, ...@@ -131,7 +131,7 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
} }
__syncthreads(); __syncthreads();
} }
if (tid == 0) { if (tid == 0 && shared_max_data[0] > out[bid]) {
out[bid] = shared_max_data[0]; out[bid] = shared_max_data[0];
} }
} }
...@@ -148,20 +148,36 @@ struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -148,20 +148,36 @@ struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
quant_axis)); quant_axis));
const int num = in_tensor.numel(); const int num = in_tensor.numel();
auto in_dims = in_tensor.dims(); auto in_dims = in_tensor.dims();
int channel = in_dims[quant_axis];
const T* in_data = in_tensor.data<T>(); const T* in_data = in_tensor.data<T>();
if (quant_axis == 0) { if (quant_axis == 0) {
int grid = channel; int cout = in_dims[0];
int grid = cout;
int block = 1024; int block = 1024;
FindChannelAbsMaxKernelQuantAxis0< FindChannelAbsMaxKernelQuantAxis0<
T><<<grid, block, block * sizeof(T), ctx.stream()>>>( T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
in_data, num, channel, out_abs_max); in_data, num, cout, out_abs_max);
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
int grid = in_dims[1]; int cin = in_dims[0];
int block = in_dims[0]; int cout = in_dims[1];
FindChannelAbsMaxKernelQuantAxis1< int grid = cout;
T><<<grid, block, block * sizeof(T), ctx.stream()>>>( int max_threads = 1024;
in_data, num, in_dims[0], in_dims[1], out_abs_max);
cudaMemset(out_abs_max, 0, sizeof(T) * cout);
for (int i = 0; i < cin / max_threads; i++) {
int block = max_threads;
FindChannelAbsMaxKernelQuantAxis1<
T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
in_data, num, cin, cout, out_abs_max);
in_data += num / cin;
}
int block = cin % max_threads;
if (block > 0) {
FindChannelAbsMaxKernelQuantAxis1<
T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
in_data, num, in_dims[0], in_dims[1], out_abs_max);
}
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册