From 75280d36afe1e5e4aab0df51a9d7ee0828ee12fa Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Tue, 1 Mar 2022 10:24:17 +0800 Subject: [PATCH] remove dot infershape (#39945) --- paddle/fluid/operators/dot_op.cc | 55 ++++++-------------------------- 1 file changed, 9 insertions(+), 46 deletions(-) diff --git a/paddle/fluid/operators/dot_op.cc b/paddle/fluid/operators/dot_op.cc index ed2b09796ee..a86a3bb3592 100644 --- a/paddle/fluid/operators/dot_op.cc +++ b/paddle/fluid/operators/dot_op.cc @@ -14,6 +14,10 @@ #include "paddle/fluid/operators/dot_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" + namespace paddle { namespace operators { @@ -21,51 +25,6 @@ class DotOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(true, ctx->HasInput("X"), - platform::errors::PreconditionNotMet( - "Input(X) of DotOp should not be null.")); - PADDLE_ENFORCE_EQ(true, ctx->HasInput("Y"), - platform::errors::PreconditionNotMet( - "Input(Y) of DotOp should not be null.")); - PADDLE_ENFORCE_EQ(true, ctx->HasOutput("Out"), - platform::errors::PreconditionNotMet( - "Output(Out) of DotOp should not be null.")); - - auto x_dims = ctx->GetInputDim("X"); - auto x_rank = static_cast(x_dims.size()); - PADDLE_ENFORCE_EQ(true, 1 == x_rank || 2 == x_rank, - platform::errors::PreconditionNotMet( - "ShapeError: The dimensions of input tensor X (%s) " - "should be 1 or 2", - x_dims.to_str())); - - auto y_dims = ctx->GetInputDim("Y"); - PADDLE_ENFORCE_EQ( - true, x_rank == (size_t)y_dims.size(), - platform::errors::PreconditionNotMet( - "ShapeError: The shape of input tensor Y: %s should match with " - "input tenosr X: %s", - y_dims.to_str(), x_dims.to_str())); - bool shape_match = true; - for (size_t i = 0; i < x_rank; ++i) { - if (x_dims[i] != y_dims[i]) { - shape_match = false; - break; - } - } - - PADDLE_ENFORCE_EQ(true, shape_match, - platform::errors::PreconditionNotMet( - "ShapeError: The shape of input tensor X: %s should " - "be exactly the same " - "with input tensor Y: %s", - x_dims.to_str(), y_dims.to_str())); - auto dims = vectorize(x_dims); - dims[dims.size() - 1] = 1; - ctx->SetOutputDim("Out", phi::make_ddim(dims)); - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( @@ -142,9 +101,13 @@ class DotOpGradMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(dot, DotInferShapeFunctor, + PT_INFER_META(phi::DotInferMeta)); + REGISTER_OPERATOR(dot, ops::DotOp, ops::DotOpMaker, ops::DotOpGradMaker, - ops::DotOpGradMaker); + ops::DotOpGradMaker, + DotInferShapeFunctor); REGISTER_OPERATOR(dot_grad, ops::DotGradOp); -- GitLab