Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
akg
提交
59f460e7
A
akg
项目概览
MindSpore
/
akg
通知
58
Star
7
Fork
7
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
akg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
59f460e7
编写于
7月 13, 2020
作者:
C
cy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix alignment and pragma
上级
0d9d3012
变更
23
展开全部
隐藏空白更改
内联
并排
Showing
23 changed file
with
1808 addition
and
1453 deletion
+1808
-1453
src/api/api_pass.cc
src/api/api_pass.cc
+3
-0
src/codegen/build_module.cc
src/codegen/build_module.cc
+8
-5
src/emit_insn/insn_builder.h
src/emit_insn/insn_builder.h
+11
-3
src/emit_insn/insn_builder_vector.cc
src/emit_insn/insn_builder_vector.cc
+19
-62
src/emit_insn/insn_emitter.cc
src/emit_insn/insn_emitter.cc
+53
-309
src/emit_insn/insn_emitter.h
src/emit_insn/insn_emitter.h
+94
-0
src/emit_insn/insn_info.cc
src/emit_insn/insn_info.cc
+39
-4
src/emit_insn/insn_info.h
src/emit_insn/insn_info.h
+9
-25
src/emit_insn/insn_pattern.h
src/emit_insn/insn_pattern.h
+1
-1
src/emit_insn/insn_with_variable.cc
src/emit_insn/insn_with_variable.cc
+12
-358
src/emit_insn/insn_with_variable.h
src/emit_insn/insn_with_variable.h
+1
-2
src/emit_insn/ir_transform.h
src/emit_insn/ir_transform.h
+481
-0
src/include/ir_pass.h
src/include/ir_pass.h
+4
-0
src/pass/analyze_align.h
src/pass/analyze_align.h
+752
-2
src/pass/analyze_align_dynamic.cc
src/pass/analyze_align_dynamic.cc
+1
-1
src/pass/analyze_align_static.cc
src/pass/analyze_align_static.cc
+34
-667
src/pass/merge_loops.cc
src/pass/merge_loops.cc
+1
-1
src/pass/multi_last_axis_reduction.cc
src/pass/multi_last_axis_reduction.cc
+2
-1
src/pass/optimize_pragma.cc
src/pass/optimize_pragma.cc
+1
-1
src/pass/rewrite_by_align_dynamic.cc
src/pass/rewrite_by_align_dynamic.cc
+2
-2
src/pass/rewrite_by_align_static.cc
src/pass/rewrite_by_align_static.cc
+5
-9
src/pass/store_pack.cc
src/pass/store_pack.cc
+52
-0
src/pass/store_recover.cc
src/pass/store_recover.cc
+223
-0
未找到文件。
src/api/api_pass.cc
浏览文件 @
59f460e7
...
...
@@ -115,6 +115,9 @@ REGISTER_PASS(AnalyzeMinAlignStatic);
REGISTER_PASS
(
AnalyzeMinAlignDynamic
);
REGISTER_PASS
(
RewriteBroadcastVector
);
REGISTER_PASS
(
OptimizePragma
);
REGISTER_PASS
(
PackStore
);
REGISTER_PASS
(
RecoverStore
);
REGISTER_PASS
(
MergeLoops
);
REGISTER_PASS
(
ExpandC0
);
REGISTER_PASS
(
ForEliminate
);
REGISTER_PASS
(
FixLoopExtent
);
...
...
src/codegen/build_module.cc
浏览文件 @
59f460e7
...
...
@@ -738,16 +738,19 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
if
(
global_attrs
.
GetBoolAttr
(
kDeadCodeElim
,
false
))
{
stmt
=
NEXT_PASS
(
DeadCodeElim
,
stmt
);
}
if
(
!
is_dynamic
)
{
stmt
=
NEXT_PASS
(
RewriteBroadcastVector
,
stmt
);
stmt
=
NEXT_PASS
(
OptimizePragma
,
stmt
);
}
if
(
is_dynamic
)
{
stmt
=
NEXT_PASS
(
AnalyzeMinAlignDynamic
,
stmt
,
global_attrs
.
GetIntAttr
(
kEnableConvAnalyzeAlign
,
true
),
global_attrs
.
GetIntAttr
(
kEnableScalarAlign
,
false
));
global_attrs
.
GetIntAttr
(
kEnableScalarAlign
,
false
));
}
else
{
stmt
=
NEXT_PASS
(
RewriteBroadcastVector
,
stmt
);
stmt
=
NEXT_PASS
(
OptimizePragma
,
stmt
);
stmt
=
NEXT_PASS
(
MergeLoops
,
stmt
,
false
);
stmt
=
NEXT_PASS
(
PackStore
,
stmt
);
stmt
=
NEXT_PASS
(
AnalyzeMinAlignStatic
,
stmt
);
stmt
=
NEXT_PASS
(
RecoverStore
,
stmt
);
}
stmt
=
NEXT_PASS
(
MultiLastAxisReductions
,
stmt
,
is_dynamic
);
stmt
=
NEXT_PASS
(
AutoReorder
,
stmt
);
if
(
enable_multicore
!=
0
)
{
...
...
src/emit_insn/insn_builder.h
浏览文件 @
59f460e7
...
...
@@ -25,6 +25,9 @@
#include "insn_info.h"
#include "cce_params.h"
namespace
akg
{
enum
SingleType
{
SIMD
,
Tensor_Scalar
,
Vector_Dump
};
struct
MutableMaskParams
{
Var
mask_var_
;
Expr
loop_var_
;
...
...
@@ -239,8 +242,11 @@ class VectorInsnBuilder : public InsnBuilder {
class
SingleVecInsnBuilder
:
public
VectorInsnBuilder
{
public:
SingleVecInsnBuilder
(
const
StmtStoreInfo
&
dst
,
const
StmtStoreInfo
&
src
,
const
ArgInfo
&
args
,
const
std
::
string
&
intrin_name
,
const
Buffer
&
tmp_buf
=
Buffer
())
:
VectorInsnBuilder
(
dst
,
{
src
},
args
,
intrin_name
),
src_info_
(
src_info_list_
[
0
]),
tmp_buffer_
(
tmp_buf
)
{
const
std
::
string
&
intrin_name
,
const
Expr
&
scalar_src
=
Expr
(),
const
SingleType
insn_type
=
SingleType
::
SIMD
)
:
VectorInsnBuilder
(
dst
,
{
src
},
args
,
intrin_name
),
src_info_
(
src_info_list_
[
0
]),
scalar_src_
(
scalar_src
),
insn_type_
(
insn_type
)
{
CHECK
(
src_info_
.
defined
());
}
~
SingleVecInsnBuilder
()
override
=
default
;
...
...
@@ -254,8 +260,10 @@ class SingleVecInsnBuilder : public VectorInsnBuilder {
Stmt
CreateBroadcast
(
const
VectorArgInfo
&
arg_info
,
const
Var
&
local_var
,
Stmt
stmt
);
StmtStoreInfo
src_info_
;
Buffer
tmp_buffer_
;
Buffer
broadcast_buffer_
;
Expr
scalar_src_
;
SingleType
insn_type_
;
// 0 simd : 1 vector_scalar : 2 vector_dup
};
class
MultiVecInsnBuilder
:
public
VectorInsnBuilder
{
...
...
src/emit_insn/insn_builder_vector.cc
浏览文件 @
59f460e7
...
...
@@ -92,9 +92,6 @@ Stmt SingleVecInsnBuilder::EmitExpandedIntrin(const VectorArgInfo &arg_info) {
Expr
dst_offset
=
dst_info_
->
insn_offset_
;
Expr
src_offset
=
src_info_
->
insn_offset_
;
Var
local_var
=
Var
(
"broadcast_for_vec_local_UB"
,
Handle
());
stmt
=
CreateBroadcast
(
arg_info
,
local_var
,
stmt
);
// Handle stride_m1 loop of single vector intrin, if stride_m1 > 255, it will be separated
if
(
dst_stride_m1
>=
MAX_STRIDE_M1
||
src_stride_m1
>=
MAX_STRIDE_M1
)
{
auto
var
=
Var
(
"repeatStrideM1Idx"
);
...
...
@@ -112,14 +109,6 @@ Stmt SingleVecInsnBuilder::EmitExpandedIntrin(const VectorArgInfo &arg_info) {
}
}
if
(
!
dst_info_
->
var_
.
empty
()
&&
src_info_
->
var_
.
empty
()
&&
intrin_name_
!=
INTRIN_NAME_VECTOR_DUP
)
{
// need to broadcast src first
stmt
=
Allocate
::
make
(
local_var
,
src_info_
->
dtype_
,
{
Expr
(
src_block_size
*
FULL_BLOCK_NUM
)},
const_true
(),
stmt
);
if
(
!
src_info_
->
scope_
.
empty
())
{
stmt
=
AttrStmt
::
make
(
local_var
,
STORAGE_SCOPE
,
StringImm
::
make
(
src_info_
->
scope_
),
stmt
);
}
}
CHECK
(
stmt
.
defined
())
<<
"Error: Stmt is undefined!"
;
return
stmt
;
...
...
@@ -131,70 +120,36 @@ Stmt SingleVecInsnBuilder::EmitExpandedIntrin(const VectorArgInfo &arg_info) {
/// \return
Stmt
SingleVecInsnBuilder
::
EmitIntrinBody
(
const
VectorArgInfo
&
arg_info
,
const
Map
<
std
::
string
,
Expr
>
&
args
)
{
Stmt
body
;
CHECK
(
!
arg_info
->
src_stride_m0_list_
.
empty
());
CHECK
(
!
arg_info
->
src_stride_m1_list_
.
empty
());
auto
dst_buffer_id
=
GenBufferId
(
dst_info_
);
auto
src_buffer_id
=
GenBufferId
(
src_info_
);
Expr
repeat
=
args
[
"repeat"
];
auto
dst_buffer_id
=
GenBufferId
(
dst_info_
);
Expr
dst_offset
=
Sub
::
make
(
args
[
"dstOffset"
],
arg_info
->
block_offset_
);
Expr
src_offset
=
args
[
"srcOffset"
];
Expr
src_stride_m1
=
arg_info
->
src_stride_m1_list_
[
0
];
auto
dst
=
GetAccessPtr
(
dst_buffer_id
,
"w"
,
dst_offset
);
auto
src
=
GetAccessPtr
(
src_buffer_id
,
"r"
,
src_offset
);
if
(
broadcast_buffer_
.
defined
())
{
src_stride_m1
=
0
;
src
=
GetAccessPtr
(
broadcast_buffer_
,
"r"
,
Expr
(
0
));
Array
<
Expr
>
insn_args
{};
if
(
insn_type_
==
SingleType
::
Vector_Dump
)
{
insn_args
=
{
dst
,
scalar_src_
,
repeat
};
}
else
{
auto
src_buffer_id
=
GenBufferId
(
src_info_
);
Expr
src_offset
=
args
[
"srcOffset"
];
auto
src
=
GetAccessPtr
(
src_buffer_id
,
"r"
,
src_offset
);
if
(
insn_type_
==
SingleType
::
SIMD
)
{
insn_args
=
{
dst
,
src
,
repeat
};
}
else
if
(
insn_type_
==
SingleType
::
Tensor_Scalar
)
{
insn_args
=
{
dst
,
src
,
scalar_src_
,
repeat
};
}
else
{
CHECK
(
0
)
<<
"
\n
Unknown insn_type_
\n
"
;
}
}
Array
<
Expr
>
stride_args
=
{
arg_info
->
dst_stride_m0_
,
arg_info
->
src_stride_m0_list_
[
0
],
arg_info
->
dst_stride_m1_
,
src_stride_m1
};
Array
<
Expr
>
insn_args
=
{
dst
,
src
,
repeat
};
if
(
arg_info
->
scalar_
.
defined
())
{
auto
scalar
=
arg_info
->
scalar_
;
if
(
tmp_buffer_
.
defined
())
{
dst
=
GetAccessPtr
(
tmp_buffer_
,
"w"
,
dst_offset
);
}
insn_args
=
{
dst
,
scalar
,
repeat
};
if
(
intrin_name_
!=
INTRIN_NAME_VECTOR_DUP
)
{
Insert
(
insn_args
,
1
,
src
);
}
}
arg_info
->
src_stride_m1_list_
[
0
]};
insn_args
=
MergeTwo
(
insn_args
,
stride_args
);
body
=
EmitCceIntrinTemplate
(
Stmt
(),
dst
.
type
(),
insn_args
,
intrin_name_
);
return
body
;
}
/// Create broadcast intrin if src is scalar
/// \param arg_info
/// \param local_var
/// \param stmt
/// \return
Stmt
SingleVecInsnBuilder
::
CreateBroadcast
(
const
VectorArgInfo
&
arg_info
,
const
Var
&
local_var
,
Stmt
stmt
)
{
if
(
!
dst_info_
->
var_
.
empty
()
&&
src_info_
->
var_
.
empty
()
&&
intrin_name_
!=
INTRIN_NAME_VECTOR_DUP
)
{
// need to broadcast src first
auto
src_block_size
=
GetUbBlkSize
(
src_info_
->
dtype_
);
broadcast_buffer_
=
BufferNode
::
make
(
local_var
,
src_info_
->
dtype_
,
{
Expr
(
src_block_size
*
FULL_BLOCK_NUM
)},
{},
src_info_
->
elem_offset_
,
"broadcast_for_vec_local_UB"
,
src_info_
->
scope_
,
src_info_
->
data_alignment_
,
1
,
BufferType
::
kDefault
);
auto
broad_dst
=
GetAccessPtr
(
broadcast_buffer_
,
"w"
,
0
);
Array
<
Expr
>
args
=
{
broad_dst
,
GenBufferId
(
src_info_
).
vload
({
Expr
(
0
)},
src_info_
->
dtype_
),
Expr
(
1
),
Expr
(
1
),
Expr
(
1
),
Expr
(
0
),
Expr
(
0
)};
stmt
=
EmitSetVecMaskIntrin
(
stmt
,
src_info_
->
dtype_
,
GetAllMask
(
src_info_
->
dtype_
));
stmt
=
InsertBody
(
stmt
,
EmitCceIntrinTemplate
(
Stmt
(),
src_info_
->
dtype_
,
args
,
INTRIN_NAME_VECTOR_DUP
));
stmt
=
EmitSetVecMaskIntrin
(
stmt
,
dst_info_
->
dtype_
,
arg_info
->
vec_mask_
);
}
return
stmt
;
}
/// if repeat-size > cce_max_repeat, then split it into loop as "Davinci ISA User Guide t6.3 (8.2.2)" mentioned
/// max_cce_repeat = 255, considering params are about 2 cycles, set it to be 255 // 2 = 127
...
...
@@ -1250,8 +1205,10 @@ Stmt EmitCceBinaryVectorToReduceLastAxis(const StmtStoreInfo &dst_info, const St
auto
vec_dup_arg_info
=
GenReduceHelperArgInfo
(
vec_dup_dst_info
,
for_extent
,
scalar
,
"VecDup"
);
vec_dup_dst_info
.
GetNode
()
->
data_
=
final_var
;
vec_dup_dst_info
.
GetNode
()
->
name_
=
final_var
->
name_hint
;
SingleVecInsnBuilder
single_vec_builder
=
SingleVecInsnBuilder
(
vec_dup_dst_info
,
vec_dup_dst_info
,
vec_dup_arg_info
,
INTRIN_NAME_VECTOR_DUP
,
final_dst_buffer
);
INTRIN_NAME_VECTOR_DUP
,
scalar
,
SingleType
::
Vector_Dump
);
auto
insn_list
=
single_vec_builder
.
EmitIntrin
();
auto
stmt
=
std
::
accumulate
(
insn_list
.
begin
(),
insn_list
.
end
(),
Stmt
(),
[](
const
Stmt
&
s0
,
const
Stmt
&
s1
)
{
return
InsertBody
(
s0
,
s1
);
});
...
...
src/emit_insn/insn_emitter.cc
浏览文件 @
59f460e7
...
...
@@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "
emit_insn/
insn_emitter.h"
#include "insn_emitter.h"
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
...
...
@@ -53,145 +53,68 @@ std::vector<size_t> SortIndexes(const std::vector<int> &v) {
/// \param intrin_name - The CCE intrin name
/// \param broadcast_last_axis - Tag of broadcast_last_axis mode
/// \return Stmt of emitted CCE intrin
Stmt
SingleVecEmitter
(
const
Stmt
&
op
,
std
::
string
intrin_name
,
bool
broadcast_last_axis
=
false
)
{
Stmt
SingleVecEmitter
(
const
Stmt
&
op
,
std
::
string
intrin_name
)
{
CHECK
(
op
);
Stmt
result
;
// optimization of copy_ubuf_to_ubuf
bool
is_dma_opt
=
false
;
if
(
intrin_name
==
INTRIN_NAME_COPY_UB_TO_UB
)
{
CommentManager
::
GetInstance
().
AddComment
(
"Insn_type"
,
"dma_copy"
);
CommentManager
::
GetInstance
().
AddComment
(
"Insn_name"
,
INTRIN_NAME_COPY_UB_TO_UB
);
CommentManager
::
GetInstance
().
AddComment
(
"Vadds_replace_copy"
,
"enable"
);
intrin_name
=
"vadds"
;
is_dma_opt
=
true
;
}
else
{
CommentManager
::
GetInstance
().
AddComment
(
"Insn_type"
,
"single_vector"
);
CommentManager
::
GetInstance
().
AddComment
(
"Insn_name"
,
intrin_name
);
}
CommentManager
::
GetInstance
().
AddComment
(
"Insn_type"
,
"single_vector"
);
CommentManager
::
GetInstance
().
AddComment
(
"Insn_name"
,
intrin_name
);
StmtInfoList
dst_info_list
;
StmtInfoList
src_info_list
;
StmtStoreInfo
scalar_info
;
StmtInfo
for_info
;
StmtInfo
if_info
;
std
::
string
mode
=
GetSingleVecComputationInfo
(
op
,
intrin_name
,
dst_info_list
,
src_info_list
,
if_info
,
for_info
);
bool
same_dtype
=
intrin_name
.
find
(
"vconv_"
)
==
std
::
string
::
npos
;
GetCompactComputationInfo
(
op
,
dst_info_list
,
src_info_list
,
if_info
,
for_info
,
same_dtype
,
true
);
CHECK
(
!
dst_info_list
.
empty
());
if
(
broadcast_last_axis
)
{
mode
=
"broadcast_last_axis"
;
// In this case, must come from binary vec, so must have two src
CHECK
(
src_info_list
.
size
()
>=
2
)
<<
"Broadcast last axis mode must have at least two srcs."
;
if
(
!
IsTwoItemEqual
(
src_info_list
[
0
]
->
var_
,
dst_info_list
[
0
]
->
var_
,
-
1
))
{
scalar_info
=
src_info_list
[
0
];
src_info_list
.
Set
(
0
,
src_info_list
[
1
]);
}
else
if
(
!
IsTwoItemEqual
(
src_info_list
[
1
]
->
var_
,
dst_info_list
[
0
]
->
var_
,
-
1
))
{
scalar_info
=
src_info_list
[
1
];
}
}
else
{
if
(
mode
==
"broadcast"
&&
!
src_info_list
.
empty
()
&&
dst_info_list
.
size
()
==
1
)
{
if
(
!
IsTwoItemEqual
(
src_info_list
[
0
]
->
var_
,
dst_info_list
[
0
]
->
var_
,
-
1
))
{
mode
=
"broadcast_last_axis"
;
Array
<
Expr
>
call_args
;
int
call_cnt
=
0
;
if
(
intrin_name
==
"vector_dup"
||
intrin_name
==
"vadds"
||
intrin_name
==
"vmuls"
||
intrin_name
==
"vaxpy"
)
{
auto
GetCallInfo
=
[
&
intrin_name
,
&
call_args
,
&
call_cnt
](
const
NodeRef
&
op
)
{
if
(
op
.
as
<
Call
>
()
&&
op
.
as
<
Call
>
()
->
name
==
intrin_name
)
{
call_args
=
op
.
as
<
Call
>
()
->
args
;
call_cnt
=
call_cnt
+
1
;
}
if
(
src_info_list
.
size
()
>
1
)
{
if
(
!
IsTwoItemEqual
(
src_info_list
[
1
]
->
var_
,
dst_info_list
[
0
]
->
var_
,
-
1
))
{
mode
=
"broadcast_last_axis"
;
}
else
{
scalar_info
=
src_info_list
[
0
];
src_info_list
.
Set
(
0
,
src_info_list
[
1
]);
}
}
}
}
if
(
broadcast_last_axis
)
{
mode
=
"broadcast_last_axis"
;
};
PostOrderVisit
(
op
,
GetCallInfo
);
CHECK_EQ
(
call_cnt
,
1
);
}
if
(
intrin_name
==
INTRIN_NAME_VECTOR_DUP
)
{
auto
dst_info
=
dst_info_list
[
0
];
if
(
dst_info
->
var_
.
size
()
>
1
&&
GetIntConst
(
GetItem
(
dst_info
->
strides_
,
-
1
))
==
GetIntConst
(
GetItem
(
dst_info
->
shape_
,
-
1
))
+
1
)
{
// diagnoal broadcast case
return
op
;
}
dst_info
.
CleanFlexVar
();
SingleType
insn_type
{
SingleType
::
SIMD
};
Expr
scalar_src
{};
if
(
intrin_name
==
"vector_dup"
)
{
insn_type
=
SingleType
::
Vector_Dump
;
src_info_list
=
{};
scalar_src
=
call_args
[
0
];
}
else
if
(
intrin_name
==
"vadds"
||
intrin_name
==
"vmuls"
||
intrin_name
==
"vaxpy"
)
{
insn_type
=
SingleType
::
Tensor_Scalar
;
src_info_list
=
{
src_info_list
[
0
]};
scalar_src
=
call_args
[
1
];
}
// check is single vector broadcast reduce mode exist
SingleVecPatternGenerator
generator
=
SingleVecPatternGenerator
(
dst_info_list
,
src_info_list
,
for_info
,
mode
);
SingleVecPatternGenerator
generator
=
SingleVecPatternGenerator
(
dst_info_list
,
src_info_list
,
for_info
);
auto
params
=
generator
.
GetInsnArgs
();
dst_info_list
=
params
.
dst_info_list
;
src_info_list
=
params
.
src_info_list
;
for_info
=
params
.
for_info
;
ArgInfo
arg_info
=
params
.
arg_info
;
CommentManager
::
GetInstance
().
AddComment
(
"Compute_type"
,
mod
e
);
CommentManager
::
GetInstance
().
AddComment
(
"Compute_type"
,
intrin_nam
e
);
CommentManager
::
GetInstance
().
AddComment
(
"Pattern"
,
arg_info
.
GetPattern
());
if
(
intrin_name
==
"vadds"
||
intrin_name
==
"vmuls"
||
intrin_name
==
INTRIN_NAME_VECTOR_DUP
)
{
auto
stores
=
GetStores
(
op
);
auto
store
=
stores
[
0
].
as
<
Store
>
();
auto
scalar
=
Expr
(
0
);
if
(
intrin_name
==
"vadds"
||
intrin_name
==
"vmuls"
)
{
if
(
!
dst_info_list
.
empty
())
{
scalar
=
FloatImm
::
make
(
dst_info_list
[
0
]
->
dtype_
,
0.000000
);
}
if
(
!
dst_info_list
[
0
]
->
dtype_
.
is_float
())
{
return
op
;
}
if
(
!
is_dma_opt
)
{
if
(
!
scalar_info
.
defined
())
{
auto
children
=
GetBinaryOpExprChildren
(
store
->
value
);
if
(
children
.
empty
())
{
LOG
(
FATAL
)
<<
store
->
value
<<
" is not binary op."
;
}
scalar
=
children
[
1
];
}
else
{
scalar
=
Load
::
make
(
scalar_info
->
dtype_
,
scalar_info
->
data_
,
scalar_info
->
index_
,
Expr
(
1
));
}
}
}
else
if
(
intrin_name
==
INTRIN_NAME_VECTOR_DUP
)
{
if
(
store
->
value
->
IsInstance
<
Load
>
())
{
// scale is load
scalar
=
Load
::
make
(
src_info_list
[
0
]
->
dtype_
,
store
->
value
.
as
<
Load
>
()
->
buffer_var
,
src_info_list
[
0
]
->
index_
,
Expr
(
1
));
}
else
{
// scale is imm
scalar
=
store
->
value
;
}
}
if
(
arg_info
->
body_arg_info_
.
defined
())
{
arg_info
->
body_arg_info_
.
GetNode
()
->
scalar_
=
scalar
;
}
if
(
arg_info
->
tail_arg_info_
.
defined
())
{
arg_info
->
tail_arg_info_
.
GetNode
()
->
scalar_
=
scalar
;
}
}
if
(
intrin_name
==
"vconv_deq"
)
{
result
=
InsertBody
(
result
,
Evaluate
::
make
(
Call
::
make
(
Float
(
16
),
"set_deqscale"
,
{
FloatImm
::
make
(
Float
(
16
),
1.0
)},
Call
::
Extern
)));
}
SingleVecInsnBuilder
single_vec_builder
=
SingleVecInsnBuilder
(
dst_info_list
[
0
],
src_info_list
[
0
],
arg_info
,
intrin_name
);
SingleVecInsnBuilder
(
dst_info_list
[
0
],
src_info_list
[
0
],
arg_info
,
intrin_name
,
scalar_src
,
insn_type
);
auto
insn_list
=
single_vec_builder
.
EmitIntrin
();
if
(
intrin_name
==
INTRIN_NAME_VECTOR_DUP
&&
dst_info_list
[
0
]
->
var_
.
empty
())
{
Stmt
store
;
auto
ScanStore
=
[
&
store
](
const
NodeRef
&
op
)
{
const
auto
e
=
op
.
as
<
Store
>
();
if
(
e
!=
nullptr
)
{
store
=
Store
::
make
(
e
->
buffer_var
,
e
->
value
,
e
->
index
,
e
->
predicate
);
}
};
air
::
ir
::
PostOrderVisit
(
op
,
ScanStore
);
store
=
EmitSetVecMaskIntrin
(
store
,
dst_info_list
[
0
]
->
dtype_
);
insn_list
=
{
store
};
}
return
FoldInsnWithForInfo
(
insn_list
,
if_info
,
for_info
,
result
);
auto
ret
=
FoldInsnWithForInfo
(
insn_list
,
if_info
,
for_info
,
result
);
return
ret
;
}
/// Function to emit binary vector intrin
...
...
@@ -211,11 +134,6 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec
CommentManager
::
GetInstance
().
AddComment
(
"Insn_name"
,
intrin_name
);
switch
(
arg_info
->
arg_type_
)
{
case
ARG_VECTOR_BROADCAST_LAST_AXIS
:
{
CommentManager
::
GetInstance
().
CleanComments
();
intrin_name
+=
"s"
;
return
SingleVecEmitter
(
op
,
intrin_name
,
true
);
}
case
ARG_VECTOR_REDUCTION_LAST_AXIS
:
{
CommentManager
::
GetInstance
().
AddComment
(
"Compute_type"
,
"reduce_last_axis"
);
auto
dst_info
=
dst_info_list
[
0
];
...
...
@@ -928,83 +846,8 @@ Stmt DmaMovEmitter(const Stmt &op, bool enable_cover_protect) {
StmtInfo
for_info
;
GetDmaComputationInfo
(
op
,
dst_info_list
,
src_info_list
,
if_info
,
for_info
,
dma_mode
,
intrin_name
);
auto
check_alignment
=
[](
const
Expr
&
align
,
const
Array
<
Expr
>
&
shape
)
{
if
(
GetIntConst
(
align
)
==
1
||
shape
.
size
()
==
1u
)
{
return
true
;
}
if
(
shape
.
empty
())
{
return
false
;
}
Expr
sz
=
1
;
for
(
int
i
=
static_cast
<
int
>
(
shape
.
size
())
-
1
;
i
>=
0
;
--
i
)
{
sz
=
sz
*
shape
[
i
];
if
(
GetIntConst
(
align
)
==
GetIntConst
(
sz
))
{
return
true
;
}
}
return
false
;
};
const
auto
&
dst_info
=
dst_info_list
[
0
];
const
auto
&
src_info
=
src_info_list
[
0
];
int
block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
// check scalar to scalar
// check if dst is considered as scalar
// check if src is considered as scalar
bool
is_broadcast
=
(
dst_info
->
var_
.
empty
()
||
(
!
dst_info
->
strides_
.
empty
()
&&
GetIntConst
(
GetItem
(
dst_info
->
strides_
,
-
1
))
!=
1
))
&&
(
src_info
->
var_
.
empty
()
||
(
!
src_info
->
strides_
.
empty
()
&&
GetIntConst
(
GetItem
(
src_info
->
strides_
,
-
1
))
!=
1
));
// check vector to vector, but in scalar dma mode
bool
last_dim_equal
=
!
dst_info
->
var_
.
empty
()
&&
!
src_info
->
var_
.
empty
()
&&
!
dst_info
->
strides_
.
empty
()
&&
!
src_info
->
strides_
.
empty
()
&&
GetItem
(
dst_info
->
var_
,
-
1
).
get
()
==
GetItem
(
src_info
->
var_
,
-
1
).
get
()
&&
GetIntConst
(
GetItem
(
dst_info
->
strides_
,
-
1
))
!=
GetIntConst
(
GetItem
(
src_info
->
strides_
,
-
1
));
bool
broadcast_scalar
=
intrin_name
==
"broadcast"
&&
is_broadcast
;
bool
ubuf_scalar
=
intrin_name
==
INTRIN_NAME_COPY_UB_TO_UB
&&
(
is_broadcast
||
last_dim_equal
);
if
(
broadcast_scalar
||
ubuf_scalar
)
{
int
shape1
=
GetInt32Const
(
GetItem
(
dst_info
->
shape_
,
-
1
));
int
stride1
=
GetInt32Const
(
GetItem
(
dst_info
->
strides_
,
-
1
));
if
(
ubuf_scalar
&&
shape1
<
block_size
&&
stride1
==
block_size
&&
IsTwoItemEqual
(
dst_info
->
strides_
,
src_info
->
strides_
,
-
1
,
true
)
&&
src_info
->
dtype_
.
bits
()
!=
64
)
{
// if last dim small than blocksize, then use vadds
return
SingleVecEmitter
(
op
,
intrin_name
);
}
CommentManager
::
GetInstance
().
AddComment
(
"Insn_type"
,
"dma_copy"
);
CommentManager
::
GetInstance
().
AddComment
(
"Insn_name"
,
"scalar"
);
if
(
src_info
->
var_
.
empty
()
&&
dst_info
->
var_
.
empty
())
{
return
op
;
}
else
{
// check align
if
(
!
check_alignment
(
dst_info
->
data_alignment_
,
dst_info
->
shape_
))
{
return
op
;
}
Stmt
base_stmt
=
EmitScalarDmaIntrinTemplate
(
op
,
src_info
,
dst_info
);
return
GenIfAndFor
(
base_stmt
,
if_info
,
for_info
,
false
);
}
}
if
(
intrin_name
==
"broadcast"
)
{
return
SingleVecEmitter
(
op
,
INTRIN_NAME_VECTOR_DUP
);
}
else
if
(
intrin_name
==
INTRIN_NAME_COPY_UB_TO_UB
)
{
// Use vadds to optimize dma copy
if
(
if_info
.
vars_
.
empty
()
&&
dst_info
->
dtype_
.
is_float
()
&&
src_info
->
dtype_
.
is_float
())
{
if
((
dst_info
->
dtype_
.
bits
()
==
32
&&
src_info
->
dtype_
.
bits
()
==
32
)
||
(
dst_info
->
dtype_
.
bits
()
==
16
&&
src_info
->
dtype_
.
bits
()
==
16
))
{
int
repeat_len
=
block_size
*
FULL_BLOCK_NUM
;
CHECK_NE
(
block_size
,
0
);
int
shape1
=
GetInt32Const
(
GetItem
(
dst_info
->
shape_
,
-
1
));
if
((
shape1
>=
repeat_len
/
2
&&
shape1
<=
repeat_len
)
||
(
dst_info
->
shape_
.
size
()
>=
3
&&
shape1
<=
block_size
)
||
(
dst_info
->
shape_
.
size
()
>=
2
&&
shape1
%
block_size
==
0
))
{
// if last dim shape is too small, there is no need to opt
return
SingleVecEmitter
(
op
,
intrin_name
);
}
}
}
}
CommentManager
::
GetInstance
().
AddComment
(
"Insn_type"
,
"dma_copy"
);
...
...
@@ -1014,31 +857,10 @@ Stmt DmaMovEmitter(const Stmt &op, bool enable_cover_protect) {
Map
<
std
::
string
,
Expr
>
ub_copy_post
;
auto
arg_info_map
=
GetDmaCopyInsnArgs
(
intrin_name
,
dst_info_list
,
src_info_list
,
for_info
,
ub_copy_pre
,
ub_copy_post
);
if
(
intrin_name
==
"vtranspose_scalar"
)
{
base_stmt
=
EmitScalarDmaIntrinTemplate
(
op
,
src_info
,
dst_info
);
CommentManager
::
GetInstance
().
AddComment
(
"Insn_name"
,
"scalar"
);
}
else
if
(
intrin_name
==
"vtranspose"
)
{
Array
<
Expr
>
args
=
{
arg_info_map
[
"loop_width"
],
arg_info_map
[
"loop_height"
],
arg_info_map
[
"shape_width"
]};
Array
<
Expr
>
pre_ub_copy_args
;
if
(
!
ub_copy_pre
.
empty
())
{
pre_ub_copy_args
=
Array
<
Expr
>
(
{
ub_copy_pre
[
"nBurst"
],
ub_copy_pre
[
"lenBurst"
],
ub_copy_pre
[
"srcStride"
],
ub_copy_pre
[
"dstStride"
]});
}
Array
<
Expr
>
post_ub_copy_args
;
if
(
!
ub_copy_post
.
empty
())
{
post_ub_copy_args
=
Array
<
Expr
>
(
{
ub_copy_post
[
"nBurst"
],
ub_copy_post
[
"lenBurst"
],
ub_copy_post
[
"srcStride"
],
ub_copy_post
[
"dstStride"
]});
}
TransposeInsnBuilder
builder
=
TransposeInsnBuilder
(
dst_info
,
src_info
,
args
,
pre_ub_copy_args
,
post_ub_copy_args
);
base_stmt
=
builder
.
EmitSingleIntrin
();
CommentManager
::
GetInstance
().
AddComment
(
"Insn_name"
,
intrin_name
);
}
else
{
DmaInsnBuilder
dma_builder
=
DmaInsnBuilder
(
dst_info
,
src_info
,
intrin_name
,
arg_info_map
,
false
,
false
,
enable_cover_protect
);
base_stmt
=
dma_builder
.
EmitSingleIntrin
();
CommentManager
::
GetInstance
().
AddComment
(
"Insn_name"
,
intrin_name
);
}
DmaInsnBuilder
dma_builder
=
DmaInsnBuilder
(
dst_info
,
src_info
,
intrin_name
,
arg_info_map
,
false
,
false
,
enable_cover_protect
);
base_stmt
=
dma_builder
.
EmitSingleIntrin
();
CommentManager
::
GetInstance
().
AddComment
(
"Insn_name"
,
intrin_name
);
}
else
if
(
dma_mode
==
"cce_load"
)
{
auto
arg_info_map
=
GetDmaLoad2DInsnArgs
(
intrin_name
,
dst_info_list
,
src_info_list
,
for_info
);
DmaInsnBuilder
builder
=
DmaInsnBuilder
(
dst_info
,
src_info
,
intrin_name
,
arg_info_map
,
true
);
...
...
@@ -1104,6 +926,19 @@ Stmt DmaAtomicAddEmitter(const Stmt &op) {
return
stmt
;
}
Stmt
VTransposeEmitter
(
const
Stmt
&
op
)
{
StmtInfoList
dst_info_list
;
StmtInfoList
src_info_list
;
StmtInfo
for_info
;
StmtInfo
if_info
;
GetCompactComputationInfo
(
op
,
dst_info_list
,
src_info_list
,
if_info
,
for_info
,
true
,
true
);
auto
dst_buffer_id
=
GenBufferId
(
dst_info_list
[
0
]);
auto
src_buffer_id
=
GenBufferId
(
src_info_list
[
0
]);
auto
dst
=
GetAccessPtr
(
dst_buffer_id
,
"w"
,
0
);
auto
src
=
GetAccessPtr
(
src_buffer_id
,
"r"
,
0
);
return
Evaluate
::
make
(
Call
::
make
(
Float
(
16
),
"vtranspose"
,
{
dst
,
src
},
Call
::
Extern
));
}
/// Function to emit dropout intrin
/// \param op - The input stmt to be emitted as intrin
/// \return Stmt of emitted CCE intrin
...
...
@@ -1913,97 +1748,6 @@ Stmt ReduceCombineEmitter(const Stmt &op, bool enable_bisect) {
Stmt
InsnEmit
(
std
::
string
insn_name
,
const
Stmt
&
op
,
bool
enable_bisect
,
bool
enable_cover_protect
,
int
comment_level
)
{
CHECK
(
op
.
defined
());
static
const
std
::
map
<
std
::
string
,
std
::
string
>
ReplaceAttrPragmaMap
=
{
// vector binary
{
"binary_vcadd"
,
"vec_binary_add"
},
{
"vaxpy"
,
"vec_binary_axpy"
},
// vector single
{
"vec_single_fabs"
,
"vec_single_abs"
},
{
"broadcast"
,
"vec_broadcast"
},
// cube
{
"mad"
,
"cube_mad"
},
{
"ub2gm"
,
"cube_ub2gm"
},
{
"im2col"
,
"cube_img2col"
},
// special attrs
{
"vec_binary_proposal_sort"
,
"vec_proposal_sort"
},
{
"vec_binary_topk_sort"
,
"vec_topk_sort"
},
{
"vec_binary_dropout"
,
"vec_dropout"
},
{
"vec_binary_fargmax"
,
"vec_argmax"
},
{
"vec_binary_fargmin"
,
"vec_argmin"
},
{
"vec_binary_iou"
,
"vec_iou"
},
{
"vec_binary_nms"
,
"vec_nms"
},
{
"mask_broadcast"
,
"vec_broadcast"
},
};
static
const
std
::
map
<
std
::
string
,
std
::
string
>
BinaryVecInsnMap
=
{
// vadd.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadd.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadd.f32 support target:mini_v100 cloud_v100
// vadd contains two situations:
// 1. normal elewise vector add
// - all src[i].shape = dst.shape
// 2. reductive vector add
// - exist src[i].shape != dst.shape
{
"vec_binary_add"
,
"vadd"
},
// vsub.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vsub.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vsub.f32 support target:mini_v100 cloud_v100
{
"vec_binary_sub"
,
"vsub"
},
// vmul.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmul.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmul.f32 support target:mini_v100 cloud_v100
{
"vec_binary_mul"
,
"vmul"
},
// vmin.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmin.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmin.f32 support target:mini_v100 cloud_v100
{
"vec_binary_min"
,
"vmin"
},
// vmax.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmax.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmax.f32 support target:mini_v100 cloud_v100
{
"vec_binary_max"
,
"vmax"
},
{
"vec_binary_div"
,
"vdiv"
},
{
"vec_binary_and"
,
"vand"
},
{
"vec_binary_bitwise_and"
,
"vand"
},
{
"vec_binary_or"
,
"vor"
},
{
"vec_binary_bitwise_or"
,
"vor"
},
{
"vec_binary_vmadd"
,
"vmadd"
},
{
"vec_binary_vmaddrelu"
,
"vmaddrelu"
},
{
"vec_binary_vmla"
,
"vmla"
}};
static
const
std
::
map
<
std
::
string
,
std
::
string
>
SingleVecInsnMap
=
{
// vmuls.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmuls.f32 supporttarget:mini_v100 cloud_v100
{
"vec_single_muls"
,
"vmuls"
},
// vadds.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadds.f32 support target:mini_v100 cloud_v100
{
"vec_single_adds"
,
"vadds"
},
// vrelu.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
{
"vec_single_relu"
,
"vrelu"
},
// vabs.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vabs.f32 support target:mini_v100 cloud_v100
{
"vec_single_abs"
,
"vabs"
},
// vln.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vln.f32 support target:cloud_v100
{
"vec_single_log"
,
"vln"
},
// vexp.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vexp.f32 support target:cloud_v100
{
"vec_single_exp"
,
"vexp"
},
// vrec.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vrec.f32 support target:mini_v100 cloud_v100
{
"vec_single_rec"
,
"vrec"
},
// vnot support target:mini_v100 tiny_v100 lite_v100 cloud_v100
{
"vec_single_not"
,
"vnot"
},
{
"vec_single_bitwise_not"
,
"vnot"
},
// vsqrt support target:cloud_v100
{
"vec_single_sqrt"
,
"vsqrt"
},
{
"vec_single_rsqrt"
,
"vrsqrt"
},
{
"vec_broadcast"
,
"vector_dup"
}};
static
const
std
::
map
<
std
::
string
,
std
::
string
>
SingleCastInsnMap
=
{
{
"vec_single_floor"
,
"f"
},
{
"vec_single_round"
,
"r"
},
{
"vec_single_ceil"
,
"c"
},
{
"vec_single_trunc"
,
"z"
}};
static
const
std
::
set
<
std
::
string
>
ReturnOpInsnSet
=
{
"scalar_dma"
,
"scatter"
,
"vec_binary_select_loop_var"
};
static
const
std
::
map
<
std
::
string
,
std
::
function
<
Stmt
(
const
Stmt
&
)
>>
InsnFunctorMap
=
{
{
"dma_atomic_add"
,
DmaAtomicAddEmitter
},
{
"vec_single_cast"
,
SingleCastEmitter
},
...
...
@@ -2017,9 +1761,9 @@ Stmt InsnEmit(std::string insn_name, const Stmt &op, bool enable_bisect, bool en
{
"vec_dropout"
,
BinaryDropoutEmitter
},
{
"cube_mad"
,
MadEmitter
},
{
"vec_select_scalar"
,
SelectWithScalarEmitter
},
{
"vec_binary_axpy"
,
VaxpyEmitter
},
{
"opt_broadcast"
,
MultiMaskEmitter
},
{
"vec_single_four2five_nchw"
,
VnchwconvEmitter
}};
{
"vec_single_four2five_nchw"
,
VnchwconvEmitter
},
{
"vtranspose"
,
VTransposeEmitter
}};
if
(
ReplaceAttrPragmaMap
.
count
(
insn_name
)
!=
0
)
{
insn_name
=
ReplaceAttrPragmaMap
.
find
(
insn_name
)
->
second
;
...
...
src/emit_insn/insn_emitter.h
浏览文件 @
59f460e7
...
...
@@ -30,6 +30,100 @@
namespace
akg
{
namespace
ir
{
static
const
std
::
map
<
std
::
string
,
std
::
string
>
ReplaceAttrPragmaMap
=
{
// vector binary
{
"binary_vcadd"
,
"vec_binary_add"
},
// vector single
{
"vec_single_fabs"
,
"vec_single_abs"
},
{
"broadcast"
,
"vec_broadcast"
},
// cube
{
"mad"
,
"cube_mad"
},
{
"ub2gm"
,
"cube_ub2gm"
},
{
"im2col"
,
"cube_img2col"
},
// special attrs
{
"vec_binary_proposal_sort"
,
"vec_proposal_sort"
},
{
"vec_binary_topk_sort"
,
"vec_topk_sort"
},
{
"vec_binary_dropout"
,
"vec_dropout"
},
{
"vec_binary_fargmax"
,
"vec_argmax"
},
{
"vec_binary_fargmin"
,
"vec_argmin"
},
{
"vec_binary_iou"
,
"vec_iou"
},
{
"vec_binary_nms"
,
"vec_nms"
},
{
"mask_broadcast"
,
"vec_broadcast"
},
};
static
const
std
::
map
<
std
::
string
,
std
::
string
>
BinaryVecInsnMap
=
{
// vadd.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadd.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadd.f32 support target:mini_v100 cloud_v100
// vadd contains two situations:
// 1. normal elewise vector add
// - all src[i].shape = dst.shape
// 2. reductive vector add
// - exist src[i].shape != dst.shape
{
"vec_binary_add"
,
"vadd"
},
// vsub.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vsub.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vsub.f32 support target:mini_v100 cloud_v100
{
"vec_binary_sub"
,
"vsub"
},
// vmul.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmul.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmul.f32 support target:mini_v100 cloud_v100
{
"vec_binary_mul"
,
"vmul"
},
// vmin.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmin.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmin.f32 support target:mini_v100 cloud_v100
{
"vec_binary_min"
,
"vmin"
},
// vmax.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmax.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmax.f32 support target:mini_v100 cloud_v100
{
"vec_binary_max"
,
"vmax"
},
{
"vec_binary_div"
,
"vdiv"
},
{
"vec_binary_and"
,
"vand"
},
{
"vec_binary_bitwise_and"
,
"vand"
},
{
"vec_binary_or"
,
"vor"
},
{
"vec_binary_bitwise_or"
,
"vor"
},
{
"vec_binary_vmadd"
,
"vmadd"
},
{
"vec_binary_vmaddrelu"
,
"vmaddrelu"
},
{
"vec_binary_vmla"
,
"vmla"
}};
static
const
std
::
map
<
std
::
string
,
std
::
string
>
SingleVecInsnMap
=
{
// vmuls.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmuls.f32 supporttarget:mini_v100 cloud_v100
{
"vec_single_muls"
,
"vmuls"
},
// vadds.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadds.f32 support target:mini_v100 cloud_v100
{
"vec_single_adds"
,
"vadds"
},
// vrelu.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
{
"vec_single_relu"
,
"vrelu"
},
// vabs.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vabs.f32 support target:mini_v100 cloud_v100
{
"vec_single_abs"
,
"vabs"
},
// vln.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vln.f32 support target:cloud_v100
{
"vec_single_log"
,
"vln"
},
// vexp.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vexp.f32 support target:cloud_v100
{
"vec_single_exp"
,
"vexp"
},
// vrec.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vrec.f32 support target:mini_v100 cloud_v100
{
"vec_single_rec"
,
"vrec"
},
// vnot support target:mini_v100 tiny_v100 lite_v100 cloud_v100
{
"vec_single_not"
,
"vnot"
},
{
"vec_single_bitwise_not"
,
"vnot"
},
// vsqrt support target:cloud_v100
{
"vec_single_sqrt"
,
"vsqrt"
},
{
"vec_single_rsqrt"
,
"vrsqrt"
},
{
"vaxpy"
,
"vaxpy"
},
{
"vec_broadcast"
,
"vector_dup"
},
{
"vadds"
,
"vadds"
},
{
"vmuls"
,
"vmuls"
},
{
"vector_dup"
,
"vector_dup"
},
};
static
const
std
::
map
<
std
::
string
,
std
::
string
>
SingleCastInsnMap
=
{
{
"vec_single_floor"
,
"f"
},
{
"vec_single_round"
,
"r"
},
{
"vec_single_ceil"
,
"c"
},
{
"vec_single_trunc"
,
"z"
}};
static
const
std
::
set
<
std
::
string
>
ReturnOpInsnSet
=
{
"scalar_calc"
,
"scalar_dma"
,
"scatter"
,
"vec_binary_select_loop_var"
};
Stmt
EmitInsnWithDynamicShapes
(
const
Stmt
&
s
,
const
Map
<
Tensor
,
Buffer
>
&
extern_buffer
);
...
...
src/emit_insn/insn_info.cc
浏览文件 @
59f460e7
...
...
@@ -935,7 +935,7 @@ void GetCompactComputationInfo(const Stmt &stmt, StmtInfoList &dst_info_list, St
/// \param if_info - The if-condition as input
/// \param for_info - The for-loop info to be modified
void
CompactComputationInfoList
(
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
if_info
,
StmtInfo
&
for_info
)
{
StmtInfo
&
for_info
)
{
auto
MergeTwoVar
=
[](
const
Var
&
keep_var
,
const
Var
&
delete_var
,
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
for_info
)
{
for
(
auto
info
:
dst_info_list
)
{
...
...
@@ -1059,8 +1059,7 @@ void CompactComputationInfoList(StmtInfoList &dst_info_list, StmtInfoList &src_i
bool
find_merge
=
false
;
for
(
size_t
i
=
0
;
(
i
<
var_cnt
-
1
)
&&
(
!
find_merge
);
i
++
)
{
for
(
size_t
j
=
i
+
1
;
j
<
var_cnt
;
j
++
)
{
if
(
CanMergeTwoVar
(
for_info
.
vars_
[
i
],
for_info
.
vars_
[
j
],
dst_info_list
,
src_info_list
,
for_info
))
{
if
(
CanMergeTwoVar
(
for_info
.
vars_
[
i
],
for_info
.
vars_
[
j
],
dst_info_list
,
src_info_list
,
for_info
))
{
find_merge
=
true
;
break
;
}
...
...
@@ -1075,7 +1074,6 @@ void CompactComputationInfoList(StmtInfoList &dst_info_list, StmtInfoList &src_i
}
}
/// A helper function for single dst_info's compact
/// \param dst_info
/// \param src_info_list
...
...
@@ -1357,6 +1355,43 @@ int GetVectorizedVarPosition(const Expr &index, Array<Var> &loop_vars) {
return
pos
;
}
std
::
string
GetOpType
(
const
Expr
&
value
)
{
if
(
value
.
as
<
Add
>
())
{
return
value
.
as
<
Add
>
()
->
_type_key
;
}
if
(
value
.
as
<
Sub
>
())
{
return
value
.
as
<
Sub
>
()
->
_type_key
;
}
if
(
value
.
as
<
Mul
>
())
{
return
value
.
as
<
Mul
>
()
->
_type_key
;
}
if
(
value
.
as
<
Div
>
())
{
return
value
.
as
<
Div
>
()
->
_type_key
;
}
if
(
value
.
as
<
Mod
>
())
{
return
value
.
as
<
Mod
>
()
->
_type_key
;
}
if
(
value
.
as
<
FloorDiv
>
())
{
return
value
.
as
<
FloorDiv
>
()
->
_type_key
;
}
if
(
value
.
as
<
FloorMod
>
())
{
return
value
.
as
<
FloorMod
>
()
->
_type_key
;
}
if
(
value
.
as
<
Min
>
())
{
return
value
.
as
<
Min
>
()
->
_type_key
;
}
if
(
value
.
as
<
Max
>
())
{
return
value
.
as
<
Max
>
()
->
_type_key
;
}
if
(
value
.
as
<
Call
>
())
{
return
value
.
as
<
Call
>
()
->
name
;
}
if
(
value
.
as
<
Load
>
()
||
value
.
as
<
IntImm
>
()
||
value
.
as
<
FloatImm
>
())
{
return
"DMACopy"
;
}
return
"undefined"
;
}
/// TVM Function Register, enable python code to call these cpp function.
TVM_REGISTER_API
(
"cce_util.GetCceAxis"
).
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
GetCceAxis
();
});
...
...
src/emit_insn/insn_info.h
浏览文件 @
59f460e7
...
...
@@ -49,13 +49,7 @@ enum ArgType {
ARG_NOT_DEFINE
};
enum
PatternType
{
PATTERN_3D
=
1
,
PATTERN_PARTIAL_3D
,
PATTERN_2D
,
PATTERN_2D_BLOCK
,
PATTERN_1D
};
enum
PatternType
{
PATTERN_3D
=
1
,
PATTERN_PARTIAL_3D
,
PATTERN_2D
,
PATTERN_2D_BLOCK
,
PATTERN_1D
};
class
StmtStoreInfoNode
:
public
Node
{
public:
...
...
@@ -98,13 +92,9 @@ class StmtStoreInfo : public NodeRef {
explicit
StmtStoreInfo
(
const
ObjectPtr
<
Object
>
&
n
)
:
NodeRef
(
n
),
node_
(
n
)
{}
~
StmtStoreInfo
()
=
default
;
inline
StmtStoreInfoNode
*
GetNode
()
const
{
return
static_cast
<
StmtStoreInfoNode
*>
(
node_
.
get
());
}
inline
StmtStoreInfoNode
*
GetNode
()
const
{
return
static_cast
<
StmtStoreInfoNode
*>
(
node_
.
get
());
}
inline
const
StmtStoreInfoNode
*
operator
->
()
const
{
return
static_cast
<
const
StmtStoreInfoNode
*>
(
node_
.
get
());
}
inline
const
StmtStoreInfoNode
*
operator
->
()
const
{
return
static_cast
<
const
StmtStoreInfoNode
*>
(
node_
.
get
());
}
void
CleanFlexVar
();
...
...
@@ -188,13 +178,9 @@ class VectorArgInfo : public NodeRef {
explicit
VectorArgInfo
(
const
ObjectPtr
<
Object
>
&
n
)
:
NodeRef
(
n
),
node_
(
n
)
{}
~
VectorArgInfo
()
=
default
;
inline
VectorArgInfoNode
*
GetNode
()
const
{
return
static_cast
<
VectorArgInfoNode
*>
(
node_
.
get
());
}
inline
VectorArgInfoNode
*
GetNode
()
const
{
return
static_cast
<
VectorArgInfoNode
*>
(
node_
.
get
());
}
inline
const
VectorArgInfoNode
*
operator
->
()
const
{
return
static_cast
<
const
VectorArgInfoNode
*>
(
node_
.
get
());
}
inline
const
VectorArgInfoNode
*
operator
->
()
const
{
return
static_cast
<
const
VectorArgInfoNode
*>
(
node_
.
get
());
}
void
Print
()
const
{
LOG
(
DEBUG
)
<<
"[ body_num: "
<<
GetNode
()
->
body_num_
<<
", body_offset: "
<<
GetNode
()
->
body_offset_
...
...
@@ -235,13 +221,9 @@ class ArgInfo : public NodeRef {
explicit
ArgInfo
(
const
ObjectPtr
<
Object
>
&
n
)
:
NodeRef
(
n
),
node_
(
n
)
{}
~
ArgInfo
()
=
default
;
inline
ArgInfoNode
*
GetNode
()
const
{
return
static_cast
<
ArgInfoNode
*>
(
node_
.
get
());
}
inline
ArgInfoNode
*
GetNode
()
const
{
return
static_cast
<
ArgInfoNode
*>
(
node_
.
get
());
}
inline
const
ArgInfoNode
*
operator
->
()
const
{
return
static_cast
<
const
ArgInfoNode
*>
(
node_
.
get
());
}
inline
const
ArgInfoNode
*
operator
->
()
const
{
return
static_cast
<
const
ArgInfoNode
*>
(
node_
.
get
());
}
inline
std
::
string
GetPattern
()
const
{
switch
(
GetNode
()
->
pattern_
)
{
...
...
@@ -373,6 +355,8 @@ bool IsBisectionReduction(const StmtInfoList &dst_info_list, const StmtInfoList
bool
HasVars
(
const
Expr
&
index
,
const
Var
&
vec_var
);
int
GetVectorizedVarPosition
(
const
Expr
&
index
,
Array
<
Var
>
&
loop_vars
);
std
::
string
GetOpType
(
const
Expr
&
value
);
}
// namespace akg
namespace
air
{
...
...
src/emit_insn/insn_pattern.h
浏览文件 @
59f460e7
...
...
@@ -77,7 +77,7 @@ class PatternGenerator {
class
SingleVecPatternGenerator
:
public
PatternGenerator
{
public:
SingleVecPatternGenerator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
mode
)
const
StmtInfo
&
for_info
,
const
std
::
string
&
mode
=
"elewise"
)
:
PatternGenerator
(
dst_info_list
,
for_info
),
arg_info
(
ArgInfo
(
make_node
<
ArgInfoNode
>
())),
body_args
(
VectorArgInfo
()),
...
...
src/emit_insn/insn_with_variable.cc
浏览文件 @
59f460e7
...
...
@@ -33,9 +33,11 @@
#include "insn_info.h"
#include "insn_pattern.h"
#include "insn_emitter.h"
#include "ir_transform.h"
namespace
akg
{
namespace
ir
{
Expr
GetVarCoefExpr
(
const
Expr
&
index
,
const
Var
&
loop_var
)
{
Expr
ret
=
Expr
();
Array
<
Expr
>
coefs
=
air
::
arith
::
DetectLinearEquation
(
index
,
{
loop_var
});
...
...
@@ -203,7 +205,7 @@ class HasScalarVarValue : public IRVisitor {
class
AdjustPragma
:
public
IRMutator
{
public:
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
air
::
ir
::
attr
::
IsPragmaKey
(
op
->
attr_key
)
&&
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
.
as
<
StringImm
>
())
{
if
(
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
.
as
<
StringImm
>
())
{
is_candidate_
=
true
;
loop_vars_
=
{};
loop_extends_
=
{};
...
...
@@ -295,7 +297,7 @@ class AdjustPragma : public IRMutator {
Array
<
Expr
>
srcs
=
call_ptr
->
args
;
CHECK_EQ
(
srcs
.
size
(),
2
);
is_argmax_min_
=
true
;
reduce_type_
=
(
op
->
value
.
as
<
Call
>
()
->
name
==
"fargmin"
)
?
"
arg_min"
:
"arg_
max"
;
reduce_type_
=
(
op
->
value
.
as
<
Call
>
()
->
name
==
"fargmin"
)
?
"
reduce_fargmin"
:
"reduce_farg
max"
;
return
Store
::
make
(
op
->
buffer_var
,
Call
::
make
(
call_ptr
->
type
,
reduce_type_
,
{
srcs
[
1
]},
Call
::
CallType
::
Extern
),
op
->
index
,
op
->
predicate
);
}
else
if
((
op
->
value
.
as
<
FloatImm
>
()
||
op
->
value
.
as
<
IntImm
>
()
||
op
->
value
.
as
<
UIntImm
>
())
&&
...
...
@@ -484,353 +486,6 @@ class AdjustPragma : public IRMutator {
Array
<
Var
>
transpose_vars_
;
};
class
TransposeTransform
:
public
IRMutator
{
public:
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
air
::
ir
::
attr
::
IsPragmaKey
(
op
->
attr_key
)
&&
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
.
as
<
StringImm
>
()
&&
op
->
value
.
as
<
StringImm
>
()
->
value
==
"dma_copy"
)
{
pre_transpose_buffer
=
Var
(
"srcTranspose_local_UB"
);
post_transpose_buffer
=
Var
(
"dstTranspose_local_UB"
);
loop_vars_
=
{};
loop_extends_
=
{};
is_candidate_
=
true
;
is_block_transpose_
=
false
;
auto
body
=
this
->
Mutate
(
op
->
body
);
is_candidate_
=
false
;
if
(
is_block_transpose_
)
{
is_block_transpose_
=
false
;
auto
allocate_pre_buffer
=
Allocate
::
make
(
pre_transpose_buffer
,
t_type
,
{
TransTotalSize
},
const_true
(
1
),
body
);
auto
attr_pre_buffer
=
AttrStmt
::
make
(
pre_transpose_buffer
,
"storage_scope"
,
Expr
(
"local.UB"
),
allocate_pre_buffer
);
auto
allocate_post_buffer
=
Allocate
::
make
(
post_transpose_buffer
,
t_type
,
{
TransTotalSize
},
const_true
(
1
),
attr_pre_buffer
);
auto
attr_post_buffer
=
AttrStmt
::
make
(
post_transpose_buffer
,
"storage_scope"
,
Expr
(
"local.UB"
),
allocate_post_buffer
);
return
attr_post_buffer
;
}
else
{
return
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
body
);
}
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
if
(
is_candidate_
)
{
loop_vars_
.
push_back
(
op
->
loop_var
);
loop_extends_
.
push_back
(
op
->
extent
);
Stmt
body
=
this
->
Mutate
(
op
->
body
);
if
(
is_block_transpose_
&&
IsInArray
(
trans_vars_
,
op
->
loop_var
))
{
return
body
;
}
else
{
return
For
::
make
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
body
);
}
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
if
(
is_candidate_
)
{
auto
value
=
op
->
value
;
if
(
auto
cast
=
op
->
value
.
as
<
Cast
>
())
{
value
=
cast
->
value
;
}
CHECK
(
value
.
as
<
Load
>
());
auto
src_ptr
=
value
.
as
<
Load
>
();
if
(
GetBufferType
(
op
->
buffer_var
)
==
SCOPE_UBUF
&&
GetBufferType
(
src_ptr
->
buffer_var
)
==
SCOPE_UBUF
)
{
int
dst_pos
=
GetVectorizedVarPosition
(
op
->
index
,
loop_vars_
);
int
src_pos
=
GetVectorizedVarPosition
(
src_ptr
->
index
,
loop_vars_
);
if
(
dst_pos
!=
-
1
&&
src_pos
!=
-
1
&&
dst_pos
!=
src_pos
&&
floormod
(
loop_extends_
[
dst_pos
],
TransAxisLen
).
as
<
IntImm
>
()
&&
floormod
(
loop_extends_
[
dst_pos
],
TransAxisLen
).
as
<
IntImm
>
()
->
value
==
0
&&
Equal
(
GetVarCoefExpr
(
op
->
index
,
loop_vars_
[
src_pos
]),
loop_extends_
[
dst_pos
]))
{
if
(
loop_extends_
[
dst_pos
].
as
<
IntImm
>
()
&&
loop_extends_
[
dst_pos
].
as
<
IntImm
>
()
->
value
==
TransAxisLen
&&
loop_extends_
[
src_pos
].
as
<
IntImm
>
()
&&
loop_extends_
[
src_pos
].
as
<
IntImm
>
()
->
value
==
TransAxisLen
)
{
return
s
;
}
else
{
is_block_transpose_
=
true
;
t_type
=
src_ptr
->
type
;
trans_vars_
=
{};
trans_vars_
.
push_back
(
loop_vars_
[
src_pos
]);
trans_vars_
.
push_back
(
loop_vars_
[
dst_pos
]);
Expr
ori_w
=
GetVarCoefExpr
(
src_ptr
->
index
,
loop_vars_
[
dst_pos
]);
Expr
ori_h
=
loop_extends_
[
dst_pos
];
Expr
ori_block_w
=
floordiv
(
ori_w
,
TransAxisLen
);
Expr
ori_block_h
=
floordiv
(
ori_h
,
TransAxisLen
);
Var
loop_w
=
Var
(
"block_w"
);
Var
loop_h
=
Var
(
"block_h"
);
Expr
src_base_index
=
EliminateVarInExpr
(
src_ptr
->
index
,
trans_vars_
);
Expr
dst_base_index
=
EliminateVarInExpr
(
op
->
index
,
trans_vars_
);
Var
tt0
=
Var
(
"tt0"
);
Var
tt1
=
Var
(
"tt1"
);
auto
pre_copy
=
Store
::
make
(
pre_transpose_buffer
,
Load
::
make
(
t_type
,
src_ptr
->
buffer_var
,
src_base_index
+
loop_h
*
TransAxisLen
*
ori_w
+
loop_w
*
TransAxisLen
+
tt1
*
ori_w
+
tt0
,
1
),
tt1
*
TransAxisLen
+
tt0
,
1
);
auto
pre_l0
=
For
::
make
(
tt0
,
0
,
TransAxisLen
,
ForType
::
Serial
,
DeviceAPI
::
None
,
pre_copy
);
auto
pre_l1
=
For
::
make
(
tt1
,
0
,
TransAxisLen
,
ForType
::
Serial
,
DeviceAPI
::
None
,
pre_l0
);
auto
pre_attr
=
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
"dma_copy"
),
pre_l1
);
auto
transpose
=
Store
::
make
(
post_transpose_buffer
,
Load
::
make
(
t_type
,
pre_transpose_buffer
,
tt1
*
TransAxisLen
+
tt0
,
1
),
tt0
*
16
+
tt1
,
1
);
auto
trans_l0
=
For
::
make
(
tt0
,
0
,
TransAxisLen
,
ForType
::
Serial
,
DeviceAPI
::
None
,
transpose
);
auto
trans_l1
=
For
::
make
(
tt1
,
0
,
TransAxisLen
,
ForType
::
Serial
,
DeviceAPI
::
None
,
trans_l0
);
auto
trans_attr
=
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
"dma_copy"
),
trans_l1
);
auto
post_copy
=
Store
::
make
(
op
->
buffer_var
,
Load
::
make
(
t_type
,
post_transpose_buffer
,
tt1
*
TransAxisLen
+
tt0
,
1
),
dst_base_index
+
loop_w
*
TransAxisLen
*
ori_h
+
loop_h
*
TransAxisLen
+
tt1
*
ori_h
+
tt0
,
1
);
auto
post_l0
=
For
::
make
(
tt0
,
0
,
TransAxisLen
,
ForType
::
Serial
,
DeviceAPI
::
None
,
post_copy
);
auto
post_l1
=
For
::
make
(
tt1
,
0
,
TransAxisLen
,
ForType
::
Serial
,
DeviceAPI
::
None
,
post_l0
);
auto
post_attr
=
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
"dma_copy"
),
post_l1
);
auto
full_inner
=
Block
::
make
(
Block
::
make
(
pre_attr
,
trans_attr
),
post_attr
);
auto
inner_w
=
For
::
make
(
loop_w
,
0
,
ori_block_w
,
ForType
::
Serial
,
DeviceAPI
::
None
,
full_inner
);
auto
inner_h
=
For
::
make
(
loop_h
,
0
,
ori_block_h
,
ForType
::
Serial
,
DeviceAPI
::
None
,
inner_w
);
return
inner_h
;
}
}
}
}
return
s
;
}
bool
is_candidate_
{
false
};
bool
is_block_transpose_
{
false
};
Array
<
Var
>
trans_vars_
;
Array
<
Var
>
loop_vars_
;
Array
<
Expr
>
loop_extends_
;
Type
t_type
;
Var
pre_transpose_buffer
;
Var
post_transpose_buffer
;
};
class
IfReorder
:
public
IRMutator
{
public:
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
air
::
ir
::
attr
::
IsPragmaKey
(
op
->
attr_key
)
&&
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
.
as
<
StringImm
>
()
&&
op
->
value
.
as
<
StringImm
>
()
->
value
!=
"mad"
)
{
in_insn_
=
true
;
for_vars_
.
clear
();
if_vars_
.
clear
();
for_vec_
.
clear
();
if_vec_
.
clear
();
auto
body
=
this
->
Mutate
(
op
->
body
);
in_insn_
=
false
;
if
(
!
if_vec_
.
empty
())
{
Stmt
new_s
=
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
body
);
for
(
auto
if_op
:
if_vec_
)
{
new_s
=
IfThenElse
::
make
(
if_op
->
condition
,
new_s
);
}
for
(
auto
for_op
=
for_vec_
.
rbegin
();
for_op
!=
for_vec_
.
rend
();
++
for_op
)
{
bool
find_flag
=
false
;
for
(
auto
for_iter
=
for_vars_
.
begin
();
for_iter
!=
for_vars_
.
end
();
++
for_iter
)
{
if
(
Equal
((
*
for_iter
),
(
*
for_op
)
->
loop_var
))
{
find_flag
=
true
;
break
;
}
}
if
(
find_flag
)
{
new_s
=
For
::
make
((
*
for_op
)
->
loop_var
,
(
*
for_op
)
->
min
,
(
*
for_op
)
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
new_s
);
}
}
return
new_s
;
}
else
{
return
s
;
}
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
if
(
in_insn_
)
{
for_vec_
.
push_back
(
op
);
for_vars_
.
push_back
(
op
->
loop_var
);
Stmt
body
=
this
->
Mutate
(
op
->
body
);
std
::
vector
<
Var
>::
iterator
for_iter
;
for
(
for_iter
=
for_vars_
.
begin
();
for_iter
!=
for_vars_
.
end
();
++
for_iter
)
{
if
(
Equal
((
*
for_iter
),
op
->
loop_var
))
{
break
;
}
}
if
(
!
if_vec_
.
empty
())
{
std
::
vector
<
Var
>::
iterator
if_iter
;
bool
find_flag
=
false
;
for
(
if_iter
=
if_vars_
.
begin
();
if_iter
!=
if_vars_
.
end
();
++
if_iter
)
{
if
(
Equal
((
*
if_iter
),
op
->
loop_var
))
{
find_flag
=
true
;
break
;
}
}
if
(
find_flag
)
{
return
body
;
}
else
{
for_vars_
.
erase
(
for_iter
);
return
For
::
make
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
body
);
}
}
else
{
for_vars_
.
erase
(
for_iter
);
return
For
::
make
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
body
);
}
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
IfThenElse
*
op
,
const
Stmt
&
s
)
final
{
if
(
in_insn_
)
{
if_vec_
.
push_back
(
op
);
for
(
auto
loop_var
:
for_vars_
)
{
if
(
HasVars
(
op
->
condition
,
loop_var
))
{
if_vars_
.
push_back
(
loop_var
);
}
}
Stmt
body
=
this
->
Mutate
(
op
->
then_case
);
return
body
;
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
if
(
in_insn_
)
{
return
s
;
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
bool
in_insn_
{
false
};
std
::
vector
<
const
IfThenElse
*>
if_vec_
;
std
::
vector
<
Var
>
if_vars_
;
std
::
vector
<
Var
>
for_vars_
;
std
::
vector
<
const
For
*>
for_vec_
;
std
::
vector
<
const
For
*>
before_if_
;
};
class
LoopReorder
:
public
IRMutator
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
air
::
ir
::
attr
::
IsPragmaKey
(
op
->
attr_key
)
&&
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
.
as
<
StringImm
>
())
{
in_insn_
=
true
;
pragma
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
for_map_
.
clear
();
ori_vars_
=
{};
var_order_
.
clear
();
auto
ret
=
this
->
Mutate
(
op
->
body
);
in_insn_
=
false
;
if
(
!
has_changed_
)
{
return
s
;
}
else
{
if
(
var_order_
.
empty
())
{
ret
=
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
ret
);
for
(
size_t
i
=
0
;
i
<
ori_vars_
.
size
();
++
i
)
{
CHECK_GT
(
for_map_
.
count
(
ori_vars_
[
i
].
get
()),
0
);
auto
ptr
=
for_map_
[
ori_vars_
[
i
].
get
()];
ret
=
For
::
make
(
ptr
->
loop_var
,
ptr
->
min
,
ptr
->
extent
,
ptr
->
for_type
,
ptr
->
device_api
,
ret
);
}
}
else
{
for
(
size_t
i
=
0
;
i
<
var_order_
.
size
();
++
i
)
{
CHECK_GT
(
for_map_
.
count
(
var_order_
[
i
].
get
()),
0
);
auto
ptr
=
for_map_
[
var_order_
[
i
].
get
()];
ret
=
For
::
make
(
ptr
->
loop_var
,
ptr
->
min
,
ptr
->
extent
,
ptr
->
for_type
,
ptr
->
device_api
,
ret
);
}
ret
=
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
ret
);
}
return
ret
;
}
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
if
(
in_insn_
)
{
for_map_
[(
op
->
loop_var
).
get
()]
=
op
;
ori_vars_
.
push_back
(
op
->
loop_var
);
auto
body
=
this
->
Mutate
(
op
->
body
);
return
body
;
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
int
dst_pos
=
GetVectorizedVarPosition
(
op
->
index
,
ori_vars_
);
int
len
=
static_cast
<
int
>
(
ori_vars_
.
size
());
std
::
vector
<
const
Load
*>
srcs
;
auto
get_loads
=
[
&
srcs
](
const
NodeRef
&
node
)
{
if
(
const
auto
v
=
node
.
as
<
Load
>
())
{
srcs
.
push_back
(
v
);
}
};
PostOrderVisit
(
op
->
value
,
get_loads
);
bool
same_pos
=
true
;
std
::
vector
<
int
>
srcs_pos
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
srcs
.
size
());
++
i
)
{
int
temp_pos
=
GetVectorizedVarPosition
(
srcs
[
i
]
->
index
,
ori_vars_
);
srcs_pos
.
push_back
(
temp_pos
);
if
(
temp_pos
!=
dst_pos
)
{
same_pos
=
false
;
}
}
has_changed_
=
false
;
if
(
dst_pos
>=
0
&&
len
>=
2
&&
dst_pos
!=
(
len
-
1
)
&&
(
same_pos
||
pragma
==
"broadcast"
))
{
// Src Load empty; all Load and Dst has the same key axis; broadcast
has_changed_
=
true
;
var_order_
.
push_back
(
ori_vars_
[
dst_pos
]);
for
(
int
i
=
len
-
1
;
i
>=
0
;
i
--
)
{
if
(
i
!=
dst_pos
)
{
var_order_
.
push_back
(
ori_vars_
[
i
]);
}
}
}
else
if
(
pragma
.
find
(
"reduce"
)
!=
pragma
.
npos
&&
len
>=
2
&&
srcs_pos
[
0
]
!=
(
len
-
1
))
{
// based on dst key axis: reduce
has_changed_
=
true
;
var_order_
.
push_back
(
ori_vars_
[
srcs_pos
[
0
]]);
for
(
int
i
=
len
-
1
;
i
>=
0
;
i
--
)
{
if
(
i
!=
srcs_pos
[
0
])
{
var_order_
.
push_back
(
ori_vars_
[
i
]);
}
}
}
return
s
;
}
std
::
unordered_map
<
const
Variable
*
,
const
For
*>
for_map_
;
std
::
vector
<
Var
>
var_order_
;
Array
<
Var
>
ori_vars_
;
bool
has_changed_
{
false
};
bool
in_insn_
{
false
};
std
::
string
pragma
;
};
class
ForVarUnique
:
public
IRMutator
{
public:
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
auto
body
=
this
->
Mutate
(
op
->
body
);
if
(
var_maps_
.
count
(
op
->
loop_var
.
get
()))
{
Var
new_var
=
Var
(
"ii"
+
std
::
to_string
(
++
index_
));
std
::
unordered_map
<
const
Variable
*
,
Expr
>
value_map
;
value_map
[
op
->
loop_var
.
get
()]
=
new_var
;
auto
new_body
=
Substitute
(
body
,
value_map
);
var_maps_
[
new_var
.
get
()]
=
1
;
return
For
::
make
(
new_var
,
op
->
min
,
op
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
new_body
);
}
else
{
var_maps_
[
op
->
loop_var
.
get
()]
=
1
;
return
For
::
make
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
body
);
}
}
std
::
unordered_map
<
const
Variable
*
,
int
>
var_maps_
;
int
index_
{
0
};
};
class
GenSIMD
{
public:
GenSIMD
(
CCEInfo
&
t_info
,
Map
<
std
::
string
,
Buffer
>
&
buffer_map
,
const
std
::
string
&
pragma
)
...
...
@@ -1520,9 +1175,9 @@ class GenReduce {
~
GenReduce
()
=
default
;
Stmt
Run
(
int
pre_index
)
{
is_arg_type_
=
(
pragma_
==
"
arg_max"
||
pragma_
==
"arg_
min"
);
is_arg_type_
=
(
pragma_
==
"
reduce_fargmax"
||
pragma_
==
"reduce_farg
min"
);
RemoveVectorizedIndex
(
t_info_
,
0
);
if
(
pragma_
.
find
(
"sum"
)
!=
std
::
string
::
npos
)
{
if
(
pragma_
.
find
(
"sum"
)
!=
std
::
string
::
npos
||
pragma_
.
find
(
"add"
)
!=
std
::
string
::
npos
)
{
insn_intrinsic_
=
"vcadd"
;
expansion_factor_
=
1
;
}
else
if
(
pragma_
.
find
(
"max"
)
!=
std
::
string
::
npos
)
{
...
...
@@ -1769,7 +1424,7 @@ class EmitVariableInsns : public IRMutator {
}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
air
::
ir
::
attr
::
IsPragmaKey
(
op
->
attr_key
)
&&
op
->
attr_key
==
"pragma_emit_insn"
)
{
if
(
op
->
attr_key
==
"pragma_emit_insn"
)
{
CHECK
(
op
->
value
.
as
<
StringImm
>
());
pragma
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
Stmt
r
;
...
...
@@ -1791,8 +1446,7 @@ class EmitVariableInsns : public IRMutator {
if
(
!
r
.
same_as
(
s
))
{
return
r
;
}
}
else
if
(
air
::
ir
::
attr
::
IsPragmaKey
(
op
->
attr_key
)
&&
(
op
->
attr_key
==
"pragma_im2col"
||
op
->
attr_key
==
"pragma_load3d"
))
{
}
else
if
(
op
->
attr_key
==
"pragma_im2col"
||
op
->
attr_key
==
"pragma_load3d"
)
{
if
(
paramters_
.
defined
()
&&
Downcast
<
Map
<
std
::
string
,
NodeRef
>>
(
paramters_
).
count
(
"feature"
))
{
auto
feature
=
Downcast
<
Map
<
std
::
string
,
NodeRef
>>
(
paramters_
)[
"feature"
].
as
<
StringImm
>
();
CHECK
(
feature
);
...
...
@@ -1842,13 +1496,13 @@ class EmitVariableInsns : public IRMutator {
if
(
pragma
.
find
(
"vec_select"
)
!=
std
::
string
::
npos
)
{
EmitSelect
(
op
,
t_info
);
}
else
if
(
pragma
.
find
(
"dma_copy"
)
==
0
)
{
}
else
if
(
pragma
.
find
(
"dma_copy"
)
!=
std
::
string
::
npos
)
{
EmitDMA
(
t_info
);
}
else
if
(
pragma
.
find
(
"vec_binary"
)
==
0
||
pragma
.
find
(
"vec_single"
)
==
0
)
{
}
else
if
(
pragma
.
find
(
"vec_binary"
)
!=
std
::
string
::
npos
||
pragma
.
find
(
"vec_single"
)
!=
std
::
string
::
npos
)
{
EmitSIMD
(
t_info
);
}
else
if
(
pragma
.
find
(
"reduce"
)
==
0
||
pragma
.
find
(
"arg_"
)
==
0
)
{
}
else
if
(
pragma
.
find
(
"reduce"
)
!=
std
::
string
::
npos
||
pragma
.
find
(
"arg_"
)
!=
std
::
string
::
npos
)
{
EmitReduce
(
t_info
);
}
else
if
(
pragma
.
find
(
"broadcast"
)
==
0
)
{
}
else
if
(
pragma
.
find
(
"broadcast"
)
!=
std
::
string
::
npos
)
{
if
(
loops_vars_
.
empty
())
{
gen_cce
=
t_info
.
ori_stmt
;
}
else
{
...
...
src/emit_insn/insn_with_variable.h
浏览文件 @
59f460e7
...
...
@@ -31,8 +31,7 @@
namespace
akg
{
namespace
ir
{
const
int
TransTotalSize
=
256
;
const
int
TransAxisLen
=
16
;
const
int64_t
FullReduceMaskValue
=
6148914691236517205
;
class
CCEInsn
{
...
...
src/emit_insn/ir_transform.h
0 → 100644
浏览文件 @
59f460e7
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef IR_TRANSFORM_H_
#define IR_TRANSFORM_H_
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <unordered_set>
#include <map>
#include <numeric>
#include <set>
#include <algorithm>
#include "ir_pass.h"
#include "common/array_api.h"
#include "insn_with_variable.h"
#include "insn_builder.h"
#include "insn_info.h"
#include "insn_pattern.h"
#include "../pass/analyze_align.h"
const
int
TransTotalSize
=
256
;
const
int
TransAxisLen
=
16
;
namespace
akg
{
namespace
ir
{
Expr
GetVarCoefExpr
(
const
Expr
&
index
,
const
Var
&
loop_var
);
std
::
string
GetBufferType
(
Expr
address
);
class
TransposeTransform
:
public
IRMutator
{
public:
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
.
as
<
StringImm
>
()
&&
(
op
->
value
.
as
<
StringImm
>
()
->
value
==
"dma_copy"
))
{
pre_transpose_buffer_
=
Var
(
"srcTranspose_local_UB"
);
post_transpose_buffer_
=
Var
(
"dstTranspose_local_UB"
);
pre_trans_cast_
=
Var
(
"pre_trans_cast__local_UB"
);
post_trans_cast_
=
Var
(
"post_trans_cast__local_UB"
);
loop_vars_
=
{};
loop_extends_
=
{};
is_candidate_
=
true
;
is_block_transpose_
=
false
;
is_native_transpose_
=
false
;
align_value
=
FREE_ALIGN
;
remain_fors_
.
clear
();
auto
body
=
this
->
Mutate
(
op
->
body
);
is_candidate_
=
false
;
if
(
is_block_transpose_
)
{
is_block_transpose_
=
false
;
if
(
t_type_
==
Float
(
32
))
{
// need cast
body
=
Allocate
::
make
(
pre_trans_cast_
,
Float
(
16
),
{
TransTotalSize
},
const_true
(
1
),
body
);
body
=
AttrStmt
::
make
(
pre_trans_cast_
,
"storage_scope"
,
Expr
(
"local.UB"
),
body
);
body
=
Allocate
::
make
(
post_trans_cast_
,
Float
(
16
),
{
TransTotalSize
},
const_true
(
1
),
body
);
body
=
AttrStmt
::
make
(
post_trans_cast_
,
"storage_scope"
,
Expr
(
"local.UB"
),
body
);
}
auto
allocate_pre_buffer
=
Allocate
::
make
(
pre_transpose_buffer_
,
t_type_
,
{
TransTotalSize
},
const_true
(
1
),
body
);
auto
attr_pre_buffer
=
AttrStmt
::
make
(
pre_transpose_buffer_
,
"storage_scope"
,
Expr
(
"local.UB"
),
allocate_pre_buffer
);
auto
allocate_post_buffer
=
Allocate
::
make
(
post_transpose_buffer_
,
t_type_
,
{
TransTotalSize
},
const_true
(
1
),
attr_pre_buffer
);
auto
attr_post_buffer
=
AttrStmt
::
make
(
post_transpose_buffer_
,
"storage_scope"
,
Expr
(
"local.UB"
),
allocate_post_buffer
);
Stmt
ret
=
attr_post_buffer
;
if
(
align_value
!=
FREE_ALIGN
)
{
ret
=
AttrStmt
::
make
(
align_buffer_
,
"align_info"
,
Expr
(
align_value
),
ret
);
}
return
ret
;
}
if
(
is_native_transpose_
)
{
Stmt
ret
=
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
Expr
(
"dma_copy_transpose"
),
body
);
for
(
int
i
=
0
;
i
<=
static_cast
<
int
>
(
remain_fors_
.
size
())
-
1
;
++
i
)
{
ret
=
For
::
make
(
remain_fors_
[
i
]
->
loop_var
,
remain_fors_
[
i
]
->
min
,
remain_fors_
[
i
]
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
ret
);
}
return
ret
;
}
return
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
body
);
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
if
(
is_candidate_
)
{
loop_vars_
.
push_back
(
op
->
loop_var
);
loop_extends_
.
push_back
(
op
->
extent
);
Stmt
body
=
this
->
Mutate
(
op
->
body
);
if
(
is_block_transpose_
&&
IsInArray
(
trans_vars_
,
op
->
loop_var
))
{
return
body
;
}
if
(
is_native_transpose_
)
{
if
(
IsInArray
(
trans_vars_
,
op
->
loop_var
))
{
return
For
::
make
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
body
);
}
remain_fors_
.
push_back
(
op
);
return
body
;
}
return
For
::
make
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
body
);
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
if
(
is_candidate_
)
{
auto
value
=
op
->
value
;
if
(
auto
cast
=
op
->
value
.
as
<
Cast
>
())
{
value
=
cast
->
value
;
}
CHECK
(
value
.
as
<
Load
>
());
auto
src_ptr
=
value
.
as
<
Load
>
();
if
(
GetBufferType
(
op
->
buffer_var
)
==
SCOPE_UBUF
&&
GetBufferType
(
src_ptr
->
buffer_var
)
==
SCOPE_UBUF
&&
src_ptr
->
type
==
Float
(
16
))
{
int
dst_pos
=
GetVectorizedVarPosition
(
op
->
index
,
loop_vars_
);
int
src_pos
=
GetVectorizedVarPosition
(
src_ptr
->
index
,
loop_vars_
);
if
(
dst_pos
!=
-
1
&&
src_pos
!=
-
1
&&
dst_pos
!=
src_pos
&&
HasVars
(
src_ptr
->
index
,
loop_vars_
[
dst_pos
])
&&
HasVars
(
op
->
index
,
loop_vars_
[
src_pos
])
&&
floormod
(
loop_extends_
[
dst_pos
],
TransAxisLen
).
as
<
IntImm
>
()
&&
floormod
(
loop_extends_
[
dst_pos
],
TransAxisLen
).
as
<
IntImm
>
()
->
value
==
0
&&
Equal
(
GetVarCoefExpr
(
op
->
index
,
loop_vars_
[
src_pos
]),
loop_extends_
[
dst_pos
]))
{
if
(
loop_extends_
[
dst_pos
].
as
<
IntImm
>
()
&&
loop_extends_
[
dst_pos
].
as
<
IntImm
>
()
->
value
==
TransAxisLen
&&
loop_extends_
[
src_pos
].
as
<
IntImm
>
()
&&
loop_extends_
[
src_pos
].
as
<
IntImm
>
()
->
value
==
TransAxisLen
)
{
trans_vars_
=
{};
trans_vars_
.
push_back
(
loop_vars_
[
src_pos
]);
trans_vars_
.
push_back
(
loop_vars_
[
dst_pos
]);
is_native_transpose_
=
true
;
return
s
;
}
is_block_transpose_
=
true
;
if
(
GetVarCoefExpr
(
src_ptr
->
index
,
loop_vars_
[
dst_pos
]).
as
<
IntImm
>
())
{
int
coef_t
=
GetVarCoefExpr
(
src_ptr
->
index
,
loop_vars_
[
dst_pos
]).
as
<
IntImm
>
()
->
value
;
if
(
coef_t
%
TransAxisLen
!=
0
)
{
align_value
=
coef_t
;
align_buffer_
=
src_ptr
->
buffer_var
;
}
}
t_type_
=
src_ptr
->
type
;
trans_vars_
=
{};
trans_vars_
.
push_back
(
loop_vars_
[
src_pos
]);
trans_vars_
.
push_back
(
loop_vars_
[
dst_pos
]);
Expr
ori_w
=
GetVarCoefExpr
(
src_ptr
->
index
,
loop_vars_
[
dst_pos
]);
Expr
ori_h
=
loop_extends_
[
dst_pos
];
Expr
ori_block_w
=
floordiv
(
ori_w
,
TransAxisLen
);
// padding the width
Expr
unit_width
=
TransAxisLen
;
if
(
!
Equal
(
floormod
(
ori_w
,
TransAxisLen
),
0
))
{
ori_block_w
=
ori_block_w
+
1
;
}
if
(
ori_w
.
as
<
IntImm
>
()
&&
ori_w
.
as
<
IntImm
>
()
->
value
<
TransAxisLen
)
{
unit_width
=
ori_w
;
}
Expr
ori_block_h
=
floordiv
(
ori_h
,
TransAxisLen
);
Var
loop_w
=
Var
(
"block_w"
);
Var
loop_h
=
Var
(
"block_h"
);
Expr
src_base_index
=
EliminateVarInExpr
(
src_ptr
->
index
,
trans_vars_
);
Expr
dst_base_index
=
EliminateVarInExpr
(
op
->
index
,
trans_vars_
);
Var
tt0
=
Var
(
"tt0"
);
Var
tt1
=
Var
(
"tt1"
);
auto
pre_copy
=
Store
::
make
(
pre_transpose_buffer_
,
Load
::
make
(
t_type_
,
src_ptr
->
buffer_var
,
src_base_index
+
loop_h
*
TransAxisLen
*
ori_w
+
loop_w
*
TransAxisLen
+
tt1
*
ori_w
+
tt0
,
1
),
tt1
*
TransAxisLen
+
tt0
,
1
);
auto
pre_l0
=
For
::
make
(
tt0
,
0
,
unit_width
,
ForType
::
Serial
,
DeviceAPI
::
None
,
pre_copy
);
auto
pre_l1
=
For
::
make
(
tt1
,
0
,
TransAxisLen
,
ForType
::
Serial
,
DeviceAPI
::
None
,
pre_l0
);
auto
pre_attr
=
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
"dma_copy"
),
pre_l1
);
Stmt
trans_attr
=
Stmt
();
if
(
t_type_
==
Float
(
16
))
{
auto
transpose
=
Store
::
make
(
post_transpose_buffer_
,
Load
::
make
(
t_type_
,
pre_transpose_buffer_
,
tt1
*
TransAxisLen
+
tt0
,
1
),
tt0
*
16
+
tt1
,
1
);
auto
trans_l0
=
For
::
make
(
tt0
,
0
,
TransAxisLen
,
ForType
::
Serial
,
DeviceAPI
::
None
,
transpose
);
auto
trans_l1
=
For
::
make
(
tt1
,
0
,
TransAxisLen
,
ForType
::
Serial
,
DeviceAPI
::
None
,
trans_l0
);
trans_attr
=
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
"dma_copy_transpose"
),
trans_l1
);
}
else
{
auto
pre_cast_store
=
Store
::
make
(
pre_trans_cast_
,
Cast
::
make
(
Float
(
16
),
Load
::
make
(
t_type_
,
pre_transpose_buffer_
,
tt0
,
1
)),
tt0
,
1
);
auto
pre_cast_for
=
For
::
make
(
tt0
,
0
,
TransTotalSize
,
ForType
::
Serial
,
DeviceAPI
::
None
,
pre_cast_store
);
auto
pre_cast_attr
=
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
"vec_single_cast"
),
pre_cast_for
);
auto
transpose
=
Store
::
make
(
post_trans_cast_
,
Load
::
make
(
Float
(
16
),
pre_trans_cast_
,
tt1
*
TransAxisLen
+
tt0
,
1
),
tt0
*
16
+
tt1
,
1
);
auto
trans_l0
=
For
::
make
(
tt0
,
0
,
TransAxisLen
,
ForType
::
Serial
,
DeviceAPI
::
None
,
transpose
);
auto
trans_l1
=
For
::
make
(
tt1
,
0
,
TransAxisLen
,
ForType
::
Serial
,
DeviceAPI
::
None
,
trans_l0
);
auto
trans_block
=
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
"dma_copy_transpose"
),
trans_l1
);
auto
post_cast_store
=
Store
::
make
(
post_transpose_buffer_
,
Cast
::
make
(
t_type_
,
Load
::
make
(
Float
(
16
),
post_trans_cast_
,
tt0
,
1
)),
tt0
,
1
);
auto
post_cast_for
=
For
::
make
(
tt0
,
0
,
TransTotalSize
,
ForType
::
Serial
,
DeviceAPI
::
None
,
post_cast_store
);
auto
post_cast_attr
=
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
"vec_single_cast"
),
post_cast_for
);
trans_attr
=
Block
::
make
(
Block
::
make
(
pre_cast_attr
,
trans_block
),
post_cast_attr
);
}
auto
post_copy
=
Store
::
make
(
op
->
buffer_var
,
Load
::
make
(
t_type_
,
post_transpose_buffer_
,
tt1
*
TransAxisLen
+
tt0
,
1
),
dst_base_index
+
loop_w
*
TransAxisLen
*
ori_h
+
loop_h
*
TransAxisLen
+
tt1
*
ori_h
+
tt0
,
1
);
auto
post_l0
=
For
::
make
(
tt0
,
0
,
TransAxisLen
,
ForType
::
Serial
,
DeviceAPI
::
None
,
post_copy
);
auto
post_l1
=
For
::
make
(
tt1
,
0
,
unit_width
,
ForType
::
Serial
,
DeviceAPI
::
None
,
post_l0
);
auto
post_attr
=
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
"dma_copy"
),
post_l1
);
auto
full_inner
=
Block
::
make
(
Block
::
make
(
pre_attr
,
trans_attr
),
post_attr
);
auto
inner_w
=
For
::
make
(
loop_w
,
0
,
ori_block_w
,
ForType
::
Serial
,
DeviceAPI
::
None
,
full_inner
);
if
(
ori_block_w
.
as
<
IntImm
>
()
&&
ori_block_w
.
as
<
IntImm
>
()
->
value
==
1
)
{
std
::
unordered_map
<
const
Variable
*
,
Expr
>
init
;
init
[
loop_w
.
get
()]
=
0
;
inner_w
=
Simplify
(
Substitute
(
full_inner
,
init
));
}
auto
inner_h
=
For
::
make
(
loop_h
,
0
,
ori_block_h
,
ForType
::
Serial
,
DeviceAPI
::
None
,
inner_w
);
if
(
ori_block_h
.
as
<
IntImm
>
()
&&
ori_block_h
.
as
<
IntImm
>
()
->
value
==
1
)
{
std
::
unordered_map
<
const
Variable
*
,
Expr
>
init
;
init
[
loop_h
.
get
()]
=
0
;
inner_h
=
Simplify
(
Substitute
(
inner_w
,
init
));
}
return
inner_h
;
}
}
}
return
s
;
}
private:
bool
is_candidate_
{
false
};
bool
is_native_transpose_
{
false
};
bool
is_block_transpose_
{
false
};
int
align_value
{
FREE_ALIGN
};
Var
align_buffer_
;
Array
<
Var
>
trans_vars_
;
Array
<
Var
>
loop_vars_
;
Array
<
Expr
>
loop_extends_
;
std
::
vector
<
const
For
*>
remain_fors_
;
Type
t_type_
;
Var
pre_transpose_buffer_
;
Var
pre_trans_cast_
;
Var
post_trans_cast_
;
Var
post_transpose_buffer_
;
};
class
ForVarUnique
:
public
IRMutator
{
public:
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
auto
body
=
this
->
Mutate
(
op
->
body
);
if
(
var_maps_
.
count
(
op
->
loop_var
.
get
()))
{
Var
new_var
=
Var
(
"ii"
+
std
::
to_string
(
++
index_
));
std
::
unordered_map
<
const
Variable
*
,
Expr
>
value_map
;
value_map
[
op
->
loop_var
.
get
()]
=
new_var
;
auto
new_body
=
Substitute
(
body
,
value_map
);
var_maps_
[
new_var
.
get
()]
=
1
;
return
For
::
make
(
new_var
,
op
->
min
,
op
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
new_body
);
}
var_maps_
[
op
->
loop_var
.
get
()]
=
1
;
return
For
::
make
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
body
);
}
private:
std
::
unordered_map
<
const
Variable
*
,
int
>
var_maps_
;
int
index_
{
0
};
};
class
LoopReorder
:
public
IRMutator
{
public:
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
.
as
<
StringImm
>
())
{
in_insn_
=
true
;
pragma_
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
for_map_
.
clear
();
ori_vars_
=
{};
var_order_
.
clear
();
auto
ret
=
this
->
Mutate
(
op
->
body
);
in_insn_
=
false
;
if
(
!
has_changed_
)
{
return
s
;
}
if
(
var_order_
.
empty
())
{
ret
=
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
ret
);
for
(
size_t
i
=
0
;
i
<
ori_vars_
.
size
();
++
i
)
{
CHECK_GT
(
for_map_
.
count
(
ori_vars_
[
i
].
get
()),
0
);
auto
ptr
=
for_map_
[
ori_vars_
[
i
].
get
()];
ret
=
For
::
make
(
ptr
->
loop_var
,
ptr
->
min
,
ptr
->
extent
,
ptr
->
for_type
,
ptr
->
device_api
,
ret
);
}
return
ret
;
}
for
(
size_t
i
=
0
;
i
<
var_order_
.
size
();
++
i
)
{
CHECK_GT
(
for_map_
.
count
(
var_order_
[
i
].
get
()),
0
);
auto
ptr
=
for_map_
[
var_order_
[
i
].
get
()];
ret
=
For
::
make
(
ptr
->
loop_var
,
ptr
->
min
,
ptr
->
extent
,
ptr
->
for_type
,
ptr
->
device_api
,
ret
);
}
ret
=
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
ret
);
return
ret
;
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
if
(
in_insn_
)
{
for_map_
[(
op
->
loop_var
).
get
()]
=
op
;
ori_vars_
.
push_back
(
op
->
loop_var
);
auto
body
=
this
->
Mutate
(
op
->
body
);
return
body
;
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
int
dst_pos
=
GetVectorizedVarPosition
(
op
->
index
,
ori_vars_
);
int
len
=
static_cast
<
int
>
(
ori_vars_
.
size
());
std
::
vector
<
const
Load
*>
srcs
;
auto
get_loads
=
[
&
srcs
](
const
NodeRef
&
node
)
{
if
(
const
auto
v
=
node
.
as
<
Load
>
())
{
srcs
.
push_back
(
v
);
}
};
PostOrderVisit
(
op
->
value
,
get_loads
);
bool
same_pos
=
true
;
std
::
vector
<
int
>
srcs_pos
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
srcs
.
size
());
++
i
)
{
int
temp_pos
=
GetVectorizedVarPosition
(
srcs
[
i
]
->
index
,
ori_vars_
);
srcs_pos
.
push_back
(
temp_pos
);
if
(
temp_pos
!=
dst_pos
)
{
same_pos
=
false
;
}
}
has_changed_
=
false
;
if
(
dst_pos
>=
0
&&
len
>=
2
&&
dst_pos
!=
(
len
-
1
)
&&
(
same_pos
||
pragma_
==
"broadcast"
))
{
// Src Load empty; all Load and Dst has the same key axis; broadcast
has_changed_
=
true
;
var_order_
.
push_back
(
ori_vars_
[
dst_pos
]);
for
(
int
i
=
len
-
1
;
i
>=
0
;
i
--
)
{
if
(
i
!=
dst_pos
)
{
var_order_
.
push_back
(
ori_vars_
[
i
]);
}
}
}
else
if
(
pragma_
.
find
(
"reduce"
)
!=
pragma_
.
npos
&&
len
>=
2
&&
srcs_pos
[
0
]
!=
(
len
-
1
))
{
// based on dst key axis: reduce
has_changed_
=
true
;
var_order_
.
push_back
(
ori_vars_
[
srcs_pos
[
0
]]);
for
(
int
i
=
len
-
1
;
i
>=
0
;
i
--
)
{
if
(
i
!=
srcs_pos
[
0
])
{
var_order_
.
push_back
(
ori_vars_
[
i
]);
}
}
}
return
s
;
}
private:
std
::
unordered_map
<
const
Variable
*
,
const
For
*>
for_map_
;
std
::
vector
<
Var
>
var_order_
;
Array
<
Var
>
ori_vars_
;
bool
has_changed_
{
false
};
bool
in_insn_
{
false
};
std
::
string
pragma_
;
};
class
IfReorder
:
public
IRMutator
{
public:
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
.
as
<
StringImm
>
()
&&
!
exclude_align_analyze_list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
))
{
in_insn_
=
true
;
for_vars_
.
clear
();
if_vars_
.
clear
();
for_vec_
.
clear
();
if_vec_
.
clear
();
auto
body
=
this
->
Mutate
(
op
->
body
);
in_insn_
=
false
;
if
(
!
if_vec_
.
empty
())
{
Stmt
new_s
=
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
body
);
for
(
auto
if_op
:
if_vec_
)
{
new_s
=
IfThenElse
::
make
(
if_op
->
condition
,
new_s
);
}
for
(
auto
for_op
=
for_vec_
.
rbegin
();
for_op
!=
for_vec_
.
rend
();
++
for_op
)
{
bool
find_flag
=
false
;
for
(
auto
for_iter
=
for_vars_
.
begin
();
for_iter
!=
for_vars_
.
end
();
++
for_iter
)
{
if
(
Equal
((
*
for_iter
),
(
*
for_op
)
->
loop_var
))
{
find_flag
=
true
;
break
;
}
}
if
(
find_flag
)
{
new_s
=
For
::
make
((
*
for_op
)
->
loop_var
,
(
*
for_op
)
->
min
,
(
*
for_op
)
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
new_s
);
}
}
return
new_s
;
}
return
s
;
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
if
(
in_insn_
)
{
for_vec_
.
push_back
(
op
);
for_vars_
.
push_back
(
op
->
loop_var
);
Stmt
body
=
this
->
Mutate
(
op
->
body
);
std
::
vector
<
Var
>::
iterator
for_iter
;
for
(
for_iter
=
for_vars_
.
begin
();
for_iter
!=
for_vars_
.
end
();
++
for_iter
)
{
if
(
Equal
((
*
for_iter
),
op
->
loop_var
))
{
break
;
}
}
if
(
!
if_vec_
.
empty
())
{
std
::
vector
<
Var
>::
iterator
if_iter
;
bool
find_flag
=
false
;
for
(
if_iter
=
if_vars_
.
begin
();
if_iter
!=
if_vars_
.
end
();
++
if_iter
)
{
if
(
Equal
((
*
if_iter
),
op
->
loop_var
))
{
find_flag
=
true
;
break
;
}
}
if
(
find_flag
)
{
return
body
;
}
for_vars_
.
erase
(
for_iter
);
return
For
::
make
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
body
);
}
for_vars_
.
erase
(
for_iter
);
return
For
::
make
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
body
);
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
IfThenElse
*
op
,
const
Stmt
&
s
)
final
{
if
(
in_insn_
)
{
if_vec_
.
push_back
(
op
);
for
(
auto
loop_var
:
for_vars_
)
{
if
(
HasVars
(
op
->
condition
,
loop_var
))
{
if_vars_
.
push_back
(
loop_var
);
}
}
Stmt
body
=
this
->
Mutate
(
op
->
then_case
);
return
body
;
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
if
(
in_insn_
)
{
return
s
;
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
private:
bool
in_insn_
{
false
};
std
::
vector
<
const
IfThenElse
*>
if_vec_
;
std
::
vector
<
Var
>
if_vars_
;
std
::
vector
<
Var
>
for_vars_
;
std
::
vector
<
const
For
*>
for_vec_
;
std
::
vector
<
const
For
*>
before_if_
;
};
}
// namespace ir
}
// namespace akg
#endif // IR_TRANSFORM_H_
\ No newline at end of file
src/include/ir_pass.h
浏览文件 @
59f460e7
...
...
@@ -265,6 +265,10 @@ Stmt RewriteBroadcastVector(Stmt stmt);
Stmt
OptimizePragma
(
Stmt
stmt
);
Stmt
PackStore
(
Stmt
stmt
);
Stmt
RecoverStore
(
Stmt
stmt
);
Stmt
RewriteByAlignDynamic
(
Stmt
stmt
);
Stmt
EliminateAtomicDma
(
Stmt
stmt
);
...
...
src/pass/analyze_align.h
浏览文件 @
59f460e7
此差异已折叠。
点击以展开。
src/pass/analyze_align_dynamic.cc
浏览文件 @
59f460e7
...
...
@@ -466,7 +466,7 @@ class AlignVistor : public IRVisitor {
// only scan dma insns
if
(
op
->
attr_key
==
"pragma_ub_gm"
||
(
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
.
as
<
StringImm
>
()
&&
op
->
value
.
as
<
StringImm
>
()
->
value
!=
"vec_binary_dropout"
&&
exclude_list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
)
==
0
))
{
exclude_
align_analyze_
list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
)
==
0
))
{
bool
in_dma_copy
=
false
;
if
(
op
->
value
.
as
<
StringImm
>
()
&&
op
->
value
.
as
<
StringImm
>
()
->
value
==
"dma_copy"
)
{
in_dma_copy
=
true
;
...
...
src/pass/analyze_align_static.cc
浏览文件 @
59f460e7
此差异已折叠。
点击以展开。
src/pass/merge_loops.cc
浏览文件 @
59f460e7
...
...
@@ -43,7 +43,7 @@ class LoopsCompacter : public IRMutator {
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
"pragma_ub_gm"
||
(
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
->
IsInstance
<
StringImm
>
()
&&
!
exclude_list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
)))
{
!
exclude_
align_analyze_
list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
)))
{
stores_
=
Array
<
NodeRef
>
();
loads_
=
Array
<
NodeRef
>
();
GetStoreAndLoads
(
op
->
body
,
stores_
,
loads_
);
...
...
src/pass/multi_last_axis_reduction.cc
浏览文件 @
59f460e7
...
...
@@ -192,6 +192,7 @@ class MultiLastAxisReduction : public IRMutator {
lastResult
=
loadTmp
+
storeLeft
;
}
broadcastNum
=
Call
::
make
(
type_tmp
,
"vector_dup"
,
{
broadcastNum
},
Call
::
PureIntrinsic
);
Stmt
stForOnce
=
Store
::
make
(
tmpBuffer
,
storeResult
,
newIdx
,
storeTmp
->
predicate
);
Stmt
stForTwice
=
Store
::
make
(
storeTmp
->
buffer_var
,
lastResult
,
storeTmp
->
index
,
storeTmp
->
predicate
);
Stmt
stBroadcast
=
Store
::
make
(
tmpBuffer
,
broadcastNum
,
newIdx
,
storeTmp
->
predicate
);
...
...
@@ -212,7 +213,7 @@ class MultiLastAxisReduction : public IRMutator {
stForOnce
=
AttrStmt
::
make
(
VarExpr
(
"0"
,
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
str
),
stForOnce
);
stForTwice
=
AttrStmt
::
make
(
VarExpr
(
"0"
,
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
str
),
stForTwice
);
stBroadcast
=
AttrStmt
::
make
(
VarExpr
(
"0"
,
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
"
broadcast
"
),
stBroadcast
);
stBroadcast
=
AttrStmt
::
make
(
VarExpr
(
"0"
,
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
"
vector_dup
"
),
stBroadcast
);
stmt
=
Block
::
make
({
stBroadcast
,
stForOnce
,
stForTwice
});
stmt
=
Allocate
::
make
(
tmpBuffer
,
type_tmp
,
extentsArray
,
const_true
(),
stmt
);
...
...
src/pass/optimize_pragma.cc
浏览文件 @
59f460e7
...
...
@@ -147,7 +147,7 @@ class EstimateAlign : public IRMutator {
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
stmt
)
final
{
if
(
air
::
ir
::
attr
::
IsPragmaKey
(
op
->
attr_key
)
&&
op
->
value
.
as
<
StringImm
>
())
{
if
(
exclude_list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
))
{
if
(
exclude_
align_analyze_
list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
))
{
return
stmt
;
}
...
...
src/pass/rewrite_by_align_dynamic.cc
浏览文件 @
59f460e7
...
...
@@ -46,7 +46,7 @@ class AxisPartitioner : public IRMutator {
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
"pragma_ub_gm"
||
(
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
->
IsInstance
<
StringImm
>
()
&&
exclude_list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
)
==
0
))
{
exclude_
index_fix_
list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
)
==
0
))
{
in_insn_
=
true
;
counter_
=
0
;
auto
ret
=
IRMutator
::
Mutate_
(
op
,
s
);
...
...
@@ -180,7 +180,7 @@ class RewriteAllocateAndIndex : public IRMutator {
}
}
if
(
op
->
attr_key
==
"pragma_ub_gm"
||
(
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
->
IsInstance
<
StringImm
>
()
&&
(
exclude_list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
)
==
0
||
(
exclude_
index_fix_
list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
)
==
0
||
op
->
value
.
as
<
StringImm
>
()
->
value
==
"scatter"
)))
{
in_insn_
=
true
;
auto
ret
=
IRMutator
::
Mutate_
(
op
,
s
);
...
...
src/pass/rewrite_by_align_static.cc
浏览文件 @
59f460e7
...
...
@@ -46,7 +46,7 @@ class AxisPartitioner : public IRMutator {
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
"pragma_ub_gm"
||
(
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
->
IsInstance
<
StringImm
>
()
&&
exclude_list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
)
==
0
))
{
exclude_
index_fix_
list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
)
==
0
))
{
in_insn_
=
true
;
counter_
=
0
;
auto
ret
=
IRMutator
::
Mutate_
(
op
,
s
);
...
...
@@ -182,7 +182,7 @@ class RewriteAllocateAndIndex : public IRMutator {
}
}
if
(
op
->
attr_key
==
"pragma_ub_gm"
||
(
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
->
IsInstance
<
StringImm
>
()
&&
(
exclude_list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
)
==
0
||
(
exclude_
index_fix_
list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
)
==
0
||
op
->
value
.
as
<
StringImm
>
()
->
value
==
"scatter"
)))
{
in_insn_
=
true
;
auto
ret
=
IRMutator
::
Mutate_
(
op
,
s
);
...
...
@@ -307,12 +307,7 @@ class RewriteAllocateAndIndex : public IRMutator {
CHECK_NE
(
align
,
0
);
int64_t
coef
=
GetIntConst
(
strides
[
0
]);
if
(
std
::
abs
(
coef
)
<
align
)
{
auto
it
=
var2ext_
.
find
(
v
.
get
());
if
(
it
!=
var2ext_
.
end
()
&&
std
::
abs
(
coef
*
it
->
second
)
<=
align
)
{
rst
+=
v
*
strides
[
0
];
}
else
{
return
SimpleFix
(
tmp_idx_bk
,
opt
.
var2expr
,
align
,
times
);
}
rst
+=
v
*
strides
[
0
];
}
else
if
(
coef
%
align
==
0
)
{
auto
new_coef
=
coef
*
times
/
align
;
rst
+=
v
*
Expr
(
static_cast
<
int32_t
>
(
new_coef
));
...
...
@@ -359,7 +354,8 @@ class RewriteAllocateAndIndex : public IRMutator {
Stmt
RewriteByAlignStatic
(
Stmt
stmt
)
{
stmt
=
AxisPartitioner
().
Run
(
stmt
);
stmt
=
RewriteAllocateAndIndex
().
Mutate
(
stmt
);
return
MergeLoops
(
stmt
);
stmt
=
MergeLoops
(
stmt
);
return
stmt
;
}
}
// namespace ir
}
// namespace akg
src/pass/store_pack.cc
0 → 100644
浏览文件 @
59f460e7
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include "emit_insn/insn_info.h"
#include "emit_insn/ir_transform.h"
#include "analyze_align.h"
namespace
akg
{
namespace
ir
{
class
ReducePacker
:
public
IRMutator
{
public:
ReducePacker
()
=
default
;
~
ReducePacker
()
override
=
default
;
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
"pragma_ub_gm"
||
(
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
->
IsInstance
<
StringImm
>
()
&&
!
exclude_align_analyze_list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
)))
{
IRInfo
info
;
ParserVisitor
(
info
,
false
).
Run
(
s
);
if
(
info
.
ChangeLastDimReduce
())
{
auto
body
=
info
.
GenStmt
();
return
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
info
.
arith_info
.
insn_type
),
body
);
}
return
s
;
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
};
Stmt
PackStore
(
Stmt
stmt
)
{
stmt
=
TransposeTransform
().
Mutate
(
stmt
);
stmt
=
ReducePacker
().
Mutate
(
stmt
);
return
stmt
;
}
}
// namespace ir
}
// namespace akg
\ No newline at end of file
src/pass/store_recover.cc
0 → 100644
浏览文件 @
59f460e7
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include "emit_insn/insn_info.h"
#include "analyze_align.h"
#include "emit_insn/ir_transform.h"
namespace
akg
{
namespace
ir
{
class
ReduceRecover
:
public
IRMutator
{
public:
ReduceRecover
()
=
default
;
~
ReduceRecover
()
override
=
default
;
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
->
IsInstance
<
StringImm
>
()
&&
op
->
value
.
as
<
StringImm
>
()
->
value
.
find
(
"reduce_"
)
!=
std
::
string
::
npos
)
{
old_pragma_
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
if
(
old_pragma_
==
"reduce_add"
)
{
new_pragma_
=
"vec_binary_add"
;
}
else
if
(
old_pragma_
==
"reduce_max"
)
{
new_pragma_
=
"vec_binary_max"
;
}
else
if
(
old_pragma_
==
"reduce_min"
)
{
new_pragma_
=
"vec_binary_min"
;
}
else
if
(
old_pragma_
==
"reduce_fargmax"
)
{
new_pragma_
=
"vec_binary_fargmax"
;
}
else
if
(
old_pragma_
==
"reduce_fargmin"
)
{
new_pragma_
=
"vec_binary_fargmin"
;
}
in_reduce_
=
true
;
auto
body
=
this
->
Mutate
(
op
->
body
);
in_reduce_
=
false
;
return
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
new_pragma_
),
body
);
}
else
if
(
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
->
IsInstance
<
StringImm
>
()
&&
op
->
value
.
as
<
StringImm
>
()
->
value
==
"dma_copy_transpose"
)
{
return
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
"vtranspose"
),
op
->
body
);
}
else
if
(
op
->
attr_key
==
"align_info"
)
{
return
this
->
Mutate
(
op
->
body
);
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
if
(
in_reduce_
)
{
if
(
old_pragma_
==
"reduce_fargmax"
)
{
auto
load_load
=
op
->
value
.
as
<
Call
>
()
->
args
[
0
];
auto
src_load
=
Load
::
make
(
op
->
value
.
type
(),
op
->
buffer_var
,
op
->
index
,
op
->
predicate
);
auto
new_value
=
Call
::
make
(
load_load
.
type
(),
"fargmax"
,
{
src_load
,
load_load
},
Call
::
CallType
::
PureIntrinsic
);
auto
new_store
=
Store
::
make
(
op
->
buffer_var
,
new_value
,
op
->
index
,
op
->
predicate
);
return
new_store
;
}
else
if
(
old_pragma_
==
"reduce_fargmin"
)
{
auto
load_load
=
op
->
value
.
as
<
Call
>
()
->
args
[
0
];
auto
src_load
=
Load
::
make
(
op
->
value
.
type
(),
op
->
buffer_var
,
op
->
index
,
op
->
predicate
);
auto
new_value
=
Call
::
make
(
load_load
.
type
(),
"fargmin"
,
{
src_load
,
load_load
},
Call
::
CallType
::
PureIntrinsic
);
auto
new_store
=
Store
::
make
(
op
->
buffer_var
,
new_value
,
op
->
index
,
op
->
predicate
);
return
new_store
;
}
else
if
(
old_pragma_
==
"reduce_add"
)
{
auto
src_load
=
Load
::
make
(
op
->
value
.
type
(),
op
->
buffer_var
,
op
->
index
,
op
->
predicate
);
auto
new_value
=
Add
::
make
(
src_load
,
op
->
value
.
as
<
Call
>
()
->
args
[
0
]);
auto
new_store
=
Store
::
make
(
op
->
buffer_var
,
new_value
,
op
->
index
,
op
->
predicate
);
return
new_store
;
}
else
if
(
old_pragma_
==
"reduce_max"
)
{
auto
src_load
=
Load
::
make
(
op
->
value
.
type
(),
op
->
buffer_var
,
op
->
index
,
op
->
predicate
);
auto
new_value
=
Max
::
make
(
src_load
,
op
->
value
.
as
<
Call
>
()
->
args
[
0
]);
auto
new_store
=
Store
::
make
(
op
->
buffer_var
,
new_value
,
op
->
index
,
op
->
predicate
);
return
new_store
;
}
else
if
(
old_pragma_
==
"reduce_min"
)
{
auto
src_load
=
Load
::
make
(
op
->
value
.
type
(),
op
->
buffer_var
,
op
->
index
,
op
->
predicate
);
auto
new_value
=
Min
::
make
(
src_load
,
op
->
value
.
as
<
Call
>
()
->
args
[
0
]);
auto
new_store
=
Store
::
make
(
op
->
buffer_var
,
new_value
,
op
->
index
,
op
->
predicate
);
return
new_store
;
}
else
{
return
s
;
}
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
private:
std
::
string
old_pragma_
;
std
::
string
new_pragma_
;
bool
in_reduce_
;
};
std
::
string
GetOpCode
(
const
std
::
string
&
op_type
)
{
std
::
string
op_code
{};
if
(
op_type
==
"Add"
)
{
op_code
=
"vadds"
;
}
else
if
(
op_type
==
"Mul"
)
{
op_code
=
"vmuls"
;
}
else
if
(
op_type
==
"vaxpy"
)
{
op_code
=
"vaxpy"
;
}
else
if
(
op_type
==
"DMACopy"
)
{
op_code
=
"vector_dup"
;
}
return
op_code
;
}
class
FinetunePragma
:
public
IRMutator
{
public:
FinetunePragma
()
=
default
;
~
FinetunePragma
()
override
=
default
;
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
((
op
->
attr_key
==
"pragma_emit_insn"
&&
op
->
value
->
IsInstance
<
StringImm
>
()
&&
!
exclude_align_analyze_list
.
count
(
op
->
value
.
as
<
StringImm
>
()
->
value
)))
{
IRInfo
info
;
ParserVisitor
(
info
,
true
).
Run
(
s
);
std
::
string
op_code
=
GetOpCode
(
info
.
arith_info
.
op_type
);
if
(
!
info
.
arith_info
.
dst_info
.
IsUB
()
||
op_code
.
empty
()
||
(
!
info
.
arith_info
.
src_info
.
empty
()
&&
!
info
.
arith_info
.
src_info
[
0
].
IsUB
()))
{
return
s
;
}
if
(
info
.
arith_info
.
insn_type
==
"simd"
&&
info
.
arith_info
.
scalar_imm_num
==
1
&&
(
op_code
==
"vmuls"
||
op_code
==
"vadds"
)
&&
!
info
.
arith_info
.
dst_info
.
p_store
->
value
.
type
().
is_float
())
{
return
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
"scalar_calc"
),
op
->
body
);
}
if
(
info
.
arith_info
.
insn_type
==
"vector_scalar"
||
info
.
arith_info
.
insn_type
==
"vector_dump"
)
{
return
GenStore
(
info
,
op_code
,
0
);
}
else
if
(
info
.
arith_info
.
insn_type
==
"simd"
&&
info
.
arith_info
.
scalar_imm_num
>
0
)
{
CHECK_EQ
(
info
.
arith_info
.
scalar_imm_num
,
1
);
return
GenStore
(
info
,
op_code
,
1
);
}
else
if
(
info
.
arith_info
.
insn_type
==
"simd"
&&
info
.
arith_info
.
scalar_imm_num
==
0
&&
info
.
arith_info
.
op_type
==
"DMACopy"
&&
info
.
arith_info
.
dst_info
.
IsUB
()
&&
info
.
arith_info
.
src_info
.
size
()
==
1
&&
info
.
arith_info
.
src_info
[
0
].
IsUB
()
&&
info
.
arith_info
.
dst_info
.
p_store
->
value
.
type
().
is_float
())
{
/// change copy_ub_to_ub (fp16 or fp32) to adds (scalar = 0)
op_code
=
"vadds"
;
info
.
arith_info
.
scalar_imm_num
=
1
;
info
.
arith_info
.
scalar_imm
=
FloatImm
::
make
(
info
.
arith_info
.
dst_info
.
p_store
->
value
.
type
(),
0
);
return
GenStore
(
info
,
op_code
,
1
);
}
else
if
(
info
.
arith_info
.
op_type
==
"DMACopy"
&&
(
info
.
arith_info
.
insn_type
==
"scalar"
||
info
.
arith_info
.
insn_type
==
"discrete"
)
&&
info
.
arith_info
.
dst_info
.
IsUB
()
&&
(
info
.
arith_info
.
src_info
.
size
()
==
1
&&
info
.
arith_info
.
src_info
[
0
].
IsUB
()))
{
return
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
"scalar_dma"
),
op
->
body
);
}
else
if
(
info
.
arith_info
.
op_type
==
"DMACopy"
&&
(
info
.
arith_info
.
insn_type
==
"scalar"
||
info
.
arith_info
.
insn_type
==
"discrete"
)
&&
info
.
arith_info
.
dst_info
.
IsUB
()
&&
info
.
arith_info
.
scalar_imm_num
==
1
)
{
return
GenStore
(
info
,
op_code
,
1
);
}
else
if
(
op
->
value
.
as
<
StringImm
>
()
->
value
==
"vec_single_muls"
||
op
->
value
.
as
<
StringImm
>
()
->
value
==
"vec_single_adds"
)
{
if
(
op
->
value
.
as
<
StringImm
>
()
->
value
==
"vec_single_muls"
)
{
op_code
=
"vmuls"
;
}
else
if
(
op
->
value
.
as
<
StringImm
>
()
->
value
==
"vec_single_adds"
)
{
op_code
=
"vadds"
;
}
return
GenStore
(
info
,
op_code
,
1
);
}
return
s
;
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
GenStore
(
IRInfo
&
info
,
const
std
::
string
&
intrin_name
,
const
int
scalar_type
=
0
)
{
CHECK
(
intrin_name
==
"vector_dup"
||
intrin_name
==
"vadds"
||
intrin_name
==
"vmuls"
||
intrin_name
==
"vaxpy"
);
/// scalar value
Expr
scalar_value
=
(
scalar_type
==
0
)
?
GetRef
<
Expr
>
(
info
.
arith_info
.
scalar_load
.
p_load
)
:
info
.
arith_info
.
scalar_imm
;
Array
<
Expr
>
call_args
{};
if
(
intrin_name
==
"vector_dup"
)
{
call_args
=
{
scalar_value
};
}
else
{
Expr
tensor_value
=
GetRef
<
Expr
>
(
info
.
arith_info
.
src_info
[
0
].
p_load
);
call_args
=
{
tensor_value
,
scalar_value
};
}
/// set store
auto
old_ptr
=
info
.
arith_info
.
dst_info
.
p_store
;
Expr
new_value
=
Call
::
make
(
old_ptr
->
value
.
type
(),
intrin_name
,
call_args
,
Call
::
PureIntrinsic
);
Stmt
ret
=
Store
::
make
(
old_ptr
->
buffer_var
,
new_value
,
old_ptr
->
index
,
old_ptr
->
predicate
);
if
(
scalar_type
==
0
)
{
auto
scalar_vars
=
info
.
arith_info
.
scalar_load
.
vars
;
/// set inner for loop
for
(
int
i
=
static_cast
<
int
>
(
info
.
for_info
.
vars
.
size
())
-
1
;
i
>=
0
;
--
i
)
{
if
(
!
IsInArray
(
scalar_vars
,
info
.
for_info
.
vars
[
i
]))
{
ret
=
For
::
make
(
info
.
for_info
.
vars
[
i
],
0
,
info
.
for_info
.
exts
[
i
],
ForType
::
Serial
,
DeviceAPI
::
None
,
ret
);
}
}
/// set attribute
ret
=
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
intrin_name
),
ret
);
/// set outer for loop
for
(
int
i
=
static_cast
<
int
>
(
info
.
for_info
.
vars
.
size
())
-
1
;
i
>=
0
;
--
i
)
{
if
(
IsInArray
(
scalar_vars
,
info
.
for_info
.
vars
[
i
]))
{
ret
=
For
::
make
(
info
.
for_info
.
vars
[
i
],
0
,
info
.
for_info
.
exts
[
i
],
ForType
::
Serial
,
DeviceAPI
::
None
,
ret
);
}
}
return
ret
;
}
else
{
for
(
int
i
=
static_cast
<
int
>
(
info
.
for_info
.
vars
.
size
())
-
1
;
i
>=
0
;
--
i
)
{
ret
=
For
::
make
(
info
.
for_info
.
vars
[
i
],
0
,
info
.
for_info
.
exts
[
i
],
ForType
::
Serial
,
DeviceAPI
::
None
,
ret
);
}
ret
=
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
"pragma_emit_insn"
,
Expr
(
intrin_name
),
ret
);
return
ret
;
}
}
};
Stmt
RecoverStore
(
Stmt
stmt
)
{
stmt
=
IfReorder
().
Mutate
(
stmt
);
stmt
=
FinetunePragma
().
Mutate
(
stmt
);
stmt
=
ReduceRecover
().
Mutate
(
stmt
);
return
stmt
;
}
}
// namespace ir
}
// namespace akg
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录