From b94cf8427a1d414dfea20ddb545362894732e81d Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sat, 26 Mar 2022 10:22:55 +0800 Subject: [PATCH] [Phi] Move mean infershape into phi (#40922) * move mean infershape into phi * try to run ci * share layout for mkldnn * revert grad infershape * revert grad infershape --- paddle/fluid/operators/mean_op.cc | 15 ++++++++------- paddle/phi/infermeta/unary.cc | 6 ++++++ paddle/phi/infermeta/unary.h | 2 ++ 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc index 44fe4f5193..6c02648973 100644 --- a/paddle/fluid/operators/mean_op.cc +++ b/paddle/fluid/operators/mean_op.cc @@ -16,7 +16,10 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -24,12 +27,6 @@ namespace operators { class MeanOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mean"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "mean"); - ctx->SetOutputDim("Out", {1}); - } }; class MeanOpMaker : public framework::OpProtoAndCheckerMaker { @@ -90,8 +87,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(MeanGradNoNeedBufferVarsInferer, "X"); } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(mean, MeanInferShapeFunctor, + PD_INFER_META(phi::MeanAllInferMeta)); REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType, ops::MeanGradMaker, - ops::MeanGradMaker); + ops::MeanGradMaker, + MeanInferShapeFunctor); + REGISTER_OPERATOR(mean_grad, ops::MeanGradOp, ops::MeanGradNoNeedBufferVarsInferer); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 09fdc321f7..c02adddac5 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -823,6 +823,12 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, mask->set_dtype(paddle::experimental::CppTypeToDataType::Type()); } +void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out) { + out->set_dims(phi::make_ddim({1})); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); +} + void ModeInferMeta(const MetaTensor& x, int axis, bool keepdim, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index ac8d62db36..5094d03767 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -144,6 +144,8 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, MetaTensor* mask, MetaConfig config = MetaConfig()); +void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out); + void ModeInferMeta(const MetaTensor& x, int axis, bool keepdim, -- GitLab