Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3559252a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3559252a
编写于
7月 13, 2023
作者:
B
BiynXu
提交者:
GitHub
7月 13, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CINN] comb the op lowering code (#54982)
* [CINN] comb the op lowering code * [CINN] format code of OpLower
上级
27cc0df5
变更
9
展开全部
隐藏空白更改
内联
并排
Showing
9 changed file
with
466 addition
and
571 deletion
+466
-571
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc
...o_schedule/search_space/auto_gen_rule/auto_inline_test.cc
+3
-2
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc
...ule/search_space/auto_gen_rule/multi_level_tiling_test.cc
+59
-62
paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc
...n/auto_schedule/search_space/auto_gen_rule/test_helper.cc
+4
-6
paddle/cinn/auto_schedule/task/tune_task.cc
paddle/cinn/auto_schedule/task/tune_task.cc
+2
-1
paddle/cinn/auto_schedule/tests/performance_comparison_test.cc
...e/cinn/auto_schedule/tests/performance_comparison_test.cc
+3
-1
paddle/cinn/hlir/framework/op_lowering.cc
paddle/cinn/hlir/framework/op_lowering.cc
+272
-462
paddle/cinn/hlir/framework/op_lowering.h
paddle/cinn/hlir/framework/op_lowering.h
+114
-28
paddle/cinn/hlir/framework/op_lowering_util.cc
paddle/cinn/hlir/framework/op_lowering_util.cc
+6
-6
paddle/cinn/hlir/framework/op_lowering_util.h
paddle/cinn/hlir/framework/op_lowering_util.h
+3
-3
未找到文件。
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc
浏览文件 @
3559252a
...
@@ -71,7 +71,6 @@ TEST(AutoInline, SingleLoopInline) {
...
@@ -71,7 +71,6 @@ TEST(AutoInline, SingleLoopInline) {
nullptr
,
nullptr
,
target
,
target
,
true
);
true
);
VLOG
(
6
)
<<
"Expr after lowering:"
;
VLOG
(
6
)
<<
"Expr after lowering:"
;
VLOG
(
6
)
<<
funcs
[
0
]
->
body
;
VLOG
(
6
)
<<
funcs
[
0
]
->
body
;
...
@@ -170,7 +169,9 @@ TEST(AutoInline, AddReluInline) {
...
@@ -170,7 +169,9 @@ TEST(AutoInline, AddReluInline) {
EXPECT_EQ
(
graph
->
fusion_groups
.
size
(),
1UL
);
EXPECT_EQ
(
graph
->
fusion_groups
.
size
(),
1UL
);
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
op_lowerer
->
LowerWithoutSchedule
(
graph
->
fusion_groups
[
0
]);
op_lowerer
->
Lower
(
graph
->
fusion_groups
[
0
],
/*apply_op_schedule = */
false
,
/*apply_group_schedule=*/
false
);
VLOG
(
6
)
<<
"Expr before auto inline: "
<<
funcs
[
0
]
->
body
;
VLOG
(
6
)
<<
"Expr before auto inline: "
<<
funcs
[
0
]
->
body
;
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc
浏览文件 @
3559252a
...
@@ -388,9 +388,9 @@ TEST_F(TestMultiLevelTiling, ReduceSum) {
...
@@ -388,9 +388,9 @@ TEST_F(TestMultiLevelTiling, ReduceSum) {
TEST_F
(
TestMultiLevelTiling
,
Pool2d
)
{
TEST_F
(
TestMultiLevelTiling
,
Pool2d
)
{
default_input_names
=
{
"input"
};
default_input_names
=
{
"input"
};
default_output_names
=
{
"var_0"
};
default_output_names
=
{
"var_0"
,
"pad_temp_0"
};
std
::
vector
<
int32_t
>
input_shape
{
2
,
8
,
16
,
16
};
std
::
vector
<
std
::
vector
<
int32_t
>>
input_shapes
{{
2
,
8
,
16
,
16
}
};
std
::
vector
<
int32_t
>
output_shape
{
2
,
8
,
8
,
8
};
std
::
vector
<
std
::
vector
<
int32_t
>>
output_shapes
{{
2
,
8
,
8
,
8
},
{
2
,
8
,
18
,
18
}
};
std
::
string
pooling_type
=
"max"
;
std
::
string
pooling_type
=
"max"
;
std
::
vector
<
int
>
ksize
{
3
,
3
};
std
::
vector
<
int
>
ksize
{
3
,
3
};
std
::
vector
<
int
>
strides
{
2
,
2
};
std
::
vector
<
int
>
strides
{
2
,
2
};
...
@@ -402,7 +402,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
...
@@ -402,7 +402,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
bool
adaptive
=
false
;
bool
adaptive
=
false
;
std
::
string
padding_algorithm
=
"EXPLICIT"
;
std
::
string
padding_algorithm
=
"EXPLICIT"
;
frontend
::
Program
pool2d_program
=
tests
::
OpBuilder
(
"pool2d"
).
Build
(
frontend
::
Program
pool2d_program
=
tests
::
OpBuilder
(
"pool2d"
).
Build
(
{{
"input"
,
input_shape
}},
{{
"input"
,
input_shape
s
[
0
]
}},
{{
"pool_type"
,
pooling_type
},
{{
"pool_type"
,
pooling_type
},
{
"kernel_size"
,
ksize
},
{
"kernel_size"
,
ksize
},
{
"stride_size"
,
strides
},
{
"stride_size"
,
strides
},
...
@@ -440,85 +440,82 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
...
@@ -440,85 +440,82 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
{
ScheduleBlock(root)
ScheduleBlock(root)
{
{
serial for (i, 0, 2)
{
{
serial for (
j, 0, 8
)
serial for (
i, 0, 2
)
{
{
serial for (
k, 0, 1
8)
serial for (
j, 0,
8)
{
{
serial for (
a
, 0, 18)
serial for (
k
, 0, 18)
{
{
ScheduleBlock(pad_temp_0
)
serial for (a, 0, 18
)
{
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
ScheduleBlock(pad_temp_0)
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
{
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
}
}
}
}
}
}
}
}
}
}
}
}
}
} // end Expr 0
Expr 1 {
{
ScheduleBlock(root_0)
{
{
thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16)
{
{
thread_bind[
threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4
)
thread_bind[
blockIdx.x] for (i_j_k_a_fused, 0, 16
)
{
{
serial for (i_1, 0, 1
)
thread_bind[threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4
)
{
{
serial for (
j_1, 0, 4
)
serial for (
i_1, 0, 1
)
{
{
serial for (
k_1, 0, 1
)
serial for (
j_1, 0, 4
)
{
{
serial for (
a_1, 0, 4
)
serial for (
k_1, 0, 1
)
{
{
ScheduleBlock(var_0__reduce_init
)
serial for (a_1, 0, 4
)
{
{
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)
)
ScheduleBlock(var_0__reduce_init
)
{
{
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1))
{
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f
}
}
}
}
}
}
}
}
}
}
}
}
{
serial for (kernel_idx, 0, 3)
{
{
serial for (kernel_idx
_0
, 0, 3)
serial for (kernel_idx, 0, 3)
{
{
serial for (
ax0_ax1_ax2_ax3_fused, 0, 28
)
serial for (
kernel_idx_0, 0, 3
)
{
{
ScheduleBlock(pad_temp_0_shared_temp_buffer
)
serial for (ax0_ax1_ax2_ax3_fused, 0, 28
)
{
{
v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0)))
ScheduleBlock(pad_temp_0_shared_temp_buffer)
attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0)
{
{
pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3]
v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0)))
attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0)
{
pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3]
}
}
}
}
}
}
serial for (i_1, 0, 1)
serial for (i_1, 0, 1)
{
serial for (j_1, 0, 4)
{
{
serial for (
k_1, 0, 1
)
serial for (
j_1, 0, 4
)
{
{
serial for (
a_1, 0, 4
)
serial for (
k_1, 0, 1
)
{
{
ScheduleBlock(var_0_local_temp_buffer
)
serial for (a_1, 0, 4
)
{
{
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
ScheduleBlock(var_0_local_temp_buffer)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)])
{
{
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))])
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)])
{
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))])
}
}
}
}
}
}
}
...
@@ -526,21 +523,21 @@ Expr 1 {
...
@@ -526,21 +523,21 @@ Expr 1 {
}
}
}
}
}
}
}
serial for (ax0_0, 0, 1)
serial for (ax0_0, 0, 1)
{
serial for (ax1_0, 0, 4)
{
{
serial for (ax
2_0, 0, 1
)
serial for (ax
1_0, 0, 4
)
{
{
serial for (ax
3_0, 0, 4
)
serial for (ax
2_0, 0, 1
)
{
{
ScheduleBlock(var_0
)
serial for (ax3_0, 0, 4
)
{
{
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
ScheduleBlock(var_0)
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
{
{
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
{
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
}
}
}
}
}
}
}
...
@@ -553,7 +550,7 @@ Expr 1 {
...
@@ -553,7 +550,7 @@ Expr 1 {
}
}
}
}
}
}
} // end Expr
1
} // end Expr
0
)ROC"
;
)ROC"
;
ASSERT_EQ
(
ir
,
expected_ir
);
ASSERT_EQ
(
ir
,
expected_ir
);
...
@@ -569,8 +566,8 @@ Expr 1 {
...
@@ -569,8 +566,8 @@ Expr 1 {
pool2d_program
,
fixed_rand_seed
,
/* apply_manual_schedule*/
true
))),
pool2d_program
,
fixed_rand_seed
,
/* apply_manual_schedule*/
true
))),
default_input_names
,
default_input_names
,
default_output_names
,
default_output_names
,
{
input_shape
}
,
input_shapes
,
{
output_shape
}
,
output_shapes
,
target_
);
target_
);
}
}
...
...
paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc
浏览文件 @
3559252a
...
@@ -63,12 +63,10 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
...
@@ -63,12 +63,10 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
hlir
::
framework
::
OpLowerer
op_lowerer
(
dtype_dict
,
shape_dict
,
target_
);
hlir
::
framework
::
OpLowerer
op_lowerer
(
dtype_dict
,
shape_dict
,
target_
);
if
(
apply_manual_schedule
)
{
lowered_funcs_
=
lowered_funcs_
=
op_lowerer
.
Lower
(
graph
->
fusion_groups
.
front
());
op_lowerer
.
Lower
(
graph
->
fusion_groups
.
front
(),
}
else
{
/*apply_op_schedule = */
apply_manual_schedule
,
lowered_funcs_
=
/*apply_group_schedule = */
apply_manual_schedule
);
op_lowerer
.
LowerWithoutSchedule
(
graph
->
fusion_groups
.
front
());
}
CHECK
(
!
lowered_funcs_
.
empty
())
<<
"lowered_funcs_ is empty"
;
CHECK
(
!
lowered_funcs_
.
empty
())
<<
"lowered_funcs_ is empty"
;
std
::
vector
<
Expr
>
bodys
;
std
::
vector
<
Expr
>
bodys
;
...
...
paddle/cinn/auto_schedule/task/tune_task.cc
浏览文件 @
3559252a
...
@@ -39,7 +39,8 @@ void TuneTask::Initialize(
...
@@ -39,7 +39,8 @@ void TuneTask::Initialize(
op_lowerer
=
lower_handler
;
op_lowerer
=
lower_handler
;
// Set lowered_funcs and analyze output names.
// Set lowered_funcs and analyze output names.
this
->
lowered_funcs
=
op_lowerer
->
LowerWithoutSchedule
(
subgraph
);
this
->
lowered_funcs
=
op_lowerer
->
Lower
(
subgraph
,
/*apply_op_schedule = */
false
,
/*apply_group_schedule=*/
false
);
this
->
output_names
=
GetOutputNamesFromLoweredFunc
(
this
->
lowered_funcs
);
this
->
output_names
=
GetOutputNamesFromLoweredFunc
(
this
->
lowered_funcs
);
this
->
serialized_key
=
SerializeToString
(
shape_dict
,
dtype_dict
);
this
->
serialized_key
=
SerializeToString
(
shape_dict
,
dtype_dict
);
}
}
...
...
paddle/cinn/auto_schedule/tests/performance_comparison_test.cc
浏览文件 @
3559252a
...
@@ -157,7 +157,9 @@ class PerformanceTester : public ::testing::Test {
...
@@ -157,7 +157,9 @@ class PerformanceTester : public ::testing::Test {
for
(
auto
group
:
graph
->
fusion_groups
)
{
for
(
auto
group
:
graph
->
fusion_groups
)
{
compile_options
.
lowered_funcs
.
push_back
(
compile_options
.
lowered_funcs
.
push_back
(
op_lowerer
->
LowerWithoutSchedule
(
group
));
op_lowerer
->
Lower
(
group
,
/*apply_op_schedule = */
false
,
/*apply_group_schedule=*/
false
));
}
}
VLOG
(
3
)
<<
"===========================No Schedule LoweredFunc "
VLOG
(
3
)
<<
"===========================No Schedule LoweredFunc "
...
...
paddle/cinn/hlir/framework/op_lowering.cc
浏览文件 @
3559252a
此差异已折叠。
点击以展开。
paddle/cinn/hlir/framework/op_lowering.h
浏览文件 @
3559252a
...
@@ -39,46 +39,132 @@ using GroupPtr = std::shared_ptr<Graph::Group>;
...
@@ -39,46 +39,132 @@ using GroupPtr = std::shared_ptr<Graph::Group>;
using
common
::
Target
;
using
common
::
Target
;
class
OpLowerer
;
class
OpLowerer
;
typedef
std
::
vector
<
Expr
>
(
OpLowerer
::*
IRComputeFunction
)(
poly
::
StageMap
&
,
typedef
bool
(
OpLowerer
::*
ScheduleDetermineFunction
)(
Node
*
);
std
::
vector
<
ir
::
Tensor
>&
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
,
const
GroupPtr
&
,
const
GroupPtr
&
,
bool
);
class
OpLowerer
{
class
OpLowerer
{
public:
public:
OpLowerer
(
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
,
OpLowerer
(
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
,
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
,
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
,
const
Target
&
);
const
Target
&
);
std
::
vector
<
ir
::
LoweredFunc
>
Lower
(
GroupPtr
&
group
);
// NOLINT
std
::
vector
<
ir
::
LoweredFunc
>
LowerWithoutSchedule
(
GroupPtr
&
group
);
// NOLINT
/**
* @brief Lower a group to CINN IR.
* @param group The group to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param apply_group_schedule Whether to schedule at group level.
* @return The lowered funcs.
*/
std
::
vector
<
ir
::
LoweredFunc
>
Lower
(
const
GroupPtr
&
group
,
bool
apply_op_schedule
=
true
,
bool
apply_group_schedule
=
true
);
private:
private:
std
::
vector
<
ir
::
LoweredFunc
>
IRLowerOp
(
IRComputeFunction
,
GroupPtr
&
);
/**
std
::
vector
<
ir
::
LoweredFunc
>
IRLowerNonFusibleOp
(
GroupPtr
&
,
bool
);
* @brief Lower a group to CINN IR.
std
::
vector
<
ir
::
LoweredFunc
>
IRLowerOpWithoutSchedule
(
IRComputeFunction
,
* @param group The group to be lowered.
GroupPtr
&
);
* @param apply_op_schedule Whether to schedule at Op level.
#define DEFINE_IR_COMPUTE(type) \
* @param apply_group_schedule Whether to schedule at group level.
std::vector<Expr> IR##type##Compute( \
* @param schedule_determine_func Function used to determine which Ops to
poly::StageMap& stages, \
* schedule.
std::vector<ir::Tensor>& func_args, \
* @return The lowered funcs.
std::unordered_map<std::string, ir::Tensor>& tensor_map, \
*/
const GroupPtr& group, \
std
::
vector
<
ir
::
LoweredFunc
>
LowerGroup
(
const GroupPtr& sub_group, \
const
GroupPtr
&
group
,
bool apply_impl_schedule = false);
bool
apply_op_schedule
,
bool
apply_group_schedule
,
// compute and schedule
ScheduleDetermineFunction
schedule_determine_func
);
DEFINE_IR_COMPUTE
(
Elementwise
);
DEFINE_IR_COMPUTE
(
Reduce
);
/**
DEFINE_IR_COMPUTE
(
OutEWiseFusable
);
* @brief Lower a group composed of CustomCall Op.
* @param group The group to be lowered.
void
IRSchedule
(
* @return The lowered funcs.
*/
std
::
vector
<
ir
::
LoweredFunc
>
LowerCustomCall
(
const
GroupPtr
&
group
);
/**
* @brief Post processing, including preparing function args and temporary
* variables, applying low-level optimization passes, etc.
* @param group The group to be lowered.
* @param tensor_map All tensors used for calculating the group.
* @param done_op_schedule Mark whether the Op level schedule has been
* applied.
* @param ir_sch The IRSchedule object of group.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @return The lowered funcs after the post processing.
*/
std
::
vector
<
ir
::
LoweredFunc
>
PostProcess
(
const
GroupPtr
&
group
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
bool
done_op_schedule
,
ir
::
IRSchedule
*
ir_sch
,
std
::
vector
<
ir
::
Tensor
>*
group_func_arg_tensors
);
/**
* @brief Lower an Op set to CINN IR.
* Compute, Lower and optional Schedule will be performed one by one
* for each Op.
* @param nodes The Op nodes to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param schedule_determine_func Function used to determine which Ops to
* schedule.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @param tensor_map All tensors used for calculating the group.
* @return The lowered func bodies of Op set.
*/
std
::
vector
<
ir
::
Expr
>
LowerOps
(
const
std
::
vector
<
Node
*>&
nodes
,
bool
apply_op_schedule
,
ScheduleDetermineFunction
schedule_determine_func
,
std
::
vector
<
ir
::
Tensor
>*
group_func_arg_tensors
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
);
/**
* @brief Lower an Op to CINN IR. The Compute and Lower processes will be
* called sequentially.
* @param op_impl The Op implementation defining Compute and Schedule.
* @param node The Op node to be lowered.
* @param tensor_map All tensors used for calculating the group.
* @param op_func_arg_tensors Tensors used as the Op function arguments.
* @return The lowered func of the Op node.
*/
std
::
vector
<
ir
::
LoweredFunc
>
DoOpLower
(
std
::
shared_ptr
<
hlir
::
framework
::
OpImpl
>
op_impl
,
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
,
std
::
vector
<
ir
::
Tensor
>*
op_func_arg_tensors
);
/**
* @brief Apply schedule on an Op.
* @param op_impl The Op implementation defining Compute and Schedule.
* @param op_func_arg_tensors Tensors used as the Op function arguments.
* @param lowered_funcs The lowered funcs of an Op to be scheduled.
* @return The lowered func body after schedule of the Op.
*/
ir
::
Expr
DoOpSchedule
(
std
::
shared_ptr
<
hlir
::
framework
::
OpImpl
>
op_impl
,
const
std
::
vector
<
ir
::
Tensor
>&
op_func_arg_tensors
,
const
std
::
vector
<
ir
::
LoweredFunc
>&
lowered_funcs
);
/**
* @brief Apply schedule on a group.
* @param ir_sch The IRSchedule containing the entire group's lowered func
* bodies.
* @param group The group to be scheduled.
* @param tensor_map All tensors used for calculating the group.
* @return The lowered func body after schedule of the group.
*/
ir
::
Expr
DoGroupSchedule
(
ir
::
IRSchedule
&
ir_sch
,
// NOLINT
ir
::
IRSchedule
&
ir_sch
,
// NOLINT
const
GroupPtr
&
group
,
const
GroupPtr
&
group
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
);
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
);
// Functions used to determine which Ops to schedule at op level, define a
// policy for each type of group.
inline
bool
ReduceScheduleDetermineFunction
(
Node
*
node
);
inline
bool
ElementwiseScheduleDetermineFunction
(
Node
*
node
);
inline
bool
NonFusibleScheduleDetermineFunction
(
Node
*
node
);
private:
Target
target_
;
Target
target_
;
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict_
;
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict_
;
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict_
;
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict_
;
...
...
paddle/cinn/hlir/framework/op_lowering_util.cc
浏览文件 @
3559252a
...
@@ -92,19 +92,19 @@ ir::Tensor GetTensor(
...
@@ -92,19 +92,19 @@ ir::Tensor GetTensor(
std
::
vector
<
ir
::
Tensor
>
CollectInputTensor
(
std
::
vector
<
ir
::
Tensor
>
CollectInputTensor
(
const
Node
*
node
,
const
Node
*
node
,
std
::
vector
<
ir
::
Tensor
>&
func_args
,
// NOLINT
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
// NOLINT
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict
,
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict
,
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict
)
{
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict
,
std
::
vector
<
ir
::
Tensor
>*
func_args
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
)
{
std
::
vector
<
ir
::
Tensor
>
tensors
;
std
::
vector
<
ir
::
Tensor
>
tensors
;
// get all input nodes
// get all input nodes
for
(
auto
&
node_data
:
GetInputNodeData
(
node
))
{
for
(
auto
&
node_data
:
GetInputNodeData
(
node
))
{
CHECK
(
node_data
);
CHECK
(
node_data
);
auto
tensor
=
GetTensor
(
node_data
,
type_dict
,
shape_dict
);
auto
tensor
=
GetTensor
(
node_data
,
type_dict
,
shape_dict
);
if
(
!
tensor_map
.
count
(
node_data
->
id
()))
{
if
(
!
tensor_map
->
count
(
node_data
->
id
()))
{
tensor_map
[
node_data
->
id
()]
=
tensor
;
(
*
tensor_map
)
[
node_data
->
id
()]
=
tensor
;
// record func input args
// record func input args
func_args
.
push_back
(
tensor
);
func_args
->
push_back
(
tensor
);
}
}
tensors
.
push_back
(
tensor
);
tensors
.
push_back
(
tensor
);
}
}
...
...
paddle/cinn/hlir/framework/op_lowering_util.h
浏览文件 @
3559252a
...
@@ -31,10 +31,10 @@ ir::Tensor GetTensor(
...
@@ -31,10 +31,10 @@ ir::Tensor GetTensor(
std
::
vector
<
ir
::
Tensor
>
CollectInputTensor
(
std
::
vector
<
ir
::
Tensor
>
CollectInputTensor
(
const
Node
*
node
,
const
Node
*
node
,
std
::
vector
<
ir
::
Tensor
>&
func_args
,
// NOLINT
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
// NOLINT
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict
,
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict
,
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict
);
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict
,
std
::
vector
<
ir
::
Tensor
>*
func_args
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
);
std
::
unordered_map
<
Node
*
,
Node
*>
BuildVirtualConsumer
(
std
::
unordered_map
<
Node
*
,
Node
*>
BuildVirtualConsumer
(
const
GroupPtr
&
group
,
const
GroupPtr
&
group
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录