diff --git a/paddle/fluid/operators/roll_op.cc b/paddle/fluid/operators/roll_op.cc index f82510556fde87fbf4aeb1904e29325358598791..898db4c22fed9cc97baa261b5b512a889290aff3 100644 --- a/paddle/fluid/operators/roll_op.cc +++ b/paddle/fluid/operators/roll_op.cc @@ -12,13 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/roll_op.h" - #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/platform/complex.h" +#include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -29,43 +32,6 @@ class RollOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::InvalidArgument( - "Input(X) of RollOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument( - "Output(Out) of RollOp should not be null.")); - - auto dims = ctx->Attrs().Get>("axis"); - auto shifts = ctx->Attrs().Get>("shifts"); - - if (!ctx->HasInput("ShiftsTensor")) { - if (dims.size() != 0) { - PADDLE_ENFORCE_EQ(dims.size(), shifts.size(), - platform::errors::InvalidArgument( - "When dims.size() != 0, dims.size() " - "should be equal to " - "shifts.size(). But received " - "dims.size() = %d, shifts.size() = %d", - dims.size(), shifts.size())); - } else { - PADDLE_ENFORCE_EQ(shifts.size(), 1, - platform::errors::InvalidArgument( - "When dims.size() == 0, shifts.size() " - "should be equal to 1, But received " - "shifts.size() = %d", - shifts.size())); - } - } - - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - auto type = ctx->GetInputsVarType("X")[0]; - if (type == framework::proto::VarType::LOD_TENSOR) { - ctx->ShareLoD("X", /*->*/ "Out"); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -149,29 +115,15 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(RollGradNoNeedBufferVarsInferer, "X"); } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(roll, RollInferShapeFunctor, + PD_INFER_META(phi::RollInferMeta)); + REGISTER_OPERATOR(roll, ops::RollOp, ops::RollOpMaker, ops::RollGradMaker, - ops::RollGradMaker); + ops::RollGradMaker, + RollInferShapeFunctor); REGISTER_OPERATOR(roll_grad, ops::RollGradOp, ops::RollGradNoNeedBufferVarsInferer); -REGISTER_OP_CPU_KERNEL( - roll, ops::RollKernel, - ops::RollKernel, - ops::RollKernel, - ops::RollKernel, - ops::RollKernel>, - ops::RollKernel>); -REGISTER_OP_CPU_KERNEL( - roll_grad, ops::RollGradKernel, - ops::RollGradKernel, - ops::RollGradKernel, - ops::RollGradKernel, - ops::RollGradKernel>, - ops::RollGradKernel>); REGISTER_OP_VERSION(roll) .AddCheckpoint( diff --git a/paddle/fluid/operators/roll_op.cu b/paddle/fluid/operators/roll_op.cu deleted file mode 100644 index b9064c5450f9fbed64bcb65a2f9d15be2b56fbcf..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/roll_op.cu +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/roll_op.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/phi/core/utils/array.h" - -namespace paddle { -namespace operators { - -using platform::PADDLE_CUDA_NUM_THREADS; -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; - -template -__global__ void RollCudaKernel(const T* input, T* output, int64_t N, - phi::Array shifts, - phi::Array strides, - phi::Array sizes) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) { - return; - } - - int64_t output_idx = idx; - int64_t new_dim_idx = 0; - -#pragma unroll - for (size_t i = 0; i < Rank; i++) { - new_dim_idx = (idx / strides[i]) % sizes[i] + shifts[i]; - if (new_dim_idx >= sizes[i]) { - output_idx += (shifts[i] - sizes[i]) * strides[i]; - } else { - output_idx += shifts[i] * strides[i]; - } - } - output[output_idx] = input[idx]; -} - -template -class RollKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - auto* out = context.Output("Out"); - std::vector shifts = context.Attr>("shifts"); - if (context.HasInput("ShiftsTensor")) { - const auto* shifts_tensor = - context.Input("ShiftsTensor"); - PADDLE_ENFORCE_EQ( - shifts_tensor->dims().size(), 1, - platform::errors::InvalidArgument( - "The rank of ShiftsTensor is expected to be 1, got %s", - shifts_tensor->dims().size())); - shifts = GetDataFromTensor(shifts_tensor); - } - std::vector dims = context.Attr>("axis"); - - auto* in_data = in->data(); - auto* out_data = out->mutable_data(context.GetPlace()); - int64_t numel = in->numel(); - auto stream = - context.template device_context().stream(); - - size_t nums = shifts.size(); - auto input_dim = in->dims(); - auto stride_dim = phi::stride(input_dim); - - std::vector strides(nums), sizes(nums); - if (dims.size() == 0) { - strides[0] = 1; - sizes[0] = numel; - shifts[0] = (shifts[0] % numel + numel) % numel; - } else { - for (size_t i = 0; i < nums; i++) { - int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size(); - int64_t size = input_dim[dim]; - - if (size != 0) { - shifts[i] = (shifts[i] % size + size) % size; - strides[i] = stride_dim[dim]; - sizes[i] = size; - } - } - } - -#define CALL_ROLL_CUDA_KERNEL(N) \ - case N: { \ - phi::Array _strides; \ - phi::Array _shifts; \ - phi::Array _sizes; \ - for (size_t idx = 0; idx < N; ++idx) { \ - _strides[idx] = strides[idx]; \ - _shifts[idx] = shifts[idx]; \ - _sizes[idx] = sizes[idx]; \ - } \ - RollCudaKernel< \ - T, \ - N><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, \ - PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, numel, \ - _shifts, _strides, _sizes); \ - break; \ - } - - switch (nums) { - CALL_ROLL_CUDA_KERNEL(1); - CALL_ROLL_CUDA_KERNEL(2); - CALL_ROLL_CUDA_KERNEL(3); - CALL_ROLL_CUDA_KERNEL(4); - CALL_ROLL_CUDA_KERNEL(5); - CALL_ROLL_CUDA_KERNEL(6); - CALL_ROLL_CUDA_KERNEL(7); - CALL_ROLL_CUDA_KERNEL(8); - CALL_ROLL_CUDA_KERNEL(9); - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "shifts.size() should be less than 10, But received shifts.size() " - "= %d", - shifts.size())); - } - } -}; - -template -class RollGradKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input(framework::GradVarName("Out")); - auto* out = context.Output(framework::GradVarName("X")); - std::vector shifts = context.Attr>("shifts"); - if (context.HasInput("ShiftsTensor")) { - const auto* shifts_tensor = - context.Input("ShiftsTensor"); - PADDLE_ENFORCE_EQ( - shifts_tensor->dims().size(), 1, - platform::errors::InvalidArgument( - "The rank of ShiftsTensor is expected to be 1, got %s", - shifts_tensor->dims().size())); - shifts = GetDataFromTensor(shifts_tensor); - } - std::vector dims = context.Attr>("axis"); - - auto* in_data = in->data(); - auto* out_data = out->mutable_data(context.GetPlace()); - int64_t numel = in->numel(); - auto stream = - context.template device_context().stream(); - size_t nums = shifts.size(); - auto input_dim = in->dims(); - auto stride_dim = phi::stride(input_dim); - - std::vector strides(nums), sizes(nums); - if (dims.size() == 0) { - strides[0] = 1; - sizes[0] = numel; - shifts[0] = ((-shifts[0]) % numel + numel) % numel; - } else { - for (size_t i = 0; i < nums; i++) { - int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size(); - int64_t size = input_dim[dim]; - if (size != 0) { - shifts[i] = ((-shifts[i]) % size + size) % size; - strides[i] = stride_dim[dim]; - sizes[i] = size; - } - } - } - - switch (nums) { - CALL_ROLL_CUDA_KERNEL(1); - CALL_ROLL_CUDA_KERNEL(2); - CALL_ROLL_CUDA_KERNEL(3); - CALL_ROLL_CUDA_KERNEL(4); - CALL_ROLL_CUDA_KERNEL(5); - CALL_ROLL_CUDA_KERNEL(6); - CALL_ROLL_CUDA_KERNEL(7); - CALL_ROLL_CUDA_KERNEL(8); - CALL_ROLL_CUDA_KERNEL(9); - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "shifts.size() should be less than 10, But received shifts.size() " - "= %d", - shifts.size())); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - roll, ops::RollKernel, - ops::RollKernel, - ops::RollKernel, - ops::RollKernel, - ops::RollKernel>, - ops::RollKernel>); -REGISTER_OP_CUDA_KERNEL( - roll_grad, ops::RollGradKernel, - ops::RollGradKernel, - ops::RollGradKernel, - ops::RollGradKernel, - ops::RollGradKernel>, - ops::RollGradKernel>); diff --git a/paddle/fluid/operators/roll_op.h b/paddle/fluid/operators/roll_op.h deleted file mode 100644 index 413c7bcfc15eb1cae86c3fedf47ea4f677d1248c..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/roll_op.h +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/utils.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; -using DDim = framework::DDim; - -template -inline void shift_along_dim(T* data, const DDim& input_dim, int64_t dim, - int64_t shift) { - if (dim < 0) { - dim += input_dim.size(); - } - if (input_dim[dim] == 0) { - return; - } - shift = shift % input_dim[dim]; - if (shift < 0) { - shift += input_dim[dim]; - } - - auto outer_loops = 1; - for (auto i = 0; i < dim; i++) { - outer_loops *= input_dim[i]; - } - auto slice_width = 1; - for (auto i = dim + 1; i < input_dim.size(); i++) { - slice_width *= input_dim[i]; - } - - VLOG(3) << "shift_along_dim_debug: input_dim: " << input_dim - << "; dim: " << dim << "; shift: " << shift - << "; outer_loops: " << outer_loops - << "; slice_width: " << slice_width; - if (shift == 0) { - return; - } - - std::vector head; - auto head_size = slice_width * (input_dim[dim] - shift); - head.resize(head_size); - - for (auto i = 0; i < outer_loops; i++) { - for (auto j = 0; j < head_size; j++) { - head[j] = data[i * input_dim[dim] * slice_width + j]; - } - for (auto j = input_dim[dim] - shift; j < input_dim[dim]; j++) { - auto dst_pos = j - input_dim[dim] + shift; - for (auto k = 0; k < slice_width; k++) { - data[(i * input_dim[dim] + dst_pos) * slice_width + k] = - data[(i * input_dim[dim] + j) * slice_width + k]; - } - } - for (auto j = 0; j < head_size; j++) { - data[(i * input_dim[dim] + shift) * slice_width + j] = head[j]; - } - } -} - -template -class RollKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input_var = context.InputVar("X"); - auto* output_var = context.OutputVar("Out"); - auto& input = input_var->Get(); - auto* output = output_var->GetMutable(); - std::vector shifts = context.Attr>("shifts"); - if (context.HasInput("ShiftsTensor")) { - const auto* shifts_tensor = - context.Input("ShiftsTensor"); - PADDLE_ENFORCE_EQ( - shifts_tensor->dims().size(), 1, - platform::errors::InvalidArgument( - "The rank of ShiftsTensor is expected to be 1, got %s", - shifts_tensor->dims().size())); - shifts = GetDataFromTensor(shifts_tensor); - } - std::vector dims = context.Attr>("axis"); - - std::vector out_vec; - paddle::framework::TensorToVector(input, context.device_context(), - &out_vec); - - size_t nums = shifts.size(); - DDim input_dim = input.dims(); - - // axis = none, reshape to 1-D tensor - if (dims.size() == 0) { - dims.push_back(0l); - input_dim = framework::Dim<1>(out_vec.size()); - } - - for (size_t i = 0; i < nums; i++) { - PADDLE_ENFORCE_EQ( - dims[i] < input_dim.size() && dims[i] >= (0 - input_dim.size()), true, - platform::errors::OutOfRange( - "Attr(axis[%d]) is out of range, It's expected " - "to be in range of [-%d, %d]. But received Attr(axis[%d]) = %d.", - i, input_dim.size(), input_dim.size() - 1, i, dims[i])); - shift_along_dim(out_vec.data(), input_dim, dims[i], shifts[i]); - } - output->mutable_data(context.GetPlace()); - framework::TensorFromVector(out_vec, context.device_context(), output); - output->Resize(input.dims()); - } -}; - -template -class RollGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input_var = context.InputVar(framework::GradVarName("Out")); - auto* output_var = context.OutputVar(framework::GradVarName("X")); - auto& input = input_var->Get(); - auto* output = output_var->GetMutable(); - std::vector shifts = context.Attr>("shifts"); - if (context.HasInput("ShiftsTensor")) { - const auto* shifts_tensor = - context.Input("ShiftsTensor"); - shifts = GetDataFromTensor(shifts_tensor); - } - std::vector dims = context.Attr>("axis"); - - std::vector out_vec; - paddle::framework::TensorToVector(input, context.device_context(), - &out_vec); - - size_t nums = shifts.size(); - DDim input_dim = input.dims(); - - // axis = none, reshape to 1-D tensor - if (dims.size() == 0) { - dims.push_back(0l); - input_dim = framework::Dim<1>(out_vec.size()); - } - - for (size_t i = 0; i < nums; i++) { - shift_along_dim(out_vec.data(), input_dim, dims[i], 0 - shifts[i]); - } - output->mutable_data(context.GetPlace()); - framework::TensorFromVector(out_vec, context.device_context(), output); - output->Resize(input.dims()); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 752abae1b0333f46a749dca586936b0fca095720..262ada3eaf3169bebc919940e7630a75b0733cd9 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1016,6 +1016,37 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ReshapeInferMeta(x, shape, out, config); } +void RollInferMeta(const MetaTensor& x, + const ScalarArray& shifts, + const std::vector& axis, + MetaTensor* out) { + auto shifts_data = shifts.GetData(); + + if (axis.size() != 0) { + PADDLE_ENFORCE_EQ( + axis.size(), + shifts_data.size(), + phi::errors::InvalidArgument("When dims.size() != 0, dims.size() " + "should be equal to " + "shifts.size(). But received " + "dims.size() = %d, shifts.size() = %d", + axis.size(), + shifts_data.size())); + } else { + PADDLE_ENFORCE_EQ( + shifts_data.size(), + 1, + phi::errors::InvalidArgument("When dims.size() == 0, shifts.size() " + "should be equal to 1, But received " + "shifts.size() = %d", + shifts_data.size())); + } + + out->set_dims(x.dims()); + out->share_lod(x); + out->set_dtype(x.dtype()); +} + void ShapeInferMeta(const MetaTensor& input, MetaTensor* out) { auto in_dim = input.dims(); out->set_dims(phi::make_ddim({in_dim.size()})); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index a9aefd1f12d67e994f6cc92c4bbb849654bb00b9..5447c9a573fbf3702dbb540f5052f2598899150e 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -164,6 +164,11 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void RollInferMeta(const MetaTensor& x, + const ScalarArray& shifts, + const std::vector& axis, + MetaTensor* out); + void ShapeInferMeta(const MetaTensor& input, MetaTensor* out); void ShardIndexInferMeta(const MetaTensor& in, diff --git a/paddle/phi/kernels/cpu/roll_grad_kernel.cc b/paddle/phi/kernels/cpu/roll_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..b0d0c0663e4a2eb71f4500baaf43bc8a891acddd --- /dev/null +++ b/paddle/phi/kernels/cpu/roll_grad_kernel.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/roll_grad_kernel.h" + +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/roll_kernel_impl.h" + +namespace phi { + +template +void RollGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const ScalarArray& shifts, + const std::vector& axis, + DenseTensor* x_grad) { + std::vector out_vec; + paddle::framework::TensorToVector(out_grad, dev_ctx, &out_vec); + + auto shifts_data = shifts.GetData(); + size_t nums = shifts_data.size(); + DDim input_dim = out_grad.dims(); + auto dims = axis; + + // axis = none, reshape to 1-D tensor + if (dims.size() == 0) { + dims.push_back(0l); + input_dim = phi::Dim<1>(out_vec.size()); + } + + for (size_t i = 0; i < nums; i++) { + ShiftAlongDim(out_vec.data(), input_dim, dims[i], 0 - shifts_data[i]); + } + + dev_ctx.template Alloc(x_grad); + paddle::framework::TensorFromVector(out_vec, dev_ctx, x_grad); + x_grad->Resize(out_grad.dims()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(roll_grad, + CPU, + ALL_LAYOUT, + phi::RollGradKernel, + float, + double, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/roll_kernel.cc b/paddle/phi/kernels/cpu/roll_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..25b64ef257dfb801f0050aad388b9fb0b3020ea5 --- /dev/null +++ b/paddle/phi/kernels/cpu/roll_kernel.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/roll_kernel.h" + +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/roll_kernel_impl.h" + +namespace phi { + +template +void RollKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& shifts, + const std::vector& axis, + DenseTensor* out) { + std::vector out_vec; + paddle::framework::TensorToVector(x, dev_ctx, &out_vec); + + auto shifts_data = shifts.GetData(); + size_t nums = shifts_data.size(); + DDim input_dim = x.dims(); + auto dims = axis; + + // axis = none, reshape to 1-D tensor + if (dims.size() == 0) { + dims.push_back(0l); + input_dim = phi::Dim<1>(out_vec.size()); + } + + for (size_t i = 0; i < nums; i++) { + PADDLE_ENFORCE_EQ( + dims[i] < input_dim.size() && dims[i] >= (0 - input_dim.size()), + true, + phi::errors::OutOfRange( + "Attr(axis[%d]) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(axis[%d]) = %d.", + i, + input_dim.size(), + input_dim.size() - 1, + i, + dims[i])); + ShiftAlongDim(out_vec.data(), input_dim, dims[i], shifts_data[i]); + } + dev_ctx.template Alloc(out); + paddle::framework::TensorFromVector(out_vec, dev_ctx, out); + out->Resize(x.dims()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(roll, + CPU, + ALL_LAYOUT, + phi::RollKernel, + float, + double, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/roll_kernel_impl.h b/paddle/phi/kernels/cpu/roll_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..924e71aff31f3f874fb35586f496b9c5952c3757 --- /dev/null +++ b/paddle/phi/kernels/cpu/roll_kernel_impl.h @@ -0,0 +1,76 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +inline void ShiftAlongDim(T* data, + const DDim& input_dim, + int64_t dim, + int64_t shift) { + if (dim < 0) { + dim += input_dim.size(); + } + if (input_dim[dim] == 0) { + return; + } + shift = shift % input_dim[dim]; + if (shift < 0) { + shift += input_dim[dim]; + } + + auto outer_loops = 1; + for (auto i = 0; i < dim; i++) { + outer_loops *= input_dim[i]; + } + auto slice_width = 1; + for (auto i = dim + 1; i < input_dim.size(); i++) { + slice_width *= input_dim[i]; + } + + VLOG(3) << "shift_along_dim_debug: input_dim: " << input_dim + << "; dim: " << dim << "; shift: " << shift + << "; outer_loops: " << outer_loops + << "; slice_width: " << slice_width; + if (shift == 0) { + return; + } + + std::vector head; + auto head_size = slice_width * (input_dim[dim] - shift); + head.resize(head_size); + + for (auto i = 0; i < outer_loops; i++) { + for (auto j = 0; j < head_size; j++) { + head[j] = data[i * input_dim[dim] * slice_width + j]; + } + for (auto j = input_dim[dim] - shift; j < input_dim[dim]; j++) { + auto dst_pos = j - input_dim[dim] + shift; + for (auto k = 0; k < slice_width; k++) { + data[(i * input_dim[dim] + dst_pos) * slice_width + k] = + data[(i * input_dim[dim] + j) * slice_width + k]; + } + } + for (auto j = 0; j < head_size; j++) { + data[(i * input_dim[dim] + shift) * slice_width + j] = head[j]; + } + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/roll_grad_kernel.cu b/paddle/phi/kernels/gpu/roll_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..93e9e81882c9e6eacd5f9ee91fa7541495ef2663 --- /dev/null +++ b/paddle/phi/kernels/gpu/roll_grad_kernel.cu @@ -0,0 +1,88 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/roll_grad_kernel.h" + +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/roll_kernel_impl.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +template +void RollGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const ScalarArray& shifts, + const std::vector& axis, + DenseTensor* x_grad) { + auto* in_data = out_grad.data(); + T* out_data = dev_ctx.template Alloc(x_grad); + int64_t numel = out_grad.numel(); + auto stream = dev_ctx.stream(); + + auto shifts_data = shifts.GetData(); + size_t nums = shifts_data.size(); + auto input_dim = out_grad.dims(); + auto stride_dim = phi::stride(input_dim); + + std::vector strides(nums), sizes(nums); + if (axis.size() == 0) { + strides[0] = 1; + sizes[0] = numel; + shifts_data[0] = ((-shifts_data[0]) % numel + numel) % numel; + } else { + for (size_t i = 0; i < nums; i++) { + int dim = axis[i] >= 0 ? axis[i] : axis[i] + input_dim.size(); + int64_t size = input_dim[dim]; + if (size != 0) { + shifts_data[i] = ((-shifts_data[i]) % size + size) % size; + strides[i] = stride_dim[dim]; + sizes[i] = size; + } + } + } + + switch (nums) { + CALL_ROLL_CUDA_KERNEL(1); + CALL_ROLL_CUDA_KERNEL(2); + CALL_ROLL_CUDA_KERNEL(3); + CALL_ROLL_CUDA_KERNEL(4); + CALL_ROLL_CUDA_KERNEL(5); + CALL_ROLL_CUDA_KERNEL(6); + CALL_ROLL_CUDA_KERNEL(7); + CALL_ROLL_CUDA_KERNEL(8); + CALL_ROLL_CUDA_KERNEL(9); + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "shifts.size() should be less than 10, But received shifts.size() " + "= %d", + shifts_data.size())); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(roll_grad, + GPU, + ALL_LAYOUT, + phi::RollGradKernel, + float, + double, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/roll_kernel.cu b/paddle/phi/kernels/gpu/roll_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..1543335d3a0c5884d6b82394253bb4e8dda8cef0 --- /dev/null +++ b/paddle/phi/kernels/gpu/roll_kernel.cu @@ -0,0 +1,90 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/roll_kernel.h" + +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/array.h" +#include "paddle/phi/kernels/gpu/roll_kernel_impl.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +template +void RollKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& shifts, + const std::vector& axis, + DenseTensor* out) { + auto* in_data = x.data(); + T* out_data = dev_ctx.template Alloc(out); + int64_t numel = x.numel(); + auto stream = dev_ctx.stream(); + + auto shifts_data = shifts.GetData(); + + size_t nums = shifts_data.size(); + auto input_dim = x.dims(); + auto stride_dim = phi::stride(input_dim); + + std::vector strides(nums), sizes(nums); + if (axis.size() == 0) { + strides[0] = 1; + sizes[0] = numel; + shifts_data[0] = (shifts_data[0] % numel + numel) % numel; + } else { + for (size_t i = 0; i < nums; i++) { + int dim = axis[i] >= 0 ? axis[i] : axis[i] + input_dim.size(); + int64_t size = input_dim[dim]; + + if (size != 0) { + shifts_data[i] = (shifts_data[i] % size + size) % size; + strides[i] = stride_dim[dim]; + sizes[i] = size; + } + } + } + + switch (nums) { + CALL_ROLL_CUDA_KERNEL(1); + CALL_ROLL_CUDA_KERNEL(2); + CALL_ROLL_CUDA_KERNEL(3); + CALL_ROLL_CUDA_KERNEL(4); + CALL_ROLL_CUDA_KERNEL(5); + CALL_ROLL_CUDA_KERNEL(6); + CALL_ROLL_CUDA_KERNEL(7); + CALL_ROLL_CUDA_KERNEL(8); + CALL_ROLL_CUDA_KERNEL(9); + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "shifts.size() should be less than 10, But received shifts.size() " + "= %d", + shifts_data.size())); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(roll, + GPU, + ALL_LAYOUT, + phi::RollKernel, + float, + double, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/roll_kernel_impl.h b/paddle/phi/kernels/gpu/roll_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..abe3ee470b4bc6b3951e1ad2da09544e319cbcac --- /dev/null +++ b/paddle/phi/kernels/gpu/roll_kernel_impl.h @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/core/utils/array.h" +#include "paddle/phi/kernels/primitive/kernel_primitives.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void RollCudaKernel(const T* input, + T* output, + int64_t N, + phi::Array shifts, + phi::Array strides, + phi::Array sizes) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + + int64_t output_idx = idx; + int64_t new_dim_idx = 0; + +#pragma unroll + for (size_t i = 0; i < Rank; i++) { + new_dim_idx = (idx / strides[i]) % sizes[i] + shifts[i]; + if (new_dim_idx >= sizes[i]) { + output_idx += (shifts[i] - sizes[i]) * strides[i]; + } else { + output_idx += shifts[i] * strides[i]; + } + } + output[output_idx] = input[idx]; +} + +#define CALL_ROLL_CUDA_KERNEL(N) \ + case N: { \ + phi::Array _strides; \ + phi::Array _shifts; \ + phi::Array _sizes; \ + for (size_t idx = 0; idx < N; ++idx) { \ + _strides[idx] = strides[idx]; \ + _shifts[idx] = shifts_data[idx]; \ + _sizes[idx] = sizes[idx]; \ + } \ + RollCudaKernel< \ + T, \ + N><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, \ + PADDLE_CUDA_NUM_THREADS, \ + 0, \ + stream>>>(in_data, out_data, numel, _shifts, _strides, _sizes); \ + break; \ + } + +} // namespace phi diff --git a/paddle/phi/kernels/roll_grad_kernel.h b/paddle/phi/kernels/roll_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..331f3626e56574615a2d6b1680335638b060846d --- /dev/null +++ b/paddle/phi/kernels/roll_grad_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void RollGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const ScalarArray& shifts, + const std::vector& axis, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/roll_kernel.h b/paddle/phi/kernels/roll_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..56f32174a4c0005968acf147b2daf25914ff01b1 --- /dev/null +++ b/paddle/phi/kernels/roll_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void RollKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& shifts, + const std::vector& axis, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/roll_sig.cc b/paddle/phi/ops/compat/roll_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..a144f0e8e8a90eee0bf0a8a80455b1e19611880c --- /dev/null +++ b/paddle/phi/ops/compat/roll_sig.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature RollOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("ShiftsTensor")) { + return KernelSignature("roll", {"X"}, {"ShiftsTensor", "axis"}, {"Out"}); + } + return KernelSignature("roll", {"X"}, {"shifts", "axis"}, {"Out"}); +} + +KernelSignature RollGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("roll_grad", + {"X", GradVarName("Out")}, + {"shifts", "axis"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(roll, phi::RollOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(roll_grad, phi::RollGradOpArgumentMapping);