Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
akg
提交
aec205ff
A
akg
项目概览
MindSpore
/
akg
通知
59
Star
7
Fork
7
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
akg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
aec205ff
编写于
8月 06, 2020
作者:
C
cy
提交者:
wYann
8月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rewrite insn pattern generator in EmitInsn
上级
11ed37cc
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
1432 addition
and
2512 deletion
+1432
-2512
src/emit_insn/insn_args_calculator.cc
src/emit_insn/insn_args_calculator.cc
+1089
-0
src/emit_insn/insn_args_calculator.h
src/emit_insn/insn_args_calculator.h
+199
-0
src/emit_insn/insn_binary_vec_pattern.cc
src/emit_insn/insn_binary_vec_pattern.cc
+0
-1335
src/emit_insn/insn_emitter.cc
src/emit_insn/insn_emitter.cc
+26
-30
src/emit_insn/insn_info.cc
src/emit_insn/insn_info.cc
+1
-4
src/emit_insn/insn_info.h
src/emit_insn/insn_info.h
+1
-9
src/emit_insn/insn_pattern.cc
src/emit_insn/insn_pattern.cc
+32
-42
src/emit_insn/insn_pattern.h
src/emit_insn/insn_pattern.h
+4
-215
src/emit_insn/insn_single_vec_pattern.cc
src/emit_insn/insn_single_vec_pattern.cc
+0
-802
src/pass/emit_insn.cc
src/pass/emit_insn.cc
+2
-0
src/pass/multi_last_axis_reduction.cc
src/pass/multi_last_axis_reduction.cc
+4
-0
src/pass/split_tail_block.cc
src/pass/split_tail_block.cc
+74
-75
未找到文件。
src/emit_insn/insn_args_calculator.cc
0 → 100644
浏览文件 @
aec205ff
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/base.h>
#include <tvm/ir_pass.h>
#include "common/array_api.h"
#include "pass/expr_alg_simplify.h"
#include "insn_pattern.h"
#include "insn_args_calculator.h"
namespace
akg
{
InsnAxis
::
InsnAxis
(
const
For
*
for_stmt
,
const
Array
<
StmtStoreInfo
>
&
info_list
)
{
this
->
var
=
for_stmt
->
loop_var
;
this
->
extent
=
GetInt32Const
(
for_stmt
->
extent
);
this
->
min
=
GetInt32Const
(
for_stmt
->
min
);
int
index
=
0
;
for
(
auto
it
:
info_list
)
{
auto
stride
=
GetInt32Const
(
GetStrideByAxis
(
it
->
var_
,
it
->
strides_
,
this
->
var
));
this
->
stride_list
.
push_back
(
stride
);
if
(
index
==
0
)
{
this
->
dst_stride
=
stride
;
}
else
{
this
->
src_stride_list
.
push_back
(
stride
);
}
index
++
;
}
}
Expr
InsnAxis
::
GetStrideByAxis
(
const
Array
<
Var
>
&
vars
,
const
Array
<
Expr
>
&
strides
,
Var
obj_var
)
{
int
index
=
0
;
for
(
auto
var_it
:
vars
)
{
if
(
Equal
(
var_it
,
obj_var
))
{
return
strides
[
index
];
}
index
++
;
}
return
Expr
(
0
);
};
bool
InsnAxis
::
IsValid
()
{
return
this
->
is_valid
;
}
void
InsnAxis
::
Print
(
const
std
::
string
&
name
)
{
if
(
!
name
.
empty
())
{
LOG
(
DEBUG
)
<<
"********** "
<<
name
<<
" ************"
;
}
auto
r_stride
=
this
->
src_stride_list
.
size
()
>
1
?
src_stride_list
[
1
]
:
99999
;
LOG
(
DEBUG
)
<<
"var:"
<<
this
->
var
<<
" extent:"
<<
this
->
extent
<<
" min:"
<<
this
->
min
<<
" dst_stride:"
<<
this
->
dst_stride
<<
" src_stride_l:"
<<
this
->
src_stride_list
.
front
()
<<
"src_stride_r:"
<<
r_stride
;
}
Array
<
StmtStoreInfo
>
GetInfoList
(
const
StmtStoreInfo
&
dst_info
,
const
Array
<
StmtStoreInfo
>
&
src_info_list
)
{
Array
<
StmtStoreInfo
>
res
;
res
.
push_back
(
dst_info
.
Copy
());
for
(
auto
it
:
src_info_list
)
{
res
.
push_back
(
it
.
Copy
());
}
return
res
;
};
std
::
list
<
InsnAxis
>
GetAxisList
(
const
StmtInfo
&
for_info
,
const
Array
<
StmtStoreInfo
>
&
info_list
)
{
std
::
list
<
InsnAxis
>
axis_list
;
for
(
auto
it
:
for_info
.
ops_
)
{
auto
for_stmt
=
it
.
as
<
For
>
();
CHECK
(
for_stmt
);
auto
axis
=
InsnAxis
(
for_stmt
,
info_list
);
axis_list
.
push_back
(
axis
);
}
return
axis_list
;
}
void
Print
(
std
::
list
<
InsnAxis
>
&
axis_list
)
{
LOG
(
DEBUG
)
<<
"+++++++++++++++++++ AXIS_LIST +++++++++++++++++++"
;
int
index
=
0
;
for
(
auto
it
:
axis_list
)
{
LOG
(
DEBUG
)
<<
"================== INDEX "
<<
index
<<
" ================="
;
it
.
Print
();
index
++
;
}
LOG
(
DEBUG
)
<<
"------------------ END ---------------------"
;
}
InsnArgsCalculator
::
InsnArgsCalculator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
)
:
dst_info_list_
(
dst_info_list
),
src_info_list_
(
src_info_list
),
for_info_
(
for_info
),
intrin_name_
(
intrin_name
)
{
InitArg
();
CalAxis
();
}
void
InsnArgsCalculator
::
CalAxis
()
{
CHECK
(
!
dst_info_list_
.
empty
());
dst_info_
=
dst_info_list_
[
0
];
if
(
src_info_list_
.
empty
())
{
src_info_list_
=
{
dst_info_
.
Copy
()};
}
auto
src_info
=
src_info_list_
[
0
];
dst_info_
.
Print
();
for
(
auto
src_info_it
:
src_info_list_
)
{
src_info_it
.
Print
();
if
(
src_info_it
->
name_
==
dst_info_
->
name_
)
{
meta_
.
same_dst_src
=
true
;
}
}
meta_
.
dst_block_size
=
GetUbBlkSize
(
dst_info_
->
dtype_
);
meta_
.
src_block_size
=
GetUbBlkSize
(
src_info
->
dtype_
);
meta_
.
cast
=
meta_
.
dst_block_size
!=
meta_
.
src_block_size
;
meta_
.
block_size
=
meta_
.
dst_block_size
<=
meta_
.
src_block_size
?
meta_
.
dst_block_size
:
meta_
.
src_block_size
;
meta_
.
src_dtype
=
src_info
->
dtype_
;
meta_
.
dst_dtype
=
dst_info_
->
dtype_
;
meta_
.
dtype
=
meta_
.
dst_dtype
.
bits
()
>=
meta_
.
src_dtype
.
bits
()
?
meta_
.
dst_dtype
:
meta_
.
src_dtype
;
auto
elem_offset_mod
=
ir
::
ExprSimplifier
().
Simplify
(
Mod
::
make
(
dst_info_
->
elem_offset_
,
meta_
.
block_size
));
if
(
elem_offset_mod
.
as
<
IntImm
>
())
{
meta_
.
block_offset
=
elem_offset_mod
.
as
<
IntImm
>
()
->
value
;
}
axis_list_
=
GetAxisList
(
for_info_
,
GetInfoList
(
dst_info_
,
src_info_list_
));
}
// namespace akg
void
InsnArgsCalculator
::
InitArg
()
{
arg_
.
src_m1_list
=
{
0
,
0
};
arg_
.
src_m0_list
=
{
1
,
1
};
}
std
::
function
<
bool
(
const
InsnAxis
&
)
>
InsnArgsCalculator
::
GetStrideLambda
()
{
return
[
&
](
const
InsnAxis
&
axis
)
{
auto
is_stride
=
[
&
](
int
stride
)
{
return
stride
%
meta_
.
block_size
==
0
;
};
auto
zero_stride
=
[
&
](
int
stride
)
{
return
stride
==
0
;
};
return
std
::
all_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
is_stride
)
&&
!
std
::
all_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
zero_stride
);
};
}
std
::
function
<
bool
(
const
InsnAxis
&
)
>
InsnArgsCalculator
::
GetM0LimitLambda
()
{
return
[
&
](
const
InsnAxis
&
axis
)
{
auto
is_limit
=
[
&
](
int
stride
)
{
return
stride
/
meta_
.
block_size
<
MAX_STRIDE_M0_SINGLE
;
};
return
std
::
all_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
is_limit
);
};
}
std
::
function
<
bool
(
const
InsnAxis
&
)
>
InsnArgsCalculator
::
GetBlockStrideLimitLambda
()
{
return
[
&
](
const
InsnAxis
&
axis
)
{
auto
is_limit
=
[
&
](
int
stride
)
{
return
stride
/
meta_
.
block_size
<=
max_block_stride_
;
};
return
std
::
all_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
is_limit
);
};
}
std
::
function
<
bool
(
const
InsnAxis
&
)
>
InsnArgsCalculator
::
GetM1LimitLambda
()
{
return
[
&
](
const
InsnAxis
&
axis
)
{
auto
is_limit
=
[
&
](
int
stride
)
{
return
stride
/
meta_
.
src_block_size
<
MAX_STRIDE_M1
;
};
return
axis
.
dst_stride
/
meta_
.
dst_block_size
<
MAX_STRIDE_M1
&&
std
::
all_of
(
axis
.
src_stride_list
.
begin
(),
axis
.
src_stride_list
.
end
(),
is_limit
);
};
}
std
::
function
<
bool
(
const
InsnAxis
&
)
>
And
(
const
std
::
list
<
std
::
function
<
bool
(
const
InsnAxis
&
)
>>
&
lambda_list
)
{
return
[
&
lambda_list
](
const
InsnAxis
&
axis
)
{
bool
res
=
true
;
for
(
auto
lambda_it
:
lambda_list
)
{
res
=
res
&&
lambda_it
(
axis
);
}
return
res
;
};
}
AxisIt
InsnArgsCalculator
::
GetAxisByLambda
(
const
std
::
function
<
bool
(
const
InsnAxis
&
)
>
&
lambda
)
{
for
(
auto
axis_it
=
axis_list_
.
begin
();
axis_it
!=
axis_list_
.
end
();
axis_it
++
)
{
if
(
lambda
(
*
axis_it
))
{
return
axis_it
;
}
}
return
axis_list_
.
end
();
}
InsnAxis
InsnArgsCalculator
::
ExtractAxis
(
AxisIt
&
it
)
{
InsnAxis
res
=
*
it
;
axis_list_
.
erase
(
it
);
return
res
;
}
bool
InsnArgsCalculator
::
IsValid
(
AxisIt
&
it
)
{
return
it
!=
axis_list_
.
end
();
}
void
AxisSort
(
std
::
list
<
InsnAxis
>
&
axis_arr
,
bool
order
=
true
)
{
auto
up_compare
=
[
&
](
InsnAxis
&
a
,
InsnAxis
&
b
)
{
return
a
.
extent
<
b
.
extent
;
};
auto
down_compare
=
[
&
](
InsnAxis
&
a
,
InsnAxis
&
b
)
{
return
a
.
extent
>
b
.
extent
;
};
if
(
order
)
{
axis_arr
.
sort
(
up_compare
);
}
else
{
axis_arr
.
sort
(
down_compare
);
}
}
AxisIt
InsnArgsCalculator
::
GetVecAxisIt
()
{
axis_list_
.
reverse
();
auto
IsVecAxis
=
[
&
](
const
InsnAxis
&
axis
)
{
return
!
(
std
::
any_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
[](
int
stride
)
{
return
stride
>
1
;
})
||
std
::
all_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
[](
int
stride
)
{
return
stride
==
0
;
}));
};
return
GetAxisByLambda
(
IsVecAxis
);
}
SplitStat
InsnArgsCalculator
::
SplitAxis
(
int
extent
,
InsnAxis
&
axis
)
{
if
(
axis
.
extent
<=
extent
)
{
return
NO_SPLIT
;
}
if
(
axis
.
extent
%
extent
!=
0
)
{
return
TAIL
;
}
InsnAxis
new_axis
;
new_axis
.
extent
=
axis
.
extent
/
extent
;
for
(
auto
stride
:
axis
.
stride_list
)
{
new_axis
.
stride_list
.
push_back
(
stride
*
extent
);
}
auto
temp_stride_list
=
new_axis
.
stride_list
;
CHECK
(
!
temp_stride_list
.
empty
());
new_axis
.
dst_stride
=
temp_stride_list
.
front
();
temp_stride_list
.
erase
(
temp_stride_list
.
begin
());
new_axis
.
src_stride_list
=
temp_stride_list
;
new_axis
.
var
=
Var
(
axis
.
var
->
name_hint
);
axis_list_
.
push_back
(
new_axis
);
axis
.
extent
=
extent
;
return
SUCCESS
;
}
AxisIt
InsnArgsCalculator
::
GetBlockAxis
()
{
AxisSort
(
axis_list_
);
auto
stride_lambda
=
GetStrideLambda
();
auto
m0_limit_lambda
=
GetM0LimitLambda
();
auto
block_stride_limit_lambda
=
GetBlockStrideLimitLambda
();
auto
axis_it
=
GetAxisByLambda
(
And
({
stride_lambda
,
m0_limit_lambda
,
block_stride_limit_lambda
,
[
&
](
const
InsnAxis
&
axis
)
{
return
axis
.
extent
>=
FULL_BLOCK_NUM
&&
axis
.
extent
%
FULL_BLOCK_NUM
==
0
;
}}));
if
(
IsValid
(
axis_it
))
{
return
axis_it
;
}
axis_it
=
GetAxisByLambda
(
And
({
stride_lambda
,
m0_limit_lambda
,
block_stride_limit_lambda
,
[
&
](
const
InsnAxis
&
axis
)
{
return
axis
.
extent
>=
FULL_BLOCK_NUM
;
}}));
if
(
IsValid
(
axis_it
)
&&
axis_list_
.
size
()
==
1
)
{
return
axis_it
;
}
axis_list_
.
reverse
();
axis_it
=
GetAxisByLambda
(
And
({
stride_lambda
,
m0_limit_lambda
,
block_stride_limit_lambda
,
[
&
](
const
InsnAxis
&
axis
)
{
return
axis
.
extent
<
FULL_BLOCK_NUM
;
}}));
if
(
IsValid
(
axis_it
))
{
return
axis_it
;
}
return
GetAxisByLambda
(
And
({
stride_lambda
,
m0_limit_lambda
,
[
&
](
const
InsnAxis
&
axis
)
{
return
axis
.
extent
<=
FULL_BLOCK_NUM
||
axis
.
extent
%
FULL_BLOCK_NUM
==
0
;
}}));
}
AxisIt
InsnArgsCalculator
::
GetRepeatAxisIt
()
{
AxisSort
(
axis_list_
);
auto
stride_lambda
=
GetStrideLambda
();
auto
m1_limit_lambda
=
GetM1LimitLambda
();
auto
axis_it
=
GetAxisByLambda
(
And
({
stride_lambda
,
m1_limit_lambda
,
[
&
](
const
InsnAxis
&
axis
)
{
return
axis
.
extent
>=
MAX_REPEAT
-
1
;
}}));
if
(
IsValid
(
axis_it
))
{
return
axis_it
;
}
axis_list_
.
reverse
();
return
GetAxisByLambda
(
And
({
stride_lambda
,
m1_limit_lambda
}));
}
void
InsnArgsCalculator
::
SetArgMask
(
int
len
)
{
SetArgBlockNum
(
1
);
SetArgBlockLen
(
len
);
}
void
InsnArgsCalculator
::
SetArgBlockNum
(
int
data_num
)
{
arg_
.
block_num
=
data_num
;
}
void
InsnArgsCalculator
::
SetArgBlockLen
(
int
data_len
)
{
arg_
.
block_len
=
data_len
;
}
void
InsnArgsCalculator
::
SetArgM0
(
int
dst_m0
,
int
lsrc_m0
,
int
rsrc_m0
=
0
)
{
arg_
.
dst_m0
=
dst_m0
;
arg_
.
src_m0_list
=
{
lsrc_m0
,
rsrc_m0
};
}
void
InsnArgsCalculator
::
SetArgM1
(
int
dst_m1
,
int
lsrc_m1
,
int
rsrc_m1
=
0
)
{
arg_
.
dst_m1
=
dst_m1
;
arg_
.
src_m1_list
=
{
lsrc_m1
,
rsrc_m1
};
}
void
InsnArgsCalculator
::
SetArgRepeat
(
int
repeat
)
{
arg_
.
repeat
=
repeat
;
}
void
InsnArgsCalculator
::
BlockAxisReduction
()
{
Print
(
axis_list_
);
auto
block_axis_it
=
GetBlockAxis
();
if
(
IsValid
(
block_axis_it
))
{
auto
origin_block_axis
=
*
block_axis_it
;
InsnAxis
block_axis
=
ExtractAxis
(
block_axis_it
);
if
(
block_axis
.
extent
%
FULL_BLOCK_NUM
!=
0
&&
block_axis
.
extent
>
FULL_BLOCK_NUM
)
{
arg_
.
tail_len
=
block_axis
.
extent
%
FULL_BLOCK_NUM
;
block_axis
.
extent
=
FloorTo
(
block_axis
.
extent
,
FULL_BLOCK_NUM
);
arg_
.
dst_tail_offset
=
block_axis
.
dst_stride
*
block_axis
.
extent
;
for
(
auto
stride
:
block_axis
.
src_stride_list
)
{
arg_
.
src_tail_offset_list
.
push_back
(
stride
*
block_axis
.
extent
);
}
SplitAxis
(
FULL_BLOCK_NUM
,
block_axis
);
auto
repeat_axis_it
=
GetRepeatAxisIt
();
if
(
!
IsValid
(
repeat_axis_it
)
&&
axis_list_
.
size
()
>
0
)
{
for
(
auto
it
=
axis_list_
.
begin
();
it
!=
axis_list_
.
end
();
it
++
)
{
if
(
it
->
var
->
name_hint
==
block_axis
.
var
->
name_hint
)
{
axis_list_
.
erase
(
it
);
break
;
}
}
axis_list_
.
push_back
(
origin_block_axis
);
return
;
}
}
else
{
SplitAxis
(
FULL_BLOCK_NUM
,
block_axis
);
}
block_axis
.
Print
(
"BLOCK_AXIS"
);
SetArgM0
(
block_axis
.
dst_stride
/
meta_
.
block_size
,
block_axis
.
src_stride_list
.
front
()
/
meta_
.
block_size
,
block_axis
.
src_stride_list
.
back
()
/
meta_
.
block_size
);
SetArgBlockNum
(
block_axis
.
extent
);
}
}
void
InsnArgsCalculator
::
RepeatAxisReduction
()
{
Print
(
axis_list_
);
auto
repeat_axis
=
GetRepeatAxis
();
if
(
repeat_axis
.
IsValid
())
{
repeat_axis
.
Print
(
"REPEAT_AXIS"
);
SetArgM1
(
repeat_axis
.
dst_stride
/
meta_
.
dst_block_size
,
repeat_axis
.
src_stride_list
.
front
()
/
meta_
.
src_block_size
,
repeat_axis
.
src_stride_list
.
back
()
/
meta_
.
src_block_size
);
SetArgRepeat
(
repeat_axis
.
extent
);
}
}
InsnAxis
InsnArgsCalculator
::
GetInvalidAxis
()
{
InsnAxis
res
;
res
.
is_valid
=
false
;
return
res
;
}
InsnAxis
InsnArgsCalculator
::
GetRepeatAxis
()
{
auto
repeat_axis_it
=
GetRepeatAxisIt
();
if
(
IsValid
(
repeat_axis_it
))
{
InsnAxis
repeat_axis
=
ExtractAxis
(
repeat_axis_it
);
SplitAxis
(
MAX_REPEAT
-
1
,
repeat_axis
);
return
repeat_axis
;
}
return
GetInvalidAxis
();
}
void
InsnArgsCalculator
::
CastCaseReduction
()
{
if
(
axis_list_
.
empty
())
{
return
;
}
Print
(
axis_list_
);
int
cast_block_size
=
meta_
.
dst_block_size
<
meta_
.
src_block_size
?
meta_
.
dst_block_size
:
meta_
.
src_block_size
;
auto
vec_axis_it
=
GetVecAxisIt
();
if
(
IsValid
(
vec_axis_it
))
{
InsnAxis
vec_axis
=
ExtractAxis
(
vec_axis_it
);
int
max_vec_len
=
cast_block_size
*
FULL_BLOCK_NUM
;
if
(
vec_axis
.
extent
>
cast_block_size
&&
vec_axis
.
extent
<
max_vec_len
)
{
SetArgMask
(
DivFloor
(
vec_axis
.
extent
,
cast_block_size
)
*
cast_block_size
);
SetArgM0
(
1
,
1
,
1
);
}
else
if
(
vec_axis
.
extent
>=
max_vec_len
)
{
SplitAxis
(
max_vec_len
,
vec_axis
);
SetArgMask
(
DivFloor
(
vec_axis
.
extent
,
cast_block_size
)
*
cast_block_size
);
SetArgM0
(
1
,
1
,
1
);
}
else
{
SetArgBlockLen
(
cast_block_size
);
}
}
RepeatAxisReduction
();
}
int
DivFloor
(
int
a
,
int
b
)
{
if
(
a
%
b
==
0
)
{
return
a
/
b
;
}
else
{
return
a
/
b
+
1
;
}
}
void
InsnArgsCalculator
::
InsnReduction
()
{
if
(
axis_list_
.
empty
())
{
return
;
}
Print
(
axis_list_
);
auto
vec_axis_it
=
GetVecAxisIt
();
meta_
.
scalar
=
!
IsValid
(
vec_axis_it
);
if
(
!
meta_
.
scalar
)
{
InsnAxis
vec_axis
=
ExtractAxis
(
vec_axis_it
);
int
max_vec_len
=
meta_
.
block_size
*
FULL_BLOCK_NUM
;
if
(
vec_axis
.
extent
>
meta_
.
block_size
&&
vec_axis
.
extent
<
max_vec_len
&&
(
vec_axis
.
extent
%
meta_
.
block_size
!=
0
||
vec_axis
.
extent
>
max_vec_len
*
meta_
.
vec_rate
))
{
vec_axis
.
Print
(
"VEC_BLOCK_AXIS"
);
SetArgMask
(
DivFloor
(
vec_axis
.
extent
,
meta_
.
block_size
)
*
meta_
.
block_size
);
SetArgM0
(
1
,
1
,
1
);
}
else
{
SplitAxis
(
meta_
.
block_size
,
vec_axis
);
vec_axis
.
Print
(
"VEC_AXIS"
);
SetArgBlockLen
(
meta_
.
block_size
);
BlockAxisReduction
();
}
RepeatAxisReduction
();
}
else
{
BlockAxisReduction
();
RepeatAxisReduction
();
}
Print
(
axis_list_
);
}
Expr
InsnArgsCalculator
::
GetOffset
(
int
stride_index
)
{
Expr
res
=
Expr
(
0
);
for
(
auto
axis_it
:
axis_list_
)
{
auto
stride
=
axis_it
.
stride_list
[
stride_index
];
auto
mul_expr
=
Mul
::
make
(
stride
,
axis_it
.
var
);
res
=
Add
::
make
(
mul_expr
,
res
);
}
return
Simplify
(
res
);
}
StmtInfo
InsnArgsCalculator
::
ExportForInfo
()
{
if
(
for_info_
.
ops_
.
empty
())
{
return
for_info_
;
}
int
last_index
=
for_info_
.
ops_
.
size
()
-
1
;
auto
last_for
=
for_info_
.
ops_
[
last_index
].
as
<
For
>
();
auto
store_stmt
=
last_for
->
body
;
Stmt
for_stmt
=
store_stmt
;
StmtInfo
result
;
for
(
auto
axis_it
:
axis_list_
)
{
for_stmt
=
For
::
make
(
axis_it
.
var
,
axis_it
.
min
,
axis_it
.
extent
,
last_for
->
for_type
,
last_for
->
device_api
,
for_stmt
);
result
.
ops_
.
push_back
(
for_stmt
);
result
.
vars_
.
push_back
(
axis_it
.
var
);
}
return
result
;
}
PatternResult
InsnArgsCalculator
::
ExportResult
()
{
PatternResult
res
;
auto
arg_info
=
ArgInfo
(
make_node
<
ArgInfoNode
>
());
auto
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
arg_
.
body_num
;
body_args
.
GetNode
()
->
body_offset_
=
meta_
.
block_size
*
FULL_BLOCK_NUM
;
body_args
.
GetNode
()
->
repeat_
=
Expr
(
arg_
.
repeat
);
body_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
arg_
.
dst_m0
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
arg_
.
dst_m1
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
arg_
.
src_m0_list
;
body_args
.
GetNode
()
->
src_stride_m1_list_
=
arg_
.
src_m1_list
;
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
arg_
.
block_len
,
arg_
.
block_num
,
meta_
.
dtype
,
meta_
.
block_offset
);
body_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
meta_
.
block_offset
);
arg_info
.
GetNode
()
->
body_arg_info_
=
body_args
;
if
(
arg_
.
tail_len
>
0
)
{
auto
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
arg_
.
dst_tail_offset
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
arg_
.
dst_m1
);
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
arg_
.
src_m1_list
;
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
src_head_list_
=
arg_
.
src_tail_offset_list
;
tail_args
.
GetNode
()
->
body_offset_
=
meta_
.
block_size
*
FULL_BLOCK_NUM
;
tail_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
arg_
.
dst_m0
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
arg_
.
src_m0_list
;
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
arg_
.
block_len
,
arg_
.
tail_len
,
meta_
.
dtype
,
meta_
.
block_offset
);
tail_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
meta_
.
block_offset
);
arg_info
.
GetNode
()
->
tail_arg_info_
=
tail_args
;
}
StmtInfoList
info_list
=
GetInfoList
(
dst_info_
,
src_info_list_
);
CleanZeroStrides
(
info_list
);
for
(
size_t
i
=
0
;
i
<
info_list
.
size
();
i
++
)
{
info_list
[
i
].
GetNode
()
->
insn_offset_
=
GetOffset
(
i
);
}
info_list
[
1
].
Print
();
res
.
for_info
=
ExportForInfo
();
res
.
arg_info
=
arg_info
;
res
.
dst_info_list
=
{
info_list
[
0
]};
if
(
info_list
.
size
()
>
2
)
{
res
.
src_info_list
=
{
info_list
[
1
],
info_list
[
2
]};
}
else
{
res
.
src_info_list
=
{
info_list
[
1
]};
}
body_args
.
Print
();
if
(
arg_info
->
tail_arg_info_
.
defined
())
{
arg_info
->
tail_arg_info_
.
Print
();
}
return
res
;
}
SingleVecInsnArgsCalculator
::
SingleVecInsnArgsCalculator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
)
:
InsnArgsCalculator
(
dst_info_list
,
src_info_list
,
for_info
,
intrin_name
)
{}
PatternResult
SingleVecInsnArgsCalculator
::
GetInsnArgs
()
{
if
(
meta_
.
cast
)
{
CastCaseReduction
();
}
else
{
InsnReduction
();
}
return
ExportResult
();
}
BinaryVecInsnArgsCalculator
::
BinaryVecInsnArgsCalculator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
mode
,
const
std
::
string
&
intrin_name
,
bool
expand_mask
)
:
InsnArgsCalculator
(
dst_info_list
,
src_info_list
,
for_info
,
intrin_name
),
mode_
{
mode
},
expand_mask_
{
expand_mask
}
{
if
(
mode_
==
"reduction"
&&
src_info_list_
.
size
()
==
2
&&
src_info_list_
[
0
]
->
name_
==
dst_info_list
[
0
]
->
name_
)
{
auto
temp
=
src_info_list_
[
0
].
Copy
();
src_info_list_
.
Set
(
0
,
src_info_list_
[
1
].
Copy
());
src_info_list_
.
Set
(
1
,
temp
);
CalAxis
();
}
}
PatternResult
BinaryVecInsnArgsCalculator
::
GetInsnArgs
()
{
LOG
(
DEBUG
)
<<
"Binary vec Insn reduction"
;
InsnReduction
();
return
ExportResult
();
}
std
::
function
<
bool
(
const
InsnAxis
&
)
>
BinaryVecInsnArgsCalculator
::
GetM0LimitLambda
()
{
return
[
&
](
const
InsnAxis
&
axis
)
{
auto
is_limit
=
[
&
](
int
stride
)
{
return
stride
/
meta_
.
block_size
<
MAX_STRIDE_M0
;
};
return
std
::
all_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
is_limit
)
&&
axis
.
dst_stride
!=
0
;
};
}
std
::
function
<
bool
(
const
InsnAxis
&
)
>
BinaryVecInsnArgsCalculator
::
GetM1LimitLambda
()
{
return
[
&
](
const
InsnAxis
&
axis
)
{
auto
is_limit
=
[
&
](
int
stride
)
{
return
stride
/
meta_
.
src_block_size
<
MAX_STRIDE_M1
;
};
return
axis
.
dst_stride
/
meta_
.
dst_block_size
<
MAX_STRIDE_M1
&&
std
::
all_of
(
axis
.
src_stride_list
.
begin
(),
axis
.
src_stride_list
.
end
(),
is_limit
);
};
}
void
BinaryVecInsnArgsCalculator
::
InsnReduction
()
{
if
(
axis_list_
.
empty
())
{
return
;
}
Print
(
axis_list_
);
auto
vec_axis_it
=
GetVecAxisIt
();
meta_
.
scalar
=
!
IsValid
(
vec_axis_it
);
if
(
!
meta_
.
scalar
)
{
vec_axis_
=
*
vec_axis_it
;
InsnAxis
vec_axis
=
ExtractAxis
(
vec_axis_it
);
auto
bad_axis_lambda
=
[
&
](
const
InsnAxis
&
axis
)
{
int
min_stride
=
vec_axis_it
->
extent
;
auto
dst_name
=
dst_info_list_
[
0
]
->
name_
;
if
(
meta_
.
same_dst_src
&&
axis
.
dst_stride
<
min_stride
&&
axis
.
dst_stride
!=
0
)
{
return
true
;
}
return
false
;
};
auto
bad_axis_it
=
GetAxisByLambda
(
bad_axis_lambda
);
InsnAxis
bad_axis
;
bad_axis
.
is_valid
=
false
;
if
(
IsValid
(
bad_axis_it
))
{
bad_axis
=
ExtractAxis
(
bad_axis_it
);
}
int
max_vec_len
=
meta_
.
block_size
*
FULL_BLOCK_NUM
;
if
(
vec_axis
.
extent
>
meta_
.
block_size
&&
vec_axis
.
extent
<
max_vec_len
&&
(
vec_axis
.
extent
%
meta_
.
block_size
!=
0
||
vec_axis
.
extent
>
max_vec_len
*
meta_
.
vec_rate
))
{
vec_axis
.
Print
(
"VEC_BLOCK_AXIS"
);
if
(
expand_mask_
)
{
SetArgMask
(
DivFloor
(
vec_axis
.
extent
,
meta_
.
block_size
)
*
meta_
.
block_size
);
}
else
{
SetArgMask
(
vec_axis
.
extent
);
}
SetArgM0
(
1
,
1
,
1
);
}
else
{
SplitAxis
(
meta_
.
block_size
,
vec_axis
);
vec_axis
.
Print
(
"VEC_AXIS"
);
if
(
expand_mask_
&&
mode_
!=
"reduction"
)
{
SetArgBlockLen
(
meta_
.
block_size
);
}
else
{
SetArgBlockLen
(
vec_axis
.
extent
);
}
BlockAxisReduction
();
}
RepeatAxisReduction
();
if
(
bad_axis
.
IsValid
())
{
axis_list_
.
push_back
(
bad_axis
);
}
}
else
{
BlockAxisReduction
();
RepeatAxisReduction
();
}
Print
(
axis_list_
);
}
PatternResult
LastAxisReduceInsnArgsCalculator
::
GetInsnArgs
()
{
CalcParams
();
Array
<
Var
>
elim_var
;
elim_var
=
GetPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_1D
;
return
GenResult
(
elim_var
);
}
Array
<
Var
>
LastAxisReduceInsnArgsCalculator
::
GetPattern
()
{
int
body_len
=
params
.
last_dim_shape
/
params
.
vec_max_len
*
params
.
vec_max_len
;
int
tail_len
=
params
.
last_dim_shape
%
params
.
vec_max_len
;
int
cmd_body_len
=
0
;
bool
is_vadd
=
intrin_name
==
"vadd"
;
int
repeat_stride
=
FULL_BLOCK_NUM
;
if
(
is_vadd
)
{
repeat_stride
=
1
;
}
const
int
fp16_block_size
=
16
;
if
(
body_len
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
body_offset_
=
params
.
vec_max_len
;
body_args
.
GetNode
()
->
repeat_
=
Expr
(
body_len
/
params
.
vec_max_len
);
// Here use dst_stride_m1 as dst_stride
body_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
)};
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
vec_max_len
,
1
,
dst_info
->
dtype_
);
cmd_body_len
+=
GetInt32Const
(
body_args
->
repeat_
)
*
repeat_stride
;
}
if
(
tail_len
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
tail_args
.
GetNode
()
->
body_offset_
=
params
.
vec_max_len
;
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
cmd_body_len
);
tail_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
body_len
)};
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
0
)};
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
tail_len
,
1
,
dst_info
->
dtype_
);
if
(
is_vadd
)
{
cmd_body_len
+=
1
;
}
else
{
cmd_body_len
+=
tail_len
/
fp16_block_size
;
if
(
tail_len
%
fp16_block_size
!=
0
)
{
cmd_body_len
+=
1
;
}
}
}
// cmd_body_len > 1 means vcadd size greater than 128, need to use vcadd again to compute final result
// if cmd_body_len > 128, then need to recursively emit vcadd
while
(
cmd_body_len
>
1
)
{
int
cmd_tail_len
=
cmd_body_len
%
params
.
vec_max_len
;
cmd_body_len
=
cmd_body_len
/
params
.
vec_max_len
;
if
(
cmd_body_len
>
0
)
{
VectorArgInfo
mix_vec_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
mix_vec_args
.
GetNode
()
->
repeat_
=
Expr
(
cmd_body_len
);
mix_vec_args
.
GetNode
()
->
dst_head_
=
Expr
(
0
);
mix_vec_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
0
)};
mix_vec_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
mix_vec_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
mix_vec_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
)};
mix_vec_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
vec_max_len
,
1
,
dst_info
->
dtype_
);
mix_vec_arg_list
.
push_back
(
mix_vec_args
);
if
(
!
is_vadd
)
{
cmd_body_len
*=
FULL_BLOCK_NUM
;
}
}
if
(
cmd_tail_len
>
0
)
{
VectorArgInfo
mix_vec_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
mix_vec_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
mix_vec_args
.
GetNode
()
->
dst_head_
=
Expr
(
cmd_body_len
);
if
(
is_vadd
)
{
mix_vec_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
cmd_body_len
*
params
.
vec_max_len
)};
}
else
{
mix_vec_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
cmd_body_len
/
FULL_BLOCK_NUM
*
params
.
vec_max_len
)};
}
mix_vec_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
mix_vec_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
mix_vec_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
)};
mix_vec_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
cmd_tail_len
,
1
,
dst_info
->
dtype_
);
if
(
is_vadd
)
{
cmd_body_len
+=
1
;
}
else
{
cmd_body_len
+=
cmd_tail_len
/
fp16_block_size
;
if
(
cmd_tail_len
%
fp16_block_size
!=
0
)
{
cmd_body_len
+=
1
;
}
}
mix_vec_arg_list
.
push_back
(
mix_vec_args
);
}
}
params
.
insn_offset_scale_factor
=
Expr
(
params
.
block_size
);
int
max_num
=
body_len
/
params
.
vec_max_len
;
if
(
intrin_name
==
"vmax"
||
intrin_name
==
"vmin"
)
{
max_num
*=
FULL_BLOCK_NUM
;
}
if
(
max_num
>=
params
.
block_size
)
{
params
.
insn_offset_scale_factor
=
max_num
+
params
.
block_size
-
1
;
if
(
tail_len
>
0
)
{
params
.
insn_offset_scale_factor
+=
1
;
}
params
.
insn_offset_scale_factor
=
truncdiv
(
params
.
insn_offset_scale_factor
,
params
.
block_size
)
*
params
.
block_size
;
}
if
(
!
params
.
src_var
.
empty
())
{
return
GetRange
(
params
.
src_var
,
-
1
,
1
);
}
return
{};
}
PatternResult
LastAxisReduceInsnArgsCalculator
::
GenResult
(
const
Array
<
Var
>
&
elim_var
)
{
dst_info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
dst_info
,
elim_var
)
*
params
.
insn_offset_scale_factor
;
src_info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
src_info
,
elim_var
);
if
(
body_args
.
defined
())
{
body_args
.
GetNode
()
->
insn_offset_scale_factor_
=
params
.
insn_offset_scale_factor
;
}
if
(
tail_args
.
defined
())
{
tail_args
.
GetNode
()
->
insn_offset_scale_factor_
=
params
.
insn_offset_scale_factor
;
}
for
(
auto
&
arg
:
mix_vec_arg_list
)
{
arg
.
GetNode
()
->
insn_offset_scale_factor_
=
params
.
insn_offset_scale_factor
;
}
arg_info
.
GetNode
()
->
body_arg_info_
=
body_args
;
arg_info
.
GetNode
()
->
tail_arg_info_
=
tail_args
;
arg_info
.
GetNode
()
->
reduction_tail_args_
=
mix_vec_arg_list
;
CleanForInfoVars
(
for_info
,
elim_var
);
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_REDUCTION_LAST_AXIS
;
PatternResult
result
;
result
.
dst_info_list
=
{
dst_info
};
result
.
src_info_list
=
{
src_info
};
result
.
for_info
=
for_info
;
result
.
arg_info
=
arg_info
;
return
result
;
}
void
LastAxisReduceInsnArgsCalculator
::
CalcParams
()
{
// check shape len
if
(
dst_info
->
shape_
.
empty
()
||
src_info
->
shape_
.
empty
())
{
LOG
(
FATAL
)
<<
"CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0."
;
}
// check data type
if
(
dst_info
->
dtype_
!=
src_info
->
dtype_
)
{
LOG
(
FATAL
)
<<
"CCE Vector Insn Error: dst_buffer and src_buffer can not be different data type."
;
}
params
.
src_var
=
src_info
->
var_
;
params
.
block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
params
.
last_dim_shape
=
GetInt32Const
(
GetItem
(
src_info
->
shape_
,
-
1
));
params
.
vec_max_len
=
GetVecMaxLen
(
dst_info
->
dtype_
);
CHECK_NE
(
params
.
block_size
,
0
);
CHECK_NE
(
params
.
vec_max_len
,
0
);
}
/// Generete info list for bisection intrin
/// \param dst_info_list
/// \param src_info_list
/// \param for_info
/// \param if_info
/// \param last_axis
/// \param postfix
/// \return
BisectionInfoWrapper
SeparateComInfoToBisectionInfoList
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
StmtInfo
&
if_info
,
bool
last_axis
,
int
postfix
=
0
)
{
CHECK_EQ
(
dst_info_list
.
size
(),
1
);
CHECK_EQ
(
src_info_list
.
size
(),
2
);
BisectionInfoWrapper
wrapper
;
// Separate com_info and for_info
int
compare_idx
=
1
;
int
var_idx
=
-
1
;
var_idx
=
GetBisectionReductionIdx
(
dst_info_list
,
src_info_list
,
compare_idx
);
StmtStoreInfo
dst_info
=
dst_info_list
[
0
];
CHECK_GE
(
compare_idx
,
0
);
StmtStoreInfo
src_info1
=
src_info_list
[
compare_idx
];
Var
reduce_var
=
GetItem
(
src_info1
->
var_
,
var_idx
);
int
stride_len
=
GetInt32Const
(
GetItem
(
src_info1
->
strides_
,
var_idx
));
size_t
for_idx
=
0
;
bool
suc
=
GetIndexOfElement
(
for_info
.
vars_
,
VarExpr
(
reduce_var
),
for_idx
);
CHECK
(
suc
);
auto
exist_for
=
GetItem
(
for_info
.
ops_
,
for_idx
).
as
<
For
>
();
CHECK
(
exist_for
);
int
extent
=
GetInt32Const
(
exist_for
->
extent
);
int
simd_len
=
1
;
const
std
::
string
un_def_var
=
"un_def_var"
;
Var
simd_var
=
Var
(
"un_def_var"
);
CHECK_GT
(
src_info1
->
strides_
.
size
(),
0
);
CHECK_EQ
(
src_info1
->
var_
.
size
(),
src_info1
->
strides_
.
size
());
for
(
size_t
i
=
0
;
i
<=
src_info1
->
strides_
.
size
()
-
1
;
i
++
)
{
if
(
GetInt32Const
(
src_info1
->
strides_
[
i
])
==
1
)
{
simd_var
=
src_info1
->
var_
[
i
];
size_t
simd_for_idx
=
0
;
bool
suc
=
GetIndexOfElement
(
for_info
.
vars_
,
VarExpr
(
simd_var
),
simd_for_idx
);
CHECK
(
suc
);
auto
simd_for
=
GetItem
(
for_info
.
ops_
,
simd_for_idx
).
as
<
For
>
();
CHECK
(
simd_for
);
simd_len
=
GetInt32Const
(
simd_for
->
extent
);
}
}
int
block_unit
=
GetUbBlkSize
(
src_info1
->
dtype_
);
int
last_dim_len
=
((
simd_len
-
1
)
/
block_unit
+
1
)
*
block_unit
;
Var
bisec_var
;
Buffer
bisec_buffer
;
std
::
string
bisec_pre_header
=
"bisec"
;
std
::
string
bisec_name
=
bisec_pre_header
+
"_local_UB"
;
if
(
postfix
>
0
)
{
bisec_name
=
bisec_name
+
"_"
+
std
::
to_string
(
postfix
);
}
int
vec_max_len
=
GetVecMaxLen
(
dst_info
->
dtype_
);
CHECK_NE
(
vec_max_len
,
0
);
std
::
vector
<
int
>
pow2_list
=
{
0
,
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
32768
,
65536
};
int
origin_len
=
extent
;
for
(
int
i
:
pow2_list
)
{
if
(
extent
<=
i
)
{
extent
=
i
/
2
;
break
;
}
}
int
prolog_len
=
origin_len
-
extent
;
src_info1
.
Print
();
auto
src_vars
=
src_info1
->
var_
;
auto
src_strides
=
src_info1
->
strides_
;
auto
src_dims
=
src_info1
->
shape_
;
auto
new_vars
=
src_info1
->
var_
;
auto
new_strides
=
src_info1
->
strides_
;
auto
new_dims
=
src_info1
->
shape_
;
LOG
(
DEBUG
)
<<
"
\n
var_idx:"
<<
var_idx
<<
"
\n
"
;
var_idx
=
var_idx
+
src_info1
->
var_
.
size
();
if
(
var_idx
!=
static_cast
<
int
>
(
src_info1
->
var_
.
size
())
-
1
)
{
new_dims
.
Set
(
var_idx
,
extent
);
new_dims
.
Set
(
new_dims
.
size
()
-
1
,
last_dim_len
);
CHECK_GT
(
new_dims
.
size
(),
1
);
new_strides
.
Set
(
new_strides
.
size
()
-
1
,
1
);
for
(
int
i
=
static_cast
<
int
>
(
new_dims
.
size
())
-
2
;
i
>=
0
;
i
--
)
{
new_strides
.
Set
(
i
,
new_strides
[
i
+
1
]
*
new_dims
[
i
+
1
]);
}
new_dims
.
Set
(
new_dims
.
size
()
-
1
,
simd_len
);
}
else
{
new_dims
=
{
extent
};
}
// copy data from origin buffer to new temp buffer
Array
<
Expr
>
shape
=
new_dims
;
wrapper
.
original_shape_
=
new_dims
;
bisec_var
=
Var
(
bisec_name
,
Handle
());
bisec_buffer
=
BufferNode
::
make
(
bisec_var
,
dst_info
->
dtype_
,
shape
,
Array
<
Expr
>
(),
Expr
(),
bisec_name
,
SCOPE_UBUF
,
0
,
0
,
BufferType
::
kDefault
);
// Need to copy input to bisect buffer
StmtStoreInfo
copy_dst_info
{
src_info1
.
Copy
()};
StmtStoreInfo
copy_src_info
{
src_info1
.
Copy
()};
StmtInfoList
src_list
=
{
copy_src_info
};
auto
for_tmp_info
=
for_info
.
Copy
();
auto
new_for
=
GetItem
(
for_tmp_info
.
ops_
,
for_idx
).
as
<
For
>
();
CHECK
(
new_for
);
SetItem
(
for_tmp_info
.
ops_
,
static_cast
<
int
>
(
for_idx
),
For
::
make
(
new_for
->
loop_var
,
new_for
->
min
,
extent
,
new_for
->
for_type
,
new_for
->
device_api
,
new_for
->
body
));
ReplaceVarWithNewForInfo
(
copy_dst_info
,
for_info
,
for_tmp_info
);
ReplaceVarWithNewForInfo
(
copy_src_info
,
for_info
,
for_tmp_info
);
SetItem
(
copy_src_info
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
extent
));
SetItem
(
copy_dst_info
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
extent
));
SetItem
(
copy_dst_info
.
GetNode
()
->
strides_
,
var_idx
,
Expr
(
last_dim_len
));
if
(
simd_var
->
name_hint
!=
un_def_var
)
{
copy_dst_info
.
GetNode
()
->
index_
=
0
;
for
(
size_t
i
=
0
;
i
<=
new_vars
.
size
()
-
1
;
i
++
)
{
copy_dst_info
.
GetNode
()
->
index_
+=
new_vars
[
i
]
*
new_strides
[
i
];
}
}
else
{
copy_dst_info
.
GetNode
()
->
index_
=
last_dim_len
*
reduce_var
;
}
copy_dst_info
.
GetNode
()
->
elem_offset_
=
0
;
copy_dst_info
.
GetNode
()
->
name_
=
bisec_name
;
copy_dst_info
.
GetNode
()
->
buffer_
=
bisec_buffer
;
copy_dst_info
.
GetNode
()
->
data_
=
bisec_var
;
copy_dst_info
.
GetNode
()
->
strides_
=
new_strides
;
CompactComputationInfoList
(
copy_dst_info
,
src_list
,
if_info
,
for_tmp_info
);
wrapper
.
bisec_info_list_
.
emplace_back
(
StmtInfoList
{
copy_dst_info
,
copy_src_info
});
wrapper
.
for_info_list_
.
push_back
(
for_tmp_info
);
// Generate the vadd wrapper
while
(
extent
>=
0
)
{
StmtStoreInfo
dst_tmp_info
=
dst_info
.
Copy
();
StmtStoreInfo
src_tmp_info0
{
src_info1
.
Copy
()};
StmtStoreInfo
src_tmp_info1
{
src_info1
.
Copy
()};
auto
for_tmp_info
=
for_info
.
Copy
();
int
vadd_length
=
(
prolog_len
!=
0
)
?
prolog_len
:
extent
;
if
(
extent
>
0
)
{
dst_tmp_info
=
src_info1
.
Copy
();
dst_tmp_info
.
GetNode
()
->
data_alignment_
=
simd_len
;
dst_tmp_info
.
GetNode
()
->
name_
=
bisec_name
;
dst_tmp_info
.
GetNode
()
->
buffer_
=
bisec_buffer
;
dst_tmp_info
.
GetNode
()
->
data_
=
bisec_var
;
dst_tmp_info
.
GetNode
()
->
shape_
=
new_dims
;
SetItem
(
dst_tmp_info
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
vadd_length
));
dst_tmp_info
.
GetNode
()
->
strides_
=
new_strides
;
dst_tmp_info
.
GetNode
()
->
var_
=
new_vars
;
dst_tmp_info
.
GetNode
()
->
index_
=
0
;
for
(
size_t
i
=
0
;
i
<=
new_vars
.
size
()
-
1
;
i
++
)
{
dst_tmp_info
.
GetNode
()
->
index_
+=
new_vars
[
i
]
*
new_strides
[
i
];
}
if
(
prolog_len
==
0
)
{
src_tmp_info1
=
dst_tmp_info
.
Copy
();
src_tmp_info1
.
GetNode
()
->
index_
=
dst_tmp_info
.
GetNode
()
->
index_
+
extent
*
last_dim_len
;
}
else
{
SetItem
(
src_tmp_info1
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
vadd_length
));
src_tmp_info1
.
GetNode
()
->
index_
+=
extent
*
stride_len
;
}
}
src_tmp_info0
=
dst_tmp_info
.
Copy
();
auto
new_for
=
GetItem
(
for_tmp_info
.
ops_
,
for_idx
).
as
<
For
>
();
CHECK
(
new_for
);
int
temp_for_len
=
(
vadd_length
!=
0
)
?
vadd_length
:
1
;
SetItem
(
for_tmp_info
.
ops_
,
static_cast
<
int
>
(
for_idx
),
For
::
make
(
new_for
->
loop_var
,
new_for
->
min
,
temp_for_len
,
new_for
->
for_type
,
new_for
->
device_api
,
new_for
->
body
));
if
(
extent
==
0
)
{
src_tmp_info1
.
GetNode
()
->
name_
=
bisec_name
;
src_tmp_info1
.
GetNode
()
->
buffer_
=
bisec_buffer
;
src_tmp_info1
.
GetNode
()
->
data_
=
bisec_var
;
if
(
simd_var
->
name_hint
!=
un_def_var
)
{
src_tmp_info1
.
GetNode
()
->
shape_
=
RemoveItemAtIndex
(
new_dims
,
var_idx
);
src_tmp_info1
.
GetNode
()
->
strides_
=
RemoveItemAtIndex
(
new_strides
,
var_idx
);
src_tmp_info1
.
GetNode
()
->
var_
=
RemoveItemAtIndex
(
new_vars
,
var_idx
);
src_tmp_info1
.
GetNode
()
->
index_
=
0
;
for
(
size_t
i
=
0
;
i
<=
src_tmp_info1
->
var_
.
size
()
-
1
;
i
++
)
{
src_tmp_info1
.
GetNode
()
->
index_
+=
src_tmp_info1
->
var_
[
i
]
*
src_tmp_info1
->
strides_
[
i
];
}
}
else
{
src_tmp_info1
.
GetNode
()
->
shape_
=
dst_tmp_info
->
shape_
;
src_tmp_info1
.
GetNode
()
->
strides_
=
dst_tmp_info
->
strides_
;
src_tmp_info1
.
GetNode
()
->
var_
=
dst_tmp_info
->
var_
;
src_tmp_info1
.
GetNode
()
->
index_
=
dst_tmp_info
->
index_
;
}
}
ReplaceVarWithNewForInfo
(
dst_tmp_info
,
for_info
,
for_tmp_info
);
ReplaceVarWithNewForInfo
(
src_tmp_info0
,
for_info
,
for_tmp_info
);
ReplaceVarWithNewForInfo
(
src_tmp_info1
,
for_info
,
for_tmp_info
);
StmtInfoList
src_list
=
{
src_tmp_info0
,
src_tmp_info1
};
CompactComputationInfoList
(
dst_tmp_info
,
src_list
,
if_info
,
for_tmp_info
);
wrapper
.
for_info_list_
.
emplace_back
(
for_tmp_info
);
if
(
extent
==
0
)
{
// normally is bisect_tmp = bisect_tmp + bisect_tmp/src_tmp
wrapper
.
bisec_info_list_
.
emplace_back
(
StmtInfoList
{
dst_tmp_info
,
dst_tmp_info
,
src_tmp_info1
});
}
else
{
// normally is dst_tmp = dst_tmp + bisect_tmp
wrapper
.
bisec_info_list_
.
emplace_back
(
StmtInfoList
{
dst_tmp_info
,
src_tmp_info0
,
src_tmp_info1
});
}
if
(
extent
==
0
)
{
break
;
}
else
{
extent
=
extent
/
2
;
}
prolog_len
=
0
;
}
// Generate arg_info
for
(
size_t
i
=
0
;
i
<
wrapper
.
bisec_info_list_
.
size
();
++
i
)
{
auto
info_list
=
wrapper
.
bisec_info_list_
[
i
];
auto
new_for_info
=
wrapper
.
for_info_list_
[
i
];
ArgInfo
arg_info
;
auto
dst_list
=
GetRange
(
info_list
,
0
,
1
);
auto
src_list
=
GetRange
(
info_list
,
1
,
info_list
.
size
()
-
1
);
if
(
info_list
.
size
()
==
2
)
{
std
::
string
dma_intrin
=
INTRIN_NAME_COPY_UB_TO_UB
;
wrapper
.
dma_arg_info_map_
=
GetDmaCopyInsnArgs
(
dma_intrin
,
dst_list
,
src_list
,
new_for_info
);
}
else
{
// Bisect can't expand mask because it has inplace operation
if
(
i
!=
wrapper
.
bisec_info_list_
.
size
()
-
1
)
{
// Last round dont need to add
FillLastDim
(
dst_list
,
src_list
,
new_for_info
);
}
std
::
string
mode
=
GetBinaryVecMode
(
dst_list
,
src_list
,
"vadd"
,
false
);
BinaryVecInsnArgsCalculator
args_calculator
=
BinaryVecInsnArgsCalculator
(
dst_list
,
src_list
,
new_for_info
,
mode
,
""
,
false
);
PatternResult
params
=
args_calculator
.
GetInsnArgs
();
arg_info
=
params
.
arg_info
;
dst_list
=
params
.
dst_info_list
;
src_list
=
params
.
src_info_list
;
new_for_info
=
params
.
for_info
;
wrapper
.
bisec_info_list_
[
i
]
=
{
dst_list
[
0
],
src_list
[
0
],
src_list
[
1
]};
}
wrapper
.
arg_info_list_
.
push_back
(
arg_info
);
wrapper
.
for_info_list_
[
i
]
=
new_for_info
;
}
return
wrapper
;
}
/// Get CCE Binary Vector Insn Computation Info
/// \param stmt - operand stmt
/// \param intrin_name - vector intrin name
/// \param dst_info_list - dst computation info list
/// \param src_info_list - src computation info list
/// \param if_info - if info list
/// \param for_info - for info list
/// \return intrin args
ArgInfo
GetBinaryVecInsnArgs
(
const
Stmt
&
stmt
,
std
::
string
intrin_name
,
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
bool
enable_bisect
)
{
// check intrin_name
std
::
set
<
std
::
string
>
intrin_name_list
=
{
"vadd"
,
"vmax"
,
"vmin"
,
"vmul"
,
"vdiv"
,
"vsel"
,
"vsub"
,
"vand"
,
"vor"
,
"vaxpy"
,
"argmax"
,
"argmin"
,
"vmadd"
,
"vmaddrelu"
,
"vmla"
};
if
(
intrin_name_list
.
count
(
intrin_name
)
==
0
)
{
LOG
(
FATAL
)
<<
"Error: CCE Binary Vector Insn doesn't support the given intrin_name."
;
}
// get and check dst and src
GetCompactComputationInfo
(
stmt
,
dst_info_list
,
src_info_list
,
if_info
,
for_info
,
true
);
// For vmadd/vmaddrelu/vmla we only need first two src
if
(
dst_info_list
.
size
()
!=
1
||
src_info_list
.
size
()
<
2
)
{
LOG
(
FATAL
)
<<
"CCE Binary Vector Insn only support ONE dst and TWO srcs."
;
}
src_info_list
=
GetRange
(
src_info_list
,
0
,
2
);
ArgInfo
arg_info
=
ArgInfo
(
make_node
<
ArgInfoNode
>
());
// detect vector op mode
std
::
string
mode
=
GetBinaryVecMode
(
dst_info_list
,
src_info_list
,
intrin_name
,
enable_bisect
);
if
(
mode
==
"reduce_last_axis"
)
{
size_t
src_var_list_size
=
src_info_list
[
1
]
->
var_
.
size
();
if
(
src_info_list
[
0
]
->
var_
.
size
()
>
src_info_list
[
1
]
->
var_
.
size
())
{
src_var_list_size
=
src_info_list
[
0
]
->
var_
.
size
();
}
CHECK
(
src_var_list_size
>
0
)
<<
"Error: src can not be a scalar."
;
if
(
src_var_list_size
-
dst_info_list
[
0
]
->
var_
.
size
()
==
1
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_REDUCTION_LAST_AXIS
;
}
else
{
LOG
(
FATAL
)
<<
"Error: cannot support multi-last-axis reduction."
;
}
}
else
if
(
mode
==
"reduce_bisection"
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_REDUCTION_BISECTION
;
}
else
{
if
(
mode
!=
"reduction"
&&
mode
!=
"broadcast"
)
{
FillLastDim
(
dst_info_list
,
src_info_list
,
for_info
);
}
// vmax/vmin can't expand mask because it may introduce dirty data
bool
can_expand_mask
=
intrin_name
!=
"vmax"
&&
intrin_name
!=
"vmin"
;
BinaryVecInsnArgsCalculator
args_calculator
=
BinaryVecInsnArgsCalculator
(
dst_info_list
,
src_info_list
,
for_info
,
mode
,
intrin_name
,
can_expand_mask
);
PatternResult
params
=
args_calculator
.
GetInsnArgs
();
arg_info
=
params
.
arg_info
;
dst_info_list
=
params
.
dst_info_list
;
src_info_list
=
params
.
src_info_list
;
for_info
=
params
.
for_info
;
if
(
mode
==
"broadcast"
)
{
bool
has_last_axis
=
false
;
if
((
arg_info
->
body_arg_info_
.
defined
()
&&
arg_info
->
body_arg_info_
->
last_axis_info_
.
src_index_
!=
-
1
)
||
(
arg_info
->
tail_arg_info_
.
defined
()
&&
arg_info
->
tail_arg_info_
->
last_axis_info_
.
src_index_
!=
-
1
))
{
has_last_axis
=
true
;
}
if
(
has_last_axis
&&
(
intrin_name
==
"vadd"
||
intrin_name
==
"vmul"
))
{
Array
<
NodeRef
>
stores
;
Array
<
NodeRef
>
loads
;
GetStoreAndLoads
(
stmt
,
stores
,
loads
);
intrin_name
=
intrin_name
+
"s"
;
if
(
arg_info
->
body_arg_info_
.
defined
())
{
arg_info
.
GetNode
()
->
body_arg_info_
.
GetNode
()
->
last_axis_info_
.
intrin_name_
=
intrin_name
;
arg_info
.
GetNode
()
->
body_arg_info_
.
GetNode
()
->
last_axis_info_
.
src_op_
=
Downcast
<
Expr
>
(
loads
[
arg_info
->
body_arg_info_
->
last_axis_info_
.
src_index_
]);
}
}
}
}
return
arg_info
;
}
}
// namespace akg
\ No newline at end of file
src/emit_insn/insn_args_calculator.h
0 → 100644
浏览文件 @
aec205ff
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef EMIT_INSN_ARGS_CALCULATOR_H_
#define EMIT_INSN_ARGS_CALCULATOR_H_
namespace
akg
{
struct
InsnArg
{
int
dst_m0
{
1
};
int
dst_m1
{
0
};
std
::
vector
<
Expr
>
src_m0_list
;
std
::
vector
<
Expr
>
src_m1_list
;
int
repeat
{
1
};
int
block_len
{
1
};
int
block_num
{
1
};
int
body_num
{
1
};
int
tail_len
{
0
};
int
dst_tail_offset
{
0
};
std
::
vector
<
Expr
>
src_tail_offset_list
;
};
struct
Meta
{
int
block_size
{
0
};
int
src_block_size
{
0
};
int
dst_block_size
{
0
};
int
block_offset
{
0
};
const
float
vec_rate
{
0.6
};
Type
src_dtype
;
Type
dst_dtype
;
Type
dtype
;
bool
cast
{
false
};
bool
tail
{
false
};
bool
scalar
{
false
};
bool
liner
{
false
};
bool
same_dst_src
{
false
};
};
enum
SplitStat
{
SUCCESS
,
NO_SPLIT
,
TAIL
};
class
InsnAxis
{
public:
InsnAxis
()
=
default
;
InsnAxis
(
const
For
*
for_stmt
,
const
Array
<
StmtStoreInfo
>
&
info_list
);
virtual
~
InsnAxis
()
=
default
;
bool
IsValid
();
void
Print
(
const
std
::
string
&
name
=
""
);
int
min
{
0
};
int
extent
{
0
};
Var
var
;
int
dst_stride
{
0
};
int
src_stride
{
0
};
std
::
vector
<
int
>
src_stride_list
;
std
::
vector
<
int
>
stride_list
;
bool
is_valid
{
true
};
private:
Expr
GetStrideByAxis
(
const
Array
<
Var
>
&
vars
,
const
Array
<
Expr
>
&
strides
,
Var
obj_var
);
};
using
AxisIt
=
std
::
list
<
InsnAxis
>::
iterator
;
std
::
list
<
InsnAxis
>
GetAxisList
(
const
StmtInfo
&
for_info
,
const
Array
<
StmtStoreInfo
>
&
info_list
);
Array
<
StmtStoreInfo
>
GetInfoList
(
const
StmtStoreInfo
&
dst_info
,
const
Array
<
StmtStoreInfo
>
&
src_info_list
);
int
DivFloor
(
int
a
,
int
b
);
void
Print
(
std
::
list
<
InsnAxis
>
&
axis_list
);
class
InsnArgsCalculator
{
public:
InsnArgsCalculator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
);
virtual
~
InsnArgsCalculator
()
=
default
;
PatternResult
ExportResult
();
void
CalAxis
();
void
InitArg
();
virtual
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetStrideLambda
();
virtual
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetM0LimitLambda
();
virtual
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetM1LimitLambda
();
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetBlockStrideLimitLambda
();
AxisIt
GetAxisByLambda
(
const
std
::
function
<
bool
(
const
InsnAxis
&
)
>
&
lambda
);
InsnAxis
ExtractAxis
(
AxisIt
&
it
);
bool
IsValid
(
AxisIt
&
it
);
AxisIt
GetVecAxisIt
();
AxisIt
GetBlockAxis
();
AxisIt
GetRepeatAxisIt
();
InsnAxis
GetRepeatAxis
();
void
SetArgMask
(
int
len
);
void
SetArgBlockNum
(
int
data_num
);
void
SetArgBlockLen
(
int
data_len
);
void
SetArgM0
(
int
dst_m0
,
int
lsrc_m0
,
int
rsrc_m0
);
void
SetArgM1
(
int
dst_m1
,
int
lsrc_m1
,
int
rsrc_m1
);
void
SetArgRepeat
(
int
repeat
);
void
BlockAxisReduction
();
void
RepeatAxisReduction
();
void
CastCaseReduction
();
virtual
void
InsnReduction
();
StmtInfo
ExportForInfo
();
Expr
GetOffset
(
int
stride_index
);
InsnAxis
GetInvalidAxis
();
SplitStat
SplitAxis
(
int
extent
,
InsnAxis
&
axis
);
std
::
list
<
InsnAxis
>
axis_list_
;
protected:
InsnArg
arg_
;
Meta
meta_
;
StmtInfoList
dst_info_list_
;
StmtInfoList
src_info_list_
;
StmtStoreInfo
dst_info_
;
StmtInfo
for_info_
;
const
std
::
string
intrin_name_
;
const
int
max_block_stride_
{
4
};
};
class
SingleVecInsnArgsCalculator
:
public
InsnArgsCalculator
{
public:
SingleVecInsnArgsCalculator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
=
""
);
virtual
~
SingleVecInsnArgsCalculator
()
override
=
default
;
PatternResult
GetInsnArgs
();
};
class
BinaryVecInsnArgsCalculator
:
public
InsnArgsCalculator
{
public:
BinaryVecInsnArgsCalculator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
mode
,
const
std
::
string
&
intrin_name
=
""
,
bool
expand_mask
=
true
);
virtual
~
BinaryVecInsnArgsCalculator
()
override
=
default
;
PatternResult
GetInsnArgs
();
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetM0LimitLambda
();
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetM1LimitLambda
();
void
InsnReduction
();
private:
std
::
string
mode_
;
bool
expand_mask_
;
InsnAxis
vec_axis_
;
};
class
LastAxisReduceInsnArgsCalculator
:
InsnArgsCalculator
{
public:
LastAxisReduceInsnArgsCalculator
(
const
StmtStoreInfo
&
dst_info
,
const
StmtStoreInfo
&
src_info
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
)
:
InsnArgsCalculator
({
dst_info
},
{
src_info
},
for_info
,
intrin_name
),
dst_info
(
dst_info
),
src_info
(
src_info
),
for_info
(
for_info
),
arg_info
(
ArgInfo
(
make_node
<
ArgInfoNode
>
())),
body_args
(
VectorArgInfo
()),
tail_args
(
VectorArgInfo
()),
intrin_name
(
intrin_name
)
{}
PatternResult
GetInsnArgs
();
~
LastAxisReduceInsnArgsCalculator
()
=
default
;
protected:
Array
<
Var
>
GetPattern
();
PatternResult
GenResult
(
const
Array
<
Var
>
&
elim_var
);
private:
void
CalcParams
();
struct
Params
{
Array
<
Var
>
src_var
;
int
block_size
=
0
;
int
vec_max_len
=
0
;
int
last_dim_shape
=
0
;
Expr
insn_offset_scale_factor
;
};
StmtStoreInfo
dst_info
;
StmtStoreInfo
src_info
;
StmtInfo
for_info
;
ArgInfo
arg_info
;
VectorArgInfo
body_args
;
VectorArgInfo
tail_args
;
Array
<
VectorArgInfo
>
mix_vec_arg_list
;
std
::
string
intrin_name
;
Params
params
;
};
BisectionInfoWrapper
SeparateComInfoToBisectionInfoList
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
StmtInfo
&
if_info
,
bool
last_axis
,
int
postfix
);
ArgInfo
GetBinaryVecInsnArgs
(
const
Stmt
&
stmt
,
std
::
string
intrin_name
,
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
bool
enable_bisect
=
true
);
}
// namespace akg
#endif
\ No newline at end of file
src/emit_insn/insn_binary_vec_pattern.cc
已删除
100644 → 0
浏览文件 @
11ed37cc
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/base.h>
#include <tvm/ir_pass.h>
#include <set>
#include "ir_pass.h"
#include "contrib/cce_parm/cceconf.h"
#include "tvm.h"
#include "common/array_api.h"
#include "insn_pattern.h"
#include "insn_builder.h"
namespace
akg
{
std
::
string
GetBinaryVecMode
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
std
::
string
&
intrin_name
,
bool
enable_bisect
=
true
)
{
std
::
set
<
std
::
string
>
reduce_bisect_list
=
{
"vadd"
,
"vsub"
,
"vmul"
,
"vmax"
};
std
::
string
mode
=
"reduction"
;
if
(
IsElementwise
(
dst_info_list
,
src_info_list
))
{
mode
=
"elewise"
;
}
else
if
(
IsBroadcast
(
dst_info_list
,
src_info_list
))
{
mode
=
"broadcast"
;
}
else
if
(
IsLastAxisReduction
(
dst_info_list
,
src_info_list
))
{
mode
=
"reduce_last_axis"
;
}
else
if
(
enable_bisect
&&
reduce_bisect_list
.
count
(
intrin_name
)
!=
0
&&
IsBisectionReduction
(
dst_info_list
,
src_info_list
))
{
mode
=
"reduce_bisection"
;
}
return
mode
;
}
PatternResult
ReduceLastAxisPatternGenerator
::
GetInsnArgs
()
{
CalcParams
();
Array
<
Var
>
elim_var
;
float
rate2d
=
Compute2DBlockPatternMaskRate
();
if
(
rate2d
>
1.0
f
)
{
elim_var
=
Get2DBlockPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_2D_BLOCK
;
}
else
{
elim_var
=
Get1DPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_1D
;
}
return
GenResult
(
elim_var
);
}
float
ReduceLastAxisPatternGenerator
::
Compute2DBlockPatternMaskRate
()
{
const
float
is2_dpattern
=
1.0
f
;
if
(
intrin_name
==
"vadd"
||
intrin_name
==
"argmax"
||
intrin_name
==
"argmin"
)
{
return
not_this_pattern
;
}
// src var size must larger than 2
if
(
params
.
src_var
.
size
()
<
2
)
{
return
not_this_pattern
;
}
int
body_len
=
params
.
last_dim_shape
/
params
.
vec_max_len
*
params
.
vec_max_len
;
int
tail_len
=
params
.
last_dim_shape
%
params
.
vec_max_len
;
// there is no body in this mode
if
(
body_len
>
0
||
tail_len
>
params
.
block_size
)
{
return
not_this_pattern
;
}
return
is2_dpattern
;
}
Array
<
Var
>
ReduceLastAxisPatternGenerator
::
Get2DBlockPattern
()
{
int
sec_last_dim_shape
=
GetInt32Const
(
GetItem
(
src_info
->
shape_
,
-
2
));
int
body_len
=
sec_last_dim_shape
/
FULL_BLOCK_NUM
*
FULL_BLOCK_NUM
;
int
tail_len
=
sec_last_dim_shape
%
FULL_BLOCK_NUM
;
int
cmd_body_len
=
0
;
if
(
body_len
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
repeat_
=
Expr
(
body_len
/
FULL_BLOCK_NUM
);
// Here use dst_stride_m1 as dst_stride
body_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
)};
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
last_dim_shape
,
FULL_BLOCK_NUM
,
dst_info
->
dtype_
);
cmd_body_len
+=
GetInt32Const
(
body_args
->
repeat_
)
*
FULL_BLOCK_NUM
;
}
if
(
tail_len
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
cmd_body_len
);
tail_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
cmd_body_len
/
FULL_BLOCK_NUM
*
params
.
vec_max_len
)};
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
0
)};
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
last_dim_shape
,
tail_len
,
dst_info
->
dtype_
);
}
params
.
insn_offset_scale_factor
=
1
;
return
GetRange
(
params
.
src_var
,
-
2
,
2
);
}
Array
<
Var
>
ReduceLastAxisPatternGenerator
::
Get1DPattern
()
{
int
body_len
=
params
.
last_dim_shape
/
params
.
vec_max_len
*
params
.
vec_max_len
;
int
tail_len
=
params
.
last_dim_shape
%
params
.
vec_max_len
;
int
cmd_body_len
=
0
;
bool
is_vadd
=
intrin_name
==
"vadd"
;
int
repeat_stride
=
FULL_BLOCK_NUM
;
if
(
is_vadd
)
{
repeat_stride
=
1
;
}
const
int
fp16_block_size
=
16
;
if
(
body_len
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
body_offset_
=
params
.
vec_max_len
;
body_args
.
GetNode
()
->
repeat_
=
Expr
(
body_len
/
params
.
vec_max_len
);
// Here use dst_stride_m1 as dst_stride
body_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
)};
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
vec_max_len
,
1
,
dst_info
->
dtype_
);
cmd_body_len
+=
GetInt32Const
(
body_args
->
repeat_
)
*
repeat_stride
;
}
if
(
tail_len
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
tail_args
.
GetNode
()
->
body_offset_
=
params
.
vec_max_len
;
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
cmd_body_len
);
tail_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
body_len
)};
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
0
)};
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
tail_len
,
1
,
dst_info
->
dtype_
);
if
(
is_vadd
)
{
cmd_body_len
+=
1
;
}
else
{
cmd_body_len
+=
tail_len
/
fp16_block_size
;
if
(
tail_len
%
fp16_block_size
!=
0
)
{
cmd_body_len
+=
1
;
}
}
}
// cmd_body_len > 1 means vcadd size greater than 128, need to use vcadd again to compute final result
// if cmd_body_len > 128, then need to recursively emit vcadd
while
(
cmd_body_len
>
1
)
{
int
cmd_tail_len
=
cmd_body_len
%
params
.
vec_max_len
;
cmd_body_len
=
cmd_body_len
/
params
.
vec_max_len
;
if
(
cmd_body_len
>
0
)
{
VectorArgInfo
mix_vec_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
mix_vec_args
.
GetNode
()
->
repeat_
=
Expr
(
cmd_body_len
);
mix_vec_args
.
GetNode
()
->
dst_head_
=
Expr
(
0
);
mix_vec_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
0
)};
mix_vec_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
mix_vec_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
mix_vec_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
)};
mix_vec_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
vec_max_len
,
1
,
dst_info
->
dtype_
);
mix_vec_arg_list
.
push_back
(
mix_vec_args
);
if
(
!
is_vadd
)
{
cmd_body_len
*=
FULL_BLOCK_NUM
;
}
}
if
(
cmd_tail_len
>
0
)
{
VectorArgInfo
mix_vec_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
mix_vec_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
mix_vec_args
.
GetNode
()
->
dst_head_
=
Expr
(
cmd_body_len
);
if
(
is_vadd
)
{
mix_vec_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
cmd_body_len
*
params
.
vec_max_len
)};
}
else
{
mix_vec_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
cmd_body_len
/
FULL_BLOCK_NUM
*
params
.
vec_max_len
)};
}
mix_vec_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
mix_vec_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
mix_vec_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
)};
mix_vec_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
cmd_tail_len
,
1
,
dst_info
->
dtype_
);
if
(
is_vadd
)
{
cmd_body_len
+=
1
;
}
else
{
cmd_body_len
+=
cmd_tail_len
/
fp16_block_size
;
if
(
cmd_tail_len
%
fp16_block_size
!=
0
)
{
cmd_body_len
+=
1
;
}
}
mix_vec_arg_list
.
push_back
(
mix_vec_args
);
}
}
params
.
insn_offset_scale_factor
=
Expr
(
params
.
block_size
);
int
max_num
=
body_len
/
params
.
vec_max_len
;
if
(
intrin_name
==
"vmax"
||
intrin_name
==
"vmin"
)
{
max_num
*=
FULL_BLOCK_NUM
;
}
if
(
max_num
>=
params
.
block_size
)
{
params
.
insn_offset_scale_factor
=
max_num
+
params
.
block_size
-
1
;
if
(
tail_len
>
0
)
{
params
.
insn_offset_scale_factor
+=
1
;
}
params
.
insn_offset_scale_factor
=
truncdiv
(
params
.
insn_offset_scale_factor
,
params
.
block_size
)
*
params
.
block_size
;
}
if
(
!
params
.
src_var
.
empty
())
{
return
GetRange
(
params
.
src_var
,
-
1
,
1
);
}
return
{};
}
PatternResult
ReduceLastAxisPatternGenerator
::
GenResult
(
const
Array
<
Var
>
&
elim_var
)
{
dst_info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
dst_info
,
elim_var
)
*
params
.
insn_offset_scale_factor
;
src_info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
src_info
,
elim_var
);
if
(
body_args
.
defined
())
{
body_args
.
GetNode
()
->
insn_offset_scale_factor_
=
params
.
insn_offset_scale_factor
;
}
if
(
tail_args
.
defined
())
{
tail_args
.
GetNode
()
->
insn_offset_scale_factor_
=
params
.
insn_offset_scale_factor
;
}
for
(
auto
&
arg
:
mix_vec_arg_list
)
{
arg
.
GetNode
()
->
insn_offset_scale_factor_
=
params
.
insn_offset_scale_factor
;
}
arg_info
.
GetNode
()
->
body_arg_info_
=
body_args
;
arg_info
.
GetNode
()
->
tail_arg_info_
=
tail_args
;
arg_info
.
GetNode
()
->
reduction_tail_args_
=
mix_vec_arg_list
;
CleanForInfoVars
(
for_info
,
elim_var
);
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_REDUCTION_LAST_AXIS
;
PatternResult
result
;
result
.
dst_info_list
=
{
dst_info
};
result
.
src_info_list
=
{
src_info
};
result
.
for_info
=
for_info
;
result
.
arg_info
=
arg_info
;
return
result
;
}
void
ReduceLastAxisPatternGenerator
::
CalcParams
()
{
// check shape len
if
(
dst_info
->
shape_
.
empty
()
||
src_info
->
shape_
.
empty
())
{
LOG
(
FATAL
)
<<
"CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0."
;
}
// check data type
if
(
dst_info
->
dtype_
!=
src_info
->
dtype_
)
{
LOG
(
FATAL
)
<<
"CCE Vector Insn Error: dst_buffer and src_buffer can not be different data type."
;
}
params
.
src_var
=
src_info
->
var_
;
params
.
block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
params
.
last_dim_shape
=
GetInt32Const
(
GetItem
(
src_info
->
shape_
,
-
1
));
params
.
vec_max_len
=
GetVecMaxLen
(
dst_info
->
dtype_
);
CHECK_NE
(
params
.
block_size
,
0
);
CHECK_NE
(
params
.
vec_max_len
,
0
);
}
/// Get CCE Binary Vector instructions args
/// \return
PatternResult
BinaryVecPatternGenerator
::
GetInsnArgs
()
{
CalcParams
();
if
(
arg_info
->
arg_type_
==
ARG_VECTOR_BROADCAST_LAST_AXIS
)
{
PatternResult
result
;
result
.
dst_info_list
=
{
dst_info
};
result
.
src_info_list
=
src_info_list
;
result
.
for_info
=
for_info
;
result
.
arg_info
=
arg_info
;
return
result
;
}
Array
<
Var
>
elim_var
=
{};
float
rate3d
=
Compute3DPatternMaskRate
();
float
rate2db
=
Compute2DBlockPatternMaskRate
();
float
rate2d
=
Compute2DPatternMaskRate
();
float
rate1d
=
Compute1DPatternMaskRate
();
if
(
rate3d
>=
rate2db
&&
rate3d
>
0
)
{
elim_var
=
Get3DPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_3D
;
}
else
if
(
rate2db
>=
rate2d
&&
rate2db
>=
rate1d
&&
rate2db
>
0
)
{
elim_var
=
Get2DBlockPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_PARTIAL_3D
;
}
else
if
(
rate2d
>
rate1d
&&
rate2d
>
0
)
{
elim_var
=
Get2DPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_2D
;
}
else
if
(
rate1d
>
0
)
{
elim_var
=
Get1DPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_1D
;
}
else
{
LOG
(
FATAL
)
<<
"Error: Cannot emit Binary-Vector-Insn with any pattern!"
;
}
std
::
string
mask_rate
=
"rate3d["
+
std
::
to_string
(
rate3d
)
+
"], rate2db["
+
std
::
to_string
(
rate2db
)
+
"], rate2d["
+
std
::
to_string
(
rate2d
)
+
"], rate1d["
+
std
::
to_string
(
rate1d
)
+
"]"
;
CommentManager
::
GetInstance
().
AddComment
(
"Mask_rate"
,
mask_rate
);
if
(
tail_args
.
defined
())
{
CommentManager
::
GetInstance
().
AddComment
(
"Contain_tail"
,
"true"
);
}
else
{
CommentManager
::
GetInstance
().
AddComment
(
"Contain_tail"
,
"false"
);
}
return
GenResult
(
elim_var
);
}
float
BinaryVecPatternGenerator
::
Compute3DPatternMaskRate
()
{
if
(
params
.
non_zero_shape3
==
1
||
params
.
non_zero_shape2
==
1
)
{
return
not_this_pattern
;
}
// in elemwise mode, the var is already checked to be equal, no need to check
if
(
params
.
dst_var
.
size
()
<
3
||
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
))
>
params
.
block_size
||
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
block_size
!=
0
||
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
3
))
%
params
.
block_size
!=
0
||
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
>
0
&&
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
))
>
0
&&
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
<
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
)))
||
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
3
))
>
0
&&
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
))
>
0
&&
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
3
))
<
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
))))
{
return
not_this_pattern
;
}
// check dst_stride_m0
// As described in ISL User Guide t6.3,
// dst_stride_m0 = 0 is treated as 1
auto
JudgeNot3D
=
[
this
](
const
StmtStoreInfo
&
info
)
{
auto
last_shape1
=
GetIntConst
(
GetItem
(
info
->
shape_
,
-
1
));
if
(
info
->
var_
.
size
()
<
3
||
last_shape1
>
params
.
block_size
)
{
return
true
;
}
auto
last_shape2
=
GetIntConst
(
GetItem
(
info
->
shape_
,
-
2
));
auto
last_stride2
=
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
));
auto
last_stride3
=
GetIntConst
(
GetItem
(
info
->
strides_
,
-
3
));
return
last_stride2
%
params
.
block_size
!=
0
||
last_stride3
%
params
.
block_size
!=
0
||
(
last_stride2
>
0
&&
last_shape1
>
0
&&
last_stride2
<
last_shape1
)
||
(
last_stride3
>
0
&&
last_shape2
>
0
&&
last_stride3
<
last_shape2
);
};
if
(
std
::
any_of
(
src_info_list
.
begin
(),
src_info_list
.
end
(),
JudgeNot3D
))
{
return
not_this_pattern
;
}
if
(
mode
==
"reduction"
)
{
// check same alignment
Array
<
Expr
>
shape_list
=
{
GetItem
(
params
.
dst_shape
,
-
1
)};
shape_list
.
push_back
(
GetItem
(
params
.
src_shape0
,
-
1
));
shape_list
.
push_back
(
GetItem
(
params
.
src_shape1
,
-
1
));
if
(
!
IsNonZeroShapeEqual
(
shape_list
))
{
return
not_this_pattern
;
}
}
// repeat axis is shape [-3], repeat once, has 8 loops
bool
is3_d
=
true
;
float
rate3d_mode1
=
not_this_pattern
;
float
rate3d_mode2
=
not_this_pattern
;
int
repeat_num
;
float
repeat_latency
;
auto
info_list
=
src_info_list
;
Insert
(
info_list
,
0
,
dst_info
);
for
(
auto
info
:
info_list
)
{
if
(
GetInt32Const
(
GetItem
(
info
->
shape_
,
-
2
))
>
FULL_BLOCK_NUM
||
GetInt32Const
(
GetItem
(
info
->
strides_
,
-
2
))
/
params
.
block_size
>=
MAX_STRIDE_M0
||
GetInt32Const
(
GetItem
(
info
->
strides_
,
-
3
))
/
params
.
block_size
>=
MAX_STRIDE_M0
)
{
is3_d
=
false
;
break
;
}
}
if
(
is3_d
)
{
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
repeat_num
=
params
.
non_zero_shape3
;
repeat_latency
=
((
repeat_num
-
1
)
/
MAX_REPEAT
)
*
repeat_latency_coef
;
rate3d_mode1
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
vec_max_len
/
(
repeat_num
+
repeat_latency
);
}
is3_d
=
true
;
// repeat axis is shape[-2]
for
(
auto
info
:
info_list
)
{
// stride_m0 should be less than 256
if
(
GetIntConst
(
GetItem
(
info
->
shape_
,
-
3
))
%
FULL_BLOCK_NUM
!=
0
||
GetIntConst
(
GetItem
(
info
->
strides_
,
-
3
))
/
params
.
block_size
>=
MAX_STRIDE_M0
)
{
is3_d
=
false
;
break
;
}
}
if
(
is3_d
)
{
if
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
3
))
==
0
)
{
return
not_this_pattern
;
}
repeat_num
=
params
.
non_zero_shape2
*
(
params
.
non_zero_shape3
/
FULL_BLOCK_NUM
);
repeat_latency
=
((
repeat_num
-
1
)
/
MAX_REPEAT
)
*
repeat_latency_coef
;
float
offset_latency
=
params
.
non_zero_shape3
/
FULL_BLOCK_NUM
>
1
?
params
.
non_zero_shape3
*
offset_latency_coef
:
0
;
rate3d_mode2
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
);
}
return
rate3d_mode1
>
rate3d_mode2
?
rate3d_mode1
:
rate3d_mode2
;
}
float
BinaryVecPatternGenerator
::
Compute2DBlockPatternMaskRate
()
{
if
(
params
.
non_zero_shape2
==
1
||
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
1
))
!=
1
)
{
return
not_this_pattern
;
}
if
(
params
.
dst_var
.
size
()
<
2
||
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
))
>
params
.
block_size
||
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
block_size
!=
0
||
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
/
params
.
block_size
>=
MAX_STRIDE_M0
||
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
>
0
&&
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
))
>
0
&&
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
<
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
))))
{
return
not_this_pattern
;
}
for
(
auto
info
:
src_info_list
)
{
if
(
info
->
var_
.
size
()
<
2
||
GetIntConst
(
GetItem
(
info
->
shape_
,
-
1
))
>
params
.
block_size
||
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
))
%
params
.
block_size
!=
0
||
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
))
/
params
.
block_size
>=
MAX_STRIDE_M0
||
(
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
))
>
0
&&
GetIntConst
(
GetItem
(
info
->
shape_
,
-
1
))
>
0
&&
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
))
<
GetIntConst
(
GetItem
(
info
->
shape_
,
-
1
))))
{
return
not_this_pattern
;
}
}
if
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
if
(
mode
==
"reduction"
)
{
if
(
params
.
dst_var
.
size
()
>
2
)
{
// if not elewise mode, then can not use partial 3D mode
if
(
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
3
))
==
0
)
{
return
not_this_pattern
;
}
for
(
auto
info
:
src_info_list
)
{
if
(
GetIntConst
(
GetItem
(
info
->
shape_
,
-
3
))
==
0
)
{
return
not_this_pattern
;
}
}
}
// check same alignment
Array
<
Expr
>
shape_list
=
{
GetItem
(
params
.
dst_shape
,
-
1
)};
shape_list
.
push_back
(
GetItem
(
params
.
src_shape0
,
-
1
));
shape_list
.
push_back
(
GetItem
(
params
.
src_shape1
,
-
1
));
// check dst_stride_m0
// As described in ISL User Guide t6.3,
// dst_stride_m0 = 0 is treated as 1
if
(
!
IsNonZeroShapeEqual
(
shape_list
))
{
return
not_this_pattern
;
}
}
int
repeat_body_num
=
params
.
non_zero_shape2
/
FULL_BLOCK_NUM
;
int
repeat_tail_num
=
(
params
.
non_zero_shape2
%
FULL_BLOCK_NUM
+
FULL_BLOCK_NUM
-
1
)
/
FULL_BLOCK_NUM
;
int
repeat_num
=
(
repeat_body_num
+
repeat_tail_num
)
*
params
.
non_zero_shape3
;
float
repeat_latency
=
(
std
::
max
(
repeat_body_num
-
1
,
0
)
/
MAX_REPEAT
+
std
::
max
(
repeat_tail_num
-
1
,
0
)
/
MAX_REPEAT
)
*
repeat_latency_coef
;
float
offset_latency
=
params
.
non_zero_shape3
>
1
?
params
.
non_zero_shape3
*
offset_latency_coef
:
0
;
float
split_latency
=
(
repeat_body_num
>
0
&&
repeat_tail_num
>
0
)
?
split_latency_coef
:
0
;
float
rate2db
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
+
split_latency
);
return
rate2db
;
}
float
BinaryVecPatternGenerator
::
Compute2DPatternMaskRate
()
{
if
(
params
.
non_zero_shape2
==
1
)
{
return
not_this_pattern
;
}
if
(
params
.
dst_var
.
size
()
<
2
||
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
block_size
!=
0
||
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
<
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
))
&&
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
)
>
0
)))
{
return
not_this_pattern
;
}
for
(
auto
info
:
src_info_list
)
{
if
(
info
->
var_
.
size
()
<
2
||
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
))
%
params
.
block_size
!=
0
||
(
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
))
<
GetIntConst
(
GetItem
(
info
->
shape_
,
-
1
))
&&
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
)
>
0
)))
{
return
not_this_pattern
;
}
}
// check num of insns, select 1D pattern or 2D pattern
int
tail_factor
=
0
;
if
(
mode
==
"reduction"
)
{
Array
<
Expr
>
shape_list
=
{
GetItem
(
params
.
dst_shape
,
-
1
)};
shape_list
.
push_back
(
GetItem
(
params
.
src_shape0
,
-
1
));
shape_list
.
push_back
(
GetItem
(
params
.
src_shape1
,
-
1
));
if
(
!
IsNonZeroShapeEqual
(
shape_list
))
{
return
not_this_pattern
;
}
}
// only cloud allow dst_stride_m1 = 0
cceconf
::
CceConf
*
conf
=
cceconf
::
CceConf
::
getInstance
();
const
std
::
string
product_name
=
conf
->
getProductName
();
if
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
&&
product_name
!=
"cloud"
)
{
return
not_this_pattern
;
}
CHECK_NE
(
params
.
vec_max_len
,
0
);
if
(
params
.
non_zero_shape1
/
params
.
vec_max_len
>
0
&&
params
.
non_zero_shape1
%
params
.
vec_max_len
>
0
)
{
tail_factor
=
1
;
}
if
(
GetIntConst
(
GetItem
(
dst_info
->
strides_
,
-
2
))
/
params
.
block_size
>=
MAX_STRIDE_M0
)
{
return
not_this_pattern
;
}
for
(
auto
info
:
src_info_list
)
{
if
(
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
))
/
params
.
block_size
>=
MAX_STRIDE_M0
)
{
return
not_this_pattern
;
}
}
int
shape1
=
(
params
.
non_zero_shape1
+
params
.
vec_max_len
-
1
)
/
params
.
vec_max_len
;
int
repeat_num
=
shape1
*
params
.
non_zero_shape2
*
params
.
non_zero_shape3
;
float
repeat_latency
=
(
std
::
max
(
params
.
non_zero_shape2
-
1
,
0
)
/
MAX_REPEAT
)
*
params
.
non_zero_shape3
*
shape1
*
repeat_latency_coef
;
float
offset_latency
=
shape1
*
params
.
non_zero_shape3
>
1
?
shape1
*
params
.
non_zero_shape3
*
offset_latency_coef
:
0
;
float
split_latency
=
tail_factor
*
split_latency_coef
;
float
rate2d
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
+
split_latency
);
return
rate2d
;
}
float
BinaryVecPatternGenerator
::
Compute1DPatternMaskRate
()
{
int
tail_factor
=
0
;
if
(
params
.
non_zero_shape1
/
params
.
vec_max_len
>
0
&&
params
.
non_zero_shape1
%
params
.
vec_max_len
>
0
)
{
tail_factor
=
1
;
}
int
shape1
=
(
params
.
non_zero_shape1
+
params
.
vec_max_len
-
1
)
/
params
.
vec_max_len
;
int
repeat_num
=
shape1
*
params
.
non_zero_shape2
*
params
.
non_zero_shape3
;
float
repeat_latency
=
std
::
max
((
shape1
-
1
)
/
MAX_REPEAT
,
0
)
*
params
.
non_zero_shape2
*
params
.
non_zero_shape3
*
repeat_latency_coef
;
float
offset_latency
=
params
.
non_zero_shape2
*
params
.
non_zero_shape3
>
1
?
params
.
non_zero_shape2
*
params
.
non_zero_shape3
*
offset_latency_coef
:
0
;
float
split_latency
=
tail_factor
*
split_latency_coef
;
float
rate1d
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
+
split_latency
);
return
rate1d
;
}
Array
<
Var
>
BinaryVecPatternGenerator
::
Get3DPattern
()
{
// repeat axis is shape [-2]
int
second_last_shape
=
GetInt32Const
(
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape0
,
-
2
),
GetItem
(
params
.
src_shape1
,
-
2
)));
int
third_last_shape
=
GetInt32Const
(
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
3
),
GetItem
(
params
.
src_shape0
,
-
3
),
GetItem
(
params
.
src_shape1
,
-
3
)));
if
(
second_last_shape
>
8
)
{
// split shape[-3]
if
(
third_last_shape
>
8
)
{
auto
info_list
=
src_info_list
;
Insert
(
info_list
,
0
,
dst_info
);
SplitAxis
(
info_list
,
for_info
,
GetItem
(
params
.
dst_var
,
-
3
),
FULL_BLOCK_NUM
);
FillEmptyVar
(
info_list
);
params
.
dst_var
=
info_list
[
0
]
->
var_
;
params
.
dst_shape
=
info_list
[
0
]
->
shape_
;
params
.
dst_strides
=
info_list
[
0
]
->
strides_
;
params
.
src_var0
=
info_list
[
1
]
->
var_
;
params
.
src_shape0
=
info_list
[
1
]
->
shape_
;
params
.
src_strides0
=
info_list
[
1
]
->
strides_
;
params
.
src_var1
=
info_list
[
2
]
->
var_
;
params
.
src_shape1
=
info_list
[
2
]
->
shape_
;
params
.
src_strides1
=
info_list
[
2
]
->
strides_
;
}
// consider original shape[-2] as repeat axis
GetShapeInfoAndSwap
(
params
.
dst_var
,
params
.
dst_shape
,
params
.
dst_strides
,
-
2
,
-
3
);
GetShapeInfoAndSwap
(
params
.
src_var0
,
params
.
src_shape0
,
params
.
src_strides0
,
-
2
,
-
3
);
GetShapeInfoAndSwap
(
params
.
src_var1
,
params
.
src_shape1
,
params
.
src_strides1
,
-
2
,
-
3
);
}
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
CHECK
(
body_args
.
GetNode
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
repeat_
=
GetItem
(
params
.
dst_shape
,
-
3
);
body_args
.
GetNode
()
->
dst_stride_m0_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
block_size
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
3
),
params
.
block_size
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides0
,
-
2
),
params
.
block_size
),
truncdiv
(
GetItem
(
params
.
src_strides1
,
-
2
),
params
.
block_size
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides0
,
-
3
),
params
.
block_size
),
truncdiv
(
GetItem
(
params
.
src_strides1
,
-
3
),
params
.
block_size
)};
int
data_num
=
GetInt32Const
(
GetItem
(
params
.
dst_shape
,
-
2
));
if
(
mode
==
"reduction"
)
{
body_args
.
GetNode
()
->
repeat_
=
Expr
(
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
3
),
GetItem
(
params
.
src_shape0
,
-
3
),
GetItem
(
params
.
src_shape1
,
-
3
)));
data_num
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape0
,
-
2
),
GetItem
(
params
.
src_shape1
,
-
2
));
}
int
data_len
=
expand_mask
?
CeilTo
(
params
.
last_dim_shape
,
params
.
block_size
)
:
params
.
last_dim_shape
;
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
data_num
,
dst_info
->
dtype_
);
return
GetRange
(
params
.
dst_var
,
-
3
,
3
);
}
Array
<
Var
>
BinaryVecPatternGenerator
::
Get2DBlockPattern
()
{
int
repeat_len
=
GetInt32Const
(
GetItem
(
params
.
dst_shape
,
-
2
));
if
(
mode
==
"reduction"
)
{
params
.
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape0
,
-
1
),
GetItem
(
params
.
src_shape1
,
-
1
));
repeat_len
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape0
,
-
2
),
GetItem
(
params
.
src_shape1
,
-
2
));
}
int
repeat_body
=
repeat_len
/
FULL_BLOCK_NUM
;
int
repeat_tail
=
(
repeat_len
%
FULL_BLOCK_NUM
+
FULL_BLOCK_NUM
-
1
)
/
FULL_BLOCK_NUM
;
if
(
repeat_body
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
CHECK
(
body_args
.
GetNode
()
!=
nullptr
);
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
repeat_
=
Expr
(
repeat_body
);
body_args
.
GetNode
()
->
dst_stride_m0_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
block_size
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
body_args
->
dst_stride_m0_
*
FULL_BLOCK_NUM
;
Expr
src0_stride_m0
=
truncdiv
(
GetItem
(
params
.
src_strides0
,
-
2
),
params
.
block_size
);
Expr
src1_stride_m0
=
truncdiv
(
GetItem
(
params
.
src_strides1
,
-
2
),
params
.
block_size
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
src0_stride_m0
,
src1_stride_m0
};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
src0_stride_m0
*
FULL_BLOCK_NUM
,
src1_stride_m0
*
FULL_BLOCK_NUM
};
int
data_len
=
expand_mask
?
CeilTo
(
params
.
last_dim_shape
,
params
.
block_size
)
:
params
.
last_dim_shape
;
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
FULL_BLOCK_NUM
,
dst_info
->
dtype_
);
}
if
(
repeat_tail
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
CHECK
(
tail_args
.
GetNode
()
!=
nullptr
);
tail_args
.
GetNode
()
->
dst_head_
=
GetItem
(
params
.
dst_strides
,
-
2
)
*
repeat_body
*
FULL_BLOCK_NUM
;
tail_args
.
GetNode
()
->
src_head_list_
=
{
GetItem
(
params
.
src_strides0
,
-
2
)
*
repeat_body
*
FULL_BLOCK_NUM
,
GetItem
(
params
.
src_strides1
,
-
2
)
*
repeat_body
*
FULL_BLOCK_NUM
};
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m0_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
block_size
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
0
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides0
,
-
2
),
params
.
block_size
),
truncdiv
(
GetItem
(
params
.
src_strides1
,
-
2
),
params
.
block_size
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
0
),
Expr
(
0
)};
int
data_len
=
expand_mask
?
CeilTo
(
params
.
last_dim_shape
,
params
.
block_size
)
:
params
.
last_dim_shape
;
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
repeat_len
%
FULL_BLOCK_NUM
,
dst_info
->
dtype_
);
}
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
Array
<
Var
>
BinaryVecPatternGenerator
::
Get2DPattern
()
{
if
(
mode
==
"reduction"
)
{
params
.
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape0
,
-
1
),
GetItem
(
params
.
src_shape1
,
-
1
));
}
int
body_len
=
FloorTo
(
params
.
last_dim_shape
,
params
.
vec_max_len
);
int
tail_len
=
params
.
last_dim_shape
%
params
.
vec_max_len
;
if
(
body_len
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
CHECK
(
body_args
.
GetNode
()
!=
nullptr
);
body_args
.
GetNode
()
->
body_num_
=
body_len
/
params
.
vec_max_len
;
body_args
.
GetNode
()
->
body_offset_
=
params
.
vec_max_len
;
body_args
.
GetNode
()
->
repeat_
=
GetItem
(
params
.
dst_shape
,
-
2
);
if
(
mode
==
"reduction"
)
{
body_args
.
GetNode
()
->
repeat_
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape0
,
-
2
),
GetItem
(
params
.
src_shape1
,
-
2
));
}
body_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
1
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
block_size
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
),
Expr
(
1
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides0
,
-
2
),
params
.
block_size
),
truncdiv
(
GetItem
(
params
.
src_strides1
,
-
2
),
params
.
block_size
)};
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
vec_max_len
,
1
,
dst_info
->
dtype_
);
}
if
(
tail_len
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
CHECK
(
tail_args
.
GetNode
()
!=
nullptr
);
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
body_len
);
tail_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
body_len
),
Expr
(
body_len
)};
tail_args
.
GetNode
()
->
repeat_
=
GetItem
(
params
.
dst_shape
,
-
2
);
if
(
mode
==
"reduction"
)
{
tail_args
.
GetNode
()
->
repeat_
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape0
,
-
2
),
GetItem
(
params
.
src_shape1
,
-
2
));
}
tail_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
block_size
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
),
Expr
(
1
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides0
,
-
2
),
params
.
block_size
),
truncdiv
(
GetItem
(
params
.
src_strides1
,
-
2
),
params
.
block_size
)};
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
tail_len
,
1
,
dst_info
->
dtype_
);
}
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
Array
<
Var
>
BinaryVecPatternGenerator
::
Get1DPattern
()
{
auto
info_list
=
src_info_list
;
Insert
(
info_list
,
0
,
dst_info
);
bool
is_scalar_mode
=
IsScalarMode
(
info_list
);
if
(
is_scalar_mode
)
{
params
.
last_dim_shape
=
1
;
}
if
(
mode
==
"reduction"
)
{
params
.
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape0
,
-
1
),
GetItem
(
params
.
src_shape1
,
-
1
));
}
int
body_len
=
FloorTo
(
params
.
last_dim_shape
,
params
.
vec_max_len
);
int
tail_len
=
params
.
last_dim_shape
%
params
.
vec_max_len
;
int
last_axis
=
-
1
;
if
(
mode
==
"broadcast"
)
{
if
(
GetIntConst
(
GetItem
(
params
.
src_strides0
,
-
1
))
==
0
)
{
last_axis
=
0
;
}
if
(
GetIntConst
(
GetItem
(
params
.
src_strides1
,
-
1
))
==
0
)
{
last_axis
=
1
;
}
}
if
(
body_len
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
CHECK
(
body_args
.
GetNode
()
!=
nullptr
);
body_args
.
GetNode
()
->
last_axis_info_
.
src_index_
=
last_axis
;
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
repeat_
=
body_len
/
params
.
vec_max_len
;
body_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
1
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
FULL_BLOCK_NUM
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
),
Expr
(
1
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
),
Expr
(
FULL_BLOCK_NUM
)};
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
vec_max_len
,
1
,
dst_info
->
dtype_
);
}
if
(
tail_len
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
CHECK
(
tail_args
.
GetNode
()
!=
nullptr
);
tail_args
.
GetNode
()
->
last_axis_info_
.
src_index_
=
last_axis
;
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
body_len
);
tail_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
body_len
),
Expr
(
body_len
)};
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
0
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
),
Expr
(
1
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
0
),
Expr
(
0
)};
int
data_len
=
expand_mask
?
CeilTo
(
tail_len
,
params
.
block_size
)
:
tail_len
;
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
1
,
dst_info
->
dtype_
);
}
// compute offset for cce instructions
Array
<
Var
>
elim_var
=
{};
if
(
mode
==
"elewise"
&&
params
.
dst_var
.
size
()
>=
2
&&
params
.
dst_strides
.
size
()
>=
2
&&
params
.
last_dim_shape
<=
params
.
vec_max_len
&&
for_info
.
ops_
.
size
()
>=
2
&&
params
.
last_dim_shape
>=
params
.
vec_max_len
-
params
.
block_size
&&
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
params
.
last_dim_shape
)
{
// in this case we can merge second last for extent to repeat
size_t
idx
=
0
;
bool
suc
=
GetIndexOfElement
(
for_info
.
vars_
,
GetItem
(
params
.
dst_var
,
-
2
),
idx
);
CHECK
(
suc
);
auto
latest_for
=
GetItem
(
for_info
.
ops_
,
idx
).
as
<
For
>
();
// there should not be if_op between for loop and compute stmt
if
(
latest_for
&&
!
latest_for
->
body
->
IsInstance
<
IfThenElse
>
())
{
if
(
!
params
.
dst_var
.
empty
()
&&
!
is_scalar_mode
)
{
if
(
body_args
.
defined
())
{
// last_dim_shape = vec_max_len
body_args
.
GetNode
()
->
repeat_
=
body_args
->
repeat_
*
latest_for
->
extent
;
}
else
if
(
tail_args
.
defined
())
{
// last_dim_shape < vec_max_len
tail_args
.
GetNode
()
->
repeat_
=
tail_args
->
repeat_
*
latest_for
->
extent
;
}
return
elim_var
=
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
}
}
if
(
!
params
.
dst_var
.
empty
()
&&
!
is_scalar_mode
)
{
elim_var
=
GetRange
(
params
.
dst_var
,
-
1
,
1
);
}
return
elim_var
;
}
PatternResult
BinaryVecPatternGenerator
::
GenResult
(
const
Array
<
Var
>
&
elim_var
)
{
arg_info
.
GetNode
()
->
body_arg_info_
=
body_args
;
arg_info
.
GetNode
()
->
tail_arg_info_
=
tail_args
;
auto
real_elim_var
=
elim_var
;
if
(
!
empty_var
->
name_hint
.
empty
())
{
bool
need_elim
=
true
;
for
(
auto
e
:
elim_var
)
{
if
(
e
->
name_hint
==
empty_var
->
name_hint
)
{
need_elim
=
false
;
break
;
}
}
if
(
need_elim
)
{
real_elim_var
.
push_back
(
empty_var
);
}
}
dst_info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
dst_info
,
real_elim_var
);
for
(
auto
&
info
:
src_info_list
)
{
info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
info
,
real_elim_var
);
}
CleanForInfoVars
(
for_info
,
real_elim_var
);
CleanZeroStrides
(
dst_info
);
CleanZeroStrides
(
src_info_list
);
if
(
mode
==
"elewise"
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_ELEWISE
;
}
else
if
(
mode
==
"broadcast"
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_BROADCAST
;
}
else
if
(
mode
==
"reduction"
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_REDUCTION
;
}
PatternResult
result
;
result
.
dst_info_list
=
{
dst_info
};
result
.
src_info_list
=
src_info_list
;
result
.
for_info
=
for_info
;
result
.
arg_info
=
arg_info
;
return
result
;
}
void
BinaryVecPatternGenerator
::
CalcParams
()
{
CHECK_GE
(
src_info_list
.
size
(),
2
);
StmtStoreInfo
src_info0
=
src_info_list
[
0
];
StmtStoreInfo
src_info1
=
src_info_list
[
1
];
StmtInfoList
info_list
=
{
dst_info
,
src_info0
,
src_info1
};
// check shape len
for
(
auto
info
:
info_list
)
{
if
(
info
->
shape_
.
empty
())
{
LOG
(
FATAL
)
<<
"CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0."
;
}
}
// check data type
for
(
auto
src_info
:
src_info_list
)
{
if
(
dst_info
->
dtype_
!=
src_info
->
dtype_
)
{
LOG
(
FATAL
)
<<
"CCE Vector Insn Error: dst_buffer and src_buffer can not be different data type."
;
}
}
params
.
last_dim_shape
=
GetInt32Const
(
GetItem
(
dst_info
->
shape_
,
-
1
));
AppendEmptyVar
(
info_list
);
if
(
arg_info
->
arg_type_
==
ARG_VECTOR_BROADCAST_LAST_AXIS
)
{
return
;
}
if
(
mode
==
"reduction"
||
mode
==
"broadcast"
)
{
FillEmptyVar
(
info_list
);
}
CHECK_EQ
(
info_list
.
size
(),
3
);
dst_info
=
info_list
[
0
];
src_info0
=
info_list
[
1
];
src_info1
=
info_list
[
2
];
params
.
vec_max_len
=
GetVecMaxLen
(
dst_info
->
dtype_
);
params
.
block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
CHECK_NE
(
params
.
vec_max_len
,
0
);
CHECK_NE
(
params
.
block_size
,
0
);
params
.
dst_var
=
dst_info
->
var_
;
params
.
dst_shape
=
dst_info
->
shape_
;
params
.
dst_strides
=
dst_info
->
strides_
;
params
.
src_var0
=
src_info0
->
var_
;
params
.
src_var1
=
src_info1
->
var_
;
params
.
src_shape0
=
src_info0
->
shape_
;
params
.
src_shape1
=
src_info1
->
shape_
;
params
.
src_strides0
=
src_info0
->
strides_
;
params
.
src_strides1
=
src_info1
->
strides_
;
auto
GetNonZeroShapeByIdx
=
[
this
](
int
index
)
->
int
{
if
(
index
<=
static_cast
<
int
>
(
params
.
dst_var
.
size
()))
{
if
(
Equal
(
GetItem
(
params
.
dst_var
,
-
index
),
GetItem
(
params
.
src_var0
,
-
index
))
&&
Equal
(
GetItem
(
params
.
dst_var
,
-
index
),
GetItem
(
params
.
src_var1
,
-
index
)))
{
return
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
index
),
GetItem
(
params
.
src_shape0
,
-
index
),
GetItem
(
params
.
src_shape1
,
-
index
));
}
}
return
1
;
};
params
.
non_zero_shape1
=
GetNonZeroShapeByIdx
(
1
);
params
.
non_zero_shape2
=
GetNonZeroShapeByIdx
(
2
);
params
.
non_zero_shape3
=
GetNonZeroShapeByIdx
(
3
);
params
.
all_points
=
params
.
non_zero_shape1
*
params
.
non_zero_shape2
*
params
.
non_zero_shape3
;
}
bool
BinaryVecPatternGenerator
::
IsSamePatternComInfo
(
const
StmtStoreInfo
&
info_a
,
const
StmtStoreInfo
&
info_b
)
{
if
(
IsSame
(
info_a
->
var_
,
info_b
->
var_
))
{
if
(
info_a
->
shape_
.
size
()
!=
info_b
->
shape_
.
size
())
{
return
false
;
}
for
(
size_t
i
=
0
;
i
<
info_a
->
shape_
.
size
();
++
i
)
{
if
(
!
IsTwoItemEqual
(
info_a
->
shape_
,
info_b
->
shape_
,
static_cast
<
int
>
(
i
),
true
))
{
return
false
;
}
}
if
(
info_a
->
strides_
.
size
()
!=
info_b
->
strides_
.
size
())
{
return
false
;
}
for
(
size_t
i
=
0
;
i
<
info_a
->
strides_
.
size
();
++
i
)
{
if
(
!
IsTwoItemEqual
(
info_a
->
strides_
,
info_b
->
strides_
,
static_cast
<
int
>
(
i
),
true
))
{
return
false
;
}
}
return
true
;
}
return
false
;
}
bool
BinaryVecPatternGenerator
::
IsNonZeroShapeEqual
(
const
Array
<
Expr
>
&
shape_list
)
{
Array
<
Expr
>
non_zero_list
;
for
(
auto
shape
:
shape_list
)
{
if
(
GetIntConst
(
shape
)
!=
0
)
{
non_zero_list
.
push_back
(
shape
);
}
}
if
(
non_zero_list
.
empty
())
{
LOG
(
FATAL
)
<<
"Error: all shapes are equal to 0."
;
}
for
(
auto
shape
:
non_zero_list
)
{
if
(
GetIntConst
(
shape
)
!=
GetIntConst
(
non_zero_list
[
0
]))
{
return
false
;
}
}
return
true
;
}
void
BinaryVecPatternGenerator
::
AppendEmptyVar
(
StmtInfoList
&
info_list
)
{
auto
FillEmptyVarToLast
=
[](
const
StmtStoreInfo
com_info
,
const
Var
&
var
)
->
void
{
com_info
.
GetNode
()
->
var_
.
push_back
(
var
);
com_info
.
GetNode
()
->
shape_
.
push_back
(
Expr
(
1
));
com_info
.
GetNode
()
->
strides_
.
push_back
(
Expr
(
1
));
com_info
.
GetNode
()
->
index_
=
com_info
->
index_
+
GetItem
(
com_info
->
var_
,
-
1
);
};
auto
src_info0
=
src_info_list
[
0
];
auto
src_info1
=
src_info_list
[
1
];
if
(
mode
==
"reduction"
||
mode
==
"broadcast"
)
{
// ISA 8.1.2, strides of Xd must be equal to Xm, [Xd = dst, Xn = src0, Xm = src1]
if
(
IsSamePatternComInfo
(
dst_info
,
src_info0
))
{
auto
tmp
=
src_info0
;
src_info0
=
src_info1
;
src_info1
=
tmp
;
}
if
(
mode
==
"reduction"
)
{
if
(
src_info0
->
data_alignment_
==
1
)
{
empty_var
=
Var
(
"empty_cc"
);
FillEmptyVarToLast
(
src_info0
,
empty_var
);
}
}
else
if
(
mode
==
"broadcast"
)
{
// last dim broadcast, should use VS insn, such as vadds and vmuls
bool
less_var
=
!
dst_info
->
var_
.
empty
()
&&
!
src_info0
->
var_
.
empty
()
&&
!
src_info1
->
var_
.
empty
()
&&
(
!
IsTwoItemEqual
(
dst_info
->
var_
,
src_info0
->
var_
,
-
1
)
||
!
IsTwoItemEqual
(
dst_info
->
var_
,
src_info1
->
var_
,
-
1
));
bool
null_var
=
src_info0
->
var_
.
empty
()
||
src_info1
->
var_
.
empty
();
if
(
less_var
||
null_var
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_BROADCAST_LAST_AXIS
;
return
;
}
else
if
(
dst_info
->
data_alignment_
==
1
&&
src_info0
->
data_alignment_
==
1
)
{
empty_var
=
Var
(
"empty_cc"
);
FillEmptyVarToLast
(
dst_info
,
empty_var
);
FillEmptyVarToLast
(
src_info0
,
empty_var
);
FillEmptyVarToLast
(
src_info1
,
empty_var
);
params
.
last_dim_shape
=
1
;
}
}
src_info_list
=
{
src_info0
,
src_info1
};
info_list
=
{
dst_info
,
src_info0
,
src_info1
};
}
}
/// Get CCE Binary Vector Insn Computation Info
/// \param stmt - operand stmt
/// \param intrin_name - vector intrin name
/// \param dst_info_list - dst computation info list
/// \param src_info_list - src computation info list
/// \param if_info - if info list
/// \param for_info - for info list
/// \return intrin args
ArgInfo
GetBinaryVecInsnArgs
(
const
Stmt
&
stmt
,
std
::
string
intrin_name
,
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
bool
enable_bisect
)
{
// check intrin_name
std
::
set
<
std
::
string
>
intrin_name_list
=
{
"vadd"
,
"vmax"
,
"vmin"
,
"vmul"
,
"vdiv"
,
"vsel"
,
"vsub"
,
"vand"
,
"vor"
,
"vaxpy"
,
"argmax"
,
"argmin"
,
"vmadd"
,
"vmaddrelu"
,
"vmla"
};
if
(
intrin_name_list
.
count
(
intrin_name
)
==
0
)
{
LOG
(
FATAL
)
<<
"Error: CCE Binary Vector Insn doesn't support the given intrin_name."
;
}
// get and check dst and src
GetCompactComputationInfo
(
stmt
,
dst_info_list
,
src_info_list
,
if_info
,
for_info
,
true
);
// For vmadd/vmaddrelu/vmla we only need first two src
if
(
dst_info_list
.
size
()
!=
1
||
src_info_list
.
size
()
<
2
)
{
LOG
(
FATAL
)
<<
"CCE Binary Vector Insn only support ONE dst and TWO srcs."
;
}
src_info_list
=
GetRange
(
src_info_list
,
0
,
2
);
ArgInfo
arg_info
=
ArgInfo
(
make_node
<
ArgInfoNode
>
());
// detect vector op mode
std
::
string
mode
=
GetBinaryVecMode
(
dst_info_list
,
src_info_list
,
intrin_name
,
enable_bisect
);
if
(
mode
==
"reduce_last_axis"
)
{
size_t
src_var_list_size
=
src_info_list
[
1
]
->
var_
.
size
();
if
(
src_info_list
[
0
]
->
var_
.
size
()
>
src_info_list
[
1
]
->
var_
.
size
())
{
src_var_list_size
=
src_info_list
[
0
]
->
var_
.
size
();
}
CHECK
(
src_var_list_size
>
0
)
<<
"Error: src can not be a scalar."
;
if
(
src_var_list_size
-
dst_info_list
[
0
]
->
var_
.
size
()
==
1
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_REDUCTION_LAST_AXIS
;
}
else
{
LOG
(
FATAL
)
<<
"Error: cannot support multi-last-axis reduction."
;
}
}
else
if
(
mode
==
"reduce_bisection"
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_REDUCTION_BISECTION
;
}
else
{
if
(
mode
!=
"reduction"
&&
mode
!=
"broadcast"
)
{
FillLastDim
(
dst_info_list
,
src_info_list
,
for_info
);
}
// vmax/vmin can't expand mask because it may introduce dirty data
bool
can_expand_mask
=
intrin_name
!=
"vmax"
&&
intrin_name
!=
"vmin"
;
BinaryVecPatternGenerator
generator
=
BinaryVecPatternGenerator
(
dst_info_list
,
src_info_list
,
for_info
,
mode
,
can_expand_mask
);
auto
params
=
generator
.
GetInsnArgs
();
arg_info
=
params
.
arg_info
;
dst_info_list
=
params
.
dst_info_list
;
src_info_list
=
params
.
src_info_list
;
for_info
=
params
.
for_info
;
if
(
mode
==
"broadcast"
)
{
bool
has_last_axis
=
false
;
if
((
arg_info
->
body_arg_info_
.
defined
()
&&
arg_info
->
body_arg_info_
->
last_axis_info_
.
src_index_
!=
-
1
)
||
(
arg_info
->
tail_arg_info_
.
defined
()
&&
arg_info
->
tail_arg_info_
->
last_axis_info_
.
src_index_
!=
-
1
))
{
has_last_axis
=
true
;
}
if
(
has_last_axis
&&
(
intrin_name
==
"vadd"
||
intrin_name
==
"vmul"
))
{
Array
<
NodeRef
>
stores
;
Array
<
NodeRef
>
loads
;
GetStoreAndLoads
(
stmt
,
stores
,
loads
);
intrin_name
=
intrin_name
+
"s"
;
if
(
arg_info
->
body_arg_info_
.
defined
())
{
arg_info
.
GetNode
()
->
body_arg_info_
.
GetNode
()
->
last_axis_info_
.
intrin_name_
=
intrin_name
;
arg_info
.
GetNode
()
->
body_arg_info_
.
GetNode
()
->
last_axis_info_
.
src_op_
=
Downcast
<
Expr
>
(
loads
[
arg_info
->
body_arg_info_
->
last_axis_info_
.
src_index_
]);
}
}
}
}
return
arg_info
;
}
/// Replace com_info's var with new for loop's var
/// \param info
/// \param old_for_info
/// \param new_for_info
void
ReplaceVarWithNewForInfo
(
StmtStoreInfo
&
info
,
const
StmtInfo
&
old_for_info
,
const
StmtInfo
&
new_for_info
)
{
for
(
size_t
i
=
0
;
i
<
new_for_info
.
vars_
.
size
();
++
i
)
{
for
(
size_t
j
=
0
;
j
<
info
->
var_
.
size
();
++
j
)
{
if
(
info
->
var_
[
j
]
->
name_hint
==
new_for_info
.
vars_
[
i
]
->
name_hint
)
{
SetItem
(
info
.
GetNode
()
->
var_
,
static_cast
<
int
>
(
j
),
new_for_info
.
vars_
[
i
]);
}
}
info
.
GetNode
()
->
index_
=
substitute
(
old_for_info
.
vars_
[
i
],
new_for_info
.
vars_
[
i
],
info
->
index_
);
}
}
/// Generete info list for bisection intrin
/// \param dst_info_list
/// \param src_info_list
/// \param for_info
/// \param if_info
/// \param last_axis
/// \param postfix
/// \return
BisectionInfoWrapper
SeparateComInfoToBisectionInfoList
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
StmtInfo
&
if_info
,
bool
last_axis
,
int
postfix
=
0
)
{
CHECK_EQ
(
dst_info_list
.
size
(),
1
);
CHECK_EQ
(
src_info_list
.
size
(),
2
);
BisectionInfoWrapper
wrapper
;
// Separate com_info and for_info
int
compare_idx
=
1
;
int
var_idx
=
-
1
;
if
(
last_axis
)
{
compare_idx
=
GetLastAxisReductionIdx
(
dst_info_list
,
src_info_list
);
}
else
{
var_idx
=
GetBisectionReductionIdx
(
dst_info_list
,
src_info_list
,
compare_idx
);
}
StmtStoreInfo
dst_info
=
dst_info_list
[
0
];
CHECK_GE
(
compare_idx
,
0
);
StmtStoreInfo
src_info1
=
src_info_list
[
compare_idx
];
Var
reduce_var
=
GetItem
(
src_info1
->
var_
,
var_idx
);
size_t
for_idx
=
0
;
bool
suc
=
GetIndexOfElement
(
for_info
.
vars_
,
VarExpr
(
reduce_var
),
for_idx
);
CHECK
(
suc
);
auto
exist_for
=
GetItem
(
for_info
.
ops_
,
for_idx
).
as
<
For
>
();
CHECK
(
exist_for
);
int
extent
=
GetInt32Const
(
exist_for
->
extent
);
std
::
string
prev_name
=
src_info1
->
name_
;
Var
prev_var
=
src_info1
->
data_
;
Buffer
prev_buffer
=
src_info1
->
buffer_
;
Var
bisec_var
;
Buffer
bisec_buffer
;
std
::
string
bisec_pre_header
=
last_axis
?
"bisec_last_axis"
:
"bisec"
;
std
::
string
bisec_name
=
bisec_pre_header
+
"_local_UB"
;
if
(
postfix
>
0
)
{
bisec_name
=
bisec_name
+
"_"
+
std
::
to_string
(
postfix
);
}
bool
first_round
=
true
;
int
vec_max_len
=
GetVecMaxLen
(
dst_info
->
dtype_
);
int
remain_extent
=
extent
;
int
left_extent
=
0
;
CHECK_NE
(
vec_max_len
,
0
);
std
::
vector
<
int
>
pow2_list
=
{
0
,
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
32768
,
65536
};
while
(
extent
>
0
)
{
int
for_extent
;
if
(
last_axis
)
{
left_extent
=
remain_extent
/
2
+
remain_extent
%
2
;
for
(
int
i
:
pow2_list
)
{
if
(
left_extent
==
i
)
{
break
;
}
else
if
(
left_extent
<
i
)
{
left_extent
=
i
;
break
;
}
}
if
(
left_extent
<
vec_max_len
)
{
// When left_extent < vec_max_len, stop bisect and generate normal reduce intrin
left_extent
=
remain_extent
;
}
extent
=
remain_extent
-
left_extent
;
remain_extent
=
left_extent
;
for_extent
=
extent
==
0
?
vec_max_len
:
extent
;
}
else
{
for_extent
=
extent
==
1
?
extent
:
extent
/
2
;
extent
=
extent
%
2
==
0
||
extent
==
1
?
extent
/
2
:
(
extent
+
1
)
/
2
;
for
(
int
i
:
pow2_list
)
{
if
(
extent
==
i
)
{
break
;
}
else
if
(
extent
<
i
)
{
int
gap
=
i
-
extent
;
extent
=
i
;
for_extent
-=
gap
;
break
;
}
}
}
StmtStoreInfo
dst_tmp_info
=
dst_info
.
Copy
();
StmtStoreInfo
src_tmp_info0
{
src_info1
.
Copy
()};
StmtStoreInfo
src_tmp_info1
{
src_info1
.
Copy
()};
if
(
first_round
)
{
auto
shape
=
src_tmp_info1
->
shape_
;
if
(
last_axis
)
{
int
block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
SetItem
(
shape
,
-
1
,
Expr
(
CeilTo
(
GetIntConst
(
GetItem
(
shape
,
-
1
)),
block_size
)));
}
wrapper
.
original_shape_
=
shape
;
bisec_var
=
Var
(
bisec_name
,
Handle
());
bisec_buffer
=
BufferNode
::
make
(
bisec_var
,
dst_tmp_info
->
dtype_
,
shape
,
Array
<
Expr
>
(),
Expr
(),
bisec_name
,
SCOPE_UBUF
,
0
,
0
,
BufferType
::
kDefault
);
if
((
last_axis
&&
extent
!=
left_extent
)
||
(
!
last_axis
&&
extent
!=
for_extent
))
{
// Need to copy input to bisect buffer
StmtStoreInfo
copy_dst_info
{
src_info1
.
Copy
()};
StmtStoreInfo
copy_src_info
{
src_info1
.
Copy
()};
StmtInfoList
src_list
=
{
copy_src_info
};
auto
for_tmp_info
=
for_info
.
Copy
();
auto
new_for
=
GetItem
(
for_tmp_info
.
ops_
,
for_idx
).
as
<
For
>
();
CHECK
(
new_for
);
SetItem
(
for_tmp_info
.
ops_
,
static_cast
<
int
>
(
for_idx
),
For
::
make
(
new_for
->
loop_var
,
new_for
->
min
,
last_axis
?
left_extent
:
extent
,
new_for
->
for_type
,
new_for
->
device_api
,
new_for
->
body
));
ReplaceVarWithNewForInfo
(
copy_dst_info
,
for_info
,
for_tmp_info
);
ReplaceVarWithNewForInfo
(
copy_src_info
,
for_info
,
for_tmp_info
);
SetItem
(
copy_dst_info
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
last_axis
?
left_extent
:
extent
));
SetItem
(
copy_src_info
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
last_axis
?
left_extent
:
extent
));
CompactComputationInfoList
(
copy_dst_info
,
src_list
,
if_info
,
for_tmp_info
);
copy_dst_info
.
GetNode
()
->
name_
=
bisec_name
;
copy_dst_info
.
GetNode
()
->
buffer_
=
bisec_buffer
;
copy_dst_info
.
GetNode
()
->
data_
=
bisec_var
;
// Replace outside for variable in index
auto
vars
=
GetVarsInExpr
(
copy_dst_info
->
index_
);
for
(
auto
var
:
vars
)
{
if
(
!
IsInArray
(
copy_dst_info
->
var_
,
var
))
{
copy_dst_info
.
GetNode
()
->
index_
=
substitute
(
var
,
Expr
(
0
),
copy_dst_info
->
index_
);
}
}
wrapper
.
bisec_info_list_
.
emplace_back
(
StmtInfoList
{
copy_dst_info
,
copy_src_info
});
wrapper
.
for_info_list_
.
push_back
(
for_tmp_info
);
}
}
auto
for_tmp_info
=
for_info
.
Copy
();
auto
new_for
=
GetItem
(
for_tmp_info
.
ops_
,
for_idx
).
as
<
For
>
();
CHECK
(
new_for
);
SetItem
(
for_tmp_info
.
ops_
,
static_cast
<
int
>
(
for_idx
),
For
::
make
(
new_for
->
loop_var
,
new_for
->
min
,
for_extent
,
new_for
->
for_type
,
new_for
->
device_api
,
new_for
->
body
));
ReplaceVarWithNewForInfo
(
dst_tmp_info
,
for_info
,
for_tmp_info
);
ReplaceVarWithNewForInfo
(
src_tmp_info0
,
for_info
,
for_tmp_info
);
ReplaceVarWithNewForInfo
(
src_tmp_info1
,
for_info
,
for_tmp_info
);
SetItem
(
src_tmp_info0
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
for_extent
));
SetItem
(
src_tmp_info1
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
for_extent
));
if
(
extent
>
0
)
{
dst_tmp_info
.
GetNode
()
->
shape_
=
src_tmp_info1
->
shape_
;
dst_tmp_info
.
GetNode
()
->
strides_
=
src_tmp_info1
->
strides_
;
dst_tmp_info
.
GetNode
()
->
var_
=
src_tmp_info1
->
var_
;
dst_tmp_info
.
GetNode
()
->
index_
=
src_tmp_info1
->
index_
;
dst_tmp_info
.
GetNode
()
->
data_alignment_
=
src_tmp_info1
->
data_alignment_
;
dst_tmp_info
.
GetNode
()
->
name_
=
bisec_name
;
dst_tmp_info
.
GetNode
()
->
buffer_
=
bisec_buffer
;
dst_tmp_info
.
GetNode
()
->
data_
=
bisec_var
;
auto
src_extent
=
Expr
(
left_extent
);
if
(
!
last_axis
)
{
src_extent
=
GetItem
(
src_tmp_info1
->
strides_
,
var_idx
)
*
extent
;
}
src_tmp_info1
.
GetNode
()
->
index_
=
src_tmp_info1
->
index_
+
src_extent
;
}
src_tmp_info0
.
GetNode
()
->
name_
=
prev_name
;
src_tmp_info1
.
GetNode
()
->
name_
=
prev_name
;
src_tmp_info0
.
GetNode
()
->
buffer_
=
prev_buffer
;
src_tmp_info1
.
GetNode
()
->
buffer_
=
prev_buffer
;
src_tmp_info0
.
GetNode
()
->
data_
=
prev_var
;
src_tmp_info1
.
GetNode
()
->
data_
=
prev_var
;
// Replace outside for variable in index
for
(
auto
&
info
:
{
dst_tmp_info
,
src_tmp_info0
,
src_tmp_info1
})
{
if
(
info
->
name_
.
find
(
bisec_pre_header
)
==
std
::
string
::
npos
)
{
continue
;
}
auto
vars
=
GetVarsInExpr
(
info
->
index_
);
for
(
auto
var
:
vars
)
{
if
(
!
IsInArray
(
info
->
var_
,
var
))
{
info
.
GetNode
()
->
index_
=
substitute
(
var
,
Expr
(
0
),
info
->
index_
);
}
}
}
prev_name
=
bisec_name
;
prev_var
=
bisec_var
;
prev_buffer
=
bisec_buffer
;
StmtInfoList
src_list
=
{
src_tmp_info0
,
src_tmp_info1
};
CompactComputationInfoList
(
dst_tmp_info
,
src_list
,
if_info
,
for_tmp_info
);
wrapper
.
for_info_list_
.
emplace_back
(
for_tmp_info
);
if
(
extent
==
0
)
{
// last round should be dst = dst + src_tmp1
wrapper
.
bisec_info_list_
.
emplace_back
(
StmtInfoList
{
dst_tmp_info
,
dst_tmp_info
,
src_tmp_info1
});
}
else
{
// normally is dst_tmp = src_tmp0 + src_tmp1
wrapper
.
bisec_info_list_
.
emplace_back
(
StmtInfoList
{
dst_tmp_info
,
src_tmp_info0
,
src_tmp_info1
});
}
first_round
=
false
;
}
// Generate arg_info
for
(
size_t
i
=
0
;
i
<
wrapper
.
bisec_info_list_
.
size
();
++
i
)
{
auto
info_list
=
wrapper
.
bisec_info_list_
[
i
];
auto
new_for_info
=
wrapper
.
for_info_list_
[
i
];
ArgInfo
arg_info
;
auto
dst_list
=
GetRange
(
info_list
,
0
,
1
);
auto
src_list
=
GetRange
(
info_list
,
1
,
info_list
.
size
()
-
1
);
if
(
info_list
.
size
()
==
2
)
{
std
::
string
dma_intrin
=
INTRIN_NAME_COPY_UB_TO_UB
;
wrapper
.
dma_arg_info_map_
=
GetDmaCopyInsnArgs
(
dma_intrin
,
dst_list
,
src_list
,
new_for_info
);
}
else
if
(
last_axis
&&
i
==
wrapper
.
bisec_info_list_
.
size
()
-
1
)
{
auto
dst_tmp_info
=
dst_list
[
0
];
auto
src_tmp_info
=
src_list
[
1
];
ReduceLastAxisPatternGenerator
generator
=
ReduceLastAxisPatternGenerator
(
dst_tmp_info
,
src_tmp_info
,
new_for_info
,
"vadd"
);
auto
result
=
generator
.
GetInsnArgs
();
arg_info
=
result
.
arg_info
;
dst_tmp_info
=
result
.
dst_info_list
[
0
];
src_tmp_info
=
result
.
src_info_list
[
0
];
new_for_info
=
result
.
for_info
;
wrapper
.
bisec_info_list_
[
i
]
=
{
dst_tmp_info
,
dst_tmp_info
,
src_tmp_info
};
}
else
{
// Bisect can't expand mask because it has inplace operation
if
(
i
!=
wrapper
.
bisec_info_list_
.
size
()
-
1
)
{
// Last round dont need to add
FillLastDim
(
dst_list
,
src_list
,
new_for_info
);
}
std
::
string
mode
=
GetBinaryVecMode
(
dst_list
,
src_list
,
"vadd"
,
false
);
BinaryVecPatternGenerator
generator
=
BinaryVecPatternGenerator
(
dst_list
,
src_list
,
new_for_info
,
mode
,
false
);
auto
params
=
generator
.
GetInsnArgs
();
arg_info
=
params
.
arg_info
;
dst_list
=
params
.
dst_info_list
;
src_list
=
params
.
src_info_list
;
new_for_info
=
params
.
for_info
;
wrapper
.
bisec_info_list_
[
i
]
=
{
dst_list
[
0
],
src_list
[
0
],
src_list
[
1
]};
}
wrapper
.
arg_info_list_
.
push_back
(
arg_info
);
wrapper
.
for_info_list_
[
i
]
=
new_for_info
;
}
return
wrapper
;
}
}
// namespace akg
src/emit_insn/insn_emitter.cc
浏览文件 @
aec205ff
...
@@ -35,7 +35,7 @@
...
@@ -35,7 +35,7 @@
#include "insn_info.h"
#include "insn_info.h"
#include "insn_pattern.h"
#include "insn_pattern.h"
#include "insn_emitter_multimask.h"
#include "insn_emitter_multimask.h"
#include "insn_args_calculator.h"
namespace
akg
{
namespace
akg
{
namespace
ir
{
namespace
ir
{
/// Sort indexes
/// Sort indexes
...
@@ -71,8 +71,7 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
...
@@ -71,8 +71,7 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
Array
<
Expr
>
call_args
;
Array
<
Expr
>
call_args
;
int
call_cnt
=
0
;
int
call_cnt
=
0
;
if
(
intrin_name
==
"vector_dup"
||
intrin_name
==
"vadds"
||
if
(
intrin_name
==
"vector_dup"
||
intrin_name
==
"vadds"
||
intrin_name
==
"vmuls"
||
intrin_name
==
"vaxpy"
)
{
intrin_name
==
"vmuls"
||
intrin_name
==
"vaxpy"
)
{
auto
GetCallInfo
=
[
&
intrin_name
,
&
call_args
,
&
call_cnt
](
const
NodeRef
&
op
)
{
auto
GetCallInfo
=
[
&
intrin_name
,
&
call_args
,
&
call_cnt
](
const
NodeRef
&
op
)
{
if
(
op
.
as
<
Call
>
()
&&
op
.
as
<
Call
>
()
->
name
==
intrin_name
)
{
if
(
op
.
as
<
Call
>
()
&&
op
.
as
<
Call
>
()
->
name
==
intrin_name
)
{
call_args
=
op
.
as
<
Call
>
()
->
args
;
call_args
=
op
.
as
<
Call
>
()
->
args
;
...
@@ -82,8 +81,8 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
...
@@ -82,8 +81,8 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
PostOrderVisit
(
op
,
GetCallInfo
);
PostOrderVisit
(
op
,
GetCallInfo
);
CHECK_EQ
(
call_cnt
,
1
);
CHECK_EQ
(
call_cnt
,
1
);
}
}
SingleType
insn_type
{
SingleType
::
SIMD
};
SingleType
insn_type
{
SingleType
::
SIMD
};
Expr
scalar_src
{};
Expr
scalar_src
{};
if
(
intrin_name
==
"vector_dup"
)
{
if
(
intrin_name
==
"vector_dup"
)
{
insn_type
=
SingleType
::
Vector_Dump
;
insn_type
=
SingleType
::
Vector_Dump
;
src_info_list
=
{};
src_info_list
=
{};
...
@@ -93,10 +92,11 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
...
@@ -93,10 +92,11 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
src_info_list
=
{
src_info_list
[
0
]};
src_info_list
=
{
src_info_list
[
0
]};
scalar_src
=
call_args
[
1
];
scalar_src
=
call_args
[
1
];
}
}
// check is single vector broadcast reduce mode exist
// check is single vector broadcast reduce mode exist
SingleVecPatternGenerator
generator
=
SingleVecPatternGenerator
(
dst_info_list
,
src_info_list
,
for_info
);
auto
params
=
generator
.
GetInsnArgs
();
SingleVecInsnArgsCalculator
args_calculator
=
SingleVecInsnArgsCalculator
(
dst_info_list
,
src_info_list
,
for_info
,
intrin_name
);
PatternResult
params
=
args_calculator
.
GetInsnArgs
();
dst_info_list
=
params
.
dst_info_list
;
dst_info_list
=
params
.
dst_info_list
;
src_info_list
=
params
.
src_info_list
;
src_info_list
=
params
.
src_info_list
;
for_info
=
params
.
for_info
;
for_info
=
params
.
for_info
;
...
@@ -141,24 +141,17 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec
...
@@ -141,24 +141,17 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec
if
(
src_info_list
[
0
]
->
var_
.
size
()
>
src_info_list
[
1
]
->
var_
.
size
())
{
if
(
src_info_list
[
0
]
->
var_
.
size
()
>
src_info_list
[
1
]
->
var_
.
size
())
{
src_info
=
src_info_list
[
0
];
src_info
=
src_info_list
[
0
];
}
}
const
int
vec_max_len
=
GetVecMaxLen
(
dst_info
->
dtype_
);
if
(
enable_bisect
&&
GetIntConst
(
GetItem
(
src_info
->
shape_
,
-
1
))
>
vec_max_len
)
{
CommentManager
::
GetInstance
().
AddComment
(
"Bisect_optimize"
,
"enabled"
);
auto
wrapper
=
SeparateComInfoToBisectionInfoList
(
dst_info_list
,
src_info_list
,
for_info
,
if_info
,
true
,
postfix
);
return
EmitCceBinaryVectorToBisectionReduction
(
wrapper
,
if_info
,
intrin_name
);
}
else
{
CommentManager
::
GetInstance
().
AddComment
(
"Pattern"
,
arg_info
.
GetPattern
());
CommentManager
::
GetInstance
().
AddComment
(
"Pattern"
,
arg_info
.
GetPattern
());
ReduceLastAxisPatternGenerator
generator
=
ReduceLastAxisPatternGenerator
(
dst_info
,
src_info
,
for_info
,
intrin_name
);
LastAxisReduceInsnArgsCalculator
args_calculator
=
LastAxisReduceInsnArgsCalculator
(
dst_info
,
src_info
,
for_info
,
intrin_name
);
auto
result
=
generator
.
GetInsnArgs
();
PatternResult
result
=
args_calculator
.
GetInsnArgs
();
arg_info
=
result
.
arg_info
;
arg_info
=
result
.
arg_info
;
dst_info
=
result
.
dst_info_list
[
0
];
dst_info
=
result
.
dst_info_list
[
0
];
src_info
=
result
.
src_info_list
[
0
];
src_info
=
result
.
src_info_list
[
0
];
for_info
=
result
.
for_info
;
for_info
=
result
.
for_info
;
return
EmitCceBinaryVectorToReduceLastAxis
(
dst_info
,
src_info
,
if_info
,
for_info
,
arg_info
,
intrin_name
);
return
EmitCceBinaryVectorToReduceLastAxis
(
dst_info
,
src_info
,
if_info
,
for_info
,
arg_info
,
intrin_name
);
}
}
}
case
ARG_VECTOR_REDUCTION_BISECTION
:
{
case
ARG_VECTOR_REDUCTION_BISECTION
:
{
CommentManager
::
GetInstance
().
AddComment
(
"Compute_type"
,
"reduction"
);
CommentManager
::
GetInstance
().
AddComment
(
"Compute_type"
,
"reduction"
);
CommentManager
::
GetInstance
().
AddComment
(
"Bisect_optimize"
,
"enabled"
);
CommentManager
::
GetInstance
().
AddComment
(
"Bisect_optimize"
,
"enabled"
);
...
@@ -192,7 +185,7 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec
...
@@ -192,7 +185,7 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec
return
FoldInsnWithForInfo
(
insn_list
,
if_info
,
for_info
,
stmt
);
return
FoldInsnWithForInfo
(
insn_list
,
if_info
,
for_info
,
stmt
);
}
}
}
}
}
}
// namespace ir
/// Function to emit scalar intrin
/// Function to emit scalar intrin
/// \param op - The input stmt to be emitted as intrin
/// \param op - The input stmt to be emitted as intrin
...
@@ -984,8 +977,9 @@ Stmt BinaryDropoutEmitter(const Stmt &op) {
...
@@ -984,8 +977,9 @@ Stmt BinaryDropoutEmitter(const Stmt &op) {
src1
.
GetNode
()
->
data_
=
mask
->
buffer_var
;
src1
.
GetNode
()
->
data_
=
mask
->
buffer_var
;
src1
.
GetNode
()
->
data_alignment_
=
GetInt32Const
(
mask
->
predicate
);
src1
.
GetNode
()
->
data_alignment_
=
GetInt32Const
(
mask
->
predicate
);
SingleVecPatternGenerator
generator
=
SingleVecPatternGenerator
(
dst_info_list
,
src_info_list
,
for_info
,
"elewise"
);
SingleVecInsnArgsCalculator
args_calculator
=
SingleVecInsnArgsCalculator
(
dst_info_list
,
src_info_list
,
for_info
);
auto
params
=
generator
.
GetInsnArgs
();
PatternResult
params
=
args_calculator
.
GetInsnArgs
();
dst_info_list
=
params
.
dst_info_list
;
dst_info_list
=
params
.
dst_info_list
;
src_info_list
=
params
.
src_info_list
;
src_info_list
=
params
.
src_info_list
;
for_info
=
params
.
for_info
;
for_info
=
params
.
for_info
;
...
@@ -1484,8 +1478,10 @@ Stmt BinaryArgOpEmitter(const Stmt &op, const std::string &intrin_name) {
...
@@ -1484,8 +1478,10 @@ Stmt BinaryArgOpEmitter(const Stmt &op, const std::string &intrin_name) {
if
(
src_info_list
[
0
]
->
var_
.
size
()
>
src_info_list
[
1
]
->
var_
.
size
())
{
if
(
src_info_list
[
0
]
->
var_
.
size
()
>
src_info_list
[
1
]
->
var_
.
size
())
{
src_info
=
src_info_list
[
0
];
src_info
=
src_info_list
[
0
];
}
}
ReduceLastAxisPatternGenerator
generator
=
ReduceLastAxisPatternGenerator
(
dst_info
,
src_info
,
for_info
,
intrin_name
);
auto
result
=
generator
.
GetInsnArgs
();
LastAxisReduceInsnArgsCalculator
args_calculator
=
LastAxisReduceInsnArgsCalculator
(
dst_info
,
src_info
,
for_info
,
intrin_name
);
PatternResult
result
=
args_calculator
.
GetInsnArgs
();
arg_info
=
result
.
arg_info
;
arg_info
=
result
.
arg_info
;
dst_info
=
result
.
dst_info_list
[
0
];
dst_info
=
result
.
dst_info_list
[
0
];
src_info
=
result
.
src_info_list
[
0
];
src_info
=
result
.
src_info_list
[
0
];
...
...
src/emit_insn/insn_info.cc
浏览文件 @
aec205ff
...
@@ -104,10 +104,7 @@ StmtStoreInfo StmtStoreInfo::Copy() const {
...
@@ -104,10 +104,7 @@ StmtStoreInfo StmtStoreInfo::Copy() const {
StmtInfo
StmtInfo
::
Copy
()
const
{
StmtInfo
StmtInfo
::
Copy
()
const
{
auto
stmt_info
=
StmtInfo
();
auto
stmt_info
=
StmtInfo
();
stmt_info
.
ops_
=
ops_
;
stmt_info
.
ops_
=
ops_
;
for
(
auto
var
:
vars_
)
{
stmt_info
.
vars_
=
vars_
;
auto
new_var
=
Variable
::
make
(
var
->
type
,
var
->
name_hint
);
stmt_info
.
vars_
.
push_back
(
new_var
);
}
for
(
size_t
i
=
0
;
i
<
vars_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
vars_
.
size
();
++
i
)
{
for
(
size_t
j
=
0
;
j
<
stmt_info
.
ops_
.
size
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
stmt_info
.
ops_
.
size
();
++
j
)
{
...
...
src/emit_insn/insn_info.h
浏览文件 @
aec205ff
...
@@ -276,15 +276,7 @@ struct BisectionInfoWrapper {
...
@@ -276,15 +276,7 @@ struct BisectionInfoWrapper {
Map
<
std
::
string
,
Expr
>
dma_arg_info_map_
;
Map
<
std
::
string
,
Expr
>
dma_arg_info_map_
;
};
};
struct
InsnAxis
{
int
min
{
0
};
int
extent
{
0
};
Var
var
;
int
dst_stride
{
0
};
int
src_stride
{
0
};
std
::
list
<
int
>
src_stride_list
;
std
::
list
<
int
>
stride_list
;
};
IterVar
GetCceAxis
();
IterVar
GetCceAxis
();
...
...
src/emit_insn/insn_pattern.cc
浏览文件 @
aec205ff
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
*/
*/
#include "insn_pattern.h"
#include "insn_pattern.h"
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/base.h>
#include <tvm/base.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
...
@@ -200,28 +199,6 @@ ArgInfo GetMultiVecInsnArgs(StmtInfoList &dst_info_list, StmtInfoList &src_info_
...
@@ -200,28 +199,6 @@ ArgInfo GetMultiVecInsnArgs(StmtInfoList &dst_info_list, StmtInfoList &src_info_
return
arg_info
;
return
arg_info
;
}
}
/// Get first non zero shape from input shapes
/// \param dst_shape
/// \param src0_shape
/// \param src1_shape
/// \return
int
PatternGenerator
::
GetNonZeroShape
(
const
Expr
&
dst_shape
,
const
Expr
&
src0_shape
,
const
Expr
&
src1_shape
)
{
int
shape
=
0
;
for
(
int
val
:
{
GetInt32Const
(
dst_shape
),
GetInt32Const
(
src0_shape
),
src1_shape
.
defined
()
?
GetInt32Const
(
src1_shape
)
:
0
})
{
if
(
val
==
0
)
{
continue
;
}
if
(
shape
!=
0
&&
val
!=
shape
)
{
LOG
(
FATAL
)
<<
"Error: same var has different shapes. "
<<
GetIntConst
(
dst_shape
)
<<
" "
<<
GetIntConst
(
src0_shape
);
}
shape
=
val
;
}
CHECK
(
shape
!=
0
)
<<
"Error: all shapes are equal to 0."
;
return
shape
;
}
/// In case
/// In case
/// for (cc3) {
/// for (cc3) {
/// A[(cc3*16)] = (B[(cc3*16)] - C[(cc3*16)])
/// A[(cc3*16)] = (B[(cc3*16)] - C[(cc3*16)])
...
@@ -432,25 +409,6 @@ void CleanZeroStrides(Array<StmtStoreInfo> &info_list) {
...
@@ -432,25 +409,6 @@ void CleanZeroStrides(Array<StmtStoreInfo> &info_list) {
}
}
}
}
/// Swap axis in Array
/// \param var
/// \param shape
/// \param strides
/// \param idx1
/// \param idx2
void
PatternGenerator
::
GetShapeInfoAndSwap
(
Array
<
Var
>
&
var
,
Array
<
Expr
>
&
shape
,
Array
<
Expr
>
&
strides
,
int
idx1
,
int
idx2
)
{
auto
tmp_var
=
GetItem
(
var
,
idx1
);
SetItem
(
var
,
idx1
,
GetItem
(
var
,
idx2
));
SetItem
(
var
,
idx2
,
tmp_var
);
auto
tmp_shape
=
GetItem
(
shape
,
idx1
);
SetItem
(
shape
,
idx1
,
GetItem
(
shape
,
idx2
));
SetItem
(
shape
,
idx2
,
tmp_shape
);
auto
tmp_stride
=
GetItem
(
strides
,
idx1
);
SetItem
(
strides
,
idx1
,
GetItem
(
strides
,
idx2
));
SetItem
(
strides
,
idx2
,
tmp_stride
);
}
/// Get insn args of load 2D intrin
/// Get insn args of load 2D intrin
/// \param intrin_name
/// \param intrin_name
/// \param dst_info_list
/// \param dst_info_list
...
@@ -856,6 +814,38 @@ Map<std::string, Expr> GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn
...
@@ -856,6 +814,38 @@ Map<std::string, Expr> GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn
return
arg_info_map
;
return
arg_info_map
;
}
}
/// Replace com_info's var with new for loop's var
/// \param info
/// \param old_for_info
/// \param new_for_info
void
ReplaceVarWithNewForInfo
(
StmtStoreInfo
&
info
,
const
StmtInfo
&
old_for_info
,
const
StmtInfo
&
new_for_info
)
{
for
(
size_t
i
=
0
;
i
<
new_for_info
.
vars_
.
size
();
++
i
)
{
for
(
size_t
j
=
0
;
j
<
info
->
var_
.
size
();
++
j
)
{
if
(
info
->
var_
[
j
]
->
name_hint
==
new_for_info
.
vars_
[
i
]
->
name_hint
)
{
SetItem
(
info
.
GetNode
()
->
var_
,
static_cast
<
int
>
(
j
),
new_for_info
.
vars_
[
i
]);
}
}
info
.
GetNode
()
->
index_
=
substitute
(
old_for_info
.
vars_
[
i
],
new_for_info
.
vars_
[
i
],
info
->
index_
);
}
}
std
::
string
GetBinaryVecMode
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
std
::
string
&
intrin_name
,
bool
enable_bisect
)
{
std
::
set
<
std
::
string
>
reduce_bisect_list
=
{
"vadd"
,
"vsub"
,
"vmul"
,
"vmax"
};
std
::
string
mode
=
"reduction"
;
if
(
IsElementwise
(
dst_info_list
,
src_info_list
))
{
mode
=
"elewise"
;
}
else
if
(
IsBroadcast
(
dst_info_list
,
src_info_list
))
{
mode
=
"broadcast"
;
}
else
if
(
IsLastAxisReduction
(
dst_info_list
,
src_info_list
))
{
mode
=
"reduce_last_axis"
;
}
else
if
(
enable_bisect
&&
reduce_bisect_list
.
count
(
intrin_name
)
!=
0
&&
IsBisectionReduction
(
dst_info_list
,
src_info_list
))
{
mode
=
"reduce_bisection"
;
}
return
mode
;
}
const
char
*
const
DummyLastVar
=
"cc_last"
;
const
char
*
const
DummyLastVar
=
"cc_last"
;
TVM_REGISTER_API
(
"cce_util.GetVecMask"
).
set_body
([](
const
TVMArgs
args
,
TVMRetValue
*
ret
)
{
TVM_REGISTER_API
(
"cce_util.GetVecMask"
).
set_body
([](
const
TVMArgs
args
,
TVMRetValue
*
ret
)
{
...
...
src/emit_insn/insn_pattern.h
浏览文件 @
aec205ff
...
@@ -37,220 +37,12 @@ struct PatternResult {
...
@@ -37,220 +37,12 @@ struct PatternResult {
StmtInfo
for_info
;
StmtInfo
for_info
;
};
};
class
PatternGenerator
{
public:
PatternGenerator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfo
&
for_info
)
:
for_info
(
for_info
),
not_this_pattern
(
-
1.0
f
),
split_latency_coef
(
10.0
f
),
repeat_latency_coef
(
3.0
f
),
offset_latency_coef
(
0.1
f
)
{
CHECK
(
!
dst_info_list
.
empty
());
dst_info
=
dst_info_list
[
0
];
}
virtual
~
PatternGenerator
()
=
default
;
virtual
PatternResult
GetInsnArgs
()
=
0
;
protected:
int
GetNonZeroShape
(
const
Expr
&
dst_shape
,
const
Expr
&
src0_shape
,
const
Expr
&
src1_shape
=
Expr
());
void
GetShapeInfoAndSwap
(
Array
<
Var
>
&
var
,
Array
<
Expr
>
&
shape
,
Array
<
Expr
>
&
strides
,
int
idx1
,
int
idx2
);
virtual
float
Compute3DPatternMaskRate
()
{
return
not_this_pattern
;
}
virtual
float
Compute2DBlockPatternMaskRate
()
{
return
not_this_pattern
;
}
virtual
float
Compute2DPatternMaskRate
()
{
return
not_this_pattern
;
}
virtual
float
Compute1DPatternMaskRate
()
{
return
not_this_pattern
;
}
virtual
Array
<
Var
>
Get3DPattern
()
{
return
{};
}
virtual
Array
<
Var
>
Get2DBlockPattern
()
{
return
{};
}
virtual
Array
<
Var
>
Get2DPattern
()
{
return
{};
}
virtual
Array
<
Var
>
Get1DPattern
()
{
return
{};
}
virtual
PatternResult
GenResult
(
const
Array
<
Var
>
&
elim_var
)
=
0
;
StmtStoreInfo
dst_info
;
StmtInfo
for_info
;
const
float
not_this_pattern
;
const
float
split_latency_coef
;
const
float
repeat_latency_coef
;
const
float
offset_latency_coef
;
};
class
SingleVecPatternGenerator
:
public
PatternGenerator
{
public:
SingleVecPatternGenerator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
mode
=
"elewise"
)
:
PatternGenerator
(
dst_info_list
,
for_info
),
arg_info
(
ArgInfo
(
make_node
<
ArgInfoNode
>
())),
body_args
(
VectorArgInfo
()),
tail_args
(
VectorArgInfo
()),
mode
(
mode
)
{
if
(
src_info_list
.
empty
())
{
src_info
=
dst_info
.
Copy
();
}
else
{
CHECK
(
!
src_info_list
.
empty
());
src_info
=
src_info_list
[
0
];
}
}
~
SingleVecPatternGenerator
()
override
=
default
;
PatternResult
GetInsnArgs
()
final
;
protected:
float
Compute3DPatternMaskRate
()
final
;
float
Compute2DBlockPatternMaskRate
()
final
;
float
Compute2DPatternMaskRate
()
final
;
float
Compute1DPatternMaskRate
()
final
;
float
Compute3DsPatternMaskRate
();
float
Compute2DRepeatPatternMaskRate
();
Array
<
Var
>
Get3DPattern
()
final
;
Array
<
Var
>
Get2DBlockPattern
()
final
;
Array
<
Var
>
Get2DPattern
()
final
;
Array
<
Var
>
Get1DPattern
()
final
;
Array
<
Var
>
Get3DsPattern
();
Array
<
Var
>
Get2DRepeatPattern
();
PatternResult
GenResult
(
const
Array
<
Var
>
&
elim_var
)
final
;
private:
void
CalcParams
();
int
GetLastDimShape
(
const
Expr
&
dst_shape
,
const
Expr
&
src_shape
);
struct
Params
{
Array
<
Var
>
dst_var
;
Array
<
Var
>
src_var
;
Array
<
Expr
>
dst_shape
;
Array
<
Expr
>
src_shape
;
Array
<
Expr
>
dst_strides
;
Array
<
Expr
>
src_strides
;
int
non_zero_shape1
=
0
;
int
non_zero_shape2
=
0
;
int
non_zero_shape3
=
0
;
int
all_points
=
0
;
int
dst_block_size
=
0
;
int
src_block_size
=
0
;
int
mask_block_size
=
0
;
int
dst_bits
=
0
;
int
src_bits
=
0
;
int
max_bits
=
0
;
int
dst_vec_max_len
=
0
;
int
vec_max_len
=
0
;
int
block_offset
=
0
;
};
StmtStoreInfo
src_info
;
Params
params
;
ArgInfo
arg_info
;
VectorArgInfo
body_args
;
VectorArgInfo
tail_args
;
std
::
string
mode
;
Type
data_type
;
};
class
BinaryVecPatternGenerator
:
public
PatternGenerator
{
public:
BinaryVecPatternGenerator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
mode
,
bool
expand_mask
=
true
)
:
PatternGenerator
(
dst_info_list
,
for_info
),
src_info_list
(
src_info_list
),
arg_info
(
ArgInfo
(
make_node
<
ArgInfoNode
>
())),
body_args
(
VectorArgInfo
()),
tail_args
(
VectorArgInfo
()),
empty_var
(
Var
(
""
)),
mode
(
mode
),
expand_mask
(
expand_mask
)
{}
~
BinaryVecPatternGenerator
()
override
=
default
;
PatternResult
GetInsnArgs
()
final
;
protected:
float
Compute3DPatternMaskRate
()
final
;
float
Compute2DBlockPatternMaskRate
()
final
;
float
Compute2DPatternMaskRate
()
final
;
float
Compute1DPatternMaskRate
()
final
;
Array
<
Var
>
Get3DPattern
()
final
;
Array
<
Var
>
Get2DBlockPattern
()
final
;
Array
<
Var
>
Get2DPattern
()
final
;
Array
<
Var
>
Get1DPattern
()
final
;
PatternResult
GenResult
(
const
Array
<
Var
>
&
elim_var
)
final
;
private:
void
CalcParams
();
bool
IsSamePatternComInfo
(
const
StmtStoreInfo
&
info_a
,
const
StmtStoreInfo
&
info_b
);
bool
IsNonZeroShapeEqual
(
const
Array
<
Expr
>
&
shape_list
);
void
AppendEmptyVar
(
StmtInfoList
&
info_list
);
struct
Params
{
Array
<
Var
>
dst_var
;
Array
<
Expr
>
dst_shape
;
Array
<
Expr
>
dst_strides
;
Array
<
Var
>
src_var0
;
Array
<
Expr
>
src_shape0
;
Array
<
Expr
>
src_strides0
;
Array
<
Var
>
src_var1
;
Array
<
Expr
>
src_shape1
;
Array
<
Expr
>
src_strides1
;
int
non_zero_shape1
=
0
;
int
non_zero_shape2
=
0
;
int
non_zero_shape3
=
0
;
int
all_points
=
0
;
int
block_size
=
0
;
int
last_dim_shape
=
0
;
int
vec_max_len
=
0
;
};
StmtInfoList
src_info_list
;
ArgInfo
arg_info
;
VectorArgInfo
body_args
;
VectorArgInfo
tail_args
;
Params
params
;
Var
empty_var
;
std
::
string
mode
;
bool
expand_mask
;
};
class
ReduceLastAxisPatternGenerator
:
public
PatternGenerator
{
public:
ReduceLastAxisPatternGenerator
(
const
StmtStoreInfo
&
dst_info
,
const
StmtStoreInfo
&
src_info
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
)
:
PatternGenerator
({
dst_info
},
for_info
),
src_info
(
src_info
),
arg_info
(
ArgInfo
(
make_node
<
ArgInfoNode
>
())),
body_args
(
VectorArgInfo
()),
tail_args
(
VectorArgInfo
()),
intrin_name
(
intrin_name
)
{}
PatternResult
GetInsnArgs
()
final
;
~
ReduceLastAxisPatternGenerator
()
override
=
default
;
protected:
float
Compute2DBlockPatternMaskRate
()
final
;
Array
<
Var
>
Get2DBlockPattern
()
final
;
Array
<
Var
>
Get1DPattern
()
final
;
PatternResult
GenResult
(
const
Array
<
Var
>
&
elim_var
)
final
;
private:
void
CalcParams
();
struct
Params
{
Array
<
Var
>
src_var
;
int
block_size
=
0
;
int
vec_max_len
=
0
;
int
last_dim_shape
=
0
;
Expr
insn_offset_scale_factor
;
};
StmtStoreInfo
src_info
;
ArgInfo
arg_info
;
VectorArgInfo
body_args
;
VectorArgInfo
tail_args
;
Array
<
VectorArgInfo
>
mix_vec_arg_list
;
std
::
string
intrin_name
;
Params
params
;
};
std
::
string
GetSingleVecComputationInfo
(
const
Stmt
&
stmt
,
const
std
::
string
&
intrin_name
,
std
::
string
GetSingleVecComputationInfo
(
const
Stmt
&
stmt
,
const
std
::
string
&
intrin_name
,
Array
<
StmtStoreInfo
>
&
dst_info_list
,
Array
<
StmtStoreInfo
>
&
src_info_list
,
Array
<
StmtStoreInfo
>
&
dst_info_list
,
Array
<
StmtStoreInfo
>
&
src_info_list
,
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
bool
need_compact
=
true
);
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
bool
need_compact
=
true
);
ArgInfo
GetBinaryVecInsnArgs
(
const
Stmt
&
stmt
,
std
::
string
intrin_name
,
StmtInfoList
&
dst_info_list
,
std
::
string
GetBinaryVecMode
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
,
bool
enable_bisect
=
true
);
bool
enable_bisect
=
true
);
ArgInfo
GetMultiVecInsnArgs
(
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
for_info
);
ArgInfo
GetMultiVecInsnArgs
(
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
for_info
);
...
@@ -277,10 +69,7 @@ Map<std::string, Expr> GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn
...
@@ -277,10 +69,7 @@ Map<std::string, Expr> GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn
const
StmtInfoList
&
src_info_list
,
StmtInfo
&
for_info
,
const
StmtInfoList
&
src_info_list
,
StmtInfo
&
for_info
,
Map
<
std
::
string
,
Expr
>
&
ub_copy_pre
,
Map
<
std
::
string
,
Expr
>
&
ub_copy_post
);
Map
<
std
::
string
,
Expr
>
&
ub_copy_pre
,
Map
<
std
::
string
,
Expr
>
&
ub_copy_post
);
BisectionInfoWrapper
SeparateComInfoToBisectionInfoList
(
const
StmtInfoList
&
dst_info_list
,
void
ReplaceVarWithNewForInfo
(
StmtStoreInfo
&
info
,
const
StmtInfo
&
old_for_info
,
const
StmtInfo
&
new_for_info
);
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
StmtInfo
&
if_info
,
bool
last_axis
,
int
postfix
);
extern
const
char
*
const
DummyLastVar
;
extern
const
char
*
const
DummyLastVar
;
}
// namespace akg
}
// namespace akg
#endif // EMIT_INSN_INSN_PATTERN_H_
#endif // EMIT_INSN_INSN_PATTERN_H_
src/emit_insn/insn_single_vec_pattern.cc
已删除
100644 → 0
浏览文件 @
11ed37cc
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/base.h>
#include <tvm/ir_pass.h>
#include <cmath>
#include <set>
#include "insn_builder.h"
#include "insn_pattern.h"
#include "common/array_api.h"
#include "pass/expr_alg_simplify.h"
namespace
akg
{
/// Get CCE Single Vector Insn mode
/// \param dst_info_list
/// \param src_info_list
/// \return
std
::
string
GetSingleVecMode
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
)
{
CHECK
(
!
dst_info_list
.
empty
());
auto
dst_var_list
=
dst_info_list
[
0
]
->
var_
;
Array
<
Var
>
src_var_list
;
if
(
!
src_info_list
.
empty
())
{
src_var_list
=
src_info_list
[
0
]
->
var_
;
}
if
(
IsSame
(
dst_var_list
,
src_var_list
))
{
return
"elewise"
;
}
else
if
(
dst_var_list
.
size
()
>=
src_var_list
.
size
())
{
return
"broadcast"
;
}
return
"reduction"
;
}
/// Get Single Vector Computation Info
/// \param stmt
/// \param intrin_name
/// \param dst_info_list
/// \param src_info_list
/// \param if_info
/// \param for_info
/// \param need_compact
/// \return
std
::
string
GetSingleVecComputationInfo
(
const
Stmt
&
stmt
,
const
std
::
string
&
intrin_name
,
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
bool
need_compact
)
{
std
::
set
<
std
::
string
>
intrin_name_list
=
{
"vadds"
,
"vmuls"
,
"vrelu"
,
"vabs"
,
"vln"
,
"vexp"
,
"vrec"
,
"vector_dup"
,
"vnot"
,
"vsqrt"
,
"vrsqrt"
};
if
(
intrin_name_list
.
count
(
intrin_name
)
==
0
&&
intrin_name
.
find
(
"vconv_"
)
==
std
::
string
::
npos
)
{
LOG
(
FATAL
)
<<
"Error: CCE Single Vector Insn unsupported the given intrin_name. "
<<
intrin_name
;
return
""
;
}
bool
same_dtype
=
intrin_name
.
find
(
"vconv_"
)
==
std
::
string
::
npos
;
GetCompactComputationInfo
(
stmt
,
dst_info_list
,
src_info_list
,
if_info
,
for_info
,
same_dtype
,
need_compact
);
std
::
string
mode
=
GetSingleVecMode
(
dst_info_list
,
src_info_list
);
CHECK
(
dst_info_list
.
size
()
==
1
)
<<
"CCE Single Vector only support ONE dst."
;
return
mode
;
}
/// Get CCE Single vector instructions args.
/// \param dst_info_list
/// \param src_info_list
/// \param for_info
/// \param mode
/// \return
PatternResult
SingleVecPatternGenerator
::
GetInsnArgs
()
{
CalcParams
();
Array
<
Var
>
elim_var
=
{};
float
rate3d
=
Compute3DPatternMaskRate
();
float
rate2db
=
Compute2DBlockPatternMaskRate
();
float
rate2d
=
Compute2DPatternMaskRate
();
float
rate1d
=
Compute1DPatternMaskRate
();
float
rate3ds
=
Compute3DsPatternMaskRate
();
float
rate2ds
=
Compute2DRepeatPatternMaskRate
();
if
(
mode
==
"broadcast_last_axis"
)
{
elim_var
=
Get1DPattern
();
}
else
if
(
rate2ds
>
0
)
{
elim_var
=
Get2DRepeatPattern
();
}
else
if
(
rate3ds
>
0
)
{
elim_var
=
Get3DsPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_2D
;
}
else
if
(
rate3d
>=
rate2db
&&
rate3d
>
0
)
{
elim_var
=
Get3DPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_3D
;
}
else
if
(
rate2db
>=
rate2d
&&
rate2db
>=
rate1d
&&
rate2db
>
0
)
{
elim_var
=
Get2DBlockPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_PARTIAL_3D
;
}
else
if
(
rate2d
>
rate1d
&&
rate2d
>
0
)
{
elim_var
=
Get2DPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_2D
;
}
else
if
(
rate1d
>
0
)
{
elim_var
=
Get1DPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_1D
;
}
else
{
LOG
(
FATAL
)
<<
"Error: Cannot emit Single-Vector-Insn with any pattern!"
;
}
std
::
string
mask_rate
=
"rate3d["
+
std
::
to_string
(
rate3d
)
+
"], rate2db["
+
std
::
to_string
(
rate2db
)
+
"], rate2d["
+
std
::
to_string
(
rate2d
)
+
"], rate1d["
+
std
::
to_string
(
rate1d
)
+
"]"
;
CommentManager
::
GetInstance
().
AddComment
(
"Mask_rate"
,
mask_rate
);
if
(
arg_info
->
tail_arg_info_
.
defined
())
{
CommentManager
::
GetInstance
().
AddComment
(
"Contain_tail"
,
"true"
);
}
else
{
CommentManager
::
GetInstance
().
AddComment
(
"Contain_tail"
,
"false"
);
}
return
GenResult
(
elim_var
);
}
/// Calc params for pattern match
void
SingleVecPatternGenerator
::
CalcParams
()
{
Array
<
StmtStoreInfo
>
info_list
=
{
dst_info
,
src_info
};
// check shape len
for
(
auto
info
:
info_list
)
{
CHECK
(
!
info
->
shape_
.
empty
())
<<
"CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0."
;
}
FillEmptyVar
(
info_list
);
dst_info
=
info_list
[
0
];
src_info
=
info_list
[
1
];
int
dst_bits
=
dst_info
->
dtype_
.
bits
();
int
src_bits
=
src_info
->
dtype_
.
bits
();
CHECK_NE
(
dst_bits
,
0
);
CHECK_NE
(
src_bits
,
0
);
int
dst_block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
int
src_block_size
=
GetUbBlkSize
(
src_info
->
dtype_
);
CHECK_NE
(
dst_block_size
,
0
);
CHECK_NE
(
src_block_size
,
0
);
data_type
=
src_bits
>
dst_bits
?
src_info
->
dtype_
:
dst_info
->
dtype_
;
params
.
dst_var
=
dst_info
->
var_
;
params
.
src_var
=
src_info
->
var_
;
params
.
dst_shape
=
dst_info
->
shape_
;
params
.
src_shape
=
src_info
->
shape_
;
params
.
dst_strides
=
dst_info
->
strides_
;
params
.
src_strides
=
src_info
->
strides_
;
params
.
dst_block_size
=
dst_block_size
;
params
.
src_block_size
=
src_block_size
;
params
.
mask_block_size
=
src_bits
>
dst_bits
?
src_block_size
:
dst_block_size
;
params
.
dst_bits
=
dst_bits
;
params
.
src_bits
=
src_bits
;
params
.
max_bits
=
FULL_BLOCK_NUM
*
std
::
min
(
dst_bits
,
src_bits
);
params
.
dst_vec_max_len
=
GetVecMaxLen
(
dst_info
->
dtype_
);
params
.
vec_max_len
=
src_bits
>
dst_bits
?
GetVecMaxLen
(
src_info
->
dtype_
)
:
GetVecMaxLen
(
dst_info
->
dtype_
);
CHECK_NE
(
params
.
dst_vec_max_len
,
0
);
CHECK_NE
(
params
.
vec_max_len
,
0
);
auto
GetNonZeroShapeByIdx
=
[
this
](
int
index
)
->
int
{
if
(
index
<=
static_cast
<
int
>
(
params
.
dst_var
.
size
()))
{
if
(
Equal
(
GetItem
(
params
.
dst_var
,
-
index
),
GetItem
(
params
.
src_var
,
-
index
)))
{
return
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
index
),
GetItem
(
params
.
src_shape
,
-
index
));
}
}
return
1
;
};
params
.
non_zero_shape1
=
GetNonZeroShapeByIdx
(
1
);
params
.
non_zero_shape2
=
GetNonZeroShapeByIdx
(
2
);
params
.
non_zero_shape3
=
GetNonZeroShapeByIdx
(
3
);
params
.
all_points
=
params
.
non_zero_shape1
*
params
.
non_zero_shape2
*
params
.
non_zero_shape3
;
auto
elem_offset_mod
=
ir
::
ExprSimplifier
().
Simplify
(
Mod
::
make
(
dst_info
->
elem_offset_
,
dst_block_size
));
if
(
elem_offset_mod
.
as
<
IntImm
>
())
{
params
.
block_offset
=
elem_offset_mod
.
as
<
IntImm
>
()
->
value
;
}
}
int
SingleVecPatternGenerator
::
GetLastDimShape
(
const
Expr
&
dst_shape
,
const
Expr
&
src_shape
)
{
int
dst_last_dim
=
GetInt32Const
(
dst_shape
);
int
src_last_dim
=
GetInt32Const
(
src_shape
);
CHECK
(
dst_last_dim
!=
0
||
src_last_dim
!=
0
);
if
(
dst_last_dim
==
0
)
{
return
src_last_dim
;
}
if
(
src_last_dim
==
0
)
{
return
dst_last_dim
;
}
return
std
::
min
(
dst_last_dim
,
src_last_dim
);
}
bool
FindInShape
(
Array
<
Expr
>
&
shape
,
const
Expr
&
target
)
{
for
(
int
i
=
-
1
;
i
>=
-
3
;
--
i
)
{
if
(
Equal
(
GetItem
(
shape
,
i
),
target
))
{
return
true
;
}
}
return
false
;
}
float
SingleVecPatternGenerator
::
Compute2DRepeatPatternMaskRate
()
{
if
(
params
.
dst_var
.
size
()
<
3
)
{
return
not_this_pattern
;
}
for
(
int
i
=
-
1
;
i
>=
-
3
;
--
i
)
{
if
(
!
FindInShape
(
params
.
src_shape
,
GetItem
(
params
.
dst_shape
,
i
)))
{
return
not_this_pattern
;
}
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
3
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
3
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
&&
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
if
(
!
Equal
(
GetItem
(
params
.
dst_shape
,
-
3
),
GetItem
(
params
.
src_shape
,
-
2
))
||
!
Equal
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape
,
-
3
)))
{
return
not_this_pattern
;
}
if
(
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
))
>
FULL_BLOCK_NUM
&&
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
))
%
FULL_BLOCK_NUM
!=
0
)
{
return
not_this_pattern
;
}
if
(
params
.
dst_block_size
==
params
.
src_block_size
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_shape
,
-
1
))
<=
params
.
dst_block_size
&&
GetInt32Const
(
GetItem
(
params
.
src_shape
,
-
1
))
<=
params
.
src_block_size
)
{
return
not_this_pattern
;
}
return
1.0
;
}
float
SingleVecPatternGenerator
::
Compute3DsPatternMaskRate
()
{
if
(
params
.
dst_var
.
size
()
<
3
)
{
return
not_this_pattern
;
}
if
(
params
.
dst_block_size
!=
params
.
src_block_size
)
{
return
not_this_pattern
;
}
for
(
int
i
=
-
1
;
i
>=
-
3
;
--
i
)
{
if
(
!
FindInShape
(
params
.
src_shape
,
GetItem
(
params
.
dst_shape
,
i
)))
{
return
not_this_pattern
;
}
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_shape
,
-
1
))
>
params
.
dst_block_size
||
GetInt32Const
(
GetItem
(
params
.
src_shape
,
-
1
))
>
params
.
src_block_size
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
3
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
3
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
&&
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
if
(
!
Equal
(
GetItem
(
params
.
dst_shape
,
-
3
),
GetItem
(
params
.
src_shape
,
-
2
))
||
!
Equal
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape
,
-
3
)))
{
return
not_this_pattern
;
}
if
(
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
))
>
FULL_BLOCK_NUM
&&
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
))
%
FULL_BLOCK_NUM
!=
0
)
{
return
not_this_pattern
;
}
return
1.0
;
}
float
SingleVecPatternGenerator
::
Compute3DPatternMaskRate
()
{
// in elemwise mode, the var is already checked to be equal, no need to check
if
(
params
.
dst_var
.
size
()
<
3
)
{
return
not_this_pattern
;
}
// do not support cast op in 3D pattern
if
(
params
.
dst_block_size
!=
params
.
src_block_size
)
{
return
not_this_pattern
;
}
for
(
int
i
=
-
1
;
i
>=
-
3
;
--
i
)
{
if
(
!
IsTwoItemEqual
(
params
.
dst_var
,
params
.
src_var
,
i
))
{
return
not_this_pattern
;
}
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_shape
,
-
1
))
>
params
.
dst_block_size
||
GetInt32Const
(
GetItem
(
params
.
src_shape
,
-
1
))
>
params
.
src_block_size
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
3
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
3
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
&&
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
// repeat axis is shape [-3], repeat once, has 8 loops
bool
is3_d
=
true
;
float
rate3d_mode1
=
not_this_pattern
;
float
rate3d_mode2
=
not_this_pattern
;
int
repeat_num
;
float
repeat_latency
;
StmtInfoList
info_list
=
{
dst_info
,
src_info
};
for
(
auto
info
:
info_list
)
{
if
(
GetInt32Const
(
GetItem
(
info
->
shape_
,
-
2
))
>
FULL_BLOCK_NUM
||
GetInt32Const
(
GetItem
(
info
->
strides_
,
-
2
))
/
params
.
dst_block_size
>=
MAX_STRIDE_M0_SINGLE
||
GetInt32Const
(
GetItem
(
info
->
strides_
,
-
3
))
/
params
.
dst_block_size
>=
MAX_STRIDE_M1
)
{
is3_d
=
false
;
break
;
}
}
if
(
is3_d
)
{
if
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
repeat_num
=
params
.
non_zero_shape3
;
repeat_latency
=
((
repeat_num
-
1
)
/
MAX_REPEAT
)
*
repeat_latency_coef
;
rate3d_mode1
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
dst_vec_max_len
/
(
repeat_num
+
repeat_latency
);
}
is3_d
=
true
;
// repeat axis is shape[-2]
for
(
auto
info
:
info_list
)
{
// stride_m0 should less than 65536
if
(
GetInt32Const
(
GetItem
(
info
->
shape_
,
-
3
))
%
FULL_BLOCK_NUM
!=
0
||
GetInt32Const
(
GetItem
(
info
->
strides_
,
-
3
))
/
params
.
dst_block_size
>=
MAX_STRIDE_M0_SINGLE
)
{
is3_d
=
false
;
break
;
}
}
if
(
is3_d
)
{
if
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
3
))
==
0
)
{
return
not_this_pattern
;
}
repeat_num
=
params
.
non_zero_shape2
*
(
params
.
non_zero_shape3
/
FULL_BLOCK_NUM
);
repeat_latency
=
((
repeat_num
-
1
)
/
MAX_REPEAT
)
*
repeat_latency_coef
;
float
offset_latency
=
params
.
non_zero_shape3
/
FULL_BLOCK_NUM
>
1
?
params
.
non_zero_shape3
*
offset_latency_coef
:
0
;
rate3d_mode2
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
dst_vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
);
}
return
rate3d_mode1
>
rate3d_mode2
?
rate3d_mode1
:
rate3d_mode2
;
}
// Partial 3D Pattern
float
SingleVecPatternGenerator
::
Compute2DBlockPatternMaskRate
()
{
// in elemwise mode, the var is already checked to be equal, no need to check
if
(
params
.
dst_var
.
size
()
<
2
||
params
.
src_var
.
size
()
<
2
||
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
1
))
!=
1
)
{
return
not_this_pattern
;
}
// do not support cast op in Partial3D pattern
if
(
params
.
dst_block_size
!=
params
.
src_block_size
)
{
return
not_this_pattern
;
}
for
(
int
i
=
-
1
;
i
>=
-
2
;
--
i
)
{
if
(
!
Equal
(
GetItem
(
params
.
dst_var
,
i
),
GetItem
(
params
.
src_var
,
i
)))
{
return
not_this_pattern
;
}
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_shape
,
-
1
))
>
params
.
dst_block_size
||
GetInt32Const
(
GetItem
(
params
.
src_shape
,
-
1
))
>
params
.
src_block_size
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
&&
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
/
params
.
dst_block_size
>=
MAX_STRIDE_M0_SINGLE
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
/
params
.
src_block_size
>=
MAX_STRIDE_M0_SINGLE
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
int
repeat_body_num
=
params
.
non_zero_shape2
/
FULL_BLOCK_NUM
;
int
repeat_tail_num
=
(
params
.
non_zero_shape2
%
FULL_BLOCK_NUM
+
FULL_BLOCK_NUM
-
1
)
/
FULL_BLOCK_NUM
;
int
repeat_num
=
(
repeat_body_num
+
repeat_tail_num
)
*
params
.
non_zero_shape3
;
float
repeat_latency
=
(
std
::
max
(
repeat_body_num
-
1
,
0
)
/
MAX_REPEAT
+
std
::
max
(
repeat_tail_num
-
1
,
0
)
/
MAX_REPEAT
)
*
repeat_latency_coef
;
float
offset_latency
=
params
.
non_zero_shape3
>
1
?
params
.
non_zero_shape3
*
offset_latency_coef
:
0
;
float
split_latency
=
(
repeat_body_num
>
0
&&
repeat_tail_num
>
0
)
?
split_latency_coef
:
0
;
float
rate2db
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
dst_vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
+
split_latency
);
return
rate2db
;
}
float
SingleVecPatternGenerator
::
Compute2DPatternMaskRate
()
{
// in elemwise mode, the var is already checked to be equal, no need to check
if
(
params
.
dst_var
.
size
()
<
2
||
params
.
src_var
.
size
()
<
2
)
{
return
not_this_pattern
;
}
if
(
src_info
->
data_alignment_
==
1
&&
GetInt32Const
(
GetItem
(
src_info
->
strides_
,
-
1
))
!=
params
.
dst_block_size
)
{
return
not_this_pattern
;
}
for
(
int
i
=
-
1
;
i
>=
-
2
;
--
i
)
{
if
(
!
Equal
(
GetItem
(
params
.
dst_var
,
i
),
GetItem
(
params
.
src_var
,
i
)))
{
return
not_this_pattern
;
}
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
/
params
.
dst_block_size
>=
MAX_STRIDE_M1
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
/
params
.
src_block_size
>=
MAX_STRIDE_M1
)
{
return
not_this_pattern
;
}
// check num of insns, select 1D pattern or 2D pattern
int
tail_factor
=
0
;
if
(
params
.
non_zero_shape1
/
params
.
dst_vec_max_len
>
0
&&
params
.
non_zero_shape1
%
params
.
dst_vec_max_len
>
0
)
{
tail_factor
=
1
;
}
int
offset_num
=
(
params
.
non_zero_shape1
+
params
.
dst_vec_max_len
-
1
)
/
params
.
dst_vec_max_len
*
params
.
non_zero_shape3
;
int
repeat_num
=
offset_num
*
params
.
non_zero_shape2
;
float
repeat_latency
=
(
std
::
max
(
params
.
non_zero_shape2
-
1
,
0
)
/
MAX_REPEAT
)
*
offset_num
*
repeat_latency_coef
;
float
offset_latency
=
offset_num
>
1
?
offset_num
*
offset_latency_coef
:
0
;
float
split_latency
=
tail_factor
*
split_latency_coef
;
float
rate2d
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
dst_vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
+
split_latency
);
return
rate2d
;
}
float
SingleVecPatternGenerator
::
Compute1DPatternMaskRate
()
{
int
tail_factor
=
0
;
if
(
params
.
non_zero_shape1
/
params
.
dst_vec_max_len
>
0
&&
params
.
non_zero_shape1
%
params
.
dst_vec_max_len
>
0
)
{
tail_factor
=
1
;
}
int
shape1
=
(
params
.
non_zero_shape1
+
params
.
dst_vec_max_len
-
1
)
/
params
.
dst_vec_max_len
;
int
repeat_num
=
shape1
*
params
.
non_zero_shape2
*
params
.
non_zero_shape3
;
float
repeat_latency
=
std
::
max
((
shape1
-
1
)
/
MAX_REPEAT
,
0
)
*
params
.
non_zero_shape2
*
params
.
non_zero_shape3
*
repeat_latency_coef
;
float
offset_latency
=
params
.
non_zero_shape2
*
params
.
non_zero_shape3
>
1
?
params
.
non_zero_shape2
*
params
.
non_zero_shape3
*
offset_latency_coef
:
0
;
float
split_latency
=
tail_factor
*
split_latency_coef
;
float
rate1d
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
dst_vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
+
split_latency
);
return
rate1d
;
}
Array
<
Var
>
SingleVecPatternGenerator
::
Get2DRepeatPattern
()
{
GetShapeInfoAndSwap
(
params
.
src_var
,
params
.
src_shape
,
params
.
src_strides
,
-
2
,
-
3
);
int
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape
,
-
1
));
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
dst_stride_m0_
=
1
;
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
1
};
body_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
dst_block_size
);
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides
,
-
2
),
params
.
src_block_size
)};
body_args
.
GetNode
()
->
repeat_
=
GetItem
(
params
.
dst_shape
,
-
2
);
int
data_len
=
CeilTo
(
last_dim_shape
,
params
.
dst_block_size
);
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
1
,
dst_info
->
dtype_
);
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
Array
<
Var
>
SingleVecPatternGenerator
::
Get3DsPattern
()
{
GetShapeInfoAndSwap
(
params
.
src_var
,
params
.
src_shape
,
params
.
src_strides
,
-
2
,
-
3
);
int
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape
,
-
1
));
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
Expr
dst_stride_m0
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
dst_block_size
);
Expr
src_stride_m0
=
truncdiv
(
GetItem
(
params
.
src_strides
,
-
2
),
params
.
src_block_size
);
body_args
.
GetNode
()
->
dst_stride_m0_
=
dst_stride_m0
;
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
src_stride_m0
};
int
block_num
=
0
;
int
data_len
=
CeilTo
(
last_dim_shape
,
params
.
mask_block_size
);
if
(
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
))
<=
FULL_BLOCK_NUM
)
{
block_num
=
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
));
body_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
3
),
params
.
dst_block_size
);
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides
,
-
3
),
params
.
src_block_size
)};
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
block_num
,
dst_info
->
dtype_
);
auto
repeat
=
GetItem
(
params
.
dst_shape
,
-
3
);
if
(
GetIntConst
(
repeat
)
<
MAX_STRIDE_M1
)
{
body_args
.
GetNode
()
->
repeat_
=
repeat
;
return
GetRange
(
params
.
dst_var
,
-
3
,
3
);
}
else
{
body_args
.
GetNode
()
->
repeat_
=
1
;
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
}
else
{
block_num
=
FULL_BLOCK_NUM
;
body_args
.
GetNode
()
->
dst_stride_m1_
=
dst_stride_m0
*
block_num
;
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
src_stride_m0
*
block_num
};
auto
repeat
=
truncdiv
(
GetItem
(
params
.
dst_shape
,
-
2
),
FULL_BLOCK_NUM
);
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
block_num
,
dst_info
->
dtype_
);
if
(
GetIntConst
(
repeat
)
<
MAX_STRIDE_M1
)
{
body_args
.
GetNode
()
->
repeat_
=
repeat
;
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
else
{
return
Get1DPattern
();
}
}
}
Array
<
Var
>
SingleVecPatternGenerator
::
Get3DPattern
()
{
if
(
GetIntConst
(
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape
,
-
2
)))
>
FULL_BLOCK_NUM
)
{
// split shape[-3]
if
(
GetIntConst
(
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
3
),
GetItem
(
params
.
src_shape
,
-
3
)))
>
FULL_BLOCK_NUM
)
{
StmtInfoList
info_list
=
{
dst_info
,
src_info
};
SplitAxis
(
info_list
,
for_info
,
GetItem
(
params
.
dst_var
,
-
3
),
FULL_BLOCK_NUM
);
FillEmptyVar
(
info_list
);
dst_info
=
info_list
[
0
];
src_info
=
info_list
[
1
];
params
.
dst_var
=
dst_info
->
var_
;
params
.
dst_shape
=
dst_info
->
shape_
;
params
.
dst_strides
=
dst_info
->
strides_
;
params
.
src_var
=
src_info
->
var_
;
params
.
src_shape
=
src_info
->
shape_
;
params
.
src_strides
=
src_info
->
strides_
;
}
// consider original shape[-2] as repeat axis
GetShapeInfoAndSwap
(
params
.
dst_var
,
params
.
dst_shape
,
params
.
dst_strides
,
-
2
,
-
3
);
GetShapeInfoAndSwap
(
params
.
src_var
,
params
.
src_shape
,
params
.
src_strides
,
-
2
,
-
3
);
}
int
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape
,
-
1
));
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
repeat_
=
make_const
(
Int
(
32
),
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
3
),
GetItem
(
params
.
src_shape
,
-
3
)));
body_args
.
GetNode
()
->
dst_stride_m0_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
dst_block_size
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
3
),
params
.
dst_block_size
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides
,
-
2
),
params
.
src_block_size
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides
,
-
3
),
params
.
src_block_size
)};
body_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
params
.
block_offset
);
int
data_len
=
CeilTo
(
last_dim_shape
,
params
.
mask_block_size
);
int
data_num
=
GetInt32Const
(
GetItem
(
params
.
dst_shape
,
-
2
));
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
data_num
,
dst_info
->
dtype_
,
params
.
block_offset
);
return
GetRange
(
params
.
dst_var
,
-
3
,
3
);
}
Array
<
Var
>
SingleVecPatternGenerator
::
Get2DBlockPattern
()
{
int
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape
,
-
1
));
int
repeat_len
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape
,
-
2
));
int
repeat_body
=
repeat_len
/
FULL_BLOCK_NUM
;
int
repeat_tail
=
(
repeat_len
%
FULL_BLOCK_NUM
+
FULL_BLOCK_NUM
-
1
)
/
FULL_BLOCK_NUM
;
int
data_len
=
CeilTo
(
last_dim_shape
,
params
.
dst_block_size
);
if
(
repeat_body
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
repeat_
=
make_const
(
Int
(
32
),
repeat_body
);
auto
dst_stride_m0
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
dst_block_size
);
body_args
.
GetNode
()
->
dst_stride_m0_
=
dst_stride_m0
;
body_args
.
GetNode
()
->
dst_stride_m1_
=
dst_stride_m0
*
(
params
.
max_bits
/
params
.
src_bits
);
auto
src_stride_m0
=
truncdiv
(
GetItem
(
params
.
src_strides
,
-
2
),
params
.
src_block_size
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
src_stride_m0
};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
src_stride_m0
*
(
params
.
max_bits
/
params
.
dst_bits
)};
body_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
params
.
block_offset
);
int
data_num
=
FULL_BLOCK_NUM
;
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
data_num
,
dst_info
->
dtype_
,
params
.
block_offset
);
}
if
(
repeat_tail
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
tail_args
.
GetNode
()
->
dst_head_
=
GetItem
(
params
.
dst_strides
,
-
2
)
*
repeat_body
*
FULL_BLOCK_NUM
;
tail_args
.
GetNode
()
->
src_head_list_
=
{
GetItem
(
params
.
src_strides
,
-
2
)
*
repeat_body
*
FULL_BLOCK_NUM
};
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m0_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
dst_block_size
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
0
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides
,
-
2
),
params
.
src_block_size
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
0
)};
tail_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
params
.
block_offset
);
int
data_num
=
repeat_len
%
FULL_BLOCK_NUM
;
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
data_num
,
dst_info
->
dtype_
,
params
.
block_offset
);
}
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
Array
<
Var
>
SingleVecPatternGenerator
::
Get2DPattern
()
{
const
int
data_num
=
1
;
int
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape
,
-
1
));
if
(
GetInt32Const
(
GetItem
(
dst_info
->
strides_
,
-
1
))
==
params
.
dst_block_size
&&
IsTwoItemEqual
(
dst_info
->
strides_
,
src_info
->
strides_
,
-
1
,
true
))
{
last_dim_shape
*=
params
.
dst_block_size
;
}
int
body_len
=
FloorTo
(
last_dim_shape
,
params
.
vec_max_len
);
int
tail_len
=
last_dim_shape
%
params
.
vec_max_len
;
if
(
body_len
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
body_len
/
params
.
vec_max_len
;
body_args
.
GetNode
()
->
body_offset_
=
params
.
vec_max_len
;
body_args
.
GetNode
()
->
repeat_
=
GetItem
(
params
.
dst_shape
,
-
2
);
body_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
1
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
dst_block_size
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides
,
-
2
),
params
.
src_block_size
)};
body_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
params
.
block_offset
);
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
vec_max_len
,
data_num
,
data_type
,
params
.
block_offset
);
}
// get tail params
if
(
tail_len
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
body_len
);
tail_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
body_len
)};
tail_args
.
GetNode
()
->
repeat_
=
GetItem
(
params
.
dst_shape
,
-
2
);
tail_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
dst_block_size
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides
,
-
2
),
params
.
src_block_size
)};
tail_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
params
.
block_offset
);
int
data_len
=
CeilTo
(
tail_len
,
params
.
mask_block_size
);
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
data_num
,
data_type
,
params
.
block_offset
);
}
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
Array
<
Var
>
SingleVecPatternGenerator
::
Get1DPattern
()
{
int
last_dim_shape
;
bool
linear_mode
=
false
;
if
((
params
.
dst_shape
.
empty
()
&&
params
.
src_shape
.
empty
())
||
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
))
==
0
)
{
last_dim_shape
=
1
;
}
else
if
(
!
IsTwoItemEqual
(
params
.
dst_var
,
params
.
src_var
,
-
1
))
{
last_dim_shape
=
1
;
}
else
{
last_dim_shape
=
GetLastDimShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape
,
-
1
));
linear_mode
=
params
.
dst_bits
==
params
.
src_bits
;
}
bool
is_scalar_mode
=
IsScalarMode
({
dst_info
,
src_info
});
if
(
is_scalar_mode
&&
params
.
dst_bits
!=
params
.
src_bits
)
{
last_dim_shape
=
1
;
}
int
vec_max_len
=
is_scalar_mode
?
FULL_BLOCK_NUM
:
params
.
vec_max_len
;
int
body_len
=
FloorTo
(
last_dim_shape
,
vec_max_len
);
int
tail_len
=
last_dim_shape
%
vec_max_len
;
auto
dst_stride_m0
=
is_scalar_mode
&&
linear_mode
?
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
1
),
params
.
dst_block_size
)
:
Expr
(
1
);
auto
src_stride_m0
=
is_scalar_mode
&&
linear_mode
?
truncdiv
(
GetItem
(
params
.
src_strides
,
-
1
),
params
.
src_block_size
)
:
Expr
(
1
);
if
(
body_len
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
body_offset_
=
vec_max_len
;
body_args
.
GetNode
()
->
repeat_
=
Expr
(
body_len
/
vec_max_len
);
body_args
.
GetNode
()
->
dst_stride_m0_
=
dst_stride_m0
;
auto
dst_block_num
=
is_scalar_mode
?
FULL_BLOCK_NUM
:
(
params
.
max_bits
/
params
.
src_bits
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
dst_stride_m0
*
dst_block_num
;
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
src_stride_m0
};
auto
src_block_num
=
is_scalar_mode
?
FULL_BLOCK_NUM
:
(
params
.
max_bits
/
params
.
dst_bits
);
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
src_stride_m0
*
src_block_num
};
body_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
params
.
block_offset
);
// in cast case, data_num should be 1 because dst and src bit is not equal
int
data_len
=
is_scalar_mode
?
1
:
vec_max_len
;
int
data_num
=
is_scalar_mode
?
FULL_BLOCK_NUM
:
1
;
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
data_num
,
data_type
,
params
.
block_offset
);
}
// get tail params
if
(
tail_len
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
tail_args
.
GetNode
()
->
body_offset_
=
vec_max_len
;
tail_args
.
GetNode
()
->
body_num_
=
1
;
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
body_len
*
(
is_scalar_mode
?
dst_stride_m0
*
params
.
dst_block_size
:
Expr
(
1
)));
tail_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
body_len
*
(
is_scalar_mode
?
src_stride_m0
*
params
.
src_block_size
:
Expr
(
1
)))};
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m0_
=
dst_stride_m0
;
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
0
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
src_stride_m0
};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
0
)};
tail_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
params
.
block_offset
);
int
data_len
=
is_scalar_mode
&&
linear_mode
?
1
:
CeilTo
(
tail_len
,
params
.
mask_block_size
);
int
data_num
=
is_scalar_mode
&&
linear_mode
?
tail_len
:
1
;
data_num
=
data_num
==
0
?
1
:
data_num
;
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
data_num
,
data_type
,
params
.
block_offset
);
}
// compute offset for cce instructions
Array
<
Var
>
elim_var
=
{};
if
(
mode
==
"elewise"
&&
params
.
dst_var
.
size
()
>=
2
&&
params
.
dst_strides
.
size
()
>=
2
&&
for_info
.
ops_
.
size
()
>=
2
&&
last_dim_shape
<=
vec_max_len
&&
last_dim_shape
>=
vec_max_len
-
params
.
dst_block_size
&&
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
last_dim_shape
)
{
// in this case we can merge second last for extent to repeat
size_t
index
=
0
;
bool
suc
=
GetIndexOfElement
(
for_info
.
vars_
,
GetItem
(
params
.
dst_var
,
-
2
),
index
);
CHECK
(
suc
);
auto
latest_for
=
GetItem
(
for_info
.
ops_
,
index
).
as
<
For
>
();
// there should not be if_op between for loop and compute stmt
if
(
latest_for
&&
!
latest_for
->
body
->
IsInstance
<
IfThenElse
>
())
{
if
(
!
params
.
dst_var
.
empty
()
&&
(
!
is_scalar_mode
||
last_dim_shape
!=
1
))
{
if
(
body_args
.
defined
())
{
// last_dim_shape = vec_max_len
body_args
.
GetNode
()
->
repeat_
=
body_args
->
repeat_
*
latest_for
->
extent
;
}
else
if
(
tail_args
.
defined
())
{
// last_dim_shape < vec_max_len
tail_args
.
GetNode
()
->
repeat_
=
tail_args
->
repeat_
*
latest_for
->
extent
;
}
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
}
}
if
(
!
params
.
dst_var
.
empty
()
&&
(
!
is_scalar_mode
||
last_dim_shape
!=
1
||
linear_mode
)
&&
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
1
))
>
0
&&
(
params
.
src_var
.
empty
()
||
IsTwoItemEqual
(
params
.
dst_var
,
params
.
src_var
,
-
1
)))
{
elim_var
=
GetRange
(
params
.
dst_var
,
-
1
,
1
);
}
return
elim_var
;
}
PatternResult
SingleVecPatternGenerator
::
GenResult
(
const
Array
<
Var
>
&
elim_var
)
{
arg_info
.
GetNode
()
->
body_arg_info_
=
body_args
;
arg_info
.
GetNode
()
->
tail_arg_info_
=
tail_args
;
dst_info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
dst_info
,
elim_var
);
src_info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
src_info
,
elim_var
);
CleanForInfoVars
(
for_info
,
elim_var
);
StmtInfoList
info_list
=
{
dst_info
,
src_info
};
CleanZeroStrides
(
info_list
);
dst_info
=
info_list
[
0
];
src_info
=
info_list
[
1
];
PatternResult
result
;
result
.
dst_info_list
=
{
dst_info
};
result
.
src_info_list
=
{
src_info
};
result
.
for_info
=
for_info
;
result
.
arg_info
=
arg_info
;
return
result
;
}
}
// namespace akg
src/pass/emit_insn.cc
浏览文件 @
aec205ff
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include "pass/ir_util.h"
#include "pass/ir_util.h"
#include "poly/poly_util.h"
#include "poly/poly_util.h"
#include "emit_insn/insn_emitter.h"
#include "emit_insn/insn_emitter.h"
#include "emit_insn/ir_transform.h"
namespace
akg
{
namespace
akg
{
namespace
ir
{
namespace
ir
{
...
@@ -475,6 +476,7 @@ Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Ma
...
@@ -475,6 +476,7 @@ Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Ma
}
}
stmt
=
UnalignedMad
().
Mutate
(
stmt
);
stmt
=
UnalignedMad
().
Mutate
(
stmt
);
stmt
=
RegCondition
().
Mutate
(
stmt
);
stmt
=
RegCondition
().
Mutate
(
stmt
);
stmt
=
ForVarUnique
().
Mutate
(
stmt
);
return
stmt
;
return
stmt
;
}
}
}
// namespace ir
}
// namespace ir
...
...
src/pass/multi_last_axis_reduction.cc
浏览文件 @
aec205ff
...
@@ -343,8 +343,12 @@ class BroadcastCalculate : public IRMutator {
...
@@ -343,8 +343,12 @@ class BroadcastCalculate : public IRMutator {
};
};
Stmt
MultiLastAxisReductions
(
Stmt
stmt
,
bool
is_dynamic
=
false
)
{
Stmt
MultiLastAxisReductions
(
Stmt
stmt
,
bool
is_dynamic
=
false
)
{
auto
ori_stmt
=
stmt
;
stmt
=
MultiLastAxisReduction
().
Mutate
(
stmt
);
stmt
=
MultiLastAxisReduction
().
Mutate
(
stmt
);
stmt
=
BroadcastCalculate
(
is_dynamic
).
Mutate
(
stmt
);
stmt
=
BroadcastCalculate
(
is_dynamic
).
Mutate
(
stmt
);
if
(
!
is_dynamic
&&
!
Equal
(
ori_stmt
,
stmt
))
{
stmt
=
MergeLoops
(
stmt
);
}
return
stmt
;
return
stmt
;
}
}
}
// namespace ir
}
// namespace ir
...
...
src/pass/split_tail_block.cc
浏览文件 @
aec205ff
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
#include <algorithm>
#include <algorithm>
#include "emit_insn/insn_info.h"
#include "emit_insn/insn_info.h"
#include "emit_insn/insn_pattern.h"
#include "emit_insn/insn_pattern.h"
#include "emit_insn/insn_args_calculator.h"
namespace
akg
{
namespace
akg
{
namespace
ir
{
namespace
ir
{
...
@@ -48,85 +48,63 @@ class TailSpliter : public IRMutator {
...
@@ -48,85 +48,63 @@ class TailSpliter : public IRMutator {
if
(
src_info_list
.
empty
())
{
if
(
src_info_list
.
empty
())
{
src_info_list
=
{
dst_info
.
Copy
()};
src_info_list
=
{
dst_info
.
Copy
()};
}
}
auto
get_info_list
=
[](
const
StmtStoreInfo
&
dst_info
,
const
Array
<
StmtStoreInfo
>
&
src_info_list
)
{
Array
<
StmtStoreInfo
>
res
;
res
.
push_back
(
dst_info
.
Copy
());
for
(
auto
it
:
src_info_list
)
{
res
.
push_back
(
it
.
Copy
());
}
return
res
;
};
auto
info_list
=
get_info_list
(
dst_info
,
src_info_list
);
FillEmptyVar
(
info_list
);
auto
axis_list
=
GetAixsList
(
for_info
,
info_list
);
auto
get_last_axis_it
=
[](
const
std
::
list
<
InsnAxis
>
&
axis_list
)
{
for
(
auto
it
=
axis_list
.
begin
();
it
!=
axis_list
.
end
();
it
++
)
{
auto
stride_list
=
it
->
stride_list
;
if
(
!
(
std
::
any_of
(
stride_list
.
begin
(),
stride_list
.
end
(),
[](
int
stride
)
{
return
stride
>
1
;
})
||
std
::
all_of
(
stride_list
.
begin
(),
stride_list
.
end
(),
[](
int
stride
)
{
return
stride
==
0
;
})))
{
return
it
;
}
}
return
axis_list
.
end
();
};
auto
last_axis_it
=
get_last_axis_it
(
axis_list
);
auto
info_list
=
GetInfoList
(
dst_info
,
src_info_list
);
if
(
last_axis_it
==
axis_list
.
end
())
{
FillEmptyVar
(
info_list
);
return
s
;
}
auto
last_axis
=
*
last_axis_it
;
auto
last_axis_shape
=
last_axis
.
extent
;
int
dst_block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
int
dst_block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
int
src_block_size
=
GetUbBlkSize
(
src_info_list
[
0
]
->
dtype_
);
int
src_block_size
=
GetUbBlkSize
(
src_info_list
[
0
]
->
dtype_
);
int
block_size
=
dst_block_size
>
src_block_size
?
dst_block_size
:
src_block_size
;
int
block_size
=
dst_block_size
<
src_block_size
?
dst_block_size
:
src_block_size
;
int
cast_block_size
=
dst_block_size
>
src_block_size
?
dst_block_size
:
src_block_size
;
int
vec_max_len
=
block_size
*
FULL_BLOCK_NUM
;
int
vec_max_len
=
block_size
*
FULL_BLOCK_NUM
;
auto
args_calculator
=
InsnArgsCalculator
(
dst_info_list
,
src_info_list
,
for_info
,
""
);
if
(
last_axis_shape
>
vec_max_len
&&
last_axis_shape
%
vec_max_len
!=
0
)
{
auto
vec_axis_it
=
args_calculator
.
GetVecAxisIt
();
return
Block
::
make
(
TailMake
(
s
,
last_axis
,
vec_max_len
,
false
),
TailMake
(
s
,
last_axis
,
vec_max_len
,
true
));
bool
cast
=
dst_block_size
!=
src_block_size
;
if
(
args_calculator
.
IsValid
(
vec_axis_it
))
{
auto
vec_axis
=
*
vec_axis_it
;
auto
vec_axis_shape
=
vec_axis
.
extent
;
if
(
vec_axis_shape
>=
vec_max_len
)
{
if
(
vec_axis_shape
%
vec_max_len
!=
0
)
{
return
TailBlock
(
s
,
vec_axis
,
vec_max_len
);
}
}
if
(
last_axis_shape
<
vec_max_len
*
tail_rate_
&&
last_axis_shape
>
block_size
&&
}
else
{
last_axis_shape
%
block_size
!=
0
&&
axis_list
.
size
()
>
1
)
{
if
(
vec_axis_shape
<
vec_max_len
*
tail_rate_
&&
vec_axis_shape
>
cast_block_size
&&
return
Block
::
make
(
TailMake
(
s
,
last_axis
,
block_size
,
false
),
TailMake
(
s
,
last_axis
,
block_size
,
true
));
vec_axis_shape
%
cast_block_size
!=
0
&&
args_calculator
.
axis_list_
.
size
()
>
1
)
{
return
TailBlock
(
s
,
vec_axis
,
cast_block_size
);
}
}
}
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
if
(
!
cast
&&
(
!
args_calculator
.
IsValid
(
vec_axis_it
)
||
vec_axis_it
->
extent
<=
cast_block_size
*
tail_rate_
))
{
std
::
list
<
InsnAxis
>
GetAixsList
(
const
StmtInfo
&
for_info
,
const
Array
<
StmtStoreInfo
>
&
info_list
)
{
auto
get_block_axis
=
[
&
](
std
::
list
<
InsnAxis
>
&
axis_list
)
{
std
::
list
<
InsnAxis
>
axis_list
;
InsnAxis
block_axis
;
auto
GetStrideByAxis
=
[](
const
Array
<
Var
>
&
vars
,
const
Array
<
Expr
>
&
strides
,
Var
obj_var
)
{
block_axis
.
is_valid
=
false
;
int
index
=
0
;
std
::
vector
<
InsnAxis
>
temp_axis_set
;
for
(
auto
var_it
:
vars
)
{
auto
block_stride_lambda
=
[
&
](
int
stride
)
{
return
stride
%
block_size
==
0
&&
stride
/
block_size
<=
4
;
};
if
(
Equal
(
var_it
,
obj_var
))
{
for
(
auto
axis
:
axis_list
)
{
return
strides
[
index
];
if
(
std
::
all_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
block_stride_lambda
)
&&
axis
.
dst_stride
!=
0
&&
axis
.
extent
!=
0
&&
axis
.
extent
>
FULL_BLOCK_NUM
&&
axis
.
extent
%
FULL_BLOCK_NUM
!=
0
)
{
temp_axis_set
.
push_back
(
axis
);
}
}
index
++
;
}
}
return
Expr
(
0
);
if
(
!
temp_axis_set
.
empty
())
{
};
return
temp_axis_set
[
0
];
for
(
auto
it
:
for_info
.
ops_
)
{
InsnAxis
axis
;
auto
for_stmt
=
it
.
as
<
For
>
();
CHECK
(
for_stmt
);
axis
.
var
=
for_stmt
->
loop_var
;
axis
.
extent
=
GetInt32Const
(
for_stmt
->
extent
);
axis
.
min
=
GetInt32Const
(
for_stmt
->
min
);
int
index
=
0
;
for
(
auto
it
:
info_list
)
{
auto
stride
=
GetInt32Const
(
GetStrideByAxis
(
it
->
var_
,
it
->
strides_
,
axis
.
var
));
axis
.
stride_list
.
push_back
(
stride
);
if
(
index
==
0
)
{
axis
.
dst_stride
=
stride
;
}
else
{
}
else
{
axis
.
src_stride_list
.
push_back
(
stride
)
;
return
block_axis
;
}
}
index
++
;
};
auto
block_axis
=
get_block_axis
(
args_calculator
.
axis_list_
);
if
(
block_axis
.
IsValid
())
{
return
TailBlock
(
s
,
block_axis
,
FULL_BLOCK_NUM
);
}
}
axis_list
.
push_back
(
axis
);
}
}
return
axis_list
;
return
s
;
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
TailBlock
(
const
Stmt
&
s
,
const
InsnAxis
&
tail_axis
,
int
body_size
)
{
return
Block
::
make
(
TailMake
(
s
,
tail_axis
,
body_size
,
false
),
TailMake
(
s
,
tail_axis
,
body_size
,
true
));
}
Stmt
TailMake
(
const
Stmt
&
s
,
const
InsnAxis
&
tail_axis
,
int
body_size
,
bool
is_tail
)
{
Stmt
TailMake
(
const
Stmt
&
s
,
const
InsnAxis
&
tail_axis
,
int
body_size
,
bool
is_tail
)
{
if
(
auto
attr_stmt
=
s
.
as
<
AttrStmt
>
())
{
if
(
auto
attr_stmt
=
s
.
as
<
AttrStmt
>
())
{
return
AttrStmt
::
make
(
attr_stmt
->
node
,
attr_stmt
->
attr_key
,
attr_stmt
->
value
,
return
AttrStmt
::
make
(
attr_stmt
->
node
,
attr_stmt
->
attr_key
,
attr_stmt
->
value
,
...
@@ -145,7 +123,6 @@ class TailSpliter : public IRMutator {
...
@@ -145,7 +123,6 @@ class TailSpliter : public IRMutator {
}
}
return
For
::
make
(
for_stmt
->
loop_var
,
for_stmt
->
min
,
for_stmt
->
extent
,
for_stmt
->
for_type
,
for_stmt
->
device_api
,
return
For
::
make
(
for_stmt
->
loop_var
,
for_stmt
->
min
,
for_stmt
->
extent
,
for_stmt
->
for_type
,
for_stmt
->
device_api
,
TailMake
(
for_stmt
->
body
,
tail_axis
,
body_size
,
is_tail
));
TailMake
(
for_stmt
->
body
,
tail_axis
,
body_size
,
is_tail
));
}
}
if
(
s
.
as
<
Store
>
()
&&
is_tail
)
{
if
(
s
.
as
<
Store
>
()
&&
is_tail
)
{
return
substitute
(
tail_axis
.
var
,
Add
::
make
(
Expr
(
tail_axis
.
extent
/
body_size
*
body_size
),
tail_axis
.
var
),
s
);
return
substitute
(
tail_axis
.
var
,
Add
::
make
(
Expr
(
tail_axis
.
extent
/
body_size
*
body_size
),
tail_axis
.
var
),
s
);
...
@@ -156,6 +133,20 @@ class TailSpliter : public IRMutator {
...
@@ -156,6 +133,20 @@ class TailSpliter : public IRMutator {
private:
private:
const
float
tail_rate_
{
0.6
};
const
float
tail_rate_
{
0.6
};
const
std
::
set
<
std
::
string
>
include_intrin_list_
=
{
const
std
::
set
<
std
::
string
>
include_intrin_list_
=
{
// binary vec
"vec_binary_add"
,
"vec_binary_sub"
,
"vec_binary_mul"
,
"vec_binary_min"
,
"vec_binary_max"
,
"vec_binary_div"
,
"vec_binary_and"
,
"vec_binary_or"
,
"vec_binary_vmadd"
,
"vec_binary_vmaddrelu"
,
"vec_binary_vmla"
,
// single vec
"vec_single_fabs"
,
"vec_single_fabs"
,
"vec_single_log"
,
"vec_single_log"
,
"vec_single_exp"
,
"vec_single_exp"
,
...
@@ -165,20 +156,28 @@ class TailSpliter : public IRMutator {
...
@@ -165,20 +156,28 @@ class TailSpliter : public IRMutator {
"vec_single_rsqrt"
,
"vec_single_rsqrt"
,
"vec_single_relu"
,
"vec_single_relu"
,
"vec_single_not"
,
"vec_single_not"
,
// vector_scalar
"vec_single_muls"
,
"vec_single_adds"
,
// Mov
// Mov
"broadcast"
,
"broadcast"
,
"mask_broadcast"
,
// vector_cast
// vector_cast
"vec_single_cast"
,
"vec_single_cast"
,
"vec_single_floor"
,
"vec_single_floor"
,
"vec_single_round"
,
"vec_single_round"
,
"vec_single_ceil"
,
"vec_single_ceil"
,
"vec_single_trunc"
,
"vec_single_trunc"
,
// scalar case
"vector_dup"
,
"vmuls"
,
"vadds"
,
"vaxpy"
,
};
};
};
};
Stmt
SplitTail
(
Stmt
stmt
)
{
return
TailSpliter
().
Mutate
(
stmt
);
}
Stmt
SplitTail
(
Stmt
stmt
)
{
auto
tail_spliter
=
TailSpliter
();
auto
first_round
=
tail_spliter
.
Mutate
(
stmt
);
auto
second_round
=
tail_spliter
.
Mutate
(
stmt
);
return
second_round
;
}
}
// namespace ir
}
// namespace ir
}
// namespace akg
}
// namespace akg
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录