未验证 提交 987cb97e 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] fix logical op infermeta (#56711)

* fix logical op infermeta

* add test

* adpat inplace api
上级 2ef4ec71
...@@ -324,4 +324,9 @@ std::ostream& operator<<(std::ostream& os, Attribute attr) { ...@@ -324,4 +324,9 @@ std::ostream& operator<<(std::ostream& os, Attribute attr) {
return os; return os;
} }
std::ostream& operator<<(std::ostream& os, const Program& prog) {
prog.Print(os);
return os;
}
} // namespace ir } // namespace ir
...@@ -71,4 +71,6 @@ class IR_API Program { ...@@ -71,4 +71,6 @@ class IR_API Program {
ParameterMap parameters_; ParameterMap parameters_;
}; };
std::ostream& operator<<(std::ostream& os, const Program& prog);
} // namespace ir } // namespace ir
...@@ -1515,7 +1515,7 @@ ...@@ -1515,7 +1515,7 @@
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : ElementwiseInferMeta func : LogicalBinaryInferMeta
kernel : kernel :
func : logical_and func : logical_and
data_type : x data_type : x
...@@ -1526,7 +1526,7 @@ ...@@ -1526,7 +1526,7 @@
args : (Tensor x) args : (Tensor x)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : LogicalNotInfermeta
kernel : kernel :
func : logical_not func : logical_not
data_type : x data_type : x
...@@ -1537,7 +1537,7 @@ ...@@ -1537,7 +1537,7 @@
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : ElementwiseInferMeta func : LogicalBinaryInferMeta
kernel : kernel :
func : logical_or func : logical_or
data_type : x data_type : x
...@@ -1548,7 +1548,7 @@ ...@@ -1548,7 +1548,7 @@
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : ElementwiseInferMeta func : LogicalBinaryInferMeta
kernel : kernel :
func : logical_xor func : logical_xor
data_type : x data_type : x
......
...@@ -1942,6 +1942,15 @@ void LogLossInferMeta(const MetaTensor& input, ...@@ -1942,6 +1942,15 @@ void LogLossInferMeta(const MetaTensor& input,
out->share_lod(input); out->share_lod(input);
} }
void LogicalBinaryInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out) {
ElementwiseInferMeta(x, y, out);
if (!(out->is_same_tensor(x))) {
out->set_dtype(DataType::BOOL);
}
}
void LUUnpackInferMeta(const MetaTensor& x, void LUUnpackInferMeta(const MetaTensor& x,
const MetaTensor& pivots, const MetaTensor& pivots,
bool unpack_ludata, bool unpack_ludata,
......
...@@ -300,6 +300,10 @@ void IndexAddInferMeta(const MetaTensor& x, ...@@ -300,6 +300,10 @@ void IndexAddInferMeta(const MetaTensor& x,
void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
void LogicalBinaryInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out);
void LogLossInferMeta(const MetaTensor& input, void LogLossInferMeta(const MetaTensor& input,
const MetaTensor& label, const MetaTensor& label,
float epsilon, float epsilon,
......
...@@ -2083,6 +2083,13 @@ void KthvalueInferMeta(const MetaTensor& x, ...@@ -2083,6 +2083,13 @@ void KthvalueInferMeta(const MetaTensor& x,
indices->set_dtype(x.dtype()); indices->set_dtype(x.dtype());
} }
void LogicalNotInfermeta(const MetaTensor& x, MetaTensor* out) {
UnchangedInferMeta(x, out);
if (!(out->is_same_tensor(x))) {
out->set_dtype(DataType::BOOL);
}
}
void LogsumexpInferMeta(const MetaTensor& input, void LogsumexpInferMeta(const MetaTensor& input,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keepdim, bool keepdim,
......
...@@ -298,6 +298,8 @@ void KthvalueInferMeta(const MetaTensor& x, ...@@ -298,6 +298,8 @@ void KthvalueInferMeta(const MetaTensor& x,
MetaTensor* indices, MetaTensor* indices,
MetaConfig = MetaConfig()); MetaConfig = MetaConfig());
void LogicalNotInfermeta(const MetaTensor& x, MetaTensor* out);
void LogsumexpInferMeta(const MetaTensor& input, void LogsumexpInferMeta(const MetaTensor& input,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keepdim, bool keepdim,
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <sstream>
#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" #include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
...@@ -193,7 +195,13 @@ TEST(program_test, program) { ...@@ -193,7 +195,13 @@ TEST(program_test, program) {
EXPECT_EQ(program.block()->size() == 5, true); EXPECT_EQ(program.block()->size() == 5, true);
EXPECT_EQ(program.parameters_num() == 3, true); EXPECT_EQ(program.parameters_num() == 3, true);
program.Print(std::cout); std::stringstream ss;
program.Print(ss);
std::stringstream ss_ostram;
ss_ostram << program;
EXPECT_EQ(ss.str(), ss_ostram.str());
} }
TEST(program_test, slice_combine_test) { TEST(program_test, slice_combine_test) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册