From a09a93a177b29bee75c1eaa99f96500d3d2087f2 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 16 Mar 2022 21:51:56 +0800 Subject: [PATCH] move determinant op infershape (#40624) --- paddle/fluid/operators/determinant_op.cc | 32 +++++++++--------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/determinant_op.cc b/paddle/fluid/operators/determinant_op.cc index 68083c7598..6959b5cf81 100644 --- a/paddle/fluid/operators/determinant_op.cc +++ b/paddle/fluid/operators/determinant_op.cc @@ -13,6 +13,10 @@ // limitations under the License. #include "paddle/fluid/operators/determinant_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -20,11 +24,6 @@ namespace operators { class DeterminantOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "determinant"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "determinant"); - } }; class DeterminantOpMaker : public framework::OpProtoAndCheckerMaker { @@ -44,19 +43,6 @@ class DeterminantGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", - "DeterminantGradOp"); - OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "DeterminantGradOp"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", - framework::GradVarName("Out"), "DeterminantGradOp"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output", - framework::GradVarName("Input"), "DeterminantGradOp"); - - ctx->SetOutputDim(framework::GradVarName("Input"), - ctx->GetInputDim("Input")); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -162,11 +148,17 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(SlogDeterminantGradNoNeedBufferVarsInferer, namespace ops = paddle::operators; namespace plat = paddle::platform; +DECLARE_INFER_SHAPE_FUNCTOR(determinant, DeterminantInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); REGISTER_OPERATOR(determinant, ops::DeterminantOp, ops::DeterminantOpMaker, ops::DeterminantGradOpMaker, - ops::DeterminantGradOpMaker); + ops::DeterminantGradOpMaker, + DeterminantInferShapeFunctor); -REGISTER_OPERATOR(determinant_grad, ops::DeterminantGradOp) +DECLARE_INFER_SHAPE_FUNCTOR(determinant_grad, DeterminantGradInferShapeFunctor, + PD_INFER_META(phi::GeneralUnaryGradInferMeta)); +REGISTER_OPERATOR(determinant_grad, ops::DeterminantGradOp, + DeterminantGradInferShapeFunctor); REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp, ops::SlogDeterminantOpMaker, -- GitLab