roll_op.cu 7.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

1
123malin 已提交
15 16
#pragma once
#include "paddle/fluid/framework/op_registry.h"
17
#include "paddle/fluid/operators/roll_op.h"
1
123malin 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

using platform::PADDLE_CUDA_NUM_THREADS;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

template <typename T>
__global__ void roll_cuda_kernel(const T* input, T* output, int64_t N,
                                 int64_t* shifts, int64_t* strides,
                                 int64_t* sizes, int64_t nums) {
  int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx >= N) {
    return;
  }
  int64_t output_idx = idx;
  int64_t dim_idx, dim_idx_shift;
  for (int64_t i = 0; i < nums; i++) {
    dim_idx = idx % (strides[i] * sizes[i]) / strides[i];
    dim_idx_shift = (dim_idx + shifts[i]) % sizes[i];
    output_idx = output_idx + (dim_idx_shift - dim_idx) * strides[i];
  }
  output[output_idx] = input[idx];
}

template <typename DeviceContext, typename T>
class RollCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* in = context.Input<LoDTensor>("X");
    auto* out = context.Output<LoDTensor>("Out");
    std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
    std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");

    auto* in_data = in->data<T>();
    auto* out_data = out->mutable_data<T>(context.GetPlace());
    int64_t numel = in->numel();
    auto stream =
        context.template device_context<platform::CUDADeviceContext>().stream();

    size_t nums = shifts.size();
    auto input_dim = in->dims();
    auto stride_dim = framework::stride(input_dim);

    int64_t dim, size;
    size_t gpu_memory_size_ = sizeof(int64_t) * nums;
    std::vector<int64_t> strides, sizes;
    strides.resize(nums);
    sizes.resize(nums);
    paddle::memory::AllocationPtr shifts_gpu =
        memory::Alloc(context.GetPlace(), gpu_memory_size_);
    paddle::memory::AllocationPtr strides_gpu =
        memory::Alloc(context.GetPlace(), gpu_memory_size_);
    paddle::memory::AllocationPtr sizes_gpu =
        memory::Alloc(context.GetPlace(), gpu_memory_size_);

    for (size_t i = 0; i < nums; i++) {
      dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size();
      size = input_dim[dim];
      shifts[i] = (shifts[i] % size + size) % size;
      strides[i] = stride_dim[dim];
      sizes[i] = size;
    }
    paddle::memory::Copy(
        BOOST_GET_CONST(platform::CUDAPlace, shifts_gpu->place()),
        shifts_gpu->ptr(), platform::CPUPlace(), shifts.data(),
        gpu_memory_size_, stream);
    paddle::memory::Copy(
        BOOST_GET_CONST(platform::CUDAPlace, strides_gpu->place()),
        strides_gpu->ptr(), platform::CPUPlace(), strides.data(),
        gpu_memory_size_, stream);
    paddle::memory::Copy(
        BOOST_GET_CONST(platform::CUDAPlace, sizes_gpu->place()),
        sizes_gpu->ptr(), platform::CPUPlace(), sizes.data(), gpu_memory_size_,
        stream);
    int64_t* shifts_ptr = reinterpret_cast<int64_t*>(shifts_gpu->ptr());
    int64_t* strides_ptr = reinterpret_cast<int64_t*>(strides_gpu->ptr());
    int64_t* sizes_ptr = reinterpret_cast<int64_t*>(sizes_gpu->ptr());

    roll_cuda_kernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
                           PADDLE_CUDA_NUM_THREADS,
                       PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
        in_data, out_data, numel, shifts_ptr, strides_ptr, sizes_ptr, nums);
  }
};

template <typename DeviceContext, typename T>
class RollGradCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* in = context.Input<LoDTensor>(framework::GradVarName("Out"));
    auto* out = context.Output<LoDTensor>(framework::GradVarName("X"));
    std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
    std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");

    auto* in_data = in->data<T>();
    auto* out_data = out->mutable_data<T>(context.GetPlace());
    int64_t numel = in->numel();
    auto stream =
        context.template device_context<platform::CUDADeviceContext>().stream();
    size_t nums = shifts.size();
    auto input_dim = in->dims();
    auto stride_dim = framework::stride(input_dim);

    int64_t dim, size;
    size_t gpu_memory_size_ = sizeof(int64_t) * nums;
    std::vector<int64_t> strides, sizes;
    strides.resize(nums);
    sizes.resize(nums);
    paddle::memory::AllocationPtr shifts_gpu =
        memory::Alloc(context.GetPlace(), gpu_memory_size_);
    paddle::memory::AllocationPtr strides_gpu =
        memory::Alloc(context.GetPlace(), gpu_memory_size_);
    paddle::memory::AllocationPtr sizes_gpu =
        memory::Alloc(context.GetPlace(), gpu_memory_size_);

    for (size_t i = 0; i < nums; i++) {
      dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size();
      size = input_dim[dim];
      shifts[i] = ((0 - shifts[i]) % size + size) % size;
      strides[i] = stride_dim[dim];
      sizes[i] = size;
    }

    paddle::memory::Copy(
        BOOST_GET_CONST(platform::CUDAPlace, shifts_gpu->place()),
        shifts_gpu->ptr(), platform::CPUPlace(), shifts.data(),
        gpu_memory_size_, stream);
    paddle::memory::Copy(
        BOOST_GET_CONST(platform::CUDAPlace, strides_gpu->place()),
        strides_gpu->ptr(), platform::CPUPlace(), strides.data(),
        gpu_memory_size_, stream);
    paddle::memory::Copy(
        BOOST_GET_CONST(platform::CUDAPlace, sizes_gpu->place()),
        sizes_gpu->ptr(), platform::CPUPlace(), sizes.data(), gpu_memory_size_,
        stream);
    int64_t* shifts_ptr = reinterpret_cast<int64_t*>(shifts_gpu->ptr());
    int64_t* strides_ptr = reinterpret_cast<int64_t*>(strides_gpu->ptr());
    int64_t* sizes_ptr = reinterpret_cast<int64_t*>(sizes_gpu->ptr());

    roll_cuda_kernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
                           PADDLE_CUDA_NUM_THREADS,
                       PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
        in_data, out_data, numel, shifts_ptr, strides_ptr, sizes_ptr, nums);
  }
};

}  // namespace operators
}  // namespace paddle
169 170 171

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
1
123malin 已提交
172 173 174 175
    roll, ops::RollCUDAKernel<paddle::platform::CUDADeviceContext, float>,
    ops::RollCUDAKernel<paddle::platform::CUDADeviceContext, double>,
    ops::RollCUDAKernel<paddle::platform::CUDADeviceContext, int>,
    ops::RollCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
176
REGISTER_OP_CUDA_KERNEL(
1
123malin 已提交
177 178 179 180 181
    roll_grad,
    ops::RollGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
    ops::RollGradCUDAKernel<paddle::platform::CUDADeviceContext, double>,
    ops::RollGradCUDAKernel<paddle::platform::CUDADeviceContext, int>,
    ops::RollGradCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);