Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
akg
提交
99ab5932
A
akg
项目概览
MindSpore
/
akg
通知
58
Star
7
Fork
7
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
akg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
99ab5932
编写于
7月 01, 2020
作者:
W
wYann
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
a new pass to split tail block
上级
be2f98cf
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
199 addition
and
0 deletion
+199
-0
src/api/api_pass.cc
src/api/api_pass.cc
+1
-0
src/codegen/build_module.cc
src/codegen/build_module.cc
+3
-0
src/emit_insn/insn_info.h
src/emit_insn/insn_info.h
+10
-0
src/include/ir_pass.h
src/include/ir_pass.h
+1
-0
src/pass/split_tail_block.cc
src/pass/split_tail_block.cc
+184
-0
未找到文件。
src/api/api_pass.cc
浏览文件 @
99ab5932
...
...
@@ -182,5 +182,6 @@ REGISTER_PASS(SinkAllocate);
REGISTER_PASS
(
SubstituteDivVar
);
REGISTER_PASS
(
CastFilter
);
REGISTER_PASS
(
ScalarComputeRewrite
);
REGISTER_PASS
(
SplitTail
);
}
// namespace ir
}
// namespace akg
src/codegen/build_module.cc
浏览文件 @
99ab5932
...
...
@@ -759,6 +759,9 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
stmt
=
NEXT_PASS
(
GatherLoopInfo
,
stmt
);
}
stmt
=
NEXT_PASS
(
CastFilter
,
stmt
);
if
(
!
is_dynamic
)
{
stmt
=
NEXT_PASS
(
SplitTail
,
stmt
);
}
stmt
=
NEXT_PASS
(
EmitInsn
,
stmt
,
global_attrs
.
GetBoolAttr
(
kEnableBisectOptimize
,
true
),
global_attrs
.
GetBoolAttr
(
kEnableCoverProtectOptimize
,
true
),
binds_0
,
is_dynamic
);
// must be after EmitInsn
...
...
src/emit_insn/insn_info.h
浏览文件 @
99ab5932
...
...
@@ -294,6 +294,16 @@ struct BisectionInfoWrapper {
Map
<
std
::
string
,
Expr
>
dma_arg_info_map_
;
};
struct
InsnAxis
{
int
min
{
0
};
int
extent
{
0
};
Var
var
;
int
dst_stride
{
0
};
int
src_stride
{
0
};
std
::
list
<
int
>
src_stride_list
;
std
::
list
<
int
>
stride_list
;
};
IterVar
GetCceAxis
();
int
CeilTo
(
int
value
,
int
target
);
...
...
src/include/ir_pass.h
浏览文件 @
99ab5932
...
...
@@ -335,6 +335,7 @@ Stmt ValueNumbering(Stmt stmt);
Stmt
MultiLastAxisReductions
(
Stmt
stmt
,
bool
is_dynamic
);
Stmt
AutoReorder
(
Stmt
stmt
);
Stmt
SplitTail
(
Stmt
stmt
);
Stmt
CopyPropagation
(
Stmt
stmt
,
const
Map
<
Tensor
,
Buffer
>
&
extern_buffer
);
...
...
src/pass/split_tail_block.cc
0 → 100644
浏览文件 @
99ab5932
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <ir_pass.h>
#include <algorithm>
#include "emit_insn/insn_info.h"
#include "emit_insn/insn_pattern.h"
namespace
akg
{
namespace
ir
{
class
TailSpliter
:
public
IRMutator
{
public:
TailSpliter
()
=
default
;
~
TailSpliter
()
override
=
default
;
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
"pragma_emit_insn"
)
{
auto
intrin_name
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
if
(
include_intrin_list_
.
find
(
intrin_name
)
==
include_intrin_list_
.
end
())
{
return
s
;
}
StmtInfoList
dst_info_list
;
StmtInfoList
src_info_list
;
StmtInfo
if_info
;
StmtInfo
for_info
;
GetCompactComputationInfo
(
op
->
body
,
dst_info_list
,
src_info_list
,
if_info
,
for_info
,
false
);
CHECK
(
!
dst_info_list
.
empty
());
auto
dst_info
=
dst_info_list
[
0
];
if
(
src_info_list
.
empty
())
{
src_info_list
=
{
dst_info
.
Copy
()};
}
auto
get_info_list
=
[](
const
StmtStoreInfo
&
dst_info
,
const
Array
<
StmtStoreInfo
>
&
src_info_list
)
{
Array
<
StmtStoreInfo
>
res
;
res
.
push_back
(
dst_info
.
Copy
());
for
(
auto
it
:
src_info_list
)
{
res
.
push_back
(
it
.
Copy
());
}
return
res
;
};
auto
info_list
=
get_info_list
(
dst_info
,
src_info_list
);
FillEmptyVar
(
info_list
);
auto
axis_list
=
GetAixsList
(
for_info
,
info_list
);
auto
get_last_axis_it
=
[](
const
std
::
list
<
InsnAxis
>
&
axis_list
)
{
for
(
auto
it
=
axis_list
.
begin
();
it
!=
axis_list
.
end
();
it
++
)
{
auto
stride_list
=
it
->
stride_list
;
if
(
!
(
std
::
any_of
(
stride_list
.
begin
(),
stride_list
.
end
(),
[](
int
stride
)
{
return
stride
>
1
;
})
||
std
::
all_of
(
stride_list
.
begin
(),
stride_list
.
end
(),
[](
int
stride
)
{
return
stride
==
0
;
})))
{
return
it
;
}
}
return
axis_list
.
end
();
};
auto
last_axis_it
=
get_last_axis_it
(
axis_list
);
if
(
last_axis_it
==
axis_list
.
end
())
{
return
s
;
}
auto
last_axis
=
*
last_axis_it
;
auto
last_axis_shape
=
last_axis
.
extent
;
int
dst_block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
int
src_block_size
=
GetUbBlkSize
(
src_info_list
[
0
]
->
dtype_
);
int
block_size
=
dst_block_size
>
src_block_size
?
dst_block_size
:
src_block_size
;
int
vec_max_len
=
block_size
*
FULL_BLOCK_NUM
;
if
(
last_axis_shape
>
vec_max_len
&&
last_axis_shape
%
vec_max_len
!=
0
)
{
return
Block
::
make
(
TailMake
(
s
,
last_axis
,
vec_max_len
,
false
),
TailMake
(
s
,
last_axis
,
vec_max_len
,
true
));
}
if
(
last_axis_shape
<
vec_max_len
*
tail_rate_
&&
last_axis_shape
>
block_size
&&
last_axis_shape
%
block_size
!=
0
&&
axis_list
.
size
()
>
1
)
{
return
Block
::
make
(
TailMake
(
s
,
last_axis
,
block_size
,
false
),
TailMake
(
s
,
last_axis
,
block_size
,
true
));
}
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
std
::
list
<
InsnAxis
>
GetAixsList
(
const
StmtInfo
&
for_info
,
const
Array
<
StmtStoreInfo
>
&
info_list
)
{
std
::
list
<
InsnAxis
>
axis_list
;
auto
GetStrideByAxis
=
[](
const
Array
<
Var
>
&
vars
,
const
Array
<
Expr
>
&
strides
,
Var
obj_var
)
{
int
index
=
0
;
for
(
auto
var_it
:
vars
)
{
if
(
Equal
(
var_it
,
obj_var
))
{
return
strides
[
index
];
}
index
++
;
}
return
Expr
(
0
);
};
for
(
auto
it
:
for_info
.
ops_
)
{
InsnAxis
axis
;
auto
for_stmt
=
it
.
as
<
For
>
();
CHECK
(
for_stmt
);
axis
.
var
=
for_stmt
->
loop_var
;
axis
.
extent
=
GetInt32Const
(
for_stmt
->
extent
);
axis
.
min
=
GetInt32Const
(
for_stmt
->
min
);
int
index
=
0
;
for
(
auto
it
:
info_list
)
{
auto
stride
=
GetInt32Const
(
GetStrideByAxis
(
it
->
var_
,
it
->
strides_
,
axis
.
var
));
axis
.
stride_list
.
push_back
(
stride
);
if
(
index
==
0
)
{
axis
.
dst_stride
=
stride
;
}
else
{
axis
.
src_stride_list
.
push_back
(
stride
);
}
index
++
;
}
axis_list
.
push_back
(
axis
);
}
return
axis_list
;
}
Stmt
TailMake
(
const
Stmt
&
s
,
const
InsnAxis
&
tail_axis
,
int
body_size
,
bool
is_tail
)
{
if
(
auto
attr_stmt
=
s
.
as
<
AttrStmt
>
())
{
return
AttrStmt
::
make
(
attr_stmt
->
node
,
attr_stmt
->
attr_key
,
attr_stmt
->
value
,
TailMake
(
attr_stmt
->
body
,
tail_axis
,
body_size
,
is_tail
));
}
if
(
auto
for_stmt
=
s
.
as
<
For
>
())
{
if
(
Equal
(
for_stmt
->
loop_var
,
tail_axis
.
var
)
&&
GetIntConst
(
for_stmt
->
extent
)
==
tail_axis
.
extent
)
{
if
(
is_tail
)
{
return
For
::
make
(
for_stmt
->
loop_var
,
for_stmt
->
min
,
Expr
(
tail_axis
.
extent
%
body_size
),
for_stmt
->
for_type
,
for_stmt
->
device_api
,
TailMake
(
for_stmt
->
body
,
tail_axis
,
body_size
,
is_tail
));
}
CHECK_NE
(
body_size
,
0
);
Expr
remain_extent
=
Expr
(
tail_axis
.
extent
/
body_size
*
body_size
);
return
For
::
make
(
for_stmt
->
loop_var
,
for_stmt
->
min
,
remain_extent
,
for_stmt
->
for_type
,
for_stmt
->
device_api
,
TailMake
(
for_stmt
->
body
,
tail_axis
,
body_size
,
is_tail
));
}
return
For
::
make
(
for_stmt
->
loop_var
,
for_stmt
->
min
,
for_stmt
->
extent
,
for_stmt
->
for_type
,
for_stmt
->
device_api
,
TailMake
(
for_stmt
->
body
,
tail_axis
,
body_size
,
is_tail
));
}
if
(
s
.
as
<
Store
>
()
&&
is_tail
)
{
return
substitute
(
tail_axis
.
var
,
Add
::
make
(
Expr
(
tail_axis
.
extent
/
body_size
*
body_size
),
tail_axis
.
var
),
s
);
}
return
s
;
}
private:
const
float
tail_rate_
{
0.6
};
const
std
::
set
<
std
::
string
>
include_intrin_list_
=
{
"vec_single_fabs"
,
"vec_single_log"
,
"vec_single_exp"
,
"vec_single_rec"
,
"vec_single_not"
,
"vec_single_sqrt"
,
"vec_single_rsqrt"
,
"vec_single_relu"
,
"vec_single_not"
,
// vector_scalar
"vec_single_muls"
,
"vec_single_adds"
,
// Mov
"broadcast"
,
// vector_cast
"vec_single_cast"
,
"vec_single_floor"
,
"vec_single_round"
,
"vec_single_ceil"
,
"vec_single_trunc"
,
};
};
Stmt
SplitTail
(
Stmt
stmt
)
{
return
TailSpliter
().
Mutate
(
stmt
);
}
}
// namespace ir
}
// namespace akg
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录