Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3559252a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3559252a
编写于
7月 13, 2023
作者:
B
BiynXu
提交者:
GitHub
7月 13, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CINN] comb the op lowering code (#54982)
* [CINN] comb the op lowering code * [CINN] format code of OpLower
上级
27cc0df5
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
466 addition
and
571 deletion
+466
-571
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc
...o_schedule/search_space/auto_gen_rule/auto_inline_test.cc
+3
-2
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc
...ule/search_space/auto_gen_rule/multi_level_tiling_test.cc
+59
-62
paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc
...n/auto_schedule/search_space/auto_gen_rule/test_helper.cc
+4
-6
paddle/cinn/auto_schedule/task/tune_task.cc
paddle/cinn/auto_schedule/task/tune_task.cc
+2
-1
paddle/cinn/auto_schedule/tests/performance_comparison_test.cc
...e/cinn/auto_schedule/tests/performance_comparison_test.cc
+3
-1
paddle/cinn/hlir/framework/op_lowering.cc
paddle/cinn/hlir/framework/op_lowering.cc
+272
-462
paddle/cinn/hlir/framework/op_lowering.h
paddle/cinn/hlir/framework/op_lowering.h
+114
-28
paddle/cinn/hlir/framework/op_lowering_util.cc
paddle/cinn/hlir/framework/op_lowering_util.cc
+6
-6
paddle/cinn/hlir/framework/op_lowering_util.h
paddle/cinn/hlir/framework/op_lowering_util.h
+3
-3
未找到文件。
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc
浏览文件 @
3559252a
...
...
@@ -71,7 +71,6 @@ TEST(AutoInline, SingleLoopInline) {
nullptr
,
target
,
true
);
VLOG
(
6
)
<<
"Expr after lowering:"
;
VLOG
(
6
)
<<
funcs
[
0
]
->
body
;
...
...
@@ -170,7 +169,9 @@ TEST(AutoInline, AddReluInline) {
EXPECT_EQ
(
graph
->
fusion_groups
.
size
(),
1UL
);
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
op_lowerer
->
LowerWithoutSchedule
(
graph
->
fusion_groups
[
0
]);
op_lowerer
->
Lower
(
graph
->
fusion_groups
[
0
],
/*apply_op_schedule = */
false
,
/*apply_group_schedule=*/
false
);
VLOG
(
6
)
<<
"Expr before auto inline: "
<<
funcs
[
0
]
->
body
;
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc
浏览文件 @
3559252a
...
...
@@ -388,9 +388,9 @@ TEST_F(TestMultiLevelTiling, ReduceSum) {
TEST_F
(
TestMultiLevelTiling
,
Pool2d
)
{
default_input_names
=
{
"input"
};
default_output_names
=
{
"var_0"
};
std
::
vector
<
int32_t
>
input_shape
{
2
,
8
,
16
,
16
};
std
::
vector
<
int32_t
>
output_shape
{
2
,
8
,
8
,
8
};
default_output_names
=
{
"var_0"
,
"pad_temp_0"
};
std
::
vector
<
std
::
vector
<
int32_t
>>
input_shapes
{{
2
,
8
,
16
,
16
}
};
std
::
vector
<
std
::
vector
<
int32_t
>>
output_shapes
{{
2
,
8
,
8
,
8
},
{
2
,
8
,
18
,
18
}
};
std
::
string
pooling_type
=
"max"
;
std
::
vector
<
int
>
ksize
{
3
,
3
};
std
::
vector
<
int
>
strides
{
2
,
2
};
...
...
@@ -402,7 +402,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
bool
adaptive
=
false
;
std
::
string
padding_algorithm
=
"EXPLICIT"
;
frontend
::
Program
pool2d_program
=
tests
::
OpBuilder
(
"pool2d"
).
Build
(
{{
"input"
,
input_shape
}},
{{
"input"
,
input_shape
s
[
0
]
}},
{{
"pool_type"
,
pooling_type
},
{
"kernel_size"
,
ksize
},
{
"stride_size"
,
strides
},
...
...
@@ -439,6 +439,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
std
::
string
expected_ir
=
R"ROC(Expr 0 {
{
ScheduleBlock(root)
{
{
serial for (i, 0, 2)
{
...
...
@@ -451,6 +452,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
ScheduleBlock(pad_temp_0)
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
{
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
}
}
...
...
@@ -458,12 +460,6 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
}
}
}
}
} // end Expr 0
Expr 1 {
{
ScheduleBlock(root_0)
{
{
thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16)
{
...
...
@@ -552,8 +548,9 @@ Expr 1 {
}
}
}
}
}
} // end Expr
1
} // end Expr
0
)ROC"
;
ASSERT_EQ
(
ir
,
expected_ir
);
...
...
@@ -569,8 +566,8 @@ Expr 1 {
pool2d_program
,
fixed_rand_seed
,
/* apply_manual_schedule*/
true
))),
default_input_names
,
default_output_names
,
{
input_shape
}
,
{
output_shape
}
,
input_shapes
,
output_shapes
,
target_
);
}
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc
浏览文件 @
3559252a
...
...
@@ -63,12 +63,10 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
hlir
::
framework
::
OpLowerer
op_lowerer
(
dtype_dict
,
shape_dict
,
target_
);
if
(
apply_manual_schedule
)
{
lowered_funcs_
=
op_lowerer
.
Lower
(
graph
->
fusion_groups
.
front
());
}
else
{
lowered_funcs_
=
op_lowerer
.
LowerWithoutSchedule
(
graph
->
fusion_groups
.
front
());
}
op_lowerer
.
Lower
(
graph
->
fusion_groups
.
front
(),
/*apply_op_schedule = */
apply_manual_schedule
,
/*apply_group_schedule = */
apply_manual_schedule
);
CHECK
(
!
lowered_funcs_
.
empty
())
<<
"lowered_funcs_ is empty"
;
std
::
vector
<
Expr
>
bodys
;
...
...
paddle/cinn/auto_schedule/task/tune_task.cc
浏览文件 @
3559252a
...
...
@@ -39,7 +39,8 @@ void TuneTask::Initialize(
op_lowerer
=
lower_handler
;
// Set lowered_funcs and analyze output names.
this
->
lowered_funcs
=
op_lowerer
->
LowerWithoutSchedule
(
subgraph
);
this
->
lowered_funcs
=
op_lowerer
->
Lower
(
subgraph
,
/*apply_op_schedule = */
false
,
/*apply_group_schedule=*/
false
);
this
->
output_names
=
GetOutputNamesFromLoweredFunc
(
this
->
lowered_funcs
);
this
->
serialized_key
=
SerializeToString
(
shape_dict
,
dtype_dict
);
}
...
...
paddle/cinn/auto_schedule/tests/performance_comparison_test.cc
浏览文件 @
3559252a
...
...
@@ -157,7 +157,9 @@ class PerformanceTester : public ::testing::Test {
for
(
auto
group
:
graph
->
fusion_groups
)
{
compile_options
.
lowered_funcs
.
push_back
(
op_lowerer
->
LowerWithoutSchedule
(
group
));
op_lowerer
->
Lower
(
group
,
/*apply_op_schedule = */
false
,
/*apply_group_schedule=*/
false
));
}
VLOG
(
3
)
<<
"===========================No Schedule LoweredFunc "
...
...
paddle/cinn/hlir/framework/op_lowering.cc
浏览文件 @
3559252a
...
...
@@ -45,7 +45,9 @@ OpLowerer::OpLowerer(
const
Target
&
target
)
:
type_dict_
(
type_dict
),
shape_dict_
(
shape_dict
),
target_
(
target
)
{}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
Lower
(
GroupPtr
&
group
)
{
// NOLINT
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
Lower
(
const
GroupPtr
&
group
,
bool
apply_op_schedule
,
bool
apply_group_schedule
)
{
VLOG
(
3
)
<<
"Lowering Group : "
<<
group
->
group_id
<<
" , Op Pattern : "
<<
group
->
op_pattern_kind
;
group
->
input_names
.
clear
();
...
...
@@ -55,13 +57,22 @@ std::vector<ir::LoweredFunc> OpLowerer::Lower(GroupPtr& group) { // NOLINT
case
framework
::
kElementWise
:
case
framework
::
kBroadcast
:
case
framework
::
kInjective
:
return
IRLowerOp
(
&
OpLowerer
::
IRElementwiseCompute
,
group
);
return
LowerGroup
(
group
,
apply_op_schedule
,
apply_group_schedule
,
&
OpLowerer
::
ElementwiseScheduleDetermineFunction
);
case
framework
::
kReduction
:
return
IRLowerOp
(
&
OpLowerer
::
IRReduceCompute
,
group
);
return
LowerGroup
(
group
,
apply_op_schedule
,
apply_group_schedule
,
&
OpLowerer
::
ReduceScheduleDetermineFunction
);
case
framework
::
kOutFusible
:
LOG
(
FATAL
)
<<
"Group Pattern Kind kOutFusible Is Not Implemented!"
;
case
framework
::
kNonFusible
:
return
IRLowerNonFusibleOp
(
group
,
/*apply_impl_schedule = */
true
);
return
LowerGroup
(
group
,
apply_op_schedule
,
apply_group_schedule
,
&
OpLowerer
::
NonFusibleScheduleDetermineFunction
);
default:
LOG
(
FATAL
)
<<
"Group Pattern Kind Is Unknown!"
;
}
...
...
@@ -70,532 +81,329 @@ std::vector<ir::LoweredFunc> OpLowerer::Lower(GroupPtr& group) { // NOLINT
}
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
LowerWithoutSchedule
(
GroupPtr
&
group
)
{
VLOG
(
3
)
<<
"Lowering Group : "
<<
group
->
group_id
<<
" , Op Pattern : "
<<
group
->
op_pattern_kind
;
if
(
FLAGS_cinn_ir_schedule
)
{
switch
(
group
->
op_pattern_kind
)
{
case
framework
::
kElementWise
:
case
framework
::
kBroadcast
:
case
framework
::
kInjective
:
return
IRLowerOpWithoutSchedule
(
&
OpLowerer
::
IRElementwiseCompute
,
group
);
case
framework
::
kReduction
:
return
IRLowerOpWithoutSchedule
(
&
OpLowerer
::
IRReduceCompute
,
group
);
case
framework
::
kOutFusible
:
LOG
(
FATAL
)
<<
"Group Pattern Kind kOutFusible Is Not Implemented!"
;
case
framework
::
kNonFusible
:
return
IRLowerNonFusibleOp
(
group
,
/*apply_impl_schedule = */
false
);
default:
LOG
(
FATAL
)
<<
"Group Pattern Kind kNonFusible Is Not Implemented!"
;
}
}
else
{
LOG
(
FATAL
)
<<
"Previous IR Schedule Is Not Implemented!"
;
}
bool
OpLowerer
::
ElementwiseScheduleDetermineFunction
(
Node
*
node
)
{
return
true
;
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
IRLowerOp
(
IRComputeFunction
compute
,
GroupPtr
&
group
)
{
poly
::
StageMap
stages
;
std
::
vector
<
ir
::
Tensor
>
arg_tensors
;
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>
tensor_map
;
// do compute.
bool
OpLowerer
::
ReduceScheduleDetermineFunction
(
Node
*
node
)
{
auto
&
op_pattern_dict
=
Operator
::
GetAttrs
<
OpPatternKind
>
(
"OpPattern"
);
return
op_pattern_dict
[
node
->
op
()]
==
framework
::
kReduction
;
}
bool
OpLowerer
::
NonFusibleScheduleDetermineFunction
(
Node
*
node
)
{
return
true
;
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
LowerGroup
(
const
GroupPtr
&
group
,
bool
apply_op_schedule
,
bool
apply_group_schedule
,
ScheduleDetermineFunction
schedule_determine_func
)
{
// 1.Do compute, lower and schedule for each op.
VLOG
(
3
)
<<
"group->fused_sub_groups.size() is : "
<<
group
->
fused_sub_groups
.
size
();
std
::
vector
<
Expr
>
ast_exprs
;
if
(
group
->
fused_sub_groups
.
size
()
==
0
)
{
ast_exprs
=
(
this
->*
compute
)(
stages
,
arg_tensors
,
tensor_map
,
group
,
group
,
/*apply_impl_schedule = */
true
);
}
else
{
for
(
auto
&
sub_group
:
group
->
fused_sub_groups
)
{
auto
exprs
=
(
this
->*
compute
)(
stages
,
arg_tensors
,
tensor_map
,
group
,
sub_group
,
/*apply_impl_schedule = */
true
);
ast_exprs
.
insert
(
ast_exprs
.
end
(),
exprs
.
begin
(),
exprs
.
end
());
}
std
::
vector
<
Node
*>
nodes
=
group
->
CollectNodes
();
if
(
nodes
.
size
()
==
1
&&
nodes
[
0
]
->
op
()
->
name
==
"custom_call"
)
{
return
LowerCustomCall
(
group
);
}
ir
::
ModuleExpr
mod_expr
(
ast_exprs
);
std
::
vector
<
ir
::
Tensor
>
group_func_arg_tensors
;
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>
tensor_map
;
bool
do_op_schedule
=
apply_group_schedule
||
apply_op_schedule
;
std
::
vector
<
ir
::
Expr
>
func_bodies
=
LowerOps
(
nodes
,
do_op_schedule
,
schedule_determine_func
,
&
group_func_arg_tensors
,
&
tensor_map
);
// 2.Do group schedule.
ir
::
ModuleExpr
mod_expr
(
func_bodies
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
Node
*
first
=
nullptr
;
Node
*
second
=
nullptr
;
VLOG
(
3
)
<<
"Before IRLowerOp schedule, ir is:
\n
"
VLOG
(
3
)
<<
"After lower, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
if
(
apply_group_schedule
)
{
DoGroupSchedule
(
ir_sch
,
group
,
tensor_map
);
VLOG
(
3
)
<<
"After group schedule, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
// do schedule.
IRSchedule
(
ir_sch
,
group
,
tensor_map
);
VLOG
(
3
)
<<
"After IRLowerOp schedule, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
// function args
group
->
input_names
.
clear
();
std
::
vector
<
ir
::
Argument
>
func_args
;
for
(
auto
&
args
:
arg_tensors
)
{
// input node data name.
group
->
input_names
.
push_back
(
args
->
name
);
// input args
func_args
.
emplace_back
(
args
->
buffer
,
ir
::
Argument
::
IO
::
kInput
);
}
group
->
output_names
.
clear
();
for
(
auto
&
node
:
group
->
output_nodes
)
{
// output node data name.
for
(
auto
node_data
:
GetAllNodeData
(
node
))
{
group
->
output_names
.
push_back
(
node_data
->
id
());
}
// collect all output tensor.
std
::
string
post
=
""
;
std
::
string
prefix
=
GetNodeData
(
node
)
->
id
();
for
(
int
idx
=
0
;
idx
<
1
;
++
idx
)
{
CHECK
(
tensor_map
.
count
(
prefix
))
<<
"Can't find output tensor "
<<
prefix
;
if
(
!
tensor_map
.
count
(
prefix
+
post
))
{
break
;
}
auto
tensor
=
tensor_map
[
prefix
+
post
];
arg_tensors
.
push_back
(
tensor
);
// output args
func_args
.
emplace_back
(
tensor
->
buffer
,
ir
::
Argument
::
IO
::
kOutput
);
// update post
post
=
"_"
+
std
::
to_string
(
idx
);
}
}
auto
func_body
=
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
#ifdef CINN_WITH_CUDA
optim
::
OptimizeExprGPU
(
&
(
func_body
));
#endif
auto
temp_buffers
=
lang
::
GetTempBuffers
(
arg_tensors
,
stages
,
func_body
);
auto
func
=
ir
::
_LoweredFunc_
::
Make
(
group
->
GetFuncName
(),
func_args
,
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
),
temp_buffers
);
func
=
optim
::
Optimize
(
Expr
(
func
),
target_
,
false
).
as_lowered_func_ref
();
return
{
func
};
// 3.Do post-processing,
// including preparing function args and temporary variables,
// applying low-level optimization passes, etc.
return
PostProcess
(
group
,
tensor_map
,
do_op_schedule
,
&
ir_sch
,
&
group_func_arg_tensors
);
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
IRLowerOpWithoutSchedule
(
IRComputeFunction
compute
,
GroupPtr
&
group
)
{
poly
::
StageMap
stages
;
std
::
vector
<
ir
::
Tensor
>
arg_tensors
;
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
LowerCustomCall
(
const
GroupPtr
&
group
)
{
std
::
vector
<
Node
*>
nodes
=
group
->
CollectNodes
();
CHECK_EQ
(
nodes
.
size
(),
1
);
Node
*
node
=
nodes
[
0
];
std
::
vector
<
ir
::
Tensor
>
op_func_arg_tensors
;
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>
tensor_map
;
// do compute.
VLOG
(
3
)
<<
"group->fused_sub_groups.size() is : "
<<
group
->
fused_sub_groups
.
size
();
std
::
vector
<
Expr
>
ast_exprs
;
if
(
group
->
fused_sub_groups
.
size
()
==
0
)
{
ast_exprs
=
(
this
->*
compute
)(
stages
,
arg_tensors
,
tensor_map
,
group
,
group
,
/*apply_impl_schedule = */
false
);
for
(
auto
&
node_data
:
GetInputNodeData
(
node
))
{
CHECK
(
node_data
);
ir
::
Tensor
tensor
;
if
(
!
tensor_map
.
count
(
node_data
->
id
()))
{
tensor
=
GetTensor
(
node_data
,
this
->
type_dict_
,
this
->
shape_dict_
);
// record tensor.
tensor_map
[
node_data
->
id
()]
=
tensor
;
// input name.
group
->
input_names
.
push_back
(
node_data
->
id
());
}
else
{
for
(
auto
&
sub_group
:
group
->
fused_sub_groups
)
{
auto
exprs
=
(
this
->*
compute
)(
stages
,
arg_tensors
,
tensor_map
,
group
,
sub_group
,
/*apply_impl_schedule = */
false
);
ast_exprs
.
insert
(
ast_exprs
.
end
(),
exprs
.
begin
(),
exprs
.
end
());
tensor
=
tensor_map
[
node_data
->
id
()];
}
op_func_arg_tensors
.
push_back
(
tensor
);
}
ir
::
ModuleExpr
mod_expr
(
ast_exprs
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
VLOG
(
3
)
<<
"After IRLowerOp compute, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
// function args
std
::
vector
<
Type
>
out_types
;
std
::
vector
<
std
::
vector
<
int
>>
out_shapes
;
auto
node_datas
=
GetAllNodeData
(
node
);
for
(
auto
node_data
:
node_datas
)
{
group
->
output_names
.
push_back
(
node_data
->
id
());
out_types
.
push_back
(
this
->
type_dict_
.
at
(
node_data
->
id
()));
out_shapes
.
push_back
(
this
->
shape_dict_
.
at
(
node_data
->
id
()));
}
auto
&
cinn_strategy
=
Operator
::
GetAttrs
<
StrategyFunction
>
(
"CINNStrategy"
);
auto
impl
=
OpStrategy
::
SelectImpl
(
cinn_strategy
[
node
->
op
()](
node
->
attrs
,
op_func_arg_tensors
,
out_types
,
out_shapes
,
target_
));
std
::
string
external_api
;
if
(
node
->
attrs
.
attr_store
.
count
(
"custom_call"
))
{
external_api
=
absl
::
get
<
std
::
string
>
(
node
->
attrs
.
attr_store
.
at
(
"custom_call"
));
}
else
{
external_api
=
ExternalApiRegistry
::
Global
()
->
GetExternalApi
(
node
,
target_
);
}
std
::
vector
<
common
::
CINNValue
>
compute_args
=
{
common
::
CINNValue
(
group
->
GetFuncName
()),
common
::
CINNValue
(
external_api
)};
common
::
CINNValuePack
pack
=
impl
->
fcompute
(
common
::
CINNValuePack
{
compute_args
});
CHECK_EQ
(
pack
.
size
(),
1UL
);
// reset input names as extern api input args can't be remove duplicate.
group
->
input_names
.
clear
();
std
::
vector
<
ir
::
Argument
>
func_args
;
for
(
auto
&
args
:
arg_tensors
)
{
for
(
auto
&
inode
:
node
->
inlinks_in_order
())
{
group
->
input_names
.
push_back
(
inode
->
source
()
->
as
<
NodeData
>
()
->
id
());
}
return
{
pack
[
0
].
operator
ir
::
Expr
().
as_lowered_func_ref
()};
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
PostProcess
(
const
GroupPtr
&
group
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
bool
done_op_schedule
,
ir
::
IRSchedule
*
ir_sch
,
std
::
vector
<
ir
::
Tensor
>*
group_func_arg_tensors
)
{
// 1.Prepare function args
group
->
input_names
.
clear
();
std
::
vector
<
ir
::
Argument
>
group_func_args
;
std
::
unordered_set
<
std
::
string
>
arg_name_set
;
for
(
auto
&
arg_tensor
:
*
group_func_arg_tensors
)
{
// input node data name.
group
->
input_names
.
push_back
(
arg
s
->
name
);
group
->
input_names
.
push_back
(
arg
_tensor
->
name
);
// input args
func_args
.
emplace_back
(
args
->
buffer
,
ir
::
Argument
::
IO
::
kInput
);
group_func_args
.
emplace_back
(
arg_tensor
->
buffer
,
ir
::
Argument
::
IO
::
kInput
);
arg_name_set
.
insert
(
arg_tensor
->
buffer
->
name
);
}
group
->
output_names
.
clear
();
for
(
auto
&
node
:
group
->
output_nodes
)
{
//
output node data name
.
//
collect all output tensor
.
for
(
auto
node_data
:
GetAllNodeData
(
node
))
{
group
->
output_names
.
push_back
(
node_data
->
id
());
std
::
string
output_node_data_name
=
node_data
->
id
();
group
->
output_names
.
push_back
(
output_node_data_name
);
// CHECK(tensor_map.count(output_node_data_name)) << "Can't find output
// tensor " << output_node_data_name;
if
(
tensor_map
.
count
(
output_node_data_name
)
==
0
)
{
continue
;
}
// collect all output tensor.
std
::
string
post
=
""
;
std
::
string
prefix
=
GetNodeData
(
node
)
->
id
();
for
(
int
idx
=
0
;
idx
<
1
;
++
idx
)
{
CHECK
(
tensor_map
.
count
(
prefix
))
<<
"Can't find output tensor "
<<
prefix
;
if
(
!
tensor_map
.
count
(
prefix
+
post
))
{
break
;
}
auto
tensor
=
tensor_map
[
prefix
+
post
];
arg_tensors
.
push_back
(
tensor
);
auto
tensor
=
tensor_map
.
at
(
output_node_data_name
);
if
(
arg_name_set
.
count
(
tensor
->
buffer
->
name
)
!=
0
)
{
continue
;
}
// output arg tensors
group_func_arg_tensors
->
push_back
(
tensor
);
// output args
func_args
.
emplace_back
(
tensor
->
buffer
,
ir
::
Argument
::
IO
::
kOutput
);
// update post
post
=
"_"
+
std
::
to_string
(
idx
);
group_func_args
.
emplace_back
(
tensor
->
buffer
,
ir
::
Argument
::
IO
::
kOutput
);
arg_name_set
.
insert
(
tensor
->
buffer
->
name
);
}
}
std
::
unordered_set
<
std
::
string
>
args_map
;
for
(
auto
arg
:
func_args
)
{
args_map
.
insert
(
arg
.
name
());
if
(
!
done_op_schedule
)
{
std
::
unordered_set
<
std
::
string
>
args_set
;
for
(
auto
arg
:
group_func_args
)
{
args_set
.
insert
(
arg
.
name
());
}
for
(
auto
&
tenso
r
:
tensor_map
)
{
if
(
args_map
.
count
(
"_"
+
tensor
.
first
))
{
for
(
auto
&
tensor_pai
r
:
tensor_map
)
{
if
(
args_set
.
count
(
"_"
+
tensor_pair
.
second
->
name
))
{
continue
;
}
arg_tensors
.
push_back
(
tensor
.
second
);
// use the underlying tensor name to be consistent with the argument name in
// the lowered function
group
->
output_names
.
push_back
(
tensor
.
second
->
name
);
func_args
.
emplace_back
(
tensor
.
second
->
buffer
,
ir
::
Argument
::
IO
::
kOutput
);
group_func_arg_tensors
->
push_back
(
tensor_pair
.
second
);
// use the underlying tensor name to be consistent with the argument name
// in the lowered function
group
->
output_names
.
push_back
(
tensor_pair
.
second
->
name
);
group_func_args
.
emplace_back
(
tensor_pair
.
second
->
buffer
,
ir
::
Argument
::
IO
::
kOutput
);
}
}
auto
func_body
=
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
auto
func_body
=
ir_sch
->
GetModule
().
GetExprs
().
at
(
0
);
#ifdef CINN_WITH_CUDA
optim
::
OptimizeExprGPU
(
&
(
func_body
));
#endif
auto
temp_buffers
=
lang
::
GetTempBuffers
(
arg_tensors
,
stages
,
func_body
);
// 2.Prepare temp buffers
poly
::
StageMap
stages
;
auto
temp_buffers
=
lang
::
GetTempBuffers
(
*
group_func_arg_tensors
,
stages
,
func_body
);
// 3.Building LoweredFunc
auto
func
=
ir
::
_LoweredFunc_
::
Make
(
group
->
GetFuncName
(),
func_args
,
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
),
group_
func_args
,
ir_sch
->
GetModule
().
GetExprs
().
at
(
0
),
temp_buffers
);
if
(
!
done_op_schedule
)
{
func
->
PrepareBufferCastExprs
();
}
// 4.Apply low level pass
func
=
optim
::
Optimize
(
Expr
(
func
),
target_
,
false
).
as_lowered_func_ref
();
return
{
func
};
}
std
::
vector
<
Expr
>
OpLowerer
::
IRElementwiseCompute
(
poly
::
StageMap
&
stages
,
std
::
vector
<
ir
::
Tensor
>&
func_tensors
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
const
GroupPtr
&
group
,
const
GroupPtr
&
sub_group
,
bool
apply_impl_schedule
)
{
VLOG
(
2
)
<<
"ElementwiseCompute Group : "
<<
sub_group
->
group_id
;
std
::
vector
<
ir
::
Expr
>
OpLowerer
::
LowerOps
(
const
std
::
vector
<
Node
*>&
nodes
,
bool
apply_op_schedule
,
ScheduleDetermineFunction
schedule_determine_func
,
std
::
vector
<
ir
::
Tensor
>*
group_func_arg_tensors
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
)
{
auto
&
strategy
=
Operator
::
GetAttrs
<
StrategyFunction
>
(
"CINNStrategy"
);
std
::
vector
<
Expr
>
ast_exprs
;
for
(
auto
&
node
:
sub_group
->
nodes
)
{
VLOG
(
4
)
<<
"Lower op: "
<<
node
->
op
()
->
name
;
auto
node_data
=
GetNodeData
(
node
);
CHECK_EQ
(
GetAllNodeData
(
node
).
size
(),
1U
);
std
::
vector
<
common
::
CINNValue
>
cinn_inputs
;
std
::
vector
<
ir
::
Tensor
>
tensor_inputs
=
std
::
move
(
CollectInputTensor
(
node
,
func_tensors
,
tensor_map
,
this
->
type_dict_
,
this
->
shape_dict_
));
for
(
auto
&
tensor
:
tensor_inputs
)
{
cinn_inputs
.
push_back
(
common
::
CINNValue
(
ir
::
Expr
(
tensor
)));
}
// set tensor name = node data name
cinn_inputs
.
push_back
(
common
::
CINNValue
(
node_data
->
id
()));
std
::
vector
<
Expr
>
func_bodies
;
for
(
Node
*
node
:
nodes
)
{
// 1.Select Op impl
std
::
vector
<
Type
>
out_types
;
std
::
vector
<
std
::
vector
<
int
>>
out_shapes
;
std
::
vector
<
NodeData
*>
node_datas
=
GetAllNodeData
(
node
);
for
(
const
auto
&
node_data
:
node_datas
)
{
out_types
.
push_back
(
this
->
type_dict_
.
at
(
node_data
->
id
()));
out_shapes
.
push_back
(
this
->
shape_dict_
.
at
(
node_data
->
id
()));
auto
impl
=
OpStrategy
::
SelectImpl
(
strategy
[
node
->
op
()](
node
->
attrs
,
tensor_inputs
,
out_types
,
out_shapes
,
this
->
target_
));
// do compute
common
::
CINNValuePack
pack
=
impl
->
fcompute
(
common
::
CINNValuePack
{
cinn_inputs
});
CHECK_EQ
(
pack
.
size
(),
2U
);
Expr
expr
=
pack
[
0
];
poly
::
StageMap
node_stages
=
pack
.
back
();
tensor_inputs
.
push_back
(
expr
.
as_tensor_ref
());
tensor_map
[
node_data
->
id
()]
=
expr
.
as_tensor_ref
();
auto
func
=
lang
::
LowerVec
(
"fn_"
+
node
->
id
(),
node_stages
,
tensor_inputs
,
{},
{},
nullptr
,
this
->
target_
,
true
);
CHECK_EQ
(
func
.
size
(),
1
);
if
(
apply_impl_schedule
)
{
std
::
vector
<
common
::
CINNValue
>
schedule_inputs
;
// collect tensor
for
(
int
idx
=
0
;
idx
<
pack
.
size
()
-
1
;
++
idx
)
{
CHECK
(
pack
[
idx
].
is_tensor
());
schedule_inputs
.
push_back
(
common
::
CINNValue
(
pack
[
idx
]));
}
for
(
auto
&
f
:
func
)
{
schedule_inputs
.
push_back
(
common
::
CINNValue
(
f
->
body
));
}
// do ast tree schedule
common
::
CINNValuePack
expr_pack
=
impl
->
fschedule
(
common
::
CINNValuePack
{
schedule_inputs
});
CHECK_EQ
(
expr_pack
.
size
(),
1
);
Expr
ast_expr
=
expr_pack
[
0
];
ast_exprs
.
push_back
(
ast_expr
);
std
::
vector
<
ir
::
Tensor
>
op_func_arg_tensors
=
std
::
move
(
CollectInputTensor
(
node
,
this
->
type_dict_
,
this
->
shape_dict_
,
group_func_arg_tensors
,
tensor_map
));
auto
op_impl
=
OpStrategy
::
SelectImpl
(
strategy
[
node
->
op
()](
node
->
attrs
,
op_func_arg_tensors
,
out_types
,
out_shapes
,
this
->
target_
));
// 2.Perform the lower process of Op
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
DoOpLower
(
op_impl
,
node
,
tensor_map
,
&
op_func_arg_tensors
);
if
(
apply_op_schedule
&&
(
this
->*
schedule_determine_func
)(
node
))
{
// 3.Perform the schedule of Op
func_bodies
.
push_back
(
DoOpSchedule
(
op_impl
,
op_func_arg_tensors
,
funcs
));
}
else
{
ast_exprs
.
push_back
(
func
[
0
]
->
body
);
for
(
const
ir
::
LoweredFunc
&
func
:
funcs
)
{
func_bodies
.
push_back
(
func
->
body
);
}
}
}
return
ast_expr
s
;
return
func_bodie
s
;
}
std
::
vector
<
Expr
>
OpLowerer
::
IRReduceCompute
(
poly
::
StageMap
&
stages
,
std
::
vector
<
ir
::
Tensor
>&
func_args
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
const
GroupPtr
&
group
,
const
GroupPtr
&
sub_group
,
bool
apply_impl_schedule
)
{
VLOG
(
2
)
<<
"ReduceCompute Group : "
<<
sub_group
->
group_id
;
auto
&
cinn_strategy
=
Operator
::
GetAttrs
<
StrategyFunction
>
(
"CINNStrategy"
);
auto
&
op_pattern_dict
=
Operator
::
GetAttrs
<
OpPatternKind
>
(
"OpPattern"
);
std
::
vector
<
Expr
>
ast_exprs
;
for
(
auto
&
node
:
sub_group
->
nodes
)
{
auto
node_data
=
GetNodeData
(
node
);
VLOG
(
3
)
<<
"In ReduceCompute, process node: "
<<
node
->
id
()
<<
" with op type: "
<<
node
->
op
()
->
name
;
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
DoOpLower
(
std
::
shared_ptr
<
hlir
::
framework
::
OpImpl
>
op_impl
,
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
,
std
::
vector
<
ir
::
Tensor
>*
op_func_arg_tensors
)
{
VLOG
(
4
)
<<
"Do lower with Compute, op: "
<<
node
->
op
()
->
name
;
std
::
vector
<
common
::
CINNValue
>
cinn_inputs
;
std
::
vector
<
ir
::
Tensor
>
tensor_inputs
=
std
::
move
(
CollectInputTensor
(
node
,
func_args
,
tensor_map
,
this
->
type_dict_
,
this
->
shape_dict_
));
for
(
auto
&
tensor
:
tensor_inputs
)
{
for
(
const
ir
::
Tensor
&
tensor
:
*
op_func_arg_tensors
)
{
cinn_inputs
.
push_back
(
common
::
CINNValue
(
ir
::
Expr
(
tensor
)));
}
// set tensor name = node data name
std
::
vector
<
NodeData
*>
node_datas
=
GetAllNodeData
(
node
);
for
(
const
NodeData
*
node_data
:
node_datas
)
{
cinn_inputs
.
push_back
(
common
::
CINNValue
(
node_data
->
id
()));
}
std
::
vector
<
Type
>
out_types
;
std
::
vector
<
std
::
vector
<
int
>>
out_shapes
;
out_types
.
push_back
(
this
->
type_dict_
.
at
(
node_data
->
id
()));
out_shapes
.
push_back
(
this
->
shape_dict_
.
at
(
node_data
->
id
()));
auto
impl
=
OpStrategy
::
SelectImpl
(
cinn_strategy
[
node
->
op
()](
node
->
attrs
,
tensor_inputs
,
out_types
,
out_shapes
,
target_
));
// do compute
// 1.Do compute
common
::
CINNValuePack
pack
=
impl
->
fcompute
(
common
::
CINNValuePack
{
cinn_inputs
});
op_
impl
->
fcompute
(
common
::
CINNValuePack
{
cinn_inputs
});
CHECK_GE
(
pack
.
size
(),
2UL
);
CHECK_LE
(
pack
.
size
(),
5UL
);
poly
::
StageMap
tmp_stages
=
pack
.
back
();
std
::
string
post
=
""
;
for
(
int
idx
=
0
;
idx
<
pack
.
size
()
-
1
;
++
idx
)
{
Expr
expr
=
pack
[
idx
];
tensor_map
[
node_data
->
id
()
+
post
]
=
expr
.
as_tensor_ref
();
// As op may has more than 1 output tensor, using id + "_0"/"_1" as key.
// Insert the output tensor defined by Compute into the tensor_map
if
(
pack
.
size
()
-
1
>
node_datas
.
size
())
{
// Some nodes may output multiple temp tensors in their Compute
// definition, but only one output node_data in the graph, and we use id +
// "_0"/"_1" as key.
(
*
tensor_map
)[
node_datas
[
0
]
->
id
()
+
post
]
=
expr
.
as_tensor_ref
();
post
=
"_"
+
std
::
to_string
(
idx
);
}
else
{
// If the number of output tensors defined by Compute is less equal than
// the output node_data on the graph, then there is a one-to-one
// correspondence, and the redundant output node_data contact empty.
(
*
tensor_map
)[
node_datas
[
idx
]
->
id
()]
=
expr
.
as_tensor_ref
();
}
// Insert outout tensors
// Insert output tensors into function arg
if
(
!
expr
.
as_tensor_ref
()
->
buffer
.
defined
()
||
this
->
target_
!=
common
::
DefaultNVGPUTarget
())
{
tensor_inputs
.
push_back
(
expr
.
as_tensor_ref
());
op_func_arg_tensors
->
push_back
(
expr
.
as_tensor_ref
());
expr
.
as_tensor_ref
()
->
WithBuffer
();
}
}
auto
func
=
lang
::
LowerVec
(
"fn_"
+
node
->
id
(),
// 2.Do lower
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
lang
::
LowerVec
(
"fn_"
+
node
->
id
(),
tmp_stages
,
tensor_input
s
,
*
op_func_arg_tensor
s
,
{},
{},
nullptr
,
this
->
target_
,
true
);
VLOG
(
4
)
<<
"Lower op: "
<<
node
->
op
()
->
name
<<
", get "
<<
funcs
.
size
()
<<
" LoweredFunc:
\n
"
;
// node is kReduction
if
(
op_pattern_dict
[
node
->
op
()]
==
framework
::
kReduction
&&
apply_impl_schedule
)
{
std
::
vector
<
common
::
CINNValue
>
schedule_inputs
;
// collect tensor
op_func_arg_tensors
->
clear
();
for
(
int
idx
=
0
;
idx
<
pack
.
size
()
-
1
;
++
idx
)
{
CHECK
(
pack
[
idx
].
is_tensor
());
schedule_inputs
.
push_back
(
common
::
CINNValue
(
pack
[
idx
]));
}
for
(
auto
&
f
:
func
)
{
schedule_inputs
.
push_back
(
common
::
CINNValue
(
f
->
body
));
}
// do ast tree schedule
common
::
CINNValuePack
expr_pack
=
impl
->
fschedule
(
common
::
CINNValuePack
{
schedule_inputs
});
// ast tree after schedule.
Expr
ast_expr
=
expr_pack
[
0
];
ast_exprs
.
push_back
(
ast_expr
);
}
else
if
(
group
->
master_nodes
.
count
(
node
))
{
// as master node should copy transform from reducer, left it to reduce
// schedule.
ast_exprs
.
push_back
(
func
[
0
]
->
body
);
}
else
{
ast_exprs
.
push_back
(
func
[
0
]
->
body
);
}
op_func_arg_tensors
->
push_back
(
pack
[
idx
].
operator
ir
::
Expr
().
as_tensor_ref
());
}
return
ast_expr
s
;
return
func
s
;
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
IRLowerNonFusibleOp
(
GroupPtr
&
group
,
bool
apply_impl_schedule
)
{
VLOG
(
3
)
<<
"LowerNonFusibleOp Group : "
<<
group
->
group_id
;
// get input tensor and output tensor
CHECK
(
group
->
nodes
.
size
()
||
group
->
fused_sub_groups
.
size
());
auto
&
cinn_strategy
=
Operator
::
GetAttrs
<
StrategyFunction
>
(
"CINNStrategy"
);
auto
&
op_pattern_dict
=
Operator
::
GetAttrs
<
OpPatternKind
>
(
"OpPattern"
);
auto
node
=
group
->
fused_sub_groups
.
size
()
?
group
->
fused_sub_groups
[
0
]
->
nodes
.
front
()
:
group
->
nodes
.
front
();
VLOG
(
3
)
<<
"GetOpFunc of op "
<<
node
->
id
();
std
::
vector
<
ir
::
Tensor
>
inputs
;
std
::
vector
<
common
::
CINNValue
>
cinn_inputs
;
std
::
vector
<
ir
::
Argument
>
args
;
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>
tensor_map
;
for
(
auto
&
node_data
:
GetInputNodeData
(
node
))
{
CHECK
(
node_data
);
ir
::
Tensor
tensor
;
if
(
!
tensor_map
.
count
(
node_data
->
id
()))
{
tensor
=
GetTensor
(
node_data
,
this
->
type_dict_
,
this
->
shape_dict_
);
// record tensor.
tensor_map
[
node_data
->
id
()]
=
tensor
;
// input name.
group
->
input_names
.
push_back
(
node_data
->
id
());
// input type.
args
.
emplace_back
(
tensor
->
buffer
,
ir
::
Argument
::
IO
::
kInput
);
}
else
{
tensor
=
tensor_map
[
node_data
->
id
()];
}
inputs
.
push_back
(
tensor
);
cinn_inputs
.
push_back
(
common
::
CINNValue
(
tensor
));
}
std
::
vector
<
Type
>
out_types
;
std
::
vector
<
std
::
vector
<
int
>>
out_shapes
;
auto
node_datas
=
GetAllNodeData
(
node
);
for
(
auto
node_data
:
node_datas
)
{
VLOG
(
3
)
<<
"cinn_inputs.push_back "
<<
node_data
->
id
();
group
->
output_names
.
push_back
(
node_data
->
id
());
out_types
.
push_back
(
this
->
type_dict_
.
at
(
node_data
->
id
()));
out_shapes
.
push_back
(
this
->
shape_dict_
.
at
(
node_data
->
id
()));
cinn_inputs
.
push_back
(
common
::
CINNValue
(
node_data
->
id
()));
}
auto
impl
=
OpStrategy
::
SelectImpl
(
cinn_strategy
[
node
->
op
()](
node
->
attrs
,
inputs
,
out_types
,
out_shapes
,
target_
));
// if node op is custom_call, apply custom_call compute.
if
(
node
->
op
()
->
name
==
"custom_call"
)
{
std
::
string
external_api
;
if
(
node
->
attrs
.
attr_store
.
count
(
"custom_call"
))
{
external_api
=
absl
::
get
<
std
::
string
>
(
node
->
attrs
.
attr_store
.
at
(
"custom_call"
));
}
else
{
external_api
=
ExternalApiRegistry
::
Global
()
->
GetExternalApi
(
node
,
target_
);
}
std
::
vector
<
common
::
CINNValue
>
compute_args
=
{
common
::
CINNValue
(
group
->
GetFuncName
()),
common
::
CINNValue
(
external_api
)};
common
::
CINNValuePack
pack
=
impl
->
fcompute
(
common
::
CINNValuePack
{
compute_args
});
CHECK_EQ
(
pack
.
size
(),
1UL
);
// reset input names as extern api input args can't be remove duplicate.
group
->
input_names
.
clear
();
for
(
auto
&
inode
:
node
->
inlinks_in_order
())
{
group
->
input_names
.
push_back
(
inode
->
source
()
->
as
<
NodeData
>
()
->
id
());
}
return
{
pack
[
0
].
operator
ir
::
Expr
().
as_lowered_func_ref
()};
}
common
::
CINNValuePack
pack
=
impl
->
fcompute
(
common
::
CINNValuePack
{
cinn_inputs
});
for
(
int
i
=
0
;
i
<
pack
->
size
()
-
1
;
i
++
)
{
ir
::
Expr
temp
=
pack
[
i
];
// checkout whether the tensor is with buffer.
if
(
!
temp
.
as_tensor_ref
()
->
buffer
.
defined
()
||
this
->
target_
!=
common
::
DefaultNVGPUTarget
())
{
inputs
.
push_back
(
temp
.
as_tensor_ref
());
temp
.
as_tensor_ref
()
->
WithBuffer
();
args
.
emplace_back
(
temp
.
as_tensor_ref
()
->
buffer
,
ir
::
Argument
::
IO
::
kOutput
);
}
}
poly
::
StageMap
stages
=
pack
.
back
();
auto
func
=
lang
::
LowerVec
(
group
->
GetFuncName
(),
stages
,
inputs
,
{},
{},
nullptr
,
this
->
target_
,
true
);
if
(
apply_impl_schedule
)
{
ir
::
Expr
OpLowerer
::
DoOpSchedule
(
std
::
shared_ptr
<
hlir
::
framework
::
OpImpl
>
op_impl
,
const
std
::
vector
<
ir
::
Tensor
>&
op_func_arg_tensors
,
const
std
::
vector
<
ir
::
LoweredFunc
>&
lowered_funcs
)
{
VLOG
(
4
)
<<
"Do op schedule"
;
std
::
vector
<
common
::
CINNValue
>
schedule_inputs
;
// collect tensor
for
(
int
idx
=
0
;
idx
<
pack
.
size
()
-
1
;
++
idx
)
{
CHECK
(
pack
[
idx
].
is_tensor
());
schedule_inputs
.
push_back
(
common
::
CINNValue
(
pack
[
idx
]));
// 1.Collect tensors
for
(
const
ir
::
Tensor
&
op_func_arg_tensor
:
op_func_arg_tensors
)
{
schedule_inputs
.
push_back
(
common
::
CINNValue
(
op_func_arg_tensor
));
}
for
(
auto
&
f
:
func
)
{
schedule_inputs
.
push_back
(
common
::
CINNValue
(
f
->
body
));
// 2.Collect bodies to be scheduled
for
(
const
ir
::
LoweredFunc
&
func
:
lowered_funcs
)
{
schedule_inputs
.
push_back
(
common
::
CINNValue
(
func
->
body
));
}
// do ast tree schedule
// 3.Do schedule on AST
common
::
CINNValuePack
expr_pack
=
impl
->
fschedule
(
common
::
CINNValuePack
{
schedule_inputs
});
ir
::
Expr
func_body
=
expr_pack
[
0
];
std
::
vector
<
std
::
string
>
input_output_nodes
(
group
->
input_names
);
input_output_nodes
.
insert
(
input_output_nodes
.
end
(),
group
->
output_names
.
begin
(),
group
->
output_names
.
end
());
VLOG
(
6
)
<<
"func.size() = "
<<
func
.
size
()
<<
", expr_pack.size() = "
<<
expr_pack
.
size
();
VLOG
(
6
)
<<
"args.size() = "
<<
args
.
size
()
<<
", input_output_nodes.size() = "
<<
input_output_nodes
.
size
();
if
(
args
.
size
()
>
input_output_nodes
.
size
())
{
args
=
lang
::
GetArgs
(
func_body
,
input_output_nodes
);
}
std
::
vector
<
ir
::
LoweredFunc
>
res
;
for
(
int
i
=
0
;
i
<
expr_pack
.
size
();
i
++
)
{
ir
::
Expr
func_body
=
expr_pack
[
0
];
#ifdef CINN_WITH_CUDA
optim
::
OptimizeExprGPU
(
&
(
func_body
));
#endif
auto
temp_buffers
=
lang
::
GetTempBuffers
(
inputs
,
stages
,
func_body
);
auto
function
=
ir
::
_LoweredFunc_
::
Make
(
group
->
GetFuncName
(),
args
,
func_body
,
temp_buffers
);
res
.
push_back
(
function
);
}
for
(
auto
&
i
:
res
)
{
i
=
optim
::
Optimize
(
Expr
(
i
),
target_
,
false
).
as_lowered_func_ref
();
}
return
res
;
}
else
{
for
(
auto
&
f
:
func
)
{
#ifdef CINN_WITH_CUDA
optim
::
OptimizeExprGPU
(
&
(
f
->
body
));
#endif
f
=
optim
::
Optimize
(
Expr
(
f
),
target_
,
false
).
as_lowered_func_ref
();
}
return
func
;
}
op_impl
->
fschedule
(
common
::
CINNValuePack
{
schedule_inputs
});
VLOG
(
4
)
<<
"After op schedule: "
<<
expr_pack
[
0
].
operator
ir
::
Expr
();
return
expr_pack
[
0
].
operator
ir
::
Expr
();
}
// group schedule
void
OpLowerer
::
IR
Schedule
(
ir
::
Expr
OpLowerer
::
DoGroup
Schedule
(
ir
::
IRSchedule
&
ir_sch
,
const
GroupPtr
&
group
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
)
{
...
...
@@ -698,6 +506,7 @@ void OpLowerer::IRSchedule(
<<
", ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
// if node is horizontal with reduce or node is reduce, loop assign
//
// master.
auto
loops
=
ir_sch
.
GetLoops
(
GetNodeData
(
node
)
->
id
());
if
(
op_pattern_dict
[
node
->
op
()]
==
framework
::
kElementWise
)
{
...
...
@@ -788,6 +597,7 @@ void OpLowerer::IRSchedule(
ir_sch
,
group
,
nodes_inline
,
nodes_set
,
this
->
shape_dict_
,
tensor_map
);
VLOG
(
4
)
<<
"After IRSchedule, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
return
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
}
}
// namespace framework
...
...
paddle/cinn/hlir/framework/op_lowering.h
浏览文件 @
3559252a
...
...
@@ -39,46 +39,132 @@ using GroupPtr = std::shared_ptr<Graph::Group>;
using
common
::
Target
;
class
OpLowerer
;
typedef
std
::
vector
<
Expr
>
(
OpLowerer
::*
IRComputeFunction
)(
poly
::
StageMap
&
,
std
::
vector
<
ir
::
Tensor
>&
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
,
const
GroupPtr
&
,
const
GroupPtr
&
,
bool
);
typedef
bool
(
OpLowerer
::*
ScheduleDetermineFunction
)(
Node
*
);
class
OpLowerer
{
public:
OpLowerer
(
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
,
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
,
const
Target
&
);
std
::
vector
<
ir
::
LoweredFunc
>
Lower
(
GroupPtr
&
group
);
// NOLINT
std
::
vector
<
ir
::
LoweredFunc
>
LowerWithoutSchedule
(
GroupPtr
&
group
);
// NOLINT
/**
* @brief Lower a group to CINN IR.
* @param group The group to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param apply_group_schedule Whether to schedule at group level.
* @return The lowered funcs.
*/
std
::
vector
<
ir
::
LoweredFunc
>
Lower
(
const
GroupPtr
&
group
,
bool
apply_op_schedule
=
true
,
bool
apply_group_schedule
=
true
);
private:
std
::
vector
<
ir
::
LoweredFunc
>
IRLowerOp
(
IRComputeFunction
,
GroupPtr
&
);
std
::
vector
<
ir
::
LoweredFunc
>
IRLowerNonFusibleOp
(
GroupPtr
&
,
bool
);
std
::
vector
<
ir
::
LoweredFunc
>
IRLowerOpWithoutSchedule
(
IRComputeFunction
,
GroupPtr
&
);
#define DEFINE_IR_COMPUTE(type) \
std::vector<Expr> IR##type##Compute( \
poly::StageMap& stages, \
std::vector<ir::Tensor>& func_args, \
std::unordered_map<std::string, ir::Tensor>& tensor_map, \
const GroupPtr& group, \
const GroupPtr& sub_group, \
bool apply_impl_schedule = false);
// compute and schedule
DEFINE_IR_COMPUTE
(
Elementwise
);
DEFINE_IR_COMPUTE
(
Reduce
);
DEFINE_IR_COMPUTE
(
OutEWiseFusable
);
void
IRSchedule
(
/**
* @brief Lower a group to CINN IR.
* @param group The group to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param apply_group_schedule Whether to schedule at group level.
* @param schedule_determine_func Function used to determine which Ops to
* schedule.
* @return The lowered funcs.
*/
std
::
vector
<
ir
::
LoweredFunc
>
LowerGroup
(
const
GroupPtr
&
group
,
bool
apply_op_schedule
,
bool
apply_group_schedule
,
ScheduleDetermineFunction
schedule_determine_func
);
/**
* @brief Lower a group composed of CustomCall Op.
* @param group The group to be lowered.
* @return The lowered funcs.
*/
std
::
vector
<
ir
::
LoweredFunc
>
LowerCustomCall
(
const
GroupPtr
&
group
);
/**
* @brief Post processing, including preparing function args and temporary
* variables, applying low-level optimization passes, etc.
* @param group The group to be lowered.
* @param tensor_map All tensors used for calculating the group.
* @param done_op_schedule Mark whether the Op level schedule has been
* applied.
* @param ir_sch The IRSchedule object of group.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @return The lowered funcs after the post processing.
*/
std
::
vector
<
ir
::
LoweredFunc
>
PostProcess
(
const
GroupPtr
&
group
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
bool
done_op_schedule
,
ir
::
IRSchedule
*
ir_sch
,
std
::
vector
<
ir
::
Tensor
>*
group_func_arg_tensors
);
/**
* @brief Lower an Op set to CINN IR.
* Compute, Lower and optional Schedule will be performed one by one
* for each Op.
* @param nodes The Op nodes to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param schedule_determine_func Function used to determine which Ops to
* schedule.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @param tensor_map All tensors used for calculating the group.
* @return The lowered func bodies of Op set.
*/
std
::
vector
<
ir
::
Expr
>
LowerOps
(
const
std
::
vector
<
Node
*>&
nodes
,
bool
apply_op_schedule
,
ScheduleDetermineFunction
schedule_determine_func
,
std
::
vector
<
ir
::
Tensor
>*
group_func_arg_tensors
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
);
/**
* @brief Lower an Op to CINN IR. The Compute and Lower processes will be
* called sequentially.
* @param op_impl The Op implementation defining Compute and Schedule.
* @param node The Op node to be lowered.
* @param tensor_map All tensors used for calculating the group.
* @param op_func_arg_tensors Tensors used as the Op function arguments.
* @return The lowered func of the Op node.
*/
std
::
vector
<
ir
::
LoweredFunc
>
DoOpLower
(
std
::
shared_ptr
<
hlir
::
framework
::
OpImpl
>
op_impl
,
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
,
std
::
vector
<
ir
::
Tensor
>*
op_func_arg_tensors
);
/**
* @brief Apply schedule on an Op.
* @param op_impl The Op implementation defining Compute and Schedule.
* @param op_func_arg_tensors Tensors used as the Op function arguments.
* @param lowered_funcs The lowered funcs of an Op to be scheduled.
* @return The lowered func body after schedule of the Op.
*/
ir
::
Expr
DoOpSchedule
(
std
::
shared_ptr
<
hlir
::
framework
::
OpImpl
>
op_impl
,
const
std
::
vector
<
ir
::
Tensor
>&
op_func_arg_tensors
,
const
std
::
vector
<
ir
::
LoweredFunc
>&
lowered_funcs
);
/**
* @brief Apply schedule on a group.
* @param ir_sch The IRSchedule containing the entire group's lowered func
* bodies.
* @param group The group to be scheduled.
* @param tensor_map All tensors used for calculating the group.
* @return The lowered func body after schedule of the group.
*/
ir
::
Expr
DoGroupSchedule
(
ir
::
IRSchedule
&
ir_sch
,
// NOLINT
const
GroupPtr
&
group
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
);
// Functions used to determine which Ops to schedule at op level, define a
// policy for each type of group.
inline
bool
ReduceScheduleDetermineFunction
(
Node
*
node
);
inline
bool
ElementwiseScheduleDetermineFunction
(
Node
*
node
);
inline
bool
NonFusibleScheduleDetermineFunction
(
Node
*
node
);
private:
Target
target_
;
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict_
;
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict_
;
...
...
paddle/cinn/hlir/framework/op_lowering_util.cc
浏览文件 @
3559252a
...
...
@@ -92,19 +92,19 @@ ir::Tensor GetTensor(
std
::
vector
<
ir
::
Tensor
>
CollectInputTensor
(
const
Node
*
node
,
std
::
vector
<
ir
::
Tensor
>&
func_args
,
// NOLINT
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
// NOLINT
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict
,
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict
)
{
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict
,
std
::
vector
<
ir
::
Tensor
>*
func_args
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
)
{
std
::
vector
<
ir
::
Tensor
>
tensors
;
// get all input nodes
for
(
auto
&
node_data
:
GetInputNodeData
(
node
))
{
CHECK
(
node_data
);
auto
tensor
=
GetTensor
(
node_data
,
type_dict
,
shape_dict
);
if
(
!
tensor_map
.
count
(
node_data
->
id
()))
{
tensor_map
[
node_data
->
id
()]
=
tensor
;
if
(
!
tensor_map
->
count
(
node_data
->
id
()))
{
(
*
tensor_map
)
[
node_data
->
id
()]
=
tensor
;
// record func input args
func_args
.
push_back
(
tensor
);
func_args
->
push_back
(
tensor
);
}
tensors
.
push_back
(
tensor
);
}
...
...
paddle/cinn/hlir/framework/op_lowering_util.h
浏览文件 @
3559252a
...
...
@@ -31,10 +31,10 @@ ir::Tensor GetTensor(
std
::
vector
<
ir
::
Tensor
>
CollectInputTensor
(
const
Node
*
node
,
std
::
vector
<
ir
::
Tensor
>&
func_args
,
// NOLINT
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
// NOLINT
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict
,
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict
);
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict
,
std
::
vector
<
ir
::
Tensor
>*
func_args
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
);
std
::
unordered_map
<
Node
*
,
Node
*>
BuildVirtualConsumer
(
const
GroupPtr
&
group
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录