arg_min_max_op_base.cu.h 6.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#ifdef __NVCC__

#include <cub/cub.cuh>
#include <limits>
#include <string>
#include <typeinfo>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace operators {

namespace {  // NOLINT
template <typename K, typename V>
using KeyValuePair = cub::KeyValuePair<K, V>;
using Tensor = framework::Tensor;

}  // end namespace

#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...)  \
  case (1 << (log2_block_dim)): {                       \
    constexpr auto kBlockDim = (1 << (log2_block_dim)); \
    __VA_ARGS__;                                        \
  } break

#define FIXED_BLOCK_DIM_CASE(...)               \
  FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__);

template <typename T, typename IndType, class Reducer, size_t BlockDim>
W
wawltor 已提交
56 57 58
__global__ void ArgCUDAKernel(const int64_t height,     // n * h
                              const int64_t width,      // c
                              const int64_t post_size,  // h
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
                              const Reducer reducer, const T init, const T* in,
                              IndType* out) {
  typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage temp_storage;

  for (int idx = blockIdx.x; idx < height; idx += gridDim.x) {
    KeyValuePair<int, T> kv_pair = {-1, init};
    int h = idx / post_size;
    int w = idx % post_size;
    for (int k = threadIdx.x; k < width; k += blockDim.x) {
      kv_pair =
          reducer({k, in[h * width * post_size + k * post_size + w]}, kv_pair);
    }
    kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer);
    if (threadIdx.x == 0) {
      out[idx] = static_cast<IndType>(kv_pair.key);
    }
    __syncthreads();
  }
}

template <typename T, typename IndType, class Reducer>
void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
W
wawltor 已提交
82 83
                    Tensor* indices, const int64_t pre, const int64_t post,
                    const int64_t n) {
84
  auto cu_stream = ctx.stream();
W
wawltor 已提交
85
  auto ComputeBlockSize = [](int64_t col) {
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
    if (col > 512)
      return 1024;
    else if (col > 256)
      return 512;
    else if (col > 128)
      return 256;
    else if (col > 64)
      return 128;
    else if (col > 32)
      return 64;
    else if (col > 16)
      return 32;
    else if (col > 8)
      return 16;
    else
      return 8;
  };

W
wawltor 已提交
104 105 106 107
  int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x;
  int64_t height = pre * post;
  int64_t width = n;
  int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx;
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131

  const T* in_data = input.data<T>();
  IndType* out_data = indices->mutable_data<IndType>(ctx.GetPlace());

  if (typeid(Reducer) == typeid(cub::ArgMax)) {
    switch (ComputeBlockSize(width)) {
      FIXED_BLOCK_DIM_CASE(
          ArgCUDAKernel<T, IndType, Reducer,
                        kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>(
              height, width, post, Reducer(), std::numeric_limits<T>::lowest(),
              in_data, out_data));
    }
  } else {
    switch (ComputeBlockSize(width)) {
      FIXED_BLOCK_DIM_CASE(
          ArgCUDAKernel<T, IndType, Reducer,
                        kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>(
              height, width, post, Reducer(), std::numeric_limits<T>::max(),
              in_data, out_data));
    }
  }
}

template <typename T, class Reducer>
W
wawltor 已提交
132 133 134 135 136 137 138
struct VisitDataCudaArgMinMaxFunctor {
  const framework::ExecutionContext& ctx;

  explicit VisitDataCudaArgMinMaxFunctor(const framework::ExecutionContext& ctx)
      : ctx(ctx) {}
  template <typename IndType>
  void apply() const {
139 140 141
    auto* input = ctx.Input<Tensor>("X");
    auto* output = ctx.Output<Tensor>("Out");
    int axis = ctx.Attr<int64_t>("axis");
W
wawltor 已提交
142 143 144 145 146 147 148 149 150 151 152
    const bool& flatten = ctx.Attr<bool>("flatten");

    framework::DDim input_dims;
    if (flatten) {
      input_dims = framework::make_ddim({input->numel()});
      // if flatten, the axis just as 0
      axis = 0;
    } else {
      input_dims = input->dims();
      if (axis < 0) axis += input->dims().size();
    }
153 154

    int64_t numel = input->numel();
W
wawltor 已提交
155
    int64_t groups = numel / input_dims[axis];
156 157
    int64_t pre = 1;
    int64_t post = 1;
W
wawltor 已提交
158
    int64_t n = input_dims[axis];
159 160

    for (int i = 0; i < axis; i++) {
W
wawltor 已提交
161
      pre *= input_dims[i];
162 163
    }

W
wawltor 已提交
164 165
    for (int i = axis + 1; i < input_dims.size(); i++) {
      post *= input_dims[i];
166 167 168
    }

    const auto& dev_ctx = ctx.cuda_device_context();
W
wawltor 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
    ComputeFullArg<T, IndType, Reducer>(dev_ctx, *input, output, pre, post, n);
  }
};
template <typename T, class Reducer>
class ArgMinMaxOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto& dtype = ctx.Attr<int>("dtype");
    if (dtype < 0) {
      framework::VisitDataType(static_cast<framework::proto::VarType::Type>(
                                   framework::proto::VarType::INT64),
                               VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
      return;
    }
    framework::VisitDataType(
        static_cast<framework::proto::VarType::Type>(dtype),
        VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
186 187 188 189 190 191 192
  }
};

#endif

}  // namespace operators
}  // namespace paddle