未验证 提交 44d46d03 编写于 作者: C chenenquan 提交者: GitHub

[PHI] Migrate roll op (#40257)

* [PHI] Migrate roll op

* 【phi】migrate eigh op to phi (#40213)

* migrate eigh to phi

* optimize code

* modify code according to comment

* conflict resolution

* [PHI] Migrate roll op

* [PHI] Fix converage of roll_sig

* [PHI] Fix infermate of roll_sig

* [Phi] Fix unittest coverage of roll op

* [PHI] Fix infermeta in unary

* [PHI] Fix parameter type of roll op

* [PHI] Fix parameter type of roll op

* [PHI] Fix parameter of roll op
Co-authored-by: Ncrystal <62974595+Zjq9409@users.noreply.github.com>
上级 99452af7
...@@ -12,13 +12,16 @@ ...@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/roll_op.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -29,43 +32,6 @@ class RollOp : public framework::OperatorWithKernel { ...@@ -29,43 +32,6 @@ class RollOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of RollOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of RollOp should not be null."));
auto dims = ctx->Attrs().Get<std::vector<int64_t>>("axis");
auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts");
if (!ctx->HasInput("ShiftsTensor")) {
if (dims.size() != 0) {
PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
platform::errors::InvalidArgument(
"When dims.size() != 0, dims.size() "
"should be equal to "
"shifts.size(). But received "
"dims.size() = %d, shifts.size() = %d",
dims.size(), shifts.size()));
} else {
PADDLE_ENFORCE_EQ(shifts.size(), 1,
platform::errors::InvalidArgument(
"When dims.size() == 0, shifts.size() "
"should be equal to 1, But received "
"shifts.size() = %d",
shifts.size()));
}
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -149,29 +115,15 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(RollGradNoNeedBufferVarsInferer, "X"); ...@@ -149,29 +115,15 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(RollGradNoNeedBufferVarsInferer, "X");
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(roll, RollInferShapeFunctor,
PD_INFER_META(phi::RollInferMeta));
REGISTER_OPERATOR(roll, ops::RollOp, ops::RollOpMaker, REGISTER_OPERATOR(roll, ops::RollOp, ops::RollOpMaker,
ops::RollGradMaker<paddle::framework::OpDesc>, ops::RollGradMaker<paddle::framework::OpDesc>,
ops::RollGradMaker<paddle::imperative::OpBase>); ops::RollGradMaker<paddle::imperative::OpBase>,
RollInferShapeFunctor);
REGISTER_OPERATOR(roll_grad, ops::RollGradOp, REGISTER_OPERATOR(roll_grad, ops::RollGradOp,
ops::RollGradNoNeedBufferVarsInferer); ops::RollGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
roll, ops::RollKernel<paddle::platform::CPUDeviceContext, float>,
ops::RollKernel<paddle::platform::CPUDeviceContext, double>,
ops::RollKernel<paddle::platform::CPUDeviceContext, int>,
ops::RollKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::RollKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::RollKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
roll_grad, ops::RollGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::RollGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_VERSION(roll) REGISTER_OP_VERSION(roll)
.AddCheckpoint( .AddCheckpoint(
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/roll_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/core/utils/array.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, size_t Rank>
__global__ void RollCudaKernel(const T* input, T* output, int64_t N,
phi::Array<int64_t, Rank> shifts,
phi::Array<int64_t, Rank> strides,
phi::Array<int64_t, Rank> sizes) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t output_idx = idx;
int64_t new_dim_idx = 0;
#pragma unroll
for (size_t i = 0; i < Rank; i++) {
new_dim_idx = (idx / strides[i]) % sizes[i] + shifts[i];
if (new_dim_idx >= sizes[i]) {
output_idx += (shifts[i] - sizes[i]) * strides[i];
} else {
output_idx += shifts[i] * strides[i];
}
}
output[output_idx] = input[idx];
}
template <typename T>
class RollKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
auto* in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = in->numel();
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
size_t nums = shifts.size();
auto input_dim = in->dims();
auto stride_dim = phi::stride(input_dim);
std::vector<int64_t> strides(nums), sizes(nums);
if (dims.size() == 0) {
strides[0] = 1;
sizes[0] = numel;
shifts[0] = (shifts[0] % numel + numel) % numel;
} else {
for (size_t i = 0; i < nums; i++) {
int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size();
int64_t size = input_dim[dim];
if (size != 0) {
shifts[i] = (shifts[i] % size + size) % size;
strides[i] = stride_dim[dim];
sizes[i] = size;
}
}
}
#define CALL_ROLL_CUDA_KERNEL(N) \
case N: { \
phi::Array<int64_t, N> _strides; \
phi::Array<int64_t, N> _shifts; \
phi::Array<int64_t, N> _sizes; \
for (size_t idx = 0; idx < N; ++idx) { \
_strides[idx] = strides[idx]; \
_shifts[idx] = shifts[idx]; \
_sizes[idx] = sizes[idx]; \
} \
RollCudaKernel< \
T, \
N><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, \
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, numel, \
_shifts, _strides, _sizes); \
break; \
}
switch (nums) {
CALL_ROLL_CUDA_KERNEL(1);
CALL_ROLL_CUDA_KERNEL(2);
CALL_ROLL_CUDA_KERNEL(3);
CALL_ROLL_CUDA_KERNEL(4);
CALL_ROLL_CUDA_KERNEL(5);
CALL_ROLL_CUDA_KERNEL(6);
CALL_ROLL_CUDA_KERNEL(7);
CALL_ROLL_CUDA_KERNEL(8);
CALL_ROLL_CUDA_KERNEL(9);
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"shifts.size() should be less than 10, But received shifts.size() "
"= %d",
shifts.size()));
}
}
};
template <typename T>
class RollGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* out = context.Output<LoDTensor>(framework::GradVarName("X"));
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
auto* in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = in->numel();
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
size_t nums = shifts.size();
auto input_dim = in->dims();
auto stride_dim = phi::stride(input_dim);
std::vector<int64_t> strides(nums), sizes(nums);
if (dims.size() == 0) {
strides[0] = 1;
sizes[0] = numel;
shifts[0] = ((-shifts[0]) % numel + numel) % numel;
} else {
for (size_t i = 0; i < nums; i++) {
int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size();
int64_t size = input_dim[dim];
if (size != 0) {
shifts[i] = ((-shifts[i]) % size + size) % size;
strides[i] = stride_dim[dim];
sizes[i] = size;
}
}
}
switch (nums) {
CALL_ROLL_CUDA_KERNEL(1);
CALL_ROLL_CUDA_KERNEL(2);
CALL_ROLL_CUDA_KERNEL(3);
CALL_ROLL_CUDA_KERNEL(4);
CALL_ROLL_CUDA_KERNEL(5);
CALL_ROLL_CUDA_KERNEL(6);
CALL_ROLL_CUDA_KERNEL(7);
CALL_ROLL_CUDA_KERNEL(8);
CALL_ROLL_CUDA_KERNEL(9);
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"shifts.size() should be less than 10, But received shifts.size() "
"= %d",
shifts.size()));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
roll, ops::RollKernel<paddle::platform::CUDADeviceContext, float>,
ops::RollKernel<paddle::platform::CUDADeviceContext, double>,
ops::RollKernel<paddle::platform::CUDADeviceContext, int>,
ops::RollKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::RollKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::RollKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
roll_grad, ops::RollGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::RollGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
...@@ -1016,6 +1016,37 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ...@@ -1016,6 +1016,37 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
ReshapeInferMeta(x, shape, out, config); ReshapeInferMeta(x, shape, out, config);
} }
void RollInferMeta(const MetaTensor& x,
const ScalarArray& shifts,
const std::vector<int64_t>& axis,
MetaTensor* out) {
auto shifts_data = shifts.GetData();
if (axis.size() != 0) {
PADDLE_ENFORCE_EQ(
axis.size(),
shifts_data.size(),
phi::errors::InvalidArgument("When dims.size() != 0, dims.size() "
"should be equal to "
"shifts.size(). But received "
"dims.size() = %d, shifts.size() = %d",
axis.size(),
shifts_data.size()));
} else {
PADDLE_ENFORCE_EQ(
shifts_data.size(),
1,
phi::errors::InvalidArgument("When dims.size() == 0, shifts.size() "
"should be equal to 1, But received "
"shifts.size() = %d",
shifts_data.size()));
}
out->set_dims(x.dims());
out->share_lod(x);
out->set_dtype(x.dtype());
}
void ShapeInferMeta(const MetaTensor& input, MetaTensor* out) { void ShapeInferMeta(const MetaTensor& input, MetaTensor* out) {
auto in_dim = input.dims(); auto in_dim = input.dims();
out->set_dims(phi::make_ddim({in_dim.size()})); out->set_dims(phi::make_ddim({in_dim.size()}));
......
...@@ -164,6 +164,11 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ...@@ -164,6 +164,11 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void RollInferMeta(const MetaTensor& x,
const ScalarArray& shifts,
const std::vector<int64_t>& axis,
MetaTensor* out);
void ShapeInferMeta(const MetaTensor& input, MetaTensor* out); void ShapeInferMeta(const MetaTensor& input, MetaTensor* out);
void ShardIndexInferMeta(const MetaTensor& in, void ShardIndexInferMeta(const MetaTensor& in,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/roll_grad_kernel.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/roll_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void RollGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const ScalarArray& shifts,
const std::vector<int64_t>& axis,
DenseTensor* x_grad) {
std::vector<T> out_vec;
paddle::framework::TensorToVector(out_grad, dev_ctx, &out_vec);
auto shifts_data = shifts.GetData();
size_t nums = shifts_data.size();
DDim input_dim = out_grad.dims();
auto dims = axis;
// axis = none, reshape to 1-D tensor
if (dims.size() == 0) {
dims.push_back(0l);
input_dim = phi::Dim<1>(out_vec.size());
}
for (size_t i = 0; i < nums; i++) {
ShiftAlongDim(out_vec.data(), input_dim, dims[i], 0 - shifts_data[i]);
}
dev_ctx.template Alloc<T>(x_grad);
paddle::framework::TensorFromVector(out_vec, dev_ctx, x_grad);
x_grad->Resize(out_grad.dims());
}
} // namespace phi
PD_REGISTER_KERNEL(roll_grad,
CPU,
ALL_LAYOUT,
phi::RollGradKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/roll_kernel.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/roll_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void RollKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& shifts,
const std::vector<int64_t>& axis,
DenseTensor* out) {
std::vector<T> out_vec;
paddle::framework::TensorToVector(x, dev_ctx, &out_vec);
auto shifts_data = shifts.GetData();
size_t nums = shifts_data.size();
DDim input_dim = x.dims();
auto dims = axis;
// axis = none, reshape to 1-D tensor
if (dims.size() == 0) {
dims.push_back(0l);
input_dim = phi::Dim<1>(out_vec.size());
}
for (size_t i = 0; i < nums; i++) {
PADDLE_ENFORCE_EQ(
dims[i] < input_dim.size() && dims[i] >= (0 - input_dim.size()),
true,
phi::errors::OutOfRange(
"Attr(axis[%d]) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(axis[%d]) = %d.",
i,
input_dim.size(),
input_dim.size() - 1,
i,
dims[i]));
ShiftAlongDim(out_vec.data(), input_dim, dims[i], shifts_data[i]);
}
dev_ctx.template Alloc<T>(out);
paddle::framework::TensorFromVector(out_vec, dev_ctx, out);
out->Resize(x.dims());
}
} // namespace phi
PD_REGISTER_KERNEL(roll,
CPU,
ALL_LAYOUT,
phi::RollKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -13,22 +13,17 @@ ...@@ -13,22 +13,17 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { #include "paddle/phi/common/scalar_array.h"
namespace operators { #include "paddle/phi/core/dense_tensor.h"
using Tensor = framework::Tensor; namespace phi {
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;
template <typename T> template <typename T>
inline void shift_along_dim(T* data, const DDim& input_dim, int64_t dim, inline void ShiftAlongDim(T* data,
int64_t shift) { const DDim& input_dim,
int64_t dim,
int64_t shift) {
if (dim < 0) { if (dim < 0) {
dim += input_dim.size(); dim += input_dim.size();
} }
...@@ -78,92 +73,4 @@ inline void shift_along_dim(T* data, const DDim& input_dim, int64_t dim, ...@@ -78,92 +73,4 @@ inline void shift_along_dim(T* data, const DDim& input_dim, int64_t dim,
} }
} }
template <typename DeviceContext, typename T> } // namespace phi
class RollKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input_var = context.InputVar("X");
auto* output_var = context.OutputVar("Out");
auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
std::vector<T> out_vec;
paddle::framework::TensorToVector(input, context.device_context(),
&out_vec);
size_t nums = shifts.size();
DDim input_dim = input.dims();
// axis = none, reshape to 1-D tensor
if (dims.size() == 0) {
dims.push_back(0l);
input_dim = framework::Dim<1>(out_vec.size());
}
for (size_t i = 0; i < nums; i++) {
PADDLE_ENFORCE_EQ(
dims[i] < input_dim.size() && dims[i] >= (0 - input_dim.size()), true,
platform::errors::OutOfRange(
"Attr(axis[%d]) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(axis[%d]) = %d.",
i, input_dim.size(), input_dim.size() - 1, i, dims[i]));
shift_along_dim(out_vec.data(), input_dim, dims[i], shifts[i]);
}
output->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(out_vec, context.device_context(), output);
output->Resize(input.dims());
}
};
template <typename DeviceContext, typename T>
class RollGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input_var = context.InputVar(framework::GradVarName("Out"));
auto* output_var = context.OutputVar(framework::GradVarName("X"));
auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
std::vector<T> out_vec;
paddle::framework::TensorToVector(input, context.device_context(),
&out_vec);
size_t nums = shifts.size();
DDim input_dim = input.dims();
// axis = none, reshape to 1-D tensor
if (dims.size() == 0) {
dims.push_back(0l);
input_dim = framework::Dim<1>(out_vec.size());
}
for (size_t i = 0; i < nums; i++) {
shift_along_dim(out_vec.data(), input_dim, dims[i], 0 - shifts[i]);
}
output->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(out_vec, context.device_context(), output);
output->Resize(input.dims());
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/roll_grad_kernel.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/roll_kernel_impl.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename Context>
void RollGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const ScalarArray& shifts,
const std::vector<int64_t>& axis,
DenseTensor* x_grad) {
auto* in_data = out_grad.data<T>();
T* out_data = dev_ctx.template Alloc<T>(x_grad);
int64_t numel = out_grad.numel();
auto stream = dev_ctx.stream();
auto shifts_data = shifts.GetData();
size_t nums = shifts_data.size();
auto input_dim = out_grad.dims();
auto stride_dim = phi::stride(input_dim);
std::vector<int64_t> strides(nums), sizes(nums);
if (axis.size() == 0) {
strides[0] = 1;
sizes[0] = numel;
shifts_data[0] = ((-shifts_data[0]) % numel + numel) % numel;
} else {
for (size_t i = 0; i < nums; i++) {
int dim = axis[i] >= 0 ? axis[i] : axis[i] + input_dim.size();
int64_t size = input_dim[dim];
if (size != 0) {
shifts_data[i] = ((-shifts_data[i]) % size + size) % size;
strides[i] = stride_dim[dim];
sizes[i] = size;
}
}
}
switch (nums) {
CALL_ROLL_CUDA_KERNEL(1);
CALL_ROLL_CUDA_KERNEL(2);
CALL_ROLL_CUDA_KERNEL(3);
CALL_ROLL_CUDA_KERNEL(4);
CALL_ROLL_CUDA_KERNEL(5);
CALL_ROLL_CUDA_KERNEL(6);
CALL_ROLL_CUDA_KERNEL(7);
CALL_ROLL_CUDA_KERNEL(8);
CALL_ROLL_CUDA_KERNEL(9);
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"shifts.size() should be less than 10, But received shifts.size() "
"= %d",
shifts_data.size()));
}
}
} // namespace phi
PD_REGISTER_KERNEL(roll_grad,
GPU,
ALL_LAYOUT,
phi::RollGradKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/roll_kernel.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/array.h"
#include "paddle/phi/kernels/gpu/roll_kernel_impl.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename Context>
void RollKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& shifts,
const std::vector<int64_t>& axis,
DenseTensor* out) {
auto* in_data = x.data<T>();
T* out_data = dev_ctx.template Alloc<T>(out);
int64_t numel = x.numel();
auto stream = dev_ctx.stream();
auto shifts_data = shifts.GetData();
size_t nums = shifts_data.size();
auto input_dim = x.dims();
auto stride_dim = phi::stride(input_dim);
std::vector<int64_t> strides(nums), sizes(nums);
if (axis.size() == 0) {
strides[0] = 1;
sizes[0] = numel;
shifts_data[0] = (shifts_data[0] % numel + numel) % numel;
} else {
for (size_t i = 0; i < nums; i++) {
int dim = axis[i] >= 0 ? axis[i] : axis[i] + input_dim.size();
int64_t size = input_dim[dim];
if (size != 0) {
shifts_data[i] = (shifts_data[i] % size + size) % size;
strides[i] = stride_dim[dim];
sizes[i] = size;
}
}
}
switch (nums) {
CALL_ROLL_CUDA_KERNEL(1);
CALL_ROLL_CUDA_KERNEL(2);
CALL_ROLL_CUDA_KERNEL(3);
CALL_ROLL_CUDA_KERNEL(4);
CALL_ROLL_CUDA_KERNEL(5);
CALL_ROLL_CUDA_KERNEL(6);
CALL_ROLL_CUDA_KERNEL(7);
CALL_ROLL_CUDA_KERNEL(8);
CALL_ROLL_CUDA_KERNEL(9);
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"shifts.size() should be less than 10, But received shifts.size() "
"= %d",
shifts_data.size()));
}
}
} // namespace phi
PD_REGISTER_KERNEL(roll,
GPU,
ALL_LAYOUT,
phi::RollKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/core/utils/array.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
template <typename T, size_t Rank>
__global__ void RollCudaKernel(const T* input,
T* output,
int64_t N,
phi::Array<int64_t, Rank> shifts,
phi::Array<int64_t, Rank> strides,
phi::Array<int64_t, Rank> sizes) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t output_idx = idx;
int64_t new_dim_idx = 0;
#pragma unroll
for (size_t i = 0; i < Rank; i++) {
new_dim_idx = (idx / strides[i]) % sizes[i] + shifts[i];
if (new_dim_idx >= sizes[i]) {
output_idx += (shifts[i] - sizes[i]) * strides[i];
} else {
output_idx += shifts[i] * strides[i];
}
}
output[output_idx] = input[idx];
}
#define CALL_ROLL_CUDA_KERNEL(N) \
case N: { \
phi::Array<int64_t, N> _strides; \
phi::Array<int64_t, N> _shifts; \
phi::Array<int64_t, N> _sizes; \
for (size_t idx = 0; idx < N; ++idx) { \
_strides[idx] = strides[idx]; \
_shifts[idx] = shifts_data[idx]; \
_sizes[idx] = sizes[idx]; \
} \
RollCudaKernel< \
T, \
N><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, \
PADDLE_CUDA_NUM_THREADS, \
0, \
stream>>>(in_data, out_data, numel, _shifts, _strides, _sizes); \
break; \
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void RollGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const ScalarArray& shifts,
const std::vector<int64_t>& axis,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void RollKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& shifts,
const std::vector<int64_t>& axis,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature RollOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("ShiftsTensor")) {
return KernelSignature("roll", {"X"}, {"ShiftsTensor", "axis"}, {"Out"});
}
return KernelSignature("roll", {"X"}, {"shifts", "axis"}, {"Out"});
}
KernelSignature RollGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("roll_grad",
{"X", GradVarName("Out")},
{"shifts", "axis"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(roll, phi::RollOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(roll_grad, phi::RollGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册