diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc index 44fe4f5193420670c21b62f71d82e7e8f5b868a6..6c026489733fd3740c3463bb0f3549d276bf04eb 100644 --- a/paddle/fluid/operators/mean_op.cc +++ b/paddle/fluid/operators/mean_op.cc @@ -16,7 +16,10 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -24,12 +27,6 @@ namespace operators { class MeanOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mean"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "mean"); - ctx->SetOutputDim("Out", {1}); - } }; class MeanOpMaker : public framework::OpProtoAndCheckerMaker { @@ -90,8 +87,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(MeanGradNoNeedBufferVarsInferer, "X"); } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(mean, MeanInferShapeFunctor, + PD_INFER_META(phi::MeanAllInferMeta)); REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType, ops::MeanGradMaker, - ops::MeanGradMaker); + ops::MeanGradMaker, + MeanInferShapeFunctor); + REGISTER_OPERATOR(mean_grad, ops::MeanGradOp, ops::MeanGradNoNeedBufferVarsInferer); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 09fdc321f7081f25020e2d3d65ed60b15bbc490e..c02adddac54a72915d5092f0416b4548eaef17c7 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -823,6 +823,12 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, mask->set_dtype(paddle::experimental::CppTypeToDataType::Type()); } +void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out) { + out->set_dims(phi::make_ddim({1})); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); +} + void ModeInferMeta(const MetaTensor& x, int axis, bool keepdim, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index ac8d62db363a2a4e928b017369f1b0ae1fca14b5..5094d0376771c477434984f12456a73a20be4d26 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -144,6 +144,8 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, MetaTensor* mask, MetaConfig config = MetaConfig()); +void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out); + void ModeInferMeta(const MetaTensor& x, int axis, bool keepdim,