Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
akg
提交
6a84977e
A
akg
项目概览
MindSpore
/
akg
通知
58
Star
7
Fork
7
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
akg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6a84977e
编写于
7月 23, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 23, 2020
浏览文件
操作
浏览文件
下载
差异文件
!66 update LLT-UT to support pass-level test
Merge pull request !66 from LuoYin/unittest_cpp
上级
d53e0c74
20e24526
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
1179 addition
and
9 deletion
+1179
-9
tests/unittest_cpp/CMakeLists.txt
tests/unittest_cpp/CMakeLists.txt
+1
-0
tests/unittest_cpp/include/base/dump_helper.h
tests/unittest_cpp/include/base/dump_helper.h
+2
-2
tests/unittest_cpp/include/base/expr_builder.h
tests/unittest_cpp/include/base/expr_builder.h
+22
-4
tests/unittest_cpp/include/base/ir_checker.h
tests/unittest_cpp/include/base/ir_checker.h
+103
-0
tests/unittest_cpp/include/base/stmt_builder.h
tests/unittest_cpp/include/base/stmt_builder.h
+58
-0
tests/unittest_cpp/include/pass_test_base/auto_poly_test_base.h
...unittest_cpp/include/pass_test_base/auto_poly_test_base.h
+64
-0
tests/unittest_cpp/src/base/expr_builder.cc
tests/unittest_cpp/src/base/expr_builder.cc
+87
-2
tests/unittest_cpp/src/base/ir_checker.cc
tests/unittest_cpp/src/base/ir_checker.cc
+118
-0
tests/unittest_cpp/src/base/stmt_builder.cc
tests/unittest_cpp/src/base/stmt_builder.cc
+57
-0
tests/unittest_cpp/src/base_test/expr_builder_test.cc
tests/unittest_cpp/src/base_test/expr_builder_test.cc
+46
-0
tests/unittest_cpp/src/base_test/ir_checker_test.cc
tests/unittest_cpp/src/base_test/ir_checker_test.cc
+105
-0
tests/unittest_cpp/src/base_test/stmt_builder_test.cc
tests/unittest_cpp/src/base_test/stmt_builder_test.cc
+81
-0
tests/unittest_cpp/src/pass_test/auto_poly_test.cc
tests/unittest_cpp/src/pass_test/auto_poly_test.cc
+309
-0
tests/unittest_cpp/src/pass_test/to_three_address_test.cc
tests/unittest_cpp/src/pass_test/to_three_address_test.cc
+73
-1
tests/unittest_cpp/src/pass_test_base/auto_poly_test_base.cc
tests/unittest_cpp/src/pass_test_base/auto_poly_test_base.cc
+53
-0
未找到文件。
tests/unittest_cpp/CMakeLists.txt
浏览文件 @
6a84977e
...
@@ -16,6 +16,7 @@ file(
...
@@ -16,6 +16,7 @@ file(
unittest_main.cc
unittest_main.cc
src/base/*.cc
src/base/*.cc
src/base_test/*.cc
src/base_test/*.cc
src/pass_test_base/*.cc
src/pass_test/*.cc
)
src/pass_test/*.cc
)
link_directories
(
${
CMAKE_BINARY_DIR
}
/googletest/googlemock/gtest
)
link_directories
(
${
CMAKE_BINARY_DIR
}
/googletest/googlemock/gtest
)
...
...
tests/unittest_cpp/include/base/dump_helper.h
浏览文件 @
6a84977e
...
@@ -29,7 +29,7 @@ class UTRegxMatch {
...
@@ -29,7 +29,7 @@ class UTRegxMatch {
static
bool
RegxMatchHex
(
const
std
::
string
&
str
);
static
bool
RegxMatchHex
(
const
std
::
string
&
str
);
static
const
std
::
string
pattern_hex_
;
static
const
std
::
string
pattern_hex_
;
};
// UTRegxMatch
};
//
class
UTRegxMatch
class
UTDumpHelper
{
class
UTDumpHelper
{
public:
public:
...
@@ -38,6 +38,6 @@ class UTDumpHelper {
...
@@ -38,6 +38,6 @@ class UTDumpHelper {
static
std
::
string
Dump
(
const
air
::
NodeRef
&
node
);
static
std
::
string
Dump
(
const
air
::
NodeRef
&
node
);
static
bool
RegxMatchPlaceholder
(
const
std
::
string
&
str
,
const
std
::
string
&
name
);
static
bool
RegxMatchPlaceholder
(
const
std
::
string
&
str
,
const
std
::
string
&
name
);
};
// UTDumpHelper
};
//
class
UTDumpHelper
}
// namespace akg
}
// namespace akg
#endif // UT_BASE_DUMP_HELPER_H_
#endif // UT_BASE_DUMP_HELPER_H_
tests/unittest_cpp/include/base/expr_builder.h
浏览文件 @
6a84977e
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <vector>
#include <vector>
#include "tvm/expr.h"
#include "tvm/expr.h"
#include "tvm/operation.h"
#include "tvm/operation.h"
#include "tvm/tensor.h"
namespace
akg
{
namespace
akg
{
class
UTExprBuilder
{
class
UTExprBuilder
{
...
@@ -26,9 +27,15 @@ class UTExprBuilder {
...
@@ -26,9 +27,15 @@ class UTExprBuilder {
UTExprBuilder
()
=
default
;
UTExprBuilder
()
=
default
;
~
UTExprBuilder
()
=
default
;
~
UTExprBuilder
()
=
default
;
static
air
::
Expr
IntImm
(
int64_t
value
,
air
::
DataType
dtype
=
air
::
Int
(
32
));
static
air
::
Expr
UIntImm
(
uint64_t
value
,
air
::
DataType
dtype
=
air
::
UInt
(
32
));
static
air
::
Expr
BoolImm
(
bool
value
);
static
air
::
Array
<
air
::
Expr
>
CreateShape
(
const
std
::
vector
<
int32_t
>
&
shapes
);
static
air
::
Array
<
air
::
Expr
>
CreateShape
(
const
std
::
vector
<
int32_t
>
&
shapes
);
static
air
::
Var
CreateVar
(
const
std
::
string
&
name
);
static
air
::
Var
CreateVar
(
const
std
::
string
&
name
);
static
air
::
Array
<
air
::
Expr
>
CreateVars
(
const
std
::
vector
<
std
::
string
>
&
names
);
static
air
::
Array
<
air
::
Expr
>
CreateVars
(
const
std
::
vector
<
std
::
string
>
&
names
);
static
air
::
Range
CreateRange
(
int32_t
min
,
int32_t
max
);
static
air
::
Region
CreateRegion
(
const
std
::
vector
<
int32_t
>
&
shapes
);
static
air
::
Region
CreateRegion
(
const
air
::
Array
<
air
::
Expr
>
&
shapes
);
static
air
::
Operation
PlaceholderOpNode
(
static
air
::
Operation
PlaceholderOpNode
(
const
std
::
string
&
name
,
const
std
::
string
&
name
,
const
std
::
vector
<
int32_t
>
&
shapes
,
const
std
::
vector
<
int32_t
>
&
shapes
,
...
@@ -38,6 +45,18 @@ class UTExprBuilder {
...
@@ -38,6 +45,18 @@ class UTExprBuilder {
const
std
::
vector
<
int32_t
>
&
shapes
,
const
std
::
vector
<
int32_t
>
&
shapes
,
const
std
::
vector
<
std
::
string
>
&
axis_names
,
const
std
::
vector
<
std
::
string
>
&
axis_names
,
air
::
DataType
dtype
=
air
::
Float
(
16
));
air
::
DataType
dtype
=
air
::
Float
(
16
));
static
air
::
Expr
ElementOf
(
const
air
::
Operation
&
op
,
const
std
::
vector
<
std
::
string
>
&
axis_names
);
static
air
::
Expr
ElementOfPlaceholderOp
(
const
air
::
Operation
&
op
,
const
std
::
vector
<
std
::
string
>
&
axis_names
);
static
air
::
Expr
CreateCall
(
const
air
::
ir
::
FunctionRef
func
,
air
::
Array
<
air
::
Expr
>
args
,
air
::
ir
::
Call
::
CallType
call_type
=
air
::
ir
::
Call
::
Halide
,
int
value_index
=
0
);
static
air
::
Tensor
CreateTensorByPlaceholder
(
const
air
::
Operation
op
);
};
// UTExprBuilder
};
// UTExprBuilder
class
UTTensorElementHelper
{
class
UTTensorElementHelper
{
...
@@ -46,14 +65,13 @@ class UTTensorElementHelper {
...
@@ -46,14 +65,13 @@ class UTTensorElementHelper {
const
std
::
string
&
axis_name_prefix
=
"ax"
);
const
std
::
string
&
axis_name_prefix
=
"ax"
);
~
UTTensorElementHelper
()
=
default
;
~
UTTensorElementHelper
()
=
default
;
air
::
Expr
Elem
(
const
std
::
string
&
name
,
air
::
Expr
Elem
(
const
std
::
string
&
name
,
uint32_t
dim
,
uint32_t
dim
,
air
::
DataType
dtype
=
air
::
Float
(
16
))
const
;
air
::
DataType
dtype
=
air
::
Float
(
16
))
const
;
private:
private:
std
::
vector
<
int32_t
>
shapes_
;
std
::
vector
<
int32_t
>
shapes_
;
std
::
string
axis_name_prefix_
;
std
::
string
axis_name_prefix_
;
std
::
vector
<
std
::
string
>
axis_names_
;
std
::
vector
<
std
::
string
>
axis_names_
;
};
// UTTensorElementHelper
};
// class UTTensorElementHelper
}
// namespace akg
}
// namespace akg
#endif // UT_BASE_EXPR_BUILDER_H_
#endif // UT_BASE_EXPR_BUILDER_H_
tests/unittest_cpp/include/base/ir_checker.h
0 → 100644
浏览文件 @
6a84977e
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UT_IR_CHECKER_H_
#define UT_IR_CHECKER_H_
#include <string>
#include <tuple>
#include <vector>
#include <tvm/ir_visitor.h>
#include "base/dump_helper.h"
#include "base/expr_builder.h"
namespace
akg
{
class
UTIRCheckHelper
{
public:
UTIRCheckHelper
()
=
default
;
~
UTIRCheckHelper
()
=
default
;
static
int64_t
GetValueFromImm
(
const
air
::
Expr
&
expr
);
};
// class UTIRCheckHelper
class
UTProvideChecker
:
public
air
::
ir
::
IRVisitor
{
public:
explicit
UTProvideChecker
(
bool
ignore_args
=
false
)
:
ignore_args_
(
ignore_args
)
{}
~
UTProvideChecker
()
=
default
;
void
Visit_
(
const
air
::
ir
::
For
*
op
)
override
;
bool
CompareDump
(
const
std
::
string
&
dump
,
const
std
::
string
&
target
);
protected:
bool
ignore_args_
{
false
};
std
::
vector
<
uint64_t
>
for_count_stack_
;
};
// class UTProvideChecker
class
UTProvideCheckerForAssign
:
public
UTProvideChecker
{
public:
explicit
UTProvideCheckerForAssign
(
bool
ignore_args
=
false
)
:
UTProvideChecker
(
ignore_args
)
{}
~
UTProvideCheckerForAssign
()
=
default
;
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
Find
(
const
air
::
NodeRef
&
node
,
const
std
::
string
&
dump_rhs
);
void
Visit_
(
const
air
::
ir
::
Provide
*
op
)
override
;
private:
std
::
string
dump_rhs_
{
""
};
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
infos_lhs_
;
};
// class UTProvideChecker
class
UTProvideCheckerForBinary
:
public
UTProvideChecker
{
public:
enum
BinaryOpType
:
int
{
kAdd
,
kSub
,
kMul
,
kDiv
,
kMod
,
};
explicit
UTProvideCheckerForBinary
(
bool
ignore_args
=
false
)
:
UTProvideChecker
(
ignore_args
)
{}
~
UTProvideCheckerForBinary
()
=
default
;
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
Find
(
const
air
::
NodeRef
&
node
,
BinaryOpType
op_type
,
const
std
::
string
&
dump_rhs1
,
const
std
::
string
&
dump_rhs2
);
void
Visit_
(
const
air
::
ir
::
Provide
*
op
)
override
;
template
<
typename
T
>
void
CheckBinary
(
const
air
::
ir
::
Provide
*
op
)
{
const
T
*
expr_binary
=
op
->
value
.
as
<
T
>
();
if
(
expr_binary
==
nullptr
)
{
return
;
}
std
::
string
dump_expr_a
=
UTDumpHelper
::
Dump
(
expr_binary
->
a
);
std
::
string
dump_expr_b
=
UTDumpHelper
::
Dump
(
expr_binary
->
b
);
if
((
dump_rhs1_
.
empty
()
||
CompareDump
(
dump_expr_a
,
dump_rhs1_
))
&&
(
dump_rhs2_
.
empty
()
||
CompareDump
(
dump_expr_b
,
dump_rhs2_
)))
{
air
::
Expr
expr_call
=
UTExprBuilder
::
CreateCall
(
op
->
func
,
op
->
args
);
infos_lhs_
.
push_back
(
std
::
make_tuple
(
UTDumpHelper
::
Dump
(
expr_call
),
op
,
for_count_stack_
.
back
()));
}
}
private:
BinaryOpType
op_type_
;
std
::
string
dump_rhs1_
{
""
};
std
::
string
dump_rhs2_
{
""
};
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
infos_lhs_
;
};
// class UTProvideCheckerForBinary
}
// namespace akg
#endif // UT_IR_CHECKER_H_
tests/unittest_cpp/include/base/stmt_builder.h
0 → 100644
浏览文件 @
6a84977e
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UT_BASE_STMT_BUILDER_H_
#define UT_BASE_STMT_BUILDER_H_
#include <list>
#include <string>
#include <vector>
#include "tvm/ir.h"
#include "base/expr_builder.h"
namespace
akg
{
class
UTStmtBuilder
{
public:
UTStmtBuilder
()
=
default
;
~
UTStmtBuilder
()
=
default
;
static
air
::
Stmt
CreateFor
(
const
std
::
string
&
loop_var_name
,
int32_t
min
,
int32_t
extent
,
air
::
Stmt
body
);
static
air
::
Stmt
CreateRealizeByPlaceholderOp
(
const
air
::
Operation
&
op
,
air
::
Stmt
body
);
static
air
::
Stmt
CreateProvideAssign
(
air
::
ir
::
FunctionRef
func_dst
,
const
std
::
vector
<
std
::
string
>
&
vars
,
air
::
Expr
src
,
int
value_index
=
0
);
template
<
typename
T
>
static
air
::
Stmt
CreateProvideBinary
(
air
::
ir
::
FunctionRef
func_dst
,
const
std
::
vector
<
std
::
string
>
&
vars
,
air
::
Expr
src1
,
air
::
Expr
src2
,
int
value_index
=
0
)
{
return
air
::
ir
::
Provide
::
make
(
func_dst
,
value_index
,
T
::
make
(
src1
,
src2
),
UTExprBuilder
::
CreateVars
(
vars
));
}
};
// class UTStmtBuilder
}
// namespace akg
#endif // UT_BASE_STMT_BUILDER_H_
tests/unittest_cpp/include/pass_test_base/auto_poly_test_base.h
0 → 100644
浏览文件 @
6a84977e
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UT_AUTO_POLY_TEST_BASE_H_
#define UT_AUTO_POLY_TEST_BASE_H_
#include <map>
#include <string>
#include <gtest/gtest.h>
#include <tvm/expr.h>
#include "codegen/util.h"
#include "contrib/cce_parm/cceconf.h"
#include "base/expr_builder.h"
namespace
akg
{
class
AutoPolyTestBase
:
public
::
testing
::
Test
{
public:
AutoPolyTestBase
()
=
default
;
~
AutoPolyTestBase
()
=
default
;
static
std
::
map
<
std
::
string
,
std
::
string
>
InitMapMode
();
void
RegisterTensor
(
const
air
::
Tensor
&
tensor
);
void
SetRunMode
(
const
std
::
string
&
mode
);
void
GlobalAttrSetIsDynamic
(
bool
arg
)
{
global_attrs_
.
Set
(
"is_dynamic"
,
arg
?
UTExprBuilder
::
IntImm
(
1
,
air
::
Int
(
1
))
:
UTExprBuilder
::
IntImm
(
0
,
air
::
Int
(
1
)));
}
void
GlobalAttrSetDynamic
(
bool
arg
)
{
global_attrs_
.
Set
(
"dynamic"
,
arg
?
UTExprBuilder
::
IntImm
(
1
,
air
::
Int
(
1
))
:
UTExprBuilder
::
IntImm
(
0
,
air
::
Int
(
1
)));
}
void
GlobalAttrSetDumpPassIR
(
bool
arg
)
{
global_attrs_
.
Set
(
"dump_pass_ir"
,
arg
?
UTExprBuilder
::
IntImm
(
1
,
air
::
Int
(
1
))
:
UTExprBuilder
::
IntImm
(
0
,
air
::
Int
(
1
)));
}
void
GlobalAttrSetDumpPolyDir
(
const
std
::
string
&
path
)
{
global_attrs_
.
Set
(
"dump_poly_dir"
,
air
::
ir
::
StringImm
::
make
(
path
));
}
void
GlobalAttrSetKernalName
(
const
std
::
string
&
name
)
{
global_attrs_
.
Set
(
"kernel_name"
,
air
::
ir
::
StringImm
::
make
(
name
));
}
static
std
::
map
<
std
::
string
,
std
::
string
>
map_mode_
;
protected:
air
::
Map
<
air
::
Tensor
,
air
::
Buffer
>
binds_
;
AttrMap
global_attrs_
;
};
// class AutoPolyTestBase
}
// namespace akg
#endif
tests/unittest_cpp/src/base/expr_builder.cc
浏览文件 @
6a84977e
...
@@ -14,9 +14,22 @@
...
@@ -14,9 +14,22 @@
* limitations under the License.
* limitations under the License.
*/
*/
#include <sstream>
#include <sstream>
#include <tvm/operation.h>
#include "base/expr_builder.h"
#include "base/expr_builder.h"
namespace
akg
{
namespace
akg
{
air
::
Expr
UTExprBuilder
::
IntImm
(
int64_t
value
,
air
::
DataType
dtype
)
{
return
air
::
IntImm
::
make
(
dtype
,
value
);
}
air
::
Expr
UTExprBuilder
::
UIntImm
(
uint64_t
value
,
air
::
DataType
dtype
)
{
return
air
::
ir
::
UIntImm
::
make
(
dtype
,
value
);
}
air
::
Expr
UTExprBuilder
::
BoolImm
(
bool
value
)
{
return
air
::
ir
::
UIntImm
::
make
(
air
::
Bool
(),
value
?
1
:
0
);
}
air
::
Array
<
air
::
Expr
>
UTExprBuilder
::
CreateShape
(
const
std
::
vector
<
int32_t
>
&
shapes
)
{
air
::
Array
<
air
::
Expr
>
UTExprBuilder
::
CreateShape
(
const
std
::
vector
<
int32_t
>
&
shapes
)
{
air
::
Array
<
air
::
Expr
>
res
;
air
::
Array
<
air
::
Expr
>
res
;
for
(
int32_t
shape
:
shapes
)
{
for
(
int32_t
shape
:
shapes
)
{
...
@@ -38,6 +51,28 @@ air::Array<air::Expr> UTExprBuilder::CreateVars(const std::vector<std::string> &
...
@@ -38,6 +51,28 @@ air::Array<air::Expr> UTExprBuilder::CreateVars(const std::vector<std::string> &
return
vars
;
return
vars
;
}
}
air
::
Region
UTExprBuilder
::
CreateRegion
(
const
std
::
vector
<
int32_t
>
&
shapes
)
{
air
::
Region
region
;
for
(
int32_t
shape
:
shapes
)
{
region
.
push_back
(
CreateRange
(
0
,
shape
));
}
return
region
;
}
air
::
Region
UTExprBuilder
::
CreateRegion
(
const
air
::
Array
<
air
::
Expr
>
&
shapes
)
{
air
::
Region
region
;
for
(
const
air
::
Expr
&
shape
:
shapes
)
{
region
.
push_back
(
air
::
Range
::
make_by_min_extent
(
IntImm
(
0
),
shape
));
}
return
region
;
}
air
::
Range
UTExprBuilder
::
CreateRange
(
int32_t
min
,
int32_t
max
)
{
air
::
Integer
imm_min
=
air
::
IntImm
::
make
(
air
::
Int
(
32
),
min
);
air
::
Integer
imm_max
=
air
::
IntImm
::
make
(
air
::
Int
(
32
),
max
);
return
air
::
Range
(
std
::
move
(
imm_min
),
std
::
move
(
imm_max
));
}
air
::
Operation
UTExprBuilder
::
PlaceholderOpNode
(
air
::
Operation
UTExprBuilder
::
PlaceholderOpNode
(
const
std
::
string
&
name
,
const
std
::
string
&
name
,
const
std
::
vector
<
int32_t
>
&
shapes
,
const
std
::
vector
<
int32_t
>
&
shapes
,
...
@@ -60,6 +95,56 @@ air::Expr UTExprBuilder::TensorElement(
...
@@ -60,6 +95,56 @@ air::Expr UTExprBuilder::TensorElement(
0
);
// value_index
0
);
// value_index
}
}
air
::
Expr
UTExprBuilder
::
ElementOf
(
const
air
::
Operation
&
op
,
const
std
::
vector
<
std
::
string
>
&
axis_names
)
{
if
(
op
->
template
IsInstance
<
air
::
PlaceholderOpNode
>())
{
return
ElementOfPlaceholderOp
(
op
,
axis_names
);
}
else
{
CHECK
(
false
);
return
air
::
ir
::
Any
::
make
();
}
}
air
::
Expr
UTExprBuilder
::
ElementOfPlaceholderOp
(
const
air
::
Operation
&
op
,
const
std
::
vector
<
std
::
string
>
&
axis_names
)
{
const
air
::
PlaceholderOpNode
*
node
=
op
.
as
<
const
air
::
PlaceholderOpNode
>
();
CHECK
(
node
);
return
air
::
ir
::
Call
::
make
(
node
->
dtype
,
node
->
name
,
CreateVars
(
axis_names
),
air
::
ir
::
Call
::
Halide
,
op
,
0
);
}
air
::
Expr
UTExprBuilder
::
CreateCall
(
const
air
::
ir
::
FunctionRef
func
,
air
::
Array
<
air
::
Expr
>
args
,
air
::
ir
::
Call
::
CallType
call_type
,
int
value_index
)
{
air
::
DataType
type
=
air
::
Float
(
16
);
const
air
::
OperationNode
*
node_op
=
func
.
as
<
air
::
OperationNode
>
();
CHECK
(
node_op
);
std
::
string
name
=
node_op
->
name
;
const
air
::
PlaceholderOpNode
*
node_placeholder
=
func
.
as
<
air
::
PlaceholderOpNode
>
();
if
(
node_placeholder
!=
nullptr
)
{
type
=
node_placeholder
->
dtype
;
}
return
air
::
ir
::
Call
::
make
(
type
,
name
,
args
,
call_type
,
func
,
value_index
);
}
air
::
Tensor
UTExprBuilder
::
CreateTensorByPlaceholder
(
const
air
::
Operation
op
)
{
const
air
::
PlaceholderOpNode
*
node
=
op
.
as
<
air
::
PlaceholderOpNode
>
();
CHECK
(
node
);
return
air
::
TensorNode
::
make
(
node
->
shape
,
node
->
dtype
,
op
,
0
);
}
UTTensorElementHelper
::
UTTensorElementHelper
(
const
std
::
vector
<
int32_t
>
&
shapes
,
UTTensorElementHelper
::
UTTensorElementHelper
(
const
std
::
vector
<
int32_t
>
&
shapes
,
const
std
::
string
&
axis_name_prefix
)
const
std
::
string
&
axis_name_prefix
)
:
shapes_
(
shapes
),
axis_name_prefix_
(
axis_name_prefix
)
{
:
shapes_
(
shapes
),
axis_name_prefix_
(
axis_name_prefix
)
{
...
@@ -72,8 +157,8 @@ UTTensorElementHelper::UTTensorElementHelper(const std::vector<int32_t> &shapes,
...
@@ -72,8 +157,8 @@ UTTensorElementHelper::UTTensorElementHelper(const std::vector<int32_t> &shapes,
}
}
air
::
Expr
UTTensorElementHelper
::
Elem
(
const
std
::
string
&
name
,
air
::
Expr
UTTensorElementHelper
::
Elem
(
const
std
::
string
&
name
,
uint32_t
dim
,
uint32_t
dim
,
air
::
DataType
dtype
)
const
{
air
::
DataType
dtype
)
const
{
uint32_t
start
=
shapes_
.
size
()
-
dim
;
uint32_t
start
=
shapes_
.
size
()
-
dim
;
return
UTExprBuilder
::
TensorElement
(
return
UTExprBuilder
::
TensorElement
(
name
,
name
,
...
...
tests/unittest_cpp/src/base/ir_checker.cc
0 → 100644
浏览文件 @
6a84977e
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "base/ir_checker.h"
#include <cinttypes>
#include <string>
#include "base/dump_helper.h"
#include "base/expr_builder.h"
namespace
akg
{
int64_t
UTIRCheckHelper
::
GetValueFromImm
(
const
air
::
Expr
&
expr
)
{
const
air
::
IntImm
*
imm_int
=
expr
.
as
<
air
::
IntImm
>
();
if
(
imm_int
!=
nullptr
)
{
return
imm_int
->
value
;
}
const
air
::
ir
::
UIntImm
*
imm_uint
=
expr
.
as
<
air
::
ir
::
UIntImm
>
();
if
(
imm_uint
!=
nullptr
)
{
CHECK
(
imm_uint
->
value
<
INT64_MAX
);
return
static_cast
<
int64_t
>
(
imm_uint
->
value
);
}
return
0
;
}
void
UTProvideChecker
::
Visit_
(
const
air
::
ir
::
For
*
op
)
{
uint64_t
count_top
=
for_count_stack_
.
back
();
int64_t
min
=
UTIRCheckHelper
::
GetValueFromImm
(
op
->
min
);
int64_t
extent
=
UTIRCheckHelper
::
GetValueFromImm
(
op
->
extent
);
CHECK
(
extent
>
min
);
count_top
*=
static_cast
<
uint64_t
>
(
extent
);
for_count_stack_
.
push_back
(
count_top
);
IRVisitor
::
Visit_
(
op
);
for_count_stack_
.
pop_back
();
}
bool
UTProvideChecker
::
CompareDump
(
const
std
::
string
&
dump
,
const
std
::
string
&
target
)
{
if
(
dump
.
compare
(
target
)
==
0
)
{
return
true
;
}
if
(
ignore_args_
)
{
size_t
npos
=
dump
.
find
(
"("
);
return
dump
.
substr
(
0
,
npos
).
compare
(
target
)
==
0
;
}
return
false
;
}
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
UTProvideCheckerForAssign
::
Find
(
const
air
::
NodeRef
&
node
,
const
std
::
string
&
dump_rhs
)
{
dump_rhs_
=
dump_rhs
;
infos_lhs_
.
clear
();
for_count_stack_
.
clear
();
for_count_stack_
.
push_back
(
1
);
Visit
(
node
);
return
infos_lhs_
;
}
void
UTProvideCheckerForAssign
::
Visit_
(
const
air
::
ir
::
Provide
*
op
)
{
std
::
string
dump_expr
=
UTDumpHelper
::
Dump
(
op
->
value
);
if
(
CompareDump
(
dump_expr
,
dump_rhs_
))
{
air
::
Expr
expr_call
=
UTExprBuilder
::
CreateCall
(
op
->
func
,
op
->
args
);
infos_lhs_
.
push_back
(
std
::
make_tuple
(
UTDumpHelper
::
Dump
(
expr_call
),
op
,
for_count_stack_
.
back
()));
}
}
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
UTProvideCheckerForBinary
::
Find
(
const
air
::
NodeRef
&
node
,
UTProvideCheckerForBinary
::
BinaryOpType
op_type
,
const
std
::
string
&
dump_rhs1
,
const
std
::
string
&
dump_rhs2
)
{
op_type_
=
op_type
;
dump_rhs1_
=
dump_rhs1
;
dump_rhs2_
=
dump_rhs2
;
infos_lhs_
.
clear
();
for_count_stack_
.
clear
();
for_count_stack_
.
push_back
(
1
);
if
(
dump_rhs1_
.
empty
()
&&
dump_rhs2_
.
empty
())
{
return
infos_lhs_
;
}
Visit
(
node
);
return
infos_lhs_
;
}
void
UTProvideCheckerForBinary
::
Visit_
(
const
air
::
ir
::
Provide
*
op
)
{
switch
(
op_type_
)
{
case
BinaryOpType
::
kAdd
:
CheckBinary
<
air
::
ir
::
Add
>
(
op
);
break
;
case
BinaryOpType
::
kSub
:
CheckBinary
<
air
::
ir
::
Sub
>
(
op
);
break
;
case
BinaryOpType
::
kMul
:
CheckBinary
<
air
::
ir
::
Mul
>
(
op
);
break
;
case
BinaryOpType
::
kDiv
:
CheckBinary
<
air
::
ir
::
Div
>
(
op
);
break
;
case
BinaryOpType
::
kMod
:
CheckBinary
<
air
::
ir
::
Mod
>
(
op
);
break
;
default:
break
;
}
}
}
// namespace akg
tests/unittest_cpp/src/base/stmt_builder.cc
0 → 100644
浏览文件 @
6a84977e
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "base/stmt_builder.h"
namespace
akg
{
air
::
Stmt
UTStmtBuilder
::
CreateFor
(
const
std
::
string
&
loop_var_name
,
int32_t
min
,
int32_t
extent
,
air
::
Stmt
body
)
{
return
air
::
ir
::
For
::
make
(
UTExprBuilder
::
CreateVar
(
loop_var_name
),
UTExprBuilder
::
IntImm
(
min
),
UTExprBuilder
::
IntImm
(
extent
),
air
::
ir
::
ForType
::
Serial
,
air
::
ir
::
DeviceAPI
::
None
,
body
);
}
air
::
Stmt
UTStmtBuilder
::
CreateRealizeByPlaceholderOp
(
const
air
::
Operation
&
op
,
air
::
Stmt
body
)
{
const
air
::
PlaceholderOpNode
*
node
=
op
.
as
<
const
air
::
PlaceholderOpNode
>
();
CHECK
(
node
);
return
air
::
ir
::
Realize
::
make
(
op
,
0
,
node
->
dtype
,
UTExprBuilder
::
CreateRegion
(
node
->
shape
),
UTExprBuilder
::
BoolImm
(
true
),
body
);
}
air
::
Stmt
UTStmtBuilder
::
CreateProvideAssign
(
air
::
ir
::
FunctionRef
func_dst
,
const
std
::
vector
<
std
::
string
>
&
vars
,
air
::
Expr
src
,
int
value_index
)
{
return
air
::
ir
::
Provide
::
make
(
func_dst
,
value_index
,
src
,
UTExprBuilder
::
CreateVars
(
vars
));
}
}
// namespace akg
tests/unittest_cpp/src/base_test/expr_builder_test.cc
浏览文件 @
6a84977e
...
@@ -18,6 +18,39 @@
...
@@ -18,6 +18,39 @@
#include "base/expr_builder.h"
#include "base/expr_builder.h"
namespace
akg
{
namespace
akg
{
TEST
(
UTExprBuilder
,
IntImm
)
{
air
::
Expr
int1
=
UTExprBuilder
::
IntImm
(
1024
);
std
::
string
dump_int1
=
UTDumpHelper
::
Dump
(
int1
);
EXPECT_EQ
(
dump_int1
,
"1024"
);
air
::
Expr
int2
=
UTExprBuilder
::
IntImm
(
1024
,
air
::
Int
(
64
));
std
::
string
dump_int2
=
UTDumpHelper
::
Dump
(
int2
);
EXPECT_EQ
(
dump_int2
,
"(int64)1024"
);
air
::
Expr
int3
=
UTExprBuilder
::
IntImm
(
1024
,
air
::
Int
(
16
));
std
::
string
dump_int3
=
UTDumpHelper
::
Dump
(
int3
);
EXPECT_EQ
(
dump_int3
,
"(int16)1024"
);
}
TEST
(
UTExprBuilder
,
UIntImm
)
{
air
::
Expr
uint1
=
UTExprBuilder
::
UIntImm
(
1024
);
std
::
string
dump_uint1
=
UTDumpHelper
::
Dump
(
uint1
);
EXPECT_EQ
(
dump_uint1
,
"(uint32)1024"
);
air
::
Expr
uint2
=
UTExprBuilder
::
UIntImm
(
1024
,
air
::
UInt
(
64
));
std
::
string
dump_uint2
=
UTDumpHelper
::
Dump
(
uint2
);
EXPECT_EQ
(
dump_uint2
,
"(uint64)1024"
);
air
::
Expr
uint3
=
UTExprBuilder
::
UIntImm
(
1024
,
air
::
UInt
(
16
));
std
::
string
dump_uint3
=
UTDumpHelper
::
Dump
(
uint3
);
EXPECT_EQ
(
dump_uint3
,
"(uint16)1024"
);
}
TEST
(
UTExprBuilder
,
Bool
)
{
air
::
Expr
bool_true
=
UTExprBuilder
::
BoolImm
(
true
);
std
::
string
dump_bool_true
=
UTDumpHelper
::
Dump
(
bool_true
);
EXPECT_EQ
(
dump_bool_true
,
"(bool)1"
);
air
::
Expr
bool_false
=
UTExprBuilder
::
BoolImm
(
false
);
std
::
string
dump_bool_false
=
UTDumpHelper
::
Dump
(
bool_false
);
EXPECT_EQ
(
dump_bool_false
,
"(bool)0"
);
}
TEST
(
UTExprBuilder
,
CreateShape
)
{
TEST
(
UTExprBuilder
,
CreateShape
)
{
air
::
Array
<
air
::
Expr
>
shape1
=
UTExprBuilder
::
CreateShape
({
1024
});
air
::
Array
<
air
::
Expr
>
shape1
=
UTExprBuilder
::
CreateShape
({
1024
});
std
::
string
dump_shape1
=
UTDumpHelper
::
Dump
(
shape1
);
std
::
string
dump_shape1
=
UTDumpHelper
::
Dump
(
shape1
);
...
@@ -44,6 +77,12 @@ TEST(UTExprBuilder, CreateVars) {
...
@@ -44,6 +77,12 @@ TEST(UTExprBuilder, CreateVars) {
EXPECT_EQ
(
dump_vars
,
"[ax0, ax1, ax2]"
);
EXPECT_EQ
(
dump_vars
,
"[ax0, ax1, ax2]"
);
}
}
TEST
(
UTExprBuilder
,
CreateRange
)
{
air
::
Range
range
=
UTExprBuilder
::
CreateRange
(
0
,
1024
);
std
::
string
dump_range
=
UTDumpHelper
::
Dump
(
range
);
EXPECT_EQ
(
dump_range
,
"range(min=0, ext=1024)"
);
}
TEST
(
UTExprBuilder
,
PlaceholderOpNode
)
{
TEST
(
UTExprBuilder
,
PlaceholderOpNode
)
{
air
::
Operation
node
=
UTExprBuilder
::
PlaceholderOpNode
(
"input"
,
{
16
,
32
,
1024
},
air
::
Float
(
16
));
air
::
Operation
node
=
UTExprBuilder
::
PlaceholderOpNode
(
"input"
,
{
16
,
32
,
1024
},
air
::
Float
(
16
));
std
::
string
dump_node
=
UTDumpHelper
::
Dump
(
node
);
std
::
string
dump_node
=
UTDumpHelper
::
Dump
(
node
);
...
@@ -56,6 +95,13 @@ TEST(UTExprBuilder, TensorElement) {
...
@@ -56,6 +95,13 @@ TEST(UTExprBuilder, TensorElement) {
EXPECT_EQ
(
dump_elem
,
"input(ax0, ax1, ax2)"
);
EXPECT_EQ
(
dump_elem
,
"input(ax0, ax1, ax2)"
);
}
}
TEST
(
UTExprBuilder
,
ElememtOfPlaceholderOp
)
{
air
::
Operation
op
=
UTExprBuilder
::
PlaceholderOpNode
(
"input"
,
{
16
,
32
,
1024
},
air
::
Float
(
16
));
air
::
Expr
elem
=
UTExprBuilder
::
ElementOfPlaceholderOp
(
op
,
{
"ax0"
,
"ax1"
,
"ax2"
});
std
::
string
dump_elem
=
UTDumpHelper
::
Dump
(
elem
);
EXPECT_EQ
(
dump_elem
,
"input(ax0, ax1, ax2)"
);
}
TEST
(
UTTensorElementHelper
,
TensorElement
)
{
TEST
(
UTTensorElementHelper
,
TensorElement
)
{
UTTensorElementHelper
helper
({
16
,
32
,
1024
});
UTTensorElementHelper
helper
({
16
,
32
,
1024
});
std
::
string
dump_elem1
=
UTDumpHelper
::
Dump
(
helper
.
Elem
(
"a"
,
3
));
std
::
string
dump_elem1
=
UTDumpHelper
::
Dump
(
helper
.
Elem
(
"a"
,
3
));
...
...
tests/unittest_cpp/src/base_test/ir_checker_test.cc
0 → 100644
浏览文件 @
6a84977e
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <string>
#include "base/ir_checker.h"
#include "base/expr_builder.h"
#include "base/stmt_builder.h"
namespace
akg
{
TEST
(
UTProvideChecker
,
CompareDump
)
{
EXPECT_EQ
(
UTProvideChecker
().
CompareDump
(
"a(i, j)"
,
"a"
),
false
);
EXPECT_EQ
(
UTProvideChecker
().
CompareDump
(
"a(i, j)"
,
"a(i, j)"
),
true
);
EXPECT_EQ
(
UTProvideChecker
(
true
).
CompareDump
(
"a(i, j)"
,
"a"
),
true
);
EXPECT_EQ
(
UTProvideChecker
(
true
).
CompareDump
(
"a(i, j)"
,
"a(i, j)"
),
true
);
}
class
UTProvideCheckerTest
:
public
testing
::
Test
{
public:
UTProvideCheckerTest
()
:
a_
(
UTExprBuilder
::
PlaceholderOpNode
(
"a"
,
{
1024
},
air
::
Float
(
16
))),
b_
(
UTExprBuilder
::
PlaceholderOpNode
(
"b"
,
{
1024
},
air
::
Float
(
16
))),
c_
(
UTExprBuilder
::
PlaceholderOpNode
(
"c"
,
{
1024
},
air
::
Float
(
16
)))
{}
~
UTProvideCheckerTest
()
=
default
;
air
::
Operation
a_
;
air
::
Operation
b_
;
air
::
Operation
c_
;
};
// class UTProvideCheckerTest
TEST_F
(
UTProvideCheckerTest
,
UTProvideCheckerForAssign
)
{
// b(ax0) = a(ax0)
air
::
Stmt
stmt
=
UTStmtBuilder
::
CreateProvideAssign
(
b_
,
{
"ax0"
},
UTExprBuilder
::
ElementOf
(
a_
,
{
"ax0"
}));
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
infos_lhs
=
UTProvideCheckerForAssign
().
Find
(
stmt
,
"a(ax0)"
);
ASSERT_EQ
(
infos_lhs
.
size
(),
1
);
EXPECT_EQ
(
std
::
get
<
0
>
(
infos_lhs
[
0
]),
"b(ax0)"
);
EXPECT_EQ
(
std
::
get
<
2
>
(
infos_lhs
[
0
]),
1
);
}
TEST_F
(
UTProvideCheckerTest
,
UTProvideCheckerForBinary
)
{
// c(ax0) = (a(ax0) + b(ax0))
air
::
Stmt
stmt
=
UTStmtBuilder
::
CreateProvideBinary
<
air
::
ir
::
Add
>
(
c_
,
{
"ax0"
},
UTExprBuilder
::
ElementOf
(
a_
,
{
"ax0"
}),
UTExprBuilder
::
ElementOf
(
b_
,
{
"ax0"
}));
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
infos_lhs
=
UTProvideCheckerForBinary
().
Find
(
stmt
,
UTProvideCheckerForBinary
::
BinaryOpType
::
kAdd
,
"a(ax0)"
,
"b(ax0)"
);
ASSERT_EQ
(
infos_lhs
.
size
(),
1
);
EXPECT_EQ
(
std
::
get
<
0
>
(
infos_lhs
[
0
]),
"c(ax0)"
);
EXPECT_EQ
(
std
::
get
<
2
>
(
infos_lhs
[
0
]),
1
);
}
class
UTProvideCheckerTest2
:
public
testing
::
Test
{
public:
UTProvideCheckerTest2
()
:
a_
(
UTExprBuilder
::
PlaceholderOpNode
(
"a"
,
{
16
,
32
,
1024
},
air
::
Float
(
16
))),
b_
(
UTExprBuilder
::
PlaceholderOpNode
(
"b"
,
{
16
,
32
,
1024
},
air
::
Float
(
16
))),
c_
(
UTExprBuilder
::
PlaceholderOpNode
(
"c"
,
{
16
,
32
,
1024
},
air
::
Float
(
16
)))
{}
~
UTProvideCheckerTest2
()
=
default
;
air
::
Operation
a_
;
air
::
Operation
b_
;
air
::
Operation
c_
;
};
// class UTProvideCheckerTest
TEST_F
(
UTProvideCheckerTest2
,
UTProvideCheckerForBinary
)
{
air
::
Stmt
stmt
=
UTStmtBuilder
::
CreateFor
(
"i"
,
0
,
16
,
UTStmtBuilder
::
CreateFor
(
"j"
,
0
,
32
,
UTStmtBuilder
::
CreateFor
(
"k"
,
0
,
1024
,
UTStmtBuilder
::
CreateProvideBinary
<
air
::
ir
::
Add
>
(
c_
,
{
"i"
,
"j"
,
"k"
},
UTExprBuilder
::
ElementOf
(
a_
,
{
"i"
,
"j"
,
"k"
}),
UTExprBuilder
::
ElementOf
(
b_
,
{
"i"
,
"j"
,
"k"
})))));
std
::
string
dump_stmt
=
UTDumpHelper
::
Dump
(
stmt
);
EXPECT_EQ
(
dump_stmt
,
"for (i, 0, 16) {
\n
"
" for (j, 0, 32) {
\n
"
" for (k, 0, 1024) {
\n
"
" c(i, j, k) = (a(i, j, k) + b(i, j, k))
\n
"
" }
\n
"
" }
\n
"
"}
\n
"
);
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
infos_lhs
=
UTProvideCheckerForBinary
().
Find
(
stmt
,
UTProvideCheckerForBinary
::
BinaryOpType
::
kAdd
,
"a(i, j, k)"
,
"b(i, j, k)"
);
ASSERT_EQ
(
infos_lhs
.
size
(),
1
);
EXPECT_EQ
(
std
::
get
<
0
>
(
infos_lhs
[
0
]),
"c(i, j, k)"
);
EXPECT_EQ
(
std
::
get
<
2
>
(
infos_lhs
[
0
]),
1024
*
32
*
16
);
}
}
// namespace akg
tests/unittest_cpp/src/base_test/stmt_builder_test.cc
0 → 100644
浏览文件 @
6a84977e
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "gtest/gtest.h"
#include "base/dump_helper.h"
#include "base/expr_builder.h"
#include "base/stmt_builder.h"
namespace
akg
{
TEST
(
UTStmtBuilder
,
CreateProvideAssign
)
{
// b(ax0) = a(ax0)
air
::
Operation
a
=
UTExprBuilder
::
PlaceholderOpNode
(
"a"
,
{
1024
},
air
::
Float
(
16
));
air
::
Operation
b
=
UTExprBuilder
::
PlaceholderOpNode
(
"b"
,
{
1024
},
air
::
Float
(
16
));
air
::
Stmt
stmt
=
UTStmtBuilder
::
CreateProvideAssign
(
b
,
{
"ax0"
},
UTExprBuilder
::
ElementOf
(
a
,
{
"ax0"
}));
std
::
string
dump_stmt
=
UTDumpHelper
::
Dump
(
stmt
);
EXPECT_EQ
(
dump_stmt
,
"b(ax0) = a(ax0)
\n
"
);
}
TEST
(
UTStmtBuilder
,
CreateProvideBinary
)
{
// c(ax0) = a(ax0) + b(ax0)
air
::
Operation
a
=
UTExprBuilder
::
PlaceholderOpNode
(
"a"
,
{
1024
},
air
::
Float
(
16
));
air
::
Operation
b
=
UTExprBuilder
::
PlaceholderOpNode
(
"b"
,
{
1024
},
air
::
Float
(
16
));
air
::
Operation
c
=
UTExprBuilder
::
PlaceholderOpNode
(
"c"
,
{
1024
},
air
::
Float
(
16
));
air
::
Stmt
stmt
=
UTStmtBuilder
::
CreateProvideBinary
<
air
::
ir
::
Add
>
(
c
,
{
"ax0"
},
UTExprBuilder
::
ElementOf
(
a
,
{
"ax0"
}),
UTExprBuilder
::
ElementOf
(
b
,
{
"ax0"
}));
std
::
string
dump_stmt
=
UTDumpHelper
::
Dump
(
stmt
);
EXPECT_EQ
(
dump_stmt
,
"c(ax0) = (a(ax0) + b(ax0))
\n
"
);
}
TEST
(
UTStmtBuilder
,
CreateFor
)
{
/*
* for (i, 0, 1024) {
* c(i) = (a(i) + b(i))
* }
*/
air
::
Operation
a
=
UTExprBuilder
::
PlaceholderOpNode
(
"a"
,
{
1024
},
air
::
Float
(
16
));
air
::
Operation
b
=
UTExprBuilder
::
PlaceholderOpNode
(
"b"
,
{
1024
},
air
::
Float
(
16
));
air
::
Operation
c
=
UTExprBuilder
::
PlaceholderOpNode
(
"c"
,
{
1024
},
air
::
Float
(
16
));
air
::
Stmt
stmt_for
=
UTStmtBuilder
::
CreateFor
(
"i"
,
0
,
1024
,
UTStmtBuilder
::
CreateProvideBinary
<
air
::
ir
::
Add
>
(
c
,
{
"i"
},
UTExprBuilder
::
ElementOf
(
a
,
{
"i"
}),
UTExprBuilder
::
ElementOf
(
b
,
{
"i"
})));
std
::
string
dump_stmt_for
=
UTDumpHelper
::
Dump
(
stmt_for
);
EXPECT_EQ
(
dump_stmt_for
,
"for (i, 0, 1024) {
\n
"
" c(i) = (a(i) + b(i))
\n
"
"}
\n
"
);
}
TEST
(
UTStmtBuilder
,
CreateRealizeByPlaceholderOp
)
{
air
::
Operation
a
=
UTExprBuilder
::
PlaceholderOpNode
(
"a"
,
{
1024
},
air
::
Float
(
16
));
air
::
Operation
b
=
UTExprBuilder
::
PlaceholderOpNode
(
"b"
,
{
1024
},
air
::
Float
(
16
));
air
::
Operation
c
=
UTExprBuilder
::
PlaceholderOpNode
(
"c"
,
{
1024
},
air
::
Float
(
16
));
air
::
Stmt
stmt_realize
=
UTStmtBuilder
::
CreateRealizeByPlaceholderOp
(
c
,
UTStmtBuilder
::
CreateFor
(
"i"
,
0
,
1024
,
UTStmtBuilder
::
CreateProvideBinary
<
air
::
ir
::
Add
>
(
c
,
{
"i"
},
UTExprBuilder
::
ElementOf
(
a
,
{
"i"
}),
UTExprBuilder
::
ElementOf
(
b
,
{
"i"
}))));
std
::
string
dump_stmt_realize
=
UTDumpHelper
::
Dump
(
stmt_realize
);
EXPECT_EQ
(
dump_stmt_realize
,
"realize c<float16>([0, 1024]) {
\n
"
" for (i, 0, 1024) {
\n
"
" c(i) = (a(i) + b(i))
\n
"
" }
\n
"
"}
\n
"
);
}
}
// namespace akg
tests/unittest_cpp/src/pass_test/auto_poly_test.cc
0 → 100644
浏览文件 @
6a84977e
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <tvm/ir.h>
#include "base/dump_helper.h"
#include "base/expr_builder.h"
#include "base/ir_checker.h"
#include "base/stmt_builder.h"
#include "pass_test_base/auto_poly_test_base.h"
#define private public
#define protected public
#include "ir_pass.h"
#undef protected
#undef private
#include "codegen/util.h"
#include "contrib/cce_parm/cceconf.h"
namespace
akg
{
/* AutoPolyTest1: test for to_three_address
* Input pattern:
* for (i0, 0, 32) {
* for (i1, 0, 1024) {
* out_0(i1) = b(i1) + c(i1)
* out(i0, i1) = out_0(i1) + a(i0, i1)
* }
* }
*
* Expect output:
* for (cc1, 0, 2) {
* for (cc2, 0, 16) {
* for (cc3, 0, 1024) {
* out_0_local_UB(cc2) = (b_local_UB(cc2) + c_local_UB(cc2))
* out_local_UB(cc3, cc2) = (out_0_local_UB(cc2) + a_local_UB(cc3, cc2))
* }
* }
* }
*
* IR Check:
* count for (b_local_UB + c_local_UB): 32 * 1024
*/
class
AutoPolyTest1
:
public
AutoPolyTestBase
{
public:
AutoPolyTest1
()
{
Construct
();
}
~
AutoPolyTest1
()
=
default
;
void
Construct
()
{
a_
=
UTExprBuilder
::
PlaceholderOpNode
(
"a"
,
{
32
,
1024
},
air
::
Float
(
16
));
b_
=
UTExprBuilder
::
PlaceholderOpNode
(
"b"
,
{
1024
},
air
::
Float
(
16
));
c_
=
UTExprBuilder
::
PlaceholderOpNode
(
"c"
,
{
1024
},
air
::
Float
(
16
));
out_
=
UTExprBuilder
::
PlaceholderOpNode
(
"out"
,
{
32
,
1024
},
air
::
Float
(
16
));
out_0_
=
UTExprBuilder
::
PlaceholderOpNode
(
"out_0"
,
{
1024
},
air
::
Float
(
16
));
stmt_
=
air
::
ir
::
AttrStmt
::
make
(
out_0_
,
"realize_scope"
,
air
::
ir
::
StringImm
::
make
(
""
),
UTStmtBuilder
::
CreateRealizeByPlaceholderOp
(
out_0_
,
air
::
ir
::
AttrStmt
::
make
(
out_
,
"realize_scope"
,
air
::
ir
::
StringImm
::
make
(
""
),
UTStmtBuilder
::
CreateRealizeByPlaceholderOp
(
out_
,
air
::
ir
::
ProducerConsumer
::
make
(
out_
,
true
,
UTStmtBuilder
::
CreateFor
(
"i0"
,
0
,
32
,
UTStmtBuilder
::
CreateFor
(
"i1"
,
0
,
1024
,
air
::
ir
::
Block
::
make
(
UTStmtBuilder
::
CreateProvideBinary
<
air
::
ir
::
Add
>
(
out_0_
,
{
"i1"
},
UTExprBuilder
::
ElementOf
(
b_
,
{
"i1"
}),
UTExprBuilder
::
ElementOf
(
c_
,
{
"i1"
})),
UTStmtBuilder
::
CreateProvideBinary
<
air
::
ir
::
Add
>
(
out_
,
{
"i0"
,
"i1"
},
UTExprBuilder
::
ElementOf
(
out_0_
,
{
"i1"
}),
UTExprBuilder
::
ElementOf
(
a_
,
{
"i0"
,
"i1"
}))))))))));
t_a_
=
UTExprBuilder
::
CreateTensorByPlaceholder
(
a_
);
t_b_
=
UTExprBuilder
::
CreateTensorByPlaceholder
(
b_
);
t_c_
=
UTExprBuilder
::
CreateTensorByPlaceholder
(
c_
);
t_out_
=
UTExprBuilder
::
CreateTensorByPlaceholder
(
out_
);
RegisterTensor
(
t_a_
);
RegisterTensor
(
t_b_
);
RegisterTensor
(
t_c_
);
RegisterTensor
(
t_out_
);
}
air
::
Operation
a_
;
air
::
Operation
b_
;
air
::
Operation
c_
;
air
::
Tensor
t_a_
;
air
::
Tensor
t_b_
;
air
::
Tensor
t_c_
;
air
::
Operation
out_
;
air
::
Tensor
t_out_
;
air
::
Operation
out_0_
;
air
::
Stmt
stmt_
;
};
// class AutoPolyTest1
TEST_F
(
AutoPolyTest1
,
RunPass
)
{
SetRunMode
(
"cloud"
);
air
::
Array
<
air
::
NodeRef
>
stmts_out
=
ir
::
AutoPoly
(
stmt_
,
binds_
,
global_attrs_
,
false
,
false
);
ASSERT_EQ
(
stmts_out
.
size
(),
2
);
air
::
NodeRef
stmt
=
stmts_out
[
0
];
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
infos_lhs
=
UTProvideCheckerForBinary
(
true
).
Find
(
stmt
,
UTProvideCheckerForBinary
::
BinaryOpType
::
kAdd
,
"b_local_UB"
,
"c_local_UB"
);
ASSERT_EQ
(
infos_lhs
.
size
(),
1
);
EXPECT_EQ
(
std
::
get
<
2
>
(
infos_lhs
[
0
]),
2
*
16
*
1024
);
}
/* AutoPolyTest2: test for to_three_address
* Input pattern:
* for (i1, 0, 32) {
* out_0(i1) = b(i1) + c(i1)
* for (i0, 0, 1024) {
* out(i0, i1) = out_0(i1) + a(i0, i1)
* }
* }
*
* Expect output:
* for (cc1, 0, 2) {
* for (cc2, 0, 1024) {
* out_0_local_UB(cc2) = (b_local_UB(cc2) + c_local_UB(cc2))
* }
* for (cc2, 0, 1024) {
* for (cc3, 0, 16) {
* out_local_UB(cc3, cc2) = (out_0_local_UB(cc2) + a_local_UB(cc3, cc2))
* }
* }
* }
*
* IR Check:
* count for (b_local_UB + c_local_UB): 2 * 1024
*/
class
AutoPolyTest2
:
public
AutoPolyTestBase
{
public:
AutoPolyTest2
()
{
Construct
();
}
~
AutoPolyTest2
()
=
default
;
void
Construct
()
{
a_
=
UTExprBuilder
::
PlaceholderOpNode
(
"a"
,
{
32
,
1024
},
air
::
Float
(
16
));
b_
=
UTExprBuilder
::
PlaceholderOpNode
(
"b"
,
{
1024
},
air
::
Float
(
16
));
c_
=
UTExprBuilder
::
PlaceholderOpNode
(
"c"
,
{
1024
},
air
::
Float
(
16
));
out_
=
UTExprBuilder
::
PlaceholderOpNode
(
"out"
,
{
32
,
1024
},
air
::
Float
(
16
));
out_0_
=
UTExprBuilder
::
PlaceholderOpNode
(
"out_0"
,
{
1024
},
air
::
Float
(
16
));
stmt_
=
air
::
ir
::
AttrStmt
::
make
(
out_0_
,
"realize_scope"
,
air
::
ir
::
StringImm
::
make
(
""
),
UTStmtBuilder
::
CreateRealizeByPlaceholderOp
(
out_0_
,
air
::
ir
::
AttrStmt
::
make
(
out_
,
"realize_scope"
,
air
::
ir
::
StringImm
::
make
(
""
),
UTStmtBuilder
::
CreateRealizeByPlaceholderOp
(
out_
,
air
::
ir
::
ProducerConsumer
::
make
(
out_
,
true
,
UTStmtBuilder
::
CreateFor
(
"i1"
,
0
,
1024
,
air
::
ir
::
Block
::
make
(
UTStmtBuilder
::
CreateProvideBinary
<
air
::
ir
::
Add
>
(
out_0_
,
{
"i1"
},
UTExprBuilder
::
ElementOf
(
b_
,
{
"i1"
}),
UTExprBuilder
::
ElementOf
(
c_
,
{
"i1"
})),
UTStmtBuilder
::
CreateFor
(
"i0"
,
0
,
32
,
UTStmtBuilder
::
CreateProvideBinary
<
air
::
ir
::
Add
>
(
out_
,
{
"i0"
,
"i1"
},
UTExprBuilder
::
ElementOf
(
out_0_
,
{
"i1"
}),
UTExprBuilder
::
ElementOf
(
a_
,
{
"i0"
,
"i1"
}))))))))));
t_a_
=
UTExprBuilder
::
CreateTensorByPlaceholder
(
a_
);
t_b_
=
UTExprBuilder
::
CreateTensorByPlaceholder
(
b_
);
t_c_
=
UTExprBuilder
::
CreateTensorByPlaceholder
(
c_
);
t_out_
=
UTExprBuilder
::
CreateTensorByPlaceholder
(
out_
);
RegisterTensor
(
t_a_
);
RegisterTensor
(
t_b_
);
RegisterTensor
(
t_c_
);
RegisterTensor
(
t_out_
);
}
air
::
Operation
a_
;
air
::
Operation
b_
;
air
::
Operation
c_
;
air
::
Tensor
t_a_
;
air
::
Tensor
t_b_
;
air
::
Tensor
t_c_
;
air
::
Operation
out_
;
air
::
Tensor
t_out_
;
air
::
Operation
out_0_
;
air
::
Stmt
stmt_
;
};
// class AutoPolyTest2
TEST_F
(
AutoPolyTest2
,
RunPass
)
{
SetRunMode
(
"cloud"
);
air
::
Array
<
air
::
NodeRef
>
stmts_out
=
ir
::
AutoPoly
(
stmt_
,
binds_
,
global_attrs_
,
false
,
false
);
ASSERT_EQ
(
stmts_out
.
size
(),
2
);
air
::
NodeRef
stmt
=
stmts_out
[
0
];
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
infos_lhs
=
UTProvideCheckerForBinary
(
true
).
Find
(
stmt
,
UTProvideCheckerForBinary
::
BinaryOpType
::
kAdd
,
"b_local_UB"
,
"c_local_UB"
);
ASSERT_EQ
(
infos_lhs
.
size
(),
1
);
EXPECT_EQ
(
std
::
get
<
2
>
(
infos_lhs
[
0
]),
2
*
1024
);
}
/* AutoPolyTest3: test for to_three_address
* Input pattern:
* for (i0, 0, 1024) {
* out_0(i0) = b(i0) + c(i0)
* }
* for (i1, 0, 32) {
* for (i0, 0, 1024) {
* out(i0, i1) = out_0(i1) + a(i0, i1)
* }
* }
*
* Expect output:
* for (cc1, 0, 2) {
* for (cc2, 0, 1024) {
* out_0_local_UB(cc2) = (b_local_UB(cc2) + c_local_UB(cc2))
* }
* for (cc2, 0, 1024) {
* for (cc3, 0, 16) {
* out_local_UB(cc3, cc2) = (out_0_local_UB(cc2) + a_local_UB(cc3, cc2))
* }
* }
* }
*
* IR Check:
* count for (b_local_UB + c_local_UB): 2 * 1024
*/
class
AutoPolyTest3
:
public
AutoPolyTestBase
{
public:
AutoPolyTest3
()
{
Construct
();
}
~
AutoPolyTest3
()
=
default
;
void
Construct
()
{
a_
=
UTExprBuilder
::
PlaceholderOpNode
(
"a"
,
{
32
,
1024
},
air
::
Float
(
16
));
b_
=
UTExprBuilder
::
PlaceholderOpNode
(
"b"
,
{
1024
},
air
::
Float
(
16
));
c_
=
UTExprBuilder
::
PlaceholderOpNode
(
"c"
,
{
1024
},
air
::
Float
(
16
));
out_
=
UTExprBuilder
::
PlaceholderOpNode
(
"out"
,
{
32
,
1024
},
air
::
Float
(
16
));
out_0_
=
UTExprBuilder
::
PlaceholderOpNode
(
"out_0"
,
{
1024
},
air
::
Float
(
16
));
stmt_
=
air
::
ir
::
AttrStmt
::
make
(
out_0_
,
"realize_scope"
,
air
::
ir
::
StringImm
::
make
(
""
),
UTStmtBuilder
::
CreateRealizeByPlaceholderOp
(
out_0_
,
air
::
ir
::
AttrStmt
::
make
(
out_
,
"realize_scope"
,
air
::
ir
::
StringImm
::
make
(
""
),
UTStmtBuilder
::
CreateRealizeByPlaceholderOp
(
out_
,
air
::
ir
::
ProducerConsumer
::
make
(
out_
,
true
,
air
::
ir
::
Block
::
make
(
UTStmtBuilder
::
CreateFor
(
"i0"
,
0
,
1024
,
UTStmtBuilder
::
CreateProvideBinary
<
air
::
ir
::
Add
>
(
out_0_
,
{
"i0"
},
UTExprBuilder
::
ElementOf
(
b_
,
{
"i0"
}),
UTExprBuilder
::
ElementOf
(
c_
,
{
"i0"
}))),
UTStmtBuilder
::
CreateFor
(
"i0"
,
0
,
32
,
UTStmtBuilder
::
CreateFor
(
"i1"
,
0
,
1024
,
UTStmtBuilder
::
CreateProvideBinary
<
air
::
ir
::
Add
>
(
out_
,
{
"i0"
,
"i1"
},
UTExprBuilder
::
ElementOf
(
out_0_
,
{
"i1"
}),
UTExprBuilder
::
ElementOf
(
a_
,
{
"i0"
,
"i1"
}))))))))));
t_a_
=
UTExprBuilder
::
CreateTensorByPlaceholder
(
a_
);
t_b_
=
UTExprBuilder
::
CreateTensorByPlaceholder
(
b_
);
t_c_
=
UTExprBuilder
::
CreateTensorByPlaceholder
(
c_
);
t_out_
=
UTExprBuilder
::
CreateTensorByPlaceholder
(
out_
);
RegisterTensor
(
t_a_
);
RegisterTensor
(
t_b_
);
RegisterTensor
(
t_c_
);
RegisterTensor
(
t_out_
);
}
air
::
Operation
a_
;
air
::
Operation
b_
;
air
::
Operation
c_
;
air
::
Tensor
t_a_
;
air
::
Tensor
t_b_
;
air
::
Tensor
t_c_
;
air
::
Operation
out_
;
air
::
Tensor
t_out_
;
air
::
Operation
out_0_
;
air
::
Stmt
stmt_
;
};
// class AutoPolyTest3
TEST_F
(
AutoPolyTest3
,
RunPass
)
{
SetRunMode
(
"cloud"
);
air
::
Array
<
air
::
NodeRef
>
stmts_out
=
ir
::
AutoPoly
(
stmt_
,
binds_
,
global_attrs_
,
false
,
false
);
ASSERT_EQ
(
stmts_out
.
size
(),
2
);
air
::
NodeRef
stmt
=
stmts_out
[
0
];
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
infos_lhs
=
UTProvideCheckerForBinary
(
true
).
Find
(
stmt
,
UTProvideCheckerForBinary
::
BinaryOpType
::
kAdd
,
"b_local_UB"
,
"c_local_UB"
);
ASSERT_EQ
(
infos_lhs
.
size
(),
1
);
EXPECT_EQ
(
std
::
get
<
2
>
(
infos_lhs
[
0
]),
2
*
1024
);
}
}
// namespace akg
tests/unittest_cpp/src/pass_test/to_three_address_test.cc
浏览文件 @
6a84977e
...
@@ -15,8 +15,10 @@
...
@@ -15,8 +15,10 @@
*/
*/
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <tvm/ir.h>
#include <tvm/ir.h>
#include "base/expr_builder.h"
#include "base/dump_helper.h"
#include "base/dump_helper.h"
#include "base/expr_builder.h"
#include "base/ir_checker.h"
#include "base/stmt_builder.h"
#define private public
#define private public
#define protected public
#define protected public
#include "pass/to_three_address.cc"
#include "pass/to_three_address.cc"
...
@@ -71,4 +73,74 @@ TEST_F(ThreeAddressExprMutatorTest, MutateBinaryOp_Add) {
...
@@ -71,4 +73,74 @@ TEST_F(ThreeAddressExprMutatorTest, MutateBinaryOp_Add) {
Expr
expr_m
=
mutator_
.
Mutate
(
expr
);
Expr
expr_m
=
mutator_
.
Mutate
(
expr
);
EXPECT_NE
(
mutator_
.
imm_ops
.
size
(),
0
);
EXPECT_NE
(
mutator_
.
imm_ops
.
size
(),
0
);
}
}
class
PassTestToThreeAddress1
:
public
::
testing
::
Test
{
public:
PassTestToThreeAddress1
()
{
Construct
();
}
~
PassTestToThreeAddress1
()
=
default
;
void
Construct
()
{
a_
=
UTExprBuilder
::
PlaceholderOpNode
(
"a"
,
{
1024
},
air
::
Float
(
16
));
b_
=
UTExprBuilder
::
PlaceholderOpNode
(
"b"
,
{
32
,
1024
},
air
::
Float
(
16
));
c_
=
UTExprBuilder
::
PlaceholderOpNode
(
"c"
,
{
1024
},
air
::
Float
(
16
));
out_
=
UTExprBuilder
::
PlaceholderOpNode
(
"out"
,
{
32
,
1024
},
air
::
Float
(
16
));
stmt
=
air
::
ir
::
AttrStmt
::
make
(
out_
,
""
,
UTExprBuilder
::
IntImm
(
1
),
UTStmtBuilder
::
CreateRealizeByPlaceholderOp
(
out_
,
air
::
ir
::
ProducerConsumer
::
make
(
out_
,
true
,
UTStmtBuilder
::
CreateFor
(
"i"
,
0
,
32
,
UTStmtBuilder
::
CreateFor
(
"j"
,
0
,
1024
,
UTStmtBuilder
::
CreateProvideBinary
<
air
::
ir
::
Add
>
(
out_
,
{
"i"
,
"j"
},
air
::
ir
::
Add
::
make
(
UTExprBuilder
::
ElementOf
(
a_
,
{
"j"
}),
UTExprBuilder
::
ElementOf
(
b_
,
{
"i"
,
"j"
})),
UTExprBuilder
::
ElementOf
(
c_
,
{
"j"
})))))));
}
air
::
Operation
a_
;
air
::
Operation
b_
;
air
::
Operation
c_
;
air
::
Operation
out_
;
air
::
Stmt
stmt
;
};
// class PassTestToThreeAddress1
TEST_F
(
PassTestToThreeAddress1
,
CaseCheck
)
{
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
infos_lhs
=
UTProvideCheckerForAssign
().
Find
(
stmt
,
"((a(j) + b(i, j)) + c(j))"
);
ASSERT_EQ
(
infos_lhs
.
size
(),
1
);
EXPECT_EQ
(
std
::
get
<
0
>
(
infos_lhs
[
0
]),
"out(i, j)"
);
EXPECT_EQ
(
std
::
get
<
2
>
(
infos_lhs
[
0
]),
32
*
1024
);
}
TEST_F
(
PassTestToThreeAddress1
,
TestPass
)
{
Stmt
stmt_out
=
ir
::
ToThreeAddress
(
stmt
,
false
,
0
,
true
);
/* current implementation
* out_2(i, j) = b(i, j)
* out_3(i, j) = (a(j) + out_2(i, j))
* out(i, j) = (out_3(i, j) + c(j))
*/
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
info1
=
UTProvideCheckerForAssign
().
Find
(
stmt_out
,
"b(i, j)"
);
ASSERT_EQ
(
info1
.
size
(),
1
);
std
::
string
dump_b_target
=
std
::
get
<
0
>
(
info1
[
0
]);
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
info2
=
UTProvideCheckerForBinary
().
Find
(
stmt_out
,
UTProvideCheckerForBinary
::
BinaryOpType
::
kAdd
,
"a(j)"
,
dump_b_target
);
ASSERT_EQ
(
info2
.
size
(),
1
);
std
::
string
dump_sum1_target
=
std
::
get
<
0
>
(
info2
[
0
]);
EXPECT_EQ
(
std
::
get
<
2
>
(
info2
[
0
]),
32
*
1024
);
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
info3
=
UTProvideCheckerForBinary
().
Find
(
stmt_out
,
UTProvideCheckerForBinary
::
BinaryOpType
::
kAdd
,
dump_sum1_target
,
"c(j)"
);
ASSERT_EQ
(
info3
.
size
(),
1
);
EXPECT_EQ
(
std
::
get
<
0
>
(
info3
[
0
]),
"out(i, j)"
);
EXPECT_EQ
(
std
::
get
<
2
>
(
info3
[
0
]),
32
*
1024
);
}
}
// namespace akg
}
// namespace akg
tests/unittest_cpp/src/pass_test_base/auto_poly_test_base.cc
0 → 100644
浏览文件 @
6a84977e
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pass_test_base/auto_poly_test_base.h"
namespace
akg
{
std
::
map
<
std
::
string
,
std
::
string
>
AutoPolyTestBase
::
map_mode_
=
AutoPolyTestBase
::
InitMapMode
();
std
::
map
<
std
::
string
,
std
::
string
>
AutoPolyTestBase
::
InitMapMode
()
{
std
::
map
<
std
::
string
,
std
::
string
>
res
;
res
[
"cloud"
]
=
"1.6"
;
res
[
"mini"
]
=
"1.1"
;
res
[
"phoenix"
]
=
"3.5"
;
res
[
"orlando"
]
=
"3.3"
;
return
res
;
}
void
AutoPolyTestBase
::
SetRunMode
(
const
std
::
string
&
mode
)
{
auto
it
=
map_mode_
.
find
(
mode
);
CHECK
(
it
!=
map_mode_
.
end
());
cceconf
::
CceConf
::
getInstance
()
->
setSection
(
it
->
second
);
}
void
AutoPolyTestBase
::
RegisterTensor
(
const
air
::
Tensor
&
tensor
)
{
const
TensorNode
*
tensor_node
=
tensor
.
as
<
TensorNode
>
();
std
::
string
name
=
tensor_node
->
op
->
name
;
air
::
Buffer
buf
=
air
::
BufferNode
::
make
(
air
::
Variable
::
make
(
Handle
(),
name
),
tensor_node
->
dtype
,
tensor_node
->
shape
,
Array
<
Expr
>
(),
Expr
(),
name
,
""
,
-
1
,
0
,
air
::
BufferType
::
kDefault
);
binds_
.
Set
(
GetRef
<
Tensor
>
(
tensor_node
),
buf
);
}
}
// namespace akg
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录