top_k_v2_op.cu 10.7 KB
Newer Older
W
wawltor 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/fluid/operators/top_k_v2_op.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

#define FIXED_BLOCK_DIM_BASE(dim, ...) \
  case (dim): {                        \
    constexpr auto kBlockDim = (dim);  \
    __VA_ARGS__;                       \
  } break

#define FIXED_BLOCK_DIM(...)                \
  FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)

template <typename DeviceContext, typename T>
class TopkV2OpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
41 42 43 44
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(ctx.GetPlace()), true,
        platform::errors::InvalidArgument(
            "It must use CUDAPlace, you must check your device set."));
W
wawltor 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
    auto* input = ctx.Input<Tensor>("X");
    auto* output = ctx.Output<Tensor>("Out");
    auto* indices = ctx.Output<Tensor>("Indices");

    // get the attributes
    int k = static_cast<int>(ctx.Attr<int>("k"));
    int axis = static_cast<int>(ctx.Attr<int>("axis"));
    const bool& sorted = static_cast<bool>(ctx.Attr<bool>("sorted"));
    const bool& largest = static_cast<bool>(ctx.Attr<bool>("largest"));

    // get the input dims
    const auto& in_dims = input->dims();
    // calcluate the real axis
    if (axis < 0) axis += in_dims.size();

    auto* k_t = ctx.Input<Tensor>("K");
    if (k_t) {
      Tensor k_host;
      framework::TensorCopySync(*k_t, platform::CPUPlace(), &k_host);
      k = k_host.data<int>()[0];
      framework::DDim output_dims = output->dims();
      output_dims[axis] = k;
      output->Resize(output_dims);
      indices->Resize(output_dims);
    }

    const auto& out_dims = output->dims();

    const T* input_data = input->data<T>();
    T* output_data = output->mutable_data<T>(ctx.GetPlace());
    int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());

    if (axis == in_dims.size() - 1) {
      // if get the topK from the last axis
      const int64_t& input_height = framework::product(
          framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
      const int64_t& input_width = in_dims[in_dims.size() - 1];
      const auto& dev_ctx = ctx.cuda_device_context();

      if (k > input_width) k = input_width;

      if ((input_width <= 1024 || k >= 128 || k == input_width)) {
        if (SortTopk<T>(dev_ctx, input, input_width, input_height, k, output,
                        indices, largest)) {
          // Successed, return.
          return;
        } else {
          LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use "
                       "default topk kernel.";
        }
      }

      // NOTE: pass lds and dim same to input width.
      // NOTE: old matrix implementation of stride is different to eigen.
      const int kMaxHeight = 2048;
      int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
      switch (GetDesiredBlockDim(input_width)) {
        FIXED_BLOCK_DIM(
            KeMatrixTopK<T, 5,
                         kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
                output_data, k, indices_data, input_data, input_width,
                input_width, static_cast<int>(k), gridx, input_height,
                largest));
        default:
          PADDLE_THROW(platform::errors::Fatal(
              "the input data shape has error in the topk cuda kernel."));
      }
    } else {
      // if get topK not from the last axis, will tranpose the tensor and get
      // TopK

      // first step, prepare the trans args for the tranpose
      std::vector<int> trans;
      for (int i = 0; i < axis; i++) {
        trans.emplace_back(i);
      }
      trans.emplace_back(in_dims.size() - 1);
      for (int i = axis + 1; i < in_dims.size() - 1; i++) {
        trans.emplace_back(i);
      }
      trans.emplace_back(axis);

      framework::DDim trans_dims(in_dims);
      framework::DDim trans_out_dims(output->dims());
      for (int i = 0; i < trans.size(); i++) {
        trans_dims[i] = in_dims[trans[i]];
        trans_out_dims[i] = out_dims[trans[i]];
      }
      // second step, tranpose the input
      Tensor trans_input;
      trans_input.mutable_data<T>(trans_dims, ctx.GetPlace());
      int ndims = trans.size();
      const auto& dev_ctx = ctx.cuda_device_context();
      TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *input,
                                                   &trans_input, trans);
      // third step, calcluate the topk
      // allocate the tmp cuda memory for the tmp result
      Tensor trans_ind;
      trans_ind.mutable_data<int64_t>(trans_out_dims, ctx.GetPlace());
      Tensor trans_out;
      trans_out.mutable_data<T>(trans_out_dims, ctx.GetPlace());

      const int64_t input_height = framework::product(
          framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
      const int64_t input_width = trans_dims[trans_dims.size() - 1];

      if (k > input_width) k = input_width;

153 154
      if (((input_width <= 1024 && input_height <= 2048) || k >= 128 ||
           k == input_width)) {
W
wawltor 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
        if (SortTopk<T>(dev_ctx, &trans_input, input_width, input_height, k,
                        &trans_out, &trans_ind, largest)) {
          // last step, tranpose back the indices and output
          TransCompute<platform::CUDADeviceContext, int64_t>(
              ndims, dev_ctx, trans_ind, indices, trans);
          TransCompute<platform::CUDADeviceContext, T>(
              ndims, dev_ctx, trans_out, output, trans);
          return;
        } else {
          LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use "
                       "default topk kernel.";
        }
      }

      const int kMaxHeight = 2048;
      int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
      switch (GetDesiredBlockDim(input_width)) {
        FIXED_BLOCK_DIM(
            KeMatrixTopK<T, 5,
                         kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
                trans_out.data<T>(), k, trans_ind.data<int64_t>(),
                trans_input.data<T>(), input_width, input_width,
                static_cast<int>(k), gridx, input_height, largest));
        default:
          PADDLE_THROW(platform::errors::Fatal(
              "the input data shape has error in the topk cuda kernel."));
      }

      // last step, tranpose back the indices and output
      TransCompute<platform::CUDADeviceContext, int64_t>(
          ndims, dev_ctx, trans_ind, indices, trans);
      TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, trans_out,
                                                   output, trans);
    }
  }
};

#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
template <typename DeviceContext, typename T>
class TopkV2OpGradCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(context.GetPlace()), true,
200 201
        platform::errors::InvalidArgument(
            "It must use CUDAPlace, you must check your device set."));
W
wawltor 已提交
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
    auto* x = context.Input<Tensor>("X");
    auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
    auto* indices = context.Input<Tensor>("Indices");
    auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
    int axis = context.Attr<int>("axis");

    const auto& in_dims = x->dims();
    const auto& out_dims = indices->dims();

    // get the real the axis and the k
    if (axis < 0) axis += in_dims.size();
    const int& k = out_dims[axis];
    const int& raw_height = in_dims[axis];

    // allocate the cuda memory for the x_grad
    T* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
    const T* out_grad_data = out_grad->data<T>();
    const int64_t* indices_data = indices->data<int64_t>();

    int pre, n, post;
    GetDims(in_dims, axis, &pre, &n, &post);

    // calcluate the block and grid num
    auto& dev_ctx = context.cuda_device_context();
    auto ComputeBlockSize = [](int col) {
      if (col > 512)
        return 1024;
      else if (col > 256 && col <= 512)
        return 512;
      else if (col > 128 && col <= 256)
        return 256;
      else if (col > 64 && col <= 128)
        return 128;
      else
        return 64;
    };
    int block_size = ComputeBlockSize(post * k);
    int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
    const int max_blocks = std::max(((max_threads - 1) / block_size + 1), 1);
    int grid_size = std::min(max_blocks, pre);

    // lanuch the cuda kernel to assign the grad
    AssignGradWithAxis<T><<<grid_size, block_size, 64 * 4, dev_ctx.stream()>>>(
        out_grad_data, indices_data, x_grad_data, pre, post, n, k);
  }
};

}  // namespace operators
}  // namespace paddle

REGISTER_OP_CUDA_KERNEL(
    top_k_v2,
    paddle::operators::TopkV2OpCUDAKernel<paddle::platform::CUDADeviceContext,
                                          float>,
    paddle::operators::TopkV2OpCUDAKernel<paddle::platform::CUDADeviceContext,
                                          double>,
    paddle::operators::TopkV2OpCUDAKernel<paddle::platform::CUDADeviceContext,
                                          int>,
    paddle::operators::TopkV2OpCUDAKernel<paddle::platform::CUDADeviceContext,
                                          int64_t>,
    paddle::operators::TopkV2OpCUDAKernel<paddle::platform::CUDADeviceContext,
                                          paddle::platform::float16>);

REGISTER_OP_CUDA_KERNEL(
    top_k_v2_grad, paddle::operators::TopkV2OpGradCUDAKernel<
                       paddle::platform::CUDADeviceContext, float>,
    paddle::operators::TopkV2OpGradCUDAKernel<
        paddle::platform::CUDADeviceContext, double>,
    paddle::operators::TopkV2OpGradCUDAKernel<
        paddle::platform::CUDADeviceContext, int>,
    paddle::operators::TopkV2OpGradCUDAKernel<
        paddle::platform::CUDADeviceContext, int64_t>,
    paddle::operators::TopkV2OpGradCUDAKernel<
        paddle::platform::CUDADeviceContext, paddle::platform::float16>);