From 5dbe9e597c4550d1d95cdbe51911e2a3881f8ff1 Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Mon, 2 Dec 2019 16:49:15 +0800 Subject: [PATCH] [cherry-pick] Improve topk performance. (#21087) (#21441) * Improve topk performance. give 200000 data to compute topk, before opt: cost 1s after opt: cost 0.0028s. * Refine return value. * Add cuda util funtions. * Fix ComputeBlockSize bug & refine comments. Signed-off-by: zhaoyuchen --- paddle/fluid/operators/top_k_op.cc | 5 + paddle/fluid/operators/top_k_op.cu | 185 +++++++++++++++++++++++- paddle/fluid/platform/device_context.cc | 5 + paddle/fluid/platform/device_context.h | 4 + paddle/fluid/platform/gpu_info.cc | 27 ++++ paddle/fluid/platform/gpu_info.h | 3 + 6 files changed, 223 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index fdf5148eb87..e394871cccf 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -42,6 +42,11 @@ class TopkOp : public framework::OperatorWithKernel { framework::DDim dims = input_dims; dims[dims.size() - 1] = k; + // If has K as tensor, set k=-1 as not know real size at this time. + if (ctx->HasInput("K")) { + dims[dims.size() - 1] = -1; + } + ctx->SetOutputDim("Out", dims); ctx->SetOutputDim("Indices", dims); ctx->ShareLoD("X", "Out"); diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index fe243a3b87d..0375611dfb5 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -12,11 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "cub/cub.cuh" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/float16.h" +// set cub base traits in order to handle float16 +namespace cub { +template <> +struct NumericTraits + : BaseTraits {}; +} // namespace cub + namespace paddle { namespace operators { @@ -303,6 +312,160 @@ inline static int GetDesiredBlockDim(int dim) { } } +// Iter for move to next row +struct SegmentOffsetIter { + EIGEN_DEVICE_FUNC + explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const { + return idx * num_cols_; + } + + int num_cols_; +}; + +// Iter using into a column +struct ColumnIndexIter { + explicit ColumnIndexIter(int num_cols) : num_cols_(num_cols) {} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()( + const Eigen::array& ix) const { + return ix[0] % num_cols_; + } + + int num_cols_; +}; + +__global__ void InitIndex(int64_t* indices, int64_t num_rows, + int64_t num_cols) { + int col_id = threadIdx.x; + int row_id = blockIdx.x; + + for (int64_t j = row_id; j < num_rows; j += gridDim.x) { + for (int64_t i = col_id; i < num_cols; i += blockDim.x) { + indices[j * num_cols + i] = i; + } + } +} + +template +bool SortTopk(const platform::CUDADeviceContext& ctx, + const framework::Tensor* input_tensor, const int64_t num_cols, + const int64_t num_rows, const int k, + framework::Tensor* out_tensor, + framework::Tensor* indices_tensor) { + auto cu_stream = ctx.stream(); + + Tensor input_indices; + const std::vector dims = {num_rows, num_cols}; + auto dim = framework::make_ddim(dims); + input_indices.Resize(dim); + // input_indices.Resize(num_rows*num_cols); + input_indices.mutable_data(ctx.GetPlace()); + size_t temp_storage_bytes = -1; + + auto ComputeBlockSize = [](int col) { + if (col > 512) + return 1024; + else if (col > 256 && col <= 512) + return 512; + else if (col > 128 && col <= 256) + return 256; + else if (col > 64 && col <= 128) + return 128; + else + return 64; + }; + + int block_size = ComputeBlockSize(num_cols); + + unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x; + // actually, int num_rows < max_grid_size + unsigned int grid_size = num_rows < maxGridDimX + ? static_cast(num_rows) + : maxGridDimX; + // Init a index array + InitIndex<<>>( + input_indices.data(), num_rows, num_cols); + + // create iter for counting input + cub::CountingInputIterator counting_iter(0); + // segment_offset is used for move to next row + cub::TransformInputIterator> + segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols)); + + T* sorted_values_ptr; + int64_t* sorted_indices_ptr; + + Tensor temp_values; + Tensor temp_indices; + + const T* input = input_tensor->data(); + T* values = out_tensor->data(); + int64_t* indices = indices_tensor->mutable_data(ctx.GetPlace()); + + if (k == num_cols) { + // Doing a full sort. + sorted_values_ptr = values; + sorted_indices_ptr = indices; + } else { + temp_values.Resize(dim); + temp_indices.Resize(dim); + sorted_values_ptr = temp_values.mutable_data(ctx.GetPlace()); + sorted_indices_ptr = temp_indices.mutable_data(ctx.GetPlace()); + } + + // Get temp storage buffer size, maybe can allocate a fixed buffer to save + // time. + auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, temp_storage_bytes, input, sorted_values_ptr, + input_indices.data(), sorted_indices_ptr, num_cols * num_rows, + num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, + cu_stream); + if (err != cudaSuccess) { + LOG(ERROR) + << "TopKOP failed as could not launch " + "cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate " + "temp_storage_bytes, status: " + << cudaGetErrorString(err); + return false; + } + Tensor temp_storage; + temp_storage.mutable_data(ctx.GetPlace(), temp_storage_bytes); + + err = cub::DeviceSegmentedRadixSort::SortPairsDescending( + temp_storage.data(), temp_storage_bytes, input, + sorted_values_ptr, input_indices.data(), sorted_indices_ptr, + num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1, + 0, sizeof(T) * 8, cu_stream); + if (err != cudaSuccess) { + LOG(ERROR) + << "TopKOP failed as could not launch " + "cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, " + "temp_storage_bytes: " + << temp_storage_bytes << ", status: " << cudaGetErrorString(err); + return false; + } + auto& dev = *ctx.eigen_device(); + if (k < num_cols) { + // copy sliced data to output. + const Eigen::DSizes slice_indices{0, 0}; + const Eigen::DSizes slice_sizes{num_rows, k}; + auto e_indices = EigenMatrix::From(*indices_tensor, dim); + auto e_tmp_indices = EigenMatrix::From(temp_indices); + + std::vector odims = {static_cast(num_rows), static_cast(k)}; + auto dim = framework::make_ddim(odims); + auto e_values = EigenMatrix::From(*out_tensor, dim); + auto e_tmp_values = EigenMatrix::From(temp_values); + + e_indices.device(dev) = e_tmp_indices.slice(slice_indices, slice_sizes); + e_values.device(dev) = e_tmp_values.slice(slice_indices, slice_sizes); + } + return true; +} + #define FIXED_BLOCK_DIM_BASE(dim, ...) \ case (dim): { \ constexpr auto kBlockDim = (dim); \ @@ -324,7 +487,7 @@ class TopkOpCUDAKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); - size_t k = static_cast(ctx.Attr("k")); + int k = static_cast(ctx.Attr("k")); auto* k_t = ctx.Input("K"); if (k_t) { @@ -340,13 +503,24 @@ class TopkOpCUDAKernel : public framework::OpKernel { const T* input_data = input->data(); T* output_data = output->mutable_data(ctx.GetPlace()); // FIXME(typhoonzero): data is always converted to type T? - int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); framework::DDim inputdims = input->dims(); - const size_t input_height = framework::product( + const int64_t input_height = framework::product( framework::slice_ddim(inputdims, 0, inputdims.size() - 1)); - const size_t input_width = inputdims[inputdims.size() - 1]; - + const int64_t input_width = inputdims[inputdims.size() - 1]; + const auto& dev_ctx = ctx.cuda_device_context(); + + if ((input_width <= 1024 || k >= 128 || k == input_width)) { + if (SortTopk(dev_ctx, input, input_width, input_height, k, output, + indices)) { + // Successed, return. + return; + } else { + LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use " + "default topk kernel."; + } + } + int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); if (k > input_width) k = input_width; // NOTE: pass lds and dim same to input width. @@ -354,7 +528,6 @@ class TopkOpCUDAKernel : public framework::OpKernel { // TODO(typhoonzero): refine this kernel. const int kMaxHeight = 2048; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; - auto& dev_ctx = ctx.cuda_device_context(); switch (GetDesiredBlockDim(input_width)) { FIXED_BLOCK_DIM( KeMatrixTopKReinitialize(&stream_, place); @@ -351,6 +352,10 @@ bool CUDADeviceContext::tensor_core_available() const { return cublas_tensor_core_handle_ != nullptr; } +dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const { + return max_grid_dim_size_; +} + cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 4fec2e9350f..935de171159 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -96,6 +96,9 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return the max physical thread count in the device context */ int GetMaxPhysicalThreadCount() const; + /*! \brief Return the max grid dim size in the device context */ + dim3 GetCUDAMaxGridDimSize() const; + /*! \brief Return eigen device in the device context. */ Eigen::GpuDevice* eigen_device() const; @@ -184,6 +187,7 @@ class CUDADeviceContext : public DeviceContext { int driver_version_; int multi_process_; int max_threads_per_mp_; + dim3 max_grid_dim_size_; // StreamCallbackManager is thread-safe std::unique_ptr callback_manager_; diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index add5cabd444..c68ce998f36 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -85,6 +85,33 @@ int GetCUDAComputeCapability(int id) { return device_prop.major * 10 + device_prop.minor; } +dim3 GetGpuMaxGridDimSize(int id) { + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); + dim3 ret; + int size; + auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id); + PADDLE_ENFORCE_EQ(error_code_x, 0, + "cudaDevAttrMaxGridDimX failed in " + "paddle::platform::GpuMaxGridDimSize, error code : %d, %s", + error_code_x, CudaErrorWebsite()); + ret.x = size; + + auto error_code_y = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimY, id); + PADDLE_ENFORCE_EQ(error_code_y, 0, + "cudaDevAttrMaxGridDimY failed in " + "paddle::platform::GpuMaxGridDimSize, error code : %d, %s", + error_code_y, CudaErrorWebsite()); + ret.y = size; + + auto error_code_z = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimZ, id); + PADDLE_ENFORCE_EQ(error_code_z, 0, + "cudaDevAttrMaxGridDimZ failed in " + "paddle::platform::GpuMaxGridDimSize, error code : %d, %s", + error_code_z, CudaErrorWebsite()); + ret.z = size; + return ret; +} + int GetCUDARuntimeVersion(int id) { PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); int runtime_version = 0; diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index e468c4aab0b..78ad1f7b3ed 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -48,6 +48,9 @@ int GetCUDAMaxThreadsPerMultiProcessor(int i); //! Get the current GPU device id in system. int GetCurrentDeviceId(); +//! Get the maximum GridDim size for GPU buddy allocator. +dim3 GetGpuMaxGridDimSize(int); + //! Get a list of device ids from environment variable or use all. std::vector GetSelectedDevices(); -- GitLab