diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 9f7e4fb8d5749cf6bd54ed3e3bf9699199c0d3e6..70597be393c35e6939b83d86ce2f9be8f2c36805 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -28,13 +28,14 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { extern __shared__ char* shared_max_data_tmp[]; auto shared_max_data = reinterpret_cast(shared_max_data_tmp); if (gridDim.x > 1) { - shared_max_data[tid] = T(0); + T local_max_data = T(0); for (int i = bid; i < n; i += blockDim.x * gridDim.x) { T tmp = abs(in[i]); - if (tmp > shared_max_data[tid]) { - shared_max_data[tid] = tmp; + if (tmp > local_max_data) { + local_max_data = tmp; } } + shared_max_data[tid] = local_max_data; } else { if (bid < n) { shared_max_data[tid] = abs(in[bid]); @@ -83,13 +84,14 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, int channel_size = n / c; const T* in_c = in + blockIdx.x * channel_size; extern __shared__ T shared_max_data[]; - shared_max_data[tid] = T(0); + T local_max_data = T(0); for (int i = tid; i < channel_size; i += blockDim.x) { T tmp = fabs(in_c[i]); - if (tmp > shared_max_data[tid]) { - shared_max_data[tid] = tmp; + if (tmp > local_max_data) { + local_max_data = tmp; } } + shared_max_data[tid] = local_max_data; __syncthreads(); for (int i = blockDim.x / 2; i > 0; i >>= 1) { if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { @@ -113,13 +115,14 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, int tid = threadIdx.x; int bid = blockIdx.x; const T* in_current = in + tid * cout_wh_size + bid * wh_size; - shared_max_data[tid] = T(0); + T local_max_data = T(0); for (int i = 0; i < wh_size; i++) { T tmp = fabs(in_current[i]); - if (tmp > shared_max_data[tid]) { - shared_max_data[tid] = tmp; + if (tmp > local_max_data) { + local_max_data = tmp; } } + shared_max_data[tid] = local_max_data; __syncthreads(); int len = blockDim.x; @@ -404,6 +407,19 @@ struct FindRangeAbsMaxFunctor { } }; +template +__global__ void FindMovingAverageAbsMaxKernel(const T* in_state, + const T* in_accum, + const T* cur_scale, const T rate, + T* out_state, T* out_accum, + T* out_scale) { + T state = rate * (*in_state) + T(1.0f); + T accum = rate * (*in_accum) + (*cur_scale); + *out_state = state; + *out_accum = accum; + *out_scale = accum / state; +} + template struct FindRangeAbsMaxFunctor; template @@ -415,29 +431,14 @@ struct FindMovingAverageAbsMaxFunctor { framework::Tensor* out_accum, framework::Tensor* out_scale) { const auto gpu_place = ctx.GetPlace(); - T accum; - T state; - T scale; - memory::Copy(platform::CPUPlace(), &accum, gpu_place, in_accum.data(), - sizeof(T), ctx.stream()); - memory::Copy(platform::CPUPlace(), &state, gpu_place, in_state.data(), - sizeof(T), ctx.stream()); - memory::Copy(platform::CPUPlace(), &scale, gpu_place, cur_scale, sizeof(T), - ctx.stream()); - ctx.Wait(); - T rate_t = static_cast(rate); - state = rate_t * state + static_cast(1.0); - accum = rate_t * accum + scale; - scale = accum / state; - - memory::Copy(gpu_place, out_accum->mutable_data(gpu_place), - platform::CPUPlace(), &accum, sizeof(T), ctx.stream()); - memory::Copy(gpu_place, out_state->mutable_data(gpu_place), - platform::CPUPlace(), &state, sizeof(T), ctx.stream()); - memory::Copy(gpu_place, out_scale->mutable_data(gpu_place), - platform::CPUPlace(), &scale, sizeof(T), ctx.stream()); - ctx.Wait(); + T* out_state_data = out_state->mutable_data(gpu_place); + T* out_accum_data = out_accum->mutable_data(gpu_place); + T* out_scale_data = out_scale->mutable_data(gpu_place); + + FindMovingAverageAbsMaxKernel<<<1, 1, 0, ctx.stream()>>>( + in_state.data(), in_accum.data(), cur_scale, rate_t, + out_state_data, out_accum_data, out_scale_data); } };