/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include "cub/cub.cuh" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/top_k_function_cuda.h" #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/float16.h" // set cub base traits in order to handle float16 namespace paddle { namespace operators { using Tensor = framework::Tensor; #define FIXED_BLOCK_DIM_BASE(dim, ...) \ case (dim): { \ constexpr auto kBlockDim = (dim); \ __VA_ARGS__; \ } break #define FIXED_BLOCK_DIM(...) \ FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) template class TopkOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE_EQ( platform::is_gpu_place(ctx.GetPlace()), true, platform::errors::InvalidArgument("It must use CUDAPlace.")); auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); int k = static_cast(ctx.Attr("k")); auto* k_t = ctx.Input("K"); if (k_t) { Tensor k_host; framework::TensorCopySync(*k_t, platform::CPUPlace(), &k_host); k = k_host.data()[0]; framework::DDim output_dims = output->dims(); output_dims[output_dims.size() - 1] = k; output->Resize(output_dims); indices->Resize(output_dims); } const T* input_data = input->data(); T* output_data = output->mutable_data(ctx.GetPlace()); // FIXME(typhoonzero): data is always converted to type T? framework::DDim inputdims = input->dims(); const int64_t input_height = framework::product( framework::slice_ddim(inputdims, 0, inputdims.size() - 1)); const int64_t input_width = inputdims[inputdims.size() - 1]; const auto& dev_ctx = ctx.cuda_device_context(); if ((input_width <= 1024 || k >= 128 || k == input_width)) { if (SortTopk(dev_ctx, input, input_width, input_height, k, output, indices)) { // Successed, return. return; } else { LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use " "default topk kernel."; } } int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); if (k > input_width) k = input_width; // NOTE: pass lds and dim same to input width. // NOTE: old matrix implementation of stride is different to eigen. // TODO(typhoonzero): refine this kernel. const int kMaxHeight = 2048; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; switch (GetDesiredBlockDim(input_width)) { FIXED_BLOCK_DIM( KeMatrixTopK<<>>( output_data, k, indices_data, input_data, input_width, input_width, static_cast(k), gridx, input_height)); default: PADDLE_THROW(platform::errors::Unavailable( "Calculation error occurred in TopK Operator's CUDA Kernel.")); } } }; template class TopkOpGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE_EQ( platform::is_gpu_place(context.GetPlace()), true, platform::errors::InvalidArgument("It must use CUDAPlace.")); auto* x = context.Input("X"); auto* out_grad = context.Input(framework::GradVarName("Out")); auto* indices = context.Input("Indices"); auto* x_grad = context.Output(framework::GradVarName("X")); T* x_grad_data = x_grad->mutable_data(context.GetPlace()); const T* out_grad_data = out_grad->data(); const int64_t* indices_data = indices->data(); size_t k = indices->dims()[indices->dims().size() - 1]; framework::DDim xdims = x->dims(); const size_t row = framework::product(framework::slice_ddim(xdims, 0, xdims.size() - 1)); const size_t col = xdims[xdims.size() - 1]; const auto& dev_ctx = context.cuda_device_context(); const int kMaxHeight = 2048; int gridx = row < kMaxHeight ? row : kMaxHeight; switch (GetDesiredBlockDim(col)) { FIXED_BLOCK_DIM( AssignGrad<<>>( x_grad_data, indices_data, out_grad_data, row, col, k)); default: PADDLE_THROW( platform::errors::Unavailable("Error occurs when Assign Grad.")); } } }; #undef FIXED_BLOCK_DIM_BASE #undef FIXED_BLOCK_DIM } // namespace operators } // namespace paddle REGISTER_OP_CUDA_KERNEL( top_k, paddle::operators::TopkOpCUDAKernel, paddle::operators::TopkOpCUDAKernel, paddle::operators::TopkOpCUDAKernel, paddle::operators::TopkOpCUDAKernel, paddle::operators::TopkOpCUDAKernel); REGISTER_OP_CUDA_KERNEL( top_k_grad, paddle::operators::TopkOpGradCUDAKernel, paddle::operators::TopkOpGradCUDAKernel, paddle::operators::TopkOpGradCUDAKernel, paddle::operators::TopkOpGradCUDAKernel, paddle::operators::TopkOpGradCUDAKernel);