index_sample_op.cu 9.6 KB
Newer Older
C
Chengmo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/framework/op_registry.h"
C
Chengmo 已提交
16
#include "paddle/fluid/operators/index_sample_op.h"
17
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
18
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
19
#include "paddle/pten/kernels/funcs/math_function.h"
20

21 22 23 24
#define PREDEFINED_BLOCK_SIZE_X 512
#define PREDEFINED_BLOCK_SIZE 1024
#define MIN(a, b) ((a) < (b) ? (a) : (b))

25 26 27
namespace paddle {
namespace operators {

28 29
namespace {
void LimitGridDim(const framework::ExecutionContext& ctx, dim3* grid_dim) {
W
Wilber 已提交
30
  auto max_grid_dim = ctx.template device_context<platform::CUDADeviceContext>()
31
                          .GetCUDAMaxGridDimSize();
W
Wilber 已提交
32 33
  grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0];
  grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1];
34 35 36
}
}

37 38 39 40 41 42 43
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

template <typename T, typename IndexT = int>
__global__ void IndexSampleForward(const IndexT* index, const T* in_data,
                                   T* out_data, size_t index_length,
                                   size_t input_length, size_t batch_size) {
44 45 46 47 48 49 50 51 52
  unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x;
  unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;
  for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
    for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
      unsigned int index_idx = index_j * index_length + index_i;
      unsigned int in_idx = index_j * input_length + index_i;
      IndexT sample_idx = index[index_idx];
      out_data[index_idx] = in_data[in_idx - index_i + sample_idx];
    }
53 54 55 56 57 58 59 60
  }
}

template <typename T, typename IndexT = int>
__global__ void IndexSampleGrad(const IndexT* index, T* in_grad,
                                const T* out_grad, size_t index_length,
                                size_t input_length, size_t batch_size,
                                bool same_data_in_row = true) {
61 62 63 64 65 66 67 68 69 70 71 72 73 74
  unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x;
  unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;

  for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
    for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
      unsigned int index_idx = index_j * index_length + index_i;
      unsigned int in_idx = index_j * input_length + index_i;
      IndexT sample_idx = index[index_idx];
      if (same_data_in_row) {
        platform::CudaAtomicAdd(&(in_grad[in_idx - index_i + sample_idx]),
                                out_grad[sample_idx]);
      } else {
        in_grad[in_idx - index_i + sample_idx] = out_grad[index_idx];
      }
75 76 77 78 79 80 81 82 83 84 85 86 87
    }
  }
}

template <typename T>
class IndexSampleKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input = ctx.Input<LoDTensor>("X");
    auto* index = ctx.Input<LoDTensor>("Index");
    auto* output = ctx.Output<LoDTensor>("Out");

88
    const auto& index_type = framework::TransToProtoVarType(index->dtype());
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    bool index_type_match = index_type == framework::proto::VarType::INT64 ||
                            index_type == framework::proto::VarType::INT32;
    PADDLE_ENFORCE_EQ(index_type_match, true,
                      platform::errors::InvalidArgument(
                          "Input(Index) holds the wrong type, it holds %s, but "
                          "desires to be %s or %s",
                          paddle::framework::DataTypeToString(index_type),
                          paddle::framework::DataTypeToString(
                              framework::proto::VarType::INT32),
                          paddle::framework::DataTypeToString(
                              framework::proto::VarType::INT64)));
    const auto* in_data = input->data<T>();
    auto* out_data = output->mutable_data<T>(ctx.GetPlace());
    auto stream =
        ctx.template device_context<platform::CUDADeviceContext>().stream();

    auto input_dim = input->dims();
    auto index_dim = index->dims();
    size_t batch_size = input_dim[0];
    size_t input_length = input_dim[1];
    size_t index_length = index_dim[1];

    auto block_width = platform::RoundToPowerOfTwo(index_length);
112
    block_width = MIN(block_width, PREDEFINED_BLOCK_SIZE_X);
113 114
    int block_height =
        platform::RoundToPowerOfTwo(index_length * batch_size) / block_width;
115
    block_height = MIN(block_height, PREDEFINED_BLOCK_SIZE / block_width);
116 117 118
    dim3 block_dim(block_width, block_height);
    dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
                  (batch_size + block_dim.y - 1) / block_dim.y);
119
    LimitGridDim(ctx, &grid_dim);
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146

    if (index_type == framework::proto::VarType::INT64) {
      const int64_t* index_data = index->data<int64_t>();
      IndexSampleForward<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
          index_data, in_data, out_data, index_length, input_length,
          batch_size);
    } else if (index_type == framework::proto::VarType::INT32) {
      const int* index_data = index->data<int>();
      IndexSampleForward<T, int><<<grid_dim, block_dim, 0, stream>>>(
          index_data, in_data, out_data, index_length, input_length,
          batch_size);
    }
  }
};

template <typename T>
class IndexSampleGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* output_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
    auto* input_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
    auto* index = ctx.Input<LoDTensor>("Index");

    const auto* output_grad_data = output_grad->data<T>();
    auto* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());

147
    const auto& index_type = framework::TransToProtoVarType(index->dtype());
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
    bool index_type_match = index_type == framework::proto::VarType::INT64 ||
                            index_type == framework::proto::VarType::INT32;
    PADDLE_ENFORCE_EQ(index_type_match, true,
                      platform::errors::InvalidArgument(
                          "Input(Index) holds the wrong type, it holds %s, but "
                          "desires to be %s or %s",
                          paddle::framework::DataTypeToString(index_type),
                          paddle::framework::DataTypeToString(
                              framework::proto::VarType::INT32),
                          paddle::framework::DataTypeToString(
                              framework::proto::VarType::INT64)));

    auto stream =
        ctx.template device_context<platform::CUDADeviceContext>().stream();
    auto input_num = input_grad->numel();
    auto input_dim = input_grad->dims();
    auto index_dim = index->dims();
    size_t batch_size = index_dim[0];
    size_t input_length = input_dim[1];
    size_t index_length = index_dim[1];
    bool same_data_in_index_row = index_length == 1 ? false : true;

    auto block_width = platform::RoundToPowerOfTwo(index_length);
171
    block_width = MIN(block_width, PREDEFINED_BLOCK_SIZE_X);
172 173
    auto block_height =
        platform::RoundToPowerOfTwo(index_length * batch_size) / block_width;
174
    block_height = MIN(block_height, PREDEFINED_BLOCK_SIZE / block_width);
175 176 177
    dim3 block_dim(block_width, block_height);
    dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
                  (batch_size + block_dim.y - 1) / block_dim.y);
178
    LimitGridDim(ctx, &grid_dim);
179

180
    pten::funcs::SetConstant<platform::CUDADeviceContext, T> set_zero;
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    set_zero(dev_ctx, input_grad, static_cast<T>(0));

    if (index_type == framework::proto::VarType::INT64) {
      const int64_t* index_data = index->data<int64_t>();
      IndexSampleGrad<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
          index_data, input_grad_data, output_grad_data, index_length,
          input_length, batch_size, same_data_in_index_row);
    } else if (index_type == framework::proto::VarType::INT32) {
      const int* index_data = index->data<int>();
      IndexSampleGrad<T, int><<<grid_dim, block_dim, 0, stream>>>(
          index_data, input_grad_data, output_grad_data, index_length,
          input_length, batch_size, same_data_in_index_row);
    }
  }
};

}  // namespace operators
}  // namespace paddle
C
Chengmo 已提交
200 201 202 203 204 205 206 207 208 209 210 211 212 213

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    index_sample,
    ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, float>,
    ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, double>,
    ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, int>,
    ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
    index_sample_grad,
    ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, int>,
    ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, int64_t>);