diff --git a/paddle/fluid/operators/roll_op.cc b/paddle/fluid/operators/roll_op.cc index f470f41f1eb5c9d08af7802f943b3a1e54f30939..975cf83ffe8bed49a5359a3add009239ce62aa30 100644 --- a/paddle/fluid/operators/roll_op.cc +++ b/paddle/fluid/operators/roll_op.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/operators/roll_op.h" + #include #include diff --git a/paddle/fluid/operators/roll_op.cu b/paddle/fluid/operators/roll_op.cu index 59178811061a25989dc2ed2fa9d7b059e8319350..09309c492d29225cb2b0ed42559e43e73ea49c7f 100644 --- a/paddle/fluid/operators/roll_op.cu +++ b/paddle/fluid/operators/roll_op.cu @@ -12,16 +12,170 @@ // See the License for the specific language governing permissions and // limitations under the License. +#pragma once +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/roll_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +__global__ void roll_cuda_kernel(const T* input, T* output, int64_t N, + int64_t* shifts, int64_t* strides, + int64_t* sizes, int64_t nums) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + int64_t output_idx = idx; + int64_t dim_idx, dim_idx_shift; + for (int64_t i = 0; i < nums; i++) { + dim_idx = idx % (strides[i] * sizes[i]) / strides[i]; + dim_idx_shift = (dim_idx + shifts[i]) % sizes[i]; + output_idx = output_idx + (dim_idx_shift - dim_idx) * strides[i]; + } + output[output_idx] = input[idx]; +} + +template +class RollCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + std::vector shifts = context.Attr>("shifts"); + std::vector dims = context.Attr>("axis"); + + auto* in_data = in->data(); + auto* out_data = out->mutable_data(context.GetPlace()); + int64_t numel = in->numel(); + auto stream = + context.template device_context().stream(); + + size_t nums = shifts.size(); + auto input_dim = in->dims(); + auto stride_dim = framework::stride(input_dim); + + int64_t dim, size; + size_t gpu_memory_size_ = sizeof(int64_t) * nums; + std::vector strides, sizes; + strides.resize(nums); + sizes.resize(nums); + paddle::memory::AllocationPtr shifts_gpu = + memory::Alloc(context.GetPlace(), gpu_memory_size_); + paddle::memory::AllocationPtr strides_gpu = + memory::Alloc(context.GetPlace(), gpu_memory_size_); + paddle::memory::AllocationPtr sizes_gpu = + memory::Alloc(context.GetPlace(), gpu_memory_size_); + + for (size_t i = 0; i < nums; i++) { + dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size(); + size = input_dim[dim]; + shifts[i] = (shifts[i] % size + size) % size; + strides[i] = stride_dim[dim]; + sizes[i] = size; + } + paddle::memory::Copy( + BOOST_GET_CONST(platform::CUDAPlace, shifts_gpu->place()), + shifts_gpu->ptr(), platform::CPUPlace(), shifts.data(), + gpu_memory_size_, stream); + paddle::memory::Copy( + BOOST_GET_CONST(platform::CUDAPlace, strides_gpu->place()), + strides_gpu->ptr(), platform::CPUPlace(), strides.data(), + gpu_memory_size_, stream); + paddle::memory::Copy( + BOOST_GET_CONST(platform::CUDAPlace, sizes_gpu->place()), + sizes_gpu->ptr(), platform::CPUPlace(), sizes.data(), gpu_memory_size_, + stream); + int64_t* shifts_ptr = reinterpret_cast(shifts_gpu->ptr()); + int64_t* strides_ptr = reinterpret_cast(strides_gpu->ptr()); + int64_t* sizes_ptr = reinterpret_cast(sizes_gpu->ptr()); + + roll_cuda_kernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + in_data, out_data, numel, shifts_ptr, strides_ptr, sizes_ptr, nums); + } +}; + +template +class RollGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input(framework::GradVarName("Out")); + auto* out = context.Output(framework::GradVarName("X")); + std::vector shifts = context.Attr>("shifts"); + std::vector dims = context.Attr>("axis"); + + auto* in_data = in->data(); + auto* out_data = out->mutable_data(context.GetPlace()); + int64_t numel = in->numel(); + auto stream = + context.template device_context().stream(); + size_t nums = shifts.size(); + auto input_dim = in->dims(); + auto stride_dim = framework::stride(input_dim); + + int64_t dim, size; + size_t gpu_memory_size_ = sizeof(int64_t) * nums; + std::vector strides, sizes; + strides.resize(nums); + sizes.resize(nums); + paddle::memory::AllocationPtr shifts_gpu = + memory::Alloc(context.GetPlace(), gpu_memory_size_); + paddle::memory::AllocationPtr strides_gpu = + memory::Alloc(context.GetPlace(), gpu_memory_size_); + paddle::memory::AllocationPtr sizes_gpu = + memory::Alloc(context.GetPlace(), gpu_memory_size_); + + for (size_t i = 0; i < nums; i++) { + dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size(); + size = input_dim[dim]; + shifts[i] = ((0 - shifts[i]) % size + size) % size; + strides[i] = stride_dim[dim]; + sizes[i] = size; + } + + paddle::memory::Copy( + BOOST_GET_CONST(platform::CUDAPlace, shifts_gpu->place()), + shifts_gpu->ptr(), platform::CPUPlace(), shifts.data(), + gpu_memory_size_, stream); + paddle::memory::Copy( + BOOST_GET_CONST(platform::CUDAPlace, strides_gpu->place()), + strides_gpu->ptr(), platform::CPUPlace(), strides.data(), + gpu_memory_size_, stream); + paddle::memory::Copy( + BOOST_GET_CONST(platform::CUDAPlace, sizes_gpu->place()), + sizes_gpu->ptr(), platform::CPUPlace(), sizes.data(), gpu_memory_size_, + stream); + int64_t* shifts_ptr = reinterpret_cast(shifts_gpu->ptr()); + int64_t* strides_ptr = reinterpret_cast(strides_gpu->ptr()); + int64_t* sizes_ptr = reinterpret_cast(sizes_gpu->ptr()); + + roll_cuda_kernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + in_data, out_data, numel, shifts_ptr, strides_ptr, sizes_ptr, nums); + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - roll, ops::RollKernel, - ops::RollKernel, - ops::RollKernel, - ops::RollKernel); + roll, ops::RollCUDAKernel, + ops::RollCUDAKernel, + ops::RollCUDAKernel, + ops::RollCUDAKernel); REGISTER_OP_CUDA_KERNEL( - roll_grad, ops::RollGradKernel, - ops::RollGradKernel, - ops::RollGradKernel, - ops::RollGradKernel); + roll_grad, + ops::RollGradCUDAKernel, + ops::RollGradCUDAKernel, + ops::RollGradCUDAKernel, + ops::RollGradCUDAKernel);