/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { template __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; extern __shared__ T shared_max_data[]; if (gridDim.x > 1) { shared_max_data[tid] = T(0); for (int i = bid; i < n; i += blockDim.x * gridDim.x) { T tmp = fabs(in[i]); if (tmp > shared_max_data[tid]) { shared_max_data[tid] = tmp; } } } else { if (bid < n) { shared_max_data[tid] = fabs(in[bid]); } else { shared_max_data[tid] = T(0); } } __syncthreads(); for (int i = blockDim.x / 2; i > 0; i >>= 1) { if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { shared_max_data[tid] = shared_max_data[tid + i]; } __syncthreads(); } if (tid == 0) { out[blockIdx.x] = shared_max_data[0]; } } template struct FindAbsMaxFunctor { void operator()(const platform::CUDADeviceContext& ctx, const T* in, const int num, T* out) { int block = 1024; int grid = (block - 1 + num) / block; grid = (grid > block) ? block : grid; framework::Tensor max; T* max_data = max.mutable_data(framework::make_ddim({grid}), ctx.GetPlace()); FindAbsMaxKernel<<>>( in, num, max_data); FindAbsMaxKernel<<<1, block, 1024 * sizeof(T), ctx.stream()>>>( max_data, grid, out); } }; template struct FindAbsMaxFunctor; template __global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c, T* out) { int tid = threadIdx.x; int channel_size = n / c; const T* in_c = in + blockIdx.x * channel_size; extern __shared__ T shared_max_data[]; shared_max_data[tid] = T(0); for (int i = tid; i < channel_size; i += blockDim.x) { T tmp = fabs(in_c[i]); if (tmp > shared_max_data[tid]) { shared_max_data[tid] = tmp; } } __syncthreads(); for (int i = blockDim.x / 2; i > 0; i >>= 1) { if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { shared_max_data[tid] = shared_max_data[tid + i]; } __syncthreads(); } if (tid == 0) { out[blockIdx.x] = shared_max_data[0]; } } template struct FindChannelAbsMaxFunctor { void operator()(const platform::CUDADeviceContext& ctx, const T* in, const int num, const int channel, T* out) { int block = 1024; int grid = channel; FindChannelAbsMaxKernel<<>>( in, num, channel, out); } }; template struct FindChannelAbsMaxFunctor; template __global__ void ClipAndQuantKernel(const T* in, const T* scale, const int bin_cnt, const int n, T* out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; T s = scale[0]; for (int i = bid; i < n; i += blockDim.x * gridDim.x) { T x = in[i]; T v = x > s ? s : x; v = v < -s ? -s : v; v = bin_cnt / s * v; out[i] = round(v); } } template struct ClipAndFakeQuantFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, const int bin_cnt, framework::Tensor* out) { int num = in.numel(); int block = 1024; int grid = (block - 1 + num) / block; const T* in_data = in.data(); const T* scale_data = scale.data(); T* out_data = out->mutable_data(ctx.GetPlace()); ClipAndQuantKernel<<>>( in_data, scale_data, bin_cnt, num, out_data); } }; template struct ClipAndFakeQuantFunctor; template __global__ void ChannelClipAndQuantKernel(const T* in, const T* scale, const int bin_cnt, const int n, const int c, T* out) { int tid = threadIdx.x; int channel_size = n / c; const T* in_c = in + blockIdx.x * channel_size; T* out_c = out + blockIdx.x * channel_size; T s = scale[blockIdx.x]; for (int i = tid; i < channel_size; i += blockDim.x) { T x = in_c[i]; T v = x > s ? s : x; v = v < -s ? -s : v; v = bin_cnt / s * v; out_c[i] = round(v); } } template struct ChannelClipAndFakeQuantFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, const int bin_cnt, const int channel, framework::Tensor* out) { int num = in.numel(); int block = 1024; int grid = channel; const T* in_data = in.data(); const T* scale_data = scale.data(); T* out_data = out->mutable_data(ctx.GetPlace()); ChannelClipAndQuantKernel<<>>( in_data, scale_data, bin_cnt, num, channel, out_data); } }; template struct ChannelClipAndFakeQuantFunctor; template __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, const T* last_scale, const int64_t* iter, const int window_size, T* scale_arr, T* out_scale, int* need_find_max, int* out_size) { int it = iter[0]; int idx = it % window_size; T removed = scale_arr[idx]; T cur = cur_scale[0]; scale_arr[idx] = cur; T max = last_scale[0]; out_scale[0] = max < cur ? cur : max; if (fabs(removed - max) < 1e-6) { need_find_max[0] = 1; out_size[0] = it > window_size ? window_size : it; } else { need_find_max[0] = 0; } } template struct FindRangeAbsMaxFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& cur_scale, const framework::Tensor& last_scale, const framework::Tensor& iter, const int window_size, framework::Tensor* scales_arr, framework::Tensor* out_scale) { const auto gpu_place = boost::get(ctx.GetPlace()); T* scale_arr = scales_arr->mutable_data(gpu_place); T* out_scale_data = out_scale->mutable_data(gpu_place); framework::Tensor need_find_max, out_size; int* find_max = need_find_max.mutable_data(gpu_place); int* out_size_data = out_size.mutable_data(gpu_place); FindRangeAbsMaxAndFillArray<<<1, 1, 0, ctx.stream()>>>( cur_scale.data(), last_scale.data(), iter.data(), window_size, scale_arr, out_scale_data, find_max, out_size_data); int g_find_max; memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max, sizeof(int), ctx.stream()); ctx.Wait(); if (g_find_max) { int len; memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data, sizeof(int), ctx.stream()); ctx.Wait(); FindAbsMaxFunctor()(ctx, scale_arr, len, out_scale_data); } } }; template struct FindRangeAbsMaxFunctor; template struct FindMovingAverageAbsMaxFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& in_accum, const framework::Tensor& in_state, const T* cur_scale, const float rate, framework::Tensor* out_state, framework::Tensor* out_accum, framework::Tensor* out_scale) { const auto gpu_place = boost::get(ctx.GetPlace()); T accum; T state; T scale; memory::Copy(platform::CPUPlace(), &accum, gpu_place, in_accum.data(), sizeof(T), ctx.stream()); memory::Copy(platform::CPUPlace(), &state, gpu_place, in_state.data(), sizeof(T), ctx.stream()); memory::Copy(platform::CPUPlace(), &scale, gpu_place, cur_scale, sizeof(T), ctx.stream()); ctx.Wait(); state = rate * state + 1; accum = rate * accum + scale; scale = accum / state; memory::Copy(gpu_place, out_accum->mutable_data(gpu_place), platform::CPUPlace(), &accum, sizeof(T), ctx.stream()); memory::Copy(gpu_place, out_state->mutable_data(gpu_place), platform::CPUPlace(), &state, sizeof(T), ctx.stream()); memory::Copy(gpu_place, out_scale->mutable_data(gpu_place), platform::CPUPlace(), &scale, sizeof(T), ctx.stream()); ctx.Wait(); } }; template struct FindMovingAverageAbsMaxFunctor; } // namespace operators } // namespace paddle namespace ops = paddle::operators; using CUDA = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxKernel); REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max, ops::FakeChannelWiseQuantizeAbsMaxKernel); REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxKernel); REGISTER_OP_CUDA_KERNEL( fake_quantize_moving_average_abs_max, ops::FakeQuantizeMovingAverageAbsMaxKernel);