提交 20e24526 编写于 作者: L LuoYin

update LLT-UT

support
1. Pass level UT for ToThreeAddress
2. Pass level UT for AutoPoly
  test 3 patterns for ToThreeAddress pass
上级 4712acdc
......@@ -16,6 +16,7 @@ file(
unittest_main.cc
src/base/*.cc
src/base_test/*.cc
src/pass_test_base/*.cc
src/pass_test/*.cc)
link_directories(${CMAKE_BINARY_DIR}/googletest/googlemock/gtest)
......
......@@ -29,7 +29,7 @@ class UTRegxMatch {
static bool RegxMatchHex(const std::string &str);
static const std::string pattern_hex_;
}; // UTRegxMatch
}; // class UTRegxMatch
class UTDumpHelper {
public:
......@@ -38,6 +38,6 @@ class UTDumpHelper {
static std::string Dump(const air::NodeRef &node);
static bool RegxMatchPlaceholder(const std::string &str, const std::string &name);
}; // UTDumpHelper
}; // class UTDumpHelper
} // namespace akg
#endif // UT_BASE_DUMP_HELPER_H_
......@@ -19,6 +19,7 @@
#include <vector>
#include "tvm/expr.h"
#include "tvm/operation.h"
#include "tvm/tensor.h"
namespace akg {
class UTExprBuilder {
......@@ -26,9 +27,15 @@ class UTExprBuilder {
UTExprBuilder() = default;
~UTExprBuilder() = default;
static air::Expr IntImm(int64_t value, air::DataType dtype = air::Int(32));
static air::Expr UIntImm(uint64_t value, air::DataType dtype = air::UInt(32));
static air::Expr BoolImm(bool value);
static air::Array<air::Expr> CreateShape(const std::vector<int32_t> &shapes);
static air::Var CreateVar(const std::string &name);
static air::Array<air::Expr> CreateVars(const std::vector<std::string> &names);
static air::Range CreateRange(int32_t min, int32_t max);
static air::Region CreateRegion(const std::vector<int32_t> &shapes);
static air::Region CreateRegion(const air::Array<air::Expr> &shapes);
static air::Operation PlaceholderOpNode(
const std::string &name,
const std::vector<int32_t> &shapes,
......@@ -38,6 +45,18 @@ class UTExprBuilder {
const std::vector<int32_t> &shapes,
const std::vector<std::string> &axis_names,
air::DataType dtype = air::Float(16));
static air::Expr ElementOf(
const air::Operation &op,
const std::vector<std::string> &axis_names);
static air::Expr ElementOfPlaceholderOp(
const air::Operation &op,
const std::vector<std::string> &axis_names);
static air::Expr CreateCall(
const air::ir::FunctionRef func,
air::Array<air::Expr> args,
air::ir::Call::CallType call_type = air::ir::Call::Halide,
int value_index = 0);
static air::Tensor CreateTensorByPlaceholder(const air::Operation op);
}; // UTExprBuilder
class UTTensorElementHelper {
......@@ -46,14 +65,13 @@ class UTTensorElementHelper {
const std::string &axis_name_prefix = "ax");
~UTTensorElementHelper() = default;
air::Expr Elem(const std::string &name,
uint32_t dim,
air::DataType dtype = air::Float(16)) const;
uint32_t dim,
air::DataType dtype = air::Float(16)) const;
private:
std::vector<int32_t> shapes_;
std::string axis_name_prefix_;
std::vector<std::string> axis_names_;
}; // UTTensorElementHelper
}; // class UTTensorElementHelper
} // namespace akg
#endif // UT_BASE_EXPR_BUILDER_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UT_IR_CHECKER_H_
#define UT_IR_CHECKER_H_
#include <string>
#include <tuple>
#include <vector>
#include <tvm/ir_visitor.h>
#include "base/dump_helper.h"
#include "base/expr_builder.h"
namespace akg {
class UTIRCheckHelper {
public:
UTIRCheckHelper() = default;
~UTIRCheckHelper() = default;
static int64_t GetValueFromImm(const air::Expr &expr);
}; // class UTIRCheckHelper
class UTProvideChecker : public air::ir::IRVisitor {
public:
explicit UTProvideChecker(bool ignore_args = false)
: ignore_args_(ignore_args) {}
~UTProvideChecker() = default;
void Visit_(const air::ir::For *op) override;
bool CompareDump(const std::string &dump, const std::string &target);
protected:
bool ignore_args_{false};
std::vector<uint64_t> for_count_stack_;
}; // class UTProvideChecker
class UTProvideCheckerForAssign : public UTProvideChecker {
public:
explicit UTProvideCheckerForAssign(bool ignore_args = false)
: UTProvideChecker(ignore_args) {}
~UTProvideCheckerForAssign() = default;
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> Find(
const air::NodeRef &node,
const std::string &dump_rhs);
void Visit_(const air::ir::Provide *op) override;
private:
std::string dump_rhs_{""};
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> infos_lhs_;
}; // class UTProvideChecker
class UTProvideCheckerForBinary : public UTProvideChecker {
public:
enum BinaryOpType : int {
kAdd,
kSub,
kMul,
kDiv,
kMod,
};
explicit UTProvideCheckerForBinary(bool ignore_args = false)
: UTProvideChecker(ignore_args) {}
~UTProvideCheckerForBinary() = default;
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> Find(
const air::NodeRef &node,
BinaryOpType op_type,
const std::string &dump_rhs1,
const std::string &dump_rhs2);
void Visit_(const air::ir::Provide *op) override;
template <typename T>
void CheckBinary(const air::ir::Provide *op) {
const T *expr_binary = op->value.as<T>();
if (expr_binary == nullptr) {
return;
}
std::string dump_expr_a = UTDumpHelper::Dump(expr_binary->a);
std::string dump_expr_b = UTDumpHelper::Dump(expr_binary->b);
if ((dump_rhs1_.empty() || CompareDump(dump_expr_a, dump_rhs1_)) &&
(dump_rhs2_.empty() || CompareDump(dump_expr_b, dump_rhs2_))) {
air::Expr expr_call = UTExprBuilder::CreateCall(op->func, op->args);
infos_lhs_.push_back(std::make_tuple(UTDumpHelper::Dump(expr_call), op, for_count_stack_.back()));
}
}
private:
BinaryOpType op_type_;
std::string dump_rhs1_{""};
std::string dump_rhs2_{""};
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> infos_lhs_;
}; // class UTProvideCheckerForBinary
} // namespace akg
#endif // UT_IR_CHECKER_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UT_BASE_STMT_BUILDER_H_
#define UT_BASE_STMT_BUILDER_H_
#include <list>
#include <string>
#include <vector>
#include "tvm/ir.h"
#include "base/expr_builder.h"
namespace akg {
class UTStmtBuilder {
public:
UTStmtBuilder() = default;
~UTStmtBuilder() = default;
static air::Stmt CreateFor(
const std::string &loop_var_name,
int32_t min,
int32_t extent,
air::Stmt body);
static air::Stmt CreateRealizeByPlaceholderOp(
const air::Operation &op,
air::Stmt body);
static air::Stmt CreateProvideAssign(
air::ir::FunctionRef func_dst,
const std::vector<std::string> &vars,
air::Expr src,
int value_index = 0);
template <typename T>
static air::Stmt CreateProvideBinary(
air::ir::FunctionRef func_dst,
const std::vector<std::string> &vars,
air::Expr src1,
air::Expr src2,
int value_index = 0) {
return air::ir::Provide::make(
func_dst,
value_index,
T::make(src1, src2),
UTExprBuilder::CreateVars(vars));
}
}; // class UTStmtBuilder
} // namespace akg
#endif // UT_BASE_STMT_BUILDER_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UT_AUTO_POLY_TEST_BASE_H_
#define UT_AUTO_POLY_TEST_BASE_H_
#include <map>
#include <string>
#include <gtest/gtest.h>
#include <tvm/expr.h>
#include "codegen/util.h"
#include "contrib/cce_parm/cceconf.h"
#include "base/expr_builder.h"
namespace akg {
class AutoPolyTestBase : public ::testing::Test {
public:
AutoPolyTestBase() = default;
~AutoPolyTestBase() = default;
static std::map<std::string, std::string> InitMapMode();
void RegisterTensor(const air::Tensor &tensor);
void SetRunMode(const std::string &mode);
void GlobalAttrSetIsDynamic(bool arg) {
global_attrs_.Set("is_dynamic", arg ? UTExprBuilder::IntImm(1, air::Int(1)) :
UTExprBuilder::IntImm(0, air::Int(1)));
}
void GlobalAttrSetDynamic(bool arg) {
global_attrs_.Set("dynamic", arg ? UTExprBuilder::IntImm(1, air::Int(1)) :
UTExprBuilder::IntImm(0, air::Int(1)));
}
void GlobalAttrSetDumpPassIR(bool arg) {
global_attrs_.Set("dump_pass_ir", arg ? UTExprBuilder::IntImm(1, air::Int(1)) :
UTExprBuilder::IntImm(0, air::Int(1)));
}
void GlobalAttrSetDumpPolyDir(const std::string &path) {
global_attrs_.Set("dump_poly_dir", air::ir::StringImm::make(path));
}
void GlobalAttrSetKernalName(const std::string &name) {
global_attrs_.Set("kernel_name", air::ir::StringImm::make(name));
}
static std::map<std::string, std::string> map_mode_;
protected:
air::Map<air::Tensor, air::Buffer> binds_;
AttrMap global_attrs_;
}; // class AutoPolyTestBase
} // namespace akg
#endif
......@@ -14,9 +14,22 @@
* limitations under the License.
*/
#include <sstream>
#include <tvm/operation.h>
#include "base/expr_builder.h"
namespace akg {
air::Expr UTExprBuilder::IntImm(int64_t value, air::DataType dtype) {
return air::IntImm::make(dtype, value);
}
air::Expr UTExprBuilder::UIntImm(uint64_t value, air::DataType dtype) {
return air::ir::UIntImm::make(dtype, value);
}
air::Expr UTExprBuilder::BoolImm(bool value) {
return air::ir::UIntImm::make(air::Bool(), value ? 1 : 0);
}
air::Array<air::Expr> UTExprBuilder::CreateShape(const std::vector<int32_t> &shapes) {
air::Array<air::Expr> res;
for (int32_t shape : shapes) {
......@@ -38,6 +51,28 @@ air::Array<air::Expr> UTExprBuilder::CreateVars(const std::vector<std::string> &
return vars;
}
air::Region UTExprBuilder::CreateRegion(const std::vector<int32_t> &shapes) {
air::Region region;
for (int32_t shape : shapes) {
region.push_back(CreateRange(0, shape));
}
return region;
}
air::Region UTExprBuilder::CreateRegion(const air::Array<air::Expr> &shapes) {
air::Region region;
for (const air::Expr &shape : shapes) {
region.push_back(air::Range::make_by_min_extent(IntImm(0), shape));
}
return region;
}
air::Range UTExprBuilder::CreateRange(int32_t min, int32_t max) {
air::Integer imm_min = air::IntImm::make(air::Int(32), min);
air::Integer imm_max = air::IntImm::make(air::Int(32), max);
return air::Range(std::move(imm_min), std::move(imm_max));
}
air::Operation UTExprBuilder::PlaceholderOpNode(
const std::string &name,
const std::vector<int32_t> &shapes,
......@@ -60,6 +95,56 @@ air::Expr UTExprBuilder::TensorElement(
0); // value_index
}
air::Expr UTExprBuilder::ElementOf(
const air::Operation &op,
const std::vector<std::string> &axis_names) {
if (op->template IsInstance<air::PlaceholderOpNode>()) {
return ElementOfPlaceholderOp(op, axis_names);
} else {
CHECK(false);
return air::ir::Any::make();
}
}
air::Expr UTExprBuilder::ElementOfPlaceholderOp(
const air::Operation &op,
const std::vector<std::string> &axis_names) {
const air::PlaceholderOpNode *node = op.as<const air::PlaceholderOpNode>();
CHECK(node);
return air::ir::Call::make(
node->dtype,
node->name,
CreateVars(axis_names),
air::ir::Call::Halide,
op,
0);
}
air::Expr UTExprBuilder::CreateCall(
const air::ir::FunctionRef func,
air::Array<air::Expr> args,
air::ir::Call::CallType call_type,
int value_index) {
air::DataType type = air::Float(16);
const air::OperationNode *node_op = func.as<air::OperationNode>();
CHECK(node_op);
std::string name = node_op->name;
const air::PlaceholderOpNode *node_placeholder = func.as<air::PlaceholderOpNode>();
if (node_placeholder != nullptr) {
type = node_placeholder->dtype;
}
return air::ir::Call::make(type, name, args, call_type, func, value_index);
}
air::Tensor UTExprBuilder::CreateTensorByPlaceholder(const air::Operation op) {
const air::PlaceholderOpNode *node = op.as<air::PlaceholderOpNode>();
CHECK(node);
return air::TensorNode::make(
node->shape,
node->dtype,
op,
0);
}
UTTensorElementHelper::UTTensorElementHelper(const std::vector<int32_t> &shapes,
const std::string &axis_name_prefix)
: shapes_(shapes), axis_name_prefix_(axis_name_prefix) {
......@@ -72,8 +157,8 @@ UTTensorElementHelper::UTTensorElementHelper(const std::vector<int32_t> &shapes,
}
air::Expr UTTensorElementHelper::Elem(const std::string &name,
uint32_t dim,
air::DataType dtype) const {
uint32_t dim,
air::DataType dtype) const {
uint32_t start = shapes_.size() - dim;
return UTExprBuilder::TensorElement(
name,
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "base/ir_checker.h"
#include <cinttypes>
#include <string>
#include "base/dump_helper.h"
#include "base/expr_builder.h"
namespace akg {
int64_t UTIRCheckHelper::GetValueFromImm(const air::Expr &expr) {
const air::IntImm *imm_int = expr.as<air::IntImm>();
if (imm_int != nullptr) {
return imm_int->value;
}
const air::ir::UIntImm *imm_uint = expr.as<air::ir::UIntImm>();
if (imm_uint != nullptr) {
CHECK(imm_uint->value < INT64_MAX);
return static_cast<int64_t>(imm_uint->value);
}
return 0;
}
void UTProvideChecker::Visit_(const air::ir::For *op) {
uint64_t count_top = for_count_stack_.back();
int64_t min = UTIRCheckHelper::GetValueFromImm(op->min);
int64_t extent = UTIRCheckHelper::GetValueFromImm(op->extent);
CHECK(extent > min);
count_top *= static_cast<uint64_t>(extent);
for_count_stack_.push_back(count_top);
IRVisitor::Visit_(op);
for_count_stack_.pop_back();
}
bool UTProvideChecker::CompareDump(
const std::string &dump,
const std::string &target) {
if (dump.compare(target) == 0) {
return true;
}
if (ignore_args_) {
size_t npos = dump.find("(");
return dump.substr(0, npos).compare(target) == 0;
}
return false;
}
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> UTProvideCheckerForAssign::Find(
const air::NodeRef &node,
const std::string &dump_rhs) {
dump_rhs_ = dump_rhs;
infos_lhs_.clear();
for_count_stack_.clear();
for_count_stack_.push_back(1);
Visit(node);
return infos_lhs_;
}
void UTProvideCheckerForAssign::Visit_(const air::ir::Provide *op) {
std::string dump_expr = UTDumpHelper::Dump(op->value);
if (CompareDump(dump_expr, dump_rhs_)) {
air::Expr expr_call = UTExprBuilder::CreateCall(op->func, op->args);
infos_lhs_.push_back(std::make_tuple(UTDumpHelper::Dump(expr_call), op, for_count_stack_.back()));
}
}
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> UTProvideCheckerForBinary::Find(
const air::NodeRef &node,
UTProvideCheckerForBinary::BinaryOpType op_type,
const std::string &dump_rhs1,
const std::string &dump_rhs2) {
op_type_ = op_type;
dump_rhs1_ = dump_rhs1;
dump_rhs2_ = dump_rhs2;
infos_lhs_.clear();
for_count_stack_.clear();
for_count_stack_.push_back(1);
if (dump_rhs1_.empty() && dump_rhs2_.empty()) {
return infos_lhs_;
}
Visit(node);
return infos_lhs_;
}
void UTProvideCheckerForBinary::Visit_(const air::ir::Provide *op) {
switch (op_type_) {
case BinaryOpType::kAdd:
CheckBinary<air::ir::Add>(op);
break;
case BinaryOpType::kSub:
CheckBinary<air::ir::Sub>(op);
break;
case BinaryOpType::kMul:
CheckBinary<air::ir::Mul>(op);
break;
case BinaryOpType::kDiv:
CheckBinary<air::ir::Div>(op);
break;
case BinaryOpType::kMod:
CheckBinary<air::ir::Mod>(op);
break;
default:
break;
}
}
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "base/stmt_builder.h"
namespace akg {
air::Stmt UTStmtBuilder::CreateFor(
const std::string &loop_var_name,
int32_t min,
int32_t extent,
air::Stmt body) {
return air::ir::For::make(
UTExprBuilder::CreateVar(loop_var_name),
UTExprBuilder::IntImm(min),
UTExprBuilder::IntImm(extent),
air::ir::ForType::Serial,
air::ir::DeviceAPI::None,
body);
}
air::Stmt UTStmtBuilder::CreateRealizeByPlaceholderOp(
const air::Operation &op,
air::Stmt body) {
const air::PlaceholderOpNode *node = op.as<const air::PlaceholderOpNode>();
CHECK(node);
return air::ir::Realize::make(
op,
0,
node->dtype,
UTExprBuilder::CreateRegion(node->shape),
UTExprBuilder::BoolImm(true),
body);
}
air::Stmt UTStmtBuilder::CreateProvideAssign(
air::ir::FunctionRef func_dst,
const std::vector<std::string> &vars,
air::Expr src,
int value_index) {
return air::ir::Provide::make(
func_dst,
value_index,
src,
UTExprBuilder::CreateVars(vars));
}
} // namespace akg
......@@ -18,6 +18,39 @@
#include "base/expr_builder.h"
namespace akg {
TEST(UTExprBuilder, IntImm) {
air::Expr int1 = UTExprBuilder::IntImm(1024);
std::string dump_int1 = UTDumpHelper::Dump(int1);
EXPECT_EQ(dump_int1, "1024");
air::Expr int2 = UTExprBuilder::IntImm(1024, air::Int(64));
std::string dump_int2 = UTDumpHelper::Dump(int2);
EXPECT_EQ(dump_int2, "(int64)1024");
air::Expr int3 = UTExprBuilder::IntImm(1024, air::Int(16));
std::string dump_int3 = UTDumpHelper::Dump(int3);
EXPECT_EQ(dump_int3, "(int16)1024");
}
TEST(UTExprBuilder, UIntImm) {
air::Expr uint1 = UTExprBuilder::UIntImm(1024);
std::string dump_uint1 = UTDumpHelper::Dump(uint1);
EXPECT_EQ(dump_uint1, "(uint32)1024");
air::Expr uint2 = UTExprBuilder::UIntImm(1024, air::UInt(64));
std::string dump_uint2 = UTDumpHelper::Dump(uint2);
EXPECT_EQ(dump_uint2, "(uint64)1024");
air::Expr uint3 = UTExprBuilder::UIntImm(1024, air::UInt(16));
std::string dump_uint3 = UTDumpHelper::Dump(uint3);
EXPECT_EQ(dump_uint3, "(uint16)1024");
}
TEST(UTExprBuilder, Bool) {
air::Expr bool_true = UTExprBuilder::BoolImm(true);
std::string dump_bool_true = UTDumpHelper::Dump(bool_true);
EXPECT_EQ(dump_bool_true, "(bool)1");
air::Expr bool_false = UTExprBuilder::BoolImm(false);
std::string dump_bool_false = UTDumpHelper::Dump(bool_false);
EXPECT_EQ(dump_bool_false, "(bool)0");
}
TEST(UTExprBuilder, CreateShape) {
air::Array<air::Expr> shape1 = UTExprBuilder::CreateShape({1024});
std::string dump_shape1 = UTDumpHelper::Dump(shape1);
......@@ -44,6 +77,12 @@ TEST(UTExprBuilder, CreateVars) {
EXPECT_EQ(dump_vars, "[ax0, ax1, ax2]");
}
TEST(UTExprBuilder, CreateRange) {
air::Range range = UTExprBuilder::CreateRange(0, 1024);
std::string dump_range = UTDumpHelper::Dump(range);
EXPECT_EQ(dump_range, "range(min=0, ext=1024)");
}
TEST(UTExprBuilder, PlaceholderOpNode) {
air::Operation node = UTExprBuilder::PlaceholderOpNode("input", {16, 32, 1024}, air::Float(16));
std::string dump_node = UTDumpHelper::Dump(node);
......@@ -56,6 +95,13 @@ TEST(UTExprBuilder, TensorElement) {
EXPECT_EQ(dump_elem, "input(ax0, ax1, ax2)");
}
TEST(UTExprBuilder, ElememtOfPlaceholderOp) {
air::Operation op = UTExprBuilder::PlaceholderOpNode("input", {16, 32, 1024}, air::Float(16));
air::Expr elem = UTExprBuilder::ElementOfPlaceholderOp(op, {"ax0", "ax1", "ax2"});
std::string dump_elem = UTDumpHelper::Dump(elem);
EXPECT_EQ(dump_elem, "input(ax0, ax1, ax2)");
}
TEST(UTTensorElementHelper, TensorElement) {
UTTensorElementHelper helper({16, 32, 1024});
std::string dump_elem1 = UTDumpHelper::Dump(helper.Elem("a", 3));
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <string>
#include "base/ir_checker.h"
#include "base/expr_builder.h"
#include "base/stmt_builder.h"
namespace akg {
TEST(UTProvideChecker, CompareDump) {
EXPECT_EQ(UTProvideChecker().CompareDump("a(i, j)", "a"), false);
EXPECT_EQ(UTProvideChecker().CompareDump("a(i, j)", "a(i, j)"), true);
EXPECT_EQ(UTProvideChecker(true).CompareDump("a(i, j)", "a"), true);
EXPECT_EQ(UTProvideChecker(true).CompareDump("a(i, j)", "a(i, j)"), true);
}
class UTProvideCheckerTest : public testing::Test {
public:
UTProvideCheckerTest()
: a_(UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16))),
b_(UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16))),
c_(UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16))) {}
~UTProvideCheckerTest() = default;
air::Operation a_;
air::Operation b_;
air::Operation c_;
}; // class UTProvideCheckerTest
TEST_F(UTProvideCheckerTest, UTProvideCheckerForAssign) {
// b(ax0) = a(ax0)
air::Stmt stmt = UTStmtBuilder::CreateProvideAssign(
b_, {"ax0"}, UTExprBuilder::ElementOf(a_, {"ax0"}));
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> infos_lhs =
UTProvideCheckerForAssign().Find(stmt, "a(ax0)");
ASSERT_EQ(infos_lhs.size(), 1);
EXPECT_EQ(std::get<0>(infos_lhs[0]), "b(ax0)");
EXPECT_EQ(std::get<2>(infos_lhs[0]), 1);
}
TEST_F(UTProvideCheckerTest, UTProvideCheckerForBinary) {
// c(ax0) = (a(ax0) + b(ax0))
air::Stmt stmt = UTStmtBuilder::CreateProvideBinary<air::ir::Add>(
c_, {"ax0"},
UTExprBuilder::ElementOf(a_, {"ax0"}),
UTExprBuilder::ElementOf(b_, {"ax0"}));
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> infos_lhs =
UTProvideCheckerForBinary().Find(stmt, UTProvideCheckerForBinary::BinaryOpType::kAdd, "a(ax0)", "b(ax0)");
ASSERT_EQ(infos_lhs.size(), 1);
EXPECT_EQ(std::get<0>(infos_lhs[0]), "c(ax0)");
EXPECT_EQ(std::get<2>(infos_lhs[0]), 1);
}
class UTProvideCheckerTest2 : public testing::Test {
public:
UTProvideCheckerTest2()
: a_(UTExprBuilder::PlaceholderOpNode("a", {16, 32, 1024}, air::Float(16))),
b_(UTExprBuilder::PlaceholderOpNode("b", {16, 32, 1024}, air::Float(16))),
c_(UTExprBuilder::PlaceholderOpNode("c", {16, 32, 1024}, air::Float(16))) {}
~UTProvideCheckerTest2() = default;
air::Operation a_;
air::Operation b_;
air::Operation c_;
}; // class UTProvideCheckerTest
TEST_F(UTProvideCheckerTest2, UTProvideCheckerForBinary) {
air::Stmt stmt = UTStmtBuilder::CreateFor(
"i", 0, 16,
UTStmtBuilder::CreateFor(
"j", 0, 32,
UTStmtBuilder::CreateFor(
"k", 0, 1024,
UTStmtBuilder::CreateProvideBinary<air::ir::Add>(
c_, {"i", "j", "k"},
UTExprBuilder::ElementOf(a_, {"i", "j", "k"}),
UTExprBuilder::ElementOf(b_, {"i", "j", "k"})))));
std::string dump_stmt = UTDumpHelper::Dump(stmt);
EXPECT_EQ(dump_stmt,
"for (i, 0, 16) {\n"
" for (j, 0, 32) {\n"
" for (k, 0, 1024) {\n"
" c(i, j, k) = (a(i, j, k) + b(i, j, k))\n"
" }\n"
" }\n"
"}\n");
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> infos_lhs =
UTProvideCheckerForBinary().Find(
stmt, UTProvideCheckerForBinary::BinaryOpType::kAdd, "a(i, j, k)", "b(i, j, k)");
ASSERT_EQ(infos_lhs.size(), 1);
EXPECT_EQ(std::get<0>(infos_lhs[0]), "c(i, j, k)");
EXPECT_EQ(std::get<2>(infos_lhs[0]), 1024 * 32 * 16);
}
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "gtest/gtest.h"
#include "base/dump_helper.h"
#include "base/expr_builder.h"
#include "base/stmt_builder.h"
namespace akg {
TEST(UTStmtBuilder, CreateProvideAssign) {
// b(ax0) = a(ax0)
air::Operation a = UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16));
air::Operation b = UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16));
air::Stmt stmt = UTStmtBuilder::CreateProvideAssign(
b, {"ax0"}, UTExprBuilder::ElementOf(a, {"ax0"}));
std::string dump_stmt = UTDumpHelper::Dump(stmt);
EXPECT_EQ(dump_stmt, "b(ax0) = a(ax0)\n");
}
TEST(UTStmtBuilder, CreateProvideBinary) {
// c(ax0) = a(ax0) + b(ax0)
air::Operation a = UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16));
air::Operation b = UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16));
air::Operation c = UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16));
air::Stmt stmt = UTStmtBuilder::CreateProvideBinary<air::ir::Add>(
c, {"ax0"}, UTExprBuilder::ElementOf(a, {"ax0"}), UTExprBuilder::ElementOf(b, {"ax0"}));
std::string dump_stmt = UTDumpHelper::Dump(stmt);
EXPECT_EQ(dump_stmt, "c(ax0) = (a(ax0) + b(ax0))\n");
}
TEST(UTStmtBuilder, CreateFor) {
/*
* for (i, 0, 1024) {
* c(i) = (a(i) + b(i))
* }
*/
air::Operation a = UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16));
air::Operation b = UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16));
air::Operation c = UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16));
air::Stmt stmt_for = UTStmtBuilder::CreateFor(
"i", 0, 1024,
UTStmtBuilder::CreateProvideBinary<air::ir::Add>(
c, {"i"}, UTExprBuilder::ElementOf(a, {"i"}), UTExprBuilder::ElementOf(b, {"i"})));
std::string dump_stmt_for = UTDumpHelper::Dump(stmt_for);
EXPECT_EQ(dump_stmt_for,
"for (i, 0, 1024) {\n"
" c(i) = (a(i) + b(i))\n"
"}\n");
}
TEST(UTStmtBuilder, CreateRealizeByPlaceholderOp) {
air::Operation a = UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16));
air::Operation b = UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16));
air::Operation c = UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16));
air::Stmt stmt_realize = UTStmtBuilder::CreateRealizeByPlaceholderOp(
c,
UTStmtBuilder::CreateFor(
"i", 0, 1024,
UTStmtBuilder::CreateProvideBinary<air::ir::Add>(
c, {"i"}, UTExprBuilder::ElementOf(a, {"i"}), UTExprBuilder::ElementOf(b, {"i"}))));
std::string dump_stmt_realize = UTDumpHelper::Dump(stmt_realize);
EXPECT_EQ(dump_stmt_realize,
"realize c<float16>([0, 1024]) {\n"
" for (i, 0, 1024) {\n"
" c(i) = (a(i) + b(i))\n"
" }\n"
"}\n");
}
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <tvm/ir.h>
#include "base/dump_helper.h"
#include "base/expr_builder.h"
#include "base/ir_checker.h"
#include "base/stmt_builder.h"
#include "pass_test_base/auto_poly_test_base.h"
#define private public
#define protected public
#include "ir_pass.h"
#undef protected
#undef private
#include "codegen/util.h"
#include "contrib/cce_parm/cceconf.h"
namespace akg {
/* AutoPolyTest1: test for to_three_address
* Input pattern:
* for (i0, 0, 32) {
* for (i1, 0, 1024) {
* out_0(i1) = b(i1) + c(i1)
* out(i0, i1) = out_0(i1) + a(i0, i1)
* }
* }
*
* Expect output:
* for (cc1, 0, 2) {
* for (cc2, 0, 16) {
* for (cc3, 0, 1024) {
* out_0_local_UB(cc2) = (b_local_UB(cc2) + c_local_UB(cc2))
* out_local_UB(cc3, cc2) = (out_0_local_UB(cc2) + a_local_UB(cc3, cc2))
* }
* }
* }
*
* IR Check:
* count for (b_local_UB + c_local_UB): 32 * 1024
*/
class AutoPolyTest1 : public AutoPolyTestBase {
public:
AutoPolyTest1() {
Construct();
}
~AutoPolyTest1() = default;
void Construct() {
a_ = UTExprBuilder::PlaceholderOpNode("a", {32, 1024}, air::Float(16));
b_ = UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16));
c_ = UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16));
out_ = UTExprBuilder::PlaceholderOpNode("out", {32, 1024}, air::Float(16));
out_0_ = UTExprBuilder::PlaceholderOpNode("out_0", {1024}, air::Float(16));
stmt_ = air::ir::AttrStmt::make(
out_0_, "realize_scope", air::ir::StringImm::make(""),
UTStmtBuilder::CreateRealizeByPlaceholderOp(
out_0_,
air::ir::AttrStmt::make(
out_, "realize_scope", air::ir::StringImm::make(""),
UTStmtBuilder::CreateRealizeByPlaceholderOp(
out_,
air::ir::ProducerConsumer::make(out_, true,
UTStmtBuilder::CreateFor(
"i0", 0, 32,
UTStmtBuilder::CreateFor(
"i1", 0, 1024,
air::ir::Block::make(
UTStmtBuilder::CreateProvideBinary<air::ir::Add>(
out_0_, {"i1"},
UTExprBuilder::ElementOf(b_, {"i1"}),
UTExprBuilder::ElementOf(c_, {"i1"})),
UTStmtBuilder::CreateProvideBinary<air::ir::Add>(
out_, {"i0", "i1"},
UTExprBuilder::ElementOf(out_0_, {"i1"}),
UTExprBuilder::ElementOf(a_, {"i0", "i1"}))))))))));
t_a_ = UTExprBuilder::CreateTensorByPlaceholder(a_);
t_b_ = UTExprBuilder::CreateTensorByPlaceholder(b_);
t_c_ = UTExprBuilder::CreateTensorByPlaceholder(c_);
t_out_ = UTExprBuilder::CreateTensorByPlaceholder(out_);
RegisterTensor(t_a_);
RegisterTensor(t_b_);
RegisterTensor(t_c_);
RegisterTensor(t_out_);
}
air::Operation a_;
air::Operation b_;
air::Operation c_;
air::Tensor t_a_;
air::Tensor t_b_;
air::Tensor t_c_;
air::Operation out_;
air::Tensor t_out_;
air::Operation out_0_;
air::Stmt stmt_;
}; // class AutoPolyTest1
TEST_F(AutoPolyTest1, RunPass) {
SetRunMode("cloud");
air::Array<air::NodeRef> stmts_out = ir::AutoPoly(stmt_, binds_, global_attrs_, false, false);
ASSERT_EQ(stmts_out.size(), 2);
air::NodeRef stmt = stmts_out[0];
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> infos_lhs =
UTProvideCheckerForBinary(true).Find(
stmt, UTProvideCheckerForBinary::BinaryOpType::kAdd, "b_local_UB", "c_local_UB");
ASSERT_EQ(infos_lhs.size(), 1);
EXPECT_EQ(std::get<2>(infos_lhs[0]), 2 * 16 * 1024);
}
/* AutoPolyTest2: test for to_three_address
* Input pattern:
* for (i1, 0, 32) {
* out_0(i1) = b(i1) + c(i1)
* for (i0, 0, 1024) {
* out(i0, i1) = out_0(i1) + a(i0, i1)
* }
* }
*
* Expect output:
* for (cc1, 0, 2) {
* for (cc2, 0, 1024) {
* out_0_local_UB(cc2) = (b_local_UB(cc2) + c_local_UB(cc2))
* }
* for (cc2, 0, 1024) {
* for (cc3, 0, 16) {
* out_local_UB(cc3, cc2) = (out_0_local_UB(cc2) + a_local_UB(cc3, cc2))
* }
* }
* }
*
* IR Check:
* count for (b_local_UB + c_local_UB): 2 * 1024
*/
class AutoPolyTest2 : public AutoPolyTestBase {
public:
AutoPolyTest2() {
Construct();
}
~AutoPolyTest2() = default;
void Construct() {
a_ = UTExprBuilder::PlaceholderOpNode("a", {32, 1024}, air::Float(16));
b_ = UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16));
c_ = UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16));
out_ = UTExprBuilder::PlaceholderOpNode("out", {32, 1024}, air::Float(16));
out_0_ = UTExprBuilder::PlaceholderOpNode("out_0", {1024}, air::Float(16));
stmt_ = air::ir::AttrStmt::make(
out_0_, "realize_scope", air::ir::StringImm::make(""),
UTStmtBuilder::CreateRealizeByPlaceholderOp(
out_0_,
air::ir::AttrStmt::make(
out_, "realize_scope", air::ir::StringImm::make(""),
UTStmtBuilder::CreateRealizeByPlaceholderOp(
out_,
air::ir::ProducerConsumer::make(out_, true,
UTStmtBuilder::CreateFor(
"i1", 0, 1024,
air::ir::Block::make(
UTStmtBuilder::CreateProvideBinary<air::ir::Add>(
out_0_, {"i1"},
UTExprBuilder::ElementOf(b_, {"i1"}),
UTExprBuilder::ElementOf(c_, {"i1"})),
UTStmtBuilder::CreateFor(
"i0", 0, 32,
UTStmtBuilder::CreateProvideBinary<air::ir::Add>(
out_, {"i0", "i1"},
UTExprBuilder::ElementOf(out_0_, {"i1"}),
UTExprBuilder::ElementOf(a_, {"i0", "i1"}))))))))));
t_a_ = UTExprBuilder::CreateTensorByPlaceholder(a_);
t_b_ = UTExprBuilder::CreateTensorByPlaceholder(b_);
t_c_ = UTExprBuilder::CreateTensorByPlaceholder(c_);
t_out_ = UTExprBuilder::CreateTensorByPlaceholder(out_);
RegisterTensor(t_a_);
RegisterTensor(t_b_);
RegisterTensor(t_c_);
RegisterTensor(t_out_);
}
air::Operation a_;
air::Operation b_;
air::Operation c_;
air::Tensor t_a_;
air::Tensor t_b_;
air::Tensor t_c_;
air::Operation out_;
air::Tensor t_out_;
air::Operation out_0_;
air::Stmt stmt_;
}; // class AutoPolyTest2
TEST_F(AutoPolyTest2, RunPass) {
SetRunMode("cloud");
air::Array<air::NodeRef> stmts_out = ir::AutoPoly(stmt_, binds_, global_attrs_, false, false);
ASSERT_EQ(stmts_out.size(), 2);
air::NodeRef stmt = stmts_out[0];
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> infos_lhs =
UTProvideCheckerForBinary(true).Find(
stmt, UTProvideCheckerForBinary::BinaryOpType::kAdd, "b_local_UB", "c_local_UB");
ASSERT_EQ(infos_lhs.size(), 1);
EXPECT_EQ(std::get<2>(infos_lhs[0]), 2 * 1024);
}
/* AutoPolyTest3: test for to_three_address
* Input pattern:
* for (i0, 0, 1024) {
* out_0(i0) = b(i0) + c(i0)
* }
* for (i1, 0, 32) {
* for (i0, 0, 1024) {
* out(i0, i1) = out_0(i1) + a(i0, i1)
* }
* }
*
* Expect output:
* for (cc1, 0, 2) {
* for (cc2, 0, 1024) {
* out_0_local_UB(cc2) = (b_local_UB(cc2) + c_local_UB(cc2))
* }
* for (cc2, 0, 1024) {
* for (cc3, 0, 16) {
* out_local_UB(cc3, cc2) = (out_0_local_UB(cc2) + a_local_UB(cc3, cc2))
* }
* }
* }
*
* IR Check:
* count for (b_local_UB + c_local_UB): 2 * 1024
*/
class AutoPolyTest3 : public AutoPolyTestBase {
public:
AutoPolyTest3() {
Construct();
}
~AutoPolyTest3() = default;
void Construct() {
a_ = UTExprBuilder::PlaceholderOpNode("a", {32, 1024}, air::Float(16));
b_ = UTExprBuilder::PlaceholderOpNode("b", {1024}, air::Float(16));
c_ = UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16));
out_ = UTExprBuilder::PlaceholderOpNode("out", {32, 1024}, air::Float(16));
out_0_ = UTExprBuilder::PlaceholderOpNode("out_0", {1024}, air::Float(16));
stmt_ = air::ir::AttrStmt::make(
out_0_, "realize_scope", air::ir::StringImm::make(""),
UTStmtBuilder::CreateRealizeByPlaceholderOp(
out_0_,
air::ir::AttrStmt::make(
out_, "realize_scope", air::ir::StringImm::make(""),
UTStmtBuilder::CreateRealizeByPlaceholderOp(
out_,
air::ir::ProducerConsumer::make(out_, true,
air::ir::Block::make(
UTStmtBuilder::CreateFor(
"i0", 0, 1024,
UTStmtBuilder::CreateProvideBinary<air::ir::Add>(
out_0_, {"i0"},
UTExprBuilder::ElementOf(b_, {"i0"}),
UTExprBuilder::ElementOf(c_, {"i0"}))),
UTStmtBuilder::CreateFor(
"i0", 0, 32,
UTStmtBuilder::CreateFor(
"i1", 0, 1024,
UTStmtBuilder::CreateProvideBinary<air::ir::Add>(
out_, {"i0", "i1"},
UTExprBuilder::ElementOf(out_0_, {"i1"}),
UTExprBuilder::ElementOf(a_, {"i0", "i1"}))))))))));
t_a_ = UTExprBuilder::CreateTensorByPlaceholder(a_);
t_b_ = UTExprBuilder::CreateTensorByPlaceholder(b_);
t_c_ = UTExprBuilder::CreateTensorByPlaceholder(c_);
t_out_ = UTExprBuilder::CreateTensorByPlaceholder(out_);
RegisterTensor(t_a_);
RegisterTensor(t_b_);
RegisterTensor(t_c_);
RegisterTensor(t_out_);
}
air::Operation a_;
air::Operation b_;
air::Operation c_;
air::Tensor t_a_;
air::Tensor t_b_;
air::Tensor t_c_;
air::Operation out_;
air::Tensor t_out_;
air::Operation out_0_;
air::Stmt stmt_;
}; // class AutoPolyTest3
TEST_F(AutoPolyTest3, RunPass) {
SetRunMode("cloud");
air::Array<air::NodeRef> stmts_out = ir::AutoPoly(stmt_, binds_, global_attrs_, false, false);
ASSERT_EQ(stmts_out.size(), 2);
air::NodeRef stmt = stmts_out[0];
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> infos_lhs =
UTProvideCheckerForBinary(true).Find(
stmt, UTProvideCheckerForBinary::BinaryOpType::kAdd, "b_local_UB", "c_local_UB");
ASSERT_EQ(infos_lhs.size(), 1);
EXPECT_EQ(std::get<2>(infos_lhs[0]), 2 * 1024);
}
} // namespace akg
......@@ -15,8 +15,10 @@
*/
#include <gtest/gtest.h>
#include <tvm/ir.h>
#include "base/expr_builder.h"
#include "base/dump_helper.h"
#include "base/expr_builder.h"
#include "base/ir_checker.h"
#include "base/stmt_builder.h"
#define private public
#define protected public
#include "pass/to_three_address.cc"
......@@ -71,4 +73,74 @@ TEST_F(ThreeAddressExprMutatorTest, MutateBinaryOp_Add) {
Expr expr_m = mutator_.Mutate(expr);
EXPECT_NE(mutator_.imm_ops.size(), 0);
}
class PassTestToThreeAddress1 : public ::testing::Test {
public:
PassTestToThreeAddress1() {
Construct();
}
~PassTestToThreeAddress1() = default;
void Construct() {
a_ = UTExprBuilder::PlaceholderOpNode("a", {1024}, air::Float(16));
b_ = UTExprBuilder::PlaceholderOpNode("b", {32, 1024}, air::Float(16));
c_ = UTExprBuilder::PlaceholderOpNode("c", {1024}, air::Float(16));
out_ = UTExprBuilder::PlaceholderOpNode("out", {32, 1024}, air::Float(16));
stmt = air::ir::AttrStmt::make(
out_, "", UTExprBuilder::IntImm(1),
UTStmtBuilder::CreateRealizeByPlaceholderOp(
out_,
air::ir::ProducerConsumer::make(out_, true,
UTStmtBuilder::CreateFor(
"i", 0, 32,
UTStmtBuilder::CreateFor(
"j", 0, 1024,
UTStmtBuilder::CreateProvideBinary<air::ir::Add>(
out_, {"i", "j"},
air::ir::Add::make(
UTExprBuilder::ElementOf(a_, {"j"}),
UTExprBuilder::ElementOf(b_, {"i", "j"})),
UTExprBuilder::ElementOf(c_, {"j"})))))));
}
air::Operation a_;
air::Operation b_;
air::Operation c_;
air::Operation out_;
air::Stmt stmt;
}; // class PassTestToThreeAddress1
TEST_F(PassTestToThreeAddress1, CaseCheck) {
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> infos_lhs =
UTProvideCheckerForAssign().Find(stmt, "((a(j) + b(i, j)) + c(j))");
ASSERT_EQ(infos_lhs.size(), 1);
EXPECT_EQ(std::get<0>(infos_lhs[0]), "out(i, j)");
EXPECT_EQ(std::get<2>(infos_lhs[0]), 32 * 1024);
}
TEST_F(PassTestToThreeAddress1, TestPass) {
Stmt stmt_out = ir::ToThreeAddress(stmt, false, 0, true);
/* current implementation
* out_2(i, j) = b(i, j)
* out_3(i, j) = (a(j) + out_2(i, j))
* out(i, j) = (out_3(i, j) + c(j))
*/
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> info1 =
UTProvideCheckerForAssign().Find(stmt_out, "b(i, j)");
ASSERT_EQ(info1.size(), 1);
std::string dump_b_target = std::get<0>(info1[0]);
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> info2 =
UTProvideCheckerForBinary().Find(
stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, "a(j)", dump_b_target);
ASSERT_EQ(info2.size(), 1);
std::string dump_sum1_target = std::get<0>(info2[0]);
EXPECT_EQ(std::get<2>(info2[0]), 32 * 1024);
std::vector<std::tuple<std::string, const air::ir::Provide*, uint64_t>> info3 =
UTProvideCheckerForBinary().Find(
stmt_out, UTProvideCheckerForBinary::BinaryOpType::kAdd, dump_sum1_target, "c(j)");
ASSERT_EQ(info3.size(), 1);
EXPECT_EQ(std::get<0>(info3[0]), "out(i, j)");
EXPECT_EQ(std::get<2>(info3[0]), 32 * 1024);
}
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pass_test_base/auto_poly_test_base.h"
namespace akg {
std::map<std::string, std::string> AutoPolyTestBase::map_mode_ =
AutoPolyTestBase::InitMapMode();
std::map<std::string, std::string> AutoPolyTestBase::InitMapMode() {
std::map<std::string, std::string> res;
res["cloud"] = "1.6";
res["mini"] = "1.1";
res["phoenix"] = "3.5";
res["orlando"] = "3.3";
return res;
}
void AutoPolyTestBase::SetRunMode(const std::string &mode) {
auto it = map_mode_.find(mode);
CHECK(it != map_mode_.end());
cceconf::CceConf::getInstance()->setSection(it->second);
}
void AutoPolyTestBase::RegisterTensor(const air::Tensor &tensor) {
const TensorNode *tensor_node = tensor.as<TensorNode>();
std::string name = tensor_node->op->name;
air::Buffer buf = air::BufferNode::make(
air::Variable::make(Handle(), name),
tensor_node->dtype,
tensor_node->shape,
Array<Expr>(),
Expr(),
name,
"",
-1,
0,
air::BufferType::kDefault);
binds_.Set(GetRef<Tensor>(tensor_node), buf);
}
} // namespace akg
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册