Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
akg
提交
0871623a
A
akg
项目概览
MindSpore
/
akg
通知
58
Star
7
Fork
7
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
akg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0871623a
编写于
8月 24, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 24, 2020
浏览文件
操作
浏览文件
下载
差异文件
!110 rewrite insn pattern generator in EmitInsn
Merge pull request !110 from wYann/insn_reduction_final
上级
d237aa7d
aec205ff
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
1432 addition
and
2512 deletion
+1432
-2512
src/emit_insn/insn_args_calculator.cc
src/emit_insn/insn_args_calculator.cc
+1089
-0
src/emit_insn/insn_args_calculator.h
src/emit_insn/insn_args_calculator.h
+199
-0
src/emit_insn/insn_binary_vec_pattern.cc
src/emit_insn/insn_binary_vec_pattern.cc
+0
-1335
src/emit_insn/insn_emitter.cc
src/emit_insn/insn_emitter.cc
+26
-30
src/emit_insn/insn_info.cc
src/emit_insn/insn_info.cc
+1
-4
src/emit_insn/insn_info.h
src/emit_insn/insn_info.h
+1
-9
src/emit_insn/insn_pattern.cc
src/emit_insn/insn_pattern.cc
+32
-42
src/emit_insn/insn_pattern.h
src/emit_insn/insn_pattern.h
+4
-215
src/emit_insn/insn_single_vec_pattern.cc
src/emit_insn/insn_single_vec_pattern.cc
+0
-802
src/pass/emit_insn.cc
src/pass/emit_insn.cc
+2
-0
src/pass/multi_last_axis_reduction.cc
src/pass/multi_last_axis_reduction.cc
+4
-0
src/pass/split_tail_block.cc
src/pass/split_tail_block.cc
+74
-75
未找到文件。
src/emit_insn/insn_args_calculator.cc
0 → 100644
浏览文件 @
0871623a
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/base.h>
#include <tvm/ir_pass.h>
#include "common/array_api.h"
#include "pass/expr_alg_simplify.h"
#include "insn_pattern.h"
#include "insn_args_calculator.h"
namespace
akg
{
InsnAxis
::
InsnAxis
(
const
For
*
for_stmt
,
const
Array
<
StmtStoreInfo
>
&
info_list
)
{
this
->
var
=
for_stmt
->
loop_var
;
this
->
extent
=
GetInt32Const
(
for_stmt
->
extent
);
this
->
min
=
GetInt32Const
(
for_stmt
->
min
);
int
index
=
0
;
for
(
auto
it
:
info_list
)
{
auto
stride
=
GetInt32Const
(
GetStrideByAxis
(
it
->
var_
,
it
->
strides_
,
this
->
var
));
this
->
stride_list
.
push_back
(
stride
);
if
(
index
==
0
)
{
this
->
dst_stride
=
stride
;
}
else
{
this
->
src_stride_list
.
push_back
(
stride
);
}
index
++
;
}
}
Expr
InsnAxis
::
GetStrideByAxis
(
const
Array
<
Var
>
&
vars
,
const
Array
<
Expr
>
&
strides
,
Var
obj_var
)
{
int
index
=
0
;
for
(
auto
var_it
:
vars
)
{
if
(
Equal
(
var_it
,
obj_var
))
{
return
strides
[
index
];
}
index
++
;
}
return
Expr
(
0
);
};
bool
InsnAxis
::
IsValid
()
{
return
this
->
is_valid
;
}
void
InsnAxis
::
Print
(
const
std
::
string
&
name
)
{
if
(
!
name
.
empty
())
{
LOG
(
DEBUG
)
<<
"********** "
<<
name
<<
" ************"
;
}
auto
r_stride
=
this
->
src_stride_list
.
size
()
>
1
?
src_stride_list
[
1
]
:
99999
;
LOG
(
DEBUG
)
<<
"var:"
<<
this
->
var
<<
" extent:"
<<
this
->
extent
<<
" min:"
<<
this
->
min
<<
" dst_stride:"
<<
this
->
dst_stride
<<
" src_stride_l:"
<<
this
->
src_stride_list
.
front
()
<<
"src_stride_r:"
<<
r_stride
;
}
Array
<
StmtStoreInfo
>
GetInfoList
(
const
StmtStoreInfo
&
dst_info
,
const
Array
<
StmtStoreInfo
>
&
src_info_list
)
{
Array
<
StmtStoreInfo
>
res
;
res
.
push_back
(
dst_info
.
Copy
());
for
(
auto
it
:
src_info_list
)
{
res
.
push_back
(
it
.
Copy
());
}
return
res
;
};
std
::
list
<
InsnAxis
>
GetAxisList
(
const
StmtInfo
&
for_info
,
const
Array
<
StmtStoreInfo
>
&
info_list
)
{
std
::
list
<
InsnAxis
>
axis_list
;
for
(
auto
it
:
for_info
.
ops_
)
{
auto
for_stmt
=
it
.
as
<
For
>
();
CHECK
(
for_stmt
);
auto
axis
=
InsnAxis
(
for_stmt
,
info_list
);
axis_list
.
push_back
(
axis
);
}
return
axis_list
;
}
void
Print
(
std
::
list
<
InsnAxis
>
&
axis_list
)
{
LOG
(
DEBUG
)
<<
"+++++++++++++++++++ AXIS_LIST +++++++++++++++++++"
;
int
index
=
0
;
for
(
auto
it
:
axis_list
)
{
LOG
(
DEBUG
)
<<
"================== INDEX "
<<
index
<<
" ================="
;
it
.
Print
();
index
++
;
}
LOG
(
DEBUG
)
<<
"------------------ END ---------------------"
;
}
InsnArgsCalculator
::
InsnArgsCalculator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
)
:
dst_info_list_
(
dst_info_list
),
src_info_list_
(
src_info_list
),
for_info_
(
for_info
),
intrin_name_
(
intrin_name
)
{
InitArg
();
CalAxis
();
}
void
InsnArgsCalculator
::
CalAxis
()
{
CHECK
(
!
dst_info_list_
.
empty
());
dst_info_
=
dst_info_list_
[
0
];
if
(
src_info_list_
.
empty
())
{
src_info_list_
=
{
dst_info_
.
Copy
()};
}
auto
src_info
=
src_info_list_
[
0
];
dst_info_
.
Print
();
for
(
auto
src_info_it
:
src_info_list_
)
{
src_info_it
.
Print
();
if
(
src_info_it
->
name_
==
dst_info_
->
name_
)
{
meta_
.
same_dst_src
=
true
;
}
}
meta_
.
dst_block_size
=
GetUbBlkSize
(
dst_info_
->
dtype_
);
meta_
.
src_block_size
=
GetUbBlkSize
(
src_info
->
dtype_
);
meta_
.
cast
=
meta_
.
dst_block_size
!=
meta_
.
src_block_size
;
meta_
.
block_size
=
meta_
.
dst_block_size
<=
meta_
.
src_block_size
?
meta_
.
dst_block_size
:
meta_
.
src_block_size
;
meta_
.
src_dtype
=
src_info
->
dtype_
;
meta_
.
dst_dtype
=
dst_info_
->
dtype_
;
meta_
.
dtype
=
meta_
.
dst_dtype
.
bits
()
>=
meta_
.
src_dtype
.
bits
()
?
meta_
.
dst_dtype
:
meta_
.
src_dtype
;
auto
elem_offset_mod
=
ir
::
ExprSimplifier
().
Simplify
(
Mod
::
make
(
dst_info_
->
elem_offset_
,
meta_
.
block_size
));
if
(
elem_offset_mod
.
as
<
IntImm
>
())
{
meta_
.
block_offset
=
elem_offset_mod
.
as
<
IntImm
>
()
->
value
;
}
axis_list_
=
GetAxisList
(
for_info_
,
GetInfoList
(
dst_info_
,
src_info_list_
));
}
// namespace akg
void
InsnArgsCalculator
::
InitArg
()
{
arg_
.
src_m1_list
=
{
0
,
0
};
arg_
.
src_m0_list
=
{
1
,
1
};
}
std
::
function
<
bool
(
const
InsnAxis
&
)
>
InsnArgsCalculator
::
GetStrideLambda
()
{
return
[
&
](
const
InsnAxis
&
axis
)
{
auto
is_stride
=
[
&
](
int
stride
)
{
return
stride
%
meta_
.
block_size
==
0
;
};
auto
zero_stride
=
[
&
](
int
stride
)
{
return
stride
==
0
;
};
return
std
::
all_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
is_stride
)
&&
!
std
::
all_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
zero_stride
);
};
}
std
::
function
<
bool
(
const
InsnAxis
&
)
>
InsnArgsCalculator
::
GetM0LimitLambda
()
{
return
[
&
](
const
InsnAxis
&
axis
)
{
auto
is_limit
=
[
&
](
int
stride
)
{
return
stride
/
meta_
.
block_size
<
MAX_STRIDE_M0_SINGLE
;
};
return
std
::
all_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
is_limit
);
};
}
std
::
function
<
bool
(
const
InsnAxis
&
)
>
InsnArgsCalculator
::
GetBlockStrideLimitLambda
()
{
return
[
&
](
const
InsnAxis
&
axis
)
{
auto
is_limit
=
[
&
](
int
stride
)
{
return
stride
/
meta_
.
block_size
<=
max_block_stride_
;
};
return
std
::
all_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
is_limit
);
};
}
std
::
function
<
bool
(
const
InsnAxis
&
)
>
InsnArgsCalculator
::
GetM1LimitLambda
()
{
return
[
&
](
const
InsnAxis
&
axis
)
{
auto
is_limit
=
[
&
](
int
stride
)
{
return
stride
/
meta_
.
src_block_size
<
MAX_STRIDE_M1
;
};
return
axis
.
dst_stride
/
meta_
.
dst_block_size
<
MAX_STRIDE_M1
&&
std
::
all_of
(
axis
.
src_stride_list
.
begin
(),
axis
.
src_stride_list
.
end
(),
is_limit
);
};
}
std
::
function
<
bool
(
const
InsnAxis
&
)
>
And
(
const
std
::
list
<
std
::
function
<
bool
(
const
InsnAxis
&
)
>>
&
lambda_list
)
{
return
[
&
lambda_list
](
const
InsnAxis
&
axis
)
{
bool
res
=
true
;
for
(
auto
lambda_it
:
lambda_list
)
{
res
=
res
&&
lambda_it
(
axis
);
}
return
res
;
};
}
AxisIt
InsnArgsCalculator
::
GetAxisByLambda
(
const
std
::
function
<
bool
(
const
InsnAxis
&
)
>
&
lambda
)
{
for
(
auto
axis_it
=
axis_list_
.
begin
();
axis_it
!=
axis_list_
.
end
();
axis_it
++
)
{
if
(
lambda
(
*
axis_it
))
{
return
axis_it
;
}
}
return
axis_list_
.
end
();
}
InsnAxis
InsnArgsCalculator
::
ExtractAxis
(
AxisIt
&
it
)
{
InsnAxis
res
=
*
it
;
axis_list_
.
erase
(
it
);
return
res
;
}
bool
InsnArgsCalculator
::
IsValid
(
AxisIt
&
it
)
{
return
it
!=
axis_list_
.
end
();
}
void
AxisSort
(
std
::
list
<
InsnAxis
>
&
axis_arr
,
bool
order
=
true
)
{
auto
up_compare
=
[
&
](
InsnAxis
&
a
,
InsnAxis
&
b
)
{
return
a
.
extent
<
b
.
extent
;
};
auto
down_compare
=
[
&
](
InsnAxis
&
a
,
InsnAxis
&
b
)
{
return
a
.
extent
>
b
.
extent
;
};
if
(
order
)
{
axis_arr
.
sort
(
up_compare
);
}
else
{
axis_arr
.
sort
(
down_compare
);
}
}
AxisIt
InsnArgsCalculator
::
GetVecAxisIt
()
{
axis_list_
.
reverse
();
auto
IsVecAxis
=
[
&
](
const
InsnAxis
&
axis
)
{
return
!
(
std
::
any_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
[](
int
stride
)
{
return
stride
>
1
;
})
||
std
::
all_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
[](
int
stride
)
{
return
stride
==
0
;
}));
};
return
GetAxisByLambda
(
IsVecAxis
);
}
SplitStat
InsnArgsCalculator
::
SplitAxis
(
int
extent
,
InsnAxis
&
axis
)
{
if
(
axis
.
extent
<=
extent
)
{
return
NO_SPLIT
;
}
if
(
axis
.
extent
%
extent
!=
0
)
{
return
TAIL
;
}
InsnAxis
new_axis
;
new_axis
.
extent
=
axis
.
extent
/
extent
;
for
(
auto
stride
:
axis
.
stride_list
)
{
new_axis
.
stride_list
.
push_back
(
stride
*
extent
);
}
auto
temp_stride_list
=
new_axis
.
stride_list
;
CHECK
(
!
temp_stride_list
.
empty
());
new_axis
.
dst_stride
=
temp_stride_list
.
front
();
temp_stride_list
.
erase
(
temp_stride_list
.
begin
());
new_axis
.
src_stride_list
=
temp_stride_list
;
new_axis
.
var
=
Var
(
axis
.
var
->
name_hint
);
axis_list_
.
push_back
(
new_axis
);
axis
.
extent
=
extent
;
return
SUCCESS
;
}
AxisIt
InsnArgsCalculator
::
GetBlockAxis
()
{
AxisSort
(
axis_list_
);
auto
stride_lambda
=
GetStrideLambda
();
auto
m0_limit_lambda
=
GetM0LimitLambda
();
auto
block_stride_limit_lambda
=
GetBlockStrideLimitLambda
();
auto
axis_it
=
GetAxisByLambda
(
And
({
stride_lambda
,
m0_limit_lambda
,
block_stride_limit_lambda
,
[
&
](
const
InsnAxis
&
axis
)
{
return
axis
.
extent
>=
FULL_BLOCK_NUM
&&
axis
.
extent
%
FULL_BLOCK_NUM
==
0
;
}}));
if
(
IsValid
(
axis_it
))
{
return
axis_it
;
}
axis_it
=
GetAxisByLambda
(
And
({
stride_lambda
,
m0_limit_lambda
,
block_stride_limit_lambda
,
[
&
](
const
InsnAxis
&
axis
)
{
return
axis
.
extent
>=
FULL_BLOCK_NUM
;
}}));
if
(
IsValid
(
axis_it
)
&&
axis_list_
.
size
()
==
1
)
{
return
axis_it
;
}
axis_list_
.
reverse
();
axis_it
=
GetAxisByLambda
(
And
({
stride_lambda
,
m0_limit_lambda
,
block_stride_limit_lambda
,
[
&
](
const
InsnAxis
&
axis
)
{
return
axis
.
extent
<
FULL_BLOCK_NUM
;
}}));
if
(
IsValid
(
axis_it
))
{
return
axis_it
;
}
return
GetAxisByLambda
(
And
({
stride_lambda
,
m0_limit_lambda
,
[
&
](
const
InsnAxis
&
axis
)
{
return
axis
.
extent
<=
FULL_BLOCK_NUM
||
axis
.
extent
%
FULL_BLOCK_NUM
==
0
;
}}));
}
AxisIt
InsnArgsCalculator
::
GetRepeatAxisIt
()
{
AxisSort
(
axis_list_
);
auto
stride_lambda
=
GetStrideLambda
();
auto
m1_limit_lambda
=
GetM1LimitLambda
();
auto
axis_it
=
GetAxisByLambda
(
And
({
stride_lambda
,
m1_limit_lambda
,
[
&
](
const
InsnAxis
&
axis
)
{
return
axis
.
extent
>=
MAX_REPEAT
-
1
;
}}));
if
(
IsValid
(
axis_it
))
{
return
axis_it
;
}
axis_list_
.
reverse
();
return
GetAxisByLambda
(
And
({
stride_lambda
,
m1_limit_lambda
}));
}
void
InsnArgsCalculator
::
SetArgMask
(
int
len
)
{
SetArgBlockNum
(
1
);
SetArgBlockLen
(
len
);
}
void
InsnArgsCalculator
::
SetArgBlockNum
(
int
data_num
)
{
arg_
.
block_num
=
data_num
;
}
void
InsnArgsCalculator
::
SetArgBlockLen
(
int
data_len
)
{
arg_
.
block_len
=
data_len
;
}
void
InsnArgsCalculator
::
SetArgM0
(
int
dst_m0
,
int
lsrc_m0
,
int
rsrc_m0
=
0
)
{
arg_
.
dst_m0
=
dst_m0
;
arg_
.
src_m0_list
=
{
lsrc_m0
,
rsrc_m0
};
}
void
InsnArgsCalculator
::
SetArgM1
(
int
dst_m1
,
int
lsrc_m1
,
int
rsrc_m1
=
0
)
{
arg_
.
dst_m1
=
dst_m1
;
arg_
.
src_m1_list
=
{
lsrc_m1
,
rsrc_m1
};
}
void
InsnArgsCalculator
::
SetArgRepeat
(
int
repeat
)
{
arg_
.
repeat
=
repeat
;
}
void
InsnArgsCalculator
::
BlockAxisReduction
()
{
Print
(
axis_list_
);
auto
block_axis_it
=
GetBlockAxis
();
if
(
IsValid
(
block_axis_it
))
{
auto
origin_block_axis
=
*
block_axis_it
;
InsnAxis
block_axis
=
ExtractAxis
(
block_axis_it
);
if
(
block_axis
.
extent
%
FULL_BLOCK_NUM
!=
0
&&
block_axis
.
extent
>
FULL_BLOCK_NUM
)
{
arg_
.
tail_len
=
block_axis
.
extent
%
FULL_BLOCK_NUM
;
block_axis
.
extent
=
FloorTo
(
block_axis
.
extent
,
FULL_BLOCK_NUM
);
arg_
.
dst_tail_offset
=
block_axis
.
dst_stride
*
block_axis
.
extent
;
for
(
auto
stride
:
block_axis
.
src_stride_list
)
{
arg_
.
src_tail_offset_list
.
push_back
(
stride
*
block_axis
.
extent
);
}
SplitAxis
(
FULL_BLOCK_NUM
,
block_axis
);
auto
repeat_axis_it
=
GetRepeatAxisIt
();
if
(
!
IsValid
(
repeat_axis_it
)
&&
axis_list_
.
size
()
>
0
)
{
for
(
auto
it
=
axis_list_
.
begin
();
it
!=
axis_list_
.
end
();
it
++
)
{
if
(
it
->
var
->
name_hint
==
block_axis
.
var
->
name_hint
)
{
axis_list_
.
erase
(
it
);
break
;
}
}
axis_list_
.
push_back
(
origin_block_axis
);
return
;
}
}
else
{
SplitAxis
(
FULL_BLOCK_NUM
,
block_axis
);
}
block_axis
.
Print
(
"BLOCK_AXIS"
);
SetArgM0
(
block_axis
.
dst_stride
/
meta_
.
block_size
,
block_axis
.
src_stride_list
.
front
()
/
meta_
.
block_size
,
block_axis
.
src_stride_list
.
back
()
/
meta_
.
block_size
);
SetArgBlockNum
(
block_axis
.
extent
);
}
}
void
InsnArgsCalculator
::
RepeatAxisReduction
()
{
Print
(
axis_list_
);
auto
repeat_axis
=
GetRepeatAxis
();
if
(
repeat_axis
.
IsValid
())
{
repeat_axis
.
Print
(
"REPEAT_AXIS"
);
SetArgM1
(
repeat_axis
.
dst_stride
/
meta_
.
dst_block_size
,
repeat_axis
.
src_stride_list
.
front
()
/
meta_
.
src_block_size
,
repeat_axis
.
src_stride_list
.
back
()
/
meta_
.
src_block_size
);
SetArgRepeat
(
repeat_axis
.
extent
);
}
}
InsnAxis
InsnArgsCalculator
::
GetInvalidAxis
()
{
InsnAxis
res
;
res
.
is_valid
=
false
;
return
res
;
}
InsnAxis
InsnArgsCalculator
::
GetRepeatAxis
()
{
auto
repeat_axis_it
=
GetRepeatAxisIt
();
if
(
IsValid
(
repeat_axis_it
))
{
InsnAxis
repeat_axis
=
ExtractAxis
(
repeat_axis_it
);
SplitAxis
(
MAX_REPEAT
-
1
,
repeat_axis
);
return
repeat_axis
;
}
return
GetInvalidAxis
();
}
void
InsnArgsCalculator
::
CastCaseReduction
()
{
if
(
axis_list_
.
empty
())
{
return
;
}
Print
(
axis_list_
);
int
cast_block_size
=
meta_
.
dst_block_size
<
meta_
.
src_block_size
?
meta_
.
dst_block_size
:
meta_
.
src_block_size
;
auto
vec_axis_it
=
GetVecAxisIt
();
if
(
IsValid
(
vec_axis_it
))
{
InsnAxis
vec_axis
=
ExtractAxis
(
vec_axis_it
);
int
max_vec_len
=
cast_block_size
*
FULL_BLOCK_NUM
;
if
(
vec_axis
.
extent
>
cast_block_size
&&
vec_axis
.
extent
<
max_vec_len
)
{
SetArgMask
(
DivFloor
(
vec_axis
.
extent
,
cast_block_size
)
*
cast_block_size
);
SetArgM0
(
1
,
1
,
1
);
}
else
if
(
vec_axis
.
extent
>=
max_vec_len
)
{
SplitAxis
(
max_vec_len
,
vec_axis
);
SetArgMask
(
DivFloor
(
vec_axis
.
extent
,
cast_block_size
)
*
cast_block_size
);
SetArgM0
(
1
,
1
,
1
);
}
else
{
SetArgBlockLen
(
cast_block_size
);
}
}
RepeatAxisReduction
();
}
int
DivFloor
(
int
a
,
int
b
)
{
if
(
a
%
b
==
0
)
{
return
a
/
b
;
}
else
{
return
a
/
b
+
1
;
}
}
void
InsnArgsCalculator
::
InsnReduction
()
{
if
(
axis_list_
.
empty
())
{
return
;
}
Print
(
axis_list_
);
auto
vec_axis_it
=
GetVecAxisIt
();
meta_
.
scalar
=
!
IsValid
(
vec_axis_it
);
if
(
!
meta_
.
scalar
)
{
InsnAxis
vec_axis
=
ExtractAxis
(
vec_axis_it
);
int
max_vec_len
=
meta_
.
block_size
*
FULL_BLOCK_NUM
;
if
(
vec_axis
.
extent
>
meta_
.
block_size
&&
vec_axis
.
extent
<
max_vec_len
&&
(
vec_axis
.
extent
%
meta_
.
block_size
!=
0
||
vec_axis
.
extent
>
max_vec_len
*
meta_
.
vec_rate
))
{
vec_axis
.
Print
(
"VEC_BLOCK_AXIS"
);
SetArgMask
(
DivFloor
(
vec_axis
.
extent
,
meta_
.
block_size
)
*
meta_
.
block_size
);
SetArgM0
(
1
,
1
,
1
);
}
else
{
SplitAxis
(
meta_
.
block_size
,
vec_axis
);
vec_axis
.
Print
(
"VEC_AXIS"
);
SetArgBlockLen
(
meta_
.
block_size
);
BlockAxisReduction
();
}
RepeatAxisReduction
();
}
else
{
BlockAxisReduction
();
RepeatAxisReduction
();
}
Print
(
axis_list_
);
}
Expr
InsnArgsCalculator
::
GetOffset
(
int
stride_index
)
{
Expr
res
=
Expr
(
0
);
for
(
auto
axis_it
:
axis_list_
)
{
auto
stride
=
axis_it
.
stride_list
[
stride_index
];
auto
mul_expr
=
Mul
::
make
(
stride
,
axis_it
.
var
);
res
=
Add
::
make
(
mul_expr
,
res
);
}
return
Simplify
(
res
);
}
StmtInfo
InsnArgsCalculator
::
ExportForInfo
()
{
if
(
for_info_
.
ops_
.
empty
())
{
return
for_info_
;
}
int
last_index
=
for_info_
.
ops_
.
size
()
-
1
;
auto
last_for
=
for_info_
.
ops_
[
last_index
].
as
<
For
>
();
auto
store_stmt
=
last_for
->
body
;
Stmt
for_stmt
=
store_stmt
;
StmtInfo
result
;
for
(
auto
axis_it
:
axis_list_
)
{
for_stmt
=
For
::
make
(
axis_it
.
var
,
axis_it
.
min
,
axis_it
.
extent
,
last_for
->
for_type
,
last_for
->
device_api
,
for_stmt
);
result
.
ops_
.
push_back
(
for_stmt
);
result
.
vars_
.
push_back
(
axis_it
.
var
);
}
return
result
;
}
PatternResult
InsnArgsCalculator
::
ExportResult
()
{
PatternResult
res
;
auto
arg_info
=
ArgInfo
(
make_node
<
ArgInfoNode
>
());
auto
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
arg_
.
body_num
;
body_args
.
GetNode
()
->
body_offset_
=
meta_
.
block_size
*
FULL_BLOCK_NUM
;
body_args
.
GetNode
()
->
repeat_
=
Expr
(
arg_
.
repeat
);
body_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
arg_
.
dst_m0
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
arg_
.
dst_m1
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
arg_
.
src_m0_list
;
body_args
.
GetNode
()
->
src_stride_m1_list_
=
arg_
.
src_m1_list
;
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
arg_
.
block_len
,
arg_
.
block_num
,
meta_
.
dtype
,
meta_
.
block_offset
);
body_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
meta_
.
block_offset
);
arg_info
.
GetNode
()
->
body_arg_info_
=
body_args
;
if
(
arg_
.
tail_len
>
0
)
{
auto
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
arg_
.
dst_tail_offset
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
arg_
.
dst_m1
);
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
arg_
.
src_m1_list
;
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
src_head_list_
=
arg_
.
src_tail_offset_list
;
tail_args
.
GetNode
()
->
body_offset_
=
meta_
.
block_size
*
FULL_BLOCK_NUM
;
tail_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
arg_
.
dst_m0
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
arg_
.
src_m0_list
;
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
arg_
.
block_len
,
arg_
.
tail_len
,
meta_
.
dtype
,
meta_
.
block_offset
);
tail_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
meta_
.
block_offset
);
arg_info
.
GetNode
()
->
tail_arg_info_
=
tail_args
;
}
StmtInfoList
info_list
=
GetInfoList
(
dst_info_
,
src_info_list_
);
CleanZeroStrides
(
info_list
);
for
(
size_t
i
=
0
;
i
<
info_list
.
size
();
i
++
)
{
info_list
[
i
].
GetNode
()
->
insn_offset_
=
GetOffset
(
i
);
}
info_list
[
1
].
Print
();
res
.
for_info
=
ExportForInfo
();
res
.
arg_info
=
arg_info
;
res
.
dst_info_list
=
{
info_list
[
0
]};
if
(
info_list
.
size
()
>
2
)
{
res
.
src_info_list
=
{
info_list
[
1
],
info_list
[
2
]};
}
else
{
res
.
src_info_list
=
{
info_list
[
1
]};
}
body_args
.
Print
();
if
(
arg_info
->
tail_arg_info_
.
defined
())
{
arg_info
->
tail_arg_info_
.
Print
();
}
return
res
;
}
SingleVecInsnArgsCalculator
::
SingleVecInsnArgsCalculator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
)
:
InsnArgsCalculator
(
dst_info_list
,
src_info_list
,
for_info
,
intrin_name
)
{}
PatternResult
SingleVecInsnArgsCalculator
::
GetInsnArgs
()
{
if
(
meta_
.
cast
)
{
CastCaseReduction
();
}
else
{
InsnReduction
();
}
return
ExportResult
();
}
BinaryVecInsnArgsCalculator
::
BinaryVecInsnArgsCalculator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
mode
,
const
std
::
string
&
intrin_name
,
bool
expand_mask
)
:
InsnArgsCalculator
(
dst_info_list
,
src_info_list
,
for_info
,
intrin_name
),
mode_
{
mode
},
expand_mask_
{
expand_mask
}
{
if
(
mode_
==
"reduction"
&&
src_info_list_
.
size
()
==
2
&&
src_info_list_
[
0
]
->
name_
==
dst_info_list
[
0
]
->
name_
)
{
auto
temp
=
src_info_list_
[
0
].
Copy
();
src_info_list_
.
Set
(
0
,
src_info_list_
[
1
].
Copy
());
src_info_list_
.
Set
(
1
,
temp
);
CalAxis
();
}
}
PatternResult
BinaryVecInsnArgsCalculator
::
GetInsnArgs
()
{
LOG
(
DEBUG
)
<<
"Binary vec Insn reduction"
;
InsnReduction
();
return
ExportResult
();
}
std
::
function
<
bool
(
const
InsnAxis
&
)
>
BinaryVecInsnArgsCalculator
::
GetM0LimitLambda
()
{
return
[
&
](
const
InsnAxis
&
axis
)
{
auto
is_limit
=
[
&
](
int
stride
)
{
return
stride
/
meta_
.
block_size
<
MAX_STRIDE_M0
;
};
return
std
::
all_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
is_limit
)
&&
axis
.
dst_stride
!=
0
;
};
}
std
::
function
<
bool
(
const
InsnAxis
&
)
>
BinaryVecInsnArgsCalculator
::
GetM1LimitLambda
()
{
return
[
&
](
const
InsnAxis
&
axis
)
{
auto
is_limit
=
[
&
](
int
stride
)
{
return
stride
/
meta_
.
src_block_size
<
MAX_STRIDE_M1
;
};
return
axis
.
dst_stride
/
meta_
.
dst_block_size
<
MAX_STRIDE_M1
&&
std
::
all_of
(
axis
.
src_stride_list
.
begin
(),
axis
.
src_stride_list
.
end
(),
is_limit
);
};
}
void
BinaryVecInsnArgsCalculator
::
InsnReduction
()
{
if
(
axis_list_
.
empty
())
{
return
;
}
Print
(
axis_list_
);
auto
vec_axis_it
=
GetVecAxisIt
();
meta_
.
scalar
=
!
IsValid
(
vec_axis_it
);
if
(
!
meta_
.
scalar
)
{
vec_axis_
=
*
vec_axis_it
;
InsnAxis
vec_axis
=
ExtractAxis
(
vec_axis_it
);
auto
bad_axis_lambda
=
[
&
](
const
InsnAxis
&
axis
)
{
int
min_stride
=
vec_axis_it
->
extent
;
auto
dst_name
=
dst_info_list_
[
0
]
->
name_
;
if
(
meta_
.
same_dst_src
&&
axis
.
dst_stride
<
min_stride
&&
axis
.
dst_stride
!=
0
)
{
return
true
;
}
return
false
;
};
auto
bad_axis_it
=
GetAxisByLambda
(
bad_axis_lambda
);
InsnAxis
bad_axis
;
bad_axis
.
is_valid
=
false
;
if
(
IsValid
(
bad_axis_it
))
{
bad_axis
=
ExtractAxis
(
bad_axis_it
);
}
int
max_vec_len
=
meta_
.
block_size
*
FULL_BLOCK_NUM
;
if
(
vec_axis
.
extent
>
meta_
.
block_size
&&
vec_axis
.
extent
<
max_vec_len
&&
(
vec_axis
.
extent
%
meta_
.
block_size
!=
0
||
vec_axis
.
extent
>
max_vec_len
*
meta_
.
vec_rate
))
{
vec_axis
.
Print
(
"VEC_BLOCK_AXIS"
);
if
(
expand_mask_
)
{
SetArgMask
(
DivFloor
(
vec_axis
.
extent
,
meta_
.
block_size
)
*
meta_
.
block_size
);
}
else
{
SetArgMask
(
vec_axis
.
extent
);
}
SetArgM0
(
1
,
1
,
1
);
}
else
{
SplitAxis
(
meta_
.
block_size
,
vec_axis
);
vec_axis
.
Print
(
"VEC_AXIS"
);
if
(
expand_mask_
&&
mode_
!=
"reduction"
)
{
SetArgBlockLen
(
meta_
.
block_size
);
}
else
{
SetArgBlockLen
(
vec_axis
.
extent
);
}
BlockAxisReduction
();
}
RepeatAxisReduction
();
if
(
bad_axis
.
IsValid
())
{
axis_list_
.
push_back
(
bad_axis
);
}
}
else
{
BlockAxisReduction
();
RepeatAxisReduction
();
}
Print
(
axis_list_
);
}
PatternResult
LastAxisReduceInsnArgsCalculator
::
GetInsnArgs
()
{
CalcParams
();
Array
<
Var
>
elim_var
;
elim_var
=
GetPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_1D
;
return
GenResult
(
elim_var
);
}
Array
<
Var
>
LastAxisReduceInsnArgsCalculator
::
GetPattern
()
{
int
body_len
=
params
.
last_dim_shape
/
params
.
vec_max_len
*
params
.
vec_max_len
;
int
tail_len
=
params
.
last_dim_shape
%
params
.
vec_max_len
;
int
cmd_body_len
=
0
;
bool
is_vadd
=
intrin_name
==
"vadd"
;
int
repeat_stride
=
FULL_BLOCK_NUM
;
if
(
is_vadd
)
{
repeat_stride
=
1
;
}
const
int
fp16_block_size
=
16
;
if
(
body_len
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
body_offset_
=
params
.
vec_max_len
;
body_args
.
GetNode
()
->
repeat_
=
Expr
(
body_len
/
params
.
vec_max_len
);
// Here use dst_stride_m1 as dst_stride
body_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
)};
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
vec_max_len
,
1
,
dst_info
->
dtype_
);
cmd_body_len
+=
GetInt32Const
(
body_args
->
repeat_
)
*
repeat_stride
;
}
if
(
tail_len
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
tail_args
.
GetNode
()
->
body_offset_
=
params
.
vec_max_len
;
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
cmd_body_len
);
tail_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
body_len
)};
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
0
)};
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
tail_len
,
1
,
dst_info
->
dtype_
);
if
(
is_vadd
)
{
cmd_body_len
+=
1
;
}
else
{
cmd_body_len
+=
tail_len
/
fp16_block_size
;
if
(
tail_len
%
fp16_block_size
!=
0
)
{
cmd_body_len
+=
1
;
}
}
}
// cmd_body_len > 1 means vcadd size greater than 128, need to use vcadd again to compute final result
// if cmd_body_len > 128, then need to recursively emit vcadd
while
(
cmd_body_len
>
1
)
{
int
cmd_tail_len
=
cmd_body_len
%
params
.
vec_max_len
;
cmd_body_len
=
cmd_body_len
/
params
.
vec_max_len
;
if
(
cmd_body_len
>
0
)
{
VectorArgInfo
mix_vec_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
mix_vec_args
.
GetNode
()
->
repeat_
=
Expr
(
cmd_body_len
);
mix_vec_args
.
GetNode
()
->
dst_head_
=
Expr
(
0
);
mix_vec_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
0
)};
mix_vec_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
mix_vec_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
mix_vec_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
)};
mix_vec_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
vec_max_len
,
1
,
dst_info
->
dtype_
);
mix_vec_arg_list
.
push_back
(
mix_vec_args
);
if
(
!
is_vadd
)
{
cmd_body_len
*=
FULL_BLOCK_NUM
;
}
}
if
(
cmd_tail_len
>
0
)
{
VectorArgInfo
mix_vec_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
mix_vec_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
mix_vec_args
.
GetNode
()
->
dst_head_
=
Expr
(
cmd_body_len
);
if
(
is_vadd
)
{
mix_vec_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
cmd_body_len
*
params
.
vec_max_len
)};
}
else
{
mix_vec_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
cmd_body_len
/
FULL_BLOCK_NUM
*
params
.
vec_max_len
)};
}
mix_vec_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
mix_vec_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
mix_vec_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
)};
mix_vec_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
cmd_tail_len
,
1
,
dst_info
->
dtype_
);
if
(
is_vadd
)
{
cmd_body_len
+=
1
;
}
else
{
cmd_body_len
+=
cmd_tail_len
/
fp16_block_size
;
if
(
cmd_tail_len
%
fp16_block_size
!=
0
)
{
cmd_body_len
+=
1
;
}
}
mix_vec_arg_list
.
push_back
(
mix_vec_args
);
}
}
params
.
insn_offset_scale_factor
=
Expr
(
params
.
block_size
);
int
max_num
=
body_len
/
params
.
vec_max_len
;
if
(
intrin_name
==
"vmax"
||
intrin_name
==
"vmin"
)
{
max_num
*=
FULL_BLOCK_NUM
;
}
if
(
max_num
>=
params
.
block_size
)
{
params
.
insn_offset_scale_factor
=
max_num
+
params
.
block_size
-
1
;
if
(
tail_len
>
0
)
{
params
.
insn_offset_scale_factor
+=
1
;
}
params
.
insn_offset_scale_factor
=
truncdiv
(
params
.
insn_offset_scale_factor
,
params
.
block_size
)
*
params
.
block_size
;
}
if
(
!
params
.
src_var
.
empty
())
{
return
GetRange
(
params
.
src_var
,
-
1
,
1
);
}
return
{};
}
PatternResult
LastAxisReduceInsnArgsCalculator
::
GenResult
(
const
Array
<
Var
>
&
elim_var
)
{
dst_info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
dst_info
,
elim_var
)
*
params
.
insn_offset_scale_factor
;
src_info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
src_info
,
elim_var
);
if
(
body_args
.
defined
())
{
body_args
.
GetNode
()
->
insn_offset_scale_factor_
=
params
.
insn_offset_scale_factor
;
}
if
(
tail_args
.
defined
())
{
tail_args
.
GetNode
()
->
insn_offset_scale_factor_
=
params
.
insn_offset_scale_factor
;
}
for
(
auto
&
arg
:
mix_vec_arg_list
)
{
arg
.
GetNode
()
->
insn_offset_scale_factor_
=
params
.
insn_offset_scale_factor
;
}
arg_info
.
GetNode
()
->
body_arg_info_
=
body_args
;
arg_info
.
GetNode
()
->
tail_arg_info_
=
tail_args
;
arg_info
.
GetNode
()
->
reduction_tail_args_
=
mix_vec_arg_list
;
CleanForInfoVars
(
for_info
,
elim_var
);
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_REDUCTION_LAST_AXIS
;
PatternResult
result
;
result
.
dst_info_list
=
{
dst_info
};
result
.
src_info_list
=
{
src_info
};
result
.
for_info
=
for_info
;
result
.
arg_info
=
arg_info
;
return
result
;
}
void
LastAxisReduceInsnArgsCalculator
::
CalcParams
()
{
// check shape len
if
(
dst_info
->
shape_
.
empty
()
||
src_info
->
shape_
.
empty
())
{
LOG
(
FATAL
)
<<
"CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0."
;
}
// check data type
if
(
dst_info
->
dtype_
!=
src_info
->
dtype_
)
{
LOG
(
FATAL
)
<<
"CCE Vector Insn Error: dst_buffer and src_buffer can not be different data type."
;
}
params
.
src_var
=
src_info
->
var_
;
params
.
block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
params
.
last_dim_shape
=
GetInt32Const
(
GetItem
(
src_info
->
shape_
,
-
1
));
params
.
vec_max_len
=
GetVecMaxLen
(
dst_info
->
dtype_
);
CHECK_NE
(
params
.
block_size
,
0
);
CHECK_NE
(
params
.
vec_max_len
,
0
);
}
/// Generete info list for bisection intrin
/// \param dst_info_list
/// \param src_info_list
/// \param for_info
/// \param if_info
/// \param last_axis
/// \param postfix
/// \return
BisectionInfoWrapper
SeparateComInfoToBisectionInfoList
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
StmtInfo
&
if_info
,
bool
last_axis
,
int
postfix
=
0
)
{
CHECK_EQ
(
dst_info_list
.
size
(),
1
);
CHECK_EQ
(
src_info_list
.
size
(),
2
);
BisectionInfoWrapper
wrapper
;
// Separate com_info and for_info
int
compare_idx
=
1
;
int
var_idx
=
-
1
;
var_idx
=
GetBisectionReductionIdx
(
dst_info_list
,
src_info_list
,
compare_idx
);
StmtStoreInfo
dst_info
=
dst_info_list
[
0
];
CHECK_GE
(
compare_idx
,
0
);
StmtStoreInfo
src_info1
=
src_info_list
[
compare_idx
];
Var
reduce_var
=
GetItem
(
src_info1
->
var_
,
var_idx
);
int
stride_len
=
GetInt32Const
(
GetItem
(
src_info1
->
strides_
,
var_idx
));
size_t
for_idx
=
0
;
bool
suc
=
GetIndexOfElement
(
for_info
.
vars_
,
VarExpr
(
reduce_var
),
for_idx
);
CHECK
(
suc
);
auto
exist_for
=
GetItem
(
for_info
.
ops_
,
for_idx
).
as
<
For
>
();
CHECK
(
exist_for
);
int
extent
=
GetInt32Const
(
exist_for
->
extent
);
int
simd_len
=
1
;
const
std
::
string
un_def_var
=
"un_def_var"
;
Var
simd_var
=
Var
(
"un_def_var"
);
CHECK_GT
(
src_info1
->
strides_
.
size
(),
0
);
CHECK_EQ
(
src_info1
->
var_
.
size
(),
src_info1
->
strides_
.
size
());
for
(
size_t
i
=
0
;
i
<=
src_info1
->
strides_
.
size
()
-
1
;
i
++
)
{
if
(
GetInt32Const
(
src_info1
->
strides_
[
i
])
==
1
)
{
simd_var
=
src_info1
->
var_
[
i
];
size_t
simd_for_idx
=
0
;
bool
suc
=
GetIndexOfElement
(
for_info
.
vars_
,
VarExpr
(
simd_var
),
simd_for_idx
);
CHECK
(
suc
);
auto
simd_for
=
GetItem
(
for_info
.
ops_
,
simd_for_idx
).
as
<
For
>
();
CHECK
(
simd_for
);
simd_len
=
GetInt32Const
(
simd_for
->
extent
);
}
}
int
block_unit
=
GetUbBlkSize
(
src_info1
->
dtype_
);
int
last_dim_len
=
((
simd_len
-
1
)
/
block_unit
+
1
)
*
block_unit
;
Var
bisec_var
;
Buffer
bisec_buffer
;
std
::
string
bisec_pre_header
=
"bisec"
;
std
::
string
bisec_name
=
bisec_pre_header
+
"_local_UB"
;
if
(
postfix
>
0
)
{
bisec_name
=
bisec_name
+
"_"
+
std
::
to_string
(
postfix
);
}
int
vec_max_len
=
GetVecMaxLen
(
dst_info
->
dtype_
);
CHECK_NE
(
vec_max_len
,
0
);
std
::
vector
<
int
>
pow2_list
=
{
0
,
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
32768
,
65536
};
int
origin_len
=
extent
;
for
(
int
i
:
pow2_list
)
{
if
(
extent
<=
i
)
{
extent
=
i
/
2
;
break
;
}
}
int
prolog_len
=
origin_len
-
extent
;
src_info1
.
Print
();
auto
src_vars
=
src_info1
->
var_
;
auto
src_strides
=
src_info1
->
strides_
;
auto
src_dims
=
src_info1
->
shape_
;
auto
new_vars
=
src_info1
->
var_
;
auto
new_strides
=
src_info1
->
strides_
;
auto
new_dims
=
src_info1
->
shape_
;
LOG
(
DEBUG
)
<<
"
\n
var_idx:"
<<
var_idx
<<
"
\n
"
;
var_idx
=
var_idx
+
src_info1
->
var_
.
size
();
if
(
var_idx
!=
static_cast
<
int
>
(
src_info1
->
var_
.
size
())
-
1
)
{
new_dims
.
Set
(
var_idx
,
extent
);
new_dims
.
Set
(
new_dims
.
size
()
-
1
,
last_dim_len
);
CHECK_GT
(
new_dims
.
size
(),
1
);
new_strides
.
Set
(
new_strides
.
size
()
-
1
,
1
);
for
(
int
i
=
static_cast
<
int
>
(
new_dims
.
size
())
-
2
;
i
>=
0
;
i
--
)
{
new_strides
.
Set
(
i
,
new_strides
[
i
+
1
]
*
new_dims
[
i
+
1
]);
}
new_dims
.
Set
(
new_dims
.
size
()
-
1
,
simd_len
);
}
else
{
new_dims
=
{
extent
};
}
// copy data from origin buffer to new temp buffer
Array
<
Expr
>
shape
=
new_dims
;
wrapper
.
original_shape_
=
new_dims
;
bisec_var
=
Var
(
bisec_name
,
Handle
());
bisec_buffer
=
BufferNode
::
make
(
bisec_var
,
dst_info
->
dtype_
,
shape
,
Array
<
Expr
>
(),
Expr
(),
bisec_name
,
SCOPE_UBUF
,
0
,
0
,
BufferType
::
kDefault
);
// Need to copy input to bisect buffer
StmtStoreInfo
copy_dst_info
{
src_info1
.
Copy
()};
StmtStoreInfo
copy_src_info
{
src_info1
.
Copy
()};
StmtInfoList
src_list
=
{
copy_src_info
};
auto
for_tmp_info
=
for_info
.
Copy
();
auto
new_for
=
GetItem
(
for_tmp_info
.
ops_
,
for_idx
).
as
<
For
>
();
CHECK
(
new_for
);
SetItem
(
for_tmp_info
.
ops_
,
static_cast
<
int
>
(
for_idx
),
For
::
make
(
new_for
->
loop_var
,
new_for
->
min
,
extent
,
new_for
->
for_type
,
new_for
->
device_api
,
new_for
->
body
));
ReplaceVarWithNewForInfo
(
copy_dst_info
,
for_info
,
for_tmp_info
);
ReplaceVarWithNewForInfo
(
copy_src_info
,
for_info
,
for_tmp_info
);
SetItem
(
copy_src_info
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
extent
));
SetItem
(
copy_dst_info
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
extent
));
SetItem
(
copy_dst_info
.
GetNode
()
->
strides_
,
var_idx
,
Expr
(
last_dim_len
));
if
(
simd_var
->
name_hint
!=
un_def_var
)
{
copy_dst_info
.
GetNode
()
->
index_
=
0
;
for
(
size_t
i
=
0
;
i
<=
new_vars
.
size
()
-
1
;
i
++
)
{
copy_dst_info
.
GetNode
()
->
index_
+=
new_vars
[
i
]
*
new_strides
[
i
];
}
}
else
{
copy_dst_info
.
GetNode
()
->
index_
=
last_dim_len
*
reduce_var
;
}
copy_dst_info
.
GetNode
()
->
elem_offset_
=
0
;
copy_dst_info
.
GetNode
()
->
name_
=
bisec_name
;
copy_dst_info
.
GetNode
()
->
buffer_
=
bisec_buffer
;
copy_dst_info
.
GetNode
()
->
data_
=
bisec_var
;
copy_dst_info
.
GetNode
()
->
strides_
=
new_strides
;
CompactComputationInfoList
(
copy_dst_info
,
src_list
,
if_info
,
for_tmp_info
);
wrapper
.
bisec_info_list_
.
emplace_back
(
StmtInfoList
{
copy_dst_info
,
copy_src_info
});
wrapper
.
for_info_list_
.
push_back
(
for_tmp_info
);
// Generate the vadd wrapper
while
(
extent
>=
0
)
{
StmtStoreInfo
dst_tmp_info
=
dst_info
.
Copy
();
StmtStoreInfo
src_tmp_info0
{
src_info1
.
Copy
()};
StmtStoreInfo
src_tmp_info1
{
src_info1
.
Copy
()};
auto
for_tmp_info
=
for_info
.
Copy
();
int
vadd_length
=
(
prolog_len
!=
0
)
?
prolog_len
:
extent
;
if
(
extent
>
0
)
{
dst_tmp_info
=
src_info1
.
Copy
();
dst_tmp_info
.
GetNode
()
->
data_alignment_
=
simd_len
;
dst_tmp_info
.
GetNode
()
->
name_
=
bisec_name
;
dst_tmp_info
.
GetNode
()
->
buffer_
=
bisec_buffer
;
dst_tmp_info
.
GetNode
()
->
data_
=
bisec_var
;
dst_tmp_info
.
GetNode
()
->
shape_
=
new_dims
;
SetItem
(
dst_tmp_info
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
vadd_length
));
dst_tmp_info
.
GetNode
()
->
strides_
=
new_strides
;
dst_tmp_info
.
GetNode
()
->
var_
=
new_vars
;
dst_tmp_info
.
GetNode
()
->
index_
=
0
;
for
(
size_t
i
=
0
;
i
<=
new_vars
.
size
()
-
1
;
i
++
)
{
dst_tmp_info
.
GetNode
()
->
index_
+=
new_vars
[
i
]
*
new_strides
[
i
];
}
if
(
prolog_len
==
0
)
{
src_tmp_info1
=
dst_tmp_info
.
Copy
();
src_tmp_info1
.
GetNode
()
->
index_
=
dst_tmp_info
.
GetNode
()
->
index_
+
extent
*
last_dim_len
;
}
else
{
SetItem
(
src_tmp_info1
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
vadd_length
));
src_tmp_info1
.
GetNode
()
->
index_
+=
extent
*
stride_len
;
}
}
src_tmp_info0
=
dst_tmp_info
.
Copy
();
auto
new_for
=
GetItem
(
for_tmp_info
.
ops_
,
for_idx
).
as
<
For
>
();
CHECK
(
new_for
);
int
temp_for_len
=
(
vadd_length
!=
0
)
?
vadd_length
:
1
;
SetItem
(
for_tmp_info
.
ops_
,
static_cast
<
int
>
(
for_idx
),
For
::
make
(
new_for
->
loop_var
,
new_for
->
min
,
temp_for_len
,
new_for
->
for_type
,
new_for
->
device_api
,
new_for
->
body
));
if
(
extent
==
0
)
{
src_tmp_info1
.
GetNode
()
->
name_
=
bisec_name
;
src_tmp_info1
.
GetNode
()
->
buffer_
=
bisec_buffer
;
src_tmp_info1
.
GetNode
()
->
data_
=
bisec_var
;
if
(
simd_var
->
name_hint
!=
un_def_var
)
{
src_tmp_info1
.
GetNode
()
->
shape_
=
RemoveItemAtIndex
(
new_dims
,
var_idx
);
src_tmp_info1
.
GetNode
()
->
strides_
=
RemoveItemAtIndex
(
new_strides
,
var_idx
);
src_tmp_info1
.
GetNode
()
->
var_
=
RemoveItemAtIndex
(
new_vars
,
var_idx
);
src_tmp_info1
.
GetNode
()
->
index_
=
0
;
for
(
size_t
i
=
0
;
i
<=
src_tmp_info1
->
var_
.
size
()
-
1
;
i
++
)
{
src_tmp_info1
.
GetNode
()
->
index_
+=
src_tmp_info1
->
var_
[
i
]
*
src_tmp_info1
->
strides_
[
i
];
}
}
else
{
src_tmp_info1
.
GetNode
()
->
shape_
=
dst_tmp_info
->
shape_
;
src_tmp_info1
.
GetNode
()
->
strides_
=
dst_tmp_info
->
strides_
;
src_tmp_info1
.
GetNode
()
->
var_
=
dst_tmp_info
->
var_
;
src_tmp_info1
.
GetNode
()
->
index_
=
dst_tmp_info
->
index_
;
}
}
ReplaceVarWithNewForInfo
(
dst_tmp_info
,
for_info
,
for_tmp_info
);
ReplaceVarWithNewForInfo
(
src_tmp_info0
,
for_info
,
for_tmp_info
);
ReplaceVarWithNewForInfo
(
src_tmp_info1
,
for_info
,
for_tmp_info
);
StmtInfoList
src_list
=
{
src_tmp_info0
,
src_tmp_info1
};
CompactComputationInfoList
(
dst_tmp_info
,
src_list
,
if_info
,
for_tmp_info
);
wrapper
.
for_info_list_
.
emplace_back
(
for_tmp_info
);
if
(
extent
==
0
)
{
// normally is bisect_tmp = bisect_tmp + bisect_tmp/src_tmp
wrapper
.
bisec_info_list_
.
emplace_back
(
StmtInfoList
{
dst_tmp_info
,
dst_tmp_info
,
src_tmp_info1
});
}
else
{
// normally is dst_tmp = dst_tmp + bisect_tmp
wrapper
.
bisec_info_list_
.
emplace_back
(
StmtInfoList
{
dst_tmp_info
,
src_tmp_info0
,
src_tmp_info1
});
}
if
(
extent
==
0
)
{
break
;
}
else
{
extent
=
extent
/
2
;
}
prolog_len
=
0
;
}
// Generate arg_info
for
(
size_t
i
=
0
;
i
<
wrapper
.
bisec_info_list_
.
size
();
++
i
)
{
auto
info_list
=
wrapper
.
bisec_info_list_
[
i
];
auto
new_for_info
=
wrapper
.
for_info_list_
[
i
];
ArgInfo
arg_info
;
auto
dst_list
=
GetRange
(
info_list
,
0
,
1
);
auto
src_list
=
GetRange
(
info_list
,
1
,
info_list
.
size
()
-
1
);
if
(
info_list
.
size
()
==
2
)
{
std
::
string
dma_intrin
=
INTRIN_NAME_COPY_UB_TO_UB
;
wrapper
.
dma_arg_info_map_
=
GetDmaCopyInsnArgs
(
dma_intrin
,
dst_list
,
src_list
,
new_for_info
);
}
else
{
// Bisect can't expand mask because it has inplace operation
if
(
i
!=
wrapper
.
bisec_info_list_
.
size
()
-
1
)
{
// Last round dont need to add
FillLastDim
(
dst_list
,
src_list
,
new_for_info
);
}
std
::
string
mode
=
GetBinaryVecMode
(
dst_list
,
src_list
,
"vadd"
,
false
);
BinaryVecInsnArgsCalculator
args_calculator
=
BinaryVecInsnArgsCalculator
(
dst_list
,
src_list
,
new_for_info
,
mode
,
""
,
false
);
PatternResult
params
=
args_calculator
.
GetInsnArgs
();
arg_info
=
params
.
arg_info
;
dst_list
=
params
.
dst_info_list
;
src_list
=
params
.
src_info_list
;
new_for_info
=
params
.
for_info
;
wrapper
.
bisec_info_list_
[
i
]
=
{
dst_list
[
0
],
src_list
[
0
],
src_list
[
1
]};
}
wrapper
.
arg_info_list_
.
push_back
(
arg_info
);
wrapper
.
for_info_list_
[
i
]
=
new_for_info
;
}
return
wrapper
;
}
/// Get CCE Binary Vector Insn Computation Info
/// \param stmt - operand stmt
/// \param intrin_name - vector intrin name
/// \param dst_info_list - dst computation info list
/// \param src_info_list - src computation info list
/// \param if_info - if info list
/// \param for_info - for info list
/// \return intrin args
ArgInfo
GetBinaryVecInsnArgs
(
const
Stmt
&
stmt
,
std
::
string
intrin_name
,
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
bool
enable_bisect
)
{
// check intrin_name
std
::
set
<
std
::
string
>
intrin_name_list
=
{
"vadd"
,
"vmax"
,
"vmin"
,
"vmul"
,
"vdiv"
,
"vsel"
,
"vsub"
,
"vand"
,
"vor"
,
"vaxpy"
,
"argmax"
,
"argmin"
,
"vmadd"
,
"vmaddrelu"
,
"vmla"
};
if
(
intrin_name_list
.
count
(
intrin_name
)
==
0
)
{
LOG
(
FATAL
)
<<
"Error: CCE Binary Vector Insn doesn't support the given intrin_name."
;
}
// get and check dst and src
GetCompactComputationInfo
(
stmt
,
dst_info_list
,
src_info_list
,
if_info
,
for_info
,
true
);
// For vmadd/vmaddrelu/vmla we only need first two src
if
(
dst_info_list
.
size
()
!=
1
||
src_info_list
.
size
()
<
2
)
{
LOG
(
FATAL
)
<<
"CCE Binary Vector Insn only support ONE dst and TWO srcs."
;
}
src_info_list
=
GetRange
(
src_info_list
,
0
,
2
);
ArgInfo
arg_info
=
ArgInfo
(
make_node
<
ArgInfoNode
>
());
// detect vector op mode
std
::
string
mode
=
GetBinaryVecMode
(
dst_info_list
,
src_info_list
,
intrin_name
,
enable_bisect
);
if
(
mode
==
"reduce_last_axis"
)
{
size_t
src_var_list_size
=
src_info_list
[
1
]
->
var_
.
size
();
if
(
src_info_list
[
0
]
->
var_
.
size
()
>
src_info_list
[
1
]
->
var_
.
size
())
{
src_var_list_size
=
src_info_list
[
0
]
->
var_
.
size
();
}
CHECK
(
src_var_list_size
>
0
)
<<
"Error: src can not be a scalar."
;
if
(
src_var_list_size
-
dst_info_list
[
0
]
->
var_
.
size
()
==
1
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_REDUCTION_LAST_AXIS
;
}
else
{
LOG
(
FATAL
)
<<
"Error: cannot support multi-last-axis reduction."
;
}
}
else
if
(
mode
==
"reduce_bisection"
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_REDUCTION_BISECTION
;
}
else
{
if
(
mode
!=
"reduction"
&&
mode
!=
"broadcast"
)
{
FillLastDim
(
dst_info_list
,
src_info_list
,
for_info
);
}
// vmax/vmin can't expand mask because it may introduce dirty data
bool
can_expand_mask
=
intrin_name
!=
"vmax"
&&
intrin_name
!=
"vmin"
;
BinaryVecInsnArgsCalculator
args_calculator
=
BinaryVecInsnArgsCalculator
(
dst_info_list
,
src_info_list
,
for_info
,
mode
,
intrin_name
,
can_expand_mask
);
PatternResult
params
=
args_calculator
.
GetInsnArgs
();
arg_info
=
params
.
arg_info
;
dst_info_list
=
params
.
dst_info_list
;
src_info_list
=
params
.
src_info_list
;
for_info
=
params
.
for_info
;
if
(
mode
==
"broadcast"
)
{
bool
has_last_axis
=
false
;
if
((
arg_info
->
body_arg_info_
.
defined
()
&&
arg_info
->
body_arg_info_
->
last_axis_info_
.
src_index_
!=
-
1
)
||
(
arg_info
->
tail_arg_info_
.
defined
()
&&
arg_info
->
tail_arg_info_
->
last_axis_info_
.
src_index_
!=
-
1
))
{
has_last_axis
=
true
;
}
if
(
has_last_axis
&&
(
intrin_name
==
"vadd"
||
intrin_name
==
"vmul"
))
{
Array
<
NodeRef
>
stores
;
Array
<
NodeRef
>
loads
;
GetStoreAndLoads
(
stmt
,
stores
,
loads
);
intrin_name
=
intrin_name
+
"s"
;
if
(
arg_info
->
body_arg_info_
.
defined
())
{
arg_info
.
GetNode
()
->
body_arg_info_
.
GetNode
()
->
last_axis_info_
.
intrin_name_
=
intrin_name
;
arg_info
.
GetNode
()
->
body_arg_info_
.
GetNode
()
->
last_axis_info_
.
src_op_
=
Downcast
<
Expr
>
(
loads
[
arg_info
->
body_arg_info_
->
last_axis_info_
.
src_index_
]);
}
}
}
}
return
arg_info
;
}
}
// namespace akg
\ No newline at end of file
src/emit_insn/insn_args_calculator.h
0 → 100644
浏览文件 @
0871623a
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef EMIT_INSN_ARGS_CALCULATOR_H_
#define EMIT_INSN_ARGS_CALCULATOR_H_
namespace
akg
{
struct
InsnArg
{
int
dst_m0
{
1
};
int
dst_m1
{
0
};
std
::
vector
<
Expr
>
src_m0_list
;
std
::
vector
<
Expr
>
src_m1_list
;
int
repeat
{
1
};
int
block_len
{
1
};
int
block_num
{
1
};
int
body_num
{
1
};
int
tail_len
{
0
};
int
dst_tail_offset
{
0
};
std
::
vector
<
Expr
>
src_tail_offset_list
;
};
struct
Meta
{
int
block_size
{
0
};
int
src_block_size
{
0
};
int
dst_block_size
{
0
};
int
block_offset
{
0
};
const
float
vec_rate
{
0.6
};
Type
src_dtype
;
Type
dst_dtype
;
Type
dtype
;
bool
cast
{
false
};
bool
tail
{
false
};
bool
scalar
{
false
};
bool
liner
{
false
};
bool
same_dst_src
{
false
};
};
enum
SplitStat
{
SUCCESS
,
NO_SPLIT
,
TAIL
};
class
InsnAxis
{
public:
InsnAxis
()
=
default
;
InsnAxis
(
const
For
*
for_stmt
,
const
Array
<
StmtStoreInfo
>
&
info_list
);
virtual
~
InsnAxis
()
=
default
;
bool
IsValid
();
void
Print
(
const
std
::
string
&
name
=
""
);
int
min
{
0
};
int
extent
{
0
};
Var
var
;
int
dst_stride
{
0
};
int
src_stride
{
0
};
std
::
vector
<
int
>
src_stride_list
;
std
::
vector
<
int
>
stride_list
;
bool
is_valid
{
true
};
private:
Expr
GetStrideByAxis
(
const
Array
<
Var
>
&
vars
,
const
Array
<
Expr
>
&
strides
,
Var
obj_var
);
};
using
AxisIt
=
std
::
list
<
InsnAxis
>::
iterator
;
std
::
list
<
InsnAxis
>
GetAxisList
(
const
StmtInfo
&
for_info
,
const
Array
<
StmtStoreInfo
>
&
info_list
);
Array
<
StmtStoreInfo
>
GetInfoList
(
const
StmtStoreInfo
&
dst_info
,
const
Array
<
StmtStoreInfo
>
&
src_info_list
);
int
DivFloor
(
int
a
,
int
b
);
void
Print
(
std
::
list
<
InsnAxis
>
&
axis_list
);
class
InsnArgsCalculator
{
public:
InsnArgsCalculator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
);
virtual
~
InsnArgsCalculator
()
=
default
;
PatternResult
ExportResult
();
void
CalAxis
();
void
InitArg
();
virtual
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetStrideLambda
();
virtual
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetM0LimitLambda
();
virtual
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetM1LimitLambda
();
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetBlockStrideLimitLambda
();
AxisIt
GetAxisByLambda
(
const
std
::
function
<
bool
(
const
InsnAxis
&
)
>
&
lambda
);
InsnAxis
ExtractAxis
(
AxisIt
&
it
);
bool
IsValid
(
AxisIt
&
it
);
AxisIt
GetVecAxisIt
();
AxisIt
GetBlockAxis
();
AxisIt
GetRepeatAxisIt
();
InsnAxis
GetRepeatAxis
();
void
SetArgMask
(
int
len
);
void
SetArgBlockNum
(
int
data_num
);
void
SetArgBlockLen
(
int
data_len
);
void
SetArgM0
(
int
dst_m0
,
int
lsrc_m0
,
int
rsrc_m0
);
void
SetArgM1
(
int
dst_m1
,
int
lsrc_m1
,
int
rsrc_m1
);
void
SetArgRepeat
(
int
repeat
);
void
BlockAxisReduction
();
void
RepeatAxisReduction
();
void
CastCaseReduction
();
virtual
void
InsnReduction
();
StmtInfo
ExportForInfo
();
Expr
GetOffset
(
int
stride_index
);
InsnAxis
GetInvalidAxis
();
SplitStat
SplitAxis
(
int
extent
,
InsnAxis
&
axis
);
std
::
list
<
InsnAxis
>
axis_list_
;
protected:
InsnArg
arg_
;
Meta
meta_
;
StmtInfoList
dst_info_list_
;
StmtInfoList
src_info_list_
;
StmtStoreInfo
dst_info_
;
StmtInfo
for_info_
;
const
std
::
string
intrin_name_
;
const
int
max_block_stride_
{
4
};
};
class
SingleVecInsnArgsCalculator
:
public
InsnArgsCalculator
{
public:
SingleVecInsnArgsCalculator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
=
""
);
virtual
~
SingleVecInsnArgsCalculator
()
override
=
default
;
PatternResult
GetInsnArgs
();
};
class
BinaryVecInsnArgsCalculator
:
public
InsnArgsCalculator
{
public:
BinaryVecInsnArgsCalculator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
mode
,
const
std
::
string
&
intrin_name
=
""
,
bool
expand_mask
=
true
);
virtual
~
BinaryVecInsnArgsCalculator
()
override
=
default
;
PatternResult
GetInsnArgs
();
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetM0LimitLambda
();
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetM1LimitLambda
();
void
InsnReduction
();
private:
std
::
string
mode_
;
bool
expand_mask_
;
InsnAxis
vec_axis_
;
};
class
LastAxisReduceInsnArgsCalculator
:
InsnArgsCalculator
{
public:
LastAxisReduceInsnArgsCalculator
(
const
StmtStoreInfo
&
dst_info
,
const
StmtStoreInfo
&
src_info
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
)
:
InsnArgsCalculator
({
dst_info
},
{
src_info
},
for_info
,
intrin_name
),
dst_info
(
dst_info
),
src_info
(
src_info
),
for_info
(
for_info
),
arg_info
(
ArgInfo
(
make_node
<
ArgInfoNode
>
())),
body_args
(
VectorArgInfo
()),
tail_args
(
VectorArgInfo
()),
intrin_name
(
intrin_name
)
{}
PatternResult
GetInsnArgs
();
~
LastAxisReduceInsnArgsCalculator
()
=
default
;
protected:
Array
<
Var
>
GetPattern
();
PatternResult
GenResult
(
const
Array
<
Var
>
&
elim_var
);
private:
void
CalcParams
();
struct
Params
{
Array
<
Var
>
src_var
;
int
block_size
=
0
;
int
vec_max_len
=
0
;
int
last_dim_shape
=
0
;
Expr
insn_offset_scale_factor
;
};
StmtStoreInfo
dst_info
;
StmtStoreInfo
src_info
;
StmtInfo
for_info
;
ArgInfo
arg_info
;
VectorArgInfo
body_args
;
VectorArgInfo
tail_args
;
Array
<
VectorArgInfo
>
mix_vec_arg_list
;
std
::
string
intrin_name
;
Params
params
;
};
BisectionInfoWrapper
SeparateComInfoToBisectionInfoList
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
StmtInfo
&
if_info
,
bool
last_axis
,
int
postfix
);
ArgInfo
GetBinaryVecInsnArgs
(
const
Stmt
&
stmt
,
std
::
string
intrin_name
,
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
bool
enable_bisect
=
true
);
}
// namespace akg
#endif
\ No newline at end of file
src/emit_insn/insn_binary_vec_pattern.cc
已删除
100644 → 0
浏览文件 @
d237aa7d
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/base.h>
#include <tvm/ir_pass.h>
#include <set>
#include "ir_pass.h"
#include "contrib/cce_parm/cceconf.h"
#include "tvm.h"
#include "common/array_api.h"
#include "insn_pattern.h"
#include "insn_builder.h"
namespace
akg
{
std
::
string
GetBinaryVecMode
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
std
::
string
&
intrin_name
,
bool
enable_bisect
=
true
)
{
std
::
set
<
std
::
string
>
reduce_bisect_list
=
{
"vadd"
,
"vsub"
,
"vmul"
,
"vmax"
};
std
::
string
mode
=
"reduction"
;
if
(
IsElementwise
(
dst_info_list
,
src_info_list
))
{
mode
=
"elewise"
;
}
else
if
(
IsBroadcast
(
dst_info_list
,
src_info_list
))
{
mode
=
"broadcast"
;
}
else
if
(
IsLastAxisReduction
(
dst_info_list
,
src_info_list
))
{
mode
=
"reduce_last_axis"
;
}
else
if
(
enable_bisect
&&
reduce_bisect_list
.
count
(
intrin_name
)
!=
0
&&
IsBisectionReduction
(
dst_info_list
,
src_info_list
))
{
mode
=
"reduce_bisection"
;
}
return
mode
;
}
PatternResult
ReduceLastAxisPatternGenerator
::
GetInsnArgs
()
{
CalcParams
();
Array
<
Var
>
elim_var
;
float
rate2d
=
Compute2DBlockPatternMaskRate
();
if
(
rate2d
>
1.0
f
)
{
elim_var
=
Get2DBlockPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_2D_BLOCK
;
}
else
{
elim_var
=
Get1DPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_1D
;
}
return
GenResult
(
elim_var
);
}
float
ReduceLastAxisPatternGenerator
::
Compute2DBlockPatternMaskRate
()
{
const
float
is2_dpattern
=
1.0
f
;
if
(
intrin_name
==
"vadd"
||
intrin_name
==
"argmax"
||
intrin_name
==
"argmin"
)
{
return
not_this_pattern
;
}
// src var size must larger than 2
if
(
params
.
src_var
.
size
()
<
2
)
{
return
not_this_pattern
;
}
int
body_len
=
params
.
last_dim_shape
/
params
.
vec_max_len
*
params
.
vec_max_len
;
int
tail_len
=
params
.
last_dim_shape
%
params
.
vec_max_len
;
// there is no body in this mode
if
(
body_len
>
0
||
tail_len
>
params
.
block_size
)
{
return
not_this_pattern
;
}
return
is2_dpattern
;
}
Array
<
Var
>
ReduceLastAxisPatternGenerator
::
Get2DBlockPattern
()
{
int
sec_last_dim_shape
=
GetInt32Const
(
GetItem
(
src_info
->
shape_
,
-
2
));
int
body_len
=
sec_last_dim_shape
/
FULL_BLOCK_NUM
*
FULL_BLOCK_NUM
;
int
tail_len
=
sec_last_dim_shape
%
FULL_BLOCK_NUM
;
int
cmd_body_len
=
0
;
if
(
body_len
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
repeat_
=
Expr
(
body_len
/
FULL_BLOCK_NUM
);
// Here use dst_stride_m1 as dst_stride
body_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
)};
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
last_dim_shape
,
FULL_BLOCK_NUM
,
dst_info
->
dtype_
);
cmd_body_len
+=
GetInt32Const
(
body_args
->
repeat_
)
*
FULL_BLOCK_NUM
;
}
if
(
tail_len
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
cmd_body_len
);
tail_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
cmd_body_len
/
FULL_BLOCK_NUM
*
params
.
vec_max_len
)};
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
0
)};
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
last_dim_shape
,
tail_len
,
dst_info
->
dtype_
);
}
params
.
insn_offset_scale_factor
=
1
;
return
GetRange
(
params
.
src_var
,
-
2
,
2
);
}
Array
<
Var
>
ReduceLastAxisPatternGenerator
::
Get1DPattern
()
{
int
body_len
=
params
.
last_dim_shape
/
params
.
vec_max_len
*
params
.
vec_max_len
;
int
tail_len
=
params
.
last_dim_shape
%
params
.
vec_max_len
;
int
cmd_body_len
=
0
;
bool
is_vadd
=
intrin_name
==
"vadd"
;
int
repeat_stride
=
FULL_BLOCK_NUM
;
if
(
is_vadd
)
{
repeat_stride
=
1
;
}
const
int
fp16_block_size
=
16
;
if
(
body_len
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
body_offset_
=
params
.
vec_max_len
;
body_args
.
GetNode
()
->
repeat_
=
Expr
(
body_len
/
params
.
vec_max_len
);
// Here use dst_stride_m1 as dst_stride
body_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
)};
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
vec_max_len
,
1
,
dst_info
->
dtype_
);
cmd_body_len
+=
GetInt32Const
(
body_args
->
repeat_
)
*
repeat_stride
;
}
if
(
tail_len
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
tail_args
.
GetNode
()
->
body_offset_
=
params
.
vec_max_len
;
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
cmd_body_len
);
tail_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
body_len
)};
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
0
)};
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
tail_len
,
1
,
dst_info
->
dtype_
);
if
(
is_vadd
)
{
cmd_body_len
+=
1
;
}
else
{
cmd_body_len
+=
tail_len
/
fp16_block_size
;
if
(
tail_len
%
fp16_block_size
!=
0
)
{
cmd_body_len
+=
1
;
}
}
}
// cmd_body_len > 1 means vcadd size greater than 128, need to use vcadd again to compute final result
// if cmd_body_len > 128, then need to recursively emit vcadd
while
(
cmd_body_len
>
1
)
{
int
cmd_tail_len
=
cmd_body_len
%
params
.
vec_max_len
;
cmd_body_len
=
cmd_body_len
/
params
.
vec_max_len
;
if
(
cmd_body_len
>
0
)
{
VectorArgInfo
mix_vec_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
mix_vec_args
.
GetNode
()
->
repeat_
=
Expr
(
cmd_body_len
);
mix_vec_args
.
GetNode
()
->
dst_head_
=
Expr
(
0
);
mix_vec_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
0
)};
mix_vec_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
mix_vec_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
mix_vec_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
)};
mix_vec_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
vec_max_len
,
1
,
dst_info
->
dtype_
);
mix_vec_arg_list
.
push_back
(
mix_vec_args
);
if
(
!
is_vadd
)
{
cmd_body_len
*=
FULL_BLOCK_NUM
;
}
}
if
(
cmd_tail_len
>
0
)
{
VectorArgInfo
mix_vec_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
mix_vec_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
mix_vec_args
.
GetNode
()
->
dst_head_
=
Expr
(
cmd_body_len
);
if
(
is_vadd
)
{
mix_vec_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
cmd_body_len
*
params
.
vec_max_len
)};
}
else
{
mix_vec_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
cmd_body_len
/
FULL_BLOCK_NUM
*
params
.
vec_max_len
)};
}
mix_vec_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
1
);
mix_vec_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
mix_vec_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
)};
mix_vec_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
cmd_tail_len
,
1
,
dst_info
->
dtype_
);
if
(
is_vadd
)
{
cmd_body_len
+=
1
;
}
else
{
cmd_body_len
+=
cmd_tail_len
/
fp16_block_size
;
if
(
cmd_tail_len
%
fp16_block_size
!=
0
)
{
cmd_body_len
+=
1
;
}
}
mix_vec_arg_list
.
push_back
(
mix_vec_args
);
}
}
params
.
insn_offset_scale_factor
=
Expr
(
params
.
block_size
);
int
max_num
=
body_len
/
params
.
vec_max_len
;
if
(
intrin_name
==
"vmax"
||
intrin_name
==
"vmin"
)
{
max_num
*=
FULL_BLOCK_NUM
;
}
if
(
max_num
>=
params
.
block_size
)
{
params
.
insn_offset_scale_factor
=
max_num
+
params
.
block_size
-
1
;
if
(
tail_len
>
0
)
{
params
.
insn_offset_scale_factor
+=
1
;
}
params
.
insn_offset_scale_factor
=
truncdiv
(
params
.
insn_offset_scale_factor
,
params
.
block_size
)
*
params
.
block_size
;
}
if
(
!
params
.
src_var
.
empty
())
{
return
GetRange
(
params
.
src_var
,
-
1
,
1
);
}
return
{};
}
PatternResult
ReduceLastAxisPatternGenerator
::
GenResult
(
const
Array
<
Var
>
&
elim_var
)
{
dst_info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
dst_info
,
elim_var
)
*
params
.
insn_offset_scale_factor
;
src_info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
src_info
,
elim_var
);
if
(
body_args
.
defined
())
{
body_args
.
GetNode
()
->
insn_offset_scale_factor_
=
params
.
insn_offset_scale_factor
;
}
if
(
tail_args
.
defined
())
{
tail_args
.
GetNode
()
->
insn_offset_scale_factor_
=
params
.
insn_offset_scale_factor
;
}
for
(
auto
&
arg
:
mix_vec_arg_list
)
{
arg
.
GetNode
()
->
insn_offset_scale_factor_
=
params
.
insn_offset_scale_factor
;
}
arg_info
.
GetNode
()
->
body_arg_info_
=
body_args
;
arg_info
.
GetNode
()
->
tail_arg_info_
=
tail_args
;
arg_info
.
GetNode
()
->
reduction_tail_args_
=
mix_vec_arg_list
;
CleanForInfoVars
(
for_info
,
elim_var
);
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_REDUCTION_LAST_AXIS
;
PatternResult
result
;
result
.
dst_info_list
=
{
dst_info
};
result
.
src_info_list
=
{
src_info
};
result
.
for_info
=
for_info
;
result
.
arg_info
=
arg_info
;
return
result
;
}
void
ReduceLastAxisPatternGenerator
::
CalcParams
()
{
// check shape len
if
(
dst_info
->
shape_
.
empty
()
||
src_info
->
shape_
.
empty
())
{
LOG
(
FATAL
)
<<
"CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0."
;
}
// check data type
if
(
dst_info
->
dtype_
!=
src_info
->
dtype_
)
{
LOG
(
FATAL
)
<<
"CCE Vector Insn Error: dst_buffer and src_buffer can not be different data type."
;
}
params
.
src_var
=
src_info
->
var_
;
params
.
block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
params
.
last_dim_shape
=
GetInt32Const
(
GetItem
(
src_info
->
shape_
,
-
1
));
params
.
vec_max_len
=
GetVecMaxLen
(
dst_info
->
dtype_
);
CHECK_NE
(
params
.
block_size
,
0
);
CHECK_NE
(
params
.
vec_max_len
,
0
);
}
/// Get CCE Binary Vector instructions args
/// \return
PatternResult
BinaryVecPatternGenerator
::
GetInsnArgs
()
{
CalcParams
();
if
(
arg_info
->
arg_type_
==
ARG_VECTOR_BROADCAST_LAST_AXIS
)
{
PatternResult
result
;
result
.
dst_info_list
=
{
dst_info
};
result
.
src_info_list
=
src_info_list
;
result
.
for_info
=
for_info
;
result
.
arg_info
=
arg_info
;
return
result
;
}
Array
<
Var
>
elim_var
=
{};
float
rate3d
=
Compute3DPatternMaskRate
();
float
rate2db
=
Compute2DBlockPatternMaskRate
();
float
rate2d
=
Compute2DPatternMaskRate
();
float
rate1d
=
Compute1DPatternMaskRate
();
if
(
rate3d
>=
rate2db
&&
rate3d
>
0
)
{
elim_var
=
Get3DPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_3D
;
}
else
if
(
rate2db
>=
rate2d
&&
rate2db
>=
rate1d
&&
rate2db
>
0
)
{
elim_var
=
Get2DBlockPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_PARTIAL_3D
;
}
else
if
(
rate2d
>
rate1d
&&
rate2d
>
0
)
{
elim_var
=
Get2DPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_2D
;
}
else
if
(
rate1d
>
0
)
{
elim_var
=
Get1DPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_1D
;
}
else
{
LOG
(
FATAL
)
<<
"Error: Cannot emit Binary-Vector-Insn with any pattern!"
;
}
std
::
string
mask_rate
=
"rate3d["
+
std
::
to_string
(
rate3d
)
+
"], rate2db["
+
std
::
to_string
(
rate2db
)
+
"], rate2d["
+
std
::
to_string
(
rate2d
)
+
"], rate1d["
+
std
::
to_string
(
rate1d
)
+
"]"
;
CommentManager
::
GetInstance
().
AddComment
(
"Mask_rate"
,
mask_rate
);
if
(
tail_args
.
defined
())
{
CommentManager
::
GetInstance
().
AddComment
(
"Contain_tail"
,
"true"
);
}
else
{
CommentManager
::
GetInstance
().
AddComment
(
"Contain_tail"
,
"false"
);
}
return
GenResult
(
elim_var
);
}
float
BinaryVecPatternGenerator
::
Compute3DPatternMaskRate
()
{
if
(
params
.
non_zero_shape3
==
1
||
params
.
non_zero_shape2
==
1
)
{
return
not_this_pattern
;
}
// in elemwise mode, the var is already checked to be equal, no need to check
if
(
params
.
dst_var
.
size
()
<
3
||
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
))
>
params
.
block_size
||
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
block_size
!=
0
||
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
3
))
%
params
.
block_size
!=
0
||
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
>
0
&&
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
))
>
0
&&
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
<
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
)))
||
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
3
))
>
0
&&
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
))
>
0
&&
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
3
))
<
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
))))
{
return
not_this_pattern
;
}
// check dst_stride_m0
// As described in ISL User Guide t6.3,
// dst_stride_m0 = 0 is treated as 1
auto
JudgeNot3D
=
[
this
](
const
StmtStoreInfo
&
info
)
{
auto
last_shape1
=
GetIntConst
(
GetItem
(
info
->
shape_
,
-
1
));
if
(
info
->
var_
.
size
()
<
3
||
last_shape1
>
params
.
block_size
)
{
return
true
;
}
auto
last_shape2
=
GetIntConst
(
GetItem
(
info
->
shape_
,
-
2
));
auto
last_stride2
=
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
));
auto
last_stride3
=
GetIntConst
(
GetItem
(
info
->
strides_
,
-
3
));
return
last_stride2
%
params
.
block_size
!=
0
||
last_stride3
%
params
.
block_size
!=
0
||
(
last_stride2
>
0
&&
last_shape1
>
0
&&
last_stride2
<
last_shape1
)
||
(
last_stride3
>
0
&&
last_shape2
>
0
&&
last_stride3
<
last_shape2
);
};
if
(
std
::
any_of
(
src_info_list
.
begin
(),
src_info_list
.
end
(),
JudgeNot3D
))
{
return
not_this_pattern
;
}
if
(
mode
==
"reduction"
)
{
// check same alignment
Array
<
Expr
>
shape_list
=
{
GetItem
(
params
.
dst_shape
,
-
1
)};
shape_list
.
push_back
(
GetItem
(
params
.
src_shape0
,
-
1
));
shape_list
.
push_back
(
GetItem
(
params
.
src_shape1
,
-
1
));
if
(
!
IsNonZeroShapeEqual
(
shape_list
))
{
return
not_this_pattern
;
}
}
// repeat axis is shape [-3], repeat once, has 8 loops
bool
is3_d
=
true
;
float
rate3d_mode1
=
not_this_pattern
;
float
rate3d_mode2
=
not_this_pattern
;
int
repeat_num
;
float
repeat_latency
;
auto
info_list
=
src_info_list
;
Insert
(
info_list
,
0
,
dst_info
);
for
(
auto
info
:
info_list
)
{
if
(
GetInt32Const
(
GetItem
(
info
->
shape_
,
-
2
))
>
FULL_BLOCK_NUM
||
GetInt32Const
(
GetItem
(
info
->
strides_
,
-
2
))
/
params
.
block_size
>=
MAX_STRIDE_M0
||
GetInt32Const
(
GetItem
(
info
->
strides_
,
-
3
))
/
params
.
block_size
>=
MAX_STRIDE_M0
)
{
is3_d
=
false
;
break
;
}
}
if
(
is3_d
)
{
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
repeat_num
=
params
.
non_zero_shape3
;
repeat_latency
=
((
repeat_num
-
1
)
/
MAX_REPEAT
)
*
repeat_latency_coef
;
rate3d_mode1
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
vec_max_len
/
(
repeat_num
+
repeat_latency
);
}
is3_d
=
true
;
// repeat axis is shape[-2]
for
(
auto
info
:
info_list
)
{
// stride_m0 should be less than 256
if
(
GetIntConst
(
GetItem
(
info
->
shape_
,
-
3
))
%
FULL_BLOCK_NUM
!=
0
||
GetIntConst
(
GetItem
(
info
->
strides_
,
-
3
))
/
params
.
block_size
>=
MAX_STRIDE_M0
)
{
is3_d
=
false
;
break
;
}
}
if
(
is3_d
)
{
if
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
3
))
==
0
)
{
return
not_this_pattern
;
}
repeat_num
=
params
.
non_zero_shape2
*
(
params
.
non_zero_shape3
/
FULL_BLOCK_NUM
);
repeat_latency
=
((
repeat_num
-
1
)
/
MAX_REPEAT
)
*
repeat_latency_coef
;
float
offset_latency
=
params
.
non_zero_shape3
/
FULL_BLOCK_NUM
>
1
?
params
.
non_zero_shape3
*
offset_latency_coef
:
0
;
rate3d_mode2
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
);
}
return
rate3d_mode1
>
rate3d_mode2
?
rate3d_mode1
:
rate3d_mode2
;
}
float
BinaryVecPatternGenerator
::
Compute2DBlockPatternMaskRate
()
{
if
(
params
.
non_zero_shape2
==
1
||
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
1
))
!=
1
)
{
return
not_this_pattern
;
}
if
(
params
.
dst_var
.
size
()
<
2
||
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
))
>
params
.
block_size
||
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
block_size
!=
0
||
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
/
params
.
block_size
>=
MAX_STRIDE_M0
||
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
>
0
&&
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
))
>
0
&&
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
<
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
))))
{
return
not_this_pattern
;
}
for
(
auto
info
:
src_info_list
)
{
if
(
info
->
var_
.
size
()
<
2
||
GetIntConst
(
GetItem
(
info
->
shape_
,
-
1
))
>
params
.
block_size
||
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
))
%
params
.
block_size
!=
0
||
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
))
/
params
.
block_size
>=
MAX_STRIDE_M0
||
(
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
))
>
0
&&
GetIntConst
(
GetItem
(
info
->
shape_
,
-
1
))
>
0
&&
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
))
<
GetIntConst
(
GetItem
(
info
->
shape_
,
-
1
))))
{
return
not_this_pattern
;
}
}
if
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
if
(
mode
==
"reduction"
)
{
if
(
params
.
dst_var
.
size
()
>
2
)
{
// if not elewise mode, then can not use partial 3D mode
if
(
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
3
))
==
0
)
{
return
not_this_pattern
;
}
for
(
auto
info
:
src_info_list
)
{
if
(
GetIntConst
(
GetItem
(
info
->
shape_
,
-
3
))
==
0
)
{
return
not_this_pattern
;
}
}
}
// check same alignment
Array
<
Expr
>
shape_list
=
{
GetItem
(
params
.
dst_shape
,
-
1
)};
shape_list
.
push_back
(
GetItem
(
params
.
src_shape0
,
-
1
));
shape_list
.
push_back
(
GetItem
(
params
.
src_shape1
,
-
1
));
// check dst_stride_m0
// As described in ISL User Guide t6.3,
// dst_stride_m0 = 0 is treated as 1
if
(
!
IsNonZeroShapeEqual
(
shape_list
))
{
return
not_this_pattern
;
}
}
int
repeat_body_num
=
params
.
non_zero_shape2
/
FULL_BLOCK_NUM
;
int
repeat_tail_num
=
(
params
.
non_zero_shape2
%
FULL_BLOCK_NUM
+
FULL_BLOCK_NUM
-
1
)
/
FULL_BLOCK_NUM
;
int
repeat_num
=
(
repeat_body_num
+
repeat_tail_num
)
*
params
.
non_zero_shape3
;
float
repeat_latency
=
(
std
::
max
(
repeat_body_num
-
1
,
0
)
/
MAX_REPEAT
+
std
::
max
(
repeat_tail_num
-
1
,
0
)
/
MAX_REPEAT
)
*
repeat_latency_coef
;
float
offset_latency
=
params
.
non_zero_shape3
>
1
?
params
.
non_zero_shape3
*
offset_latency_coef
:
0
;
float
split_latency
=
(
repeat_body_num
>
0
&&
repeat_tail_num
>
0
)
?
split_latency_coef
:
0
;
float
rate2db
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
+
split_latency
);
return
rate2db
;
}
float
BinaryVecPatternGenerator
::
Compute2DPatternMaskRate
()
{
if
(
params
.
non_zero_shape2
==
1
)
{
return
not_this_pattern
;
}
if
(
params
.
dst_var
.
size
()
<
2
||
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
block_size
!=
0
||
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
<
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
))
&&
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
)
>
0
)))
{
return
not_this_pattern
;
}
for
(
auto
info
:
src_info_list
)
{
if
(
info
->
var_
.
size
()
<
2
||
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
))
%
params
.
block_size
!=
0
||
(
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
))
<
GetIntConst
(
GetItem
(
info
->
shape_
,
-
1
))
&&
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
)
>
0
)))
{
return
not_this_pattern
;
}
}
// check num of insns, select 1D pattern or 2D pattern
int
tail_factor
=
0
;
if
(
mode
==
"reduction"
)
{
Array
<
Expr
>
shape_list
=
{
GetItem
(
params
.
dst_shape
,
-
1
)};
shape_list
.
push_back
(
GetItem
(
params
.
src_shape0
,
-
1
));
shape_list
.
push_back
(
GetItem
(
params
.
src_shape1
,
-
1
));
if
(
!
IsNonZeroShapeEqual
(
shape_list
))
{
return
not_this_pattern
;
}
}
// only cloud allow dst_stride_m1 = 0
cceconf
::
CceConf
*
conf
=
cceconf
::
CceConf
::
getInstance
();
const
std
::
string
product_name
=
conf
->
getProductName
();
if
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
&&
product_name
!=
"cloud"
)
{
return
not_this_pattern
;
}
CHECK_NE
(
params
.
vec_max_len
,
0
);
if
(
params
.
non_zero_shape1
/
params
.
vec_max_len
>
0
&&
params
.
non_zero_shape1
%
params
.
vec_max_len
>
0
)
{
tail_factor
=
1
;
}
if
(
GetIntConst
(
GetItem
(
dst_info
->
strides_
,
-
2
))
/
params
.
block_size
>=
MAX_STRIDE_M0
)
{
return
not_this_pattern
;
}
for
(
auto
info
:
src_info_list
)
{
if
(
GetIntConst
(
GetItem
(
info
->
strides_
,
-
2
))
/
params
.
block_size
>=
MAX_STRIDE_M0
)
{
return
not_this_pattern
;
}
}
int
shape1
=
(
params
.
non_zero_shape1
+
params
.
vec_max_len
-
1
)
/
params
.
vec_max_len
;
int
repeat_num
=
shape1
*
params
.
non_zero_shape2
*
params
.
non_zero_shape3
;
float
repeat_latency
=
(
std
::
max
(
params
.
non_zero_shape2
-
1
,
0
)
/
MAX_REPEAT
)
*
params
.
non_zero_shape3
*
shape1
*
repeat_latency_coef
;
float
offset_latency
=
shape1
*
params
.
non_zero_shape3
>
1
?
shape1
*
params
.
non_zero_shape3
*
offset_latency_coef
:
0
;
float
split_latency
=
tail_factor
*
split_latency_coef
;
float
rate2d
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
+
split_latency
);
return
rate2d
;
}
float
BinaryVecPatternGenerator
::
Compute1DPatternMaskRate
()
{
int
tail_factor
=
0
;
if
(
params
.
non_zero_shape1
/
params
.
vec_max_len
>
0
&&
params
.
non_zero_shape1
%
params
.
vec_max_len
>
0
)
{
tail_factor
=
1
;
}
int
shape1
=
(
params
.
non_zero_shape1
+
params
.
vec_max_len
-
1
)
/
params
.
vec_max_len
;
int
repeat_num
=
shape1
*
params
.
non_zero_shape2
*
params
.
non_zero_shape3
;
float
repeat_latency
=
std
::
max
((
shape1
-
1
)
/
MAX_REPEAT
,
0
)
*
params
.
non_zero_shape2
*
params
.
non_zero_shape3
*
repeat_latency_coef
;
float
offset_latency
=
params
.
non_zero_shape2
*
params
.
non_zero_shape3
>
1
?
params
.
non_zero_shape2
*
params
.
non_zero_shape3
*
offset_latency_coef
:
0
;
float
split_latency
=
tail_factor
*
split_latency_coef
;
float
rate1d
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
+
split_latency
);
return
rate1d
;
}
Array
<
Var
>
BinaryVecPatternGenerator
::
Get3DPattern
()
{
// repeat axis is shape [-2]
int
second_last_shape
=
GetInt32Const
(
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape0
,
-
2
),
GetItem
(
params
.
src_shape1
,
-
2
)));
int
third_last_shape
=
GetInt32Const
(
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
3
),
GetItem
(
params
.
src_shape0
,
-
3
),
GetItem
(
params
.
src_shape1
,
-
3
)));
if
(
second_last_shape
>
8
)
{
// split shape[-3]
if
(
third_last_shape
>
8
)
{
auto
info_list
=
src_info_list
;
Insert
(
info_list
,
0
,
dst_info
);
SplitAxis
(
info_list
,
for_info
,
GetItem
(
params
.
dst_var
,
-
3
),
FULL_BLOCK_NUM
);
FillEmptyVar
(
info_list
);
params
.
dst_var
=
info_list
[
0
]
->
var_
;
params
.
dst_shape
=
info_list
[
0
]
->
shape_
;
params
.
dst_strides
=
info_list
[
0
]
->
strides_
;
params
.
src_var0
=
info_list
[
1
]
->
var_
;
params
.
src_shape0
=
info_list
[
1
]
->
shape_
;
params
.
src_strides0
=
info_list
[
1
]
->
strides_
;
params
.
src_var1
=
info_list
[
2
]
->
var_
;
params
.
src_shape1
=
info_list
[
2
]
->
shape_
;
params
.
src_strides1
=
info_list
[
2
]
->
strides_
;
}
// consider original shape[-2] as repeat axis
GetShapeInfoAndSwap
(
params
.
dst_var
,
params
.
dst_shape
,
params
.
dst_strides
,
-
2
,
-
3
);
GetShapeInfoAndSwap
(
params
.
src_var0
,
params
.
src_shape0
,
params
.
src_strides0
,
-
2
,
-
3
);
GetShapeInfoAndSwap
(
params
.
src_var1
,
params
.
src_shape1
,
params
.
src_strides1
,
-
2
,
-
3
);
}
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
CHECK
(
body_args
.
GetNode
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
repeat_
=
GetItem
(
params
.
dst_shape
,
-
3
);
body_args
.
GetNode
()
->
dst_stride_m0_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
block_size
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
3
),
params
.
block_size
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides0
,
-
2
),
params
.
block_size
),
truncdiv
(
GetItem
(
params
.
src_strides1
,
-
2
),
params
.
block_size
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides0
,
-
3
),
params
.
block_size
),
truncdiv
(
GetItem
(
params
.
src_strides1
,
-
3
),
params
.
block_size
)};
int
data_num
=
GetInt32Const
(
GetItem
(
params
.
dst_shape
,
-
2
));
if
(
mode
==
"reduction"
)
{
body_args
.
GetNode
()
->
repeat_
=
Expr
(
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
3
),
GetItem
(
params
.
src_shape0
,
-
3
),
GetItem
(
params
.
src_shape1
,
-
3
)));
data_num
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape0
,
-
2
),
GetItem
(
params
.
src_shape1
,
-
2
));
}
int
data_len
=
expand_mask
?
CeilTo
(
params
.
last_dim_shape
,
params
.
block_size
)
:
params
.
last_dim_shape
;
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
data_num
,
dst_info
->
dtype_
);
return
GetRange
(
params
.
dst_var
,
-
3
,
3
);
}
Array
<
Var
>
BinaryVecPatternGenerator
::
Get2DBlockPattern
()
{
int
repeat_len
=
GetInt32Const
(
GetItem
(
params
.
dst_shape
,
-
2
));
if
(
mode
==
"reduction"
)
{
params
.
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape0
,
-
1
),
GetItem
(
params
.
src_shape1
,
-
1
));
repeat_len
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape0
,
-
2
),
GetItem
(
params
.
src_shape1
,
-
2
));
}
int
repeat_body
=
repeat_len
/
FULL_BLOCK_NUM
;
int
repeat_tail
=
(
repeat_len
%
FULL_BLOCK_NUM
+
FULL_BLOCK_NUM
-
1
)
/
FULL_BLOCK_NUM
;
if
(
repeat_body
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
CHECK
(
body_args
.
GetNode
()
!=
nullptr
);
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
repeat_
=
Expr
(
repeat_body
);
body_args
.
GetNode
()
->
dst_stride_m0_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
block_size
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
body_args
->
dst_stride_m0_
*
FULL_BLOCK_NUM
;
Expr
src0_stride_m0
=
truncdiv
(
GetItem
(
params
.
src_strides0
,
-
2
),
params
.
block_size
);
Expr
src1_stride_m0
=
truncdiv
(
GetItem
(
params
.
src_strides1
,
-
2
),
params
.
block_size
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
src0_stride_m0
,
src1_stride_m0
};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
src0_stride_m0
*
FULL_BLOCK_NUM
,
src1_stride_m0
*
FULL_BLOCK_NUM
};
int
data_len
=
expand_mask
?
CeilTo
(
params
.
last_dim_shape
,
params
.
block_size
)
:
params
.
last_dim_shape
;
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
FULL_BLOCK_NUM
,
dst_info
->
dtype_
);
}
if
(
repeat_tail
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
CHECK
(
tail_args
.
GetNode
()
!=
nullptr
);
tail_args
.
GetNode
()
->
dst_head_
=
GetItem
(
params
.
dst_strides
,
-
2
)
*
repeat_body
*
FULL_BLOCK_NUM
;
tail_args
.
GetNode
()
->
src_head_list_
=
{
GetItem
(
params
.
src_strides0
,
-
2
)
*
repeat_body
*
FULL_BLOCK_NUM
,
GetItem
(
params
.
src_strides1
,
-
2
)
*
repeat_body
*
FULL_BLOCK_NUM
};
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m0_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
block_size
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
0
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides0
,
-
2
),
params
.
block_size
),
truncdiv
(
GetItem
(
params
.
src_strides1
,
-
2
),
params
.
block_size
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
0
),
Expr
(
0
)};
int
data_len
=
expand_mask
?
CeilTo
(
params
.
last_dim_shape
,
params
.
block_size
)
:
params
.
last_dim_shape
;
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
repeat_len
%
FULL_BLOCK_NUM
,
dst_info
->
dtype_
);
}
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
Array
<
Var
>
BinaryVecPatternGenerator
::
Get2DPattern
()
{
if
(
mode
==
"reduction"
)
{
params
.
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape0
,
-
1
),
GetItem
(
params
.
src_shape1
,
-
1
));
}
int
body_len
=
FloorTo
(
params
.
last_dim_shape
,
params
.
vec_max_len
);
int
tail_len
=
params
.
last_dim_shape
%
params
.
vec_max_len
;
if
(
body_len
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
CHECK
(
body_args
.
GetNode
()
!=
nullptr
);
body_args
.
GetNode
()
->
body_num_
=
body_len
/
params
.
vec_max_len
;
body_args
.
GetNode
()
->
body_offset_
=
params
.
vec_max_len
;
body_args
.
GetNode
()
->
repeat_
=
GetItem
(
params
.
dst_shape
,
-
2
);
if
(
mode
==
"reduction"
)
{
body_args
.
GetNode
()
->
repeat_
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape0
,
-
2
),
GetItem
(
params
.
src_shape1
,
-
2
));
}
body_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
1
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
block_size
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
),
Expr
(
1
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides0
,
-
2
),
params
.
block_size
),
truncdiv
(
GetItem
(
params
.
src_strides1
,
-
2
),
params
.
block_size
)};
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
vec_max_len
,
1
,
dst_info
->
dtype_
);
}
if
(
tail_len
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
CHECK
(
tail_args
.
GetNode
()
!=
nullptr
);
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
body_len
);
tail_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
body_len
),
Expr
(
body_len
)};
tail_args
.
GetNode
()
->
repeat_
=
GetItem
(
params
.
dst_shape
,
-
2
);
if
(
mode
==
"reduction"
)
{
tail_args
.
GetNode
()
->
repeat_
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape0
,
-
2
),
GetItem
(
params
.
src_shape1
,
-
2
));
}
tail_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
block_size
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
),
Expr
(
1
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides0
,
-
2
),
params
.
block_size
),
truncdiv
(
GetItem
(
params
.
src_strides1
,
-
2
),
params
.
block_size
)};
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
tail_len
,
1
,
dst_info
->
dtype_
);
}
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
Array
<
Var
>
BinaryVecPatternGenerator
::
Get1DPattern
()
{
auto
info_list
=
src_info_list
;
Insert
(
info_list
,
0
,
dst_info
);
bool
is_scalar_mode
=
IsScalarMode
(
info_list
);
if
(
is_scalar_mode
)
{
params
.
last_dim_shape
=
1
;
}
if
(
mode
==
"reduction"
)
{
params
.
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape0
,
-
1
),
GetItem
(
params
.
src_shape1
,
-
1
));
}
int
body_len
=
FloorTo
(
params
.
last_dim_shape
,
params
.
vec_max_len
);
int
tail_len
=
params
.
last_dim_shape
%
params
.
vec_max_len
;
int
last_axis
=
-
1
;
if
(
mode
==
"broadcast"
)
{
if
(
GetIntConst
(
GetItem
(
params
.
src_strides0
,
-
1
))
==
0
)
{
last_axis
=
0
;
}
if
(
GetIntConst
(
GetItem
(
params
.
src_strides1
,
-
1
))
==
0
)
{
last_axis
=
1
;
}
}
if
(
body_len
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
CHECK
(
body_args
.
GetNode
()
!=
nullptr
);
body_args
.
GetNode
()
->
last_axis_info_
.
src_index_
=
last_axis
;
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
repeat_
=
body_len
/
params
.
vec_max_len
;
body_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
1
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
FULL_BLOCK_NUM
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
),
Expr
(
1
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
FULL_BLOCK_NUM
),
Expr
(
FULL_BLOCK_NUM
)};
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
vec_max_len
,
1
,
dst_info
->
dtype_
);
}
if
(
tail_len
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
CHECK
(
tail_args
.
GetNode
()
!=
nullptr
);
tail_args
.
GetNode
()
->
last_axis_info_
.
src_index_
=
last_axis
;
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
body_len
);
tail_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
body_len
),
Expr
(
body_len
)};
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
0
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
),
Expr
(
1
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
0
),
Expr
(
0
)};
int
data_len
=
expand_mask
?
CeilTo
(
tail_len
,
params
.
block_size
)
:
tail_len
;
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
1
,
dst_info
->
dtype_
);
}
// compute offset for cce instructions
Array
<
Var
>
elim_var
=
{};
if
(
mode
==
"elewise"
&&
params
.
dst_var
.
size
()
>=
2
&&
params
.
dst_strides
.
size
()
>=
2
&&
params
.
last_dim_shape
<=
params
.
vec_max_len
&&
for_info
.
ops_
.
size
()
>=
2
&&
params
.
last_dim_shape
>=
params
.
vec_max_len
-
params
.
block_size
&&
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
params
.
last_dim_shape
)
{
// in this case we can merge second last for extent to repeat
size_t
idx
=
0
;
bool
suc
=
GetIndexOfElement
(
for_info
.
vars_
,
GetItem
(
params
.
dst_var
,
-
2
),
idx
);
CHECK
(
suc
);
auto
latest_for
=
GetItem
(
for_info
.
ops_
,
idx
).
as
<
For
>
();
// there should not be if_op between for loop and compute stmt
if
(
latest_for
&&
!
latest_for
->
body
->
IsInstance
<
IfThenElse
>
())
{
if
(
!
params
.
dst_var
.
empty
()
&&
!
is_scalar_mode
)
{
if
(
body_args
.
defined
())
{
// last_dim_shape = vec_max_len
body_args
.
GetNode
()
->
repeat_
=
body_args
->
repeat_
*
latest_for
->
extent
;
}
else
if
(
tail_args
.
defined
())
{
// last_dim_shape < vec_max_len
tail_args
.
GetNode
()
->
repeat_
=
tail_args
->
repeat_
*
latest_for
->
extent
;
}
return
elim_var
=
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
}
}
if
(
!
params
.
dst_var
.
empty
()
&&
!
is_scalar_mode
)
{
elim_var
=
GetRange
(
params
.
dst_var
,
-
1
,
1
);
}
return
elim_var
;
}
PatternResult
BinaryVecPatternGenerator
::
GenResult
(
const
Array
<
Var
>
&
elim_var
)
{
arg_info
.
GetNode
()
->
body_arg_info_
=
body_args
;
arg_info
.
GetNode
()
->
tail_arg_info_
=
tail_args
;
auto
real_elim_var
=
elim_var
;
if
(
!
empty_var
->
name_hint
.
empty
())
{
bool
need_elim
=
true
;
for
(
auto
e
:
elim_var
)
{
if
(
e
->
name_hint
==
empty_var
->
name_hint
)
{
need_elim
=
false
;
break
;
}
}
if
(
need_elim
)
{
real_elim_var
.
push_back
(
empty_var
);
}
}
dst_info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
dst_info
,
real_elim_var
);
for
(
auto
&
info
:
src_info_list
)
{
info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
info
,
real_elim_var
);
}
CleanForInfoVars
(
for_info
,
real_elim_var
);
CleanZeroStrides
(
dst_info
);
CleanZeroStrides
(
src_info_list
);
if
(
mode
==
"elewise"
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_ELEWISE
;
}
else
if
(
mode
==
"broadcast"
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_BROADCAST
;
}
else
if
(
mode
==
"reduction"
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_REDUCTION
;
}
PatternResult
result
;
result
.
dst_info_list
=
{
dst_info
};
result
.
src_info_list
=
src_info_list
;
result
.
for_info
=
for_info
;
result
.
arg_info
=
arg_info
;
return
result
;
}
void
BinaryVecPatternGenerator
::
CalcParams
()
{
CHECK_GE
(
src_info_list
.
size
(),
2
);
StmtStoreInfo
src_info0
=
src_info_list
[
0
];
StmtStoreInfo
src_info1
=
src_info_list
[
1
];
StmtInfoList
info_list
=
{
dst_info
,
src_info0
,
src_info1
};
// check shape len
for
(
auto
info
:
info_list
)
{
if
(
info
->
shape_
.
empty
())
{
LOG
(
FATAL
)
<<
"CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0."
;
}
}
// check data type
for
(
auto
src_info
:
src_info_list
)
{
if
(
dst_info
->
dtype_
!=
src_info
->
dtype_
)
{
LOG
(
FATAL
)
<<
"CCE Vector Insn Error: dst_buffer and src_buffer can not be different data type."
;
}
}
params
.
last_dim_shape
=
GetInt32Const
(
GetItem
(
dst_info
->
shape_
,
-
1
));
AppendEmptyVar
(
info_list
);
if
(
arg_info
->
arg_type_
==
ARG_VECTOR_BROADCAST_LAST_AXIS
)
{
return
;
}
if
(
mode
==
"reduction"
||
mode
==
"broadcast"
)
{
FillEmptyVar
(
info_list
);
}
CHECK_EQ
(
info_list
.
size
(),
3
);
dst_info
=
info_list
[
0
];
src_info0
=
info_list
[
1
];
src_info1
=
info_list
[
2
];
params
.
vec_max_len
=
GetVecMaxLen
(
dst_info
->
dtype_
);
params
.
block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
CHECK_NE
(
params
.
vec_max_len
,
0
);
CHECK_NE
(
params
.
block_size
,
0
);
params
.
dst_var
=
dst_info
->
var_
;
params
.
dst_shape
=
dst_info
->
shape_
;
params
.
dst_strides
=
dst_info
->
strides_
;
params
.
src_var0
=
src_info0
->
var_
;
params
.
src_var1
=
src_info1
->
var_
;
params
.
src_shape0
=
src_info0
->
shape_
;
params
.
src_shape1
=
src_info1
->
shape_
;
params
.
src_strides0
=
src_info0
->
strides_
;
params
.
src_strides1
=
src_info1
->
strides_
;
auto
GetNonZeroShapeByIdx
=
[
this
](
int
index
)
->
int
{
if
(
index
<=
static_cast
<
int
>
(
params
.
dst_var
.
size
()))
{
if
(
Equal
(
GetItem
(
params
.
dst_var
,
-
index
),
GetItem
(
params
.
src_var0
,
-
index
))
&&
Equal
(
GetItem
(
params
.
dst_var
,
-
index
),
GetItem
(
params
.
src_var1
,
-
index
)))
{
return
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
index
),
GetItem
(
params
.
src_shape0
,
-
index
),
GetItem
(
params
.
src_shape1
,
-
index
));
}
}
return
1
;
};
params
.
non_zero_shape1
=
GetNonZeroShapeByIdx
(
1
);
params
.
non_zero_shape2
=
GetNonZeroShapeByIdx
(
2
);
params
.
non_zero_shape3
=
GetNonZeroShapeByIdx
(
3
);
params
.
all_points
=
params
.
non_zero_shape1
*
params
.
non_zero_shape2
*
params
.
non_zero_shape3
;
}
bool
BinaryVecPatternGenerator
::
IsSamePatternComInfo
(
const
StmtStoreInfo
&
info_a
,
const
StmtStoreInfo
&
info_b
)
{
if
(
IsSame
(
info_a
->
var_
,
info_b
->
var_
))
{
if
(
info_a
->
shape_
.
size
()
!=
info_b
->
shape_
.
size
())
{
return
false
;
}
for
(
size_t
i
=
0
;
i
<
info_a
->
shape_
.
size
();
++
i
)
{
if
(
!
IsTwoItemEqual
(
info_a
->
shape_
,
info_b
->
shape_
,
static_cast
<
int
>
(
i
),
true
))
{
return
false
;
}
}
if
(
info_a
->
strides_
.
size
()
!=
info_b
->
strides_
.
size
())
{
return
false
;
}
for
(
size_t
i
=
0
;
i
<
info_a
->
strides_
.
size
();
++
i
)
{
if
(
!
IsTwoItemEqual
(
info_a
->
strides_
,
info_b
->
strides_
,
static_cast
<
int
>
(
i
),
true
))
{
return
false
;
}
}
return
true
;
}
return
false
;
}
bool
BinaryVecPatternGenerator
::
IsNonZeroShapeEqual
(
const
Array
<
Expr
>
&
shape_list
)
{
Array
<
Expr
>
non_zero_list
;
for
(
auto
shape
:
shape_list
)
{
if
(
GetIntConst
(
shape
)
!=
0
)
{
non_zero_list
.
push_back
(
shape
);
}
}
if
(
non_zero_list
.
empty
())
{
LOG
(
FATAL
)
<<
"Error: all shapes are equal to 0."
;
}
for
(
auto
shape
:
non_zero_list
)
{
if
(
GetIntConst
(
shape
)
!=
GetIntConst
(
non_zero_list
[
0
]))
{
return
false
;
}
}
return
true
;
}
void
BinaryVecPatternGenerator
::
AppendEmptyVar
(
StmtInfoList
&
info_list
)
{
auto
FillEmptyVarToLast
=
[](
const
StmtStoreInfo
com_info
,
const
Var
&
var
)
->
void
{
com_info
.
GetNode
()
->
var_
.
push_back
(
var
);
com_info
.
GetNode
()
->
shape_
.
push_back
(
Expr
(
1
));
com_info
.
GetNode
()
->
strides_
.
push_back
(
Expr
(
1
));
com_info
.
GetNode
()
->
index_
=
com_info
->
index_
+
GetItem
(
com_info
->
var_
,
-
1
);
};
auto
src_info0
=
src_info_list
[
0
];
auto
src_info1
=
src_info_list
[
1
];
if
(
mode
==
"reduction"
||
mode
==
"broadcast"
)
{
// ISA 8.1.2, strides of Xd must be equal to Xm, [Xd = dst, Xn = src0, Xm = src1]
if
(
IsSamePatternComInfo
(
dst_info
,
src_info0
))
{
auto
tmp
=
src_info0
;
src_info0
=
src_info1
;
src_info1
=
tmp
;
}
if
(
mode
==
"reduction"
)
{
if
(
src_info0
->
data_alignment_
==
1
)
{
empty_var
=
Var
(
"empty_cc"
);
FillEmptyVarToLast
(
src_info0
,
empty_var
);
}
}
else
if
(
mode
==
"broadcast"
)
{
// last dim broadcast, should use VS insn, such as vadds and vmuls
bool
less_var
=
!
dst_info
->
var_
.
empty
()
&&
!
src_info0
->
var_
.
empty
()
&&
!
src_info1
->
var_
.
empty
()
&&
(
!
IsTwoItemEqual
(
dst_info
->
var_
,
src_info0
->
var_
,
-
1
)
||
!
IsTwoItemEqual
(
dst_info
->
var_
,
src_info1
->
var_
,
-
1
));
bool
null_var
=
src_info0
->
var_
.
empty
()
||
src_info1
->
var_
.
empty
();
if
(
less_var
||
null_var
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_BROADCAST_LAST_AXIS
;
return
;
}
else
if
(
dst_info
->
data_alignment_
==
1
&&
src_info0
->
data_alignment_
==
1
)
{
empty_var
=
Var
(
"empty_cc"
);
FillEmptyVarToLast
(
dst_info
,
empty_var
);
FillEmptyVarToLast
(
src_info0
,
empty_var
);
FillEmptyVarToLast
(
src_info1
,
empty_var
);
params
.
last_dim_shape
=
1
;
}
}
src_info_list
=
{
src_info0
,
src_info1
};
info_list
=
{
dst_info
,
src_info0
,
src_info1
};
}
}
/// Get CCE Binary Vector Insn Computation Info
/// \param stmt - operand stmt
/// \param intrin_name - vector intrin name
/// \param dst_info_list - dst computation info list
/// \param src_info_list - src computation info list
/// \param if_info - if info list
/// \param for_info - for info list
/// \return intrin args
ArgInfo
GetBinaryVecInsnArgs
(
const
Stmt
&
stmt
,
std
::
string
intrin_name
,
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
bool
enable_bisect
)
{
// check intrin_name
std
::
set
<
std
::
string
>
intrin_name_list
=
{
"vadd"
,
"vmax"
,
"vmin"
,
"vmul"
,
"vdiv"
,
"vsel"
,
"vsub"
,
"vand"
,
"vor"
,
"vaxpy"
,
"argmax"
,
"argmin"
,
"vmadd"
,
"vmaddrelu"
,
"vmla"
};
if
(
intrin_name_list
.
count
(
intrin_name
)
==
0
)
{
LOG
(
FATAL
)
<<
"Error: CCE Binary Vector Insn doesn't support the given intrin_name."
;
}
// get and check dst and src
GetCompactComputationInfo
(
stmt
,
dst_info_list
,
src_info_list
,
if_info
,
for_info
,
true
);
// For vmadd/vmaddrelu/vmla we only need first two src
if
(
dst_info_list
.
size
()
!=
1
||
src_info_list
.
size
()
<
2
)
{
LOG
(
FATAL
)
<<
"CCE Binary Vector Insn only support ONE dst and TWO srcs."
;
}
src_info_list
=
GetRange
(
src_info_list
,
0
,
2
);
ArgInfo
arg_info
=
ArgInfo
(
make_node
<
ArgInfoNode
>
());
// detect vector op mode
std
::
string
mode
=
GetBinaryVecMode
(
dst_info_list
,
src_info_list
,
intrin_name
,
enable_bisect
);
if
(
mode
==
"reduce_last_axis"
)
{
size_t
src_var_list_size
=
src_info_list
[
1
]
->
var_
.
size
();
if
(
src_info_list
[
0
]
->
var_
.
size
()
>
src_info_list
[
1
]
->
var_
.
size
())
{
src_var_list_size
=
src_info_list
[
0
]
->
var_
.
size
();
}
CHECK
(
src_var_list_size
>
0
)
<<
"Error: src can not be a scalar."
;
if
(
src_var_list_size
-
dst_info_list
[
0
]
->
var_
.
size
()
==
1
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_REDUCTION_LAST_AXIS
;
}
else
{
LOG
(
FATAL
)
<<
"Error: cannot support multi-last-axis reduction."
;
}
}
else
if
(
mode
==
"reduce_bisection"
)
{
arg_info
.
GetNode
()
->
arg_type_
=
ARG_VECTOR_REDUCTION_BISECTION
;
}
else
{
if
(
mode
!=
"reduction"
&&
mode
!=
"broadcast"
)
{
FillLastDim
(
dst_info_list
,
src_info_list
,
for_info
);
}
// vmax/vmin can't expand mask because it may introduce dirty data
bool
can_expand_mask
=
intrin_name
!=
"vmax"
&&
intrin_name
!=
"vmin"
;
BinaryVecPatternGenerator
generator
=
BinaryVecPatternGenerator
(
dst_info_list
,
src_info_list
,
for_info
,
mode
,
can_expand_mask
);
auto
params
=
generator
.
GetInsnArgs
();
arg_info
=
params
.
arg_info
;
dst_info_list
=
params
.
dst_info_list
;
src_info_list
=
params
.
src_info_list
;
for_info
=
params
.
for_info
;
if
(
mode
==
"broadcast"
)
{
bool
has_last_axis
=
false
;
if
((
arg_info
->
body_arg_info_
.
defined
()
&&
arg_info
->
body_arg_info_
->
last_axis_info_
.
src_index_
!=
-
1
)
||
(
arg_info
->
tail_arg_info_
.
defined
()
&&
arg_info
->
tail_arg_info_
->
last_axis_info_
.
src_index_
!=
-
1
))
{
has_last_axis
=
true
;
}
if
(
has_last_axis
&&
(
intrin_name
==
"vadd"
||
intrin_name
==
"vmul"
))
{
Array
<
NodeRef
>
stores
;
Array
<
NodeRef
>
loads
;
GetStoreAndLoads
(
stmt
,
stores
,
loads
);
intrin_name
=
intrin_name
+
"s"
;
if
(
arg_info
->
body_arg_info_
.
defined
())
{
arg_info
.
GetNode
()
->
body_arg_info_
.
GetNode
()
->
last_axis_info_
.
intrin_name_
=
intrin_name
;
arg_info
.
GetNode
()
->
body_arg_info_
.
GetNode
()
->
last_axis_info_
.
src_op_
=
Downcast
<
Expr
>
(
loads
[
arg_info
->
body_arg_info_
->
last_axis_info_
.
src_index_
]);
}
}
}
}
return
arg_info
;
}
/// Replace com_info's var with new for loop's var
/// \param info
/// \param old_for_info
/// \param new_for_info
void
ReplaceVarWithNewForInfo
(
StmtStoreInfo
&
info
,
const
StmtInfo
&
old_for_info
,
const
StmtInfo
&
new_for_info
)
{
for
(
size_t
i
=
0
;
i
<
new_for_info
.
vars_
.
size
();
++
i
)
{
for
(
size_t
j
=
0
;
j
<
info
->
var_
.
size
();
++
j
)
{
if
(
info
->
var_
[
j
]
->
name_hint
==
new_for_info
.
vars_
[
i
]
->
name_hint
)
{
SetItem
(
info
.
GetNode
()
->
var_
,
static_cast
<
int
>
(
j
),
new_for_info
.
vars_
[
i
]);
}
}
info
.
GetNode
()
->
index_
=
substitute
(
old_for_info
.
vars_
[
i
],
new_for_info
.
vars_
[
i
],
info
->
index_
);
}
}
/// Generete info list for bisection intrin
/// \param dst_info_list
/// \param src_info_list
/// \param for_info
/// \param if_info
/// \param last_axis
/// \param postfix
/// \return
BisectionInfoWrapper
SeparateComInfoToBisectionInfoList
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
StmtInfo
&
if_info
,
bool
last_axis
,
int
postfix
=
0
)
{
CHECK_EQ
(
dst_info_list
.
size
(),
1
);
CHECK_EQ
(
src_info_list
.
size
(),
2
);
BisectionInfoWrapper
wrapper
;
// Separate com_info and for_info
int
compare_idx
=
1
;
int
var_idx
=
-
1
;
if
(
last_axis
)
{
compare_idx
=
GetLastAxisReductionIdx
(
dst_info_list
,
src_info_list
);
}
else
{
var_idx
=
GetBisectionReductionIdx
(
dst_info_list
,
src_info_list
,
compare_idx
);
}
StmtStoreInfo
dst_info
=
dst_info_list
[
0
];
CHECK_GE
(
compare_idx
,
0
);
StmtStoreInfo
src_info1
=
src_info_list
[
compare_idx
];
Var
reduce_var
=
GetItem
(
src_info1
->
var_
,
var_idx
);
size_t
for_idx
=
0
;
bool
suc
=
GetIndexOfElement
(
for_info
.
vars_
,
VarExpr
(
reduce_var
),
for_idx
);
CHECK
(
suc
);
auto
exist_for
=
GetItem
(
for_info
.
ops_
,
for_idx
).
as
<
For
>
();
CHECK
(
exist_for
);
int
extent
=
GetInt32Const
(
exist_for
->
extent
);
std
::
string
prev_name
=
src_info1
->
name_
;
Var
prev_var
=
src_info1
->
data_
;
Buffer
prev_buffer
=
src_info1
->
buffer_
;
Var
bisec_var
;
Buffer
bisec_buffer
;
std
::
string
bisec_pre_header
=
last_axis
?
"bisec_last_axis"
:
"bisec"
;
std
::
string
bisec_name
=
bisec_pre_header
+
"_local_UB"
;
if
(
postfix
>
0
)
{
bisec_name
=
bisec_name
+
"_"
+
std
::
to_string
(
postfix
);
}
bool
first_round
=
true
;
int
vec_max_len
=
GetVecMaxLen
(
dst_info
->
dtype_
);
int
remain_extent
=
extent
;
int
left_extent
=
0
;
CHECK_NE
(
vec_max_len
,
0
);
std
::
vector
<
int
>
pow2_list
=
{
0
,
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
32768
,
65536
};
while
(
extent
>
0
)
{
int
for_extent
;
if
(
last_axis
)
{
left_extent
=
remain_extent
/
2
+
remain_extent
%
2
;
for
(
int
i
:
pow2_list
)
{
if
(
left_extent
==
i
)
{
break
;
}
else
if
(
left_extent
<
i
)
{
left_extent
=
i
;
break
;
}
}
if
(
left_extent
<
vec_max_len
)
{
// When left_extent < vec_max_len, stop bisect and generate normal reduce intrin
left_extent
=
remain_extent
;
}
extent
=
remain_extent
-
left_extent
;
remain_extent
=
left_extent
;
for_extent
=
extent
==
0
?
vec_max_len
:
extent
;
}
else
{
for_extent
=
extent
==
1
?
extent
:
extent
/
2
;
extent
=
extent
%
2
==
0
||
extent
==
1
?
extent
/
2
:
(
extent
+
1
)
/
2
;
for
(
int
i
:
pow2_list
)
{
if
(
extent
==
i
)
{
break
;
}
else
if
(
extent
<
i
)
{
int
gap
=
i
-
extent
;
extent
=
i
;
for_extent
-=
gap
;
break
;
}
}
}
StmtStoreInfo
dst_tmp_info
=
dst_info
.
Copy
();
StmtStoreInfo
src_tmp_info0
{
src_info1
.
Copy
()};
StmtStoreInfo
src_tmp_info1
{
src_info1
.
Copy
()};
if
(
first_round
)
{
auto
shape
=
src_tmp_info1
->
shape_
;
if
(
last_axis
)
{
int
block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
SetItem
(
shape
,
-
1
,
Expr
(
CeilTo
(
GetIntConst
(
GetItem
(
shape
,
-
1
)),
block_size
)));
}
wrapper
.
original_shape_
=
shape
;
bisec_var
=
Var
(
bisec_name
,
Handle
());
bisec_buffer
=
BufferNode
::
make
(
bisec_var
,
dst_tmp_info
->
dtype_
,
shape
,
Array
<
Expr
>
(),
Expr
(),
bisec_name
,
SCOPE_UBUF
,
0
,
0
,
BufferType
::
kDefault
);
if
((
last_axis
&&
extent
!=
left_extent
)
||
(
!
last_axis
&&
extent
!=
for_extent
))
{
// Need to copy input to bisect buffer
StmtStoreInfo
copy_dst_info
{
src_info1
.
Copy
()};
StmtStoreInfo
copy_src_info
{
src_info1
.
Copy
()};
StmtInfoList
src_list
=
{
copy_src_info
};
auto
for_tmp_info
=
for_info
.
Copy
();
auto
new_for
=
GetItem
(
for_tmp_info
.
ops_
,
for_idx
).
as
<
For
>
();
CHECK
(
new_for
);
SetItem
(
for_tmp_info
.
ops_
,
static_cast
<
int
>
(
for_idx
),
For
::
make
(
new_for
->
loop_var
,
new_for
->
min
,
last_axis
?
left_extent
:
extent
,
new_for
->
for_type
,
new_for
->
device_api
,
new_for
->
body
));
ReplaceVarWithNewForInfo
(
copy_dst_info
,
for_info
,
for_tmp_info
);
ReplaceVarWithNewForInfo
(
copy_src_info
,
for_info
,
for_tmp_info
);
SetItem
(
copy_dst_info
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
last_axis
?
left_extent
:
extent
));
SetItem
(
copy_src_info
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
last_axis
?
left_extent
:
extent
));
CompactComputationInfoList
(
copy_dst_info
,
src_list
,
if_info
,
for_tmp_info
);
copy_dst_info
.
GetNode
()
->
name_
=
bisec_name
;
copy_dst_info
.
GetNode
()
->
buffer_
=
bisec_buffer
;
copy_dst_info
.
GetNode
()
->
data_
=
bisec_var
;
// Replace outside for variable in index
auto
vars
=
GetVarsInExpr
(
copy_dst_info
->
index_
);
for
(
auto
var
:
vars
)
{
if
(
!
IsInArray
(
copy_dst_info
->
var_
,
var
))
{
copy_dst_info
.
GetNode
()
->
index_
=
substitute
(
var
,
Expr
(
0
),
copy_dst_info
->
index_
);
}
}
wrapper
.
bisec_info_list_
.
emplace_back
(
StmtInfoList
{
copy_dst_info
,
copy_src_info
});
wrapper
.
for_info_list_
.
push_back
(
for_tmp_info
);
}
}
auto
for_tmp_info
=
for_info
.
Copy
();
auto
new_for
=
GetItem
(
for_tmp_info
.
ops_
,
for_idx
).
as
<
For
>
();
CHECK
(
new_for
);
SetItem
(
for_tmp_info
.
ops_
,
static_cast
<
int
>
(
for_idx
),
For
::
make
(
new_for
->
loop_var
,
new_for
->
min
,
for_extent
,
new_for
->
for_type
,
new_for
->
device_api
,
new_for
->
body
));
ReplaceVarWithNewForInfo
(
dst_tmp_info
,
for_info
,
for_tmp_info
);
ReplaceVarWithNewForInfo
(
src_tmp_info0
,
for_info
,
for_tmp_info
);
ReplaceVarWithNewForInfo
(
src_tmp_info1
,
for_info
,
for_tmp_info
);
SetItem
(
src_tmp_info0
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
for_extent
));
SetItem
(
src_tmp_info1
.
GetNode
()
->
shape_
,
var_idx
,
Expr
(
for_extent
));
if
(
extent
>
0
)
{
dst_tmp_info
.
GetNode
()
->
shape_
=
src_tmp_info1
->
shape_
;
dst_tmp_info
.
GetNode
()
->
strides_
=
src_tmp_info1
->
strides_
;
dst_tmp_info
.
GetNode
()
->
var_
=
src_tmp_info1
->
var_
;
dst_tmp_info
.
GetNode
()
->
index_
=
src_tmp_info1
->
index_
;
dst_tmp_info
.
GetNode
()
->
data_alignment_
=
src_tmp_info1
->
data_alignment_
;
dst_tmp_info
.
GetNode
()
->
name_
=
bisec_name
;
dst_tmp_info
.
GetNode
()
->
buffer_
=
bisec_buffer
;
dst_tmp_info
.
GetNode
()
->
data_
=
bisec_var
;
auto
src_extent
=
Expr
(
left_extent
);
if
(
!
last_axis
)
{
src_extent
=
GetItem
(
src_tmp_info1
->
strides_
,
var_idx
)
*
extent
;
}
src_tmp_info1
.
GetNode
()
->
index_
=
src_tmp_info1
->
index_
+
src_extent
;
}
src_tmp_info0
.
GetNode
()
->
name_
=
prev_name
;
src_tmp_info1
.
GetNode
()
->
name_
=
prev_name
;
src_tmp_info0
.
GetNode
()
->
buffer_
=
prev_buffer
;
src_tmp_info1
.
GetNode
()
->
buffer_
=
prev_buffer
;
src_tmp_info0
.
GetNode
()
->
data_
=
prev_var
;
src_tmp_info1
.
GetNode
()
->
data_
=
prev_var
;
// Replace outside for variable in index
for
(
auto
&
info
:
{
dst_tmp_info
,
src_tmp_info0
,
src_tmp_info1
})
{
if
(
info
->
name_
.
find
(
bisec_pre_header
)
==
std
::
string
::
npos
)
{
continue
;
}
auto
vars
=
GetVarsInExpr
(
info
->
index_
);
for
(
auto
var
:
vars
)
{
if
(
!
IsInArray
(
info
->
var_
,
var
))
{
info
.
GetNode
()
->
index_
=
substitute
(
var
,
Expr
(
0
),
info
->
index_
);
}
}
}
prev_name
=
bisec_name
;
prev_var
=
bisec_var
;
prev_buffer
=
bisec_buffer
;
StmtInfoList
src_list
=
{
src_tmp_info0
,
src_tmp_info1
};
CompactComputationInfoList
(
dst_tmp_info
,
src_list
,
if_info
,
for_tmp_info
);
wrapper
.
for_info_list_
.
emplace_back
(
for_tmp_info
);
if
(
extent
==
0
)
{
// last round should be dst = dst + src_tmp1
wrapper
.
bisec_info_list_
.
emplace_back
(
StmtInfoList
{
dst_tmp_info
,
dst_tmp_info
,
src_tmp_info1
});
}
else
{
// normally is dst_tmp = src_tmp0 + src_tmp1
wrapper
.
bisec_info_list_
.
emplace_back
(
StmtInfoList
{
dst_tmp_info
,
src_tmp_info0
,
src_tmp_info1
});
}
first_round
=
false
;
}
// Generate arg_info
for
(
size_t
i
=
0
;
i
<
wrapper
.
bisec_info_list_
.
size
();
++
i
)
{
auto
info_list
=
wrapper
.
bisec_info_list_
[
i
];
auto
new_for_info
=
wrapper
.
for_info_list_
[
i
];
ArgInfo
arg_info
;
auto
dst_list
=
GetRange
(
info_list
,
0
,
1
);
auto
src_list
=
GetRange
(
info_list
,
1
,
info_list
.
size
()
-
1
);
if
(
info_list
.
size
()
==
2
)
{
std
::
string
dma_intrin
=
INTRIN_NAME_COPY_UB_TO_UB
;
wrapper
.
dma_arg_info_map_
=
GetDmaCopyInsnArgs
(
dma_intrin
,
dst_list
,
src_list
,
new_for_info
);
}
else
if
(
last_axis
&&
i
==
wrapper
.
bisec_info_list_
.
size
()
-
1
)
{
auto
dst_tmp_info
=
dst_list
[
0
];
auto
src_tmp_info
=
src_list
[
1
];
ReduceLastAxisPatternGenerator
generator
=
ReduceLastAxisPatternGenerator
(
dst_tmp_info
,
src_tmp_info
,
new_for_info
,
"vadd"
);
auto
result
=
generator
.
GetInsnArgs
();
arg_info
=
result
.
arg_info
;
dst_tmp_info
=
result
.
dst_info_list
[
0
];
src_tmp_info
=
result
.
src_info_list
[
0
];
new_for_info
=
result
.
for_info
;
wrapper
.
bisec_info_list_
[
i
]
=
{
dst_tmp_info
,
dst_tmp_info
,
src_tmp_info
};
}
else
{
// Bisect can't expand mask because it has inplace operation
if
(
i
!=
wrapper
.
bisec_info_list_
.
size
()
-
1
)
{
// Last round dont need to add
FillLastDim
(
dst_list
,
src_list
,
new_for_info
);
}
std
::
string
mode
=
GetBinaryVecMode
(
dst_list
,
src_list
,
"vadd"
,
false
);
BinaryVecPatternGenerator
generator
=
BinaryVecPatternGenerator
(
dst_list
,
src_list
,
new_for_info
,
mode
,
false
);
auto
params
=
generator
.
GetInsnArgs
();
arg_info
=
params
.
arg_info
;
dst_list
=
params
.
dst_info_list
;
src_list
=
params
.
src_info_list
;
new_for_info
=
params
.
for_info
;
wrapper
.
bisec_info_list_
[
i
]
=
{
dst_list
[
0
],
src_list
[
0
],
src_list
[
1
]};
}
wrapper
.
arg_info_list_
.
push_back
(
arg_info
);
wrapper
.
for_info_list_
[
i
]
=
new_for_info
;
}
return
wrapper
;
}
}
// namespace akg
src/emit_insn/insn_emitter.cc
浏览文件 @
0871623a
...
...
@@ -35,7 +35,7 @@
#include "insn_info.h"
#include "insn_pattern.h"
#include "insn_emitter_multimask.h"
#include "insn_args_calculator.h"
namespace
akg
{
namespace
ir
{
/// Sort indexes
...
...
@@ -71,8 +71,7 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
Array
<
Expr
>
call_args
;
int
call_cnt
=
0
;
if
(
intrin_name
==
"vector_dup"
||
intrin_name
==
"vadds"
||
intrin_name
==
"vmuls"
||
intrin_name
==
"vaxpy"
)
{
if
(
intrin_name
==
"vector_dup"
||
intrin_name
==
"vadds"
||
intrin_name
==
"vmuls"
||
intrin_name
==
"vaxpy"
)
{
auto
GetCallInfo
=
[
&
intrin_name
,
&
call_args
,
&
call_cnt
](
const
NodeRef
&
op
)
{
if
(
op
.
as
<
Call
>
()
&&
op
.
as
<
Call
>
()
->
name
==
intrin_name
)
{
call_args
=
op
.
as
<
Call
>
()
->
args
;
...
...
@@ -82,8 +81,8 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
PostOrderVisit
(
op
,
GetCallInfo
);
CHECK_EQ
(
call_cnt
,
1
);
}
SingleType
insn_type
{
SingleType
::
SIMD
};
Expr
scalar_src
{};
SingleType
insn_type
{
SingleType
::
SIMD
};
Expr
scalar_src
{};
if
(
intrin_name
==
"vector_dup"
)
{
insn_type
=
SingleType
::
Vector_Dump
;
src_info_list
=
{};
...
...
@@ -93,10 +92,11 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
src_info_list
=
{
src_info_list
[
0
]};
scalar_src
=
call_args
[
1
];
}
// check is single vector broadcast reduce mode exist
SingleVecPatternGenerator
generator
=
SingleVecPatternGenerator
(
dst_info_list
,
src_info_list
,
for_info
);
auto
params
=
generator
.
GetInsnArgs
();
SingleVecInsnArgsCalculator
args_calculator
=
SingleVecInsnArgsCalculator
(
dst_info_list
,
src_info_list
,
for_info
,
intrin_name
);
PatternResult
params
=
args_calculator
.
GetInsnArgs
();
dst_info_list
=
params
.
dst_info_list
;
src_info_list
=
params
.
src_info_list
;
for_info
=
params
.
for_info
;
...
...
@@ -141,23 +141,16 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec
if
(
src_info_list
[
0
]
->
var_
.
size
()
>
src_info_list
[
1
]
->
var_
.
size
())
{
src_info
=
src_info_list
[
0
];
}
const
int
vec_max_len
=
GetVecMaxLen
(
dst_info
->
dtype_
);
if
(
enable_bisect
&&
GetIntConst
(
GetItem
(
src_info
->
shape_
,
-
1
))
>
vec_max_len
)
{
CommentManager
::
GetInstance
().
AddComment
(
"Bisect_optimize"
,
"enabled"
);
auto
wrapper
=
SeparateComInfoToBisectionInfoList
(
dst_info_list
,
src_info_list
,
for_info
,
if_info
,
true
,
postfix
);
return
EmitCceBinaryVectorToBisectionReduction
(
wrapper
,
if_info
,
intrin_name
);
}
else
{
CommentManager
::
GetInstance
().
AddComment
(
"Pattern"
,
arg_info
.
GetPattern
());
ReduceLastAxisPatternGenerator
generator
=
ReduceLastAxisPatternGenerator
(
dst_info
,
src_info
,
for_info
,
intrin_name
);
auto
result
=
generator
.
GetInsnArgs
();
arg_info
=
result
.
arg_info
;
dst_info
=
result
.
dst_info_list
[
0
];
src_info
=
result
.
src_info_list
[
0
];
for_info
=
result
.
for_info
;
return
EmitCceBinaryVectorToReduceLastAxis
(
dst_info
,
src_info
,
if_info
,
for_info
,
arg_info
,
intrin_name
);
}
CommentManager
::
GetInstance
().
AddComment
(
"Pattern"
,
arg_info
.
GetPattern
());
LastAxisReduceInsnArgsCalculator
args_calculator
=
LastAxisReduceInsnArgsCalculator
(
dst_info
,
src_info
,
for_info
,
intrin_name
);
PatternResult
result
=
args_calculator
.
GetInsnArgs
();
arg_info
=
result
.
arg_info
;
dst_info
=
result
.
dst_info_list
[
0
];
src_info
=
result
.
src_info_list
[
0
];
for_info
=
result
.
for_info
;
return
EmitCceBinaryVectorToReduceLastAxis
(
dst_info
,
src_info
,
if_info
,
for_info
,
arg_info
,
intrin_name
);
}
case
ARG_VECTOR_REDUCTION_BISECTION
:
{
CommentManager
::
GetInstance
().
AddComment
(
"Compute_type"
,
"reduction"
);
...
...
@@ -192,7 +185,7 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec
return
FoldInsnWithForInfo
(
insn_list
,
if_info
,
for_info
,
stmt
);
}
}
}
}
// namespace ir
/// Function to emit scalar intrin
/// \param op - The input stmt to be emitted as intrin
...
...
@@ -984,8 +977,9 @@ Stmt BinaryDropoutEmitter(const Stmt &op) {
src1
.
GetNode
()
->
data_
=
mask
->
buffer_var
;
src1
.
GetNode
()
->
data_alignment_
=
GetInt32Const
(
mask
->
predicate
);
SingleVecPatternGenerator
generator
=
SingleVecPatternGenerator
(
dst_info_list
,
src_info_list
,
for_info
,
"elewise"
);
auto
params
=
generator
.
GetInsnArgs
();
SingleVecInsnArgsCalculator
args_calculator
=
SingleVecInsnArgsCalculator
(
dst_info_list
,
src_info_list
,
for_info
);
PatternResult
params
=
args_calculator
.
GetInsnArgs
();
dst_info_list
=
params
.
dst_info_list
;
src_info_list
=
params
.
src_info_list
;
for_info
=
params
.
for_info
;
...
...
@@ -1484,8 +1478,10 @@ Stmt BinaryArgOpEmitter(const Stmt &op, const std::string &intrin_name) {
if
(
src_info_list
[
0
]
->
var_
.
size
()
>
src_info_list
[
1
]
->
var_
.
size
())
{
src_info
=
src_info_list
[
0
];
}
ReduceLastAxisPatternGenerator
generator
=
ReduceLastAxisPatternGenerator
(
dst_info
,
src_info
,
for_info
,
intrin_name
);
auto
result
=
generator
.
GetInsnArgs
();
LastAxisReduceInsnArgsCalculator
args_calculator
=
LastAxisReduceInsnArgsCalculator
(
dst_info
,
src_info
,
for_info
,
intrin_name
);
PatternResult
result
=
args_calculator
.
GetInsnArgs
();
arg_info
=
result
.
arg_info
;
dst_info
=
result
.
dst_info_list
[
0
];
src_info
=
result
.
src_info_list
[
0
];
...
...
src/emit_insn/insn_info.cc
浏览文件 @
0871623a
...
...
@@ -104,10 +104,7 @@ StmtStoreInfo StmtStoreInfo::Copy() const {
StmtInfo
StmtInfo
::
Copy
()
const
{
auto
stmt_info
=
StmtInfo
();
stmt_info
.
ops_
=
ops_
;
for
(
auto
var
:
vars_
)
{
auto
new_var
=
Variable
::
make
(
var
->
type
,
var
->
name_hint
);
stmt_info
.
vars_
.
push_back
(
new_var
);
}
stmt_info
.
vars_
=
vars_
;
for
(
size_t
i
=
0
;
i
<
vars_
.
size
();
++
i
)
{
for
(
size_t
j
=
0
;
j
<
stmt_info
.
ops_
.
size
();
++
j
)
{
...
...
src/emit_insn/insn_info.h
浏览文件 @
0871623a
...
...
@@ -276,15 +276,7 @@ struct BisectionInfoWrapper {
Map
<
std
::
string
,
Expr
>
dma_arg_info_map_
;
};
struct
InsnAxis
{
int
min
{
0
};
int
extent
{
0
};
Var
var
;
int
dst_stride
{
0
};
int
src_stride
{
0
};
std
::
list
<
int
>
src_stride_list
;
std
::
list
<
int
>
stride_list
;
};
IterVar
GetCceAxis
();
...
...
src/emit_insn/insn_pattern.cc
浏览文件 @
0871623a
...
...
@@ -15,7 +15,6 @@
*/
#include "insn_pattern.h"
#include <tvm/runtime/packed_func.h>
#include <tvm/base.h>
#include <tvm/ir_pass.h>
...
...
@@ -200,28 +199,6 @@ ArgInfo GetMultiVecInsnArgs(StmtInfoList &dst_info_list, StmtInfoList &src_info_
return
arg_info
;
}
/// Get first non zero shape from input shapes
/// \param dst_shape
/// \param src0_shape
/// \param src1_shape
/// \return
int
PatternGenerator
::
GetNonZeroShape
(
const
Expr
&
dst_shape
,
const
Expr
&
src0_shape
,
const
Expr
&
src1_shape
)
{
int
shape
=
0
;
for
(
int
val
:
{
GetInt32Const
(
dst_shape
),
GetInt32Const
(
src0_shape
),
src1_shape
.
defined
()
?
GetInt32Const
(
src1_shape
)
:
0
})
{
if
(
val
==
0
)
{
continue
;
}
if
(
shape
!=
0
&&
val
!=
shape
)
{
LOG
(
FATAL
)
<<
"Error: same var has different shapes. "
<<
GetIntConst
(
dst_shape
)
<<
" "
<<
GetIntConst
(
src0_shape
);
}
shape
=
val
;
}
CHECK
(
shape
!=
0
)
<<
"Error: all shapes are equal to 0."
;
return
shape
;
}
/// In case
/// for (cc3) {
/// A[(cc3*16)] = (B[(cc3*16)] - C[(cc3*16)])
...
...
@@ -432,25 +409,6 @@ void CleanZeroStrides(Array<StmtStoreInfo> &info_list) {
}
}
/// Swap axis in Array
/// \param var
/// \param shape
/// \param strides
/// \param idx1
/// \param idx2
void
PatternGenerator
::
GetShapeInfoAndSwap
(
Array
<
Var
>
&
var
,
Array
<
Expr
>
&
shape
,
Array
<
Expr
>
&
strides
,
int
idx1
,
int
idx2
)
{
auto
tmp_var
=
GetItem
(
var
,
idx1
);
SetItem
(
var
,
idx1
,
GetItem
(
var
,
idx2
));
SetItem
(
var
,
idx2
,
tmp_var
);
auto
tmp_shape
=
GetItem
(
shape
,
idx1
);
SetItem
(
shape
,
idx1
,
GetItem
(
shape
,
idx2
));
SetItem
(
shape
,
idx2
,
tmp_shape
);
auto
tmp_stride
=
GetItem
(
strides
,
idx1
);
SetItem
(
strides
,
idx1
,
GetItem
(
strides
,
idx2
));
SetItem
(
strides
,
idx2
,
tmp_stride
);
}
/// Get insn args of load 2D intrin
/// \param intrin_name
/// \param dst_info_list
...
...
@@ -856,6 +814,38 @@ Map<std::string, Expr> GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn
return
arg_info_map
;
}
/// Replace com_info's var with new for loop's var
/// \param info
/// \param old_for_info
/// \param new_for_info
void
ReplaceVarWithNewForInfo
(
StmtStoreInfo
&
info
,
const
StmtInfo
&
old_for_info
,
const
StmtInfo
&
new_for_info
)
{
for
(
size_t
i
=
0
;
i
<
new_for_info
.
vars_
.
size
();
++
i
)
{
for
(
size_t
j
=
0
;
j
<
info
->
var_
.
size
();
++
j
)
{
if
(
info
->
var_
[
j
]
->
name_hint
==
new_for_info
.
vars_
[
i
]
->
name_hint
)
{
SetItem
(
info
.
GetNode
()
->
var_
,
static_cast
<
int
>
(
j
),
new_for_info
.
vars_
[
i
]);
}
}
info
.
GetNode
()
->
index_
=
substitute
(
old_for_info
.
vars_
[
i
],
new_for_info
.
vars_
[
i
],
info
->
index_
);
}
}
std
::
string
GetBinaryVecMode
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
std
::
string
&
intrin_name
,
bool
enable_bisect
)
{
std
::
set
<
std
::
string
>
reduce_bisect_list
=
{
"vadd"
,
"vsub"
,
"vmul"
,
"vmax"
};
std
::
string
mode
=
"reduction"
;
if
(
IsElementwise
(
dst_info_list
,
src_info_list
))
{
mode
=
"elewise"
;
}
else
if
(
IsBroadcast
(
dst_info_list
,
src_info_list
))
{
mode
=
"broadcast"
;
}
else
if
(
IsLastAxisReduction
(
dst_info_list
,
src_info_list
))
{
mode
=
"reduce_last_axis"
;
}
else
if
(
enable_bisect
&&
reduce_bisect_list
.
count
(
intrin_name
)
!=
0
&&
IsBisectionReduction
(
dst_info_list
,
src_info_list
))
{
mode
=
"reduce_bisection"
;
}
return
mode
;
}
const
char
*
const
DummyLastVar
=
"cc_last"
;
TVM_REGISTER_API
(
"cce_util.GetVecMask"
).
set_body
([](
const
TVMArgs
args
,
TVMRetValue
*
ret
)
{
...
...
src/emit_insn/insn_pattern.h
浏览文件 @
0871623a
...
...
@@ -37,220 +37,12 @@ struct PatternResult {
StmtInfo
for_info
;
};
class
PatternGenerator
{
public:
PatternGenerator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfo
&
for_info
)
:
for_info
(
for_info
),
not_this_pattern
(
-
1.0
f
),
split_latency_coef
(
10.0
f
),
repeat_latency_coef
(
3.0
f
),
offset_latency_coef
(
0.1
f
)
{
CHECK
(
!
dst_info_list
.
empty
());
dst_info
=
dst_info_list
[
0
];
}
virtual
~
PatternGenerator
()
=
default
;
virtual
PatternResult
GetInsnArgs
()
=
0
;
protected:
int
GetNonZeroShape
(
const
Expr
&
dst_shape
,
const
Expr
&
src0_shape
,
const
Expr
&
src1_shape
=
Expr
());
void
GetShapeInfoAndSwap
(
Array
<
Var
>
&
var
,
Array
<
Expr
>
&
shape
,
Array
<
Expr
>
&
strides
,
int
idx1
,
int
idx2
);
virtual
float
Compute3DPatternMaskRate
()
{
return
not_this_pattern
;
}
virtual
float
Compute2DBlockPatternMaskRate
()
{
return
not_this_pattern
;
}
virtual
float
Compute2DPatternMaskRate
()
{
return
not_this_pattern
;
}
virtual
float
Compute1DPatternMaskRate
()
{
return
not_this_pattern
;
}
virtual
Array
<
Var
>
Get3DPattern
()
{
return
{};
}
virtual
Array
<
Var
>
Get2DBlockPattern
()
{
return
{};
}
virtual
Array
<
Var
>
Get2DPattern
()
{
return
{};
}
virtual
Array
<
Var
>
Get1DPattern
()
{
return
{};
}
virtual
PatternResult
GenResult
(
const
Array
<
Var
>
&
elim_var
)
=
0
;
StmtStoreInfo
dst_info
;
StmtInfo
for_info
;
const
float
not_this_pattern
;
const
float
split_latency_coef
;
const
float
repeat_latency_coef
;
const
float
offset_latency_coef
;
};
class
SingleVecPatternGenerator
:
public
PatternGenerator
{
public:
SingleVecPatternGenerator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
mode
=
"elewise"
)
:
PatternGenerator
(
dst_info_list
,
for_info
),
arg_info
(
ArgInfo
(
make_node
<
ArgInfoNode
>
())),
body_args
(
VectorArgInfo
()),
tail_args
(
VectorArgInfo
()),
mode
(
mode
)
{
if
(
src_info_list
.
empty
())
{
src_info
=
dst_info
.
Copy
();
}
else
{
CHECK
(
!
src_info_list
.
empty
());
src_info
=
src_info_list
[
0
];
}
}
~
SingleVecPatternGenerator
()
override
=
default
;
PatternResult
GetInsnArgs
()
final
;
protected:
float
Compute3DPatternMaskRate
()
final
;
float
Compute2DBlockPatternMaskRate
()
final
;
float
Compute2DPatternMaskRate
()
final
;
float
Compute1DPatternMaskRate
()
final
;
float
Compute3DsPatternMaskRate
();
float
Compute2DRepeatPatternMaskRate
();
Array
<
Var
>
Get3DPattern
()
final
;
Array
<
Var
>
Get2DBlockPattern
()
final
;
Array
<
Var
>
Get2DPattern
()
final
;
Array
<
Var
>
Get1DPattern
()
final
;
Array
<
Var
>
Get3DsPattern
();
Array
<
Var
>
Get2DRepeatPattern
();
PatternResult
GenResult
(
const
Array
<
Var
>
&
elim_var
)
final
;
private:
void
CalcParams
();
int
GetLastDimShape
(
const
Expr
&
dst_shape
,
const
Expr
&
src_shape
);
struct
Params
{
Array
<
Var
>
dst_var
;
Array
<
Var
>
src_var
;
Array
<
Expr
>
dst_shape
;
Array
<
Expr
>
src_shape
;
Array
<
Expr
>
dst_strides
;
Array
<
Expr
>
src_strides
;
int
non_zero_shape1
=
0
;
int
non_zero_shape2
=
0
;
int
non_zero_shape3
=
0
;
int
all_points
=
0
;
int
dst_block_size
=
0
;
int
src_block_size
=
0
;
int
mask_block_size
=
0
;
int
dst_bits
=
0
;
int
src_bits
=
0
;
int
max_bits
=
0
;
int
dst_vec_max_len
=
0
;
int
vec_max_len
=
0
;
int
block_offset
=
0
;
};
StmtStoreInfo
src_info
;
Params
params
;
ArgInfo
arg_info
;
VectorArgInfo
body_args
;
VectorArgInfo
tail_args
;
std
::
string
mode
;
Type
data_type
;
};
class
BinaryVecPatternGenerator
:
public
PatternGenerator
{
public:
BinaryVecPatternGenerator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
mode
,
bool
expand_mask
=
true
)
:
PatternGenerator
(
dst_info_list
,
for_info
),
src_info_list
(
src_info_list
),
arg_info
(
ArgInfo
(
make_node
<
ArgInfoNode
>
())),
body_args
(
VectorArgInfo
()),
tail_args
(
VectorArgInfo
()),
empty_var
(
Var
(
""
)),
mode
(
mode
),
expand_mask
(
expand_mask
)
{}
~
BinaryVecPatternGenerator
()
override
=
default
;
PatternResult
GetInsnArgs
()
final
;
protected:
float
Compute3DPatternMaskRate
()
final
;
float
Compute2DBlockPatternMaskRate
()
final
;
float
Compute2DPatternMaskRate
()
final
;
float
Compute1DPatternMaskRate
()
final
;
Array
<
Var
>
Get3DPattern
()
final
;
Array
<
Var
>
Get2DBlockPattern
()
final
;
Array
<
Var
>
Get2DPattern
()
final
;
Array
<
Var
>
Get1DPattern
()
final
;
PatternResult
GenResult
(
const
Array
<
Var
>
&
elim_var
)
final
;
private:
void
CalcParams
();
bool
IsSamePatternComInfo
(
const
StmtStoreInfo
&
info_a
,
const
StmtStoreInfo
&
info_b
);
bool
IsNonZeroShapeEqual
(
const
Array
<
Expr
>
&
shape_list
);
void
AppendEmptyVar
(
StmtInfoList
&
info_list
);
struct
Params
{
Array
<
Var
>
dst_var
;
Array
<
Expr
>
dst_shape
;
Array
<
Expr
>
dst_strides
;
Array
<
Var
>
src_var0
;
Array
<
Expr
>
src_shape0
;
Array
<
Expr
>
src_strides0
;
Array
<
Var
>
src_var1
;
Array
<
Expr
>
src_shape1
;
Array
<
Expr
>
src_strides1
;
int
non_zero_shape1
=
0
;
int
non_zero_shape2
=
0
;
int
non_zero_shape3
=
0
;
int
all_points
=
0
;
int
block_size
=
0
;
int
last_dim_shape
=
0
;
int
vec_max_len
=
0
;
};
StmtInfoList
src_info_list
;
ArgInfo
arg_info
;
VectorArgInfo
body_args
;
VectorArgInfo
tail_args
;
Params
params
;
Var
empty_var
;
std
::
string
mode
;
bool
expand_mask
;
};
class
ReduceLastAxisPatternGenerator
:
public
PatternGenerator
{
public:
ReduceLastAxisPatternGenerator
(
const
StmtStoreInfo
&
dst_info
,
const
StmtStoreInfo
&
src_info
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
)
:
PatternGenerator
({
dst_info
},
for_info
),
src_info
(
src_info
),
arg_info
(
ArgInfo
(
make_node
<
ArgInfoNode
>
())),
body_args
(
VectorArgInfo
()),
tail_args
(
VectorArgInfo
()),
intrin_name
(
intrin_name
)
{}
PatternResult
GetInsnArgs
()
final
;
~
ReduceLastAxisPatternGenerator
()
override
=
default
;
protected:
float
Compute2DBlockPatternMaskRate
()
final
;
Array
<
Var
>
Get2DBlockPattern
()
final
;
Array
<
Var
>
Get1DPattern
()
final
;
PatternResult
GenResult
(
const
Array
<
Var
>
&
elim_var
)
final
;
private:
void
CalcParams
();
struct
Params
{
Array
<
Var
>
src_var
;
int
block_size
=
0
;
int
vec_max_len
=
0
;
int
last_dim_shape
=
0
;
Expr
insn_offset_scale_factor
;
};
StmtStoreInfo
src_info
;
ArgInfo
arg_info
;
VectorArgInfo
body_args
;
VectorArgInfo
tail_args
;
Array
<
VectorArgInfo
>
mix_vec_arg_list
;
std
::
string
intrin_name
;
Params
params
;
};
std
::
string
GetSingleVecComputationInfo
(
const
Stmt
&
stmt
,
const
std
::
string
&
intrin_name
,
Array
<
StmtStoreInfo
>
&
dst_info_list
,
Array
<
StmtStoreInfo
>
&
src_info_list
,
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
bool
need_compact
=
true
);
ArgInfo
GetBinaryVecInsnArgs
(
const
Stmt
&
stmt
,
std
::
string
intrin_name
,
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
bool
enable_bisect
=
true
);
std
::
string
GetBinaryVecMode
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
std
::
string
&
intrin_name
,
bool
enable_bisect
=
true
);
ArgInfo
GetMultiVecInsnArgs
(
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
for_info
);
...
...
@@ -277,10 +69,7 @@ Map<std::string, Expr> GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn
const
StmtInfoList
&
src_info_list
,
StmtInfo
&
for_info
,
Map
<
std
::
string
,
Expr
>
&
ub_copy_pre
,
Map
<
std
::
string
,
Expr
>
&
ub_copy_post
);
BisectionInfoWrapper
SeparateComInfoToBisectionInfoList
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
StmtInfo
&
if_info
,
bool
last_axis
,
int
postfix
);
void
ReplaceVarWithNewForInfo
(
StmtStoreInfo
&
info
,
const
StmtInfo
&
old_for_info
,
const
StmtInfo
&
new_for_info
);
extern
const
char
*
const
DummyLastVar
;
}
// namespace akg
#endif // EMIT_INSN_INSN_PATTERN_H_
src/emit_insn/insn_single_vec_pattern.cc
已删除
100644 → 0
浏览文件 @
d237aa7d
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/base.h>
#include <tvm/ir_pass.h>
#include <cmath>
#include <set>
#include "insn_builder.h"
#include "insn_pattern.h"
#include "common/array_api.h"
#include "pass/expr_alg_simplify.h"
namespace
akg
{
/// Get CCE Single Vector Insn mode
/// \param dst_info_list
/// \param src_info_list
/// \return
std
::
string
GetSingleVecMode
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
)
{
CHECK
(
!
dst_info_list
.
empty
());
auto
dst_var_list
=
dst_info_list
[
0
]
->
var_
;
Array
<
Var
>
src_var_list
;
if
(
!
src_info_list
.
empty
())
{
src_var_list
=
src_info_list
[
0
]
->
var_
;
}
if
(
IsSame
(
dst_var_list
,
src_var_list
))
{
return
"elewise"
;
}
else
if
(
dst_var_list
.
size
()
>=
src_var_list
.
size
())
{
return
"broadcast"
;
}
return
"reduction"
;
}
/// Get Single Vector Computation Info
/// \param stmt
/// \param intrin_name
/// \param dst_info_list
/// \param src_info_list
/// \param if_info
/// \param for_info
/// \param need_compact
/// \return
std
::
string
GetSingleVecComputationInfo
(
const
Stmt
&
stmt
,
const
std
::
string
&
intrin_name
,
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
bool
need_compact
)
{
std
::
set
<
std
::
string
>
intrin_name_list
=
{
"vadds"
,
"vmuls"
,
"vrelu"
,
"vabs"
,
"vln"
,
"vexp"
,
"vrec"
,
"vector_dup"
,
"vnot"
,
"vsqrt"
,
"vrsqrt"
};
if
(
intrin_name_list
.
count
(
intrin_name
)
==
0
&&
intrin_name
.
find
(
"vconv_"
)
==
std
::
string
::
npos
)
{
LOG
(
FATAL
)
<<
"Error: CCE Single Vector Insn unsupported the given intrin_name. "
<<
intrin_name
;
return
""
;
}
bool
same_dtype
=
intrin_name
.
find
(
"vconv_"
)
==
std
::
string
::
npos
;
GetCompactComputationInfo
(
stmt
,
dst_info_list
,
src_info_list
,
if_info
,
for_info
,
same_dtype
,
need_compact
);
std
::
string
mode
=
GetSingleVecMode
(
dst_info_list
,
src_info_list
);
CHECK
(
dst_info_list
.
size
()
==
1
)
<<
"CCE Single Vector only support ONE dst."
;
return
mode
;
}
/// Get CCE Single vector instructions args.
/// \param dst_info_list
/// \param src_info_list
/// \param for_info
/// \param mode
/// \return
PatternResult
SingleVecPatternGenerator
::
GetInsnArgs
()
{
CalcParams
();
Array
<
Var
>
elim_var
=
{};
float
rate3d
=
Compute3DPatternMaskRate
();
float
rate2db
=
Compute2DBlockPatternMaskRate
();
float
rate2d
=
Compute2DPatternMaskRate
();
float
rate1d
=
Compute1DPatternMaskRate
();
float
rate3ds
=
Compute3DsPatternMaskRate
();
float
rate2ds
=
Compute2DRepeatPatternMaskRate
();
if
(
mode
==
"broadcast_last_axis"
)
{
elim_var
=
Get1DPattern
();
}
else
if
(
rate2ds
>
0
)
{
elim_var
=
Get2DRepeatPattern
();
}
else
if
(
rate3ds
>
0
)
{
elim_var
=
Get3DsPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_2D
;
}
else
if
(
rate3d
>=
rate2db
&&
rate3d
>
0
)
{
elim_var
=
Get3DPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_3D
;
}
else
if
(
rate2db
>=
rate2d
&&
rate2db
>=
rate1d
&&
rate2db
>
0
)
{
elim_var
=
Get2DBlockPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_PARTIAL_3D
;
}
else
if
(
rate2d
>
rate1d
&&
rate2d
>
0
)
{
elim_var
=
Get2DPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_2D
;
}
else
if
(
rate1d
>
0
)
{
elim_var
=
Get1DPattern
();
arg_info
.
GetNode
()
->
pattern_
=
PATTERN_1D
;
}
else
{
LOG
(
FATAL
)
<<
"Error: Cannot emit Single-Vector-Insn with any pattern!"
;
}
std
::
string
mask_rate
=
"rate3d["
+
std
::
to_string
(
rate3d
)
+
"], rate2db["
+
std
::
to_string
(
rate2db
)
+
"], rate2d["
+
std
::
to_string
(
rate2d
)
+
"], rate1d["
+
std
::
to_string
(
rate1d
)
+
"]"
;
CommentManager
::
GetInstance
().
AddComment
(
"Mask_rate"
,
mask_rate
);
if
(
arg_info
->
tail_arg_info_
.
defined
())
{
CommentManager
::
GetInstance
().
AddComment
(
"Contain_tail"
,
"true"
);
}
else
{
CommentManager
::
GetInstance
().
AddComment
(
"Contain_tail"
,
"false"
);
}
return
GenResult
(
elim_var
);
}
/// Calc params for pattern match
void
SingleVecPatternGenerator
::
CalcParams
()
{
Array
<
StmtStoreInfo
>
info_list
=
{
dst_info
,
src_info
};
// check shape len
for
(
auto
info
:
info_list
)
{
CHECK
(
!
info
->
shape_
.
empty
())
<<
"CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0."
;
}
FillEmptyVar
(
info_list
);
dst_info
=
info_list
[
0
];
src_info
=
info_list
[
1
];
int
dst_bits
=
dst_info
->
dtype_
.
bits
();
int
src_bits
=
src_info
->
dtype_
.
bits
();
CHECK_NE
(
dst_bits
,
0
);
CHECK_NE
(
src_bits
,
0
);
int
dst_block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
int
src_block_size
=
GetUbBlkSize
(
src_info
->
dtype_
);
CHECK_NE
(
dst_block_size
,
0
);
CHECK_NE
(
src_block_size
,
0
);
data_type
=
src_bits
>
dst_bits
?
src_info
->
dtype_
:
dst_info
->
dtype_
;
params
.
dst_var
=
dst_info
->
var_
;
params
.
src_var
=
src_info
->
var_
;
params
.
dst_shape
=
dst_info
->
shape_
;
params
.
src_shape
=
src_info
->
shape_
;
params
.
dst_strides
=
dst_info
->
strides_
;
params
.
src_strides
=
src_info
->
strides_
;
params
.
dst_block_size
=
dst_block_size
;
params
.
src_block_size
=
src_block_size
;
params
.
mask_block_size
=
src_bits
>
dst_bits
?
src_block_size
:
dst_block_size
;
params
.
dst_bits
=
dst_bits
;
params
.
src_bits
=
src_bits
;
params
.
max_bits
=
FULL_BLOCK_NUM
*
std
::
min
(
dst_bits
,
src_bits
);
params
.
dst_vec_max_len
=
GetVecMaxLen
(
dst_info
->
dtype_
);
params
.
vec_max_len
=
src_bits
>
dst_bits
?
GetVecMaxLen
(
src_info
->
dtype_
)
:
GetVecMaxLen
(
dst_info
->
dtype_
);
CHECK_NE
(
params
.
dst_vec_max_len
,
0
);
CHECK_NE
(
params
.
vec_max_len
,
0
);
auto
GetNonZeroShapeByIdx
=
[
this
](
int
index
)
->
int
{
if
(
index
<=
static_cast
<
int
>
(
params
.
dst_var
.
size
()))
{
if
(
Equal
(
GetItem
(
params
.
dst_var
,
-
index
),
GetItem
(
params
.
src_var
,
-
index
)))
{
return
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
index
),
GetItem
(
params
.
src_shape
,
-
index
));
}
}
return
1
;
};
params
.
non_zero_shape1
=
GetNonZeroShapeByIdx
(
1
);
params
.
non_zero_shape2
=
GetNonZeroShapeByIdx
(
2
);
params
.
non_zero_shape3
=
GetNonZeroShapeByIdx
(
3
);
params
.
all_points
=
params
.
non_zero_shape1
*
params
.
non_zero_shape2
*
params
.
non_zero_shape3
;
auto
elem_offset_mod
=
ir
::
ExprSimplifier
().
Simplify
(
Mod
::
make
(
dst_info
->
elem_offset_
,
dst_block_size
));
if
(
elem_offset_mod
.
as
<
IntImm
>
())
{
params
.
block_offset
=
elem_offset_mod
.
as
<
IntImm
>
()
->
value
;
}
}
int
SingleVecPatternGenerator
::
GetLastDimShape
(
const
Expr
&
dst_shape
,
const
Expr
&
src_shape
)
{
int
dst_last_dim
=
GetInt32Const
(
dst_shape
);
int
src_last_dim
=
GetInt32Const
(
src_shape
);
CHECK
(
dst_last_dim
!=
0
||
src_last_dim
!=
0
);
if
(
dst_last_dim
==
0
)
{
return
src_last_dim
;
}
if
(
src_last_dim
==
0
)
{
return
dst_last_dim
;
}
return
std
::
min
(
dst_last_dim
,
src_last_dim
);
}
bool
FindInShape
(
Array
<
Expr
>
&
shape
,
const
Expr
&
target
)
{
for
(
int
i
=
-
1
;
i
>=
-
3
;
--
i
)
{
if
(
Equal
(
GetItem
(
shape
,
i
),
target
))
{
return
true
;
}
}
return
false
;
}
float
SingleVecPatternGenerator
::
Compute2DRepeatPatternMaskRate
()
{
if
(
params
.
dst_var
.
size
()
<
3
)
{
return
not_this_pattern
;
}
for
(
int
i
=
-
1
;
i
>=
-
3
;
--
i
)
{
if
(
!
FindInShape
(
params
.
src_shape
,
GetItem
(
params
.
dst_shape
,
i
)))
{
return
not_this_pattern
;
}
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
3
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
3
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
&&
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
if
(
!
Equal
(
GetItem
(
params
.
dst_shape
,
-
3
),
GetItem
(
params
.
src_shape
,
-
2
))
||
!
Equal
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape
,
-
3
)))
{
return
not_this_pattern
;
}
if
(
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
))
>
FULL_BLOCK_NUM
&&
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
))
%
FULL_BLOCK_NUM
!=
0
)
{
return
not_this_pattern
;
}
if
(
params
.
dst_block_size
==
params
.
src_block_size
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_shape
,
-
1
))
<=
params
.
dst_block_size
&&
GetInt32Const
(
GetItem
(
params
.
src_shape
,
-
1
))
<=
params
.
src_block_size
)
{
return
not_this_pattern
;
}
return
1.0
;
}
float
SingleVecPatternGenerator
::
Compute3DsPatternMaskRate
()
{
if
(
params
.
dst_var
.
size
()
<
3
)
{
return
not_this_pattern
;
}
if
(
params
.
dst_block_size
!=
params
.
src_block_size
)
{
return
not_this_pattern
;
}
for
(
int
i
=
-
1
;
i
>=
-
3
;
--
i
)
{
if
(
!
FindInShape
(
params
.
src_shape
,
GetItem
(
params
.
dst_shape
,
i
)))
{
return
not_this_pattern
;
}
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_shape
,
-
1
))
>
params
.
dst_block_size
||
GetInt32Const
(
GetItem
(
params
.
src_shape
,
-
1
))
>
params
.
src_block_size
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
3
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
3
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
&&
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
if
(
!
Equal
(
GetItem
(
params
.
dst_shape
,
-
3
),
GetItem
(
params
.
src_shape
,
-
2
))
||
!
Equal
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape
,
-
3
)))
{
return
not_this_pattern
;
}
if
(
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
))
>
FULL_BLOCK_NUM
&&
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
))
%
FULL_BLOCK_NUM
!=
0
)
{
return
not_this_pattern
;
}
return
1.0
;
}
float
SingleVecPatternGenerator
::
Compute3DPatternMaskRate
()
{
// in elemwise mode, the var is already checked to be equal, no need to check
if
(
params
.
dst_var
.
size
()
<
3
)
{
return
not_this_pattern
;
}
// do not support cast op in 3D pattern
if
(
params
.
dst_block_size
!=
params
.
src_block_size
)
{
return
not_this_pattern
;
}
for
(
int
i
=
-
1
;
i
>=
-
3
;
--
i
)
{
if
(
!
IsTwoItemEqual
(
params
.
dst_var
,
params
.
src_var
,
i
))
{
return
not_this_pattern
;
}
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_shape
,
-
1
))
>
params
.
dst_block_size
||
GetInt32Const
(
GetItem
(
params
.
src_shape
,
-
1
))
>
params
.
src_block_size
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
3
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
3
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
&&
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
// repeat axis is shape [-3], repeat once, has 8 loops
bool
is3_d
=
true
;
float
rate3d_mode1
=
not_this_pattern
;
float
rate3d_mode2
=
not_this_pattern
;
int
repeat_num
;
float
repeat_latency
;
StmtInfoList
info_list
=
{
dst_info
,
src_info
};
for
(
auto
info
:
info_list
)
{
if
(
GetInt32Const
(
GetItem
(
info
->
shape_
,
-
2
))
>
FULL_BLOCK_NUM
||
GetInt32Const
(
GetItem
(
info
->
strides_
,
-
2
))
/
params
.
dst_block_size
>=
MAX_STRIDE_M0_SINGLE
||
GetInt32Const
(
GetItem
(
info
->
strides_
,
-
3
))
/
params
.
dst_block_size
>=
MAX_STRIDE_M1
)
{
is3_d
=
false
;
break
;
}
}
if
(
is3_d
)
{
if
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
repeat_num
=
params
.
non_zero_shape3
;
repeat_latency
=
((
repeat_num
-
1
)
/
MAX_REPEAT
)
*
repeat_latency_coef
;
rate3d_mode1
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
dst_vec_max_len
/
(
repeat_num
+
repeat_latency
);
}
is3_d
=
true
;
// repeat axis is shape[-2]
for
(
auto
info
:
info_list
)
{
// stride_m0 should less than 65536
if
(
GetInt32Const
(
GetItem
(
info
->
shape_
,
-
3
))
%
FULL_BLOCK_NUM
!=
0
||
GetInt32Const
(
GetItem
(
info
->
strides_
,
-
3
))
/
params
.
dst_block_size
>=
MAX_STRIDE_M0_SINGLE
)
{
is3_d
=
false
;
break
;
}
}
if
(
is3_d
)
{
if
(
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
3
))
==
0
)
{
return
not_this_pattern
;
}
repeat_num
=
params
.
non_zero_shape2
*
(
params
.
non_zero_shape3
/
FULL_BLOCK_NUM
);
repeat_latency
=
((
repeat_num
-
1
)
/
MAX_REPEAT
)
*
repeat_latency_coef
;
float
offset_latency
=
params
.
non_zero_shape3
/
FULL_BLOCK_NUM
>
1
?
params
.
non_zero_shape3
*
offset_latency_coef
:
0
;
rate3d_mode2
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
dst_vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
);
}
return
rate3d_mode1
>
rate3d_mode2
?
rate3d_mode1
:
rate3d_mode2
;
}
// Partial 3D Pattern
float
SingleVecPatternGenerator
::
Compute2DBlockPatternMaskRate
()
{
// in elemwise mode, the var is already checked to be equal, no need to check
if
(
params
.
dst_var
.
size
()
<
2
||
params
.
src_var
.
size
()
<
2
||
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
1
))
!=
1
)
{
return
not_this_pattern
;
}
// do not support cast op in Partial3D pattern
if
(
params
.
dst_block_size
!=
params
.
src_block_size
)
{
return
not_this_pattern
;
}
for
(
int
i
=
-
1
;
i
>=
-
2
;
--
i
)
{
if
(
!
Equal
(
GetItem
(
params
.
dst_var
,
i
),
GetItem
(
params
.
src_var
,
i
)))
{
return
not_this_pattern
;
}
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_shape
,
-
1
))
>
params
.
dst_block_size
||
GetInt32Const
(
GetItem
(
params
.
src_shape
,
-
1
))
>
params
.
src_block_size
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
&&
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
/
params
.
dst_block_size
>=
MAX_STRIDE_M0_SINGLE
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
/
params
.
src_block_size
>=
MAX_STRIDE_M0_SINGLE
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
0
)
{
return
not_this_pattern
;
}
int
repeat_body_num
=
params
.
non_zero_shape2
/
FULL_BLOCK_NUM
;
int
repeat_tail_num
=
(
params
.
non_zero_shape2
%
FULL_BLOCK_NUM
+
FULL_BLOCK_NUM
-
1
)
/
FULL_BLOCK_NUM
;
int
repeat_num
=
(
repeat_body_num
+
repeat_tail_num
)
*
params
.
non_zero_shape3
;
float
repeat_latency
=
(
std
::
max
(
repeat_body_num
-
1
,
0
)
/
MAX_REPEAT
+
std
::
max
(
repeat_tail_num
-
1
,
0
)
/
MAX_REPEAT
)
*
repeat_latency_coef
;
float
offset_latency
=
params
.
non_zero_shape3
>
1
?
params
.
non_zero_shape3
*
offset_latency_coef
:
0
;
float
split_latency
=
(
repeat_body_num
>
0
&&
repeat_tail_num
>
0
)
?
split_latency_coef
:
0
;
float
rate2db
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
dst_vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
+
split_latency
);
return
rate2db
;
}
float
SingleVecPatternGenerator
::
Compute2DPatternMaskRate
()
{
// in elemwise mode, the var is already checked to be equal, no need to check
if
(
params
.
dst_var
.
size
()
<
2
||
params
.
src_var
.
size
()
<
2
)
{
return
not_this_pattern
;
}
if
(
src_info
->
data_alignment_
==
1
&&
GetInt32Const
(
GetItem
(
src_info
->
strides_
,
-
1
))
!=
params
.
dst_block_size
)
{
return
not_this_pattern
;
}
for
(
int
i
=
-
1
;
i
>=
-
2
;
--
i
)
{
if
(
!
Equal
(
GetItem
(
params
.
dst_var
,
i
),
GetItem
(
params
.
src_var
,
i
)))
{
return
not_this_pattern
;
}
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
%
params
.
dst_block_size
!=
0
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
%
params
.
src_block_size
!=
0
)
{
return
not_this_pattern
;
}
if
(
GetInt32Const
(
GetItem
(
params
.
dst_strides
,
-
2
))
/
params
.
dst_block_size
>=
MAX_STRIDE_M1
||
GetInt32Const
(
GetItem
(
params
.
src_strides
,
-
2
))
/
params
.
src_block_size
>=
MAX_STRIDE_M1
)
{
return
not_this_pattern
;
}
// check num of insns, select 1D pattern or 2D pattern
int
tail_factor
=
0
;
if
(
params
.
non_zero_shape1
/
params
.
dst_vec_max_len
>
0
&&
params
.
non_zero_shape1
%
params
.
dst_vec_max_len
>
0
)
{
tail_factor
=
1
;
}
int
offset_num
=
(
params
.
non_zero_shape1
+
params
.
dst_vec_max_len
-
1
)
/
params
.
dst_vec_max_len
*
params
.
non_zero_shape3
;
int
repeat_num
=
offset_num
*
params
.
non_zero_shape2
;
float
repeat_latency
=
(
std
::
max
(
params
.
non_zero_shape2
-
1
,
0
)
/
MAX_REPEAT
)
*
offset_num
*
repeat_latency_coef
;
float
offset_latency
=
offset_num
>
1
?
offset_num
*
offset_latency_coef
:
0
;
float
split_latency
=
tail_factor
*
split_latency_coef
;
float
rate2d
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
dst_vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
+
split_latency
);
return
rate2d
;
}
float
SingleVecPatternGenerator
::
Compute1DPatternMaskRate
()
{
int
tail_factor
=
0
;
if
(
params
.
non_zero_shape1
/
params
.
dst_vec_max_len
>
0
&&
params
.
non_zero_shape1
%
params
.
dst_vec_max_len
>
0
)
{
tail_factor
=
1
;
}
int
shape1
=
(
params
.
non_zero_shape1
+
params
.
dst_vec_max_len
-
1
)
/
params
.
dst_vec_max_len
;
int
repeat_num
=
shape1
*
params
.
non_zero_shape2
*
params
.
non_zero_shape3
;
float
repeat_latency
=
std
::
max
((
shape1
-
1
)
/
MAX_REPEAT
,
0
)
*
params
.
non_zero_shape2
*
params
.
non_zero_shape3
*
repeat_latency_coef
;
float
offset_latency
=
params
.
non_zero_shape2
*
params
.
non_zero_shape3
>
1
?
params
.
non_zero_shape2
*
params
.
non_zero_shape3
*
offset_latency_coef
:
0
;
float
split_latency
=
tail_factor
*
split_latency_coef
;
float
rate1d
=
static_cast
<
float
>
(
params
.
all_points
)
/
params
.
dst_vec_max_len
/
(
repeat_num
+
repeat_latency
+
offset_latency
+
split_latency
);
return
rate1d
;
}
Array
<
Var
>
SingleVecPatternGenerator
::
Get2DRepeatPattern
()
{
GetShapeInfoAndSwap
(
params
.
src_var
,
params
.
src_shape
,
params
.
src_strides
,
-
2
,
-
3
);
int
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape
,
-
1
));
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
dst_stride_m0_
=
1
;
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
1
};
body_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
dst_block_size
);
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides
,
-
2
),
params
.
src_block_size
)};
body_args
.
GetNode
()
->
repeat_
=
GetItem
(
params
.
dst_shape
,
-
2
);
int
data_len
=
CeilTo
(
last_dim_shape
,
params
.
dst_block_size
);
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
1
,
dst_info
->
dtype_
);
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
Array
<
Var
>
SingleVecPatternGenerator
::
Get3DsPattern
()
{
GetShapeInfoAndSwap
(
params
.
src_var
,
params
.
src_shape
,
params
.
src_strides
,
-
2
,
-
3
);
int
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape
,
-
1
));
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
Expr
dst_stride_m0
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
dst_block_size
);
Expr
src_stride_m0
=
truncdiv
(
GetItem
(
params
.
src_strides
,
-
2
),
params
.
src_block_size
);
body_args
.
GetNode
()
->
dst_stride_m0_
=
dst_stride_m0
;
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
src_stride_m0
};
int
block_num
=
0
;
int
data_len
=
CeilTo
(
last_dim_shape
,
params
.
mask_block_size
);
if
(
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
))
<=
FULL_BLOCK_NUM
)
{
block_num
=
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
2
));
body_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
3
),
params
.
dst_block_size
);
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides
,
-
3
),
params
.
src_block_size
)};
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
block_num
,
dst_info
->
dtype_
);
auto
repeat
=
GetItem
(
params
.
dst_shape
,
-
3
);
if
(
GetIntConst
(
repeat
)
<
MAX_STRIDE_M1
)
{
body_args
.
GetNode
()
->
repeat_
=
repeat
;
return
GetRange
(
params
.
dst_var
,
-
3
,
3
);
}
else
{
body_args
.
GetNode
()
->
repeat_
=
1
;
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
}
else
{
block_num
=
FULL_BLOCK_NUM
;
body_args
.
GetNode
()
->
dst_stride_m1_
=
dst_stride_m0
*
block_num
;
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
src_stride_m0
*
block_num
};
auto
repeat
=
truncdiv
(
GetItem
(
params
.
dst_shape
,
-
2
),
FULL_BLOCK_NUM
);
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
block_num
,
dst_info
->
dtype_
);
if
(
GetIntConst
(
repeat
)
<
MAX_STRIDE_M1
)
{
body_args
.
GetNode
()
->
repeat_
=
repeat
;
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
else
{
return
Get1DPattern
();
}
}
}
Array
<
Var
>
SingleVecPatternGenerator
::
Get3DPattern
()
{
if
(
GetIntConst
(
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape
,
-
2
)))
>
FULL_BLOCK_NUM
)
{
// split shape[-3]
if
(
GetIntConst
(
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
3
),
GetItem
(
params
.
src_shape
,
-
3
)))
>
FULL_BLOCK_NUM
)
{
StmtInfoList
info_list
=
{
dst_info
,
src_info
};
SplitAxis
(
info_list
,
for_info
,
GetItem
(
params
.
dst_var
,
-
3
),
FULL_BLOCK_NUM
);
FillEmptyVar
(
info_list
);
dst_info
=
info_list
[
0
];
src_info
=
info_list
[
1
];
params
.
dst_var
=
dst_info
->
var_
;
params
.
dst_shape
=
dst_info
->
shape_
;
params
.
dst_strides
=
dst_info
->
strides_
;
params
.
src_var
=
src_info
->
var_
;
params
.
src_shape
=
src_info
->
shape_
;
params
.
src_strides
=
src_info
->
strides_
;
}
// consider original shape[-2] as repeat axis
GetShapeInfoAndSwap
(
params
.
dst_var
,
params
.
dst_shape
,
params
.
dst_strides
,
-
2
,
-
3
);
GetShapeInfoAndSwap
(
params
.
src_var
,
params
.
src_shape
,
params
.
src_strides
,
-
2
,
-
3
);
}
int
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape
,
-
1
));
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
repeat_
=
make_const
(
Int
(
32
),
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
3
),
GetItem
(
params
.
src_shape
,
-
3
)));
body_args
.
GetNode
()
->
dst_stride_m0_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
dst_block_size
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
3
),
params
.
dst_block_size
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides
,
-
2
),
params
.
src_block_size
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides
,
-
3
),
params
.
src_block_size
)};
body_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
params
.
block_offset
);
int
data_len
=
CeilTo
(
last_dim_shape
,
params
.
mask_block_size
);
int
data_num
=
GetInt32Const
(
GetItem
(
params
.
dst_shape
,
-
2
));
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
data_num
,
dst_info
->
dtype_
,
params
.
block_offset
);
return
GetRange
(
params
.
dst_var
,
-
3
,
3
);
}
Array
<
Var
>
SingleVecPatternGenerator
::
Get2DBlockPattern
()
{
int
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape
,
-
1
));
int
repeat_len
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
2
),
GetItem
(
params
.
src_shape
,
-
2
));
int
repeat_body
=
repeat_len
/
FULL_BLOCK_NUM
;
int
repeat_tail
=
(
repeat_len
%
FULL_BLOCK_NUM
+
FULL_BLOCK_NUM
-
1
)
/
FULL_BLOCK_NUM
;
int
data_len
=
CeilTo
(
last_dim_shape
,
params
.
dst_block_size
);
if
(
repeat_body
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
repeat_
=
make_const
(
Int
(
32
),
repeat_body
);
auto
dst_stride_m0
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
dst_block_size
);
body_args
.
GetNode
()
->
dst_stride_m0_
=
dst_stride_m0
;
body_args
.
GetNode
()
->
dst_stride_m1_
=
dst_stride_m0
*
(
params
.
max_bits
/
params
.
src_bits
);
auto
src_stride_m0
=
truncdiv
(
GetItem
(
params
.
src_strides
,
-
2
),
params
.
src_block_size
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
src_stride_m0
};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
src_stride_m0
*
(
params
.
max_bits
/
params
.
dst_bits
)};
body_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
params
.
block_offset
);
int
data_num
=
FULL_BLOCK_NUM
;
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
data_num
,
dst_info
->
dtype_
,
params
.
block_offset
);
}
if
(
repeat_tail
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
tail_args
.
GetNode
()
->
dst_head_
=
GetItem
(
params
.
dst_strides
,
-
2
)
*
repeat_body
*
FULL_BLOCK_NUM
;
tail_args
.
GetNode
()
->
src_head_list_
=
{
GetItem
(
params
.
src_strides
,
-
2
)
*
repeat_body
*
FULL_BLOCK_NUM
};
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m0_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
dst_block_size
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
0
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides
,
-
2
),
params
.
src_block_size
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
0
)};
tail_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
params
.
block_offset
);
int
data_num
=
repeat_len
%
FULL_BLOCK_NUM
;
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
data_num
,
dst_info
->
dtype_
,
params
.
block_offset
);
}
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
Array
<
Var
>
SingleVecPatternGenerator
::
Get2DPattern
()
{
const
int
data_num
=
1
;
int
last_dim_shape
=
GetNonZeroShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape
,
-
1
));
if
(
GetInt32Const
(
GetItem
(
dst_info
->
strides_
,
-
1
))
==
params
.
dst_block_size
&&
IsTwoItemEqual
(
dst_info
->
strides_
,
src_info
->
strides_
,
-
1
,
true
))
{
last_dim_shape
*=
params
.
dst_block_size
;
}
int
body_len
=
FloorTo
(
last_dim_shape
,
params
.
vec_max_len
);
int
tail_len
=
last_dim_shape
%
params
.
vec_max_len
;
if
(
body_len
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
body_len
/
params
.
vec_max_len
;
body_args
.
GetNode
()
->
body_offset_
=
params
.
vec_max_len
;
body_args
.
GetNode
()
->
repeat_
=
GetItem
(
params
.
dst_shape
,
-
2
);
body_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
1
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
dst_block_size
);
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides
,
-
2
),
params
.
src_block_size
)};
body_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
params
.
block_offset
);
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
params
.
vec_max_len
,
data_num
,
data_type
,
params
.
block_offset
);
}
// get tail params
if
(
tail_len
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
body_len
);
tail_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
body_len
)};
tail_args
.
GetNode
()
->
repeat_
=
GetItem
(
params
.
dst_shape
,
-
2
);
tail_args
.
GetNode
()
->
dst_stride_m0_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m1_
=
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
2
),
params
.
dst_block_size
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
Expr
(
1
)};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
truncdiv
(
GetItem
(
params
.
src_strides
,
-
2
),
params
.
src_block_size
)};
tail_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
params
.
block_offset
);
int
data_len
=
CeilTo
(
tail_len
,
params
.
mask_block_size
);
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
data_num
,
data_type
,
params
.
block_offset
);
}
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
Array
<
Var
>
SingleVecPatternGenerator
::
Get1DPattern
()
{
int
last_dim_shape
;
bool
linear_mode
=
false
;
if
((
params
.
dst_shape
.
empty
()
&&
params
.
src_shape
.
empty
())
||
GetIntConst
(
GetItem
(
params
.
dst_shape
,
-
1
))
==
0
)
{
last_dim_shape
=
1
;
}
else
if
(
!
IsTwoItemEqual
(
params
.
dst_var
,
params
.
src_var
,
-
1
))
{
last_dim_shape
=
1
;
}
else
{
last_dim_shape
=
GetLastDimShape
(
GetItem
(
params
.
dst_shape
,
-
1
),
GetItem
(
params
.
src_shape
,
-
1
));
linear_mode
=
params
.
dst_bits
==
params
.
src_bits
;
}
bool
is_scalar_mode
=
IsScalarMode
({
dst_info
,
src_info
});
if
(
is_scalar_mode
&&
params
.
dst_bits
!=
params
.
src_bits
)
{
last_dim_shape
=
1
;
}
int
vec_max_len
=
is_scalar_mode
?
FULL_BLOCK_NUM
:
params
.
vec_max_len
;
int
body_len
=
FloorTo
(
last_dim_shape
,
vec_max_len
);
int
tail_len
=
last_dim_shape
%
vec_max_len
;
auto
dst_stride_m0
=
is_scalar_mode
&&
linear_mode
?
truncdiv
(
GetItem
(
params
.
dst_strides
,
-
1
),
params
.
dst_block_size
)
:
Expr
(
1
);
auto
src_stride_m0
=
is_scalar_mode
&&
linear_mode
?
truncdiv
(
GetItem
(
params
.
src_strides
,
-
1
),
params
.
src_block_size
)
:
Expr
(
1
);
if
(
body_len
>
0
)
{
body_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
body_args
.
GetNode
()
->
body_num_
=
1
;
body_args
.
GetNode
()
->
body_offset_
=
vec_max_len
;
body_args
.
GetNode
()
->
repeat_
=
Expr
(
body_len
/
vec_max_len
);
body_args
.
GetNode
()
->
dst_stride_m0_
=
dst_stride_m0
;
auto
dst_block_num
=
is_scalar_mode
?
FULL_BLOCK_NUM
:
(
params
.
max_bits
/
params
.
src_bits
);
body_args
.
GetNode
()
->
dst_stride_m1_
=
dst_stride_m0
*
dst_block_num
;
body_args
.
GetNode
()
->
src_stride_m0_list_
=
{
src_stride_m0
};
auto
src_block_num
=
is_scalar_mode
?
FULL_BLOCK_NUM
:
(
params
.
max_bits
/
params
.
dst_bits
);
body_args
.
GetNode
()
->
src_stride_m1_list_
=
{
src_stride_m0
*
src_block_num
};
body_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
params
.
block_offset
);
// in cast case, data_num should be 1 because dst and src bit is not equal
int
data_len
=
is_scalar_mode
?
1
:
vec_max_len
;
int
data_num
=
is_scalar_mode
?
FULL_BLOCK_NUM
:
1
;
body_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
data_num
,
data_type
,
params
.
block_offset
);
}
// get tail params
if
(
tail_len
>
0
)
{
tail_args
=
VectorArgInfo
(
make_node
<
VectorArgInfoNode
>
());
tail_args
.
GetNode
()
->
body_offset_
=
vec_max_len
;
tail_args
.
GetNode
()
->
body_num_
=
1
;
tail_args
.
GetNode
()
->
dst_head_
=
Expr
(
body_len
*
(
is_scalar_mode
?
dst_stride_m0
*
params
.
dst_block_size
:
Expr
(
1
)));
tail_args
.
GetNode
()
->
src_head_list_
=
{
Expr
(
body_len
*
(
is_scalar_mode
?
src_stride_m0
*
params
.
src_block_size
:
Expr
(
1
)))};
tail_args
.
GetNode
()
->
repeat_
=
Expr
(
1
);
tail_args
.
GetNode
()
->
dst_stride_m0_
=
dst_stride_m0
;
tail_args
.
GetNode
()
->
dst_stride_m1_
=
Expr
(
0
);
tail_args
.
GetNode
()
->
src_stride_m0_list_
=
{
src_stride_m0
};
tail_args
.
GetNode
()
->
src_stride_m1_list_
=
{
Expr
(
0
)};
tail_args
.
GetNode
()
->
block_offset_
=
make_const
(
Int
(
32
),
params
.
block_offset
);
int
data_len
=
is_scalar_mode
&&
linear_mode
?
1
:
CeilTo
(
tail_len
,
params
.
mask_block_size
);
int
data_num
=
is_scalar_mode
&&
linear_mode
?
tail_len
:
1
;
data_num
=
data_num
==
0
?
1
:
data_num
;
tail_args
.
GetNode
()
->
vec_mask_
=
GetVecMask
(
data_len
,
data_num
,
data_type
,
params
.
block_offset
);
}
// compute offset for cce instructions
Array
<
Var
>
elim_var
=
{};
if
(
mode
==
"elewise"
&&
params
.
dst_var
.
size
()
>=
2
&&
params
.
dst_strides
.
size
()
>=
2
&&
for_info
.
ops_
.
size
()
>=
2
&&
last_dim_shape
<=
vec_max_len
&&
last_dim_shape
>=
vec_max_len
-
params
.
dst_block_size
&&
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
2
))
==
last_dim_shape
)
{
// in this case we can merge second last for extent to repeat
size_t
index
=
0
;
bool
suc
=
GetIndexOfElement
(
for_info
.
vars_
,
GetItem
(
params
.
dst_var
,
-
2
),
index
);
CHECK
(
suc
);
auto
latest_for
=
GetItem
(
for_info
.
ops_
,
index
).
as
<
For
>
();
// there should not be if_op between for loop and compute stmt
if
(
latest_for
&&
!
latest_for
->
body
->
IsInstance
<
IfThenElse
>
())
{
if
(
!
params
.
dst_var
.
empty
()
&&
(
!
is_scalar_mode
||
last_dim_shape
!=
1
))
{
if
(
body_args
.
defined
())
{
// last_dim_shape = vec_max_len
body_args
.
GetNode
()
->
repeat_
=
body_args
->
repeat_
*
latest_for
->
extent
;
}
else
if
(
tail_args
.
defined
())
{
// last_dim_shape < vec_max_len
tail_args
.
GetNode
()
->
repeat_
=
tail_args
->
repeat_
*
latest_for
->
extent
;
}
return
GetRange
(
params
.
dst_var
,
-
2
,
2
);
}
}
}
if
(
!
params
.
dst_var
.
empty
()
&&
(
!
is_scalar_mode
||
last_dim_shape
!=
1
||
linear_mode
)
&&
GetIntConst
(
GetItem
(
params
.
dst_strides
,
-
1
))
>
0
&&
(
params
.
src_var
.
empty
()
||
IsTwoItemEqual
(
params
.
dst_var
,
params
.
src_var
,
-
1
)))
{
elim_var
=
GetRange
(
params
.
dst_var
,
-
1
,
1
);
}
return
elim_var
;
}
PatternResult
SingleVecPatternGenerator
::
GenResult
(
const
Array
<
Var
>
&
elim_var
)
{
arg_info
.
GetNode
()
->
body_arg_info_
=
body_args
;
arg_info
.
GetNode
()
->
tail_arg_info_
=
tail_args
;
dst_info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
dst_info
,
elim_var
);
src_info
.
GetNode
()
->
insn_offset_
=
GetInsnOffset
(
src_info
,
elim_var
);
CleanForInfoVars
(
for_info
,
elim_var
);
StmtInfoList
info_list
=
{
dst_info
,
src_info
};
CleanZeroStrides
(
info_list
);
dst_info
=
info_list
[
0
];
src_info
=
info_list
[
1
];
PatternResult
result
;
result
.
dst_info_list
=
{
dst_info
};
result
.
src_info_list
=
{
src_info
};
result
.
for_info
=
for_info
;
result
.
arg_info
=
arg_info
;
return
result
;
}
}
// namespace akg
src/pass/emit_insn.cc
浏览文件 @
0871623a
...
...
@@ -21,6 +21,7 @@
#include "pass/ir_util.h"
#include "poly/poly_util.h"
#include "emit_insn/insn_emitter.h"
#include "emit_insn/ir_transform.h"
namespace
akg
{
namespace
ir
{
...
...
@@ -475,6 +476,7 @@ Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Ma
}
stmt
=
UnalignedMad
().
Mutate
(
stmt
);
stmt
=
RegCondition
().
Mutate
(
stmt
);
stmt
=
ForVarUnique
().
Mutate
(
stmt
);
return
stmt
;
}
}
// namespace ir
...
...
src/pass/multi_last_axis_reduction.cc
浏览文件 @
0871623a
...
...
@@ -343,8 +343,12 @@ class BroadcastCalculate : public IRMutator {
};
Stmt
MultiLastAxisReductions
(
Stmt
stmt
,
bool
is_dynamic
=
false
)
{
auto
ori_stmt
=
stmt
;
stmt
=
MultiLastAxisReduction
().
Mutate
(
stmt
);
stmt
=
BroadcastCalculate
(
is_dynamic
).
Mutate
(
stmt
);
if
(
!
is_dynamic
&&
!
Equal
(
ori_stmt
,
stmt
))
{
stmt
=
MergeLoops
(
stmt
);
}
return
stmt
;
}
}
// namespace ir
...
...
src/pass/split_tail_block.cc
浏览文件 @
0871623a
...
...
@@ -21,7 +21,7 @@
#include <algorithm>
#include "emit_insn/insn_info.h"
#include "emit_insn/insn_pattern.h"
#include "emit_insn/insn_args_calculator.h"
namespace
akg
{
namespace
ir
{
...
...
@@ -48,85 +48,63 @@ class TailSpliter : public IRMutator {
if
(
src_info_list
.
empty
())
{
src_info_list
=
{
dst_info
.
Copy
()};
}
auto
get_info_list
=
[](
const
StmtStoreInfo
&
dst_info
,
const
Array
<
StmtStoreInfo
>
&
src_info_list
)
{
Array
<
StmtStoreInfo
>
res
;
res
.
push_back
(
dst_info
.
Copy
());
for
(
auto
it
:
src_info_list
)
{
res
.
push_back
(
it
.
Copy
());
}
return
res
;
};
auto
info_list
=
get_info_list
(
dst_info
,
src_info_list
);
FillEmptyVar
(
info_list
);
auto
axis_list
=
GetAixsList
(
for_info
,
info_list
);
auto
get_last_axis_it
=
[](
const
std
::
list
<
InsnAxis
>
&
axis_list
)
{
for
(
auto
it
=
axis_list
.
begin
();
it
!=
axis_list
.
end
();
it
++
)
{
auto
stride_list
=
it
->
stride_list
;
if
(
!
(
std
::
any_of
(
stride_list
.
begin
(),
stride_list
.
end
(),
[](
int
stride
)
{
return
stride
>
1
;
})
||
std
::
all_of
(
stride_list
.
begin
(),
stride_list
.
end
(),
[](
int
stride
)
{
return
stride
==
0
;
})))
{
return
it
;
}
}
return
axis_list
.
end
();
};
auto
last_axis_it
=
get_last_axis_it
(
axis_list
);
if
(
last_axis_it
==
axis_list
.
end
())
{
return
s
;
}
auto
last_axis
=
*
last_axis_it
;
auto
last_axis_shape
=
last_axis
.
extent
;
auto
info_list
=
GetInfoList
(
dst_info
,
src_info_list
);
FillEmptyVar
(
info_list
);
int
dst_block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
int
src_block_size
=
GetUbBlkSize
(
src_info_list
[
0
]
->
dtype_
);
int
block_size
=
dst_block_size
>
src_block_size
?
dst_block_size
:
src_block_size
;
int
block_size
=
dst_block_size
<
src_block_size
?
dst_block_size
:
src_block_size
;
int
cast_block_size
=
dst_block_size
>
src_block_size
?
dst_block_size
:
src_block_size
;
int
vec_max_len
=
block_size
*
FULL_BLOCK_NUM
;
if
(
last_axis_shape
>
vec_max_len
&&
last_axis_shape
%
vec_max_len
!=
0
)
{
return
Block
::
make
(
TailMake
(
s
,
last_axis
,
vec_max_len
,
false
),
TailMake
(
s
,
last_axis
,
vec_max_len
,
true
));
}
if
(
last_axis_shape
<
vec_max_len
*
tail_rate_
&&
last_axis_shape
>
block_size
&&
last_axis_shape
%
block_size
!=
0
&&
axis_list
.
size
()
>
1
)
{
return
Block
::
make
(
TailMake
(
s
,
last_axis
,
block_size
,
false
),
TailMake
(
s
,
last_axis
,
block_size
,
true
));
}
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
std
::
list
<
InsnAxis
>
GetAixsList
(
const
StmtInfo
&
for_info
,
const
Array
<
StmtStoreInfo
>
&
info_list
)
{
std
::
list
<
InsnAxis
>
axis_list
;
auto
GetStrideByAxis
=
[](
const
Array
<
Var
>
&
vars
,
const
Array
<
Expr
>
&
strides
,
Var
obj_var
)
{
int
index
=
0
;
for
(
auto
var_it
:
vars
)
{
if
(
Equal
(
var_it
,
obj_var
))
{
return
strides
[
index
];
auto
args_calculator
=
InsnArgsCalculator
(
dst_info_list
,
src_info_list
,
for_info
,
""
);
auto
vec_axis_it
=
args_calculator
.
GetVecAxisIt
();
bool
cast
=
dst_block_size
!=
src_block_size
;
if
(
args_calculator
.
IsValid
(
vec_axis_it
))
{
auto
vec_axis
=
*
vec_axis_it
;
auto
vec_axis_shape
=
vec_axis
.
extent
;
if
(
vec_axis_shape
>=
vec_max_len
)
{
if
(
vec_axis_shape
%
vec_max_len
!=
0
)
{
return
TailBlock
(
s
,
vec_axis
,
vec_max_len
);
}
}
else
{
if
(
vec_axis_shape
<
vec_max_len
*
tail_rate_
&&
vec_axis_shape
>
cast_block_size
&&
vec_axis_shape
%
cast_block_size
!=
0
&&
args_calculator
.
axis_list_
.
size
()
>
1
)
{
return
TailBlock
(
s
,
vec_axis
,
cast_block_size
);
}
}
index
++
;
}
return
Expr
(
0
);
};
for
(
auto
it
:
for_info
.
ops_
)
{
InsnAxis
axis
;
auto
for_stmt
=
it
.
as
<
For
>
();
CHECK
(
for_stmt
);
axis
.
var
=
for_stmt
->
loop_var
;
axis
.
extent
=
GetInt32Const
(
for_stmt
->
extent
);
axis
.
min
=
GetInt32Const
(
for_stmt
->
min
);
int
index
=
0
;
for
(
auto
it
:
info_list
)
{
auto
stride
=
GetInt32Const
(
GetStrideByAxis
(
it
->
var_
,
it
->
strides_
,
axis
.
var
));
axis
.
stride_list
.
push_back
(
stride
);
if
(
index
==
0
)
{
axis
.
dst_stride
=
stride
;
}
else
{
axis
.
src_stride_list
.
push_back
(
stride
);
if
(
!
cast
&&
(
!
args_calculator
.
IsValid
(
vec_axis_it
)
||
vec_axis_it
->
extent
<=
cast_block_size
*
tail_rate_
))
{
auto
get_block_axis
=
[
&
](
std
::
list
<
InsnAxis
>
&
axis_list
)
{
InsnAxis
block_axis
;
block_axis
.
is_valid
=
false
;
std
::
vector
<
InsnAxis
>
temp_axis_set
;
auto
block_stride_lambda
=
[
&
](
int
stride
)
{
return
stride
%
block_size
==
0
&&
stride
/
block_size
<=
4
;
};
for
(
auto
axis
:
axis_list
)
{
if
(
std
::
all_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
block_stride_lambda
)
&&
axis
.
dst_stride
!=
0
&&
axis
.
extent
!=
0
&&
axis
.
extent
>
FULL_BLOCK_NUM
&&
axis
.
extent
%
FULL_BLOCK_NUM
!=
0
)
{
temp_axis_set
.
push_back
(
axis
);
}
}
if
(
!
temp_axis_set
.
empty
())
{
return
temp_axis_set
[
0
];
}
else
{
return
block_axis
;
}
};
auto
block_axis
=
get_block_axis
(
args_calculator
.
axis_list_
);
if
(
block_axis
.
IsValid
())
{
return
TailBlock
(
s
,
block_axis
,
FULL_BLOCK_NUM
);
}
index
++
;
}
axis_list
.
push_back
(
axis
)
;
return
s
;
}
return
axis_list
;
return
IRMutator
::
Mutate_
(
op
,
s
)
;
}
Stmt
TailBlock
(
const
Stmt
&
s
,
const
InsnAxis
&
tail_axis
,
int
body_size
)
{
return
Block
::
make
(
TailMake
(
s
,
tail_axis
,
body_size
,
false
),
TailMake
(
s
,
tail_axis
,
body_size
,
true
));
}
Stmt
TailMake
(
const
Stmt
&
s
,
const
InsnAxis
&
tail_axis
,
int
body_size
,
bool
is_tail
)
{
if
(
auto
attr_stmt
=
s
.
as
<
AttrStmt
>
())
{
return
AttrStmt
::
make
(
attr_stmt
->
node
,
attr_stmt
->
attr_key
,
attr_stmt
->
value
,
...
...
@@ -145,8 +123,7 @@ class TailSpliter : public IRMutator {
}
return
For
::
make
(
for_stmt
->
loop_var
,
for_stmt
->
min
,
for_stmt
->
extent
,
for_stmt
->
for_type
,
for_stmt
->
device_api
,
TailMake
(
for_stmt
->
body
,
tail_axis
,
body_size
,
is_tail
));
}
}
if
(
s
.
as
<
Store
>
()
&&
is_tail
)
{
return
substitute
(
tail_axis
.
var
,
Add
::
make
(
Expr
(
tail_axis
.
extent
/
body_size
*
body_size
),
tail_axis
.
var
),
s
);
}
...
...
@@ -156,6 +133,20 @@ class TailSpliter : public IRMutator {
private:
const
float
tail_rate_
{
0.6
};
const
std
::
set
<
std
::
string
>
include_intrin_list_
=
{
// binary vec
"vec_binary_add"
,
"vec_binary_sub"
,
"vec_binary_mul"
,
"vec_binary_min"
,
"vec_binary_max"
,
"vec_binary_div"
,
"vec_binary_and"
,
"vec_binary_or"
,
"vec_binary_vmadd"
,
"vec_binary_vmaddrelu"
,
"vec_binary_vmla"
,
// single vec
"vec_single_fabs"
,
"vec_single_log"
,
"vec_single_exp"
,
...
...
@@ -165,20 +156,28 @@ class TailSpliter : public IRMutator {
"vec_single_rsqrt"
,
"vec_single_relu"
,
"vec_single_not"
,
// vector_scalar
"vec_single_muls"
,
"vec_single_adds"
,
// Mov
"broadcast"
,
"mask_broadcast"
,
// vector_cast
"vec_single_cast"
,
"vec_single_floor"
,
"vec_single_round"
,
"vec_single_ceil"
,
"vec_single_trunc"
,
// scalar case
"vector_dup"
,
"vmuls"
,
"vadds"
,
"vaxpy"
,
};
};
Stmt
SplitTail
(
Stmt
stmt
)
{
return
TailSpliter
().
Mutate
(
stmt
);
}
Stmt
SplitTail
(
Stmt
stmt
)
{
auto
tail_spliter
=
TailSpliter
();
auto
first_round
=
tail_spliter
.
Mutate
(
stmt
);
auto
second_round
=
tail_spliter
.
Mutate
(
stmt
);
return
second_round
;
}
}
// namespace ir
}
// namespace akg
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录