未验证 提交 f5565494 编写于 作者: C ceci3 提交者: GitHub

speedup ChannelClipAndQuantDequantKernelQuantAxis1 kernel (#46471) (#46551)

上级 9cc3f69f
...@@ -590,20 +590,16 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis0(const T *in, ...@@ -590,20 +590,16 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis0(const T *in,
const T *scale, const T *scale,
const int bin_cnt, const int bin_cnt,
const int round_type, const int round_type,
const int n, const int wh_size,
const int c, const int num,
const int cout,
T *out) { T *out) {
int tid = threadIdx.x; int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
int channel_size = n / c;
const T *in_c = in + blockIdx.x * channel_size;
T *out_c = out + blockIdx.x * channel_size;
T s = scale[blockIdx.x]; for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
T s = scale[(i / wh_size) % cout];
T inv_s = inverse(s); T inv_s = inverse(s);
T x = in[i];
for (int i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i];
if (round_type == 0) { if (round_type == 0) {
x = bin_cnt * inv_s * x; x = bin_cnt * inv_s * x;
x = roundWithTiesToEven(x); x = roundWithTiesToEven(x);
...@@ -611,12 +607,12 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis0(const T *in, ...@@ -611,12 +607,12 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis0(const T *in,
T min_bound = -bin_cnt - static_cast<T>(1); T min_bound = -bin_cnt - static_cast<T>(1);
x = x > max_bound ? max_bound : x; x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x; x = x < min_bound ? min_bound : x;
out_c[i] = (x * s) / bin_cnt; out[i] = (x * s) / bin_cnt;
} else { } else {
T v = x > s ? s : x; T v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt * inv_s * v; v = bin_cnt * inv_s * v;
out_c[i] = round(v) * s / bin_cnt; out[i] = round(v) * s / bin_cnt;
} }
} }
} }
...@@ -627,19 +623,16 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis1(const T *in, ...@@ -627,19 +623,16 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis1(const T *in,
const T *scale, const T *scale,
const int bin_cnt, const int bin_cnt,
const int round_type, const int round_type,
const int n, const int wh_size,
const int cin, const int num,
const int cout, const int cout,
T *out) { T *out) {
T s = scale[blockIdx.x % cout]; int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
T inv_s = inverse(s);
int wh_size = n / (cin * cout);
const T *in_c = in + blockIdx.x * wh_size;
T *out_c = out + blockIdx.x * wh_size;
for (int i = threadIdx.x; i < wh_size; i += blockDim.x) { for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
T x = in_c[i]; T s = scale[(i / wh_size) % cout];
T inv_s = inverse(s);
T x = in[i];
if (round_type == 0) { if (round_type == 0) {
x = bin_cnt * inv_s * x; x = bin_cnt * inv_s * x;
x = roundWithTiesToEven(x); x = roundWithTiesToEven(x);
...@@ -647,12 +640,12 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis1(const T *in, ...@@ -647,12 +640,12 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis1(const T *in,
T min_bound = -bin_cnt - static_cast<T>(1); T min_bound = -bin_cnt - static_cast<T>(1);
x = x > max_bound ? max_bound : x; x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x; x = x < min_bound ? min_bound : x;
out_c[i] = (x * s) / bin_cnt; out[i] = (x * s) / bin_cnt;
} else { } else {
T v = x > s ? s : x; T v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt * inv_s * v; v = bin_cnt * inv_s * v;
out_c[i] = round(v) * s / bin_cnt; out[i] = round(v) * s / bin_cnt;
} }
} }
} }
...@@ -682,28 +675,37 @@ struct ChannelClipFakeQuantDequantFunctor<phi::GPUContext, T> { ...@@ -682,28 +675,37 @@ struct ChannelClipFakeQuantDequantFunctor<phi::GPUContext, T> {
const T *scale_data = scale.data<T>(); const T *scale_data = scale.data<T>();
T *out_data = out->mutable_data<T>(ctx.GetPlace()); T *out_data = out->mutable_data<T>(ctx.GetPlace());
int64_t block_size =
std::min(static_cast<int64_t>(num),
static_cast<int64_t>(ctx.GetMaxThreadsPerBlock() / 4));
int64_t max_threads = ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM
const int64_t max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (num + block_size - 1) / block_size);
if (quant_axis == 0) { if (quant_axis == 0) {
int grid = in_dims[0]; const int window_size = num / in_dims[0];
int block = 1024;
ChannelClipAndQuantDequantKernelQuantAxis0<T> ChannelClipAndQuantDequantKernelQuantAxis0<T>
<<<grid, block, 0, ctx.stream()>>>(in_data, <<<grid_size, block_size, 0, ctx.stream()>>>(in_data,
scale_data, scale_data,
bin_cnt, bin_cnt,
round_type, round_type,
window_size,
num, num,
in_dims[0], in_dims[0],
out_data); out_data);
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
int grid = in_dims[0] * in_dims[1]; const int window_size = num / (in_dims[0] * in_dims[1]);
int block = 1024;
ChannelClipAndQuantDequantKernelQuantAxis1<T> ChannelClipAndQuantDequantKernelQuantAxis1<T>
<<<grid, block, 0, ctx.stream()>>>(in_data, <<<grid_size, block_size, 0, ctx.stream()>>>(in_data,
scale_data, scale_data,
bin_cnt, bin_cnt,
round_type, round_type,
window_size,
num, num,
in_dims[0],
in_dims[1], in_dims[1],
out_data); out_data);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册