提交 b71359b1 编写于 作者: L LuoYin

initial unittest_cpp (based on googletest)

include:
1. simple Expr/Tensor Builder
2. simple DumpHelper and Regx Check
3. test for Expr/Tensor Builder
4. simple pass test example
上级 2c8b15f1
[submodule "third_party/ktvm"]
path = third_party/ktvm
url = https://gitee.com/mindspore/ktvm.git
[submodule "third_party/googletest"]
path = third_party/googletest
url = https://github.com/google/googletest.git
......@@ -230,3 +230,8 @@ if(ENABLE_AKG)
#file(COPY ${CMAKE_CURRENT_BINARY_DIR}/libakg.so DESTINATION "${AKG_SOURCE_DIR}/output")
endif()
# unittest_cpp
add_subdirectory(${AKG_SOURCE_DIR}/third_party/googletest)
add_subdirectory(${AKG_SOURCE_DIR}/tests/unittest_cpp)
set(GTEST_DIR "${AKG_SOURCE_DIR}/third_party/googletest")
set(UNITTEST_CPP_DIR "${AKG_SOURCE_DIR}/tests/unittest_cpp")
......@@ -602,3 +602,67 @@ ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
For more information, please refer to <http://unlicense.org>
Software: googletest 1.8.1
Copyright notice:
Copyright 2009, Google Inc.
Copyright 2008, Google Inc.
Copyright 2007 Google Inc.
Copyright 2007, Google Inc.
Copyright 2013, Google Inc.
Copyright 2015, Google Inc.
Copyright 2005, Google Inc.
Copyright 2008 Google Inc.
Copyright 2006, Google Inc.
Copyright 2009 Google Inc. All Rights Reserved.
Copyright 2013 Google Inc. All Rights Reserved.
Copyright 2017 Google Inc.
Copyright 2007 Neal Norwitz
Copyright 2008 Google Inc. All Rights Reserved.
Copyright 2009 Neal Norwitz All Rights Reserved.
Copyright 2003 Google Inc.
Copyright 2009 Google Inc.
Copyright 2008 Google Inc. All Rights Reserved.
Copyright [2007] Neal Norwitz
Portions Copyright [2007] Google Inc.
Copyright 2010 Google Inc. All Rights Reserved.
Copyright 2010, Google Inc.
Copyright 2005 Google Inc. All Rights Reserved.
Copyright 2018, Google Inc.
Copyright 2003, Google Inc.
Copyright 2009 Google Inc. All rights reserved.
Copyright 2015 Google Inc. All rights reserved.
Copyright 2009 Google Inc. All rights reserved.
Copyright 2018 Google LLC. All rights reserved.
Copyright 2018, Google LLC.
License: BSD 3-Clause License
Copyright 2008, Google Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
......@@ -1329,9 +1329,5 @@ Stmt ToThreeAddress(Stmt stmt, bool reuse_variable, int minimum_split, bool cros
stmt = ThreeAddressStmtMutator(reuse_variable, minimum_split, cross_stmt_simplify).Mutate(stmt);
return Simplify_cce(stmt);
}
TVM_REGISTER_API("ir_pass.ExprEqual").set_body([](const TVMArgs args, TVMRetValue *ret) {
*ret = Equal(args[0].operator Expr(), args[1].operator Expr());
});
} // namespace ir
} // namespace akg
include_directories(${GTEST_DIR}/googletest/include)
include_directories(${UNITTEST_CPP_SRC}include)
include_directories(${AKG_SOURCE_DIR}/src)
include_directories(${TVM_DIR}/include)
include_directories(${TVM_DIR}/topi/include)
file(
GLOB
UT_CPP_SRC
unittest_main.cc
src/base/*.cc
src/base_test/*.cc
src/pass_test/*.cc)
add_executable(unittest_main ${UT_CPP_SRC})
target_link_libraries(unittest_main gtest akg ${TVM_RUNTIME_LINKER_LIBS} ${CMAKE_DL_LIBS})
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UT_BASE_DUMP_HELPER_H_
#define UT_BASE_DUMP_HELPER_H_
#include <string>
#include <tvm/node/node.h>
#include <tvm/expr.h>
namespace akg {
class UTRegxMatch {
public:
UTRegxMatch() = default;
~UTRegxMatch() = default;
// match pattern 0x...
static bool RegxMatchHex(const std::string &str);
static const std::string pattern_hex_;
}; // UTRegxMatch
class UTDumpHelper {
public:
UTDumpHelper() = default;
~UTDumpHelper() = default;
static std::string Dump(const ktvm::NodeRef &node);
static bool RegxMatchPlaceholder(const std::string &str, const std::string &name);
}; // UTDumpHelper
} // namespace akg
#endif // UT_BASE_DUMP_HELPER_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UT_BASE_EXPR_BUILDER_H_
#define UT_BASE_EXPR_BUILDER_H_
#include <string>
#include <vector>
#include "tvm/expr.h"
#include "tvm/operation.h"
namespace akg {
class UTExprBuilder {
public:
UTExprBuilder() = default;
~UTExprBuilder() = default;
static ktvm::Array<ktvm::Expr> CreateShape(const std::vector<int32_t> &shapes);
static ktvm::Var CreateVar(const std::string &name);
static ktvm::Array<ktvm::Expr> CreateVars(const std::vector<std::string> &names);
static ktvm::Operation PlaceholderOpNode(
const std::string &name,
const std::vector<int32_t> &shapes,
ktvm::DataType dtype = ktvm::Float(16));
static ktvm::Expr TensorElement(
const std::string &name,
const std::vector<int32_t> &shapes,
const std::vector<std::string> &axis_names,
ktvm::DataType dtype = ktvm::Float(16));
}; // UTExprBuilder
class UTTensorElementHelper {
public:
UTTensorElementHelper(const std::vector<int32_t> &shapes,
const std::string &axis_name_prefix = "ax");
~UTTensorElementHelper() = default;
ktvm::Expr Elem(const std::string &name,
uint32_t dim,
ktvm::DataType dtype = ktvm::Float(16)) const;
private:
std::vector<int32_t> shapes_;
std::string axis_name_prefix_;
std::vector<std::string> axis_names_;
}; // UTTensorElementHelper
} // namespace akg
#endif // UT_BASE_EXPR_BUILDER_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <regex>
#include <sstream>
#include "base/dump_helper.h"
namespace akg {
const std::string UTRegxMatch::pattern_hex_ = "0[xX][0-9a-fA-F]+";
bool UTRegxMatch::RegxMatchHex(const std::string &str) {
return std::regex_match(str, std::regex(pattern_hex_));
}
std::string UTDumpHelper::Dump(const ktvm::NodeRef &node) {
std::stringstream ss;
ss << node;
return ss.str();
}
bool UTDumpHelper::RegxMatchPlaceholder(const std::string &str, const std::string &name) {
std::string pattern = "placeholder\\(" + name + ", " + UTRegxMatch::pattern_hex_ + "\\)";
return std::regex_match(str, std::regex(pattern));
}
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sstream>
#include "base/expr_builder.h"
namespace akg {
ktvm::Array<ktvm::Expr> UTExprBuilder::CreateShape(const std::vector<int32_t> &shapes) {
ktvm::Array<ktvm::Expr> res;
for (int32_t shape : shapes) {
ktvm::Integer imm = ktvm::IntImm::make(ktvm::Int(32), shape);
res.push_back(imm);
}
return res;
}
ktvm::Var UTExprBuilder::CreateVar(const std::string &name) {
return ktvm::Var(name);
}
ktvm::Array<ktvm::Expr> UTExprBuilder::CreateVars(const std::vector<std::string> &names) {
ktvm::Array<ktvm::Expr> vars;
for (const std::string &name : names) {
vars.push_back(std::move(CreateVar(name)));
}
return vars;
}
ktvm::Operation UTExprBuilder::PlaceholderOpNode(
const std::string &name,
const std::vector<int32_t> &shapes,
ktvm::DataType dtype) {
ktvm::Array<ktvm::Expr> expr_shapes = CreateShape(shapes);
return ktvm::PlaceholderOpNode::make(name, expr_shapes, dtype);
}
ktvm::Expr UTExprBuilder::TensorElement(
const std::string &name,
const std::vector<int32_t> &shapes,
const std::vector<std::string> &axis_names,
ktvm::DataType dtype) {
return ktvm::ir::Call::make(
dtype, // type
name, // name
CreateVars(axis_names), // args
ktvm::ir::Call::Halide, // call_type
PlaceholderOpNode(name, shapes, dtype), // func,
0); // value_index
}
UTTensorElementHelper::UTTensorElementHelper(const std::vector<int32_t> &shapes,
const std::string &axis_name_prefix)
: shapes_(shapes), axis_name_prefix_(axis_name_prefix) {
std::stringstream ss;
for (size_t i = 0; i < shapes_.size(); i++) {
ss << axis_name_prefix_ << i;
axis_names_.push_back(ss.str());
ss.str("");
}
}
ktvm::Expr UTTensorElementHelper::Elem(const std::string &name,
uint32_t dim,
ktvm::DataType dtype) const {
uint32_t start = shapes_.size() - dim;
return UTExprBuilder::TensorElement(
name,
std::vector<int32_t>(shapes_.begin() + start, shapes_.end()),
std::vector<std::string>(axis_names_.begin() + start, axis_names_.end()),
dtype);
}
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include "base/dump_helper.h"
namespace akg {
TEST(UTRegxMatch, RegxMatchHex) {
EXPECT_EQ(UTRegxMatch::RegxMatchHex("0x123"), true);
EXPECT_EQ(UTRegxMatch::RegxMatchHex("0xABC"), true);
EXPECT_EQ(UTRegxMatch::RegxMatchHex("0XABC"), true);
EXPECT_EQ(UTRegxMatch::RegxMatchHex("0x"), false);
}
TEST(UTDumpHelper, RegxMatchPlaceholder) {
EXPECT_EQ(UTDumpHelper::RegxMatchPlaceholder("placeholder(input, 0x1234abcd)", "input"), true);
}
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "gtest/gtest.h"
#include "base/dump_helper.h"
#include "base/expr_builder.h"
namespace akg {
TEST(UTExprBuilder, CreateShape) {
ktvm::Array<ktvm::Expr> shape1 = UTExprBuilder::CreateShape({1024});
std::string dump_shape1 = UTDumpHelper::Dump(shape1);
EXPECT_EQ(dump_shape1, "[1024]");
ktvm::Array<ktvm::Expr> shape2 = UTExprBuilder::CreateShape({32, 1024});
std::string dump_shape2 = UTDumpHelper::Dump(shape2);
EXPECT_EQ(dump_shape2, "[32, 1024]");
ktvm::Array<ktvm::Expr> shape3 = UTExprBuilder::CreateShape({16, 32, 1024});
std::string dump_shape3 = UTDumpHelper::Dump(shape3);
EXPECT_EQ(dump_shape3, "[16, 32, 1024]");
}
TEST(UTExprBuilder, CreateVar) {
ktvm::Var var = UTExprBuilder::CreateVar("ax0");
std::string dump_var = UTDumpHelper::Dump(var);
EXPECT_EQ(dump_var, "ax0");
}
TEST(UTExprBuilder, CreateVars) {
ktvm::Array<ktvm::Expr> vars = UTExprBuilder::CreateVars({"ax0", "ax1", "ax2"});
std::string dump_vars = UTDumpHelper::Dump(vars);
EXPECT_EQ(dump_vars, "[ax0, ax1, ax2]");
}
TEST(UTExprBuilder, PlaceholderOpNode) {
ktvm::Operation node = UTExprBuilder::PlaceholderOpNode("input", {16, 32, 1024}, ktvm::Float(16));
std::string dump_node = UTDumpHelper::Dump(node);
EXPECT_EQ(UTDumpHelper::RegxMatchPlaceholder(dump_node, "input"), true);
}
TEST(UTExprBuilder, TensorElement) {
ktvm::Expr elem = UTExprBuilder::TensorElement("input", {16, 32, 1024}, {"ax0", "ax1", "ax2"}, ktvm::Float(16));
std::string dump_elem = UTDumpHelper::Dump(elem);
EXPECT_EQ(dump_elem, "input(ax0, ax1, ax2)");
}
TEST(UTTensorElementHelper, TensorElement) {
UTTensorElementHelper helper({16, 32, 1024});
std::string dump_elem1 = UTDumpHelper::Dump(helper.Elem("a", 3));
EXPECT_EQ(dump_elem1, "a(ax0, ax1, ax2)");
std::string dump_elem2 = UTDumpHelper::Dump(helper.Elem("b", 2));
EXPECT_EQ(dump_elem2, "b(ax1, ax2)");
std::string dump_elem3 = UTDumpHelper::Dump(helper.Elem("c", 1));
EXPECT_EQ(dump_elem3, "c(ax2)");
}
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <tvm/ir.h>
#include "base/expr_builder.h"
#include "base/dump_helper.h"
#define private public
#define protected public
#include "pass/to_three_address.cc"
#undef protected
#undef private
namespace akg {
class ToThreeAddressTest {
public:
ToThreeAddressTest() = default;
~ToThreeAddressTest() = default;
};
TEST(ToThreeAddressTest, BuildCase1) {
UTTensorElementHelper th({16, 32, 1024});
using Add = ktvm::ir::Add;
// a(ax1, ax2) + b(ax2) + c(ax0, ax1, ax2) + d(ax2)
ktvm::Expr expr =
Add::make(
Add::make(
Add::make(th.Elem("a", 2), th.Elem("b", 1)),
th.Elem("c", 3)),
th.Elem("d", 1));
std::string dump_expr = UTDumpHelper::Dump(expr);
EXPECT_EQ(dump_expr, "(((a(ax1, ax2) + b(ax2)) + c(ax0, ax1, ax2)) + d(ax2))");
}
class ThreeAddressExprMutatorTest : public testing::Test {
public:
ThreeAddressExprMutatorTest()
: mutator_(ktvm::TensorNode::make(
UTExprBuilder::CreateShape(shape_), // shape
dtype_, // dtype
UTExprBuilder::PlaceholderOpNode("out", shape_), // op
0), // index
UTExprBuilder::CreateVars({"ax0", "ax1", "ax2"}), // args
UTExprBuilder::CreateShape(shape_), // shape
std::unordered_set<const Call *>(), // broadcast
false, // IsReductionOp
false) {} // cross_stmt_simplify
~ThreeAddressExprMutatorTest() = default;
std::vector<int32_t> shape_ = {16, 32, 1024};
ktvm::DataType dtype_ = ktvm::Float(16);
ir::ThreeAddressExprMutator mutator_;
}; // ThreeAddressExprMutatorTest
TEST_F(ThreeAddressExprMutatorTest, MutateBinaryOp_Add) {
UTTensorElementHelper th(shape_);
using Add = ktvm::ir::Add;
ktvm::Expr expr = Add::make(th.Elem("a", 2), th.Elem("b", 1));
Expr expr_m = mutator_.Mutate(expr);
EXPECT_NE(mutator_.imm_ops.size(), 0);
}
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
int main(int argc, char **argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
Subproject commit 2fe3bd994b3189899d93f1d5a881e725e046fdc2
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册