未验证 提交 7fc0c619 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move reverse kernel and infershape into phi (#40791)

* add reverse phi kernel

* add reverse infermeta

* remove original reverse op kernl & infershape
上级 67f2c9f7
......@@ -12,60 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reverse_op.h"
#include <memory>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class ReverseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Reverse");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Reverse");
auto x_var_type = ctx->GetInputsVarType("X")[0];
const auto& axis = ctx->Attrs().Get<std::vector<int>>("axis");
if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
PADDLE_ENFORCE_EQ(
axis.size(), 1,
platform::errors::InvalidArgument(
"The size of axis must be 1 when the Input(X) is LoDTensorArray, "
"but received %d.",
axis.size()));
PADDLE_ENFORCE_EQ(axis[0], 0, platform::errors::InvalidArgument(
"The value of axis should be 1 when "
"the Input(X) is LoDTensorArray, "
"but received %d.",
axis[0]));
// In runtime, shape is determined by RunImpl.
if (!ctx->IsRuntime()) {
const auto& x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
}
return;
}
const auto& x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_NE(axis.empty(), true, platform::errors::InvalidArgument(
"'axis' can not be empty."));
for (int a : axis) {
PADDLE_ENFORCE_LT(a, x_dims.size(),
paddle::platform::errors::OutOfRange(
"The axis must be less than input tensor's rank. "
"but got %d >= %d",
a, x_dims.size()));
PADDLE_ENFORCE_GE(
a, -x_dims.size(),
paddle::platform::errors::OutOfRange(
"The axis must be greater than the negative number of "
"input tensor's rank, but got %d < %d",
a, -x_dims.size()));
}
ctx->SetOutputDim("Out", x_dims);
}
};
class ReverseOpVarTypeInference : public framework::VarTypeInference {
......@@ -134,23 +94,10 @@ class ReverseGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(reverse, ReverseInferShapeFunctor,
PD_INFER_META(phi::ReverseInferMeta));
REGISTER_OPERATOR(reverse, ops::ReverseOp, ops::ReverseOpMaker,
ops::ReverseGradMaker<paddle::framework::OpDesc>,
ops::ReverseGradMaker<paddle::imperative::OpBase>,
ops::ReverseOpVarTypeInference);
ops::ReverseOpVarTypeInference, ReverseInferShapeFunctor);
REGISTER_OPERATOR(reverse_grad, ops::ReverseOp, ops::ReverseOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(
reverse, ops::ReverseKernel<paddle::platform::CPUDeviceContext, int>,
ops::ReverseKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::ReverseKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ReverseKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ReverseKernel<paddle::platform::CPUDeviceContext, float>,
ops::ReverseKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
reverse, ops::ReverseKernel<paddle::platform::CUDADeviceContext, int>,
ops::ReverseKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::ReverseKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ReverseKernel<paddle::platform::CUDADeviceContext, bool>,
ops::ReverseKernel<paddle::platform::CUDADeviceContext, float>,
ops::ReverseKernel<paddle::platform::CUDADeviceContext, double>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T, int Rank>
struct ReverseFunctor {
void operator()(const DeviceContext& context, const framework::LoDTensor& in,
framework::LoDTensor* out, const std::vector<int>& axis) {
Eigen::DSizes<bool, Rank> reverse_axis;
for (int i = 0; i < Rank; ++i) {
reverse_axis[i] = false;
}
for (int a : axis) {
if (a >= 0) {
reverse_axis[a] = true;
} else {
reverse_axis[Rank + a] = true;
}
}
auto in_eigen = framework::EigenTensor<T, Rank>::From(in);
auto out_eigen = framework::EigenTensor<T, Rank>::From(*out);
auto& dev = *context.eigen_device();
EigenReverse<std::decay_t<decltype(dev)>, T, Rank>::Eval(
dev, out_eigen, in_eigen, reverse_axis);
}
};
template <typename DeviceContext, typename T>
class ReverseKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x_var = context.InputVar("X");
const auto& axis = context.Attr<std::vector<int>>("axis");
if (x_var->IsType<framework::LoDTensorArray>()) {
auto& x_array = x_var->Get<framework::LoDTensorArray>();
auto* out_array = context.Output<framework::LoDTensorArray>("Out");
out_array->resize(x_array.size());
for (size_t offset = 0; offset < x_array.size(); offset++) {
auto& x_tensor = x_array.at(offset);
PADDLE_ENFORCE_GT(
x_tensor.memory_size(), 0,
platform::errors::PreconditionNotMet(
"The input LoDTensorArray X[%d] holds no memory.", offset));
auto out_offset = x_array.size() - offset - 1;
auto* out_tensor = &out_array->at(out_offset);
out_tensor->set_lod(x_tensor.lod());
paddle::framework::TensorCopy(x_tensor, context.GetPlace(), out_tensor);
}
return;
}
auto* x = context.Input<framework::LoDTensor>("X");
auto* out = context.Output<framework::LoDTensor>("Out");
out->mutable_data<T>(context.GetPlace());
int rank = x->dims().size();
auto& dev_ctx = context.template device_context<DeviceContext>();
switch (rank) {
case 1:
ReverseFunctor<DeviceContext, T, 1> functor1;
functor1(dev_ctx, *x, out, axis);
break;
case 2:
ReverseFunctor<DeviceContext, T, 2> functor2;
functor2(dev_ctx, *x, out, axis);
break;
case 3:
ReverseFunctor<DeviceContext, T, 3> functor3;
functor3(dev_ctx, *x, out, axis);
break;
case 4:
ReverseFunctor<DeviceContext, T, 4> functor4;
functor4(dev_ctx, *x, out, axis);
break;
case 5:
ReverseFunctor<DeviceContext, T, 5> functor5;
functor5(dev_ctx, *x, out, axis);
break;
case 6:
ReverseFunctor<DeviceContext, T, 6> functor6;
functor6(dev_ctx, *x, out, axis);
break;
default:
PADDLE_THROW(paddle::platform::errors::OutOfRange(
"The reserve operator does not support input tensors"
"whose ranks are greater than 6."));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -1240,6 +1240,33 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
ReshapeInferMeta(x, shape, out, config);
}
void ReverseInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out) {
PADDLE_ENFORCE_NE(axis.empty(),
true,
phi::errors::InvalidArgument("'axis' can not be empty."));
const auto& x_dims = x.dims();
for (int a : axis) {
PADDLE_ENFORCE_LT(a,
x_dims.size(),
phi::errors::OutOfRange(
"The axis must be less than input tensor's rank. "
"but got %d >= %d",
a,
x_dims.size()));
PADDLE_ENFORCE_GE(
a,
-x_dims.size(),
phi::errors::OutOfRange(
"The axis must be greater than the negative number of "
"input tensor's rank, but got %d < %d",
a,
-x_dims.size()));
}
out->share_meta(x);
}
void RollInferMeta(const MetaTensor& x,
const ScalarArray& shifts,
const std::vector<int64_t>& axis,
......
......@@ -198,6 +198,10 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void ReverseInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out);
void RollInferMeta(const MetaTensor& x,
const ScalarArray& shifts,
const std::vector<int64_t>& axis,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reverse_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/reverse_kernel_impl.h"
PD_REGISTER_KERNEL(reverse,
CPU,
ALL_LAYOUT,
phi::ReverseKernel,
int,
uint8_t,
int64_t,
bool,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reverse_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/reverse_kernel_impl.h"
PD_REGISTER_KERNEL(reverse,
GPU,
ALL_LAYOUT,
phi::ReverseKernel,
int,
uint8_t,
int64_t,
bool,
float,
double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/reverse_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename Context, typename T, int Rank>
struct ReverseFunctor {
void operator()(const Context& dev_ctx,
const DenseTensor& in,
DenseTensor* out,
const std::vector<int>& axis) {
Eigen::DSizes<bool, Rank> reverse_axis;
for (int i = 0; i < Rank; ++i) {
reverse_axis[i] = false;
}
for (int a : axis) {
if (a >= 0) {
reverse_axis[a] = true;
} else {
reverse_axis[Rank + a] = true;
}
}
auto in_eigen = EigenTensor<T, Rank>::From(in);
auto out_eigen = EigenTensor<T, Rank>::From(*out);
auto& dev = *dev_ctx.eigen_device();
funcs::EigenReverse<std::decay_t<decltype(dev)>, T, Rank>::Eval(
dev, out_eigen, in_eigen, reverse_axis);
}
};
template <typename T, typename Context>
void ReverseKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
int rank = x.dims().size();
switch (rank) {
case 1:
ReverseFunctor<Context, T, 1> functor1;
functor1(dev_ctx, x, out, axis);
break;
case 2:
ReverseFunctor<Context, T, 2> functor2;
functor2(dev_ctx, x, out, axis);
break;
case 3:
ReverseFunctor<Context, T, 3> functor3;
functor3(dev_ctx, x, out, axis);
break;
case 4:
ReverseFunctor<Context, T, 4> functor4;
functor4(dev_ctx, x, out, axis);
break;
case 5:
ReverseFunctor<Context, T, 5> functor5;
functor5(dev_ctx, x, out, axis);
break;
case 6:
ReverseFunctor<Context, T, 6> functor6;
functor6(dev_ctx, x, out, axis);
break;
default:
PADDLE_THROW(phi::errors::OutOfRange(
"The reserve operator does not support input tensors"
"whose ranks are greater than 6."));
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reverse_kernel.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
namespace phi {
template <typename T, typename Context>
void ReverseArrayKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<int>& axis,
std::vector<DenseTensor*> out) {
PADDLE_ENFORCE_EQ(
x.size(),
out.size(),
phi::errors::InvalidArgument("The input size(%d) and output size(%d) of "
"ReverseArrayKernel is different.",
x.size(),
out.size()));
for (size_t offset = 0; offset < x.size(); ++offset) {
auto* x_tensor = x.at(offset);
PADDLE_ENFORCE_GT(
x_tensor->memory_size(),
0,
phi::errors::PreconditionNotMet(
"The input LoDTensorArray X[%d] holds no memory.", offset));
auto out_offset = x.size() - offset - 1;
auto* out_tensor = out.at(out_offset);
out_tensor->set_lod(x_tensor->lod());
phi::Copy<Context>(
dev_ctx, *x_tensor, dev_ctx.GetPlace(), false, out_tensor);
}
}
} // namespace phi
PD_REGISTER_KERNEL(reverse_array,
CPU,
ALL_LAYOUT,
phi::ReverseArrayKernel,
int,
uint8_t,
int64_t,
bool,
float,
double) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(reverse_array,
GPU,
ALL_LAYOUT,
phi::ReverseArrayKernel,
int,
uint8_t,
int64_t,
bool,
float,
double) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ReverseKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out);
template <typename T, typename Context>
void ReverseArrayKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<int>& axis,
std::vector<DenseTensor*> out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature ReverseOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorVectorInput("X")) {
return KernelSignature("reverse_array", {"X"}, {"axis"}, {"Out"});
} else {
return KernelSignature("reverse", {"X"}, {"axis"}, {"Out"});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(reverse, phi::ReverseOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册