From 5d08a4471973e1c2b2a595781d0a0840875a0c77 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 15 Mar 2022 00:00:04 +0800 Subject: [PATCH] move allclose infershape (#40508) --- paddle/fluid/operators/allclose_op.cc | 41 ++++------------------ paddle/phi/infermeta/binary.cc | 50 +++++++++++++++++++++++++++ paddle/phi/infermeta/binary.h | 5 +++ 3 files changed, 61 insertions(+), 35 deletions(-) diff --git a/paddle/fluid/operators/allclose_op.cc b/paddle/fluid/operators/allclose_op.cc index 706a132878d..88d7cb7c1f5 100644 --- a/paddle/fluid/operators/allclose_op.cc +++ b/paddle/fluid/operators/allclose_op.cc @@ -15,10 +15,13 @@ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -61,40 +64,6 @@ class AllcloseOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Allclose"); - OP_INOUT_CHECK(ctx->HasInput("Other"), "Input", "Other", "Allclose"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Allclose"); - - auto input_dim = ctx->GetInputDim("Input"); - auto other_dim = ctx->GetInputDim("Other"); - PADDLE_ENFORCE_EQ(input_dim.size(), other_dim.size(), - platform::errors::PreconditionNotMet( - "Input(Input) and Input(Other) must have the same " - "dimension size.")); - int n = input_dim.size(); - bool is_runtime = ctx->IsRuntime(); - for (int i = 0; i < n; i++) { - if (is_runtime) { - PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i], - platform::errors::PreconditionNotMet( - "The value at dim %d of Input(Input) is not " - "equal to the Input(Other): %ld != %ld.", - i, input_dim[i], other_dim[i])); - } else { - if (!(input_dim[i] < 0 || other_dim[i] < 0)) { - PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i], - platform::errors::PreconditionNotMet( - "The value at dim %d of Input(Input) is not " - "equal to the Input(Other): %ld != %ld.", - i, input_dim[i], other_dim[i])); - } - } - } - - ctx->SetOutputDim("Out", phi::make_ddim({1})); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -117,11 +86,13 @@ class AllcloseOpVarTypeInference : public framework::VarTypeInference { namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; +DECLARE_INFER_SHAPE_FUNCTOR(allclose, AllcloseInferShapeFunctor, + PD_INFER_META(phi::AllValueCompareInferMeta)); REGISTER_OPERATOR( allclose, ops::AllcloseOp, ops::AllcloseOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, - ops::AllcloseOpVarTypeInference); + ops::AllcloseOpVarTypeInference, AllcloseInferShapeFunctor); /* ========================== register checkpoint ===========================*/ REGISTER_OP_VERSION(allclose) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index b9d43224456..2947661517e 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -21,6 +21,56 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/common_shape.h" namespace phi { +namespace detail { + +static void BinarySameInputDimsCheck(const MetaTensor& x, + const MetaTensor& y, + MetaConfig config) { + auto input_dim = x.dims(); + auto other_dim = y.dims(); + PADDLE_ENFORCE_EQ(input_dim.size(), + other_dim.size(), + phi::errors::PreconditionNotMet( + "Input(Input) and Input(Other) must have the same " + "dimension size.")); + int n = input_dim.size(); + bool is_runtime = config.is_runtime; + for (int i = 0; i < n; i++) { + if (is_runtime) { + PADDLE_ENFORCE_EQ(input_dim[i], + other_dim[i], + phi::errors::PreconditionNotMet( + "The value at dim %d of Input(Input) is not " + "equal to the Input(Other): %ld != %ld.", + i, + input_dim[i], + other_dim[i])); + } else { + if (!(input_dim[i] < 0 || other_dim[i] < 0)) { + PADDLE_ENFORCE_EQ(input_dim[i], + other_dim[i], + phi::errors::PreconditionNotMet( + "The value at dim %d of Input(Input) is not " + "equal to the Input(Other): %ld != %ld.", + i, + input_dim[i], + other_dim[i])); + } + } + } +} + +} // namespace detail + +void AllValueCompareInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out, + MetaConfig config) { + detail::BinarySameInputDimsCheck(x, y, config); + + out->set_dims(phi::make_ddim({1})); + out->set_dtype(DataType::BOOL); +} void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { out->share_meta(x); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 307ecc29cac..cfae45cf04b 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -29,6 +29,11 @@ namespace phi { // Because functions in this file not only can infer shape, but also need // infer lod or other useful data. +void AllValueCompareInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); void BCELossInferMeta(const MetaTensor& input, -- GitLab