diff --git a/paddle/fluid/operators/abs_op.cc b/paddle/fluid/operators/abs_op.cc index 149a87fe32da16e850d5d64fb519c9bde7afef62..c28026a4bd43aac5b0c447e24a164e27233076e8 100644 --- a/paddle/fluid/operators/abs_op.cc +++ b/paddle/fluid/operators/abs_op.cc @@ -16,7 +16,10 @@ #include #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -27,16 +30,6 @@ namespace operators { class AbsOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "abs"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "abs"); - - auto in_dims = ctx->GetInputDim("X"); - - ctx->SetOutputDim("Out", in_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } }; class AbsOpMaker : public framework::OpProtoAndCheckerMaker { @@ -148,11 +141,15 @@ class AbsDoubleGradOp : public framework::OperatorWithKernel { } // namespace operators } // namespace paddle +DELCARE_INFER_SHAPE_FUNCTOR(abs, AbsInferShapeFunctor, + PT_INFER_META(phi::UnchangedInferMeta)); + namespace ops = paddle::operators; REGISTER_OPERATOR(abs, ops::AbsOp, ops::AbsOpMaker, ops::AbsGradMaker, - ops::AbsGradMaker); + ops::AbsGradMaker, + AbsInferShapeFunctor); REGISTER_OPERATOR(abs_grad, ops::AbsGradOp, ops::AbsDoubleGradMaker,