提交 05a96ace 编写于 作者: Z zhaiyukun

Add Loop Mutator & Relu to pass three address

上级 11ed37cc
此差异已折叠。
...@@ -36,12 +36,7 @@ TEST(ToThreeAddressTest, BuildCase1) { ...@@ -36,12 +36,7 @@ TEST(ToThreeAddressTest, BuildCase1) {
UTTensorElementHelper th({16, 32, 1024}); UTTensorElementHelper th({16, 32, 1024});
using Add = air::ir::Add; using Add = air::ir::Add;
// a(ax1, ax2) + b(ax2) + c(ax0, ax1, ax2) + d(ax2) // a(ax1, ax2) + b(ax2) + c(ax0, ax1, ax2) + d(ax2)
air::Expr expr = air::Expr expr = Add::make(Add::make(Add::make(th.Elem("a", 2), th.Elem("b", 1)), th.Elem("c", 3)), th.Elem("d", 1));
Add::make(
Add::make(
Add::make(th.Elem("a", 2), th.Elem("b", 1)),
th.Elem("c", 3)),
th.Elem("d", 1));
std::string dump_expr = UTDumpHelper::Dump(expr); std::string dump_expr = UTDumpHelper::Dump(expr);
EXPECT_EQ(dump_expr, "(((a(ax1, ax2) + b(ax2)) + c(ax0, ax1, ax2)) + d(ax2))"); EXPECT_EQ(dump_expr, "(((a(ax1, ax2) + b(ax2)) + c(ax0, ax1, ax2)) + d(ax2))");
} }
...@@ -49,12 +44,12 @@ TEST(ToThreeAddressTest, BuildCase1) { ...@@ -49,12 +44,12 @@ TEST(ToThreeAddressTest, BuildCase1) {
class ThreeAddressExprMutatorTest : public testing::Test { class ThreeAddressExprMutatorTest : public testing::Test {
public: public:
ThreeAddressExprMutatorTest() ThreeAddressExprMutatorTest()
: mutator_(air::TensorNode::make( : mutator_(air::TensorNode::make(UTExprBuilder::CreateShape(shape_), // shape
UTExprBuilder::CreateShape(shape_), // shape
dtype_, // dtype dtype_, // dtype
UTExprBuilder::PlaceholderOpNode("out", shape_), // op UTExprBuilder::PlaceholderOpNode("out", shape_), // op
0), // index 0), // index
UTExprBuilder::CreateVars({"ax0", "ax1", "ax2"}), // args UTExprBuilder::CreateVars({"ax0", "ax1", "ax2"}), // args
UTExprBuilder::CreateVars({"ax0", "ax1", "ax2"}), // args
UTExprBuilder::CreateShape(shape_), // shape UTExprBuilder::CreateShape(shape_), // shape
std::unordered_set<const Call *>(), // broadcast std::unordered_set<const Call *>(), // broadcast
false, // IsReductionOp false, // IsReductionOp
...@@ -76,9 +71,7 @@ TEST_F(ThreeAddressExprMutatorTest, MutateBinaryOp_Add) { ...@@ -76,9 +71,7 @@ TEST_F(ThreeAddressExprMutatorTest, MutateBinaryOp_Add) {
class PassTestToThreeAddress1 : public ::testing::Test { class PassTestToThreeAddress1 : public ::testing::Test {
public: public:
PassTestToThreeAddress1() { PassTestToThreeAddress1() { Construct(); }
Construct();
}
~PassTestToThreeAddress1() = default; ~PassTestToThreeAddress1() = default;
void Construct() { void Construct() {
a_ = UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16)); a_ = UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16));
...@@ -88,17 +81,15 @@ class PassTestToThreeAddress1 : public ::testing::Test { ...@@ -88,17 +81,15 @@ class PassTestToThreeAddress1 : public ::testing::Test {
stmt = air::ir::AttrStmt::make( stmt = air::ir::AttrStmt::make(
out_, "", UTExprBuilder::IntImm(1), out_, "", UTExprBuilder::IntImm(1),
UTStmtBuilder::CreateRealizeByPlaceholderOp( UTStmtBuilder::CreateRealizeByPlaceholderOp(
out_, out_, air::ir::ProducerConsumer::make(
air::ir::ProducerConsumer::make(out_, true, out_, true,
UTStmtBuilder::CreateFor( UTStmtBuilder::CreateFor(
"i", 0, 32, "i", 0, 32,
UTStmtBuilder::CreateFor( UTStmtBuilder::CreateFor(
"j", 0, 1024, "j", 0, 1024,
UTStmtBuilder::CreateProvideBinary<air::ir::Add>( UTStmtBuilder::CreateProvideBinary<air::ir::Add>(
out_, {"i", "j"}, out_, {"i", "j"},
air::ir::Add::make( air::ir::Add::make(UTExprBuilder::ElementOf(a_, {"j"}), UTExprBuilder::ElementOf(b_, {"i", "j"})),
UTExprBuilder::ElementOf(a_, {"j"}),
UTExprBuilder::ElementOf(b_, {"i", "j"})),
UTExprBuilder::ElementOf(c_, {"j"}))))))); UTExprBuilder::ElementOf(c_, {"j"})))))));
} }
...@@ -110,7 +101,7 @@ class PassTestToThreeAddress1 : public ::testing::Test { ...@@ -110,7 +101,7 @@ class PassTestToThreeAddress1 : public ::testing::Test {
}; // class PassTestToThreeAddress1 }; // class PassTestToThreeAddress1
TEST_F(PassTestToThreeAddress1, CaseCheck) { TEST_F(PassTestToThreeAddress1, CaseCheck) {
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> infos_lhs = std::vector<std::tuple<std::string, const air::ir::Provide *, uint64_t>> infos_lhs =
UTProvideCheckerForAssign().Find(stmt, "((a(j) + b(i, j)) + c(j))"); UTProvideCheckerForAssign().Find(stmt, "((a(j) + b(i, j)) + c(j))");
ASSERT_EQ(infos_lhs.size(), 1); ASSERT_EQ(infos_lhs.size(), 1);
EXPECT_EQ(std::get<0>(infos_lhs[0]), "out(i, j)"); EXPECT_EQ(std::get<0>(infos_lhs[0]), "out(i, j)");
...@@ -124,21 +115,19 @@ TEST_F(PassTestToThreeAddress1, TestPass) { ...@@ -124,21 +115,19 @@ TEST_F(PassTestToThreeAddress1, TestPass) {
* out_3(i, j) = (a(j) + out_2(i, j)) * out_3(i, j) = (a(j) + out_2(i, j))
* out(i, j) = (out_3(i, j) + c(j)) * out(i, j) = (out_3(i, j) + c(j))
*/ */
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> info1 = std::vector<std::tuple<std::string, const air::ir::Provide *, uint64_t>> info1 =
UTProvideCheckerForAssign().Find(stmt_out, "b(i, j)"); UTProvideCheckerForAssign().Find(stmt_out, "b(i, j)");
ASSERT_EQ(info1.size(), 1); ASSERT_EQ(info1.size(), 1);
std::string dump_b_target = std::get<0>(info1[0]); std::string dump_b_target = std::get<0>(info1[0]);
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> info2 = std::vector<std::tuple<std::string, const air::ir::Provide *, uint64_t>> info2 =
UTProvideCheckerForBinary().Find( UTProvideCheckerForBinary().Find(stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, "a(j)", dump_b_target);
stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, "a(j)", dump_b_target);
ASSERT_EQ(info2.size(), 1); ASSERT_EQ(info2.size(), 1);
std::string dump_sum1_target = std::get<0>(info2[0]); std::string dump_sum1_target = std::get<0>(info2[0]);
EXPECT_EQ(std::get<2>(info2[0]), 32 * 1024); EXPECT_EQ(std::get<2>(info2[0]), 32 * 1024);
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> info3 = std::vector<std::tuple<std::string, const air::ir::Provide *, uint64_t>> info3 =
UTProvideCheckerForBinary().Find( UTProvideCheckerForBinary().Find(stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, dump_sum1_target, "c(j)");
stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, dump_sum1_target, "c(j)");
ASSERT_EQ(info3.size(), 1); ASSERT_EQ(info3.size(), 1);
EXPECT_EQ(std::get<0>(info3[0]), "out(i, j)"); EXPECT_EQ(std::get<0>(info3[0]), "out(i, j)");
EXPECT_EQ(std::get<2>(info3[0]), 32 * 1024); EXPECT_EQ(std::get<2>(info3[0]), 32 * 1024);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册