roll_op.cu 7.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

1
123malin 已提交
15
#pragma once
S
sunli 已提交
16
#include "paddle/fluid/framework/array.h"
1
123malin 已提交
17
#include "paddle/fluid/framework/op_registry.h"
18
#include "paddle/fluid/operators/roll_op.h"
1
123malin 已提交
19 20 21 22 23 24 25 26 27
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

using platform::PADDLE_CUDA_NUM_THREADS;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

S
sunli 已提交
28 29 30 31 32
template <typename T, size_t Rank>
__global__ void RollCudaKernel(const T* input, T* output, int64_t N,
                               paddle::framework::Array<int64_t, Rank> shifts,
                               paddle::framework::Array<int64_t, Rank> strides,
                               paddle::framework::Array<int64_t, Rank> sizes) {
1
123malin 已提交
33 34 35 36
  int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx >= N) {
    return;
  }
S
sunli 已提交
37

1
123malin 已提交
38
  int64_t output_idx = idx;
39
  int64_t new_dim_idx = 0;
S
sunli 已提交
40

41
#pragma unroll
S
sunli 已提交
42
  for (size_t i = 0; i < Rank; i++) {
43 44 45 46 47 48
    new_dim_idx = (idx / strides[i]) % sizes[i] + shifts[i];
    if (new_dim_idx >= sizes[i]) {
      output_idx += (shifts[i] - sizes[i]) * strides[i];
    } else {
      output_idx += shifts[i] * strides[i];
    }
1
123malin 已提交
49 50 51 52
  }
  output[output_idx] = input[idx];
}

S
sunli 已提交
53 54 55
template <typename T>
class RollKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
1
123malin 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* in = context.Input<LoDTensor>("X");
    auto* out = context.Output<LoDTensor>("Out");
    std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
    std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");

    auto* in_data = in->data<T>();
    auto* out_data = out->mutable_data<T>(context.GetPlace());
    int64_t numel = in->numel();
    auto stream =
        context.template device_context<platform::CUDADeviceContext>().stream();

    size_t nums = shifts.size();
    auto input_dim = in->dims();
    auto stride_dim = framework::stride(input_dim);

S
sunli 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
    std::vector<int64_t> strides(nums), sizes(nums);
    if (dims.size() == 0) {
      strides[0] = 1;
      sizes[0] = numel;
      shifts[0] = (shifts[0] % numel + numel) % numel;
    } else {
      for (size_t i = 0; i < nums; i++) {
        int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size();
        int64_t size = input_dim[dim];

        shifts[i] = (shifts[i] % size + size) % size;
        strides[i] = stride_dim[dim];
        sizes[i] = size;
      }
    }

#define CALL_ROLL_CUDA_KERNEL(N)                                               \
  case N: {                                                                    \
    paddle::framework::Array<int64_t, N> _strides;                             \
    paddle::framework::Array<int64_t, N> _shifts;                              \
    paddle::framework::Array<int64_t, N> _sizes;                               \
    for (size_t idx = 0; idx < N; ++idx) {                                     \
      _strides[idx] = strides[idx];                                            \
      _shifts[idx] = shifts[idx];                                              \
      _sizes[idx] = sizes[idx];                                                \
    }                                                                          \
    RollCudaKernel<                                                            \
        T,                                                                     \
        N><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,  \
             PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, numel,   \
                                                   _shifts, _strides, _sizes); \
    break;                                                                     \
  }

    switch (nums) {
      CALL_ROLL_CUDA_KERNEL(1);
      CALL_ROLL_CUDA_KERNEL(2);
      CALL_ROLL_CUDA_KERNEL(3);
      CALL_ROLL_CUDA_KERNEL(4);
      CALL_ROLL_CUDA_KERNEL(5);
      CALL_ROLL_CUDA_KERNEL(6);
      CALL_ROLL_CUDA_KERNEL(7);
      CALL_ROLL_CUDA_KERNEL(8);
      CALL_ROLL_CUDA_KERNEL(9);
      default:
        PADDLE_THROW(platform::errors::InvalidArgument(
            "shifts.size() should be less than 10, But received shifts.size() "
            "= %d",
            shifts.size()));
1
123malin 已提交
122 123 124 125
    }
  }
};

S
sunli 已提交
126 127 128
template <typename T>
class RollGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
1
123malin 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* in = context.Input<LoDTensor>(framework::GradVarName("Out"));
    auto* out = context.Output<LoDTensor>(framework::GradVarName("X"));
    std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
    std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");

    auto* in_data = in->data<T>();
    auto* out_data = out->mutable_data<T>(context.GetPlace());
    int64_t numel = in->numel();
    auto stream =
        context.template device_context<platform::CUDADeviceContext>().stream();
    size_t nums = shifts.size();
    auto input_dim = in->dims();
    auto stride_dim = framework::stride(input_dim);

S
sunli 已提交
145 146 147 148 149 150 151 152 153 154 155 156 157 158
    std::vector<int64_t> strides(nums), sizes(nums);
    if (dims.size() == 0) {
      strides[0] = 1;
      sizes[0] = numel;
      shifts[0] = ((-shifts[0]) % numel + numel) % numel;
    } else {
      for (size_t i = 0; i < nums; i++) {
        int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size();
        int64_t size = input_dim[dim];

        shifts[i] = ((-shifts[i]) % size + size) % size;
        strides[i] = stride_dim[dim];
        sizes[i] = size;
      }
1
123malin 已提交
159 160
    }

S
sunli 已提交
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
    switch (nums) {
      CALL_ROLL_CUDA_KERNEL(1);
      CALL_ROLL_CUDA_KERNEL(2);
      CALL_ROLL_CUDA_KERNEL(3);
      CALL_ROLL_CUDA_KERNEL(4);
      CALL_ROLL_CUDA_KERNEL(5);
      CALL_ROLL_CUDA_KERNEL(6);
      CALL_ROLL_CUDA_KERNEL(7);
      CALL_ROLL_CUDA_KERNEL(8);
      CALL_ROLL_CUDA_KERNEL(9);
      default:
        PADDLE_THROW(platform::errors::InvalidArgument(
            "shifts.size() should be less than 10, But received shifts.size() "
            "= %d",
            shifts.size()));
    }
1
123malin 已提交
177 178 179 180 181
  }
};

}  // namespace operators
}  // namespace paddle
182 183 184

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
S
sunli 已提交
185 186 187 188
    roll, ops::RollKernel<paddle::platform::CUDADeviceContext, float>,
    ops::RollKernel<paddle::platform::CUDADeviceContext, double>,
    ops::RollKernel<paddle::platform::CUDADeviceContext, int>,
    ops::RollKernel<paddle::platform::CUDADeviceContext, int64_t>);
189
REGISTER_OP_CUDA_KERNEL(
S
sunli 已提交
190 191 192 193
    roll_grad, ops::RollGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::RollGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::RollGradKernel<paddle::platform::CUDADeviceContext, int>,
    ops::RollGradKernel<paddle::platform::CUDADeviceContext, int64_t>);