diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index 080e0bafc966a7f4d157661d5083cdffd1d51bed..25f23b31e2854195d3d8d351257a2304c35e3b07 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -324,4 +324,9 @@ std::ostream& operator<<(std::ostream& os, Attribute attr) { return os; } +std::ostream& operator<<(std::ostream& os, const Program& prog) { + prog.Print(os); + return os; +} + } // namespace ir diff --git a/paddle/ir/core/program.h b/paddle/ir/core/program.h index 0e2ecb58d918168b75e7d9f09961119eab3a90df..6f44a3fe4699cea22566182c43475eb817633696 100644 --- a/paddle/ir/core/program.h +++ b/paddle/ir/core/program.h @@ -71,4 +71,6 @@ class IR_API Program { ParameterMap parameters_; }; +std::ostream& operator<<(std::ostream& os, const Program& prog); + } // namespace ir diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index eca5b93e24f883164cb1e8524c5afe549132cc15..fbc058ff64e78f5e0309e19deeab262a7dbd82a5 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1515,7 +1515,7 @@ args : (Tensor x, Tensor y) output : Tensor(out) infer_meta : - func : ElementwiseInferMeta + func : LogicalBinaryInferMeta kernel : func : logical_and data_type : x @@ -1526,7 +1526,7 @@ args : (Tensor x) output : Tensor(out) infer_meta : - func : UnchangedInferMeta + func : LogicalNotInfermeta kernel : func : logical_not data_type : x @@ -1537,7 +1537,7 @@ args : (Tensor x, Tensor y) output : Tensor(out) infer_meta : - func : ElementwiseInferMeta + func : LogicalBinaryInferMeta kernel : func : logical_or data_type : x @@ -1548,7 +1548,7 @@ args : (Tensor x, Tensor y) output : Tensor(out) infer_meta : - func : ElementwiseInferMeta + func : LogicalBinaryInferMeta kernel : func : logical_xor data_type : x diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 0118db6041203d87938d9b8c1af0e397ee3f05b2..a9b14d2df3d17283cc3040c37cbeb01d08d04454 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1942,6 +1942,15 @@ void LogLossInferMeta(const MetaTensor& input, out->share_lod(input); } +void LogicalBinaryInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out) { + ElementwiseInferMeta(x, y, out); + if (!(out->is_same_tensor(x))) { + out->set_dtype(DataType::BOOL); + } +} + void LUUnpackInferMeta(const MetaTensor& x, const MetaTensor& pivots, bool unpack_ludata, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 8aa4114e74046024e00015cdc0d4acfec2f6c4bb..9060d2abc6564d40484cf6db7b44c5e735b9cb68 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -300,6 +300,10 @@ void IndexAddInferMeta(const MetaTensor& x, void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); +void LogicalBinaryInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out); + void LogLossInferMeta(const MetaTensor& input, const MetaTensor& label, float epsilon, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 4c952bb3cd2bc1ec5bc811dcf9ac8e87bcf0c725..28b80b58155fa4e72026ce8b59775e24a71ef1f0 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2083,6 +2083,13 @@ void KthvalueInferMeta(const MetaTensor& x, indices->set_dtype(x.dtype()); } +void LogicalNotInfermeta(const MetaTensor& x, MetaTensor* out) { + UnchangedInferMeta(x, out); + if (!(out->is_same_tensor(x))) { + out->set_dtype(DataType::BOOL); + } +} + void LogsumexpInferMeta(const MetaTensor& input, const std::vector& axis, bool keepdim, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 136b8c240e5f33df336abd03c2b33fed1b173092..2bf90048d30d36f9dd432d18433ebe3990f3e355 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -298,6 +298,8 @@ void KthvalueInferMeta(const MetaTensor& x, MetaTensor* indices, MetaConfig = MetaConfig()); +void LogicalNotInfermeta(const MetaTensor& x, MetaTensor* out); + void LogsumexpInferMeta(const MetaTensor& input, const std::vector& axis, bool keepdim, diff --git a/test/cpp/ir/core/ir_program_test.cc b/test/cpp/ir/core/ir_program_test.cc index dcc81e9c517827b3a31cf8c449f495d3377c4a2c..c7729ae89fde8a75adf029f508639facadcc6395 100644 --- a/test/cpp/ir/core/ir_program_test.cc +++ b/test/cpp/ir/core/ir_program_test.cc @@ -14,6 +14,8 @@ #include +#include + #include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" @@ -193,7 +195,13 @@ TEST(program_test, program) { EXPECT_EQ(program.block()->size() == 5, true); EXPECT_EQ(program.parameters_num() == 3, true); - program.Print(std::cout); + std::stringstream ss; + program.Print(ss); + + std::stringstream ss_ostram; + ss_ostram << program; + + EXPECT_EQ(ss.str(), ss_ostram.str()); } TEST(program_test, slice_combine_test) {