未验证 提交 827b6a0e 编写于 作者: L Leo Chen 提交者: GitHub

Improve the performance of fake quantize OP (#40491)

* Move the computation of moving average scale to device

* Use register to save local maximum in a thread
上级 3082ed46
...@@ -28,13 +28,14 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { ...@@ -28,13 +28,14 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
extern __shared__ char* shared_max_data_tmp[]; extern __shared__ char* shared_max_data_tmp[];
auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp); auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp);
if (gridDim.x > 1) { if (gridDim.x > 1) {
shared_max_data[tid] = T(0); T local_max_data = T(0);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T tmp = abs(in[i]); T tmp = abs(in[i]);
if (tmp > shared_max_data[tid]) { if (tmp > local_max_data) {
shared_max_data[tid] = tmp; local_max_data = tmp;
} }
} }
shared_max_data[tid] = local_max_data;
} else { } else {
if (bid < n) { if (bid < n) {
shared_max_data[tid] = abs(in[bid]); shared_max_data[tid] = abs(in[bid]);
...@@ -83,13 +84,14 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, ...@@ -83,13 +84,14 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n,
int channel_size = n / c; int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size; const T* in_c = in + blockIdx.x * channel_size;
extern __shared__ T shared_max_data[]; extern __shared__ T shared_max_data[];
shared_max_data[tid] = T(0); T local_max_data = T(0);
for (int i = tid; i < channel_size; i += blockDim.x) { for (int i = tid; i < channel_size; i += blockDim.x) {
T tmp = fabs(in_c[i]); T tmp = fabs(in_c[i]);
if (tmp > shared_max_data[tid]) { if (tmp > local_max_data) {
shared_max_data[tid] = tmp; local_max_data = tmp;
} }
} }
shared_max_data[tid] = local_max_data;
__syncthreads(); __syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) { for (int i = blockDim.x / 2; i > 0; i >>= 1) {
if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
...@@ -113,13 +115,14 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, ...@@ -113,13 +115,14 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
int tid = threadIdx.x; int tid = threadIdx.x;
int bid = blockIdx.x; int bid = blockIdx.x;
const T* in_current = in + tid * cout_wh_size + bid * wh_size; const T* in_current = in + tid * cout_wh_size + bid * wh_size;
shared_max_data[tid] = T(0); T local_max_data = T(0);
for (int i = 0; i < wh_size; i++) { for (int i = 0; i < wh_size; i++) {
T tmp = fabs(in_current[i]); T tmp = fabs(in_current[i]);
if (tmp > shared_max_data[tid]) { if (tmp > local_max_data) {
shared_max_data[tid] = tmp; local_max_data = tmp;
} }
} }
shared_max_data[tid] = local_max_data;
__syncthreads(); __syncthreads();
int len = blockDim.x; int len = blockDim.x;
...@@ -404,6 +407,19 @@ struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -404,6 +407,19 @@ struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
} }
}; };
template <typename T>
__global__ void FindMovingAverageAbsMaxKernel(const T* in_state,
const T* in_accum,
const T* cur_scale, const T rate,
T* out_state, T* out_accum,
T* out_scale) {
T state = rate * (*in_state) + T(1.0f);
T accum = rate * (*in_accum) + (*cur_scale);
*out_state = state;
*out_accum = accum;
*out_scale = accum / state;
}
template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>; template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T> template <typename T>
...@@ -415,29 +431,14 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -415,29 +431,14 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
framework::Tensor* out_accum, framework::Tensor* out_scale) { framework::Tensor* out_accum, framework::Tensor* out_scale) {
const auto gpu_place = ctx.GetPlace(); const auto gpu_place = ctx.GetPlace();
T accum;
T state;
T scale;
memory::Copy(platform::CPUPlace(), &accum, gpu_place, in_accum.data<T>(),
sizeof(T), ctx.stream());
memory::Copy(platform::CPUPlace(), &state, gpu_place, in_state.data<T>(),
sizeof(T), ctx.stream());
memory::Copy(platform::CPUPlace(), &scale, gpu_place, cur_scale, sizeof(T),
ctx.stream());
ctx.Wait();
T rate_t = static_cast<T>(rate); T rate_t = static_cast<T>(rate);
state = rate_t * state + static_cast<T>(1.0); T* out_state_data = out_state->mutable_data<T>(gpu_place);
accum = rate_t * accum + scale; T* out_accum_data = out_accum->mutable_data<T>(gpu_place);
scale = accum / state; T* out_scale_data = out_scale->mutable_data<T>(gpu_place);
memory::Copy(gpu_place, out_accum->mutable_data<T>(gpu_place), FindMovingAverageAbsMaxKernel<T><<<1, 1, 0, ctx.stream()>>>(
platform::CPUPlace(), &accum, sizeof(T), ctx.stream()); in_state.data<T>(), in_accum.data<T>(), cur_scale, rate_t,
memory::Copy(gpu_place, out_state->mutable_data<T>(gpu_place), out_state_data, out_accum_data, out_scale_data);
platform::CPUPlace(), &state, sizeof(T), ctx.stream());
memory::Copy(gpu_place, out_scale->mutable_data<T>(gpu_place),
platform::CPUPlace(), &scale, sizeof(T), ctx.stream());
ctx.Wait();
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册