Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
akg
提交
0871623a
A
akg
项目概览
MindSpore
/
akg
通知
58
Star
7
Fork
7
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
akg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0871623a
编写于
8月 24, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 24, 2020
浏览文件
操作
浏览文件
下载
差异文件
!110 rewrite insn pattern generator in EmitInsn
Merge pull request !110 from wYann/insn_reduction_final
上级
d237aa7d
aec205ff
变更
12
展开全部
显示空白变更内容
内联
并排
Showing
12 changed file
with
1432 addition
and
2512 deletion
+1432
-2512
src/emit_insn/insn_args_calculator.cc
src/emit_insn/insn_args_calculator.cc
+1089
-0
src/emit_insn/insn_args_calculator.h
src/emit_insn/insn_args_calculator.h
+199
-0
src/emit_insn/insn_binary_vec_pattern.cc
src/emit_insn/insn_binary_vec_pattern.cc
+0
-1335
src/emit_insn/insn_emitter.cc
src/emit_insn/insn_emitter.cc
+26
-30
src/emit_insn/insn_info.cc
src/emit_insn/insn_info.cc
+1
-4
src/emit_insn/insn_info.h
src/emit_insn/insn_info.h
+1
-9
src/emit_insn/insn_pattern.cc
src/emit_insn/insn_pattern.cc
+32
-42
src/emit_insn/insn_pattern.h
src/emit_insn/insn_pattern.h
+4
-215
src/emit_insn/insn_single_vec_pattern.cc
src/emit_insn/insn_single_vec_pattern.cc
+0
-802
src/pass/emit_insn.cc
src/pass/emit_insn.cc
+2
-0
src/pass/multi_last_axis_reduction.cc
src/pass/multi_last_axis_reduction.cc
+4
-0
src/pass/split_tail_block.cc
src/pass/split_tail_block.cc
+74
-75
未找到文件。
src/emit_insn/insn_args_calculator.cc
0 → 100644
浏览文件 @
0871623a
此差异已折叠。
点击以展开。
src/emit_insn/insn_args_calculator.h
0 → 100644
浏览文件 @
0871623a
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef EMIT_INSN_ARGS_CALCULATOR_H_
#define EMIT_INSN_ARGS_CALCULATOR_H_
namespace
akg
{
struct
InsnArg
{
int
dst_m0
{
1
};
int
dst_m1
{
0
};
std
::
vector
<
Expr
>
src_m0_list
;
std
::
vector
<
Expr
>
src_m1_list
;
int
repeat
{
1
};
int
block_len
{
1
};
int
block_num
{
1
};
int
body_num
{
1
};
int
tail_len
{
0
};
int
dst_tail_offset
{
0
};
std
::
vector
<
Expr
>
src_tail_offset_list
;
};
struct
Meta
{
int
block_size
{
0
};
int
src_block_size
{
0
};
int
dst_block_size
{
0
};
int
block_offset
{
0
};
const
float
vec_rate
{
0.6
};
Type
src_dtype
;
Type
dst_dtype
;
Type
dtype
;
bool
cast
{
false
};
bool
tail
{
false
};
bool
scalar
{
false
};
bool
liner
{
false
};
bool
same_dst_src
{
false
};
};
enum
SplitStat
{
SUCCESS
,
NO_SPLIT
,
TAIL
};
class
InsnAxis
{
public:
InsnAxis
()
=
default
;
InsnAxis
(
const
For
*
for_stmt
,
const
Array
<
StmtStoreInfo
>
&
info_list
);
virtual
~
InsnAxis
()
=
default
;
bool
IsValid
();
void
Print
(
const
std
::
string
&
name
=
""
);
int
min
{
0
};
int
extent
{
0
};
Var
var
;
int
dst_stride
{
0
};
int
src_stride
{
0
};
std
::
vector
<
int
>
src_stride_list
;
std
::
vector
<
int
>
stride_list
;
bool
is_valid
{
true
};
private:
Expr
GetStrideByAxis
(
const
Array
<
Var
>
&
vars
,
const
Array
<
Expr
>
&
strides
,
Var
obj_var
);
};
using
AxisIt
=
std
::
list
<
InsnAxis
>::
iterator
;
std
::
list
<
InsnAxis
>
GetAxisList
(
const
StmtInfo
&
for_info
,
const
Array
<
StmtStoreInfo
>
&
info_list
);
Array
<
StmtStoreInfo
>
GetInfoList
(
const
StmtStoreInfo
&
dst_info
,
const
Array
<
StmtStoreInfo
>
&
src_info_list
);
int
DivFloor
(
int
a
,
int
b
);
void
Print
(
std
::
list
<
InsnAxis
>
&
axis_list
);
class
InsnArgsCalculator
{
public:
InsnArgsCalculator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
);
virtual
~
InsnArgsCalculator
()
=
default
;
PatternResult
ExportResult
();
void
CalAxis
();
void
InitArg
();
virtual
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetStrideLambda
();
virtual
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetM0LimitLambda
();
virtual
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetM1LimitLambda
();
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetBlockStrideLimitLambda
();
AxisIt
GetAxisByLambda
(
const
std
::
function
<
bool
(
const
InsnAxis
&
)
>
&
lambda
);
InsnAxis
ExtractAxis
(
AxisIt
&
it
);
bool
IsValid
(
AxisIt
&
it
);
AxisIt
GetVecAxisIt
();
AxisIt
GetBlockAxis
();
AxisIt
GetRepeatAxisIt
();
InsnAxis
GetRepeatAxis
();
void
SetArgMask
(
int
len
);
void
SetArgBlockNum
(
int
data_num
);
void
SetArgBlockLen
(
int
data_len
);
void
SetArgM0
(
int
dst_m0
,
int
lsrc_m0
,
int
rsrc_m0
);
void
SetArgM1
(
int
dst_m1
,
int
lsrc_m1
,
int
rsrc_m1
);
void
SetArgRepeat
(
int
repeat
);
void
BlockAxisReduction
();
void
RepeatAxisReduction
();
void
CastCaseReduction
();
virtual
void
InsnReduction
();
StmtInfo
ExportForInfo
();
Expr
GetOffset
(
int
stride_index
);
InsnAxis
GetInvalidAxis
();
SplitStat
SplitAxis
(
int
extent
,
InsnAxis
&
axis
);
std
::
list
<
InsnAxis
>
axis_list_
;
protected:
InsnArg
arg_
;
Meta
meta_
;
StmtInfoList
dst_info_list_
;
StmtInfoList
src_info_list_
;
StmtStoreInfo
dst_info_
;
StmtInfo
for_info_
;
const
std
::
string
intrin_name_
;
const
int
max_block_stride_
{
4
};
};
class
SingleVecInsnArgsCalculator
:
public
InsnArgsCalculator
{
public:
SingleVecInsnArgsCalculator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
=
""
);
virtual
~
SingleVecInsnArgsCalculator
()
override
=
default
;
PatternResult
GetInsnArgs
();
};
class
BinaryVecInsnArgsCalculator
:
public
InsnArgsCalculator
{
public:
BinaryVecInsnArgsCalculator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
mode
,
const
std
::
string
&
intrin_name
=
""
,
bool
expand_mask
=
true
);
virtual
~
BinaryVecInsnArgsCalculator
()
override
=
default
;
PatternResult
GetInsnArgs
();
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetM0LimitLambda
();
std
::
function
<
bool
(
const
InsnAxis
&
)
>
GetM1LimitLambda
();
void
InsnReduction
();
private:
std
::
string
mode_
;
bool
expand_mask_
;
InsnAxis
vec_axis_
;
};
class
LastAxisReduceInsnArgsCalculator
:
InsnArgsCalculator
{
public:
LastAxisReduceInsnArgsCalculator
(
const
StmtStoreInfo
&
dst_info
,
const
StmtStoreInfo
&
src_info
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
)
:
InsnArgsCalculator
({
dst_info
},
{
src_info
},
for_info
,
intrin_name
),
dst_info
(
dst_info
),
src_info
(
src_info
),
for_info
(
for_info
),
arg_info
(
ArgInfo
(
make_node
<
ArgInfoNode
>
())),
body_args
(
VectorArgInfo
()),
tail_args
(
VectorArgInfo
()),
intrin_name
(
intrin_name
)
{}
PatternResult
GetInsnArgs
();
~
LastAxisReduceInsnArgsCalculator
()
=
default
;
protected:
Array
<
Var
>
GetPattern
();
PatternResult
GenResult
(
const
Array
<
Var
>
&
elim_var
);
private:
void
CalcParams
();
struct
Params
{
Array
<
Var
>
src_var
;
int
block_size
=
0
;
int
vec_max_len
=
0
;
int
last_dim_shape
=
0
;
Expr
insn_offset_scale_factor
;
};
StmtStoreInfo
dst_info
;
StmtStoreInfo
src_info
;
StmtInfo
for_info
;
ArgInfo
arg_info
;
VectorArgInfo
body_args
;
VectorArgInfo
tail_args
;
Array
<
VectorArgInfo
>
mix_vec_arg_list
;
std
::
string
intrin_name
;
Params
params
;
};
BisectionInfoWrapper
SeparateComInfoToBisectionInfoList
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
StmtInfo
&
if_info
,
bool
last_axis
,
int
postfix
);
ArgInfo
GetBinaryVecInsnArgs
(
const
Stmt
&
stmt
,
std
::
string
intrin_name
,
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
bool
enable_bisect
=
true
);
}
// namespace akg
#endif
\ No newline at end of file
src/emit_insn/insn_binary_vec_pattern.cc
已删除
100644 → 0
浏览文件 @
d237aa7d
此差异已折叠。
点击以展开。
src/emit_insn/insn_emitter.cc
浏览文件 @
0871623a
...
...
@@ -35,7 +35,7 @@
#include "insn_info.h"
#include "insn_pattern.h"
#include "insn_emitter_multimask.h"
#include "insn_args_calculator.h"
namespace
akg
{
namespace
ir
{
/// Sort indexes
...
...
@@ -71,8 +71,7 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
Array
<
Expr
>
call_args
;
int
call_cnt
=
0
;
if
(
intrin_name
==
"vector_dup"
||
intrin_name
==
"vadds"
||
intrin_name
==
"vmuls"
||
intrin_name
==
"vaxpy"
)
{
if
(
intrin_name
==
"vector_dup"
||
intrin_name
==
"vadds"
||
intrin_name
==
"vmuls"
||
intrin_name
==
"vaxpy"
)
{
auto
GetCallInfo
=
[
&
intrin_name
,
&
call_args
,
&
call_cnt
](
const
NodeRef
&
op
)
{
if
(
op
.
as
<
Call
>
()
&&
op
.
as
<
Call
>
()
->
name
==
intrin_name
)
{
call_args
=
op
.
as
<
Call
>
()
->
args
;
...
...
@@ -82,8 +81,8 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
PostOrderVisit
(
op
,
GetCallInfo
);
CHECK_EQ
(
call_cnt
,
1
);
}
SingleType
insn_type
{
SingleType
::
SIMD
};
Expr
scalar_src
{};
SingleType
insn_type
{
SingleType
::
SIMD
};
Expr
scalar_src
{};
if
(
intrin_name
==
"vector_dup"
)
{
insn_type
=
SingleType
::
Vector_Dump
;
src_info_list
=
{};
...
...
@@ -93,10 +92,11 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
src_info_list
=
{
src_info_list
[
0
]};
scalar_src
=
call_args
[
1
];
}
// check is single vector broadcast reduce mode exist
SingleVecPatternGenerator
generator
=
SingleVecPatternGenerator
(
dst_info_list
,
src_info_list
,
for_info
);
auto
params
=
generator
.
GetInsnArgs
();
SingleVecInsnArgsCalculator
args_calculator
=
SingleVecInsnArgsCalculator
(
dst_info_list
,
src_info_list
,
for_info
,
intrin_name
);
PatternResult
params
=
args_calculator
.
GetInsnArgs
();
dst_info_list
=
params
.
dst_info_list
;
src_info_list
=
params
.
src_info_list
;
for_info
=
params
.
for_info
;
...
...
@@ -141,24 +141,17 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec
if
(
src_info_list
[
0
]
->
var_
.
size
()
>
src_info_list
[
1
]
->
var_
.
size
())
{
src_info
=
src_info_list
[
0
];
}
const
int
vec_max_len
=
GetVecMaxLen
(
dst_info
->
dtype_
);
if
(
enable_bisect
&&
GetIntConst
(
GetItem
(
src_info
->
shape_
,
-
1
))
>
vec_max_len
)
{
CommentManager
::
GetInstance
().
AddComment
(
"Bisect_optimize"
,
"enabled"
);
auto
wrapper
=
SeparateComInfoToBisectionInfoList
(
dst_info_list
,
src_info_list
,
for_info
,
if_info
,
true
,
postfix
);
return
EmitCceBinaryVectorToBisectionReduction
(
wrapper
,
if_info
,
intrin_name
);
}
else
{
CommentManager
::
GetInstance
().
AddComment
(
"Pattern"
,
arg_info
.
GetPattern
());
ReduceLastAxisPatternGenerator
generator
=
ReduceLastAxisPatternGenerator
(
dst_info
,
src_info
,
for_info
,
intrin_name
);
auto
result
=
generator
.
GetInsnArgs
();
LastAxisReduceInsnArgsCalculator
args_calculator
=
LastAxisReduceInsnArgsCalculator
(
dst_info
,
src_info
,
for_info
,
intrin_name
);
PatternResult
result
=
args_calculator
.
GetInsnArgs
();
arg_info
=
result
.
arg_info
;
dst_info
=
result
.
dst_info_list
[
0
];
src_info
=
result
.
src_info_list
[
0
];
for_info
=
result
.
for_info
;
return
EmitCceBinaryVectorToReduceLastAxis
(
dst_info
,
src_info
,
if_info
,
for_info
,
arg_info
,
intrin_name
);
}
}
case
ARG_VECTOR_REDUCTION_BISECTION
:
{
CommentManager
::
GetInstance
().
AddComment
(
"Compute_type"
,
"reduction"
);
CommentManager
::
GetInstance
().
AddComment
(
"Bisect_optimize"
,
"enabled"
);
...
...
@@ -192,7 +185,7 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec
return
FoldInsnWithForInfo
(
insn_list
,
if_info
,
for_info
,
stmt
);
}
}
}
}
// namespace ir
/// Function to emit scalar intrin
/// \param op - The input stmt to be emitted as intrin
...
...
@@ -984,8 +977,9 @@ Stmt BinaryDropoutEmitter(const Stmt &op) {
src1
.
GetNode
()
->
data_
=
mask
->
buffer_var
;
src1
.
GetNode
()
->
data_alignment_
=
GetInt32Const
(
mask
->
predicate
);
SingleVecPatternGenerator
generator
=
SingleVecPatternGenerator
(
dst_info_list
,
src_info_list
,
for_info
,
"elewise"
);
auto
params
=
generator
.
GetInsnArgs
();
SingleVecInsnArgsCalculator
args_calculator
=
SingleVecInsnArgsCalculator
(
dst_info_list
,
src_info_list
,
for_info
);
PatternResult
params
=
args_calculator
.
GetInsnArgs
();
dst_info_list
=
params
.
dst_info_list
;
src_info_list
=
params
.
src_info_list
;
for_info
=
params
.
for_info
;
...
...
@@ -1484,8 +1478,10 @@ Stmt BinaryArgOpEmitter(const Stmt &op, const std::string &intrin_name) {
if
(
src_info_list
[
0
]
->
var_
.
size
()
>
src_info_list
[
1
]
->
var_
.
size
())
{
src_info
=
src_info_list
[
0
];
}
ReduceLastAxisPatternGenerator
generator
=
ReduceLastAxisPatternGenerator
(
dst_info
,
src_info
,
for_info
,
intrin_name
);
auto
result
=
generator
.
GetInsnArgs
();
LastAxisReduceInsnArgsCalculator
args_calculator
=
LastAxisReduceInsnArgsCalculator
(
dst_info
,
src_info
,
for_info
,
intrin_name
);
PatternResult
result
=
args_calculator
.
GetInsnArgs
();
arg_info
=
result
.
arg_info
;
dst_info
=
result
.
dst_info_list
[
0
];
src_info
=
result
.
src_info_list
[
0
];
...
...
src/emit_insn/insn_info.cc
浏览文件 @
0871623a
...
...
@@ -104,10 +104,7 @@ StmtStoreInfo StmtStoreInfo::Copy() const {
StmtInfo
StmtInfo
::
Copy
()
const
{
auto
stmt_info
=
StmtInfo
();
stmt_info
.
ops_
=
ops_
;
for
(
auto
var
:
vars_
)
{
auto
new_var
=
Variable
::
make
(
var
->
type
,
var
->
name_hint
);
stmt_info
.
vars_
.
push_back
(
new_var
);
}
stmt_info
.
vars_
=
vars_
;
for
(
size_t
i
=
0
;
i
<
vars_
.
size
();
++
i
)
{
for
(
size_t
j
=
0
;
j
<
stmt_info
.
ops_
.
size
();
++
j
)
{
...
...
src/emit_insn/insn_info.h
浏览文件 @
0871623a
...
...
@@ -276,15 +276,7 @@ struct BisectionInfoWrapper {
Map
<
std
::
string
,
Expr
>
dma_arg_info_map_
;
};
struct
InsnAxis
{
int
min
{
0
};
int
extent
{
0
};
Var
var
;
int
dst_stride
{
0
};
int
src_stride
{
0
};
std
::
list
<
int
>
src_stride_list
;
std
::
list
<
int
>
stride_list
;
};
IterVar
GetCceAxis
();
...
...
src/emit_insn/insn_pattern.cc
浏览文件 @
0871623a
...
...
@@ -15,7 +15,6 @@
*/
#include "insn_pattern.h"
#include <tvm/runtime/packed_func.h>
#include <tvm/base.h>
#include <tvm/ir_pass.h>
...
...
@@ -200,28 +199,6 @@ ArgInfo GetMultiVecInsnArgs(StmtInfoList &dst_info_list, StmtInfoList &src_info_
return
arg_info
;
}
/// Get first non zero shape from input shapes
/// \param dst_shape
/// \param src0_shape
/// \param src1_shape
/// \return
int
PatternGenerator
::
GetNonZeroShape
(
const
Expr
&
dst_shape
,
const
Expr
&
src0_shape
,
const
Expr
&
src1_shape
)
{
int
shape
=
0
;
for
(
int
val
:
{
GetInt32Const
(
dst_shape
),
GetInt32Const
(
src0_shape
),
src1_shape
.
defined
()
?
GetInt32Const
(
src1_shape
)
:
0
})
{
if
(
val
==
0
)
{
continue
;
}
if
(
shape
!=
0
&&
val
!=
shape
)
{
LOG
(
FATAL
)
<<
"Error: same var has different shapes. "
<<
GetIntConst
(
dst_shape
)
<<
" "
<<
GetIntConst
(
src0_shape
);
}
shape
=
val
;
}
CHECK
(
shape
!=
0
)
<<
"Error: all shapes are equal to 0."
;
return
shape
;
}
/// In case
/// for (cc3) {
/// A[(cc3*16)] = (B[(cc3*16)] - C[(cc3*16)])
...
...
@@ -432,25 +409,6 @@ void CleanZeroStrides(Array<StmtStoreInfo> &info_list) {
}
}
/// Swap axis in Array
/// \param var
/// \param shape
/// \param strides
/// \param idx1
/// \param idx2
void
PatternGenerator
::
GetShapeInfoAndSwap
(
Array
<
Var
>
&
var
,
Array
<
Expr
>
&
shape
,
Array
<
Expr
>
&
strides
,
int
idx1
,
int
idx2
)
{
auto
tmp_var
=
GetItem
(
var
,
idx1
);
SetItem
(
var
,
idx1
,
GetItem
(
var
,
idx2
));
SetItem
(
var
,
idx2
,
tmp_var
);
auto
tmp_shape
=
GetItem
(
shape
,
idx1
);
SetItem
(
shape
,
idx1
,
GetItem
(
shape
,
idx2
));
SetItem
(
shape
,
idx2
,
tmp_shape
);
auto
tmp_stride
=
GetItem
(
strides
,
idx1
);
SetItem
(
strides
,
idx1
,
GetItem
(
strides
,
idx2
));
SetItem
(
strides
,
idx2
,
tmp_stride
);
}
/// Get insn args of load 2D intrin
/// \param intrin_name
/// \param dst_info_list
...
...
@@ -856,6 +814,38 @@ Map<std::string, Expr> GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn
return
arg_info_map
;
}
/// Replace com_info's var with new for loop's var
/// \param info
/// \param old_for_info
/// \param new_for_info
void
ReplaceVarWithNewForInfo
(
StmtStoreInfo
&
info
,
const
StmtInfo
&
old_for_info
,
const
StmtInfo
&
new_for_info
)
{
for
(
size_t
i
=
0
;
i
<
new_for_info
.
vars_
.
size
();
++
i
)
{
for
(
size_t
j
=
0
;
j
<
info
->
var_
.
size
();
++
j
)
{
if
(
info
->
var_
[
j
]
->
name_hint
==
new_for_info
.
vars_
[
i
]
->
name_hint
)
{
SetItem
(
info
.
GetNode
()
->
var_
,
static_cast
<
int
>
(
j
),
new_for_info
.
vars_
[
i
]);
}
}
info
.
GetNode
()
->
index_
=
substitute
(
old_for_info
.
vars_
[
i
],
new_for_info
.
vars_
[
i
],
info
->
index_
);
}
}
std
::
string
GetBinaryVecMode
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
std
::
string
&
intrin_name
,
bool
enable_bisect
)
{
std
::
set
<
std
::
string
>
reduce_bisect_list
=
{
"vadd"
,
"vsub"
,
"vmul"
,
"vmax"
};
std
::
string
mode
=
"reduction"
;
if
(
IsElementwise
(
dst_info_list
,
src_info_list
))
{
mode
=
"elewise"
;
}
else
if
(
IsBroadcast
(
dst_info_list
,
src_info_list
))
{
mode
=
"broadcast"
;
}
else
if
(
IsLastAxisReduction
(
dst_info_list
,
src_info_list
))
{
mode
=
"reduce_last_axis"
;
}
else
if
(
enable_bisect
&&
reduce_bisect_list
.
count
(
intrin_name
)
!=
0
&&
IsBisectionReduction
(
dst_info_list
,
src_info_list
))
{
mode
=
"reduce_bisection"
;
}
return
mode
;
}
const
char
*
const
DummyLastVar
=
"cc_last"
;
TVM_REGISTER_API
(
"cce_util.GetVecMask"
).
set_body
([](
const
TVMArgs
args
,
TVMRetValue
*
ret
)
{
...
...
src/emit_insn/insn_pattern.h
浏览文件 @
0871623a
...
...
@@ -37,220 +37,12 @@ struct PatternResult {
StmtInfo
for_info
;
};
class
PatternGenerator
{
public:
PatternGenerator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfo
&
for_info
)
:
for_info
(
for_info
),
not_this_pattern
(
-
1.0
f
),
split_latency_coef
(
10.0
f
),
repeat_latency_coef
(
3.0
f
),
offset_latency_coef
(
0.1
f
)
{
CHECK
(
!
dst_info_list
.
empty
());
dst_info
=
dst_info_list
[
0
];
}
virtual
~
PatternGenerator
()
=
default
;
virtual
PatternResult
GetInsnArgs
()
=
0
;
protected:
int
GetNonZeroShape
(
const
Expr
&
dst_shape
,
const
Expr
&
src0_shape
,
const
Expr
&
src1_shape
=
Expr
());
void
GetShapeInfoAndSwap
(
Array
<
Var
>
&
var
,
Array
<
Expr
>
&
shape
,
Array
<
Expr
>
&
strides
,
int
idx1
,
int
idx2
);
virtual
float
Compute3DPatternMaskRate
()
{
return
not_this_pattern
;
}
virtual
float
Compute2DBlockPatternMaskRate
()
{
return
not_this_pattern
;
}
virtual
float
Compute2DPatternMaskRate
()
{
return
not_this_pattern
;
}
virtual
float
Compute1DPatternMaskRate
()
{
return
not_this_pattern
;
}
virtual
Array
<
Var
>
Get3DPattern
()
{
return
{};
}
virtual
Array
<
Var
>
Get2DBlockPattern
()
{
return
{};
}
virtual
Array
<
Var
>
Get2DPattern
()
{
return
{};
}
virtual
Array
<
Var
>
Get1DPattern
()
{
return
{};
}
virtual
PatternResult
GenResult
(
const
Array
<
Var
>
&
elim_var
)
=
0
;
StmtStoreInfo
dst_info
;
StmtInfo
for_info
;
const
float
not_this_pattern
;
const
float
split_latency_coef
;
const
float
repeat_latency_coef
;
const
float
offset_latency_coef
;
};
class
SingleVecPatternGenerator
:
public
PatternGenerator
{
public:
SingleVecPatternGenerator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
mode
=
"elewise"
)
:
PatternGenerator
(
dst_info_list
,
for_info
),
arg_info
(
ArgInfo
(
make_node
<
ArgInfoNode
>
())),
body_args
(
VectorArgInfo
()),
tail_args
(
VectorArgInfo
()),
mode
(
mode
)
{
if
(
src_info_list
.
empty
())
{
src_info
=
dst_info
.
Copy
();
}
else
{
CHECK
(
!
src_info_list
.
empty
());
src_info
=
src_info_list
[
0
];
}
}
~
SingleVecPatternGenerator
()
override
=
default
;
PatternResult
GetInsnArgs
()
final
;
protected:
float
Compute3DPatternMaskRate
()
final
;
float
Compute2DBlockPatternMaskRate
()
final
;
float
Compute2DPatternMaskRate
()
final
;
float
Compute1DPatternMaskRate
()
final
;
float
Compute3DsPatternMaskRate
();
float
Compute2DRepeatPatternMaskRate
();
Array
<
Var
>
Get3DPattern
()
final
;
Array
<
Var
>
Get2DBlockPattern
()
final
;
Array
<
Var
>
Get2DPattern
()
final
;
Array
<
Var
>
Get1DPattern
()
final
;
Array
<
Var
>
Get3DsPattern
();
Array
<
Var
>
Get2DRepeatPattern
();
PatternResult
GenResult
(
const
Array
<
Var
>
&
elim_var
)
final
;
private:
void
CalcParams
();
int
GetLastDimShape
(
const
Expr
&
dst_shape
,
const
Expr
&
src_shape
);
struct
Params
{
Array
<
Var
>
dst_var
;
Array
<
Var
>
src_var
;
Array
<
Expr
>
dst_shape
;
Array
<
Expr
>
src_shape
;
Array
<
Expr
>
dst_strides
;
Array
<
Expr
>
src_strides
;
int
non_zero_shape1
=
0
;
int
non_zero_shape2
=
0
;
int
non_zero_shape3
=
0
;
int
all_points
=
0
;
int
dst_block_size
=
0
;
int
src_block_size
=
0
;
int
mask_block_size
=
0
;
int
dst_bits
=
0
;
int
src_bits
=
0
;
int
max_bits
=
0
;
int
dst_vec_max_len
=
0
;
int
vec_max_len
=
0
;
int
block_offset
=
0
;
};
StmtStoreInfo
src_info
;
Params
params
;
ArgInfo
arg_info
;
VectorArgInfo
body_args
;
VectorArgInfo
tail_args
;
std
::
string
mode
;
Type
data_type
;
};
class
BinaryVecPatternGenerator
:
public
PatternGenerator
{
public:
BinaryVecPatternGenerator
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
mode
,
bool
expand_mask
=
true
)
:
PatternGenerator
(
dst_info_list
,
for_info
),
src_info_list
(
src_info_list
),
arg_info
(
ArgInfo
(
make_node
<
ArgInfoNode
>
())),
body_args
(
VectorArgInfo
()),
tail_args
(
VectorArgInfo
()),
empty_var
(
Var
(
""
)),
mode
(
mode
),
expand_mask
(
expand_mask
)
{}
~
BinaryVecPatternGenerator
()
override
=
default
;
PatternResult
GetInsnArgs
()
final
;
protected:
float
Compute3DPatternMaskRate
()
final
;
float
Compute2DBlockPatternMaskRate
()
final
;
float
Compute2DPatternMaskRate
()
final
;
float
Compute1DPatternMaskRate
()
final
;
Array
<
Var
>
Get3DPattern
()
final
;
Array
<
Var
>
Get2DBlockPattern
()
final
;
Array
<
Var
>
Get2DPattern
()
final
;
Array
<
Var
>
Get1DPattern
()
final
;
PatternResult
GenResult
(
const
Array
<
Var
>
&
elim_var
)
final
;
private:
void
CalcParams
();
bool
IsSamePatternComInfo
(
const
StmtStoreInfo
&
info_a
,
const
StmtStoreInfo
&
info_b
);
bool
IsNonZeroShapeEqual
(
const
Array
<
Expr
>
&
shape_list
);
void
AppendEmptyVar
(
StmtInfoList
&
info_list
);
struct
Params
{
Array
<
Var
>
dst_var
;
Array
<
Expr
>
dst_shape
;
Array
<
Expr
>
dst_strides
;
Array
<
Var
>
src_var0
;
Array
<
Expr
>
src_shape0
;
Array
<
Expr
>
src_strides0
;
Array
<
Var
>
src_var1
;
Array
<
Expr
>
src_shape1
;
Array
<
Expr
>
src_strides1
;
int
non_zero_shape1
=
0
;
int
non_zero_shape2
=
0
;
int
non_zero_shape3
=
0
;
int
all_points
=
0
;
int
block_size
=
0
;
int
last_dim_shape
=
0
;
int
vec_max_len
=
0
;
};
StmtInfoList
src_info_list
;
ArgInfo
arg_info
;
VectorArgInfo
body_args
;
VectorArgInfo
tail_args
;
Params
params
;
Var
empty_var
;
std
::
string
mode
;
bool
expand_mask
;
};
class
ReduceLastAxisPatternGenerator
:
public
PatternGenerator
{
public:
ReduceLastAxisPatternGenerator
(
const
StmtStoreInfo
&
dst_info
,
const
StmtStoreInfo
&
src_info
,
const
StmtInfo
&
for_info
,
const
std
::
string
&
intrin_name
)
:
PatternGenerator
({
dst_info
},
for_info
),
src_info
(
src_info
),
arg_info
(
ArgInfo
(
make_node
<
ArgInfoNode
>
())),
body_args
(
VectorArgInfo
()),
tail_args
(
VectorArgInfo
()),
intrin_name
(
intrin_name
)
{}
PatternResult
GetInsnArgs
()
final
;
~
ReduceLastAxisPatternGenerator
()
override
=
default
;
protected:
float
Compute2DBlockPatternMaskRate
()
final
;
Array
<
Var
>
Get2DBlockPattern
()
final
;
Array
<
Var
>
Get1DPattern
()
final
;
PatternResult
GenResult
(
const
Array
<
Var
>
&
elim_var
)
final
;
private:
void
CalcParams
();
struct
Params
{
Array
<
Var
>
src_var
;
int
block_size
=
0
;
int
vec_max_len
=
0
;
int
last_dim_shape
=
0
;
Expr
insn_offset_scale_factor
;
};
StmtStoreInfo
src_info
;
ArgInfo
arg_info
;
VectorArgInfo
body_args
;
VectorArgInfo
tail_args
;
Array
<
VectorArgInfo
>
mix_vec_arg_list
;
std
::
string
intrin_name
;
Params
params
;
};
std
::
string
GetSingleVecComputationInfo
(
const
Stmt
&
stmt
,
const
std
::
string
&
intrin_name
,
Array
<
StmtStoreInfo
>
&
dst_info_list
,
Array
<
StmtStoreInfo
>
&
src_info_list
,
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
bool
need_compact
=
true
);
ArgInfo
GetBinaryVecInsnArgs
(
const
Stmt
&
stmt
,
std
::
string
intrin_name
,
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
if_info
,
StmtInfo
&
for_info
,
bool
enable_bisect
=
true
);
std
::
string
GetBinaryVecMode
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
std
::
string
&
intrin_name
,
bool
enable_bisect
=
true
);
ArgInfo
GetMultiVecInsnArgs
(
StmtInfoList
&
dst_info_list
,
StmtInfoList
&
src_info_list
,
StmtInfo
&
for_info
);
...
...
@@ -277,10 +69,7 @@ Map<std::string, Expr> GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn
const
StmtInfoList
&
src_info_list
,
StmtInfo
&
for_info
,
Map
<
std
::
string
,
Expr
>
&
ub_copy_pre
,
Map
<
std
::
string
,
Expr
>
&
ub_copy_post
);
BisectionInfoWrapper
SeparateComInfoToBisectionInfoList
(
const
StmtInfoList
&
dst_info_list
,
const
StmtInfoList
&
src_info_list
,
const
StmtInfo
&
for_info
,
StmtInfo
&
if_info
,
bool
last_axis
,
int
postfix
);
void
ReplaceVarWithNewForInfo
(
StmtStoreInfo
&
info
,
const
StmtInfo
&
old_for_info
,
const
StmtInfo
&
new_for_info
);
extern
const
char
*
const
DummyLastVar
;
}
// namespace akg
#endif // EMIT_INSN_INSN_PATTERN_H_
src/emit_insn/insn_single_vec_pattern.cc
已删除
100644 → 0
浏览文件 @
d237aa7d
此差异已折叠。
点击以展开。
src/pass/emit_insn.cc
浏览文件 @
0871623a
...
...
@@ -21,6 +21,7 @@
#include "pass/ir_util.h"
#include "poly/poly_util.h"
#include "emit_insn/insn_emitter.h"
#include "emit_insn/ir_transform.h"
namespace
akg
{
namespace
ir
{
...
...
@@ -475,6 +476,7 @@ Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Ma
}
stmt
=
UnalignedMad
().
Mutate
(
stmt
);
stmt
=
RegCondition
().
Mutate
(
stmt
);
stmt
=
ForVarUnique
().
Mutate
(
stmt
);
return
stmt
;
}
}
// namespace ir
...
...
src/pass/multi_last_axis_reduction.cc
浏览文件 @
0871623a
...
...
@@ -343,8 +343,12 @@ class BroadcastCalculate : public IRMutator {
};
Stmt
MultiLastAxisReductions
(
Stmt
stmt
,
bool
is_dynamic
=
false
)
{
auto
ori_stmt
=
stmt
;
stmt
=
MultiLastAxisReduction
().
Mutate
(
stmt
);
stmt
=
BroadcastCalculate
(
is_dynamic
).
Mutate
(
stmt
);
if
(
!
is_dynamic
&&
!
Equal
(
ori_stmt
,
stmt
))
{
stmt
=
MergeLoops
(
stmt
);
}
return
stmt
;
}
}
// namespace ir
...
...
src/pass/split_tail_block.cc
浏览文件 @
0871623a
...
...
@@ -21,7 +21,7 @@
#include <algorithm>
#include "emit_insn/insn_info.h"
#include "emit_insn/insn_pattern.h"
#include "emit_insn/insn_args_calculator.h"
namespace
akg
{
namespace
ir
{
...
...
@@ -48,85 +48,63 @@ class TailSpliter : public IRMutator {
if
(
src_info_list
.
empty
())
{
src_info_list
=
{
dst_info
.
Copy
()};
}
auto
get_info_list
=
[](
const
StmtStoreInfo
&
dst_info
,
const
Array
<
StmtStoreInfo
>
&
src_info_list
)
{
Array
<
StmtStoreInfo
>
res
;
res
.
push_back
(
dst_info
.
Copy
());
for
(
auto
it
:
src_info_list
)
{
res
.
push_back
(
it
.
Copy
());
}
return
res
;
};
auto
info_list
=
get_info_list
(
dst_info
,
src_info_list
);
FillEmptyVar
(
info_list
);
auto
axis_list
=
GetAixsList
(
for_info
,
info_list
);
auto
get_last_axis_it
=
[](
const
std
::
list
<
InsnAxis
>
&
axis_list
)
{
for
(
auto
it
=
axis_list
.
begin
();
it
!=
axis_list
.
end
();
it
++
)
{
auto
stride_list
=
it
->
stride_list
;
if
(
!
(
std
::
any_of
(
stride_list
.
begin
(),
stride_list
.
end
(),
[](
int
stride
)
{
return
stride
>
1
;
})
||
std
::
all_of
(
stride_list
.
begin
(),
stride_list
.
end
(),
[](
int
stride
)
{
return
stride
==
0
;
})))
{
return
it
;
}
}
return
axis_list
.
end
();
};
auto
last_axis_it
=
get_last_axis_it
(
axis_list
);
if
(
last_axis_it
==
axis_list
.
end
())
{
return
s
;
}
auto
last_axis
=
*
last_axis_it
;
auto
last_axis_shape
=
last_axis
.
extent
;
auto
info_list
=
GetInfoList
(
dst_info
,
src_info_list
);
FillEmptyVar
(
info_list
);
int
dst_block_size
=
GetUbBlkSize
(
dst_info
->
dtype_
);
int
src_block_size
=
GetUbBlkSize
(
src_info_list
[
0
]
->
dtype_
);
int
block_size
=
dst_block_size
>
src_block_size
?
dst_block_size
:
src_block_size
;
int
block_size
=
dst_block_size
<
src_block_size
?
dst_block_size
:
src_block_size
;
int
cast_block_size
=
dst_block_size
>
src_block_size
?
dst_block_size
:
src_block_size
;
int
vec_max_len
=
block_size
*
FULL_BLOCK_NUM
;
if
(
last_axis_shape
>
vec_max_len
&&
last_axis_shape
%
vec_max_len
!=
0
)
{
return
Block
::
make
(
TailMake
(
s
,
last_axis
,
vec_max_len
,
false
),
TailMake
(
s
,
last_axis
,
vec_max_len
,
true
));
auto
args_calculator
=
InsnArgsCalculator
(
dst_info_list
,
src_info_list
,
for_info
,
""
);
auto
vec_axis_it
=
args_calculator
.
GetVecAxisIt
();
bool
cast
=
dst_block_size
!=
src_block_size
;
if
(
args_calculator
.
IsValid
(
vec_axis_it
))
{
auto
vec_axis
=
*
vec_axis_it
;
auto
vec_axis_shape
=
vec_axis
.
extent
;
if
(
vec_axis_shape
>=
vec_max_len
)
{
if
(
vec_axis_shape
%
vec_max_len
!=
0
)
{
return
TailBlock
(
s
,
vec_axis
,
vec_max_len
);
}
if
(
last_axis_shape
<
vec_max_len
*
tail_rate_
&&
last_axis_shape
>
block_size
&&
last_axis_shape
%
block_size
!=
0
&&
axis_list
.
size
()
>
1
)
{
return
Block
::
make
(
TailMake
(
s
,
last_axis
,
block_size
,
false
),
TailMake
(
s
,
last_axis
,
block_size
,
true
));
}
else
{
if
(
vec_axis_shape
<
vec_max_len
*
tail_rate_
&&
vec_axis_shape
>
cast_block_size
&&
vec_axis_shape
%
cast_block_size
!=
0
&&
args_calculator
.
axis_list_
.
size
()
>
1
)
{
return
TailBlock
(
s
,
vec_axis
,
cast_block_size
);
}
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
std
::
list
<
InsnAxis
>
GetAixsList
(
const
StmtInfo
&
for_info
,
const
Array
<
StmtStoreInfo
>
&
info_list
)
{
std
::
list
<
InsnAxis
>
axis_list
;
auto
GetStrideByAxis
=
[](
const
Array
<
Var
>
&
vars
,
const
Array
<
Expr
>
&
strides
,
Var
obj_var
)
{
int
index
=
0
;
for
(
auto
var_it
:
vars
)
{
if
(
Equal
(
var_it
,
obj_var
))
{
return
strides
[
index
];
if
(
!
cast
&&
(
!
args_calculator
.
IsValid
(
vec_axis_it
)
||
vec_axis_it
->
extent
<=
cast_block_size
*
tail_rate_
))
{
auto
get_block_axis
=
[
&
](
std
::
list
<
InsnAxis
>
&
axis_list
)
{
InsnAxis
block_axis
;
block_axis
.
is_valid
=
false
;
std
::
vector
<
InsnAxis
>
temp_axis_set
;
auto
block_stride_lambda
=
[
&
](
int
stride
)
{
return
stride
%
block_size
==
0
&&
stride
/
block_size
<=
4
;
};
for
(
auto
axis
:
axis_list
)
{
if
(
std
::
all_of
(
axis
.
stride_list
.
begin
(),
axis
.
stride_list
.
end
(),
block_stride_lambda
)
&&
axis
.
dst_stride
!=
0
&&
axis
.
extent
!=
0
&&
axis
.
extent
>
FULL_BLOCK_NUM
&&
axis
.
extent
%
FULL_BLOCK_NUM
!=
0
)
{
temp_axis_set
.
push_back
(
axis
);
}
index
++
;
}
return
Expr
(
0
);
};
for
(
auto
it
:
for_info
.
ops_
)
{
InsnAxis
axis
;
auto
for_stmt
=
it
.
as
<
For
>
();
CHECK
(
for_stmt
);
axis
.
var
=
for_stmt
->
loop_var
;
axis
.
extent
=
GetInt32Const
(
for_stmt
->
extent
);
axis
.
min
=
GetInt32Const
(
for_stmt
->
min
);
int
index
=
0
;
for
(
auto
it
:
info_list
)
{
auto
stride
=
GetInt32Const
(
GetStrideByAxis
(
it
->
var_
,
it
->
strides_
,
axis
.
var
));
axis
.
stride_list
.
push_back
(
stride
);
if
(
index
==
0
)
{
axis
.
dst_stride
=
stride
;
if
(
!
temp_axis_set
.
empty
())
{
return
temp_axis_set
[
0
];
}
else
{
axis
.
src_stride_list
.
push_back
(
stride
)
;
return
block_axis
;
}
index
++
;
};
auto
block_axis
=
get_block_axis
(
args_calculator
.
axis_list_
);
if
(
block_axis
.
IsValid
())
{
return
TailBlock
(
s
,
block_axis
,
FULL_BLOCK_NUM
);
}
axis_list
.
push_back
(
axis
);
}
return
axis_list
;
return
s
;
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
TailBlock
(
const
Stmt
&
s
,
const
InsnAxis
&
tail_axis
,
int
body_size
)
{
return
Block
::
make
(
TailMake
(
s
,
tail_axis
,
body_size
,
false
),
TailMake
(
s
,
tail_axis
,
body_size
,
true
));
}
Stmt
TailMake
(
const
Stmt
&
s
,
const
InsnAxis
&
tail_axis
,
int
body_size
,
bool
is_tail
)
{
if
(
auto
attr_stmt
=
s
.
as
<
AttrStmt
>
())
{
return
AttrStmt
::
make
(
attr_stmt
->
node
,
attr_stmt
->
attr_key
,
attr_stmt
->
value
,
...
...
@@ -145,7 +123,6 @@ class TailSpliter : public IRMutator {
}
return
For
::
make
(
for_stmt
->
loop_var
,
for_stmt
->
min
,
for_stmt
->
extent
,
for_stmt
->
for_type
,
for_stmt
->
device_api
,
TailMake
(
for_stmt
->
body
,
tail_axis
,
body_size
,
is_tail
));
}
if
(
s
.
as
<
Store
>
()
&&
is_tail
)
{
return
substitute
(
tail_axis
.
var
,
Add
::
make
(
Expr
(
tail_axis
.
extent
/
body_size
*
body_size
),
tail_axis
.
var
),
s
);
...
...
@@ -156,6 +133,20 @@ class TailSpliter : public IRMutator {
private:
const
float
tail_rate_
{
0.6
};
const
std
::
set
<
std
::
string
>
include_intrin_list_
=
{
// binary vec
"vec_binary_add"
,
"vec_binary_sub"
,
"vec_binary_mul"
,
"vec_binary_min"
,
"vec_binary_max"
,
"vec_binary_div"
,
"vec_binary_and"
,
"vec_binary_or"
,
"vec_binary_vmadd"
,
"vec_binary_vmaddrelu"
,
"vec_binary_vmla"
,
// single vec
"vec_single_fabs"
,
"vec_single_log"
,
"vec_single_exp"
,
...
...
@@ -165,20 +156,28 @@ class TailSpliter : public IRMutator {
"vec_single_rsqrt"
,
"vec_single_relu"
,
"vec_single_not"
,
// vector_scalar
"vec_single_muls"
,
"vec_single_adds"
,
// Mov
"broadcast"
,
"mask_broadcast"
,
// vector_cast
"vec_single_cast"
,
"vec_single_floor"
,
"vec_single_round"
,
"vec_single_ceil"
,
"vec_single_trunc"
,
// scalar case
"vector_dup"
,
"vmuls"
,
"vadds"
,
"vaxpy"
,
};
};
Stmt
SplitTail
(
Stmt
stmt
)
{
return
TailSpliter
().
Mutate
(
stmt
);
}
Stmt
SplitTail
(
Stmt
stmt
)
{
auto
tail_spliter
=
TailSpliter
();
auto
first_round
=
tail_spliter
.
Mutate
(
stmt
);
auto
second_round
=
tail_spliter
.
Mutate
(
stmt
);
return
second_round
;
}
}
// namespace ir
}
// namespace akg
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录