Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3559252a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3559252a
编写于
7月 13, 2023
作者:
B
BiynXu
提交者:
GitHub
7月 13, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CINN] comb the op lowering code (#54982)
* [CINN] comb the op lowering code * [CINN] format code of OpLower
上级
27cc0df5
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
466 addition
and
571 deletion
+466
-571
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc
...o_schedule/search_space/auto_gen_rule/auto_inline_test.cc
+3
-2
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc
...ule/search_space/auto_gen_rule/multi_level_tiling_test.cc
+59
-62
paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc
...n/auto_schedule/search_space/auto_gen_rule/test_helper.cc
+4
-6
paddle/cinn/auto_schedule/task/tune_task.cc
paddle/cinn/auto_schedule/task/tune_task.cc
+2
-1
paddle/cinn/auto_schedule/tests/performance_comparison_test.cc
...e/cinn/auto_schedule/tests/performance_comparison_test.cc
+3
-1
paddle/cinn/hlir/framework/op_lowering.cc
paddle/cinn/hlir/framework/op_lowering.cc
+272
-462
paddle/cinn/hlir/framework/op_lowering.h
paddle/cinn/hlir/framework/op_lowering.h
+114
-28
paddle/cinn/hlir/framework/op_lowering_util.cc
paddle/cinn/hlir/framework/op_lowering_util.cc
+6
-6
paddle/cinn/hlir/framework/op_lowering_util.h
paddle/cinn/hlir/framework/op_lowering_util.h
+3
-3
未找到文件。
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc
浏览文件 @
3559252a
...
@@ -71,7 +71,6 @@ TEST(AutoInline, SingleLoopInline) {
...
@@ -71,7 +71,6 @@ TEST(AutoInline, SingleLoopInline) {
nullptr
,
nullptr
,
target
,
target
,
true
);
true
);
VLOG
(
6
)
<<
"Expr after lowering:"
;
VLOG
(
6
)
<<
"Expr after lowering:"
;
VLOG
(
6
)
<<
funcs
[
0
]
->
body
;
VLOG
(
6
)
<<
funcs
[
0
]
->
body
;
...
@@ -170,7 +169,9 @@ TEST(AutoInline, AddReluInline) {
...
@@ -170,7 +169,9 @@ TEST(AutoInline, AddReluInline) {
EXPECT_EQ
(
graph
->
fusion_groups
.
size
(),
1UL
);
EXPECT_EQ
(
graph
->
fusion_groups
.
size
(),
1UL
);
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
op_lowerer
->
LowerWithoutSchedule
(
graph
->
fusion_groups
[
0
]);
op_lowerer
->
Lower
(
graph
->
fusion_groups
[
0
],
/*apply_op_schedule = */
false
,
/*apply_group_schedule=*/
false
);
VLOG
(
6
)
<<
"Expr before auto inline: "
<<
funcs
[
0
]
->
body
;
VLOG
(
6
)
<<
"Expr before auto inline: "
<<
funcs
[
0
]
->
body
;
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc
浏览文件 @
3559252a
...
@@ -388,9 +388,9 @@ TEST_F(TestMultiLevelTiling, ReduceSum) {
...
@@ -388,9 +388,9 @@ TEST_F(TestMultiLevelTiling, ReduceSum) {
TEST_F
(
TestMultiLevelTiling
,
Pool2d
)
{
TEST_F
(
TestMultiLevelTiling
,
Pool2d
)
{
default_input_names
=
{
"input"
};
default_input_names
=
{
"input"
};
default_output_names
=
{
"var_0"
};
default_output_names
=
{
"var_0"
,
"pad_temp_0"
};
std
::
vector
<
int32_t
>
input_shape
{
2
,
8
,
16
,
16
};
std
::
vector
<
std
::
vector
<
int32_t
>>
input_shapes
{{
2
,
8
,
16
,
16
}
};
std
::
vector
<
int32_t
>
output_shape
{
2
,
8
,
8
,
8
};
std
::
vector
<
std
::
vector
<
int32_t
>>
output_shapes
{{
2
,
8
,
8
,
8
},
{
2
,
8
,
18
,
18
}
};
std
::
string
pooling_type
=
"max"
;
std
::
string
pooling_type
=
"max"
;
std
::
vector
<
int
>
ksize
{
3
,
3
};
std
::
vector
<
int
>
ksize
{
3
,
3
};
std
::
vector
<
int
>
strides
{
2
,
2
};
std
::
vector
<
int
>
strides
{
2
,
2
};
...
@@ -402,7 +402,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
...
@@ -402,7 +402,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
bool
adaptive
=
false
;
bool
adaptive
=
false
;
std
::
string
padding_algorithm
=
"EXPLICIT"
;
std
::
string
padding_algorithm
=
"EXPLICIT"
;
frontend
::
Program
pool2d_program
=
tests
::
OpBuilder
(
"pool2d"
).
Build
(
frontend
::
Program
pool2d_program
=
tests
::
OpBuilder
(
"pool2d"
).
Build
(
{{
"input"
,
input_shape
}},
{{
"input"
,
input_shape
s
[
0
]
}},
{{
"pool_type"
,
pooling_type
},
{{
"pool_type"
,
pooling_type
},
{
"kernel_size"
,
ksize
},
{
"kernel_size"
,
ksize
},
{
"stride_size"
,
strides
},
{
"stride_size"
,
strides
},
...
@@ -440,85 +440,82 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
...
@@ -440,85 +440,82 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
{
ScheduleBlock(root)
ScheduleBlock(root)
{
{
serial for (i, 0, 2)
{
{
serial for (
j, 0, 8
)
serial for (
i, 0, 2
)
{
{
serial for (
k, 0, 1
8)
serial for (
j, 0,
8)
{
{
serial for (
a
, 0, 18)
serial for (
k
, 0, 18)
{
{
ScheduleBlock(pad_temp_0
)
serial for (a, 0, 18
)
{
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
ScheduleBlock(pad_temp_0)
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
{
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
}
}
}
}
}
}
}
}
}
}
}
}
}
} // end Expr 0
Expr 1 {
{
ScheduleBlock(root_0)
{
{
thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16)
{
{
thread_bind[
threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4
)
thread_bind[
blockIdx.x] for (i_j_k_a_fused, 0, 16
)
{
{
serial for (i_1, 0, 1
)
thread_bind[threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4
)
{
{
serial for (
j_1, 0, 4
)
serial for (
i_1, 0, 1
)
{
{
serial for (
k_1, 0, 1
)
serial for (
j_1, 0, 4
)
{
{
serial for (
a_1, 0, 4
)
serial for (
k_1, 0, 1
)
{
{
ScheduleBlock(var_0__reduce_init
)
serial for (a_1, 0, 4
)
{
{
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)
)
ScheduleBlock(var_0__reduce_init
)
{
{
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1))
{
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f
}
}
}
}
}
}
}
}
}
}
}
}
{
serial for (kernel_idx, 0, 3)
{
{
serial for (kernel_idx
_0
, 0, 3)
serial for (kernel_idx, 0, 3)
{
{
serial for (
ax0_ax1_ax2_ax3_fused, 0, 28
)
serial for (
kernel_idx_0, 0, 3
)
{
{
ScheduleBlock(pad_temp_0_shared_temp_buffer
)
serial for (ax0_ax1_ax2_ax3_fused, 0, 28
)
{
{
v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0)))
ScheduleBlock(pad_temp_0_shared_temp_buffer)
attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0)
{
{
pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3]
v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0)))
attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0)
{
pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3]
}
}
}
}
}
}
serial for (i_1, 0, 1)
serial for (i_1, 0, 1)
{
serial for (j_1, 0, 4)
{
{
serial for (
k_1, 0, 1
)
serial for (
j_1, 0, 4
)
{
{
serial for (
a_1, 0, 4
)
serial for (
k_1, 0, 1
)
{
{
ScheduleBlock(var_0_local_temp_buffer
)
serial for (a_1, 0, 4
)
{
{
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
ScheduleBlock(var_0_local_temp_buffer)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)])
{
{
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))])
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)])
{
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))])
}
}
}
}
}
}
}
...
@@ -526,21 +523,21 @@ Expr 1 {
...
@@ -526,21 +523,21 @@ Expr 1 {
}
}
}
}
}
}
}
serial for (ax0_0, 0, 1)
serial for (ax0_0, 0, 1)
{
serial for (ax1_0, 0, 4)
{
{
serial for (ax
2_0, 0, 1
)
serial for (ax
1_0, 0, 4
)
{
{
serial for (ax
3_0, 0, 4
)
serial for (ax
2_0, 0, 1
)
{
{
ScheduleBlock(var_0
)
serial for (ax3_0, 0, 4
)
{
{
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
ScheduleBlock(var_0)
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
{
{
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
{
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
}
}
}
}
}
}
}
...
@@ -553,7 +550,7 @@ Expr 1 {
...
@@ -553,7 +550,7 @@ Expr 1 {
}
}
}
}
}
}
} // end Expr
1
} // end Expr
0
)ROC"
;
)ROC"
;
ASSERT_EQ
(
ir
,
expected_ir
);
ASSERT_EQ
(
ir
,
expected_ir
);
...
@@ -569,8 +566,8 @@ Expr 1 {
...
@@ -569,8 +566,8 @@ Expr 1 {
pool2d_program
,
fixed_rand_seed
,
/* apply_manual_schedule*/
true
))),
pool2d_program
,
fixed_rand_seed
,
/* apply_manual_schedule*/
true
))),
default_input_names
,
default_input_names
,
default_output_names
,
default_output_names
,
{
input_shape
}
,
input_shapes
,
{
output_shape
}
,
output_shapes
,
target_
);
target_
);
}
}
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc
浏览文件 @
3559252a
...
@@ -63,12 +63,10 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
...
@@ -63,12 +63,10 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
hlir
::
framework
::
OpLowerer
op_lowerer
(
dtype_dict
,
shape_dict
,
target_
);
hlir
::
framework
::
OpLowerer
op_lowerer
(
dtype_dict
,
shape_dict
,
target_
);
if
(
apply_manual_schedule
)
{
lowered_funcs_
=
lowered_funcs_
=
op_lowerer
.
Lower
(
graph
->
fusion_groups
.
front
());
op_lowerer
.
Lower
(
graph
->
fusion_groups
.
front
(),
}
else
{
/*apply_op_schedule = */
apply_manual_schedule
,
lowered_funcs_
=
/*apply_group_schedule = */
apply_manual_schedule
);
op_lowerer
.
LowerWithoutSchedule
(
graph
->
fusion_groups
.
front
());
}
CHECK
(
!
lowered_funcs_
.
empty
())
<<
"lowered_funcs_ is empty"
;
CHECK
(
!
lowered_funcs_
.
empty
())
<<
"lowered_funcs_ is empty"
;
std
::
vector
<
Expr
>
bodys
;
std
::
vector
<
Expr
>
bodys
;
...
...
paddle/cinn/auto_schedule/task/tune_task.cc
浏览文件 @
3559252a
...
@@ -39,7 +39,8 @@ void TuneTask::Initialize(
...
@@ -39,7 +39,8 @@ void TuneTask::Initialize(
op_lowerer
=
lower_handler
;
op_lowerer
=
lower_handler
;
// Set lowered_funcs and analyze output names.
// Set lowered_funcs and analyze output names.
this
->
lowered_funcs
=
op_lowerer
->
LowerWithoutSchedule
(
subgraph
);
this
->
lowered_funcs
=
op_lowerer
->
Lower
(
subgraph
,
/*apply_op_schedule = */
false
,
/*apply_group_schedule=*/
false
);
this
->
output_names
=
GetOutputNamesFromLoweredFunc
(
this
->
lowered_funcs
);
this
->
output_names
=
GetOutputNamesFromLoweredFunc
(
this
->
lowered_funcs
);
this
->
serialized_key
=
SerializeToString
(
shape_dict
,
dtype_dict
);
this
->
serialized_key
=
SerializeToString
(
shape_dict
,
dtype_dict
);
}
}
...
...
paddle/cinn/auto_schedule/tests/performance_comparison_test.cc
浏览文件 @
3559252a
...
@@ -157,7 +157,9 @@ class PerformanceTester : public ::testing::Test {
...
@@ -157,7 +157,9 @@ class PerformanceTester : public ::testing::Test {
for
(
auto
group
:
graph
->
fusion_groups
)
{
for
(
auto
group
:
graph
->
fusion_groups
)
{
compile_options
.
lowered_funcs
.
push_back
(
compile_options
.
lowered_funcs
.
push_back
(
op_lowerer
->
LowerWithoutSchedule
(
group
));
op_lowerer
->
Lower
(
group
,
/*apply_op_schedule = */
false
,
/*apply_group_schedule=*/
false
));
}
}
VLOG
(
3
)
<<
"===========================No Schedule LoweredFunc "
VLOG
(
3
)
<<
"===========================No Schedule LoweredFunc "
...
...
paddle/cinn/hlir/framework/op_lowering.cc
浏览文件 @
3559252a
...
@@ -45,7 +45,9 @@ OpLowerer::OpLowerer(
...
@@ -45,7 +45,9 @@ OpLowerer::OpLowerer(
const
Target
&
target
)
const
Target
&
target
)
:
type_dict_
(
type_dict
),
shape_dict_
(
shape_dict
),
target_
(
target
)
{}
:
type_dict_
(
type_dict
),
shape_dict_
(
shape_dict
),
target_
(
target
)
{}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
Lower
(
GroupPtr
&
group
)
{
// NOLINT
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
Lower
(
const
GroupPtr
&
group
,
bool
apply_op_schedule
,
bool
apply_group_schedule
)
{
VLOG
(
3
)
<<
"Lowering Group : "
<<
group
->
group_id
VLOG
(
3
)
<<
"Lowering Group : "
<<
group
->
group_id
<<
" , Op Pattern : "
<<
group
->
op_pattern_kind
;
<<
" , Op Pattern : "
<<
group
->
op_pattern_kind
;
group
->
input_names
.
clear
();
group
->
input_names
.
clear
();
...
@@ -55,13 +57,22 @@ std::vector<ir::LoweredFunc> OpLowerer::Lower(GroupPtr& group) { // NOLINT
...
@@ -55,13 +57,22 @@ std::vector<ir::LoweredFunc> OpLowerer::Lower(GroupPtr& group) { // NOLINT
case
framework
::
kElementWise
:
case
framework
::
kElementWise
:
case
framework
::
kBroadcast
:
case
framework
::
kBroadcast
:
case
framework
::
kInjective
:
case
framework
::
kInjective
:
return
IRLowerOp
(
&
OpLowerer
::
IRElementwiseCompute
,
group
);
return
LowerGroup
(
group
,
apply_op_schedule
,
apply_group_schedule
,
&
OpLowerer
::
ElementwiseScheduleDetermineFunction
);
case
framework
::
kReduction
:
case
framework
::
kReduction
:
return
IRLowerOp
(
&
OpLowerer
::
IRReduceCompute
,
group
);
return
LowerGroup
(
group
,
apply_op_schedule
,
apply_group_schedule
,
&
OpLowerer
::
ReduceScheduleDetermineFunction
);
case
framework
::
kOutFusible
:
case
framework
::
kOutFusible
:
LOG
(
FATAL
)
<<
"Group Pattern Kind kOutFusible Is Not Implemented!"
;
LOG
(
FATAL
)
<<
"Group Pattern Kind kOutFusible Is Not Implemented!"
;
case
framework
::
kNonFusible
:
case
framework
::
kNonFusible
:
return
IRLowerNonFusibleOp
(
group
,
/*apply_impl_schedule = */
true
);
return
LowerGroup
(
group
,
apply_op_schedule
,
apply_group_schedule
,
&
OpLowerer
::
NonFusibleScheduleDetermineFunction
);
default:
default:
LOG
(
FATAL
)
<<
"Group Pattern Kind Is Unknown!"
;
LOG
(
FATAL
)
<<
"Group Pattern Kind Is Unknown!"
;
}
}
...
@@ -70,532 +81,329 @@ std::vector<ir::LoweredFunc> OpLowerer::Lower(GroupPtr& group) { // NOLINT
...
@@ -70,532 +81,329 @@ std::vector<ir::LoweredFunc> OpLowerer::Lower(GroupPtr& group) { // NOLINT
}
}
}
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
LowerWithoutSchedule
(
GroupPtr
&
group
)
{
bool
OpLowerer
::
ElementwiseScheduleDetermineFunction
(
Node
*
node
)
{
VLOG
(
3
)
<<
"Lowering Group : "
<<
group
->
group_id
return
true
;
<<
" , Op Pattern : "
<<
group
->
op_pattern_kind
;
if
(
FLAGS_cinn_ir_schedule
)
{
switch
(
group
->
op_pattern_kind
)
{
case
framework
::
kElementWise
:
case
framework
::
kBroadcast
:
case
framework
::
kInjective
:
return
IRLowerOpWithoutSchedule
(
&
OpLowerer
::
IRElementwiseCompute
,
group
);
case
framework
::
kReduction
:
return
IRLowerOpWithoutSchedule
(
&
OpLowerer
::
IRReduceCompute
,
group
);
case
framework
::
kOutFusible
:
LOG
(
FATAL
)
<<
"Group Pattern Kind kOutFusible Is Not Implemented!"
;
case
framework
::
kNonFusible
:
return
IRLowerNonFusibleOp
(
group
,
/*apply_impl_schedule = */
false
);
default:
LOG
(
FATAL
)
<<
"Group Pattern Kind kNonFusible Is Not Implemented!"
;
}
}
else
{
LOG
(
FATAL
)
<<
"Previous IR Schedule Is Not Implemented!"
;
}
}
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
IRLowerOp
(
IRComputeFunction
compute
,
bool
OpLowerer
::
ReduceScheduleDetermineFunction
(
Node
*
node
)
{
GroupPtr
&
group
)
{
auto
&
op_pattern_dict
=
Operator
::
GetAttrs
<
OpPatternKind
>
(
"OpPattern"
);
poly
::
StageMap
stages
;
return
op_pattern_dict
[
node
->
op
()]
==
framework
::
kReduction
;
std
::
vector
<
ir
::
Tensor
>
arg_tensors
;
}
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>
tensor_map
;
// do compute.
bool
OpLowerer
::
NonFusibleScheduleDetermineFunction
(
Node
*
node
)
{
return
true
;
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
LowerGroup
(
const
GroupPtr
&
group
,
bool
apply_op_schedule
,
bool
apply_group_schedule
,
ScheduleDetermineFunction
schedule_determine_func
)
{
// 1.Do compute, lower and schedule for each op.
VLOG
(
3
)
<<
"group->fused_sub_groups.size() is : "
VLOG
(
3
)
<<
"group->fused_sub_groups.size() is : "
<<
group
->
fused_sub_groups
.
size
();
<<
group
->
fused_sub_groups
.
size
();
std
::
vector
<
Expr
>
ast_exprs
;
std
::
vector
<
Node
*>
nodes
=
group
->
CollectNodes
();
if
(
group
->
fused_sub_groups
.
size
()
==
0
)
{
if
(
nodes
.
size
()
==
1
&&
nodes
[
0
]
->
op
()
->
name
==
"custom_call"
)
{
ast_exprs
=
(
this
->*
compute
)(
stages
,
return
LowerCustomCall
(
group
);
arg_tensors
,
tensor_map
,
group
,
group
,
/*apply_impl_schedule = */
true
);
}
else
{
for
(
auto
&
sub_group
:
group
->
fused_sub_groups
)
{
auto
exprs
=
(
this
->*
compute
)(
stages
,
arg_tensors
,
tensor_map
,
group
,
sub_group
,
/*apply_impl_schedule = */
true
);
ast_exprs
.
insert
(
ast_exprs
.
end
(),
exprs
.
begin
(),
exprs
.
end
());
}
}
}
ir
::
ModuleExpr
mod_expr
(
ast_exprs
);
std
::
vector
<
ir
::
Tensor
>
group_func_arg_tensors
;
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>
tensor_map
;
bool
do_op_schedule
=
apply_group_schedule
||
apply_op_schedule
;
std
::
vector
<
ir
::
Expr
>
func_bodies
=
LowerOps
(
nodes
,
do_op_schedule
,
schedule_determine_func
,
&
group_func_arg_tensors
,
&
tensor_map
);
// 2.Do group schedule.
ir
::
ModuleExpr
mod_expr
(
func_bodies
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
ir_sch
.
MergeExprs
();
VLOG
(
3
)
<<
"After lower, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
Node
*
first
=
nullptr
;
if
(
apply_group_schedule
)
{
Node
*
second
=
nullptr
;
DoGroupSchedule
(
ir_sch
,
group
,
tensor_map
);
VLOG
(
3
)
<<
"After group schedule, ir is:
\n
"
VLOG
(
3
)
<<
"Before IRLowerOp schedule, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
// do schedule.
IRSchedule
(
ir_sch
,
group
,
tensor_map
);
VLOG
(
3
)
<<
"After IRLowerOp schedule, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
// function args
group
->
input_names
.
clear
();
std
::
vector
<
ir
::
Argument
>
func_args
;
for
(
auto
&
args
:
arg_tensors
)
{
// input node data name.
group
->
input_names
.
push_back
(
args
->
name
);
// input args
func_args
.
emplace_back
(
args
->
buffer
,
ir
::
Argument
::
IO
::
kInput
);
}
}
group
->
output_names
.
clear
();
// 3.Do post-processing,
for
(
auto
&
node
:
group
->
output_nodes
)
{
// including preparing function args and temporary variables,
// output node data name.
// applying low-level optimization passes, etc.
for
(
auto
node_data
:
GetAllNodeData
(
node
))
{
return
PostProcess
(
group
->
output_names
.
push_back
(
node_data
->
id
());
group
,
tensor_map
,
do_op_schedule
,
&
ir_sch
,
&
group_func_arg_tensors
);
}
// collect all output tensor.
std
::
string
post
=
""
;
std
::
string
prefix
=
GetNodeData
(
node
)
->
id
();
for
(
int
idx
=
0
;
idx
<
1
;
++
idx
)
{
CHECK
(
tensor_map
.
count
(
prefix
))
<<
"Can't find output tensor "
<<
prefix
;
if
(
!
tensor_map
.
count
(
prefix
+
post
))
{
break
;
}
auto
tensor
=
tensor_map
[
prefix
+
post
];
arg_tensors
.
push_back
(
tensor
);
// output args
func_args
.
emplace_back
(
tensor
->
buffer
,
ir
::
Argument
::
IO
::
kOutput
);
// update post
post
=
"_"
+
std
::
to_string
(
idx
);
}
}
auto
func_body
=
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
#ifdef CINN_WITH_CUDA
optim
::
OptimizeExprGPU
(
&
(
func_body
));
#endif
auto
temp_buffers
=
lang
::
GetTempBuffers
(
arg_tensors
,
stages
,
func_body
);
auto
func
=
ir
::
_LoweredFunc_
::
Make
(
group
->
GetFuncName
(),
func_args
,
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
),
temp_buffers
);
func
=
optim
::
Optimize
(
Expr
(
func
),
target_
,
false
).
as_lowered_func_ref
();
return
{
func
};
}
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
IRLowerOpWithoutSchedule
(
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
LowerCustomCall
(
const
GroupPtr
&
group
)
{
IRComputeFunction
compute
,
GroupPtr
&
group
)
{
std
::
vector
<
Node
*>
nodes
=
group
->
CollectNodes
();
poly
::
StageMap
stages
;
CHECK_EQ
(
nodes
.
size
(),
1
);
std
::
vector
<
ir
::
Tensor
>
arg_tensors
;
Node
*
node
=
nodes
[
0
];
std
::
vector
<
ir
::
Tensor
>
op_func_arg_tensors
;
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>
tensor_map
;
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>
tensor_map
;
// do compute.
for
(
auto
&
node_data
:
GetInputNodeData
(
node
))
{
VLOG
(
3
)
<<
"group->fused_sub_groups.size() is : "
CHECK
(
node_data
);
<<
group
->
fused_sub_groups
.
size
();
ir
::
Tensor
tensor
;
std
::
vector
<
Expr
>
ast_exprs
;
if
(
!
tensor_map
.
count
(
node_data
->
id
()))
{
if
(
group
->
fused_sub_groups
.
size
()
==
0
)
{
tensor
=
GetTensor
(
node_data
,
this
->
type_dict_
,
this
->
shape_dict_
);
ast_exprs
=
(
this
->*
compute
)(
stages
,
// record tensor.
arg_tensors
,
tensor_map
[
node_data
->
id
()]
=
tensor
;
tensor_map
,
// input name.
group
,
group
->
input_names
.
push_back
(
node_data
->
id
());
group
,
}
else
{
/*apply_impl_schedule = */
false
);
tensor
=
tensor_map
[
node_data
->
id
()];
}
else
{
for
(
auto
&
sub_group
:
group
->
fused_sub_groups
)
{
auto
exprs
=
(
this
->*
compute
)(
stages
,
arg_tensors
,
tensor_map
,
group
,
sub_group
,
/*apply_impl_schedule = */
false
);
ast_exprs
.
insert
(
ast_exprs
.
end
(),
exprs
.
begin
(),
exprs
.
end
());
}
}
op_func_arg_tensors
.
push_back
(
tensor
);
}
}
ir
::
ModuleExpr
mod_expr
(
ast_exprs
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
VLOG
(
3
)
<<
"After IRLowerOp compute, ir is:
\n
"
std
::
vector
<
Type
>
out_types
;
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
std
::
vector
<
std
::
vector
<
int
>>
out_shapes
;
// function args
auto
node_datas
=
GetAllNodeData
(
node
);
for
(
auto
node_data
:
node_datas
)
{
group
->
output_names
.
push_back
(
node_data
->
id
());
out_types
.
push_back
(
this
->
type_dict_
.
at
(
node_data
->
id
()));
out_shapes
.
push_back
(
this
->
shape_dict_
.
at
(
node_data
->
id
()));
}
auto
&
cinn_strategy
=
Operator
::
GetAttrs
<
StrategyFunction
>
(
"CINNStrategy"
);
auto
impl
=
OpStrategy
::
SelectImpl
(
cinn_strategy
[
node
->
op
()](
node
->
attrs
,
op_func_arg_tensors
,
out_types
,
out_shapes
,
target_
));
std
::
string
external_api
;
if
(
node
->
attrs
.
attr_store
.
count
(
"custom_call"
))
{
external_api
=
absl
::
get
<
std
::
string
>
(
node
->
attrs
.
attr_store
.
at
(
"custom_call"
));
}
else
{
external_api
=
ExternalApiRegistry
::
Global
()
->
GetExternalApi
(
node
,
target_
);
}
std
::
vector
<
common
::
CINNValue
>
compute_args
=
{
common
::
CINNValue
(
group
->
GetFuncName
()),
common
::
CINNValue
(
external_api
)};
common
::
CINNValuePack
pack
=
impl
->
fcompute
(
common
::
CINNValuePack
{
compute_args
});
CHECK_EQ
(
pack
.
size
(),
1UL
);
// reset input names as extern api input args can't be remove duplicate.
group
->
input_names
.
clear
();
for
(
auto
&
inode
:
node
->
inlinks_in_order
())
{
group
->
input_names
.
push_back
(
inode
->
source
()
->
as
<
NodeData
>
()
->
id
());
}
return
{
pack
[
0
].
operator
ir
::
Expr
().
as_lowered_func_ref
()};
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
PostProcess
(
const
GroupPtr
&
group
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
bool
done_op_schedule
,
ir
::
IRSchedule
*
ir_sch
,
std
::
vector
<
ir
::
Tensor
>*
group_func_arg_tensors
)
{
// 1.Prepare function args
group
->
input_names
.
clear
();
group
->
input_names
.
clear
();
std
::
vector
<
ir
::
Argument
>
func_args
;
std
::
vector
<
ir
::
Argument
>
group_func_args
;
for
(
auto
&
args
:
arg_tensors
)
{
std
::
unordered_set
<
std
::
string
>
arg_name_set
;
for
(
auto
&
arg_tensor
:
*
group_func_arg_tensors
)
{
// input node data name.
// input node data name.
group
->
input_names
.
push_back
(
arg
s
->
name
);
group
->
input_names
.
push_back
(
arg
_tensor
->
name
);
// input args
// input args
func_args
.
emplace_back
(
args
->
buffer
,
ir
::
Argument
::
IO
::
kInput
);
group_func_args
.
emplace_back
(
arg_tensor
->
buffer
,
ir
::
Argument
::
IO
::
kInput
);
arg_name_set
.
insert
(
arg_tensor
->
buffer
->
name
);
}
}
group
->
output_names
.
clear
();
group
->
output_names
.
clear
();
for
(
auto
&
node
:
group
->
output_nodes
)
{
for
(
auto
&
node
:
group
->
output_nodes
)
{
// output node data name.
for
(
auto
node_data
:
GetAllNodeData
(
node
))
{
group
->
output_names
.
push_back
(
node_data
->
id
());
}
// collect all output tensor.
// collect all output tensor.
std
::
string
post
=
""
;
for
(
auto
node_data
:
GetAllNodeData
(
node
))
{
std
::
string
prefix
=
GetNodeData
(
node
)
->
id
();
std
::
string
output_node_data_name
=
node_data
->
id
();
for
(
int
idx
=
0
;
idx
<
1
;
++
idx
)
{
group
->
output_names
.
push_back
(
output_node_data_name
);
CHECK
(
tensor_map
.
count
(
prefix
))
<<
"Can't find output tensor "
<<
prefix
;
// CHECK(tensor_map.count(output_node_data_name)) << "Can't find output
if
(
!
tensor_map
.
count
(
prefix
+
post
))
{
// tensor " << output_node_data_name;
break
;
if
(
tensor_map
.
count
(
output_node_data_name
)
==
0
)
{
continue
;
}
auto
tensor
=
tensor_map
.
at
(
output_node_data_name
);
if
(
arg_name_set
.
count
(
tensor
->
buffer
->
name
)
!=
0
)
{
continue
;
}
}
auto
tensor
=
tensor_map
[
prefix
+
post
];
// output arg tensors
arg_tensors
.
push_back
(
tensor
);
group_func_arg_tensors
->
push_back
(
tensor
);
// output args
// output args
func_args
.
emplace_back
(
tensor
->
buffer
,
ir
::
Argument
::
IO
::
kOutput
);
group_func_args
.
emplace_back
(
tensor
->
buffer
,
ir
::
Argument
::
IO
::
kOutput
);
// update post
arg_name_set
.
insert
(
tensor
->
buffer
->
name
);
post
=
"_"
+
std
::
to_string
(
idx
);
}
}
}
}
std
::
unordered_set
<
std
::
string
>
args_map
;
if
(
!
done_op_schedule
)
{
for
(
auto
arg
:
func_args
)
{
std
::
unordered_set
<
std
::
string
>
args_set
;
args_map
.
insert
(
arg
.
name
());
for
(
auto
arg
:
group_func_args
)
{
}
args_set
.
insert
(
arg
.
name
());
}
for
(
auto
&
tensor
:
tensor_map
)
{
for
(
auto
&
tensor_pair
:
tensor_map
)
{
if
(
args_map
.
count
(
"_"
+
tensor
.
first
))
{
if
(
args_set
.
count
(
"_"
+
tensor_pair
.
second
->
name
))
{
continue
;
continue
;
}
group_func_arg_tensors
->
push_back
(
tensor_pair
.
second
);
// use the underlying tensor name to be consistent with the argument name
// in the lowered function
group
->
output_names
.
push_back
(
tensor_pair
.
second
->
name
);
group_func_args
.
emplace_back
(
tensor_pair
.
second
->
buffer
,
ir
::
Argument
::
IO
::
kOutput
);
}
}
arg_tensors
.
push_back
(
tensor
.
second
);
// use the underlying tensor name to be consistent with the argument name in
// the lowered function
group
->
output_names
.
push_back
(
tensor
.
second
->
name
);
func_args
.
emplace_back
(
tensor
.
second
->
buffer
,
ir
::
Argument
::
IO
::
kOutput
);
}
}
auto
func_body
=
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
auto
func_body
=
ir_sch
->
GetModule
().
GetExprs
().
at
(
0
);
#ifdef CINN_WITH_CUDA
#ifdef CINN_WITH_CUDA
optim
::
OptimizeExprGPU
(
&
(
func_body
));
optim
::
OptimizeExprGPU
(
&
(
func_body
));
#endif
#endif
auto
temp_buffers
=
lang
::
GetTempBuffers
(
arg_tensors
,
stages
,
func_body
);
// 2.Prepare temp buffers
poly
::
StageMap
stages
;
auto
temp_buffers
=
lang
::
GetTempBuffers
(
*
group_func_arg_tensors
,
stages
,
func_body
);
// 3.Building LoweredFunc
auto
func
=
ir
::
_LoweredFunc_
::
Make
(
group
->
GetFuncName
(),
auto
func
=
ir
::
_LoweredFunc_
::
Make
(
group
->
GetFuncName
(),
func_args
,
group_
func_args
,
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
),
ir_sch
->
GetModule
().
GetExprs
().
at
(
0
),
temp_buffers
);
temp_buffers
);
func
->
PrepareBufferCastExprs
();
if
(
!
done_op_schedule
)
{
func
->
PrepareBufferCastExprs
();
}
// 4.Apply low level pass
func
=
optim
::
Optimize
(
Expr
(
func
),
target_
,
false
).
as_lowered_func_ref
();
func
=
optim
::
Optimize
(
Expr
(
func
),
target_
,
false
).
as_lowered_func_ref
();
return
{
func
};
return
{
func
};
}
}
std
::
vector
<
Expr
>
OpLowerer
::
IRElementwiseCompute
(
std
::
vector
<
ir
::
Expr
>
OpLowerer
::
LowerOps
(
poly
::
StageMap
&
stages
,
const
std
::
vector
<
Node
*>&
nodes
,
std
::
vector
<
ir
::
Tensor
>&
func_tensors
,
bool
apply_op_schedule
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
ScheduleDetermineFunction
schedule_determine_func
,
const
GroupPtr
&
group
,
std
::
vector
<
ir
::
Tensor
>*
group_func_arg_tensors
,
const
GroupPtr
&
sub_group
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
)
{
bool
apply_impl_schedule
)
{
VLOG
(
2
)
<<
"ElementwiseCompute Group : "
<<
sub_group
->
group_id
;
auto
&
strategy
=
Operator
::
GetAttrs
<
StrategyFunction
>
(
"CINNStrategy"
);
auto
&
strategy
=
Operator
::
GetAttrs
<
StrategyFunction
>
(
"CINNStrategy"
);
std
::
vector
<
Expr
>
func_bodies
;
std
::
vector
<
Expr
>
ast_exprs
;
for
(
Node
*
node
:
nodes
)
{
for
(
auto
&
node
:
sub_group
->
nodes
)
{
// 1.Select Op impl
VLOG
(
4
)
<<
"Lower op: "
<<
node
->
op
()
->
name
;
auto
node_data
=
GetNodeData
(
node
);
CHECK_EQ
(
GetAllNodeData
(
node
).
size
(),
1U
);
std
::
vector
<
common
::
CINNValue
>
cinn_inputs
;
std
::
vector
<
ir
::
Tensor
>
tensor_inputs
=
std
::
move
(
CollectInputTensor
(
node
,
func_tensors
,
tensor_map
,
this
->
type_dict_
,
this
->
shape_dict_
));
for
(
auto
&
tensor
:
tensor_inputs
)
{
cinn_inputs
.
push_back
(
common
::
CINNValue
(
ir
::
Expr
(
tensor
)));
}
// set tensor name = node data name
cinn_inputs
.
push_back
(
common
::
CINNValue
(
node_data
->
id
()));
std
::
vector
<
Type
>
out_types
;
std
::
vector
<
Type
>
out_types
;
std
::
vector
<
std
::
vector
<
int
>>
out_shapes
;
std
::
vector
<
std
::
vector
<
int
>>
out_shapes
;
out_types
.
push_back
(
this
->
type_dict_
.
at
(
node_data
->
id
()));
std
::
vector
<
NodeData
*>
node_datas
=
GetAllNodeData
(
node
);
out_shapes
.
push_back
(
this
->
shape_dict_
.
at
(
node_data
->
id
()));
for
(
const
auto
&
node_data
:
node_datas
)
{
auto
impl
=
OpStrategy
::
SelectImpl
(
strategy
[
node
->
op
()](
out_types
.
push_back
(
this
->
type_dict_
.
at
(
node_data
->
id
()));
node
->
attrs
,
tensor_inputs
,
out_types
,
out_shapes
,
this
->
target_
));
out_shapes
.
push_back
(
this
->
shape_dict_
.
at
(
node_data
->
id
()));
// do compute
common
::
CINNValuePack
pack
=
impl
->
fcompute
(
common
::
CINNValuePack
{
cinn_inputs
});
CHECK_EQ
(
pack
.
size
(),
2U
);
Expr
expr
=
pack
[
0
];
poly
::
StageMap
node_stages
=
pack
.
back
();
tensor_inputs
.
push_back
(
expr
.
as_tensor_ref
());
tensor_map
[
node_data
->
id
()]
=
expr
.
as_tensor_ref
();
auto
func
=
lang
::
LowerVec
(
"fn_"
+
node
->
id
(),
node_stages
,
tensor_inputs
,
{},
{},
nullptr
,
this
->
target_
,
true
);
CHECK_EQ
(
func
.
size
(),
1
);
if
(
apply_impl_schedule
)
{
std
::
vector
<
common
::
CINNValue
>
schedule_inputs
;
// collect tensor
for
(
int
idx
=
0
;
idx
<
pack
.
size
()
-
1
;
++
idx
)
{
CHECK
(
pack
[
idx
].
is_tensor
());
schedule_inputs
.
push_back
(
common
::
CINNValue
(
pack
[
idx
]));
}
for
(
auto
&
f
:
func
)
{
schedule_inputs
.
push_back
(
common
::
CINNValue
(
f
->
body
));
}
// do ast tree schedule
common
::
CINNValuePack
expr_pack
=
impl
->
fschedule
(
common
::
CINNValuePack
{
schedule_inputs
});
CHECK_EQ
(
expr_pack
.
size
(),
1
);
Expr
ast_expr
=
expr_pack
[
0
];
ast_exprs
.
push_back
(
ast_expr
);
}
else
{
ast_exprs
.
push_back
(
func
[
0
]
->
body
);
}
}
}
std
::
vector
<
ir
::
Tensor
>
op_func_arg_tensors
=
std
::
move
(
CollectInputTensor
(
node
,
return
ast_exprs
;
this
->
type_dict_
,
}
this
->
shape_dict_
,
group_func_arg_tensors
,
std
::
vector
<
Expr
>
OpLowerer
::
IRReduceCompute
(
tensor_map
));
poly
::
StageMap
&
stages
,
auto
op_impl
=
std
::
vector
<
ir
::
Tensor
>&
func_args
,
OpStrategy
::
SelectImpl
(
strategy
[
node
->
op
()](
node
->
attrs
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
op_func_arg_tensors
,
const
GroupPtr
&
group
,
out_types
,
const
GroupPtr
&
sub_group
,
out_shapes
,
bool
apply_impl_schedule
)
{
this
->
target_
));
VLOG
(
2
)
<<
"ReduceCompute Group : "
<<
sub_group
->
group_id
;
auto
&
cinn_strategy
=
Operator
::
GetAttrs
<
StrategyFunction
>
(
"CINNStrategy"
);
// 2.Perform the lower process of Op
auto
&
op_pattern_dict
=
Operator
::
GetAttrs
<
OpPatternKind
>
(
"OpPattern"
);
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
DoOpLower
(
op_impl
,
node
,
tensor_map
,
&
op_func_arg_tensors
);
std
::
vector
<
Expr
>
ast_exprs
;
for
(
auto
&
node
:
sub_group
->
nodes
)
{
if
(
apply_op_schedule
&&
(
this
->*
schedule_determine_func
)(
node
))
{
auto
node_data
=
GetNodeData
(
node
);
// 3.Perform the schedule of Op
VLOG
(
3
)
<<
"In ReduceCompute, process node: "
<<
node
->
id
()
func_bodies
.
push_back
(
DoOpSchedule
(
op_impl
,
op_func_arg_tensors
,
funcs
));
<<
" with op type: "
<<
node
->
op
()
->
name
;
std
::
vector
<
common
::
CINNValue
>
cinn_inputs
;
std
::
vector
<
ir
::
Tensor
>
tensor_inputs
=
std
::
move
(
CollectInputTensor
(
node
,
func_args
,
tensor_map
,
this
->
type_dict_
,
this
->
shape_dict_
));
for
(
auto
&
tensor
:
tensor_inputs
)
{
cinn_inputs
.
push_back
(
common
::
CINNValue
(
ir
::
Expr
(
tensor
)));
}
cinn_inputs
.
push_back
(
common
::
CINNValue
(
node_data
->
id
()));
std
::
vector
<
Type
>
out_types
;
std
::
vector
<
std
::
vector
<
int
>>
out_shapes
;
out_types
.
push_back
(
this
->
type_dict_
.
at
(
node_data
->
id
()));
out_shapes
.
push_back
(
this
->
shape_dict_
.
at
(
node_data
->
id
()));
auto
impl
=
OpStrategy
::
SelectImpl
(
cinn_strategy
[
node
->
op
()](
node
->
attrs
,
tensor_inputs
,
out_types
,
out_shapes
,
target_
));
// do compute
common
::
CINNValuePack
pack
=
impl
->
fcompute
(
common
::
CINNValuePack
{
cinn_inputs
});
CHECK_GE
(
pack
.
size
(),
2UL
);
CHECK_LE
(
pack
.
size
(),
5UL
);
poly
::
StageMap
tmp_stages
=
pack
.
back
();
std
::
string
post
=
""
;
for
(
int
idx
=
0
;
idx
<
pack
.
size
()
-
1
;
++
idx
)
{
Expr
expr
=
pack
[
idx
];
tensor_map
[
node_data
->
id
()
+
post
]
=
expr
.
as_tensor_ref
();
// As op may has more than 1 output tensor, using id + "_0"/"_1" as key.
post
=
"_"
+
std
::
to_string
(
idx
);
// Insert outout tensors
if
(
!
expr
.
as_tensor_ref
()
->
buffer
.
defined
()
||
this
->
target_
!=
common
::
DefaultNVGPUTarget
())
{
tensor_inputs
.
push_back
(
expr
.
as_tensor_ref
());
}
}
auto
func
=
lang
::
LowerVec
(
"fn_"
+
node
->
id
(),
tmp_stages
,
tensor_inputs
,
{},
{},
nullptr
,
this
->
target_
,
true
);
// node is kReduction
if
(
op_pattern_dict
[
node
->
op
()]
==
framework
::
kReduction
&&
apply_impl_schedule
)
{
std
::
vector
<
common
::
CINNValue
>
schedule_inputs
;
// collect tensor
for
(
int
idx
=
0
;
idx
<
pack
.
size
()
-
1
;
++
idx
)
{
CHECK
(
pack
[
idx
].
is_tensor
());
schedule_inputs
.
push_back
(
common
::
CINNValue
(
pack
[
idx
]));
}
for
(
auto
&
f
:
func
)
{
schedule_inputs
.
push_back
(
common
::
CINNValue
(
f
->
body
));
}
// do ast tree schedule
common
::
CINNValuePack
expr_pack
=
impl
->
fschedule
(
common
::
CINNValuePack
{
schedule_inputs
});
// ast tree after schedule.
Expr
ast_expr
=
expr_pack
[
0
];
ast_exprs
.
push_back
(
ast_expr
);
}
else
if
(
group
->
master_nodes
.
count
(
node
))
{
// as master node should copy transform from reducer, left it to reduce
// schedule.
ast_exprs
.
push_back
(
func
[
0
]
->
body
);
}
else
{
}
else
{
ast_exprs
.
push_back
(
func
[
0
]
->
body
);
for
(
const
ir
::
LoweredFunc
&
func
:
funcs
)
{
func_bodies
.
push_back
(
func
->
body
);
}
}
}
}
}
return
ast_expr
s
;
return
func_bodie
s
;
}
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
IRLowerNonFusibleOp
(
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
DoOpLower
(
GroupPtr
&
group
,
bool
apply_impl_schedule
)
{
std
::
shared_ptr
<
hlir
::
framework
::
OpImpl
>
op_impl
,
VLOG
(
3
)
<<
"LowerNonFusibleOp Group : "
<<
group
->
group_id
;
Node
*
node
,
// get input tensor and output tensor
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
,
CHECK
(
group
->
nodes
.
size
()
||
group
->
fused_sub_groups
.
size
());
std
::
vector
<
ir
::
Tensor
>*
op_func_arg_tensors
)
{
auto
&
cinn_strategy
=
Operator
::
GetAttrs
<
StrategyFunction
>
(
"CINNStrategy"
);
VLOG
(
4
)
<<
"Do lower with Compute, op: "
<<
node
->
op
()
->
name
;
auto
&
op_pattern_dict
=
Operator
::
GetAttrs
<
OpPatternKind
>
(
"OpPattern"
);
auto
node
=
group
->
fused_sub_groups
.
size
()
?
group
->
fused_sub_groups
[
0
]
->
nodes
.
front
()
:
group
->
nodes
.
front
();
VLOG
(
3
)
<<
"GetOpFunc of op "
<<
node
->
id
();
std
::
vector
<
ir
::
Tensor
>
inputs
;
std
::
vector
<
common
::
CINNValue
>
cinn_inputs
;
std
::
vector
<
common
::
CINNValue
>
cinn_inputs
;
for
(
const
ir
::
Tensor
&
tensor
:
*
op_func_arg_tensors
)
{
std
::
vector
<
ir
::
Argument
>
args
;
cinn_inputs
.
push_back
(
common
::
CINNValue
(
ir
::
Expr
(
tensor
)));
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>
tensor_map
;
for
(
auto
&
node_data
:
GetInputNodeData
(
node
))
{
CHECK
(
node_data
);
ir
::
Tensor
tensor
;
if
(
!
tensor_map
.
count
(
node_data
->
id
()))
{
tensor
=
GetTensor
(
node_data
,
this
->
type_dict_
,
this
->
shape_dict_
);
// record tensor.
tensor_map
[
node_data
->
id
()]
=
tensor
;
// input name.
group
->
input_names
.
push_back
(
node_data
->
id
());
// input type.
args
.
emplace_back
(
tensor
->
buffer
,
ir
::
Argument
::
IO
::
kInput
);
}
else
{
tensor
=
tensor_map
[
node_data
->
id
()];
}
inputs
.
push_back
(
tensor
);
cinn_inputs
.
push_back
(
common
::
CINNValue
(
tensor
));
}
}
// set tensor name = node data name
std
::
vector
<
Type
>
out_types
;
std
::
vector
<
NodeData
*>
node_datas
=
GetAllNodeData
(
node
);
std
::
vector
<
std
::
vector
<
int
>>
out_shapes
;
for
(
const
NodeData
*
node_data
:
node_datas
)
{
auto
node_datas
=
GetAllNodeData
(
node
);
for
(
auto
node_data
:
node_datas
)
{
VLOG
(
3
)
<<
"cinn_inputs.push_back "
<<
node_data
->
id
();
group
->
output_names
.
push_back
(
node_data
->
id
());
out_types
.
push_back
(
this
->
type_dict_
.
at
(
node_data
->
id
()));
out_shapes
.
push_back
(
this
->
shape_dict_
.
at
(
node_data
->
id
()));
cinn_inputs
.
push_back
(
common
::
CINNValue
(
node_data
->
id
()));
cinn_inputs
.
push_back
(
common
::
CINNValue
(
node_data
->
id
()));
}
}
auto
impl
=
OpStrategy
::
SelectImpl
(
cinn_strategy
[
node
->
op
()](
// 1.Do compute
node
->
attrs
,
inputs
,
out_types
,
out_shapes
,
target_
));
common
::
CINNValuePack
pack
=
// if node op is custom_call, apply custom_call compute.
op_impl
->
fcompute
(
common
::
CINNValuePack
{
cinn_inputs
});
if
(
node
->
op
()
->
name
==
"custom_call"
)
{
std
::
string
external_api
;
poly
::
StageMap
tmp_stages
=
pack
.
back
();
if
(
node
->
attrs
.
attr_store
.
count
(
"custom_call"
))
{
std
::
string
post
=
""
;
external_api
=
for
(
int
idx
=
0
;
idx
<
pack
.
size
()
-
1
;
++
idx
)
{
absl
::
get
<
std
::
string
>
(
node
->
attrs
.
attr_store
.
at
(
"custom_call"
));
Expr
expr
=
pack
[
idx
];
// Insert the output tensor defined by Compute into the tensor_map
if
(
pack
.
size
()
-
1
>
node_datas
.
size
())
{
// Some nodes may output multiple temp tensors in their Compute
// definition, but only one output node_data in the graph, and we use id +
// "_0"/"_1" as key.
(
*
tensor_map
)[
node_datas
[
0
]
->
id
()
+
post
]
=
expr
.
as_tensor_ref
();
post
=
"_"
+
std
::
to_string
(
idx
);
}
else
{
}
else
{
external_api
=
// If the number of output tensors defined by Compute is less equal than
ExternalApiRegistry
::
Global
()
->
GetExternalApi
(
node
,
target_
);
// the output node_data on the graph, then there is a one-to-one
// correspondence, and the redundant output node_data contact empty.
(
*
tensor_map
)[
node_datas
[
idx
]
->
id
()]
=
expr
.
as_tensor_ref
();
}
}
std
::
vector
<
common
::
CINNValue
>
compute_args
=
{
common
::
CINNValue
(
group
->
GetFuncName
()),
common
::
CINNValue
(
external_api
)};
common
::
CINNValuePack
pack
=
impl
->
fcompute
(
common
::
CINNValuePack
{
compute_args
});
CHECK_EQ
(
pack
.
size
(),
1UL
);
// reset input names as extern api input args can't be remove duplicate.
group
->
input_names
.
clear
();
for
(
auto
&
inode
:
node
->
inlinks_in_order
())
{
group
->
input_names
.
push_back
(
inode
->
source
()
->
as
<
NodeData
>
()
->
id
());
}
return
{
pack
[
0
].
operator
ir
::
Expr
().
as_lowered_func_ref
()};
}
common
::
CINNValuePack
pack
=
// Insert output tensors into function arg
impl
->
fcompute
(
common
::
CINNValuePack
{
cinn_inputs
});
if
(
!
expr
.
as_tensor_ref
()
->
buffer
.
defined
()
||
for
(
int
i
=
0
;
i
<
pack
->
size
()
-
1
;
i
++
)
{
ir
::
Expr
temp
=
pack
[
i
];
// checkout whether the tensor is with buffer.
if
(
!
temp
.
as_tensor_ref
()
->
buffer
.
defined
()
||
this
->
target_
!=
common
::
DefaultNVGPUTarget
())
{
this
->
target_
!=
common
::
DefaultNVGPUTarget
())
{
inputs
.
push_back
(
temp
.
as_tensor_ref
());
op_func_arg_tensors
->
push_back
(
expr
.
as_tensor_ref
());
temp
.
as_tensor_ref
()
->
WithBuffer
();
expr
.
as_tensor_ref
()
->
WithBuffer
();
args
.
emplace_back
(
temp
.
as_tensor_ref
()
->
buffer
,
ir
::
Argument
::
IO
::
kOutput
);
}
}
}
}
poly
::
StageMap
stages
=
pack
.
back
();
// 2.Do lower
auto
func
=
lang
::
LowerVec
(
group
->
GetFuncName
(),
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
lang
::
LowerVec
(
"fn_"
+
node
->
id
(),
stages
,
tmp_stages
,
inputs
,
*
op_func_arg_tensors
,
{},
{},
{},
{},
nullptr
,
nullptr
,
this
->
target_
,
this
->
target_
,
true
);
true
);
VLOG
(
4
)
<<
"Lower op: "
<<
node
->
op
()
->
name
<<
", get "
<<
funcs
.
size
()
if
(
apply_impl_schedule
)
{
<<
" LoweredFunc:
\n
"
;
std
::
vector
<
common
::
CINNValue
>
schedule_inputs
;
// collect tensor
op_func_arg_tensors
->
clear
();
for
(
int
idx
=
0
;
idx
<
pack
.
size
()
-
1
;
++
idx
)
{
for
(
int
idx
=
0
;
idx
<
pack
.
size
()
-
1
;
++
idx
)
{
CHECK
(
pack
[
idx
].
is_tensor
());
CHECK
(
pack
[
idx
].
is_tensor
());
schedule_inputs
.
push_back
(
common
::
CINNValue
(
pack
[
idx
]));
op_func_arg_tensors
->
push_back
(
}
pack
[
idx
].
operator
ir
::
Expr
().
as_tensor_ref
());
for
(
auto
&
f
:
func
)
{
schedule_inputs
.
push_back
(
common
::
CINNValue
(
f
->
body
));
}
// do ast tree schedule
common
::
CINNValuePack
expr_pack
=
impl
->
fschedule
(
common
::
CINNValuePack
{
schedule_inputs
});
ir
::
Expr
func_body
=
expr_pack
[
0
];
std
::
vector
<
std
::
string
>
input_output_nodes
(
group
->
input_names
);
input_output_nodes
.
insert
(
input_output_nodes
.
end
(),
group
->
output_names
.
begin
(),
group
->
output_names
.
end
());
VLOG
(
6
)
<<
"func.size() = "
<<
func
.
size
()
<<
", expr_pack.size() = "
<<
expr_pack
.
size
();
VLOG
(
6
)
<<
"args.size() = "
<<
args
.
size
()
<<
", input_output_nodes.size() = "
<<
input_output_nodes
.
size
();
if
(
args
.
size
()
>
input_output_nodes
.
size
())
{
args
=
lang
::
GetArgs
(
func_body
,
input_output_nodes
);
}
std
::
vector
<
ir
::
LoweredFunc
>
res
;
for
(
int
i
=
0
;
i
<
expr_pack
.
size
();
i
++
)
{
ir
::
Expr
func_body
=
expr_pack
[
0
];
#ifdef CINN_WITH_CUDA
optim
::
OptimizeExprGPU
(
&
(
func_body
));
#endif
auto
temp_buffers
=
lang
::
GetTempBuffers
(
inputs
,
stages
,
func_body
);
auto
function
=
ir
::
_LoweredFunc_
::
Make
(
group
->
GetFuncName
(),
args
,
func_body
,
temp_buffers
);
res
.
push_back
(
function
);
}
for
(
auto
&
i
:
res
)
{
i
=
optim
::
Optimize
(
Expr
(
i
),
target_
,
false
).
as_lowered_func_ref
();
}
return
res
;
}
else
{
for
(
auto
&
f
:
func
)
{
#ifdef CINN_WITH_CUDA
optim
::
OptimizeExprGPU
(
&
(
f
->
body
));
#endif
f
=
optim
::
Optimize
(
Expr
(
f
),
target_
,
false
).
as_lowered_func_ref
();
}
return
func
;
}
}
return
funcs
;
}
ir
::
Expr
OpLowerer
::
DoOpSchedule
(
std
::
shared_ptr
<
hlir
::
framework
::
OpImpl
>
op_impl
,
const
std
::
vector
<
ir
::
Tensor
>&
op_func_arg_tensors
,
const
std
::
vector
<
ir
::
LoweredFunc
>&
lowered_funcs
)
{
VLOG
(
4
)
<<
"Do op schedule"
;
std
::
vector
<
common
::
CINNValue
>
schedule_inputs
;
// 1.Collect tensors
for
(
const
ir
::
Tensor
&
op_func_arg_tensor
:
op_func_arg_tensors
)
{
schedule_inputs
.
push_back
(
common
::
CINNValue
(
op_func_arg_tensor
));
}
// 2.Collect bodies to be scheduled
for
(
const
ir
::
LoweredFunc
&
func
:
lowered_funcs
)
{
schedule_inputs
.
push_back
(
common
::
CINNValue
(
func
->
body
));
}
// 3.Do schedule on AST
common
::
CINNValuePack
expr_pack
=
op_impl
->
fschedule
(
common
::
CINNValuePack
{
schedule_inputs
});
VLOG
(
4
)
<<
"After op schedule: "
<<
expr_pack
[
0
].
operator
ir
::
Expr
();
return
expr_pack
[
0
].
operator
ir
::
Expr
();
}
}
// group schedule
// group schedule
void
OpLowerer
::
IR
Schedule
(
ir
::
Expr
OpLowerer
::
DoGroup
Schedule
(
ir
::
IRSchedule
&
ir_sch
,
ir
::
IRSchedule
&
ir_sch
,
const
GroupPtr
&
group
,
const
GroupPtr
&
group
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
)
{
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
)
{
...
@@ -698,6 +506,7 @@ void OpLowerer::IRSchedule(
...
@@ -698,6 +506,7 @@ void OpLowerer::IRSchedule(
<<
", ir is:
\n
"
<<
", ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
// if node is horizontal with reduce or node is reduce, loop assign
// if node is horizontal with reduce or node is reduce, loop assign
//
// master.
// master.
auto
loops
=
ir_sch
.
GetLoops
(
GetNodeData
(
node
)
->
id
());
auto
loops
=
ir_sch
.
GetLoops
(
GetNodeData
(
node
)
->
id
());
if
(
op_pattern_dict
[
node
->
op
()]
==
framework
::
kElementWise
)
{
if
(
op_pattern_dict
[
node
->
op
()]
==
framework
::
kElementWise
)
{
...
@@ -788,6 +597,7 @@ void OpLowerer::IRSchedule(
...
@@ -788,6 +597,7 @@ void OpLowerer::IRSchedule(
ir_sch
,
group
,
nodes_inline
,
nodes_set
,
this
->
shape_dict_
,
tensor_map
);
ir_sch
,
group
,
nodes_inline
,
nodes_set
,
this
->
shape_dict_
,
tensor_map
);
VLOG
(
4
)
<<
"After IRSchedule, ir is:
\n
"
VLOG
(
4
)
<<
"After IRSchedule, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
return
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/cinn/hlir/framework/op_lowering.h
浏览文件 @
3559252a
...
@@ -39,46 +39,132 @@ using GroupPtr = std::shared_ptr<Graph::Group>;
...
@@ -39,46 +39,132 @@ using GroupPtr = std::shared_ptr<Graph::Group>;
using
common
::
Target
;
using
common
::
Target
;
class
OpLowerer
;
class
OpLowerer
;
typedef
std
::
vector
<
Expr
>
(
OpLowerer
::*
IRComputeFunction
)(
poly
::
StageMap
&
,
typedef
bool
(
OpLowerer
::*
ScheduleDetermineFunction
)(
Node
*
);
std
::
vector
<
ir
::
Tensor
>&
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
,
const
GroupPtr
&
,
const
GroupPtr
&
,
bool
);
class
OpLowerer
{
class
OpLowerer
{
public:
public:
OpLowerer
(
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
,
OpLowerer
(
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
,
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
,
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
,
const
Target
&
);
const
Target
&
);
std
::
vector
<
ir
::
LoweredFunc
>
Lower
(
GroupPtr
&
group
);
// NOLINT
std
::
vector
<
ir
::
LoweredFunc
>
LowerWithoutSchedule
(
GroupPtr
&
group
);
// NOLINT
/**
* @brief Lower a group to CINN IR.
* @param group The group to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param apply_group_schedule Whether to schedule at group level.
* @return The lowered funcs.
*/
std
::
vector
<
ir
::
LoweredFunc
>
Lower
(
const
GroupPtr
&
group
,
bool
apply_op_schedule
=
true
,
bool
apply_group_schedule
=
true
);
private:
private:
std
::
vector
<
ir
::
LoweredFunc
>
IRLowerOp
(
IRComputeFunction
,
GroupPtr
&
);
/**
std
::
vector
<
ir
::
LoweredFunc
>
IRLowerNonFusibleOp
(
GroupPtr
&
,
bool
);
* @brief Lower a group to CINN IR.
std
::
vector
<
ir
::
LoweredFunc
>
IRLowerOpWithoutSchedule
(
IRComputeFunction
,
* @param group The group to be lowered.
GroupPtr
&
);
* @param apply_op_schedule Whether to schedule at Op level.
#define DEFINE_IR_COMPUTE(type) \
* @param apply_group_schedule Whether to schedule at group level.
std::vector<Expr> IR##type##Compute( \
* @param schedule_determine_func Function used to determine which Ops to
poly::StageMap& stages, \
* schedule.
std::vector<ir::Tensor>& func_args, \
* @return The lowered funcs.
std::unordered_map<std::string, ir::Tensor>& tensor_map, \
*/
const GroupPtr& group, \
std
::
vector
<
ir
::
LoweredFunc
>
LowerGroup
(
const GroupPtr& sub_group, \
const
GroupPtr
&
group
,
bool apply_impl_schedule = false);
bool
apply_op_schedule
,
bool
apply_group_schedule
,
// compute and schedule
ScheduleDetermineFunction
schedule_determine_func
);
DEFINE_IR_COMPUTE
(
Elementwise
);
DEFINE_IR_COMPUTE
(
Reduce
);
/**
DEFINE_IR_COMPUTE
(
OutEWiseFusable
);
* @brief Lower a group composed of CustomCall Op.
* @param group The group to be lowered.
void
IRSchedule
(
* @return The lowered funcs.
*/
std
::
vector
<
ir
::
LoweredFunc
>
LowerCustomCall
(
const
GroupPtr
&
group
);
/**
* @brief Post processing, including preparing function args and temporary
* variables, applying low-level optimization passes, etc.
* @param group The group to be lowered.
* @param tensor_map All tensors used for calculating the group.
* @param done_op_schedule Mark whether the Op level schedule has been
* applied.
* @param ir_sch The IRSchedule object of group.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @return The lowered funcs after the post processing.
*/
std
::
vector
<
ir
::
LoweredFunc
>
PostProcess
(
const
GroupPtr
&
group
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
bool
done_op_schedule
,
ir
::
IRSchedule
*
ir_sch
,
std
::
vector
<
ir
::
Tensor
>*
group_func_arg_tensors
);
/**
* @brief Lower an Op set to CINN IR.
* Compute, Lower and optional Schedule will be performed one by one
* for each Op.
* @param nodes The Op nodes to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param schedule_determine_func Function used to determine which Ops to
* schedule.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @param tensor_map All tensors used for calculating the group.
* @return The lowered func bodies of Op set.
*/
std
::
vector
<
ir
::
Expr
>
LowerOps
(
const
std
::
vector
<
Node
*>&
nodes
,
bool
apply_op_schedule
,
ScheduleDetermineFunction
schedule_determine_func
,
std
::
vector
<
ir
::
Tensor
>*
group_func_arg_tensors
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
);
/**
* @brief Lower an Op to CINN IR. The Compute and Lower processes will be
* called sequentially.
* @param op_impl The Op implementation defining Compute and Schedule.
* @param node The Op node to be lowered.
* @param tensor_map All tensors used for calculating the group.
* @param op_func_arg_tensors Tensors used as the Op function arguments.
* @return The lowered func of the Op node.
*/
std
::
vector
<
ir
::
LoweredFunc
>
DoOpLower
(
std
::
shared_ptr
<
hlir
::
framework
::
OpImpl
>
op_impl
,
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
,
std
::
vector
<
ir
::
Tensor
>*
op_func_arg_tensors
);
/**
* @brief Apply schedule on an Op.
* @param op_impl The Op implementation defining Compute and Schedule.
* @param op_func_arg_tensors Tensors used as the Op function arguments.
* @param lowered_funcs The lowered funcs of an Op to be scheduled.
* @return The lowered func body after schedule of the Op.
*/
ir
::
Expr
DoOpSchedule
(
std
::
shared_ptr
<
hlir
::
framework
::
OpImpl
>
op_impl
,
const
std
::
vector
<
ir
::
Tensor
>&
op_func_arg_tensors
,
const
std
::
vector
<
ir
::
LoweredFunc
>&
lowered_funcs
);
/**
* @brief Apply schedule on a group.
* @param ir_sch The IRSchedule containing the entire group's lowered func
* bodies.
* @param group The group to be scheduled.
* @param tensor_map All tensors used for calculating the group.
* @return The lowered func body after schedule of the group.
*/
ir
::
Expr
DoGroupSchedule
(
ir
::
IRSchedule
&
ir_sch
,
// NOLINT
ir
::
IRSchedule
&
ir_sch
,
// NOLINT
const
GroupPtr
&
group
,
const
GroupPtr
&
group
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
);
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
);
// Functions used to determine which Ops to schedule at op level, define a
// policy for each type of group.
inline
bool
ReduceScheduleDetermineFunction
(
Node
*
node
);
inline
bool
ElementwiseScheduleDetermineFunction
(
Node
*
node
);
inline
bool
NonFusibleScheduleDetermineFunction
(
Node
*
node
);
private:
Target
target_
;
Target
target_
;
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict_
;
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict_
;
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict_
;
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict_
;
...
...
paddle/cinn/hlir/framework/op_lowering_util.cc
浏览文件 @
3559252a
...
@@ -92,19 +92,19 @@ ir::Tensor GetTensor(
...
@@ -92,19 +92,19 @@ ir::Tensor GetTensor(
std
::
vector
<
ir
::
Tensor
>
CollectInputTensor
(
std
::
vector
<
ir
::
Tensor
>
CollectInputTensor
(
const
Node
*
node
,
const
Node
*
node
,
std
::
vector
<
ir
::
Tensor
>&
func_args
,
// NOLINT
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
// NOLINT
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict
,
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict
,
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict
)
{
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict
,
std
::
vector
<
ir
::
Tensor
>*
func_args
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
)
{
std
::
vector
<
ir
::
Tensor
>
tensors
;
std
::
vector
<
ir
::
Tensor
>
tensors
;
// get all input nodes
// get all input nodes
for
(
auto
&
node_data
:
GetInputNodeData
(
node
))
{
for
(
auto
&
node_data
:
GetInputNodeData
(
node
))
{
CHECK
(
node_data
);
CHECK
(
node_data
);
auto
tensor
=
GetTensor
(
node_data
,
type_dict
,
shape_dict
);
auto
tensor
=
GetTensor
(
node_data
,
type_dict
,
shape_dict
);
if
(
!
tensor_map
.
count
(
node_data
->
id
()))
{
if
(
!
tensor_map
->
count
(
node_data
->
id
()))
{
tensor_map
[
node_data
->
id
()]
=
tensor
;
(
*
tensor_map
)
[
node_data
->
id
()]
=
tensor
;
// record func input args
// record func input args
func_args
.
push_back
(
tensor
);
func_args
->
push_back
(
tensor
);
}
}
tensors
.
push_back
(
tensor
);
tensors
.
push_back
(
tensor
);
}
}
...
...
paddle/cinn/hlir/framework/op_lowering_util.h
浏览文件 @
3559252a
...
@@ -31,10 +31,10 @@ ir::Tensor GetTensor(
...
@@ -31,10 +31,10 @@ ir::Tensor GetTensor(
std
::
vector
<
ir
::
Tensor
>
CollectInputTensor
(
std
::
vector
<
ir
::
Tensor
>
CollectInputTensor
(
const
Node
*
node
,
const
Node
*
node
,
std
::
vector
<
ir
::
Tensor
>&
func_args
,
// NOLINT
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
// NOLINT
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict
,
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict
,
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict
);
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict
,
std
::
vector
<
ir
::
Tensor
>*
func_args
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
);
std
::
unordered_map
<
Node
*
,
Node
*>
BuildVirtualConsumer
(
std
::
unordered_map
<
Node
*
,
Node
*>
BuildVirtualConsumer
(
const
GroupPtr
&
group
,
const
GroupPtr
&
group
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录