// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/kthvalue_op.h" #include "paddle/fluid/operators/top_k_function_cuda.h" #include "paddle/fluid/operators/top_k_v2_op.h" #ifdef __NVCC__ #include "cub/cub.cuh" #endif #ifdef __HIPCC__ #include #endif namespace paddle { namespace operators { int getBlockSize(int col) { if (col > 512) return 1024; else if (col > 256 && col <= 512) return 512; else if (col > 128 && col <= 256) return 256; else if (col > 64 && col <= 128) return 128; else return 64; } template bool SortKthvalue(const platform::CUDADeviceContext& ctx, const framework::Tensor* input_tensor, const int64_t num_cols, const int64_t num_rows, const int k, framework::Tensor* out_tensor, framework::Tensor* indices_tensor) { auto cu_stream = ctx.stream(); framework::Tensor input_indices; const std::vector dims = {num_rows, num_cols}; auto dim = framework::make_ddim(dims); input_indices.Resize(dim); input_indices.mutable_data(ctx.GetPlace()); size_t temp_storage_bytes = -1; int block_size = getBlockSize(num_cols); unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x; unsigned int grid_size = num_rows < maxGridDimX ? static_cast(num_rows) : maxGridDimX; InitIndex<<>>( input_indices.data(), num_rows, num_cols); cub::CountingInputIterator counting_iter(0); cub::TransformInputIterator> segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols)); T* sorted_values_ptr; int64_t* sorted_indices_ptr; framework::Tensor temp_values, temp_indices; const T* input = input_tensor->data(); T* values = out_tensor->data(); int64_t* indices = indices_tensor->mutable_data(ctx.GetPlace()); temp_values.Resize(dim); temp_indices.Resize(dim); sorted_values_ptr = temp_values.mutable_data(ctx.GetPlace()); sorted_indices_ptr = temp_indices.mutable_data(ctx.GetPlace()); auto err = cub::DeviceSegmentedRadixSort::SortPairs( nullptr, temp_storage_bytes, input, sorted_values_ptr, input_indices.data(), sorted_indices_ptr, num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, cu_stream); #ifdef __HIPCC__ if (err != hipSuccess) { LOG(ERROR) << "KthvalueOP failed as could not launch " "hipcub::DeviceSegmentedRadixSort::SortPairs, status: " << hipGetErrorString(err); return false; } #else if (err != cudaSuccess) { LOG(ERROR) << "KthvalueOP failed as could not launch " "cub::DeviceSegmentedRadixSort::SortPairs, status: " << cudaGetErrorString(err); return false; } #endif framework::Tensor temp_storage; temp_storage.mutable_data(ctx.GetPlace(), temp_storage_bytes); err = cub::DeviceSegmentedRadixSort::SortPairs( temp_storage.data(), temp_storage_bytes, input, sorted_values_ptr, input_indices.data(), sorted_indices_ptr, num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, cu_stream); #ifdef __HIPCC__ if (err != hipSuccess) { LOG(ERROR) << "KthvalueOP failed as could not launch " "hipcub::DeviceSegmentedRadixSort::SortPairs, " << temp_storage_bytes << ", status: " << hipGetErrorString(err); return false; } #else if (err != cudaSuccess) { LOG(ERROR) << "KthvalueOP failed as could not launch " "cub::DeviceSegmentedRadixSort::SortPairs, " << temp_storage_bytes << ", status: " << cudaGetErrorString(err); return false; } #endif auto& dev = *ctx.eigen_device(); const Eigen::DSizes slice_indices{0, k - 1}; const Eigen::DSizes slice_sizes{num_rows, 1}; auto e_indices = framework::EigenMatrix::From(*indices_tensor, dim); auto e_tmp_indices = framework::EigenMatrix::From( static_cast(temp_indices)); std::vector odims = {static_cast(num_rows), static_cast(1)}; dim = framework::make_ddim(odims); auto e_values = framework::EigenMatrix::From(*out_tensor, dim); auto e_tmp_values = framework::EigenMatrix::From( static_cast(temp_values)); EigenSlice, int64_t, 2>::Eval( dev, e_indices, e_tmp_indices, slice_indices, slice_sizes); EigenSlice, T, 2>::Eval( dev, e_values, e_tmp_values, slice_indices, slice_sizes); return true; } template class KthvalueOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE_EQ( platform::is_gpu_place(ctx.GetPlace()), true, platform::errors::InvalidArgument( "It must use CUDAPlace, you must check your device set.")); auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); int k = static_cast(ctx.Attr("k")); int axis = static_cast(ctx.Attr("axis")); bool keepdim = static_cast(ctx.Attr("keepdim")); const auto& in_dims = input->dims(); if (axis < 0) axis += in_dims.size(); auto out_dims = output->dims(); const T* input_data = input->data(); T* output_data = output->mutable_data(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); if (axis == in_dims.size() - 1) { const int64_t& input_height = framework::product( framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); const int64_t& input_width = in_dims[in_dims.size() - 1]; const auto& dev_ctx = ctx.cuda_device_context(); PADDLE_ENFORCE_EQ(SortKthvalue(dev_ctx, input, input_width, input_height, k, output, indices), true, platform::errors::External( "KthvalueOP: Error when use cub sorting")); return; } else { std::vector trans; for (int i = 0; i < axis; i++) { trans.emplace_back(i); } trans.emplace_back(in_dims.size() - 1); for (int i = axis + 1; i < in_dims.size() - 1; i++) { trans.emplace_back(i); } trans.emplace_back(axis); if (!keepdim) { std::vector tmp_out_shape; for (int i = 0; i < axis; i++) { tmp_out_shape.emplace_back(in_dims[i]); } tmp_out_shape.emplace_back(1); for (int i = axis + 1; i < in_dims.size(); i++) { tmp_out_shape.emplace_back(in_dims[i]); } framework::DDim tmp_out_dims = framework::make_ddim(tmp_out_shape); output->Resize(tmp_out_dims); indices->Resize(tmp_out_dims); } framework::DDim trans_dims(in_dims); framework::DDim trans_out_dims(in_dims); for (int i = 0; i < trans.size(); i++) { trans_dims[i] = in_dims[trans[i]]; trans_out_dims[i] = in_dims[trans[i]]; } trans_out_dims[in_dims.size() - 1] = 1; framework::Tensor trans_input; trans_input.mutable_data(trans_dims, ctx.GetPlace()); int ndims = trans.size(); const auto& dev_ctx = ctx.cuda_device_context(); TransCompute(ndims, dev_ctx, *input, &trans_input, trans); framework::Tensor trans_ind, trans_out; trans_ind.mutable_data(trans_out_dims, ctx.GetPlace()); trans_out.mutable_data(trans_out_dims, ctx.GetPlace()); const int64_t input_height = framework::product( framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); const int64_t input_width = trans_dims[trans_dims.size() - 1]; PADDLE_ENFORCE_EQ( SortKthvalue(dev_ctx, &trans_input, input_width, input_height, k, &trans_out, &trans_ind), true, platform::errors::External("KthvalueOP: Error when use cub sorting")); TransCompute( ndims, dev_ctx, trans_ind, indices, trans); TransCompute(ndims, dev_ctx, trans_out, output, trans); if (!keepdim) { output->Resize(out_dims); indices->Resize(out_dims); } } } }; template class KthvalueOpGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE_EQ( platform::is_gpu_place(context.GetPlace()), true, platform::errors::InvalidArgument( "It must use CUDAPlace, you must check your device set.")); auto* x = context.Input("X"); auto* out_grad = context.Input(framework::GradVarName("Out")); auto* indices = context.Input("Indices"); auto* x_grad = context.Output(framework::GradVarName("X")); int axis = context.Attr("axis"); int k = static_cast(context.Attr("k")); const auto& in_dims = x->dims(); auto out_dims = indices->dims(); if (axis < 0) axis += in_dims.size(); T* x_grad_data = x_grad->mutable_data(context.GetPlace()); const T* out_grad_data = out_grad->data(); const int64_t* indices_data = indices->data(); int pre, n, post; GetDims(in_dims, axis, &pre, &n, &post); auto& dev_ctx = context.cuda_device_context(); int block_size = getBlockSize(post * k); int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(((max_threads - 1) / block_size + 1), 1); int grid_size = std::min(max_blocks, pre); AssignGradWithAxis<<>>( out_grad_data, indices_data, x_grad_data, pre, post, n, 1); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( kthvalue, ops::KthvalueOpCUDAKernel, ops::KthvalueOpCUDAKernel, ops::KthvalueOpCUDAKernel, ops::KthvalueOpCUDAKernel); REGISTER_OP_CUDA_KERNEL( kthvalue_grad, ops::KthvalueOpGradCUDAKernel, ops::KthvalueOpGradCUDAKernel, ops::KthvalueOpGradCUDAKernel, ops::KthvalueOpGradCUDAKernel);