提交 05a96ace 编写于 作者: Z zhaiyukun

Add Loop Mutator & Relu to pass three address

上级 11ed37cc
此差异已折叠。
......@@ -36,12 +36,7 @@ TEST(ToThreeAddressTest, BuildCase1) {
UTTensorElementHelper th({16, 32, 1024});
using Add = air::ir::Add;
// a(ax1, ax2) + b(ax2) + c(ax0, ax1, ax2) + d(ax2)
air::Expr expr =
Add::make(
Add::make(
Add::make(th.Elem("a", 2), th.Elem("b", 1)),
th.Elem("c", 3)),
th.Elem("d", 1));
air::Expr expr = Add::make(Add::make(Add::make(th.Elem("a", 2), th.Elem("b", 1)), th.Elem("c", 3)), th.Elem("d", 1));
std::string dump_expr = UTDumpHelper::Dump(expr);
EXPECT_EQ(dump_expr, "(((a(ax1, ax2) + b(ax2)) + c(ax0, ax1, ax2)) + d(ax2))");
}
......@@ -49,16 +44,16 @@ TEST(ToThreeAddressTest, BuildCase1) {
class ThreeAddressExprMutatorTest : public testing::Test {
public:
ThreeAddressExprMutatorTest()
: mutator_(air::TensorNode::make(
UTExprBuilder::CreateShape(shape_), // shape
dtype_, // dtype
UTExprBuilder::PlaceholderOpNode("out", shape_), // op
0), // index
UTExprBuilder::CreateVars({"ax0", "ax1", "ax2"}), // args
UTExprBuilder::CreateShape(shape_), // shape
std::unordered_set<const Call *>(), // broadcast
false, // IsReductionOp
false) {} // cross_stmt_simplify
: mutator_(air::TensorNode::make(UTExprBuilder::CreateShape(shape_), // shape
dtype_, // dtype
UTExprBuilder::PlaceholderOpNode("out", shape_), // op
0), // index
UTExprBuilder::CreateVars({"ax0", "ax1", "ax2"}), // args
UTExprBuilder::CreateVars({"ax0", "ax1", "ax2"}), // args
UTExprBuilder::CreateShape(shape_), // shape
std::unordered_set<const Call *>(), // broadcast
false, // IsReductionOp
false) {} // cross_stmt_simplify
~ThreeAddressExprMutatorTest() = default;
std::vector<int32_t> shape_ = {16, 32, 1024};
......@@ -75,10 +70,8 @@ TEST_F(ThreeAddressExprMutatorTest, MutateBinaryOp_Add) {
}
class PassTestToThreeAddress1 : public ::testing::Test {
public:
PassTestToThreeAddress1() {
Construct();
}
public:
PassTestToThreeAddress1() { Construct(); }
~PassTestToThreeAddress1() = default;
void Construct() {
a_ = UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16));
......@@ -86,20 +79,18 @@ class PassTestToThreeAddress1 : public ::testing::Test {
c_ = UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16));
out_ = UTExprBuilder::PlaceholderOpNode("out", {32, 1024}, air::Float(16));
stmt = air::ir::AttrStmt::make(
out_, "", UTExprBuilder::IntImm(1),
UTStmtBuilder::CreateRealizeByPlaceholderOp(
out_,
air::ir::ProducerConsumer::make(out_, true,
UTStmtBuilder::CreateFor(
"i", 0, 32,
out_, "", UTExprBuilder::IntImm(1),
UTStmtBuilder::CreateRealizeByPlaceholderOp(
out_, air::ir::ProducerConsumer::make(
out_, true,
UTStmtBuilder::CreateFor(
"j", 0, 1024,
"i", 0, 32,
UTStmtBuilder::CreateFor(
"j", 0, 1024,
UTStmtBuilder::CreateProvideBinary<air::ir::Add>(
out_, {"i", "j"},
air::ir::Add::make(
UTExprBuilder::ElementOf(a_, {"j"}),
UTExprBuilder::ElementOf(b_, {"i", "j"})),
UTExprBuilder::ElementOf(c_, {"j"})))))));
out_, {"i", "j"},
air::ir::Add::make(UTExprBuilder::ElementOf(a_, {"j"}), UTExprBuilder::ElementOf(b_, {"i", "j"})),
UTExprBuilder::ElementOf(c_, {"j"})))))));
}
air::Operation a_;
......@@ -110,8 +101,8 @@ class PassTestToThreeAddress1 : public ::testing::Test {
}; // class PassTestToThreeAddress1
TEST_F(PassTestToThreeAddress1, CaseCheck) {
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> infos_lhs =
UTProvideCheckerForAssign().Find(stmt, "((a(j) + b(i, j)) + c(j))");
std::vector<std::tuple<std::string, const air::ir::Provide *, uint64_t>> infos_lhs =
UTProvideCheckerForAssign().Find(stmt, "((a(j) + b(i, j)) + c(j))");
ASSERT_EQ(infos_lhs.size(), 1);
EXPECT_EQ(std::get<0>(infos_lhs[0]), "out(i, j)");
EXPECT_EQ(std::get<2>(infos_lhs[0]), 32 * 1024);
......@@ -124,21 +115,19 @@ TEST_F(PassTestToThreeAddress1, TestPass) {
* out_3(i, j) = (a(j) + out_2(i, j))
* out(i, j) = (out_3(i, j) + c(j))
*/
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> info1 =
UTProvideCheckerForAssign().Find(stmt_out, "b(i, j)");
std::vector<std::tuple<std::string, const air::ir::Provide *, uint64_t>> info1 =
UTProvideCheckerForAssign().Find(stmt_out, "b(i, j)");
ASSERT_EQ(info1.size(), 1);
std::string dump_b_target = std::get<0>(info1[0]);
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> info2 =
UTProvideCheckerForBinary().Find(
stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, "a(j)", dump_b_target);
std::vector<std::tuple<std::string, const air::ir::Provide *, uint64_t>> info2 =
UTProvideCheckerForBinary().Find(stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, "a(j)", dump_b_target);
ASSERT_EQ(info2.size(), 1);
std::string dump_sum1_target = std::get<0>(info2[0]);
EXPECT_EQ(std::get<2>(info2[0]), 32 * 1024);
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> info3 =
UTProvideCheckerForBinary().Find(
stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, dump_sum1_target, "c(j)");
std::vector<std::tuple<std::string, const air::ir::Provide *, uint64_t>> info3 =
UTProvideCheckerForBinary().Find(stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, dump_sum1_target, "c(j)");
ASSERT_EQ(info3.size(), 1);
EXPECT_EQ(std::get<0>(info3[0]), "out(i, j)");
EXPECT_EQ(std::get<2>(info3[0]), 32 * 1024);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册