未验证 提交 987cb97e 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] fix logical op infermeta (#56711)

* fix logical op infermeta

* add test

* adpat inplace api
上级 2ef4ec71
......@@ -324,4 +324,9 @@ std::ostream& operator<<(std::ostream& os, Attribute attr) {
return os;
}
std::ostream& operator<<(std::ostream& os, const Program& prog) {
prog.Print(os);
return os;
}
} // namespace ir
......@@ -71,4 +71,6 @@ class IR_API Program {
ParameterMap parameters_;
};
std::ostream& operator<<(std::ostream& os, const Program& prog);
} // namespace ir
......@@ -1515,7 +1515,7 @@
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
func : LogicalBinaryInferMeta
kernel :
func : logical_and
data_type : x
......@@ -1526,7 +1526,7 @@
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
func : LogicalNotInfermeta
kernel :
func : logical_not
data_type : x
......@@ -1537,7 +1537,7 @@
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
func : LogicalBinaryInferMeta
kernel :
func : logical_or
data_type : x
......@@ -1548,7 +1548,7 @@
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
func : LogicalBinaryInferMeta
kernel :
func : logical_xor
data_type : x
......
......@@ -1942,6 +1942,15 @@ void LogLossInferMeta(const MetaTensor& input,
out->share_lod(input);
}
void LogicalBinaryInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out) {
ElementwiseInferMeta(x, y, out);
if (!(out->is_same_tensor(x))) {
out->set_dtype(DataType::BOOL);
}
}
void LUUnpackInferMeta(const MetaTensor& x,
const MetaTensor& pivots,
bool unpack_ludata,
......
......@@ -300,6 +300,10 @@ void IndexAddInferMeta(const MetaTensor& x,
void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
void LogicalBinaryInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out);
void LogLossInferMeta(const MetaTensor& input,
const MetaTensor& label,
float epsilon,
......
......@@ -2083,6 +2083,13 @@ void KthvalueInferMeta(const MetaTensor& x,
indices->set_dtype(x.dtype());
}
void LogicalNotInfermeta(const MetaTensor& x, MetaTensor* out) {
UnchangedInferMeta(x, out);
if (!(out->is_same_tensor(x))) {
out->set_dtype(DataType::BOOL);
}
}
void LogsumexpInferMeta(const MetaTensor& input,
const std::vector<int64_t>& axis,
bool keepdim,
......
......@@ -298,6 +298,8 @@ void KthvalueInferMeta(const MetaTensor& x,
MetaTensor* indices,
MetaConfig = MetaConfig());
void LogicalNotInfermeta(const MetaTensor& x, MetaTensor* out);
void LogsumexpInferMeta(const MetaTensor& input,
const std::vector<int64_t>& axis,
bool keepdim,
......
......@@ -14,6 +14,8 @@
#include <gtest/gtest.h>
#include <sstream>
#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
......@@ -193,7 +195,13 @@ TEST(program_test, program) {
EXPECT_EQ(program.block()->size() == 5, true);
EXPECT_EQ(program.parameters_num() == 3, true);
program.Print(std::cout);
std::stringstream ss;
program.Print(ss);
std::stringstream ss_ostram;
ss_ostram << program;
EXPECT_EQ(ss.str(), ss_ostram.str());
}
TEST(program_test, slice_combine_test) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册