diff --git a/paddle/fluid/operators/isclose_op.cc b/paddle/fluid/operators/isclose_op.cc index 8668de4d3a6288841ad191f3e47b87a76eeb1d63..1c79213757fdfa8d9ef0d7c7ab315d03f94b0c57 100644 --- a/paddle/fluid/operators/isclose_op.cc +++ b/paddle/fluid/operators/isclose_op.cc @@ -14,10 +14,13 @@ #include #include + +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -60,40 +63,6 @@ class IscloseOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Isclose"); - OP_INOUT_CHECK(ctx->HasInput("Other"), "Input", "Other", "Isclose"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Isclose"); - - auto input_dim = ctx->GetInputDim("Input"); - auto other_dim = ctx->GetInputDim("Other"); - PADDLE_ENFORCE_EQ(input_dim.size(), other_dim.size(), - platform::errors::PreconditionNotMet( - "Input(Input) and Input(Other) must have the same " - "dimension size.")); - int n = input_dim.size(); - bool is_runtime = ctx->IsRuntime(); - for (int i = 0; i < n; i++) { - if (is_runtime) { - PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i], - platform::errors::PreconditionNotMet( - "The value at dim %d of Input(Input) is not " - "equal to the Input(Other): %ld != %ld.", - i, input_dim[i], other_dim[i])); - } else { - if (!(input_dim[i] < 0 || other_dim[i] < 0)) { - PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i], - platform::errors::PreconditionNotMet( - "The value at dim %d of Input(Input) is not " - "equal to the Input(Other): %ld != %ld.", - i, input_dim[i], other_dim[i])); - } - } - } - - ctx->SetOutputDim("Out", input_dim); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -115,8 +84,10 @@ class IscloseOpVarTypeInference : public framework::VarTypeInference { namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(isclose, IscloseInferShapeFunctor, + PD_INFER_META(phi::ValueCompareInferMeta)); REGISTER_OPERATOR( isclose, ops::IscloseOp, ops::IscloseOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, - ops::IscloseOpVarTypeInference); + ops::IscloseOpVarTypeInference, IscloseInferShapeFunctor); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index ffb1ed5450232b9e5972dfac47750364bfae5ed1..d103bef2d9ed331b7b9e4d3489e11fbc2c720072 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -862,6 +862,16 @@ void TriangularSolveInferMeta(const MetaTensor& x, out->share_lod(y); } +void ValueCompareInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out, + MetaConfig config) { + detail::BinarySameInputDimsCheck(x, y, config); + + out->set_dims(x.dims()); + out->set_dtype(DataType::BOOL); +} + } // namespace phi PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index d852db7a8462db684114f0b826c4200afb069eb3..5d93bae316238ce02c33351ea4a320b3fc79a877 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -142,4 +142,9 @@ void TriangularSolveInferMeta(const MetaTensor& x, bool unitriangular, MetaTensor* out); +void ValueCompareInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out, + MetaConfig config = MetaConfig()); + } // namespace phi