Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
akg
提交
05313741
A
akg
项目概览
MindSpore
/
akg
通知
58
Star
7
Fork
7
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
akg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
05313741
编写于
8月 06, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 06, 2020
浏览文件
操作
浏览文件
下载
差异文件
!102 Add Loop Mutator & Relu to pass three address
Merge pull request !102 from ConnZhai/conn
上级
bf3a6c3a
05a96ace
变更
2
展开全部
隐藏空白更改
内联
并排
Showing
2 changed file
with
477 addition
and
309 deletion
+477
-309
src/pass/to_three_address.cc
src/pass/to_three_address.cc
+446
-267
tests/unittest_cpp/src/pass_test/to_three_address_test.cc
tests/unittest_cpp/src/pass_test/to_three_address_test.cc
+31
-42
未找到文件。
src/pass/to_three_address.cc
浏览文件 @
05313741
此差异已折叠。
点击以展开。
tests/unittest_cpp/src/pass_test/to_three_address_test.cc
浏览文件 @
05313741
...
@@ -36,12 +36,7 @@ TEST(ToThreeAddressTest, BuildCase1) {
...
@@ -36,12 +36,7 @@ TEST(ToThreeAddressTest, BuildCase1) {
UTTensorElementHelper
th
({
16
,
32
,
1024
});
UTTensorElementHelper
th
({
16
,
32
,
1024
});
using
Add
=
air
::
ir
::
Add
;
using
Add
=
air
::
ir
::
Add
;
// a(ax1, ax2) + b(ax2) + c(ax0, ax1, ax2) + d(ax2)
// a(ax1, ax2) + b(ax2) + c(ax0, ax1, ax2) + d(ax2)
air
::
Expr
expr
=
air
::
Expr
expr
=
Add
::
make
(
Add
::
make
(
Add
::
make
(
th
.
Elem
(
"a"
,
2
),
th
.
Elem
(
"b"
,
1
)),
th
.
Elem
(
"c"
,
3
)),
th
.
Elem
(
"d"
,
1
));
Add
::
make
(
Add
::
make
(
Add
::
make
(
th
.
Elem
(
"a"
,
2
),
th
.
Elem
(
"b"
,
1
)),
th
.
Elem
(
"c"
,
3
)),
th
.
Elem
(
"d"
,
1
));
std
::
string
dump_expr
=
UTDumpHelper
::
Dump
(
expr
);
std
::
string
dump_expr
=
UTDumpHelper
::
Dump
(
expr
);
EXPECT_EQ
(
dump_expr
,
"(((a(ax1, ax2) + b(ax2)) + c(ax0, ax1, ax2)) + d(ax2))"
);
EXPECT_EQ
(
dump_expr
,
"(((a(ax1, ax2) + b(ax2)) + c(ax0, ax1, ax2)) + d(ax2))"
);
}
}
...
@@ -49,16 +44,16 @@ TEST(ToThreeAddressTest, BuildCase1) {
...
@@ -49,16 +44,16 @@ TEST(ToThreeAddressTest, BuildCase1) {
class
ThreeAddressExprMutatorTest
:
public
testing
::
Test
{
class
ThreeAddressExprMutatorTest
:
public
testing
::
Test
{
public:
public:
ThreeAddressExprMutatorTest
()
ThreeAddressExprMutatorTest
()
:
mutator_
(
air
::
TensorNode
::
make
(
:
mutator_
(
air
::
TensorNode
::
make
(
UTExprBuilder
::
CreateShape
(
shape_
),
// shape
UTExprBuilder
::
CreateShape
(
shape_
),
// sha
pe
dtype_
,
// dty
pe
dtype_
,
// dtype
UTExprBuilder
::
PlaceholderOpNode
(
"out"
,
shape_
),
// op
UTExprBuilder
::
PlaceholderOpNode
(
"out"
,
shape_
),
// op
0
),
// index
0
),
// index
UTExprBuilder
::
CreateVars
({
"ax0"
,
"ax1"
,
"ax2"
}),
// args
UTExprBuilder
::
CreateVars
({
"ax0"
,
"ax1"
,
"ax2"
}),
// args
UTExprBuilder
::
CreateVars
({
"ax0"
,
"ax1"
,
"ax2"
}),
// args
UTExprBuilder
::
CreateShape
(
shape_
),
// shape
UTExprBuilder
::
CreateShape
(
shape_
),
// shape
std
::
unordered_set
<
const
Call
*>
(),
// broadcast
std
::
unordered_set
<
const
Call
*>
(),
// broadcast
false
,
// IsReductionOp
false
,
// IsReductionOp
false
)
{}
// cross_stmt_simplify
false
)
{}
// cross_stmt_simplify
~
ThreeAddressExprMutatorTest
()
=
default
;
~
ThreeAddressExprMutatorTest
()
=
default
;
std
::
vector
<
int32_t
>
shape_
=
{
16
,
32
,
1024
};
std
::
vector
<
int32_t
>
shape_
=
{
16
,
32
,
1024
};
...
@@ -75,10 +70,8 @@ TEST_F(ThreeAddressExprMutatorTest, MutateBinaryOp_Add) {
...
@@ -75,10 +70,8 @@ TEST_F(ThreeAddressExprMutatorTest, MutateBinaryOp_Add) {
}
}
class
PassTestToThreeAddress1
:
public
::
testing
::
Test
{
class
PassTestToThreeAddress1
:
public
::
testing
::
Test
{
public:
public:
PassTestToThreeAddress1
()
{
PassTestToThreeAddress1
()
{
Construct
();
}
Construct
();
}
~
PassTestToThreeAddress1
()
=
default
;
~
PassTestToThreeAddress1
()
=
default
;
void
Construct
()
{
void
Construct
()
{
a_
=
UTExprBuilder
::
PlaceholderOpNode
(
"a"
,
{
1024
},
air
::
Float
(
16
));
a_
=
UTExprBuilder
::
PlaceholderOpNode
(
"a"
,
{
1024
},
air
::
Float
(
16
));
...
@@ -86,20 +79,18 @@ class PassTestToThreeAddress1 : public ::testing::Test {
...
@@ -86,20 +79,18 @@ class PassTestToThreeAddress1 : public ::testing::Test {
c_
=
UTExprBuilder
::
PlaceholderOpNode
(
"c"
,
{
1024
},
air
::
Float
(
16
));
c_
=
UTExprBuilder
::
PlaceholderOpNode
(
"c"
,
{
1024
},
air
::
Float
(
16
));
out_
=
UTExprBuilder
::
PlaceholderOpNode
(
"out"
,
{
32
,
1024
},
air
::
Float
(
16
));
out_
=
UTExprBuilder
::
PlaceholderOpNode
(
"out"
,
{
32
,
1024
},
air
::
Float
(
16
));
stmt
=
air
::
ir
::
AttrStmt
::
make
(
stmt
=
air
::
ir
::
AttrStmt
::
make
(
out_
,
""
,
UTExprBuilder
::
IntImm
(
1
),
out_
,
""
,
UTExprBuilder
::
IntImm
(
1
),
UTStmtBuilder
::
CreateRealizeByPlaceholderOp
(
UTStmtBuilder
::
CreateRealizeByPlaceholderOp
(
out_
,
out_
,
air
::
ir
::
ProducerConsumer
::
make
(
air
::
ir
::
ProducerConsumer
::
make
(
out_
,
true
,
out_
,
true
,
UTStmtBuilder
::
CreateFor
(
"i"
,
0
,
32
,
UTStmtBuilder
::
CreateFor
(
UTStmtBuilder
::
CreateFor
(
"j"
,
0
,
1024
,
"i"
,
0
,
32
,
UTStmtBuilder
::
CreateFor
(
"j"
,
0
,
1024
,
UTStmtBuilder
::
CreateProvideBinary
<
air
::
ir
::
Add
>
(
UTStmtBuilder
::
CreateProvideBinary
<
air
::
ir
::
Add
>
(
out_
,
{
"i"
,
"j"
},
out_
,
{
"i"
,
"j"
},
air
::
ir
::
Add
::
make
(
air
::
ir
::
Add
::
make
(
UTExprBuilder
::
ElementOf
(
a_
,
{
"j"
}),
UTExprBuilder
::
ElementOf
(
b_
,
{
"i"
,
"j"
})),
UTExprBuilder
::
ElementOf
(
a_
,
{
"j"
}),
UTExprBuilder
::
ElementOf
(
c_
,
{
"j"
})))))));
UTExprBuilder
::
ElementOf
(
b_
,
{
"i"
,
"j"
})),
UTExprBuilder
::
ElementOf
(
c_
,
{
"j"
})))))));
}
}
air
::
Operation
a_
;
air
::
Operation
a_
;
...
@@ -110,8 +101,8 @@ class PassTestToThreeAddress1 : public ::testing::Test {
...
@@ -110,8 +101,8 @@ class PassTestToThreeAddress1 : public ::testing::Test {
};
// class PassTestToThreeAddress1
};
// class PassTestToThreeAddress1
TEST_F
(
PassTestToThreeAddress1
,
CaseCheck
)
{
TEST_F
(
PassTestToThreeAddress1
,
CaseCheck
)
{
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
infos_lhs
=
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
infos_lhs
=
UTProvideCheckerForAssign
().
Find
(
stmt
,
"((a(j) + b(i, j)) + c(j))"
);
UTProvideCheckerForAssign
().
Find
(
stmt
,
"((a(j) + b(i, j)) + c(j))"
);
ASSERT_EQ
(
infos_lhs
.
size
(),
1
);
ASSERT_EQ
(
infos_lhs
.
size
(),
1
);
EXPECT_EQ
(
std
::
get
<
0
>
(
infos_lhs
[
0
]),
"out(i, j)"
);
EXPECT_EQ
(
std
::
get
<
0
>
(
infos_lhs
[
0
]),
"out(i, j)"
);
EXPECT_EQ
(
std
::
get
<
2
>
(
infos_lhs
[
0
]),
32
*
1024
);
EXPECT_EQ
(
std
::
get
<
2
>
(
infos_lhs
[
0
]),
32
*
1024
);
...
@@ -124,21 +115,19 @@ TEST_F(PassTestToThreeAddress1, TestPass) {
...
@@ -124,21 +115,19 @@ TEST_F(PassTestToThreeAddress1, TestPass) {
* out_3(i, j) = (a(j) + out_2(i, j))
* out_3(i, j) = (a(j) + out_2(i, j))
* out(i, j) = (out_3(i, j) + c(j))
* out(i, j) = (out_3(i, j) + c(j))
*/
*/
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
info1
=
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
info1
=
UTProvideCheckerForAssign
().
Find
(
stmt_out
,
"b(i, j)"
);
UTProvideCheckerForAssign
().
Find
(
stmt_out
,
"b(i, j)"
);
ASSERT_EQ
(
info1
.
size
(),
1
);
ASSERT_EQ
(
info1
.
size
(),
1
);
std
::
string
dump_b_target
=
std
::
get
<
0
>
(
info1
[
0
]);
std
::
string
dump_b_target
=
std
::
get
<
0
>
(
info1
[
0
]);
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
info2
=
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
info2
=
UTProvideCheckerForBinary
().
Find
(
UTProvideCheckerForBinary
().
Find
(
stmt_out
,
UTProvideCheckerForBinary
::
BinaryOpType
::
kAdd
,
"a(j)"
,
dump_b_target
);
stmt_out
,
UTProvideCheckerForBinary
::
BinaryOpType
::
kAdd
,
"a(j)"
,
dump_b_target
);
ASSERT_EQ
(
info2
.
size
(),
1
);
ASSERT_EQ
(
info2
.
size
(),
1
);
std
::
string
dump_sum1_target
=
std
::
get
<
0
>
(
info2
[
0
]);
std
::
string
dump_sum1_target
=
std
::
get
<
0
>
(
info2
[
0
]);
EXPECT_EQ
(
std
::
get
<
2
>
(
info2
[
0
]),
32
*
1024
);
EXPECT_EQ
(
std
::
get
<
2
>
(
info2
[
0
]),
32
*
1024
);
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
info3
=
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
info3
=
UTProvideCheckerForBinary
().
Find
(
UTProvideCheckerForBinary
().
Find
(
stmt_out
,
UTProvideCheckerForBinary
::
BinaryOpType
::
kAdd
,
dump_sum1_target
,
"c(j)"
);
stmt_out
,
UTProvideCheckerForBinary
::
BinaryOpType
::
kAdd
,
dump_sum1_target
,
"c(j)"
);
ASSERT_EQ
(
info3
.
size
(),
1
);
ASSERT_EQ
(
info3
.
size
(),
1
);
EXPECT_EQ
(
std
::
get
<
0
>
(
info3
[
0
]),
"out(i, j)"
);
EXPECT_EQ
(
std
::
get
<
0
>
(
info3
[
0
]),
"out(i, j)"
);
EXPECT_EQ
(
std
::
get
<
2
>
(
info3
[
0
]),
32
*
1024
);
EXPECT_EQ
(
std
::
get
<
2
>
(
info3
[
0
]),
32
*
1024
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录