未验证 提交 b6661d3a 编写于 作者: Y Yang 提交者: GitHub

[phi] move infershape: flip/maxout/take_along_axis/put_along_axis (#40974)

上级 1431305e
......@@ -16,8 +16,11 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -29,72 +32,6 @@ class FlipOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
// TODO move to phi kernel
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of FlipOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of FlipOp should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto flip_dims = ctx->Attrs().Get<std::vector<int>>("axis");
size_t flip_dims_size = flip_dims.size();
if (flip_dims_size > 0) {
// check if dims axis within range
auto min_max_d = std::minmax_element(flip_dims.begin(), flip_dims.end());
PADDLE_ENFORCE_LT(
*min_max_d.first, x_dims.size(),
platform::errors::InvalidArgument(
"min(axes) should be less than the input tensor X's "
"axes of FlipOp. But received min(axes) = %d, "
"X's axes = %d, X's shape = [%s]",
*min_max_d.first, x_dims.size(), x_dims));
PADDLE_ENFORCE_GE(*min_max_d.first, x_dims.size() * -1,
platform::errors::InvalidArgument(
"min(axes) should be greater than or equal to the "
"input tensor X's "
"axes of FlipOp times -1. But received "
"min(axes) = %d, X's "
"axes = %d, X's shape = [%s]",
*min_max_d.first, x_dims.size() * -1, x_dims));
PADDLE_ENFORCE_LT(
*min_max_d.second, x_dims.size(),
platform::errors::InvalidArgument(
"max(axes) should be less than the input tensor X's "
"axes of FlipOp. But received max(axes) = %d, "
"X's axes = %d, X's shape = [%s]",
*min_max_d.second, x_dims.size(), x_dims));
PADDLE_ENFORCE_GE(*min_max_d.second, x_dims.size() * -1,
platform::errors::InvalidArgument(
"max(axes) should be greater than or equal to the "
"input tensor X's "
"axes of FlipOp times -1. But received "
"max(axes) = %d, X's "
"axes = %d, X's shape = [%s]",
*min_max_d.second, x_dims.size() * -1, x_dims));
// check duplicates in dims
flip_dims.erase(std::unique(flip_dims.begin(), flip_dims.end()),
flip_dims.end());
PADDLE_ENFORCE_EQ(flip_dims.size(), flip_dims_size,
platform::errors::InvalidArgument(
"axes has duplicates, original flip axes size=%d, "
"but unique flip axes size=%d.)",
flip_dims_size, flip_dims.size()));
}
VLOG(3) << "flip operator x.shape=" << x_dims;
std::vector<int64_t> output_dims(x_dims.size());
for (int i = 0; i < x_dims.size(); ++i) {
output_dims[i] = x_dims[i];
}
ctx->SetOutputDim("Out", phi::make_ddim(output_dims));
ctx->ShareLoD("X", "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
......@@ -148,9 +85,12 @@ class FlipOpGradMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(flip, FlipInferShapeFunctor,
PD_INFER_META(phi::FlipInferMeta));
REGISTER_OPERATOR(flip, ops::FlipOp, ops::FlipOpMaker, ops::FlipOpInferVarType,
ops::FlipOpGradMaker<paddle::framework::OpDesc>,
ops::FlipOpGradMaker<paddle::imperative::OpBase>);
ops::FlipOpGradMaker<paddle::imperative::OpBase>,
FlipInferShapeFunctor);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(flip)
......
......@@ -14,8 +14,11 @@
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -71,50 +74,12 @@ Please refer to Paper:
class MaxOutOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "maxout");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "maxout");
auto in_x_dims = ctx->GetInputDim("X");
int groups = ctx->Attrs().Get<int>("groups");
int axis = ctx->Attrs().Get<int>("axis");
// check groups > 1
PADDLE_ENFORCE_GT(groups, 1, platform::errors::InvalidArgument(
"Attr(groups) of Op(maxout) should be "
"larger than 1. But received %d.",
groups));
PADDLE_ENFORCE_EQ(
axis == 1 || axis == -1 || axis == 3, true,
platform::errors::InvalidArgument(
"axis only supported 1, -1 or 3, but recevied axis is: %d", axis));
PADDLE_ENFORCE_EQ(in_x_dims.size(), 4,
platform::errors::InvalidArgument(
"x's dims should be 4, but received x's dims is: %d",
in_x_dims.size()));
if (axis < 0) {
axis += in_x_dims.size();
}
PADDLE_ENFORCE_EQ(
in_x_dims[axis] % groups, 0,
platform::errors::InvalidArgument(
"The number of input channels for Op(maxout) "
"should be divisible by Attr(groups). But received: the "
"input's channels is [%d], the shape of input is [%s], "
"the Attr(groups) is [%d], the Attr(axis) is [%d]. The "
"error may come from wrong Attr(groups) or Attr(axis) setting.",
in_x_dims[axis], in_x_dims, groups, axis));
std::vector<int64_t> output_shape(
{in_x_dims[0], in_x_dims[1], in_x_dims[2], in_x_dims[3]});
output_shape[axis] = in_x_dims[axis] / groups;
ctx->SetOutputDim("Out", phi::make_ddim(output_shape));
}
};
class MaxOutOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "maxout_grad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
"X@Grad", "maxout_grad");
......@@ -125,8 +90,11 @@ class MaxOutOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(maxout, MaxOutInferShapeFunctor,
PD_INFER_META(phi::MaxOutInferMeta));
REGISTER_OPERATOR(
maxout, ops::MaxOutOp, ops::MaxOutOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>,
MaxOutInferShapeFunctor);
REGISTER_OPERATOR(maxout_grad, ops::MaxOutOpGrad);
......@@ -16,9 +16,12 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle {
namespace operators {
......@@ -27,18 +30,6 @@ class PutAlongAxisOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "PutAlongAxis");
OP_INOUT_CHECK(ctx->HasInput("Index"), "Input", "Index", "PutAlongAxis");
OP_INOUT_CHECK(ctx->HasInput("Value"), "Input", "Value", "PutAlongAxis");
OP_INOUT_CHECK(ctx->HasOutput("Result"), "Output", "Result",
"PutAlongAxis");
auto index_dim = ctx->GetInputDim("Index");
ctx->SetOutputDim("Result", index_dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -118,9 +109,12 @@ DECLARE_INPLACE_OP_INFERER(PutAlongAxisInplaceInferer, {"Input", "Result"});
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(put_along_axis, PutAlongAxisInferShapeFunctor,
PD_INFER_META(phi::PutAlongAxisInferMeta));
REGISTER_OPERATOR(put_along_axis, ops::PutAlongAxisOp, ops::PutAlongAxisOpMaker,
ops::PutAlongAxisGradOpMaker<paddle::framework::OpDesc>,
ops::PutAlongAxisGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::PutAlongAxisInplaceInferer);
paddle::operators::PutAlongAxisInplaceInferer,
PutAlongAxisInferShapeFunctor);
REGISTER_OPERATOR(put_along_axis_grad, ops::PutAlongAxisGradOp);
......@@ -16,9 +16,12 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -27,38 +30,6 @@ class TakeAlongAxisOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Input"), true,
platform::errors::InvalidArgument(
"Input(Input) of TakeAlongAxisOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Index"), true,
platform::errors::InvalidArgument(
"Input(Index) of TakeAlongAxisOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Result"), true,
platform::errors::InvalidArgument(
"Output(Result) of TakeAlongAxisOp should not be null."));
auto input_dim = ctx->GetInputDim("Input");
auto index_dim = ctx->GetInputDim("Index");
PADDLE_ENFORCE_GT(input_dim.size(), 0,
platform::errors::InvalidArgument(
"Dimension of the input(Input) of TakeAlongAxisOp "
"should be greater than 0.",
input_dim));
PADDLE_ENFORCE_GT(index_dim.size(), 0,
platform::errors::InvalidArgument(
"Dimension of the input(Index) of TakeAlongAxisOp "
"should be greater than 0.",
index_dim));
ctx->SetOutputDim("Result", index_dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -134,9 +105,12 @@ class TakeAlongAxisGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(take_along_axis, TakeAlongAxisInferShapeFunctor,
PD_INFER_META(phi::TakeAlongAxisInferMeta));
REGISTER_OPERATOR(take_along_axis, ops::TakeAlongAxisOp,
ops::TakeAlongAxisOpMaker,
ops::TakeAlongAxisGradOpMaker<paddle::framework::OpDesc>,
ops::TakeAlongAxisGradOpMaker<paddle::imperative::OpBase>);
ops::TakeAlongAxisGradOpMaker<paddle::imperative::OpBase>,
TakeAlongAxisInferShapeFunctor);
REGISTER_OPERATOR(take_along_axis_grad, ops::TakeAlongAxisGradOp);
......@@ -1583,6 +1583,31 @@ void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
out->share_lod(x);
}
void TakeAlongAxisInferMeta(const MetaTensor& x,
const MetaTensor& index,
int axis,
MetaTensor* out) {
auto input_dim = x.dims();
auto index_dim = index.dims();
PADDLE_ENFORCE_GT(input_dim.size(),
0,
phi::errors::InvalidArgument(
"Dimension of the input(Input) of TakeAlongAxisOp "
"should be greater than 0.",
input_dim));
PADDLE_ENFORCE_GT(index_dim.size(),
0,
phi::errors::InvalidArgument(
"Dimension of the input(Index) of TakeAlongAxisOp "
"should be greater than 0.",
index_dim));
out->set_dims(index_dim);
out->set_dtype(x.dtype());
}
void TriangularSolveInferMeta(const MetaTensor& x,
const MetaTensor& y,
bool upper,
......
......@@ -221,6 +221,11 @@ void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void TakeAlongAxisInferMeta(const MetaTensor& x,
const MetaTensor& index,
int axis,
MetaTensor* out);
void TriangularSolveInferMeta(const MetaTensor& x,
const MetaTensor& y,
bool upper,
......
......@@ -335,6 +335,16 @@ void NllLossRawInferMeta(const MetaTensor& input,
total_weight->set_dtype(input.dtype());
}
void PutAlongAxisInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& value,
int axis,
const std::string& reduce,
MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(x.dtype());
}
void RoiAlignInferMeta(const MetaTensor& x,
const MetaTensor& boxes,
paddle::optional<const MetaTensor&> boxes_num,
......
......@@ -74,6 +74,13 @@ void NllLossRawInferMeta(const MetaTensor& input,
MetaTensor* total_weight,
MetaConfig config = MetaConfig());
void PutAlongAxisInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& value,
int axis,
const std::string& reduce,
MetaTensor* out);
void RoiAlignInferMeta(const MetaTensor& x,
const MetaTensor& boxes,
paddle::optional<const MetaTensor&> boxes_num,
......
......@@ -467,6 +467,81 @@ void FlattenWithXShapeInferMeta(const MetaTensor& x,
xshape->share_lod(x);
}
void FlipInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out) {
auto x_dims = x.dims();
auto flip_dims = axis;
size_t flip_dims_size = axis.size();
if (flip_dims_size > 0) {
// check if dims axis within range
auto min_max_d = std::minmax_element(flip_dims.begin(), flip_dims.end());
PADDLE_ENFORCE_LT(*min_max_d.first,
x_dims.size(),
phi::errors::InvalidArgument(
"min(axes) should be less than the input tensor X's "
"axes of FlipOp. But received min(axes) = %d, "
"X's axes = %d, X's shape = [%s]",
*min_max_d.first,
x_dims.size(),
x_dims));
PADDLE_ENFORCE_GE(*min_max_d.first,
x_dims.size() * -1,
phi::errors::InvalidArgument(
"min(axes) should be greater than or equal to the "
"input tensor X's "
"axes of FlipOp times -1. But received "
"min(axes) = %d, X's "
"axes = %d, X's shape = [%s]",
*min_max_d.first,
x_dims.size() * -1,
x_dims));
PADDLE_ENFORCE_LT(*min_max_d.second,
x_dims.size(),
phi::errors::InvalidArgument(
"max(axes) should be less than the input tensor X's "
"axes of FlipOp. But received max(axes) = %d, "
"X's axes = %d, X's shape = [%s]",
*min_max_d.second,
x_dims.size(),
x_dims));
PADDLE_ENFORCE_GE(*min_max_d.second,
x_dims.size() * -1,
phi::errors::InvalidArgument(
"max(axes) should be greater than or equal to the "
"input tensor X's "
"axes of FlipOp times -1. But received "
"max(axes) = %d, X's "
"axes = %d, X's shape = [%s]",
*min_max_d.second,
x_dims.size() * -1,
x_dims));
// check duplicates in dims
flip_dims.erase(std::unique(flip_dims.begin(), flip_dims.end()),
flip_dims.end());
PADDLE_ENFORCE_EQ(flip_dims.size(),
flip_dims_size,
phi::errors::InvalidArgument(
"axes has duplicates, original flip axes size=%d, "
"but unique flip axes size=%d.)",
flip_dims_size,
flip_dims.size()));
}
VLOG(3) << "flip operator x.shape=" << x_dims;
std::vector<int64_t> output_dims(x_dims.size());
for (int i = 0; i < x_dims.size(); ++i) {
output_dims[i] = x_dims[i];
}
out->set_dims(phi::make_ddim(output_dims));
out->set_dtype(x.dtype());
out->share_lod(x);
}
void FullBatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape,
const Scalar& val,
......@@ -751,6 +826,52 @@ void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) {
out->set_dtype(x.dtype());
}
void MaxOutInferMeta(const MetaTensor& x,
int groups,
int axis,
MetaTensor* out) {
auto in_x_dims = x.dims();
// check groups > 1
PADDLE_ENFORCE_GT(
groups,
1,
phi::errors::InvalidArgument("Attr(groups) of Op(maxout) should be "
"larger than 1. But received %d.",
groups));
PADDLE_ENFORCE_EQ(
axis == 1 || axis == -1 || axis == 3,
true,
phi::errors::InvalidArgument(
"axis only supported 1, -1 or 3, but recevied axis is: %d", axis));
PADDLE_ENFORCE_EQ(in_x_dims.size(),
4,
phi::errors::InvalidArgument(
"x's dims should be 4, but received x's dims is: %d",
in_x_dims.size()));
if (axis < 0) {
axis += in_x_dims.size();
}
PADDLE_ENFORCE_EQ(
in_x_dims[axis] % groups,
0,
phi::errors::InvalidArgument(
"The number of input channels for Op(maxout) "
"should be divisible by Attr(groups). But received: the "
"input's channels is [%d], the shape of input is [%s], "
"the Attr(groups) is [%d], the Attr(axis) is [%d]. The "
"error may come from wrong Attr(groups) or Attr(axis) setting.",
in_x_dims[axis],
in_x_dims,
groups,
axis));
std::vector<int64_t> output_shape(
{in_x_dims[0], in_x_dims[1], in_x_dims[2], in_x_dims[3]});
output_shape[axis] = in_x_dims[axis] / groups;
out->set_dims(phi::make_ddim(output_shape));
out->set_dtype(x.dtype());
}
void MaxPoolWithIndexInferMeta(const MetaTensor& x,
const std::vector<int>& kernel_size,
const std::vector<int>& strides,
......
......@@ -98,6 +98,10 @@ void FlattenWithXShapeInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaTensor* xshape);
void FlipInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out);
void FullBatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape,
const Scalar& val,
......@@ -134,6 +138,11 @@ void KthvalueInferMeta(const MetaTensor& x,
void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out);
void MaxOutInferMeta(const MetaTensor& x,
int groups,
int axis,
MetaTensor* out);
void MaxPoolWithIndexInferMeta(const MetaTensor& x,
const std::vector<int>& kernel_size,
const std::vector<int>& strides,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册