From 987cb97e36e8f72c48c13d1913506313b8bc338b Mon Sep 17 00:00:00 2001 From: kangguangli Date: Wed, 30 Aug 2023 17:38:54 +0800 Subject: [PATCH] [NewIR] fix logical op infermeta (#56711) * fix logical op infermeta * add test * adpat inplace api --- paddle/ir/core/ir_printer.cc | 5 +++++ paddle/ir/core/program.h | 2 ++ paddle/phi/api/yaml/ops.yaml | 8 ++++---- paddle/phi/infermeta/binary.cc | 9 +++++++++ paddle/phi/infermeta/binary.h | 4 ++++ paddle/phi/infermeta/unary.cc | 7 +++++++ paddle/phi/infermeta/unary.h | 2 ++ test/cpp/ir/core/ir_program_test.cc | 10 +++++++++- 8 files changed, 42 insertions(+), 5 deletions(-) diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index 080e0bafc96..25f23b31e28 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -324,4 +324,9 @@ std::ostream& operator<<(std::ostream& os, Attribute attr) { return os; } +std::ostream& operator<<(std::ostream& os, const Program& prog) { + prog.Print(os); + return os; +} + } // namespace ir diff --git a/paddle/ir/core/program.h b/paddle/ir/core/program.h index 0e2ecb58d91..6f44a3fe469 100644 --- a/paddle/ir/core/program.h +++ b/paddle/ir/core/program.h @@ -71,4 +71,6 @@ class IR_API Program { ParameterMap parameters_; }; +std::ostream& operator<<(std::ostream& os, const Program& prog); + } // namespace ir diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index eca5b93e24f..fbc058ff64e 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1515,7 +1515,7 @@ args : (Tensor x, Tensor y) output : Tensor(out) infer_meta : - func : ElementwiseInferMeta + func : LogicalBinaryInferMeta kernel : func : logical_and data_type : x @@ -1526,7 +1526,7 @@ args : (Tensor x) output : Tensor(out) infer_meta : - func : UnchangedInferMeta + func : LogicalNotInfermeta kernel : func : logical_not data_type : x @@ -1537,7 +1537,7 @@ args : (Tensor x, Tensor y) output : Tensor(out) infer_meta : - func : ElementwiseInferMeta + func : LogicalBinaryInferMeta kernel : func : logical_or data_type : x @@ -1548,7 +1548,7 @@ args : (Tensor x, Tensor y) output : Tensor(out) infer_meta : - func : ElementwiseInferMeta + func : LogicalBinaryInferMeta kernel : func : logical_xor data_type : x diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 0118db60412..a9b14d2df3d 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1942,6 +1942,15 @@ void LogLossInferMeta(const MetaTensor& input, out->share_lod(input); } +void LogicalBinaryInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out) { + ElementwiseInferMeta(x, y, out); + if (!(out->is_same_tensor(x))) { + out->set_dtype(DataType::BOOL); + } +} + void LUUnpackInferMeta(const MetaTensor& x, const MetaTensor& pivots, bool unpack_ludata, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 8aa4114e740..9060d2abc65 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -300,6 +300,10 @@ void IndexAddInferMeta(const MetaTensor& x, void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); +void LogicalBinaryInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out); + void LogLossInferMeta(const MetaTensor& input, const MetaTensor& label, float epsilon, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 4c952bb3cd2..28b80b58155 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2083,6 +2083,13 @@ void KthvalueInferMeta(const MetaTensor& x, indices->set_dtype(x.dtype()); } +void LogicalNotInfermeta(const MetaTensor& x, MetaTensor* out) { + UnchangedInferMeta(x, out); + if (!(out->is_same_tensor(x))) { + out->set_dtype(DataType::BOOL); + } +} + void LogsumexpInferMeta(const MetaTensor& input, const std::vector& axis, bool keepdim, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 136b8c240e5..2bf90048d30 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -298,6 +298,8 @@ void KthvalueInferMeta(const MetaTensor& x, MetaTensor* indices, MetaConfig = MetaConfig()); +void LogicalNotInfermeta(const MetaTensor& x, MetaTensor* out); + void LogsumexpInferMeta(const MetaTensor& input, const std::vector& axis, bool keepdim, diff --git a/test/cpp/ir/core/ir_program_test.cc b/test/cpp/ir/core/ir_program_test.cc index dcc81e9c517..c7729ae89fd 100644 --- a/test/cpp/ir/core/ir_program_test.cc +++ b/test/cpp/ir/core/ir_program_test.cc @@ -14,6 +14,8 @@ #include +#include + #include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" @@ -193,7 +195,13 @@ TEST(program_test, program) { EXPECT_EQ(program.block()->size() == 5, true); EXPECT_EQ(program.parameters_num() == 3, true); - program.Print(std::cout); + std::stringstream ss; + program.Print(ss); + + std::stringstream ss_ostram; + ss_ostram << program; + + EXPECT_EQ(ss.str(), ss_ostram.str()); } TEST(program_test, slice_combine_test) { -- GitLab