diff --git a/tests/unittest_cpp/CMakeLists.txt b/tests/unittest_cpp/CMakeLists.txt index efc575297bb62408bda3801d0cfa4954d86c9250..8148c8b5f97a23fb414639fff9c1949bec4a6ad1 100644 --- a/tests/unittest_cpp/CMakeLists.txt +++ b/tests/unittest_cpp/CMakeLists.txt @@ -16,6 +16,7 @@ file( unittest_main.cc src/base/*.cc src/base_test/*.cc + src/pass_test_base/*.cc src/pass_test/*.cc) link_directories(${CMAKE_BINARY_DIR}/googletest/googlemock/gtest) diff --git a/tests/unittest_cpp/include/base/dump_helper.h b/tests/unittest_cpp/include/base/dump_helper.h index a9a61cbde253a73eb40a41efe1255904dd26ce12..7c3244aa6bc6332745da3bf61e948aed0e38ccb5 100644 --- a/tests/unittest_cpp/include/base/dump_helper.h +++ b/tests/unittest_cpp/include/base/dump_helper.h @@ -29,7 +29,7 @@ class UTRegxMatch { static bool RegxMatchHex(const std::string &str); static const std::string pattern_hex_; -}; // UTRegxMatch +}; // class UTRegxMatch class UTDumpHelper { public: @@ -38,6 +38,6 @@ class UTDumpHelper { static std::string Dump(const air::NodeRef &node); static bool RegxMatchPlaceholder(const std::string &str, const std::string &name); -}; // UTDumpHelper +}; // class UTDumpHelper } // namespace akg #endif // UT_BASE_DUMP_HELPER_H_ diff --git a/tests/unittest_cpp/include/base/expr_builder.h b/tests/unittest_cpp/include/base/expr_builder.h index aaf923abfc998f83023b3568f463e20dd469c93d..24d8b13aaa14f9b62389ddf7abd6dfc5c3068e6b 100644 --- a/tests/unittest_cpp/include/base/expr_builder.h +++ b/tests/unittest_cpp/include/base/expr_builder.h @@ -19,6 +19,7 @@ #include #include "tvm/expr.h" #include "tvm/operation.h" +#include "tvm/tensor.h" namespace akg { class UTExprBuilder { @@ -26,9 +27,15 @@ class UTExprBuilder { UTExprBuilder() = default; ~UTExprBuilder() = default; + static air::Expr IntImm(int64_t value, air::DataType dtype = air::Int(32)); + static air::Expr UIntImm(uint64_t value, air::DataType dtype = air::UInt(32)); + static air::Expr BoolImm(bool value); static air::Array CreateShape(const std::vector &shapes); static air::Var CreateVar(const std::string &name); static air::Array CreateVars(const std::vector &names); + static air::Range CreateRange(int32_t min, int32_t max); + static air::Region CreateRegion(const std::vector &shapes); + static air::Region CreateRegion(const air::Array &shapes); static air::Operation PlaceholderOpNode( const std::string &name, const std::vector &shapes, @@ -38,6 +45,18 @@ class UTExprBuilder { const std::vector &shapes, const std::vector &axis_names, air::DataType dtype = air::Float(16)); + static air::Expr ElementOf( + const air::Operation &op, + const std::vector &axis_names); + static air::Expr ElementOfPlaceholderOp( + const air::Operation &op, + const std::vector &axis_names); + static air::Expr CreateCall( + const air::ir::FunctionRef func, + air::Array args, + air::ir::Call::CallType call_type = air::ir::Call::Halide, + int value_index = 0); + static air::Tensor CreateTensorByPlaceholder(const air::Operation op); }; // UTExprBuilder class UTTensorElementHelper { @@ -46,14 +65,13 @@ class UTTensorElementHelper { const std::string &axis_name_prefix = "ax"); ~UTTensorElementHelper() = default; air::Expr Elem(const std::string &name, - uint32_t dim, - air::DataType dtype = air::Float(16)) const; + uint32_t dim, + air::DataType dtype = air::Float(16)) const; private: std::vector shapes_; std::string axis_name_prefix_; std::vector axis_names_; -}; // UTTensorElementHelper - +}; // class UTTensorElementHelper } // namespace akg #endif // UT_BASE_EXPR_BUILDER_H_ diff --git a/tests/unittest_cpp/include/base/ir_checker.h b/tests/unittest_cpp/include/base/ir_checker.h new file mode 100644 index 0000000000000000000000000000000000000000..d5168ba1345ea97f4ade419d8f769d015ea305df --- /dev/null +++ b/tests/unittest_cpp/include/base/ir_checker.h @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef UT_IR_CHECKER_H_ +#define UT_IR_CHECKER_H_ +#include +#include +#include +#include +#include "base/dump_helper.h" +#include "base/expr_builder.h" + +namespace akg { +class UTIRCheckHelper { + public: + UTIRCheckHelper() = default; + ~UTIRCheckHelper() = default; + static int64_t GetValueFromImm(const air::Expr &expr); +}; // class UTIRCheckHelper + +class UTProvideChecker : public air::ir::IRVisitor { + public: + explicit UTProvideChecker(bool ignore_args = false) + : ignore_args_(ignore_args) {} + ~UTProvideChecker() = default; + void Visit_(const air::ir::For *op) override; + bool CompareDump(const std::string &dump, const std::string &target); + + protected: + bool ignore_args_{false}; + std::vector for_count_stack_; +}; // class UTProvideChecker + +class UTProvideCheckerForAssign : public UTProvideChecker { + public: + explicit UTProvideCheckerForAssign(bool ignore_args = false) + : UTProvideChecker(ignore_args) {} + ~UTProvideCheckerForAssign() = default; + std::vector> Find( + const air::NodeRef &node, + const std::string &dump_rhs); + void Visit_(const air::ir::Provide *op) override; + + private: + std::string dump_rhs_{""}; + std::vector> infos_lhs_; +}; // class UTProvideChecker + +class UTProvideCheckerForBinary : public UTProvideChecker { + public: + enum BinaryOpType : int { + kAdd, + kSub, + kMul, + kDiv, + kMod, + }; + + explicit UTProvideCheckerForBinary(bool ignore_args = false) + : UTProvideChecker(ignore_args) {} + ~UTProvideCheckerForBinary() = default; + std::vector> Find( + const air::NodeRef &node, + BinaryOpType op_type, + const std::string &dump_rhs1, + const std::string &dump_rhs2); + void Visit_(const air::ir::Provide *op) override; + + template + void CheckBinary(const air::ir::Provide *op) { + const T *expr_binary = op->value.as(); + if (expr_binary == nullptr) { + return; + } + std::string dump_expr_a = UTDumpHelper::Dump(expr_binary->a); + std::string dump_expr_b = UTDumpHelper::Dump(expr_binary->b); + if ((dump_rhs1_.empty() || CompareDump(dump_expr_a, dump_rhs1_)) && + (dump_rhs2_.empty() || CompareDump(dump_expr_b, dump_rhs2_))) { + air::Expr expr_call = UTExprBuilder::CreateCall(op->func, op->args); + infos_lhs_.push_back(std::make_tuple(UTDumpHelper::Dump(expr_call), op, for_count_stack_.back())); + } + } + + private: + BinaryOpType op_type_; + std::string dump_rhs1_{""}; + std::string dump_rhs2_{""}; + std::vector> infos_lhs_; +}; // class UTProvideCheckerForBinary +} // namespace akg +#endif // UT_IR_CHECKER_H_ diff --git a/tests/unittest_cpp/include/base/stmt_builder.h b/tests/unittest_cpp/include/base/stmt_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..4d59eed97b3226ad7bc6c5bffb00306a36184ff7 --- /dev/null +++ b/tests/unittest_cpp/include/base/stmt_builder.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef UT_BASE_STMT_BUILDER_H_ +#define UT_BASE_STMT_BUILDER_H_ +#include +#include +#include +#include "tvm/ir.h" +#include "base/expr_builder.h" + +namespace akg { +class UTStmtBuilder { + public: + UTStmtBuilder() = default; + ~UTStmtBuilder() = default; + static air::Stmt CreateFor( + const std::string &loop_var_name, + int32_t min, + int32_t extent, + air::Stmt body); + static air::Stmt CreateRealizeByPlaceholderOp( + const air::Operation &op, + air::Stmt body); + static air::Stmt CreateProvideAssign( + air::ir::FunctionRef func_dst, + const std::vector &vars, + air::Expr src, + int value_index = 0); + + template + static air::Stmt CreateProvideBinary( + air::ir::FunctionRef func_dst, + const std::vector &vars, + air::Expr src1, + air::Expr src2, + int value_index = 0) { + return air::ir::Provide::make( + func_dst, + value_index, + T::make(src1, src2), + UTExprBuilder::CreateVars(vars)); + } +}; // class UTStmtBuilder +} // namespace akg +#endif // UT_BASE_STMT_BUILDER_H_ diff --git a/tests/unittest_cpp/include/pass_test_base/auto_poly_test_base.h b/tests/unittest_cpp/include/pass_test_base/auto_poly_test_base.h new file mode 100644 index 0000000000000000000000000000000000000000..4fa21dc19b322575a1568914477d82e066f6610d --- /dev/null +++ b/tests/unittest_cpp/include/pass_test_base/auto_poly_test_base.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef UT_AUTO_POLY_TEST_BASE_H_ +#define UT_AUTO_POLY_TEST_BASE_H_ +#include +#include +#include +#include +#include "codegen/util.h" +#include "contrib/cce_parm/cceconf.h" +#include "base/expr_builder.h" + +namespace akg { +class AutoPolyTestBase : public ::testing::Test { + public: + AutoPolyTestBase() = default; + ~AutoPolyTestBase() = default; + static std::map InitMapMode(); + void RegisterTensor(const air::Tensor &tensor); + void SetRunMode(const std::string &mode); + void GlobalAttrSetIsDynamic(bool arg) { + global_attrs_.Set("is_dynamic", arg ? UTExprBuilder::IntImm(1, air::Int(1)) : + UTExprBuilder::IntImm(0, air::Int(1))); + } + + void GlobalAttrSetDynamic(bool arg) { + global_attrs_.Set("dynamic", arg ? UTExprBuilder::IntImm(1, air::Int(1)) : + UTExprBuilder::IntImm(0, air::Int(1))); + } + + void GlobalAttrSetDumpPassIR(bool arg) { + global_attrs_.Set("dump_pass_ir", arg ? UTExprBuilder::IntImm(1, air::Int(1)) : + UTExprBuilder::IntImm(0, air::Int(1))); + } + + void GlobalAttrSetDumpPolyDir(const std::string &path) { + global_attrs_.Set("dump_poly_dir", air::ir::StringImm::make(path)); + } + + void GlobalAttrSetKernalName(const std::string &name) { + global_attrs_.Set("kernel_name", air::ir::StringImm::make(name)); + } + + static std::map map_mode_; + + protected: + air::Map binds_; + AttrMap global_attrs_; +}; // class AutoPolyTestBase +} // namespace akg +#endif diff --git a/tests/unittest_cpp/src/base/expr_builder.cc b/tests/unittest_cpp/src/base/expr_builder.cc index efad73652d1f33a35c7356e4edee8fadba178553..be870ccbd5da591a2e275ed413f7a74ea061caea 100644 --- a/tests/unittest_cpp/src/base/expr_builder.cc +++ b/tests/unittest_cpp/src/base/expr_builder.cc @@ -14,9 +14,22 @@ * limitations under the License. */ #include +#include #include "base/expr_builder.h" namespace akg { +air::Expr UTExprBuilder::IntImm(int64_t value, air::DataType dtype) { + return air::IntImm::make(dtype, value); +} + +air::Expr UTExprBuilder::UIntImm(uint64_t value, air::DataType dtype) { + return air::ir::UIntImm::make(dtype, value); +} + +air::Expr UTExprBuilder::BoolImm(bool value) { + return air::ir::UIntImm::make(air::Bool(), value ? 1 : 0); +} + air::Array UTExprBuilder::CreateShape(const std::vector &shapes) { air::Array res; for (int32_t shape : shapes) { @@ -38,6 +51,28 @@ air::Array UTExprBuilder::CreateVars(const std::vector & return vars; } +air::Region UTExprBuilder::CreateRegion(const std::vector &shapes) { + air::Region region; + for (int32_t shape : shapes) { + region.push_back(CreateRange(0, shape)); + } + return region; +} + +air::Region UTExprBuilder::CreateRegion(const air::Array &shapes) { + air::Region region; + for (const air::Expr &shape : shapes) { + region.push_back(air::Range::make_by_min_extent(IntImm(0), shape)); + } + return region; +} + +air::Range UTExprBuilder::CreateRange(int32_t min, int32_t max) { + air::Integer imm_min = air::IntImm::make(air::Int(32), min); + air::Integer imm_max = air::IntImm::make(air::Int(32), max); + return air::Range(std::move(imm_min), std::move(imm_max)); +} + air::Operation UTExprBuilder::PlaceholderOpNode( const std::string &name, const std::vector &shapes, @@ -60,6 +95,56 @@ air::Expr UTExprBuilder::TensorElement( 0); // value_index } +air::Expr UTExprBuilder::ElementOf( + const air::Operation &op, + const std::vector &axis_names) { + if (op->template IsInstance()) { + return ElementOfPlaceholderOp(op, axis_names); + } else { + CHECK(false); + return air::ir::Any::make(); + } +} + +air::Expr UTExprBuilder::ElementOfPlaceholderOp( + const air::Operation &op, + const std::vector &axis_names) { + const air::PlaceholderOpNode *node = op.as(); + CHECK(node); + return air::ir::Call::make( + node->dtype, + node->name, + CreateVars(axis_names), + air::ir::Call::Halide, + op, + 0); +} +air::Expr UTExprBuilder::CreateCall( + const air::ir::FunctionRef func, + air::Array args, + air::ir::Call::CallType call_type, + int value_index) { + air::DataType type = air::Float(16); + const air::OperationNode *node_op = func.as(); + CHECK(node_op); + std::string name = node_op->name; + const air::PlaceholderOpNode *node_placeholder = func.as(); + if (node_placeholder != nullptr) { + type = node_placeholder->dtype; + } + return air::ir::Call::make(type, name, args, call_type, func, value_index); +} + +air::Tensor UTExprBuilder::CreateTensorByPlaceholder(const air::Operation op) { + const air::PlaceholderOpNode *node = op.as(); + CHECK(node); + return air::TensorNode::make( + node->shape, + node->dtype, + op, + 0); +} + UTTensorElementHelper::UTTensorElementHelper(const std::vector &shapes, const std::string &axis_name_prefix) : shapes_(shapes), axis_name_prefix_(axis_name_prefix) { @@ -72,8 +157,8 @@ UTTensorElementHelper::UTTensorElementHelper(const std::vector &shapes, } air::Expr UTTensorElementHelper::Elem(const std::string &name, - uint32_t dim, - air::DataType dtype) const { + uint32_t dim, + air::DataType dtype) const { uint32_t start = shapes_.size() - dim; return UTExprBuilder::TensorElement( name, diff --git a/tests/unittest_cpp/src/base/ir_checker.cc b/tests/unittest_cpp/src/base/ir_checker.cc new file mode 100644 index 0000000000000000000000000000000000000000..c3277022c7ba506bd8e679124dc0ef3c29f2ac09 --- /dev/null +++ b/tests/unittest_cpp/src/base/ir_checker.cc @@ -0,0 +1,118 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "base/ir_checker.h" +#include +#include +#include "base/dump_helper.h" +#include "base/expr_builder.h" + +namespace akg { +int64_t UTIRCheckHelper::GetValueFromImm(const air::Expr &expr) { + const air::IntImm *imm_int = expr.as(); + if (imm_int != nullptr) { + return imm_int->value; + } + const air::ir::UIntImm *imm_uint = expr.as(); + if (imm_uint != nullptr) { + CHECK(imm_uint->value < INT64_MAX); + return static_cast(imm_uint->value); + } + return 0; +} + +void UTProvideChecker::Visit_(const air::ir::For *op) { + uint64_t count_top = for_count_stack_.back(); + int64_t min = UTIRCheckHelper::GetValueFromImm(op->min); + int64_t extent = UTIRCheckHelper::GetValueFromImm(op->extent); + CHECK(extent > min); + count_top *= static_cast(extent); + for_count_stack_.push_back(count_top); + IRVisitor::Visit_(op); + for_count_stack_.pop_back(); +} + +bool UTProvideChecker::CompareDump( + const std::string &dump, + const std::string &target) { + if (dump.compare(target) == 0) { + return true; + } + if (ignore_args_) { + size_t npos = dump.find("("); + return dump.substr(0, npos).compare(target) == 0; + } + return false; +} + +std::vector> UTProvideCheckerForAssign::Find( + const air::NodeRef &node, + const std::string &dump_rhs) { + dump_rhs_ = dump_rhs; + infos_lhs_.clear(); + for_count_stack_.clear(); + for_count_stack_.push_back(1); + Visit(node); + return infos_lhs_; +} + +void UTProvideCheckerForAssign::Visit_(const air::ir::Provide *op) { + std::string dump_expr = UTDumpHelper::Dump(op->value); + if (CompareDump(dump_expr, dump_rhs_)) { + air::Expr expr_call = UTExprBuilder::CreateCall(op->func, op->args); + infos_lhs_.push_back(std::make_tuple(UTDumpHelper::Dump(expr_call), op, for_count_stack_.back())); + } +} + +std::vector> UTProvideCheckerForBinary::Find( + const air::NodeRef &node, + UTProvideCheckerForBinary::BinaryOpType op_type, + const std::string &dump_rhs1, + const std::string &dump_rhs2) { + op_type_ = op_type; + dump_rhs1_ = dump_rhs1; + dump_rhs2_ = dump_rhs2; + infos_lhs_.clear(); + for_count_stack_.clear(); + for_count_stack_.push_back(1); + if (dump_rhs1_.empty() && dump_rhs2_.empty()) { + return infos_lhs_; + } + Visit(node); + return infos_lhs_; +} + +void UTProvideCheckerForBinary::Visit_(const air::ir::Provide *op) { + switch (op_type_) { + case BinaryOpType::kAdd: + CheckBinary(op); + break; + case BinaryOpType::kSub: + CheckBinary(op); + break; + case BinaryOpType::kMul: + CheckBinary(op); + break; + case BinaryOpType::kDiv: + CheckBinary(op); + break; + case BinaryOpType::kMod: + CheckBinary(op); + break; + default: + break; + } +} +} // namespace akg diff --git a/tests/unittest_cpp/src/base/stmt_builder.cc b/tests/unittest_cpp/src/base/stmt_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..72b6dd184383f2ffbe887571ef995d0b6c845cbc --- /dev/null +++ b/tests/unittest_cpp/src/base/stmt_builder.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "base/stmt_builder.h" + +namespace akg { +air::Stmt UTStmtBuilder::CreateFor( + const std::string &loop_var_name, + int32_t min, + int32_t extent, + air::Stmt body) { + return air::ir::For::make( + UTExprBuilder::CreateVar(loop_var_name), + UTExprBuilder::IntImm(min), + UTExprBuilder::IntImm(extent), + air::ir::ForType::Serial, + air::ir::DeviceAPI::None, + body); +} +air::Stmt UTStmtBuilder::CreateRealizeByPlaceholderOp( + const air::Operation &op, + air::Stmt body) { + const air::PlaceholderOpNode *node = op.as(); + CHECK(node); + return air::ir::Realize::make( + op, + 0, + node->dtype, + UTExprBuilder::CreateRegion(node->shape), + UTExprBuilder::BoolImm(true), + body); +} + +air::Stmt UTStmtBuilder::CreateProvideAssign( + air::ir::FunctionRef func_dst, + const std::vector &vars, + air::Expr src, + int value_index) { + return air::ir::Provide::make( + func_dst, + value_index, + src, + UTExprBuilder::CreateVars(vars)); +} +} // namespace akg diff --git a/tests/unittest_cpp/src/base_test/expr_builder_test.cc b/tests/unittest_cpp/src/base_test/expr_builder_test.cc index 0ba6b98373e7b78f4f246a3a33932e4c4dc5c5e9..f2635e20f9ef041188bceb79df7fa01eb953a245 100644 --- a/tests/unittest_cpp/src/base_test/expr_builder_test.cc +++ b/tests/unittest_cpp/src/base_test/expr_builder_test.cc @@ -18,6 +18,39 @@ #include "base/expr_builder.h" namespace akg { +TEST(UTExprBuilder, IntImm) { + air::Expr int1 = UTExprBuilder::IntImm(1024); + std::string dump_int1 = UTDumpHelper::Dump(int1); + EXPECT_EQ(dump_int1, "1024"); + air::Expr int2 = UTExprBuilder::IntImm(1024, air::Int(64)); + std::string dump_int2 = UTDumpHelper::Dump(int2); + EXPECT_EQ(dump_int2, "(int64)1024"); + air::Expr int3 = UTExprBuilder::IntImm(1024, air::Int(16)); + std::string dump_int3 = UTDumpHelper::Dump(int3); + EXPECT_EQ(dump_int3, "(int16)1024"); +} + +TEST(UTExprBuilder, UIntImm) { + air::Expr uint1 = UTExprBuilder::UIntImm(1024); + std::string dump_uint1 = UTDumpHelper::Dump(uint1); + EXPECT_EQ(dump_uint1, "(uint32)1024"); + air::Expr uint2 = UTExprBuilder::UIntImm(1024, air::UInt(64)); + std::string dump_uint2 = UTDumpHelper::Dump(uint2); + EXPECT_EQ(dump_uint2, "(uint64)1024"); + air::Expr uint3 = UTExprBuilder::UIntImm(1024, air::UInt(16)); + std::string dump_uint3 = UTDumpHelper::Dump(uint3); + EXPECT_EQ(dump_uint3, "(uint16)1024"); +} + +TEST(UTExprBuilder, Bool) { + air::Expr bool_true = UTExprBuilder::BoolImm(true); + std::string dump_bool_true = UTDumpHelper::Dump(bool_true); + EXPECT_EQ(dump_bool_true, "(bool)1"); + air::Expr bool_false = UTExprBuilder::BoolImm(false); + std::string dump_bool_false = UTDumpHelper::Dump(bool_false); + EXPECT_EQ(dump_bool_false, "(bool)0"); +} + TEST(UTExprBuilder, CreateShape) { air::Array shape1 = UTExprBuilder::CreateShape({1024}); std::string dump_shape1 = UTDumpHelper::Dump(shape1); @@ -44,6 +77,12 @@ TEST(UTExprBuilder, CreateVars) { EXPECT_EQ(dump_vars, "[ax0, ax1, ax2]"); } +TEST(UTExprBuilder, CreateRange) { + air::Range range = UTExprBuilder::CreateRange(0, 1024); + std::string dump_range = UTDumpHelper::Dump(range); + EXPECT_EQ(dump_range, "range(min=0, ext=1024)"); +} + TEST(UTExprBuilder, PlaceholderOpNode) { air::Operation node = UTExprBuilder::PlaceholderOpNode("input", {16, 32, 1024}, air::Float(16)); std::string dump_node = UTDumpHelper::Dump(node); @@ -56,6 +95,13 @@ TEST(UTExprBuilder, TensorElement) { EXPECT_EQ(dump_elem, "input(ax0, ax1, ax2)"); } +TEST(UTExprBuilder, ElememtOfPlaceholderOp) { + air::Operation op = UTExprBuilder::PlaceholderOpNode("input", {16, 32, 1024}, air::Float(16)); + air::Expr elem = UTExprBuilder::ElementOfPlaceholderOp(op, {"ax0", "ax1", "ax2"}); + std::string dump_elem = UTDumpHelper::Dump(elem); + EXPECT_EQ(dump_elem, "input(ax0, ax1, ax2)"); +} + TEST(UTTensorElementHelper, TensorElement) { UTTensorElementHelper helper({16, 32, 1024}); std::string dump_elem1 = UTDumpHelper::Dump(helper.Elem("a", 3)); diff --git a/tests/unittest_cpp/src/base_test/ir_checker_test.cc b/tests/unittest_cpp/src/base_test/ir_checker_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..321d71924d0aab6dc04625c3b3fd425a025cf7a8 --- /dev/null +++ b/tests/unittest_cpp/src/base_test/ir_checker_test.cc @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "base/ir_checker.h" +#include "base/expr_builder.h" +#include "base/stmt_builder.h" + +namespace akg { +TEST(UTProvideChecker, CompareDump) { + EXPECT_EQ(UTProvideChecker().CompareDump("a(i, j)", "a"), false); + EXPECT_EQ(UTProvideChecker().CompareDump("a(i, j)", "a(i, j)"), true); + EXPECT_EQ(UTProvideChecker(true).CompareDump("a(i, j)", "a"), true); + EXPECT_EQ(UTProvideChecker(true).CompareDump("a(i, j)", "a(i, j)"), true); +} + +class UTProvideCheckerTest : public testing::Test { + public: + UTProvideCheckerTest() + : a_(UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16))), + b_(UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16))), + c_(UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16))) {} + ~UTProvideCheckerTest() = default; + air::Operation a_; + air::Operation b_; + air::Operation c_; +}; // class UTProvideCheckerTest + +TEST_F(UTProvideCheckerTest, UTProvideCheckerForAssign) { + // b(ax0) = a(ax0) + air::Stmt stmt = UTStmtBuilder::CreateProvideAssign( + b_, {"ax0"}, UTExprBuilder::ElementOf(a_, {"ax0"})); + std::vector> infos_lhs = + UTProvideCheckerForAssign().Find(stmt, "a(ax0)"); + ASSERT_EQ(infos_lhs.size(), 1); + EXPECT_EQ(std::get<0>(infos_lhs[0]), "b(ax0)"); + EXPECT_EQ(std::get<2>(infos_lhs[0]), 1); +} + +TEST_F(UTProvideCheckerTest, UTProvideCheckerForBinary) { + // c(ax0) = (a(ax0) + b(ax0)) + air::Stmt stmt = UTStmtBuilder::CreateProvideBinary( + c_, {"ax0"}, + UTExprBuilder::ElementOf(a_, {"ax0"}), + UTExprBuilder::ElementOf(b_, {"ax0"})); + std::vector> infos_lhs = + UTProvideCheckerForBinary().Find(stmt, UTProvideCheckerForBinary::BinaryOpType::kAdd, "a(ax0)", "b(ax0)"); + ASSERT_EQ(infos_lhs.size(), 1); + EXPECT_EQ(std::get<0>(infos_lhs[0]), "c(ax0)"); + EXPECT_EQ(std::get<2>(infos_lhs[0]), 1); +} + +class UTProvideCheckerTest2 : public testing::Test { + public: + UTProvideCheckerTest2() + : a_(UTExprBuilder::PlaceholderOpNode("a", {16, 32, 1024}, air::Float(16))), + b_(UTExprBuilder::PlaceholderOpNode("b", {16, 32, 1024}, air::Float(16))), + c_(UTExprBuilder::PlaceholderOpNode("c", {16, 32, 1024}, air::Float(16))) {} + ~UTProvideCheckerTest2() = default; + air::Operation a_; + air::Operation b_; + air::Operation c_; +}; // class UTProvideCheckerTest + +TEST_F(UTProvideCheckerTest2, UTProvideCheckerForBinary) { + air::Stmt stmt = UTStmtBuilder::CreateFor( + "i", 0, 16, + UTStmtBuilder::CreateFor( + "j", 0, 32, + UTStmtBuilder::CreateFor( + "k", 0, 1024, + UTStmtBuilder::CreateProvideBinary( + c_, {"i", "j", "k"}, + UTExprBuilder::ElementOf(a_, {"i", "j", "k"}), + UTExprBuilder::ElementOf(b_, {"i", "j", "k"}))))); + std::string dump_stmt = UTDumpHelper::Dump(stmt); + EXPECT_EQ(dump_stmt, + "for (i, 0, 16) {\n" + " for (j, 0, 32) {\n" + " for (k, 0, 1024) {\n" + " c(i, j, k) = (a(i, j, k) + b(i, j, k))\n" + " }\n" + " }\n" + "}\n"); + std::vector> infos_lhs = + UTProvideCheckerForBinary().Find( + stmt, UTProvideCheckerForBinary::BinaryOpType::kAdd, "a(i, j, k)", "b(i, j, k)"); + ASSERT_EQ(infos_lhs.size(), 1); + EXPECT_EQ(std::get<0>(infos_lhs[0]), "c(i, j, k)"); + EXPECT_EQ(std::get<2>(infos_lhs[0]), 1024 * 32 * 16); +} +} // namespace akg diff --git a/tests/unittest_cpp/src/base_test/stmt_builder_test.cc b/tests/unittest_cpp/src/base_test/stmt_builder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc732d8b9b7dffb158482fa2e6b87eb324143aa4 --- /dev/null +++ b/tests/unittest_cpp/src/base_test/stmt_builder_test.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "gtest/gtest.h" +#include "base/dump_helper.h" +#include "base/expr_builder.h" +#include "base/stmt_builder.h" + +namespace akg { +TEST(UTStmtBuilder, CreateProvideAssign) { + // b(ax0) = a(ax0) + air::Operation a = UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16)); + air::Operation b = UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16)); + air::Stmt stmt = UTStmtBuilder::CreateProvideAssign( + b, {"ax0"}, UTExprBuilder::ElementOf(a, {"ax0"})); + std::string dump_stmt = UTDumpHelper::Dump(stmt); + EXPECT_EQ(dump_stmt, "b(ax0) = a(ax0)\n"); +} + +TEST(UTStmtBuilder, CreateProvideBinary) { + // c(ax0) = a(ax0) + b(ax0) + air::Operation a = UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16)); + air::Operation b = UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16)); + air::Operation c = UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16)); + air::Stmt stmt = UTStmtBuilder::CreateProvideBinary( + c, {"ax0"}, UTExprBuilder::ElementOf(a, {"ax0"}), UTExprBuilder::ElementOf(b, {"ax0"})); + std::string dump_stmt = UTDumpHelper::Dump(stmt); + EXPECT_EQ(dump_stmt, "c(ax0) = (a(ax0) + b(ax0))\n"); +} + +TEST(UTStmtBuilder, CreateFor) { + /* + * for (i, 0, 1024) { + * c(i) = (a(i) + b(i)) + * } + */ + air::Operation a = UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16)); + air::Operation b = UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16)); + air::Operation c = UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16)); + air::Stmt stmt_for = UTStmtBuilder::CreateFor( + "i", 0, 1024, + UTStmtBuilder::CreateProvideBinary( + c, {"i"}, UTExprBuilder::ElementOf(a, {"i"}), UTExprBuilder::ElementOf(b, {"i"}))); + std::string dump_stmt_for = UTDumpHelper::Dump(stmt_for); + EXPECT_EQ(dump_stmt_for, + "for (i, 0, 1024) {\n" + " c(i) = (a(i) + b(i))\n" + "}\n"); +} + +TEST(UTStmtBuilder, CreateRealizeByPlaceholderOp) { + air::Operation a = UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16)); + air::Operation b = UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16)); + air::Operation c = UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16)); + air::Stmt stmt_realize = UTStmtBuilder::CreateRealizeByPlaceholderOp( + c, + UTStmtBuilder::CreateFor( + "i", 0, 1024, + UTStmtBuilder::CreateProvideBinary( + c, {"i"}, UTExprBuilder::ElementOf(a, {"i"}), UTExprBuilder::ElementOf(b, {"i"})))); + std::string dump_stmt_realize = UTDumpHelper::Dump(stmt_realize); + EXPECT_EQ(dump_stmt_realize, + "realize c([0, 1024]) {\n" + " for (i, 0, 1024) {\n" + " c(i) = (a(i) + b(i))\n" + " }\n" + "}\n"); +} +} // namespace akg diff --git a/tests/unittest_cpp/src/pass_test/auto_poly_test.cc b/tests/unittest_cpp/src/pass_test/auto_poly_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5ea1012371cb10c4af77d02073465b3080644c49 --- /dev/null +++ b/tests/unittest_cpp/src/pass_test/auto_poly_test.cc @@ -0,0 +1,309 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "base/dump_helper.h" +#include "base/expr_builder.h" +#include "base/ir_checker.h" +#include "base/stmt_builder.h" +#include "pass_test_base/auto_poly_test_base.h" +#define private public +#define protected public +#include "ir_pass.h" +#undef protected +#undef private +#include "codegen/util.h" +#include "contrib/cce_parm/cceconf.h" + +namespace akg { +/* AutoPolyTest1: test for to_three_address + * Input pattern: + * for (i0, 0, 32) { + * for (i1, 0, 1024) { + * out_0(i1) = b(i1) + c(i1) + * out(i0, i1) = out_0(i1) + a(i0, i1) + * } + * } + * + * Expect output: + * for (cc1, 0, 2) { + * for (cc2, 0, 16) { + * for (cc3, 0, 1024) { + * out_0_local_UB(cc2) = (b_local_UB(cc2) + c_local_UB(cc2)) + * out_local_UB(cc3, cc2) = (out_0_local_UB(cc2) + a_local_UB(cc3, cc2)) + * } + * } + * } + * + * IR Check: + * count for (b_local_UB + c_local_UB): 32 * 1024 + */ +class AutoPolyTest1 : public AutoPolyTestBase { + public: + AutoPolyTest1() { + Construct(); + } + ~AutoPolyTest1() = default; + void Construct() { + a_ = UTExprBuilder::PlaceholderOpNode("a", {32, 1024}, air::Float(16)); + b_ = UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16)); + c_ = UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16)); + out_ = UTExprBuilder::PlaceholderOpNode("out", {32, 1024}, air::Float(16)); + out_0_ = UTExprBuilder::PlaceholderOpNode("out_0", {1024}, air::Float(16)); + stmt_ = air::ir::AttrStmt::make( + out_0_, "realize_scope", air::ir::StringImm::make(""), + UTStmtBuilder::CreateRealizeByPlaceholderOp( + out_0_, + air::ir::AttrStmt::make( + out_, "realize_scope", air::ir::StringImm::make(""), + UTStmtBuilder::CreateRealizeByPlaceholderOp( + out_, + air::ir::ProducerConsumer::make(out_, true, + UTStmtBuilder::CreateFor( + "i0", 0, 32, + UTStmtBuilder::CreateFor( + "i1", 0, 1024, + air::ir::Block::make( + UTStmtBuilder::CreateProvideBinary( + out_0_, {"i1"}, + UTExprBuilder::ElementOf(b_, {"i1"}), + UTExprBuilder::ElementOf(c_, {"i1"})), + UTStmtBuilder::CreateProvideBinary( + out_, {"i0", "i1"}, + UTExprBuilder::ElementOf(out_0_, {"i1"}), + UTExprBuilder::ElementOf(a_, {"i0", "i1"})))))))))); + t_a_ = UTExprBuilder::CreateTensorByPlaceholder(a_); + t_b_ = UTExprBuilder::CreateTensorByPlaceholder(b_); + t_c_ = UTExprBuilder::CreateTensorByPlaceholder(c_); + t_out_ = UTExprBuilder::CreateTensorByPlaceholder(out_); + RegisterTensor(t_a_); + RegisterTensor(t_b_); + RegisterTensor(t_c_); + RegisterTensor(t_out_); + } + + air::Operation a_; + air::Operation b_; + air::Operation c_; + air::Tensor t_a_; + air::Tensor t_b_; + air::Tensor t_c_; + air::Operation out_; + air::Tensor t_out_; + air::Operation out_0_; + air::Stmt stmt_; +}; // class AutoPolyTest1 + +TEST_F(AutoPolyTest1, RunPass) { + SetRunMode("cloud"); + air::Array stmts_out = ir::AutoPoly(stmt_, binds_, global_attrs_, false, false); + ASSERT_EQ(stmts_out.size(), 2); + air::NodeRef stmt = stmts_out[0]; + std::vector> infos_lhs = + UTProvideCheckerForBinary(true).Find( + stmt, UTProvideCheckerForBinary::BinaryOpType::kAdd, "b_local_UB", "c_local_UB"); + ASSERT_EQ(infos_lhs.size(), 1); + EXPECT_EQ(std::get<2>(infos_lhs[0]), 2 * 16 * 1024); +} + +/* AutoPolyTest2: test for to_three_address + * Input pattern: + * for (i1, 0, 32) { + * out_0(i1) = b(i1) + c(i1) + * for (i0, 0, 1024) { + * out(i0, i1) = out_0(i1) + a(i0, i1) + * } + * } + * + * Expect output: + * for (cc1, 0, 2) { + * for (cc2, 0, 1024) { + * out_0_local_UB(cc2) = (b_local_UB(cc2) + c_local_UB(cc2)) + * } + * for (cc2, 0, 1024) { + * for (cc3, 0, 16) { + * out_local_UB(cc3, cc2) = (out_0_local_UB(cc2) + a_local_UB(cc3, cc2)) + * } + * } + * } + * + * IR Check: + * count for (b_local_UB + c_local_UB): 2 * 1024 + */ +class AutoPolyTest2 : public AutoPolyTestBase { + public: + AutoPolyTest2() { + Construct(); + } + ~AutoPolyTest2() = default; + void Construct() { + a_ = UTExprBuilder::PlaceholderOpNode("a", {32, 1024}, air::Float(16)); + b_ = UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16)); + c_ = UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16)); + out_ = UTExprBuilder::PlaceholderOpNode("out", {32, 1024}, air::Float(16)); + out_0_ = UTExprBuilder::PlaceholderOpNode("out_0", {1024}, air::Float(16)); + stmt_ = air::ir::AttrStmt::make( + out_0_, "realize_scope", air::ir::StringImm::make(""), + UTStmtBuilder::CreateRealizeByPlaceholderOp( + out_0_, + air::ir::AttrStmt::make( + out_, "realize_scope", air::ir::StringImm::make(""), + UTStmtBuilder::CreateRealizeByPlaceholderOp( + out_, + air::ir::ProducerConsumer::make(out_, true, + UTStmtBuilder::CreateFor( + "i1", 0, 1024, + air::ir::Block::make( + UTStmtBuilder::CreateProvideBinary( + out_0_, {"i1"}, + UTExprBuilder::ElementOf(b_, {"i1"}), + UTExprBuilder::ElementOf(c_, {"i1"})), + UTStmtBuilder::CreateFor( + "i0", 0, 32, + UTStmtBuilder::CreateProvideBinary( + out_, {"i0", "i1"}, + UTExprBuilder::ElementOf(out_0_, {"i1"}), + UTExprBuilder::ElementOf(a_, {"i0", "i1"})))))))))); + t_a_ = UTExprBuilder::CreateTensorByPlaceholder(a_); + t_b_ = UTExprBuilder::CreateTensorByPlaceholder(b_); + t_c_ = UTExprBuilder::CreateTensorByPlaceholder(c_); + t_out_ = UTExprBuilder::CreateTensorByPlaceholder(out_); + RegisterTensor(t_a_); + RegisterTensor(t_b_); + RegisterTensor(t_c_); + RegisterTensor(t_out_); + } + + air::Operation a_; + air::Operation b_; + air::Operation c_; + air::Tensor t_a_; + air::Tensor t_b_; + air::Tensor t_c_; + air::Operation out_; + air::Tensor t_out_; + air::Operation out_0_; + air::Stmt stmt_; +}; // class AutoPolyTest2 + +TEST_F(AutoPolyTest2, RunPass) { + SetRunMode("cloud"); + air::Array stmts_out = ir::AutoPoly(stmt_, binds_, global_attrs_, false, false); + ASSERT_EQ(stmts_out.size(), 2); + air::NodeRef stmt = stmts_out[0]; + std::vector> infos_lhs = + UTProvideCheckerForBinary(true).Find( + stmt, UTProvideCheckerForBinary::BinaryOpType::kAdd, "b_local_UB", "c_local_UB"); + ASSERT_EQ(infos_lhs.size(), 1); + EXPECT_EQ(std::get<2>(infos_lhs[0]), 2 * 1024); +} + +/* AutoPolyTest3: test for to_three_address + * Input pattern: + * for (i0, 0, 1024) { + * out_0(i0) = b(i0) + c(i0) + * } + * for (i1, 0, 32) { + * for (i0, 0, 1024) { + * out(i0, i1) = out_0(i1) + a(i0, i1) + * } + * } + * + * Expect output: + * for (cc1, 0, 2) { + * for (cc2, 0, 1024) { + * out_0_local_UB(cc2) = (b_local_UB(cc2) + c_local_UB(cc2)) + * } + * for (cc2, 0, 1024) { + * for (cc3, 0, 16) { + * out_local_UB(cc3, cc2) = (out_0_local_UB(cc2) + a_local_UB(cc3, cc2)) + * } + * } + * } + * + * IR Check: + * count for (b_local_UB + c_local_UB): 2 * 1024 + */ +class AutoPolyTest3 : public AutoPolyTestBase { + public: + AutoPolyTest3() { + Construct(); + } + ~AutoPolyTest3() = default; + void Construct() { + a_ = UTExprBuilder::PlaceholderOpNode("a", {32, 1024}, air::Float(16)); + b_ = UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16)); + c_ = UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16)); + out_ = UTExprBuilder::PlaceholderOpNode("out", {32, 1024}, air::Float(16)); + out_0_ = UTExprBuilder::PlaceholderOpNode("out_0", {1024}, air::Float(16)); + stmt_ = air::ir::AttrStmt::make( + out_0_, "realize_scope", air::ir::StringImm::make(""), + UTStmtBuilder::CreateRealizeByPlaceholderOp( + out_0_, + air::ir::AttrStmt::make( + out_, "realize_scope", air::ir::StringImm::make(""), + UTStmtBuilder::CreateRealizeByPlaceholderOp( + out_, + air::ir::ProducerConsumer::make(out_, true, + air::ir::Block::make( + UTStmtBuilder::CreateFor( + "i0", 0, 1024, + UTStmtBuilder::CreateProvideBinary( + out_0_, {"i0"}, + UTExprBuilder::ElementOf(b_, {"i0"}), + UTExprBuilder::ElementOf(c_, {"i0"}))), + UTStmtBuilder::CreateFor( + "i0", 0, 32, + UTStmtBuilder::CreateFor( + "i1", 0, 1024, + UTStmtBuilder::CreateProvideBinary( + out_, {"i0", "i1"}, + UTExprBuilder::ElementOf(out_0_, {"i1"}), + UTExprBuilder::ElementOf(a_, {"i0", "i1"})))))))))); + t_a_ = UTExprBuilder::CreateTensorByPlaceholder(a_); + t_b_ = UTExprBuilder::CreateTensorByPlaceholder(b_); + t_c_ = UTExprBuilder::CreateTensorByPlaceholder(c_); + t_out_ = UTExprBuilder::CreateTensorByPlaceholder(out_); + RegisterTensor(t_a_); + RegisterTensor(t_b_); + RegisterTensor(t_c_); + RegisterTensor(t_out_); + } + + air::Operation a_; + air::Operation b_; + air::Operation c_; + air::Tensor t_a_; + air::Tensor t_b_; + air::Tensor t_c_; + air::Operation out_; + air::Tensor t_out_; + air::Operation out_0_; + air::Stmt stmt_; +}; // class AutoPolyTest3 + +TEST_F(AutoPolyTest3, RunPass) { + SetRunMode("cloud"); + air::Array stmts_out = ir::AutoPoly(stmt_, binds_, global_attrs_, false, false); + ASSERT_EQ(stmts_out.size(), 2); + air::NodeRef stmt = stmts_out[0]; + std::vector> infos_lhs = + UTProvideCheckerForBinary(true).Find( + stmt, UTProvideCheckerForBinary::BinaryOpType::kAdd, "b_local_UB", "c_local_UB"); + ASSERT_EQ(infos_lhs.size(), 1); + EXPECT_EQ(std::get<2>(infos_lhs[0]), 2 * 1024); +} +} // namespace akg diff --git a/tests/unittest_cpp/src/pass_test/to_three_address_test.cc b/tests/unittest_cpp/src/pass_test/to_three_address_test.cc index d84e102557406031f2ea28459d87fc3dc7b5081e..ee10181cf0597fdf7ba1103403d0c458ce732c01 100644 --- a/tests/unittest_cpp/src/pass_test/to_three_address_test.cc +++ b/tests/unittest_cpp/src/pass_test/to_three_address_test.cc @@ -15,8 +15,10 @@ */ #include #include -#include "base/expr_builder.h" #include "base/dump_helper.h" +#include "base/expr_builder.h" +#include "base/ir_checker.h" +#include "base/stmt_builder.h" #define private public #define protected public #include "pass/to_three_address.cc" @@ -71,4 +73,74 @@ TEST_F(ThreeAddressExprMutatorTest, MutateBinaryOp_Add) { Expr expr_m = mutator_.Mutate(expr); EXPECT_NE(mutator_.imm_ops.size(), 0); } + +class PassTestToThreeAddress1 : public ::testing::Test { + public: + PassTestToThreeAddress1() { + Construct(); + } + ~PassTestToThreeAddress1() = default; + void Construct() { + a_ = UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16)); + b_ = UTExprBuilder::PlaceholderOpNode("b", {32, 1024}, air::Float(16)); + c_ = UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16)); + out_ = UTExprBuilder::PlaceholderOpNode("out", {32, 1024}, air::Float(16)); + stmt = air::ir::AttrStmt::make( + out_, "", UTExprBuilder::IntImm(1), + UTStmtBuilder::CreateRealizeByPlaceholderOp( + out_, + air::ir::ProducerConsumer::make(out_, true, + UTStmtBuilder::CreateFor( + "i", 0, 32, + UTStmtBuilder::CreateFor( + "j", 0, 1024, + UTStmtBuilder::CreateProvideBinary( + out_, {"i", "j"}, + air::ir::Add::make( + UTExprBuilder::ElementOf(a_, {"j"}), + UTExprBuilder::ElementOf(b_, {"i", "j"})), + UTExprBuilder::ElementOf(c_, {"j"}))))))); + } + + air::Operation a_; + air::Operation b_; + air::Operation c_; + air::Operation out_; + air::Stmt stmt; +}; // class PassTestToThreeAddress1 + +TEST_F(PassTestToThreeAddress1, CaseCheck) { + std::vector> infos_lhs = + UTProvideCheckerForAssign().Find(stmt, "((a(j) + b(i, j)) + c(j))"); + ASSERT_EQ(infos_lhs.size(), 1); + EXPECT_EQ(std::get<0>(infos_lhs[0]), "out(i, j)"); + EXPECT_EQ(std::get<2>(infos_lhs[0]), 32 * 1024); +} + +TEST_F(PassTestToThreeAddress1, TestPass) { + Stmt stmt_out = ir::ToThreeAddress(stmt, false, 0, true); + /* current implementation + * out_2(i, j) = b(i, j) + * out_3(i, j) = (a(j) + out_2(i, j)) + * out(i, j) = (out_3(i, j) + c(j)) + */ + std::vector> info1 = + UTProvideCheckerForAssign().Find(stmt_out, "b(i, j)"); + ASSERT_EQ(info1.size(), 1); + std::string dump_b_target = std::get<0>(info1[0]); + + std::vector> info2 = + UTProvideCheckerForBinary().Find( + stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, "a(j)", dump_b_target); + ASSERT_EQ(info2.size(), 1); + std::string dump_sum1_target = std::get<0>(info2[0]); + EXPECT_EQ(std::get<2>(info2[0]), 32 * 1024); + + std::vector> info3 = + UTProvideCheckerForBinary().Find( + stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, dump_sum1_target, "c(j)"); + ASSERT_EQ(info3.size(), 1); + EXPECT_EQ(std::get<0>(info3[0]), "out(i, j)"); + EXPECT_EQ(std::get<2>(info3[0]), 32 * 1024); +} } // namespace akg diff --git a/tests/unittest_cpp/src/pass_test_base/auto_poly_test_base.cc b/tests/unittest_cpp/src/pass_test_base/auto_poly_test_base.cc new file mode 100644 index 0000000000000000000000000000000000000000..f0696fd3c6a4a72137ec75a29ad5677acb4ce596 --- /dev/null +++ b/tests/unittest_cpp/src/pass_test_base/auto_poly_test_base.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pass_test_base/auto_poly_test_base.h" + +namespace akg { +std::map AutoPolyTestBase::map_mode_ = + AutoPolyTestBase::InitMapMode(); + +std::map AutoPolyTestBase::InitMapMode() { + std::map res; + res["cloud"] = "1.6"; + res["mini"] = "1.1"; + res["phoenix"] = "3.5"; + res["orlando"] = "3.3"; + return res; +} + +void AutoPolyTestBase::SetRunMode(const std::string &mode) { + auto it = map_mode_.find(mode); + CHECK(it != map_mode_.end()); + cceconf::CceConf::getInstance()->setSection(it->second); +} + +void AutoPolyTestBase::RegisterTensor(const air::Tensor &tensor) { + const TensorNode *tensor_node = tensor.as(); + std::string name = tensor_node->op->name; + air::Buffer buf = air::BufferNode::make( + air::Variable::make(Handle(), name), + tensor_node->dtype, + tensor_node->shape, + Array(), + Expr(), + name, + "", + -1, + 0, + air::BufferType::kDefault); + binds_.Set(GetRef(tensor_node), buf); +} +} // namespace akg