/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_ #define PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_ #endif // PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_ #include #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" namespace paddle { namespace operators { template struct QuantizeDataType { using type = T; }; template <> struct QuantizeDataType { using type = float; }; template __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; extern __shared__ char* shared_max_data_tmp[]; auto shared_max_data = reinterpret_cast(shared_max_data_tmp); if (gridDim.x > 1) { T local_max_data = T(0); for (int i = bid; i < n; i += blockDim.x * gridDim.x) { T tmp = abs(in[i]); if (tmp > local_max_data) { local_max_data = tmp; } } shared_max_data[tid] = local_max_data; } else { if (bid < n) { shared_max_data[tid] = abs(in[bid]); } else { shared_max_data[tid] = T(0); } } __syncthreads(); for (int i = blockDim.x / 2; i > 0; i >>= 1) { if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { shared_max_data[tid] = shared_max_data[tid + i]; } __syncthreads(); } if (tid == 0) { out[blockIdx.x] = shared_max_data[0]; } } template struct FindAbsMaxFunctor { void operator()(const platform::CUDADeviceContext& ctx, const T* in, const int num, T* out) { int block = 1024; int grid = (block - 1 + num) / block; grid = (grid > block) ? block : grid; framework::Tensor max; T* max_data = max.mutable_data(phi::make_ddim({grid}), ctx.GetPlace()); FindAbsMaxKernel<<>>( in, num, max_data); FindAbsMaxKernel<<<1, block, 1024 * sizeof(T), ctx.stream()>>>( max_data, grid, out); } }; template struct FindAbsMaxFunctor; template struct FindAbsMaxFunctor; template __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, const int c, T* out) { int tid = threadIdx.x; int channel_size = n / c; const T* in_c = in + blockIdx.x * channel_size; extern __shared__ char* shared_max_data_tmp[]; auto shared_max_data = reinterpret_cast(shared_max_data_tmp); T local_max_data = T(0); for (int i = tid; i < channel_size; i += blockDim.x) { T tmp = static_cast( fabs(static_cast::type>(in_c[i]))); if (tmp > local_max_data) { local_max_data = tmp; } } shared_max_data[tid] = local_max_data; __syncthreads(); for (int i = blockDim.x / 2; i > 0; i >>= 1) { if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { shared_max_data[tid] = shared_max_data[tid + i]; } __syncthreads(); } if (tid == 0) { out[blockIdx.x] = shared_max_data[0]; } } template __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, const int cin, const int cout, T* out) { extern __shared__ char* shared_max_data_tmp[]; auto shared_max_data = reinterpret_cast(shared_max_data_tmp); int cout_wh_size = n / cin; int wh_size = n / (cin * cout); int tid = threadIdx.x; int bid = blockIdx.x; const T* in_current = in + tid * cout_wh_size + bid * wh_size; T local_max_data = T(0); for (int i = 0; i < wh_size; i++) { T tmp = static_cast( fabs(static_cast::type>(in_current[i]))); if (tmp > local_max_data) { local_max_data = tmp; } } shared_max_data[tid] = local_max_data; __syncthreads(); int len = blockDim.x; for (int i = (len + 1) / 2; i > 0; len = i, i = (i + 1) / 2) { if (tid < i && tid + i < len && shared_max_data[tid] < shared_max_data[tid + i]) { shared_max_data[tid] = shared_max_data[tid + i]; } if (i == 1) { i = 0; // break the loop } __syncthreads(); } if (tid == 0 && shared_max_data[0] > out[bid]) { out[bid] = shared_max_data[0]; } } template struct FindChannelAbsMaxFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& in_tensor, const int quant_axis, T* out_abs_max) { PADDLE_ENFORCE_EQ( quant_axis == 0 || quant_axis == 1, true, platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " "the received is %d", quant_axis)); const int num = in_tensor.numel(); auto in_dims = in_tensor.dims(); const T* in_data = in_tensor.data(); if (quant_axis == 0) { int cout = in_dims[0]; int grid = cout; int block = 1024; FindChannelAbsMaxKernelQuantAxis0< T><<>>( in_data, num, cout, out_abs_max); } else if (quant_axis == 1) { int cin = in_dims[0]; int cout = in_dims[1]; int grid = cout; int max_threads = 1024; #ifdef PADDLE_WITH_HIP hipMemset(out_abs_max, 0, sizeof(T) * cout); #else cudaMemset(out_abs_max, 0, sizeof(T) * cout); #endif // PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_ for (int i = 0; i < cin / max_threads; i++) { int block = max_threads; FindChannelAbsMaxKernelQuantAxis1< T><<>>( in_data, num, cin, cout, out_abs_max); in_data += num / cin; } int block = cin % max_threads; if (block > 0) { FindChannelAbsMaxKernelQuantAxis1< T><<>>( in_data, num, in_dims[0], in_dims[1], out_abs_max); } } } }; template struct FindChannelAbsMaxFunctor; template __global__ void ClipAndQuantKernel(const T* in, const T* scale, const int bin_cnt, const int n, T* out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; using ComputeDataType = typename QuantizeDataType::type; ComputeDataType s = static_cast(scale[0]); ComputeDataType inv_s = inverse(s); ComputeDataType bin_cnt_t = static_cast(bin_cnt); for (int i = bid; i < n; i += blockDim.x * gridDim.x) { ComputeDataType x = static_cast(in[i]); ComputeDataType v = x > s ? s : x; v = v < -s ? -s : v; v = bin_cnt_t * inv_s * v; out[i] = static_cast(round(v)); } } template __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, const int bin_cnt, const int n, T* out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; using ComputeDataType = typename QuantizeDataType::type; ComputeDataType s = static_cast(scale[0]); ComputeDataType inv_s = inverse(s); ComputeDataType bin_cnt_t = static_cast(bin_cnt); for (int i = bid; i < n; i += blockDim.x * gridDim.x) { ComputeDataType x = static_cast(in[i]); x = x > s ? s : x; x = x < -s ? -s : x; x = bin_cnt_t * inv_s * x; x = round(x); out[i] = static_cast((x * s) / bin_cnt_t); } } template struct ClipAndFakeQuantFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, const int bin_cnt, framework::Tensor* out) { int num = in.numel(); int block = 1024; int grid = (block - 1 + num) / block; const T* in_data = in.data(); const T* scale_data = scale.data(); T* out_data = out->mutable_data(ctx.GetPlace()); ClipAndQuantKernel<<>>( in_data, scale_data, bin_cnt, num, out_data); } }; template struct ClipAndFakeQuantFunctor; template struct ClipAndFakeQuantDequantFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, const int bin_cnt, framework::Tensor* out) { int num = in.numel(); int block = 1024; int grid = (block - 1 + num) / block; const T* in_data = in.data(); const T* scale_data = scale.data(); T* out_data = out->mutable_data(ctx.GetPlace()); ClipAndQuantDequantKernel<<>>( in_data, scale_data, bin_cnt, num, out_data); } }; // ChannelClipAndQuantKernel for quant_axis is 0 template __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, const int bin_cnt, const int64_t n, const int c, T* out) { int tid = threadIdx.x; int64_t channel_size = n / c; const T* in_c = in + blockIdx.x * channel_size; T* out_c = out + blockIdx.x * channel_size; using ComputeDataType = typename QuantizeDataType::type; ComputeDataType s = static_cast(scale[blockIdx.x]); ComputeDataType inv_s = inverse(s); ComputeDataType bin_cnt_t = static_cast(bin_cnt); for (int64_t i = tid; i < channel_size; i += blockDim.x) { ComputeDataType x = static_cast(in_c[i]); ComputeDataType v = x > s ? s : x; v = v < -s ? -s : v; v = bin_cnt_t * inv_s * v; out_c[i] = static_cast(round(v)); } } // ChannelClipAndQuantKernel for quant_axis is N template __global__ void ChannelClipAndQuantKernelQuantAxisN( const T* in, const T* scale, const int bin_cnt, const int64_t n, const int nScale, const int quant_stride, T* out) { int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; using ComputeDataType = typename QuantizeDataType::type; ComputeDataType bin_cnt_t = static_cast(bin_cnt); for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) { ComputeDataType s = static_cast(scale[(i / quant_stride) % nScale]); ComputeDataType inv_s = inverse(s); ComputeDataType x = static_cast(in[i]); ComputeDataType v = x > s ? s : x; v = v < -s ? -s : v; v = bin_cnt_t * inv_s * v; out[i] = static_cast(round(v)); } } template struct ChannelClipAndFakeQuantFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, const int bin_cnt, const int quant_axis, framework::Tensor* out) { PADDLE_ENFORCE_EQ( quant_axis == 0 || quant_axis == 1, true, platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " "the received is %d", quant_axis)); int64_t num = in.numel(); auto in_dims = in.dims(); const T* in_data = in.data(); const T* scale_data = scale.data(); T* out_data = out->mutable_data(ctx.GetPlace()); if (quant_axis == 0) { int grid = in_dims[0]; int block = 1024; ChannelClipAndQuantKernelQuantAxis0<<>>( in_data, scale_data, bin_cnt, num, in_dims[0], out_data); } else { int quant_stride = 1; for (int i = quant_axis + 1; i < in_dims.size(); i++) { quant_stride *= in_dims[i]; } int64_t block_size = std::min(num, static_cast(ctx.GetMaxThreadsPerBlock() / 4)); int64_t max_threads = ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1), static_cast(1)); const int64_t grid_size = std::min(max_blocks, (num + block_size - 1) / block_size); ChannelClipAndQuantKernelQuantAxisN<<>>( in_data, scale_data, bin_cnt, num, in_dims[quant_axis], quant_stride, out_data); } } }; template struct ChannelClipAndFakeQuantFunctor; template __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, const T* last_scale, const int64_t* iter, const int window_size, T* scale_arr, T* out_scale, int* need_find_max, int* out_size) { int it = iter[0]; int idx = it % window_size; T removed = scale_arr[idx]; T cur = cur_scale[0]; scale_arr[idx] = cur; T max = last_scale[0]; out_scale[0] = max < cur ? cur : max; if (fabs(static_cast::type>(removed - max)) < 1e-6) { need_find_max[0] = 1; out_size[0] = it > window_size ? window_size : it; } else { need_find_max[0] = 0; } } template struct FindRangeAbsMaxFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& cur_scale, const framework::Tensor& last_scale, const framework::Tensor& iter, const int window_size, framework::Tensor* scales_arr, framework::Tensor* out_scale) { const auto gpu_place = ctx.GetPlace(); T* scale_arr = scales_arr->mutable_data(gpu_place); T* out_scale_data = out_scale->mutable_data(gpu_place); framework::Tensor need_find_max, out_size; int* find_max = need_find_max.mutable_data({1}, gpu_place); int* out_size_data = out_size.mutable_data({1}, gpu_place); FindRangeAbsMaxAndFillArray<<<1, 1, 0, ctx.stream()>>>( cur_scale.data(), last_scale.data(), iter.data(), window_size, scale_arr, out_scale_data, find_max, out_size_data); int g_find_max; memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max, sizeof(int), ctx.stream()); ctx.Wait(); if (g_find_max) { int len; memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data, sizeof(int), ctx.stream()); ctx.Wait(); FindAbsMaxFunctor()(ctx, scale_arr, len, out_scale_data); } } }; template __global__ void FindMovingAverageAbsMaxKernel(const T* in_state, const T* in_accum, const T* cur_scale, const T rate, T* out_state, T* out_accum, T* out_scale) { T state = rate * (*in_state) + T(1.0f); T accum = rate * (*in_accum) + (*cur_scale); *out_state = state; *out_accum = accum; *out_scale = accum / state; } template struct FindRangeAbsMaxFunctor; template struct FindMovingAverageAbsMaxFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& in_accum, const framework::Tensor& in_state, const T* cur_scale, const float rate, framework::Tensor* out_state, framework::Tensor* out_accum, framework::Tensor* out_scale) { const auto gpu_place = ctx.GetPlace(); T rate_t = static_cast(rate); T* out_state_data = out_state->mutable_data(gpu_place); T* out_accum_data = out_accum->mutable_data(gpu_place); T* out_scale_data = out_scale->mutable_data(gpu_place); FindMovingAverageAbsMaxKernel<<<1, 1, 0, ctx.stream()>>>( in_state.data(), in_accum.data(), cur_scale, rate_t, out_state_data, out_accum_data, out_scale_data); } }; // ChannelClipAndQuantDequantKernel for quant_axis is 0 template __global__ void ChannelClipAndQuantDequantKernelQuantAxis0( const T* in, const T* scale, const int bin_cnt, const int n, const int c, T* out) { int tid = threadIdx.x; int channel_size = n / c; const T* in_c = in + blockIdx.x * channel_size; T* out_c = out + blockIdx.x * channel_size; T s = scale[blockIdx.x]; T inv_s = inverse(s); for (int i = tid; i < channel_size; i += blockDim.x) { T x = in_c[i]; T v = x > s ? s : x; v = v < -s ? -s : v; v = bin_cnt * inv_s * v; out_c[i] = round(v) * s / bin_cnt; } } // ChannelClipAndQuantDequantKernel for quant_axis is 1 template __global__ void ChannelClipAndQuantDequantKernelQuantAxis1( const T* in, const T* scale, const int bin_cnt, const int n, const int cin, const int cout, T* out) { T s = scale[blockIdx.x % cout]; T inv_s = inverse(s); int wh_size = n / (cin * cout); const T* in_c = in + blockIdx.x * wh_size; T* out_c = out + blockIdx.x * wh_size; for (int i = threadIdx.x; i < wh_size; i += blockDim.x) { T x = in_c[i]; T v = x > s ? s : x; v = v < -s ? -s : v; v = bin_cnt * inv_s * v; out_c[i] = round(v) * s / bin_cnt; } } template struct ChannelClipFakeQuantDequantFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, const int bin_cnt, const int quant_axis, framework::Tensor* out) { // At present, channelwise quantization supports conv2d, depthwise_conv2d // conv2d_transpose and mul PADDLE_ENFORCE_EQ( quant_axis == 0 || quant_axis == 1, true, platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " "the received is %d", quant_axis)); int num = in.numel(); auto in_dims = in.dims(); const T* in_data = in.data(); const T* scale_data = scale.data(); T* out_data = out->mutable_data(ctx.GetPlace()); if (quant_axis == 0) { int grid = in_dims[0]; int block = 1024; ChannelClipAndQuantDequantKernelQuantAxis0< T><<>>(in_data, scale_data, bin_cnt, num, in_dims[0], out_data); } else if (quant_axis == 1) { int grid = in_dims[0] * in_dims[1]; int block = 1024; ChannelClipAndQuantDequantKernelQuantAxis1< T><<>>( in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data); } } }; template struct ChannelClipFakeQuantDequantFunctor; } // namespace operators } // namespace paddle