From c76377005be429ad12e42e26cce39fbc16229521 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 16 Mar 2022 11:39:57 +0800 Subject: [PATCH] move isclose infershape (#40595) --- paddle/fluid/operators/isclose_op.cc | 41 ++++------------------------ paddle/phi/infermeta/binary.cc | 10 +++++++ paddle/phi/infermeta/binary.h | 5 ++++ 3 files changed, 21 insertions(+), 35 deletions(-) diff --git a/paddle/fluid/operators/isclose_op.cc b/paddle/fluid/operators/isclose_op.cc index 8668de4d3a6..1c79213757f 100644 --- a/paddle/fluid/operators/isclose_op.cc +++ b/paddle/fluid/operators/isclose_op.cc @@ -14,10 +14,13 @@ #include #include + +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -60,40 +63,6 @@ class IscloseOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Isclose"); - OP_INOUT_CHECK(ctx->HasInput("Other"), "Input", "Other", "Isclose"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Isclose"); - - auto input_dim = ctx->GetInputDim("Input"); - auto other_dim = ctx->GetInputDim("Other"); - PADDLE_ENFORCE_EQ(input_dim.size(), other_dim.size(), - platform::errors::PreconditionNotMet( - "Input(Input) and Input(Other) must have the same " - "dimension size.")); - int n = input_dim.size(); - bool is_runtime = ctx->IsRuntime(); - for (int i = 0; i < n; i++) { - if (is_runtime) { - PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i], - platform::errors::PreconditionNotMet( - "The value at dim %d of Input(Input) is not " - "equal to the Input(Other): %ld != %ld.", - i, input_dim[i], other_dim[i])); - } else { - if (!(input_dim[i] < 0 || other_dim[i] < 0)) { - PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i], - platform::errors::PreconditionNotMet( - "The value at dim %d of Input(Input) is not " - "equal to the Input(Other): %ld != %ld.", - i, input_dim[i], other_dim[i])); - } - } - } - - ctx->SetOutputDim("Out", input_dim); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -115,8 +84,10 @@ class IscloseOpVarTypeInference : public framework::VarTypeInference { namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(isclose, IscloseInferShapeFunctor, + PD_INFER_META(phi::ValueCompareInferMeta)); REGISTER_OPERATOR( isclose, ops::IscloseOp, ops::IscloseOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, - ops::IscloseOpVarTypeInference); + ops::IscloseOpVarTypeInference, IscloseInferShapeFunctor); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index ffb1ed54502..d103bef2d9e 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -862,6 +862,16 @@ void TriangularSolveInferMeta(const MetaTensor& x, out->share_lod(y); } +void ValueCompareInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out, + MetaConfig config) { + detail::BinarySameInputDimsCheck(x, y, config); + + out->set_dims(x.dims()); + out->set_dtype(DataType::BOOL); +} + } // namespace phi PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index d852db7a846..5d93bae3162 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -142,4 +142,9 @@ void TriangularSolveInferMeta(const MetaTensor& x, bool unitriangular, MetaTensor* out); +void ValueCompareInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out, + MetaConfig config = MetaConfig()); + } // namespace phi -- GitLab