diff --git a/paddle/fluid/operators/allclose_op.cc b/paddle/fluid/operators/allclose_op.cc index 706a132878de441a4fa2d54d9517012c0df55fcd..88d7cb7c1f5f4bf47dc82f8632116424253d6d19 100644 --- a/paddle/fluid/operators/allclose_op.cc +++ b/paddle/fluid/operators/allclose_op.cc @@ -15,10 +15,13 @@ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -61,40 +64,6 @@ class AllcloseOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Allclose"); - OP_INOUT_CHECK(ctx->HasInput("Other"), "Input", "Other", "Allclose"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Allclose"); - - auto input_dim = ctx->GetInputDim("Input"); - auto other_dim = ctx->GetInputDim("Other"); - PADDLE_ENFORCE_EQ(input_dim.size(), other_dim.size(), - platform::errors::PreconditionNotMet( - "Input(Input) and Input(Other) must have the same " - "dimension size.")); - int n = input_dim.size(); - bool is_runtime = ctx->IsRuntime(); - for (int i = 0; i < n; i++) { - if (is_runtime) { - PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i], - platform::errors::PreconditionNotMet( - "The value at dim %d of Input(Input) is not " - "equal to the Input(Other): %ld != %ld.", - i, input_dim[i], other_dim[i])); - } else { - if (!(input_dim[i] < 0 || other_dim[i] < 0)) { - PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i], - platform::errors::PreconditionNotMet( - "The value at dim %d of Input(Input) is not " - "equal to the Input(Other): %ld != %ld.", - i, input_dim[i], other_dim[i])); - } - } - } - - ctx->SetOutputDim("Out", phi::make_ddim({1})); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -117,11 +86,13 @@ class AllcloseOpVarTypeInference : public framework::VarTypeInference { namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; +DECLARE_INFER_SHAPE_FUNCTOR(allclose, AllcloseInferShapeFunctor, + PD_INFER_META(phi::AllValueCompareInferMeta)); REGISTER_OPERATOR( allclose, ops::AllcloseOp, ops::AllcloseOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, - ops::AllcloseOpVarTypeInference); + ops::AllcloseOpVarTypeInference, AllcloseInferShapeFunctor); /* ========================== register checkpoint ===========================*/ REGISTER_OP_VERSION(allclose) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index b9d432244568365ff29dea43c465952bbfab1174..2947661517e78f328100d9b208cfa45c5f98dae1 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -21,6 +21,56 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/common_shape.h" namespace phi { +namespace detail { + +static void BinarySameInputDimsCheck(const MetaTensor& x, + const MetaTensor& y, + MetaConfig config) { + auto input_dim = x.dims(); + auto other_dim = y.dims(); + PADDLE_ENFORCE_EQ(input_dim.size(), + other_dim.size(), + phi::errors::PreconditionNotMet( + "Input(Input) and Input(Other) must have the same " + "dimension size.")); + int n = input_dim.size(); + bool is_runtime = config.is_runtime; + for (int i = 0; i < n; i++) { + if (is_runtime) { + PADDLE_ENFORCE_EQ(input_dim[i], + other_dim[i], + phi::errors::PreconditionNotMet( + "The value at dim %d of Input(Input) is not " + "equal to the Input(Other): %ld != %ld.", + i, + input_dim[i], + other_dim[i])); + } else { + if (!(input_dim[i] < 0 || other_dim[i] < 0)) { + PADDLE_ENFORCE_EQ(input_dim[i], + other_dim[i], + phi::errors::PreconditionNotMet( + "The value at dim %d of Input(Input) is not " + "equal to the Input(Other): %ld != %ld.", + i, + input_dim[i], + other_dim[i])); + } + } + } +} + +} // namespace detail + +void AllValueCompareInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out, + MetaConfig config) { + detail::BinarySameInputDimsCheck(x, y, config); + + out->set_dims(phi::make_ddim({1})); + out->set_dtype(DataType::BOOL); +} void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { out->share_meta(x); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 307ecc29cac7548a811b9bc0ae9a693c7062a354..cfae45cf04b87c287a174d172700a794c8c2a2a3 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -29,6 +29,11 @@ namespace phi { // Because functions in this file not only can infer shape, but also need // infer lod or other useful data. +void AllValueCompareInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); void BCELossInferMeta(const MetaTensor& input,