From 3fc56aa07ef43b22d83b16d61791c4ef103ab838 Mon Sep 17 00:00:00 2001 From: sunli <466530738@qq.com> Date: Thu, 1 Jul 2021 16:30:29 +0800 Subject: [PATCH] roll optimize (#32880) --- paddle/fluid/operators/roll_op.cc | 35 ++- paddle/fluid/operators/roll_op.cu | 199 +++++++++--------- paddle/fluid/operators/roll_op.h | 20 +- .../fluid/tests/unittests/test_roll_op.py | 1 + python/paddle/tensor/manipulation.py | 14 +- 5 files changed, 149 insertions(+), 120 deletions(-) diff --git a/paddle/fluid/operators/roll_op.cc b/paddle/fluid/operators/roll_op.cc index b1fe952036..a0c28ae6cb 100644 --- a/paddle/fluid/operators/roll_op.cc +++ b/paddle/fluid/operators/roll_op.cc @@ -13,8 +13,10 @@ // limitations under the License. #include "paddle/fluid/operators/roll_op.h" + #include #include + #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { @@ -37,12 +39,22 @@ class RollOp : public framework::OperatorWithKernel { auto dims = ctx->Attrs().Get>("axis"); auto shifts = ctx->Attrs().Get>("shifts"); - PADDLE_ENFORCE_EQ(dims.size(), shifts.size(), - platform::errors::InvalidArgument( - "Attr(dims).size() should be equl to " - "Attr(shifts).size(). But received " - "Attr(dims).size() = %d, Attr(shifts).size() = %d", - dims.size(), shifts.size())); + if (dims.size() != 0) { + PADDLE_ENFORCE_EQ(dims.size(), shifts.size(), + platform::errors::InvalidArgument( + "When dims.size() != 0, dims.size() " + "should be equal to " + "shifts.size(). But received " + "dims.size() = %d, shifts.size() = %d", + dims.size(), shifts.size())); + } else { + PADDLE_ENFORCE_EQ(shifts.size(), 1, + platform::errors::InvalidArgument( + "When dims.size() == 0, shifts.size() " + "should be equal to 1, But received " + "shifts.size() = %d", + shifts.size())); + } ctx->SetOutputDim("Out", ctx->GetInputDim("X")); auto type = ctx->GetInputsVarType("X")[0]; @@ -95,7 +107,7 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>( "axis", "Axis along which to roll. It must have the same size " - "with shifts.") + "with shifts or size == 0") .SetDefault({}); AddComment(R"DOC( Roll the tensor along the given dimension(s). @@ -151,8 +163,9 @@ REGISTER_OP_VERSION(roll) paddle::framework::compatible::OpVersionDesc() .NewAttr("axis", "(std::vector) Axis along which to roll. " - "It must have the same size with shifts.", + "It must have the same size with shifts, or size = 0.", std::vector()) - .DeleteAttr("dims", - "(std::vector) Dims along which to roll. " - "It must have the same size with shifts.")); + .DeleteAttr( + "dims", + "(std::vector) Dims along which to roll. " + "It must have the same size with shifts, or size = 0.")); diff --git a/paddle/fluid/operators/roll_op.cu b/paddle/fluid/operators/roll_op.cu index 09309c492d..ce93c5f984 100644 --- a/paddle/fluid/operators/roll_op.cu +++ b/paddle/fluid/operators/roll_op.cu @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include "paddle/fluid/framework/array.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/roll_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -24,26 +25,31 @@ using platform::PADDLE_CUDA_NUM_THREADS; using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -template -__global__ void roll_cuda_kernel(const T* input, T* output, int64_t N, - int64_t* shifts, int64_t* strides, - int64_t* sizes, int64_t nums) { +template +__global__ void RollCudaKernel(const T* input, T* output, int64_t N, + paddle::framework::Array shifts, + paddle::framework::Array strides, + paddle::framework::Array sizes) { int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= N) { return; } + int64_t output_idx = idx; int64_t dim_idx, dim_idx_shift; - for (int64_t i = 0; i < nums; i++) { - dim_idx = idx % (strides[i] * sizes[i]) / strides[i]; + +#pragma unroll Rank + for (size_t i = 0; i < Rank; i++) { + dim_idx = (idx / strides[i]) % sizes[i]; dim_idx_shift = (dim_idx + shifts[i]) % sizes[i]; output_idx = output_idx + (dim_idx_shift - dim_idx) * strides[i]; } output[output_idx] = input[idx]; } -template -class RollCUDAKernel : public framework::OpKernel { +template +class RollKernel + : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); @@ -61,50 +67,62 @@ class RollCUDAKernel : public framework::OpKernel { auto input_dim = in->dims(); auto stride_dim = framework::stride(input_dim); - int64_t dim, size; - size_t gpu_memory_size_ = sizeof(int64_t) * nums; - std::vector strides, sizes; - strides.resize(nums); - sizes.resize(nums); - paddle::memory::AllocationPtr shifts_gpu = - memory::Alloc(context.GetPlace(), gpu_memory_size_); - paddle::memory::AllocationPtr strides_gpu = - memory::Alloc(context.GetPlace(), gpu_memory_size_); - paddle::memory::AllocationPtr sizes_gpu = - memory::Alloc(context.GetPlace(), gpu_memory_size_); - - for (size_t i = 0; i < nums; i++) { - dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size(); - size = input_dim[dim]; - shifts[i] = (shifts[i] % size + size) % size; - strides[i] = stride_dim[dim]; - sizes[i] = size; + std::vector strides(nums), sizes(nums); + if (dims.size() == 0) { + strides[0] = 1; + sizes[0] = numel; + shifts[0] = (shifts[0] % numel + numel) % numel; + } else { + for (size_t i = 0; i < nums; i++) { + int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size(); + int64_t size = input_dim[dim]; + + shifts[i] = (shifts[i] % size + size) % size; + strides[i] = stride_dim[dim]; + sizes[i] = size; + } + } + +#define CALL_ROLL_CUDA_KERNEL(N) \ + case N: { \ + paddle::framework::Array _strides; \ + paddle::framework::Array _shifts; \ + paddle::framework::Array _sizes; \ + for (size_t idx = 0; idx < N; ++idx) { \ + _strides[idx] = strides[idx]; \ + _shifts[idx] = shifts[idx]; \ + _sizes[idx] = sizes[idx]; \ + } \ + RollCudaKernel< \ + T, \ + N><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, \ + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, numel, \ + _shifts, _strides, _sizes); \ + break; \ + } + + switch (nums) { + CALL_ROLL_CUDA_KERNEL(1); + CALL_ROLL_CUDA_KERNEL(2); + CALL_ROLL_CUDA_KERNEL(3); + CALL_ROLL_CUDA_KERNEL(4); + CALL_ROLL_CUDA_KERNEL(5); + CALL_ROLL_CUDA_KERNEL(6); + CALL_ROLL_CUDA_KERNEL(7); + CALL_ROLL_CUDA_KERNEL(8); + CALL_ROLL_CUDA_KERNEL(9); + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "shifts.size() should be less than 10, But received shifts.size() " + "= %d", + shifts.size())); } - paddle::memory::Copy( - BOOST_GET_CONST(platform::CUDAPlace, shifts_gpu->place()), - shifts_gpu->ptr(), platform::CPUPlace(), shifts.data(), - gpu_memory_size_, stream); - paddle::memory::Copy( - BOOST_GET_CONST(platform::CUDAPlace, strides_gpu->place()), - strides_gpu->ptr(), platform::CPUPlace(), strides.data(), - gpu_memory_size_, stream); - paddle::memory::Copy( - BOOST_GET_CONST(platform::CUDAPlace, sizes_gpu->place()), - sizes_gpu->ptr(), platform::CPUPlace(), sizes.data(), gpu_memory_size_, - stream); - int64_t* shifts_ptr = reinterpret_cast(shifts_gpu->ptr()); - int64_t* strides_ptr = reinterpret_cast(strides_gpu->ptr()); - int64_t* sizes_ptr = reinterpret_cast(sizes_gpu->ptr()); - - roll_cuda_kernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / - PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, 0, stream>>>( - in_data, out_data, numel, shifts_ptr, strides_ptr, sizes_ptr, nums); } }; -template -class RollGradCUDAKernel : public framework::OpKernel { +template +class RollGradKernel + : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input(framework::GradVarName("Out")); @@ -121,46 +139,38 @@ class RollGradCUDAKernel : public framework::OpKernel { auto input_dim = in->dims(); auto stride_dim = framework::stride(input_dim); - int64_t dim, size; - size_t gpu_memory_size_ = sizeof(int64_t) * nums; - std::vector strides, sizes; - strides.resize(nums); - sizes.resize(nums); - paddle::memory::AllocationPtr shifts_gpu = - memory::Alloc(context.GetPlace(), gpu_memory_size_); - paddle::memory::AllocationPtr strides_gpu = - memory::Alloc(context.GetPlace(), gpu_memory_size_); - paddle::memory::AllocationPtr sizes_gpu = - memory::Alloc(context.GetPlace(), gpu_memory_size_); - - for (size_t i = 0; i < nums; i++) { - dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size(); - size = input_dim[dim]; - shifts[i] = ((0 - shifts[i]) % size + size) % size; - strides[i] = stride_dim[dim]; - sizes[i] = size; + std::vector strides(nums), sizes(nums); + if (dims.size() == 0) { + strides[0] = 1; + sizes[0] = numel; + shifts[0] = ((-shifts[0]) % numel + numel) % numel; + } else { + for (size_t i = 0; i < nums; i++) { + int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size(); + int64_t size = input_dim[dim]; + + shifts[i] = ((-shifts[i]) % size + size) % size; + strides[i] = stride_dim[dim]; + sizes[i] = size; + } } - paddle::memory::Copy( - BOOST_GET_CONST(platform::CUDAPlace, shifts_gpu->place()), - shifts_gpu->ptr(), platform::CPUPlace(), shifts.data(), - gpu_memory_size_, stream); - paddle::memory::Copy( - BOOST_GET_CONST(platform::CUDAPlace, strides_gpu->place()), - strides_gpu->ptr(), platform::CPUPlace(), strides.data(), - gpu_memory_size_, stream); - paddle::memory::Copy( - BOOST_GET_CONST(platform::CUDAPlace, sizes_gpu->place()), - sizes_gpu->ptr(), platform::CPUPlace(), sizes.data(), gpu_memory_size_, - stream); - int64_t* shifts_ptr = reinterpret_cast(shifts_gpu->ptr()); - int64_t* strides_ptr = reinterpret_cast(strides_gpu->ptr()); - int64_t* sizes_ptr = reinterpret_cast(sizes_gpu->ptr()); - - roll_cuda_kernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / - PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, 0, stream>>>( - in_data, out_data, numel, shifts_ptr, strides_ptr, sizes_ptr, nums); + switch (nums) { + CALL_ROLL_CUDA_KERNEL(1); + CALL_ROLL_CUDA_KERNEL(2); + CALL_ROLL_CUDA_KERNEL(3); + CALL_ROLL_CUDA_KERNEL(4); + CALL_ROLL_CUDA_KERNEL(5); + CALL_ROLL_CUDA_KERNEL(6); + CALL_ROLL_CUDA_KERNEL(7); + CALL_ROLL_CUDA_KERNEL(8); + CALL_ROLL_CUDA_KERNEL(9); + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "shifts.size() should be less than 10, But received shifts.size() " + "= %d", + shifts.size())); + } } }; @@ -169,13 +179,12 @@ class RollGradCUDAKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - roll, ops::RollCUDAKernel, - ops::RollCUDAKernel, - ops::RollCUDAKernel, - ops::RollCUDAKernel); + roll, ops::RollKernel, + ops::RollKernel, + ops::RollKernel, + ops::RollKernel); REGISTER_OP_CUDA_KERNEL( - roll_grad, - ops::RollGradCUDAKernel, - ops::RollGradCUDAKernel, - ops::RollGradCUDAKernel, - ops::RollGradCUDAKernel); + roll_grad, ops::RollGradKernel, + ops::RollGradKernel, + ops::RollGradKernel, + ops::RollGradKernel); diff --git a/paddle/fluid/operators/roll_op.h b/paddle/fluid/operators/roll_op.h index 74dd37ed83..da4f335ca7 100644 --- a/paddle/fluid/operators/roll_op.h +++ b/paddle/fluid/operators/roll_op.h @@ -88,7 +88,13 @@ class RollKernel : public framework::OpKernel { TensorToVector(input, context.device_context(), &out_vec); size_t nums = shifts.size(); - const DDim input_dim = input.dims(); + DDim input_dim = input.dims(); + + // axis = none, reshape to 1-D tensor + if (dims.size() == 0) { + dims.push_back(0l); + input_dim = framework::Dim<1>(out_vec.size()); + } for (size_t i = 0; i < nums; i++) { PADDLE_ENFORCE_EQ( @@ -101,7 +107,7 @@ class RollKernel : public framework::OpKernel { } output->mutable_data(context.GetPlace()); framework::TensorFromVector(out_vec, context.device_context(), output); - output->Resize(input_dim); + output->Resize(input.dims()); } }; @@ -120,14 +126,20 @@ class RollGradKernel : public framework::OpKernel { TensorToVector(input, context.device_context(), &out_vec); size_t nums = shifts.size(); - const DDim input_dim = input.dims(); + DDim input_dim = input.dims(); + + // axis = none, reshape to 1-D tensor + if (dims.size() == 0) { + dims.push_back(0l); + input_dim = framework::Dim<1>(out_vec.size()); + } for (size_t i = 0; i < nums; i++) { shift_along_dim(out_vec.data(), input_dim, dims[i], 0 - shifts[i]); } output->mutable_data(context.GetPlace()); framework::TensorFromVector(out_vec, context.device_context(), output); - output->Resize(input_dim); + output->Resize(input.dims()); } }; diff --git a/python/paddle/fluid/tests/unittests/test_roll_op.py b/python/paddle/fluid/tests/unittests/test_roll_op.py index b20293adf4..99121d2953 100644 --- a/python/paddle/fluid/tests/unittests/test_roll_op.py +++ b/python/paddle/fluid/tests/unittests/test_roll_op.py @@ -63,6 +63,7 @@ class TestRollAPI(unittest.TestCase): def test_roll_op_api(self): self.input_data() + paddle.enable_static() # case 1: with program_guard(Program(), Program()): x = fluid.layers.data(name='x', shape=[-1, 3]) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 981baecb64..6d6d2c9f9a 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -459,28 +459,22 @@ def roll(x, shifts, axis=None, name=None): if axis: check_type(axis, 'axis', (list, tuple), 'roll') + else: + axis = [] + check_type(shifts, 'shifts', (list, tuple), 'roll') if in_dygraph_mode(): - if axis is None: - x = core.ops.reshape(x, 'shape', [-1, 1]) - axis = [0] - out = core.ops.roll(x, 'axis', axis, 'shifts', shifts) - return core.ops.reshape(out, 'shape', origin_shape) + return core.ops.roll(x, 'axis', axis, 'shifts', shifts) out = helper.create_variable_for_type_inference(x.dtype) - if axis is None: - x = reshape(x, shape=[-1, 1]) - axis = [0] - helper.append_op( type='roll', inputs={'X': x}, outputs={'Out': out}, attrs={'axis': axis, 'shifts': shifts}) - out = layers.reshape(out, shape=origin_shape) return out -- GitLab