Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
70183c4b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
70183c4b
编写于
7月 17, 2023
作者:
H
Huihuang Zheng
提交者:
GitHub
7月 17, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove Old Schedules in Ops (#55391)
Remove old schedules.
上级
db1f2c42
变更
15
展开全部
显示空白变更内容
内联
并排
Showing
15 changed file
with
729 addition
and
1519 deletion
+729
-1519
paddle/cinn/hlir/op/broadcast.cc
paddle/cinn/hlir/op/broadcast.cc
+35
-43
paddle/cinn/hlir/op/contrib/gather_nd.cc
paddle/cinn/hlir/op/contrib/gather_nd.cc
+27
-40
paddle/cinn/hlir/op/contrib/logical_right_shift.cc
paddle/cinn/hlir/op/contrib/logical_right_shift.cc
+2
-6
paddle/cinn/hlir/op/contrib/lookup_table.cc
paddle/cinn/hlir/op/contrib/lookup_table.cc
+3
-5
paddle/cinn/hlir/op/contrib/one_hot.cc
paddle/cinn/hlir/op/contrib/one_hot.cc
+2
-6
paddle/cinn/hlir/op/contrib/reciprocal.cc
paddle/cinn/hlir/op/contrib/reciprocal.cc
+5
-11
paddle/cinn/hlir/op/contrib/resize.cc
paddle/cinn/hlir/op/contrib/resize.cc
+2
-5
paddle/cinn/hlir/op/contrib/sort.cc
paddle/cinn/hlir/op/contrib/sort.cc
+67
-90
paddle/cinn/hlir/op/elementwise.cc
paddle/cinn/hlir/op/elementwise.cc
+27
-57
paddle/cinn/hlir/op/nn.cc
paddle/cinn/hlir/op/nn.cc
+268
-647
paddle/cinn/hlir/op/op_broadcast_test.cc
paddle/cinn/hlir/op/op_broadcast_test.cc
+34
-89
paddle/cinn/hlir/op/op_util.cc
paddle/cinn/hlir/op/op_util.cc
+38
-80
paddle/cinn/hlir/op/reduction.cc
paddle/cinn/hlir/op/reduction.cc
+117
-194
paddle/cinn/hlir/op/transform.cc
paddle/cinn/hlir/op/transform.cc
+90
-212
paddle/cinn/hlir/op/transform_test.cc
paddle/cinn/hlir/op/transform_test.cc
+12
-34
未找到文件。
paddle/cinn/hlir/op/broadcast.cc
浏览文件 @
70183c4b
...
@@ -60,19 +60,16 @@ std::shared_ptr<OpStrategy> StrategyForBroadcast(
...
@@ -60,19 +60,16 @@ std::shared_ptr<OpStrategy> StrategyForBroadcast(
const
ir
::
Tensor
&
B
,
const
ir
::
Tensor
&
B
,
const
std
::
string
&
output_name
,
const
std
::
string
&
output_name
,
const
Expr
&
axis
))
{
const
Expr
&
axis
))
{
framework
::
CINNCompute
binary_compute
(
[
=
](
lang
::
Args
args
,
framework
::
CINNCompute
binary_compute
(
lang
::
RetValue
*
ret
)
{
[
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
<<
" compute is empty! Please check."
;
<<
" compute is empty! Please check."
;
CINNValuePack
pack_args
=
args
[
0
];
CINNValuePack
pack_args
=
args
[
0
];
CHECK_GE
(
pack_args
.
size
(),
2U
)
CHECK_GE
(
pack_args
.
size
(),
2U
)
<<
"at least 2 input tensors for "
<<
op_name
<<
" compute"
;
<<
"at least 2 input tensors for "
<<
op_name
<<
" compute"
;
std
::
string
tensor_name
=
UniqName
(
op_name
+
"_Out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_GE
(
pack_args
.
size
(),
3U
)
<<
op_name
<<
" 's input is not enough!"
;
CHECK_GE
(
pack_args
.
size
(),
3U
)
<<
op_name
<<
" 's input is not enough!"
;
CHECK
(
pack_args
[
2
].
is_string
());
CHECK
(
pack_args
[
2
].
is_string
());
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
}
Expr
A_expr
=
pack_args
[
0
];
Expr
A_expr
=
pack_args
[
0
];
Expr
B_expr
=
pack_args
[
1
];
Expr
B_expr
=
pack_args
[
1
];
CHECK
(
A_expr
.
as_tensor
());
CHECK
(
A_expr
.
as_tensor
());
...
@@ -198,12 +195,10 @@ std::shared_ptr<OpStrategy> StrategyForBroadcastTo(
...
@@ -198,12 +195,10 @@ std::shared_ptr<OpStrategy> StrategyForBroadcastTo(
CINNValuePack
pack_args
=
args
[
0
];
CINNValuePack
pack_args
=
args
[
0
];
CHECK
(
!
pack_args
.
empty
())
CHECK
(
!
pack_args
.
empty
())
<<
"The input tensors of broadcast_to compute is empty! Please check."
;
<<
"The input tensors of broadcast_to compute is empty! Please check."
;
std
::
string
tensor_name
=
UniqName
(
"broadcast_to_Out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_GE
(
pack_args
.
size
(),
2U
);
CHECK_GE
(
pack_args
.
size
(),
2U
);
CHECK
(
pack_args
[
1
].
is_string
());
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
Expr
A_expr
=
pack_args
[
0
];
Expr
A_expr
=
pack_args
[
0
];
CHECK
(
A_expr
.
as_tensor
());
CHECK
(
A_expr
.
as_tensor
());
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
...
@@ -323,12 +318,9 @@ std::shared_ptr<OpStrategy> StrategyForIsClose(
...
@@ -323,12 +318,9 @@ std::shared_ptr<OpStrategy> StrategyForIsClose(
CINNValuePack
pack_args
=
args
[
0
];
CINNValuePack
pack_args
=
args
[
0
];
int
input_size
=
pack_args
.
size
();
int
input_size
=
pack_args
.
size
();
std
::
string
tensor_name
=
UniqName
(
"IsClose_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
// the last pack argument is the output tensor name
// the last pack argument is the output tensor name
tensor_name
=
pack_args
.
back
().
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
.
back
().
operator
std
::
string
();
--
input_size
;
--
input_size
;
}
CHECK_EQ
(
input_size
,
2
)
CHECK_EQ
(
input_size
,
2
)
<<
"The input number of isclose should be 2, but here "
<<
"The input number of isclose should be 2, but here "
<<
input_size
<<
"! Please check."
;
<<
input_size
<<
"! Please check."
;
...
...
paddle/cinn/hlir/op/contrib/gather_nd.cc
浏览文件 @
70183c4b
...
@@ -114,11 +114,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
...
@@ -114,11 +114,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
VLOG
(
3
)
<<
"x shape: "
<<
utils
::
Join
(
tensor_x
->
shape
,
", "
)
VLOG
(
3
)
<<
"x shape: "
<<
utils
::
Join
(
tensor_x
->
shape
,
", "
)
<<
", index shape: "
<<
utils
::
Join
(
tensor_index
->
shape
,
", "
)
<<
", index shape: "
<<
utils
::
Join
(
tensor_index
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
UniqName
(
"GatherNd_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
3U
);
CHECK_EQ
(
pack_args
.
size
(),
3U
);
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
GatherNd
(
tensor_x
,
tensor_index
,
tensor_name
);
ir
::
Tensor
out
=
GatherNd
(
tensor_x
,
tensor_index
,
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
std
::
vector
<
CINNValue
>
res
;
stages
->
InsertLazily
(
out
);
stages
->
InsertLazily
(
out
);
...
@@ -131,7 +128,6 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
...
@@ -131,7 +128,6 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
framework
::
CINNSchedule
gather_nd_schedule
([
=
](
lang
::
Args
args
,
framework
::
CINNSchedule
gather_nd_schedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of gather_nd_schedule is "
CHECK
(
!
args
.
empty
())
<<
"The input argument of gather_nd_schedule is "
"empty! Please check.
\n
"
;
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
common
::
CINNValuePack
arg_pack
=
args
[
0
];
...
@@ -154,21 +150,12 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
...
@@ -154,21 +150,12 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
pe
::
IRCudaScheduleInjective
(
ir_sch
,
output_shapes
.
front
(),
target
);
pe
::
IRCudaScheduleInjective
(
ir_sch
,
output_shapes
.
front
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
IRScheduleInjectiveCPU
(
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
}
}
}
}
std
::
vector
<
common
::
CINNValue
>
res
{
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of gather_nd_schedule is "
"empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
CHECK
(
out
.
as_tensor
());
*
ret
=
arg_pack
;
}
});
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
...
paddle/cinn/hlir/op/contrib/logical_right_shift.cc
浏览文件 @
70183c4b
...
@@ -105,12 +105,8 @@ std::shared_ptr<OpStrategy> StrategyForLogicalRightShift(
...
@@ -105,12 +105,8 @@ std::shared_ptr<OpStrategy> StrategyForLogicalRightShift(
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
B
=
B_expr
.
as_tensor_ref
();
ir
::
Tensor
B
=
B_expr
.
as_tensor_ref
();
std
::
string
tensor_name
=
UniqName
(
"T_LogicalRightShift_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
3U
);
CHECK_EQ
(
pack_args
.
size
(),
3U
);
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
}
auto
out
=
LogicalRightShift
(
A
,
B
,
target
,
tensor_name
);
auto
out
=
LogicalRightShift
(
A
,
B
,
target
,
tensor_name
);
auto
stages
=
CreateStages
({
out
});
auto
stages
=
CreateStages
({
out
});
...
...
paddle/cinn/hlir/op/contrib/lookup_table.cc
浏览文件 @
70183c4b
...
@@ -106,11 +106,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForLookupTable(
...
@@ -106,11 +106,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForLookupTable(
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", B shape: "
<<
utils
::
Join
(
tensor_B
->
shape
,
", "
)
<<
", B shape: "
<<
utils
::
Join
(
tensor_B
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
UniqName
(
"LookupTable_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
3U
);
CHECK_EQ
(
pack_args
.
size
(),
3U
);
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
LookupTable
(
tensor_A
,
tensor_B
,
padding_idx
,
tensor_name
);
ir
::
Tensor
out
=
LookupTable
(
tensor_A
,
tensor_B
,
padding_idx
,
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
std
::
vector
<
CINNValue
>
res
;
stages
->
InsertLazily
(
out
);
stages
->
InsertLazily
(
out
);
...
...
paddle/cinn/hlir/op/contrib/one_hot.cc
100755 → 100644
浏览文件 @
70183c4b
...
@@ -194,12 +194,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForOneHot(
...
@@ -194,12 +194,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForOneHot(
ir
::
Tensor
on_value
=
on_value_expr
.
as_tensor_ref
();
ir
::
Tensor
on_value
=
on_value_expr
.
as_tensor_ref
();
ir
::
Tensor
off_value
=
off_value_expr
.
as_tensor_ref
();
ir
::
Tensor
off_value
=
off_value_expr
.
as_tensor_ref
();
std
::
string
tensor_name
=
common
::
UniqName
(
"T_OneHot_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
4U
);
CHECK_EQ
(
pack_args
.
size
(),
4U
);
tensor_name
=
pack_args
[
3
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
3
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
OneHot
(
indices
,
ir
::
Tensor
out
=
OneHot
(
indices
,
on_value
,
on_value
,
...
...
paddle/cinn/hlir/op/contrib/reciprocal.cc
浏览文件 @
70183c4b
...
@@ -94,13 +94,9 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(
...
@@ -94,13 +94,9 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(
CHECK
(
!
pack_args
.
empty
())
CHECK
(
!
pack_args
.
empty
())
<<
"at least one input tensor for "
<<
op_name
<<
" compute
\n
"
;
<<
"at least one input tensor for "
<<
op_name
<<
" compute
\n
"
;
std
::
string
tensor_name
=
UniqName
(
"Reciprocal_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2
);
CHECK_EQ
(
pack_args
.
size
(),
2
);
CHECK
(
pack_args
[
1
].
is_string
());
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
Expr
A
=
pack_args
[
0
];
Expr
A
=
pack_args
[
0
];
CHECK
(
A
.
as_tensor
());
CHECK
(
A
.
as_tensor
());
...
@@ -110,10 +106,8 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(
...
@@ -110,10 +106,8 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2U
);
CHECK_EQ
(
pack_args
.
size
(),
2U
);
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
Reciprocal
(
tensor_A
,
tensor_name
);
ir
::
Tensor
out
=
Reciprocal
(
tensor_A
,
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
std
::
vector
<
CINNValue
>
res
;
...
...
paddle/cinn/hlir/op/contrib/resize.cc
浏览文件 @
70183c4b
...
@@ -207,12 +207,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForResize(
...
@@ -207,12 +207,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForResize(
auto
tensor_A
=
A
.
as_tensor_ref
();
auto
tensor_A
=
A
.
as_tensor_ref
();
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
common
::
UniqName
(
"T_Resize_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2U
);
CHECK_EQ
(
pack_args
.
size
(),
2U
);
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
Resize
(
tensor_A
,
target
,
out_shape
,
mode
,
tensor_name
);
ir
::
Tensor
out
=
Resize
(
tensor_A
,
target
,
out_shape
,
mode
,
tensor_name
);
...
...
paddle/cinn/hlir/op/contrib/sort.cc
浏览文件 @
70183c4b
...
@@ -178,12 +178,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
...
@@ -178,12 +178,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
auto
stages
=
CreateStages
({
tensor_A
});
auto
stages
=
CreateStages
({
tensor_A
});
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
auto
tensor_name
=
UniqName
(
"Sort_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2U
);
CHECK_EQ
(
pack_args
.
size
(),
2U
);
CHECK
(
pack_args
[
1
].
is_string
());
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
std
::
vector
<
ir
::
Tensor
>
out
=
std
::
vector
<
ir
::
Tensor
>
out
=
Sort
(
tensor_A
,
target
,
stages
,
axis
,
is_ascend
,
tensor_name
);
Sort
(
tensor_A
,
target
,
stages
,
axis
,
is_ascend
,
tensor_name
);
stages
->
InsertLazily
(
out
[
0
]);
stages
->
InsertLazily
(
out
[
0
]);
...
@@ -195,9 +192,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
...
@@ -195,9 +192,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
*
ret
=
CINNValuePack
{
res
};
*
ret
=
CINNValuePack
{
res
};
});
});
framework
::
CINNSchedule
sort_schedule
([
=
](
lang
::
Args
args
,
framework
::
CINNSchedule
sort_schedule
(
lang
::
RetValue
*
ret
)
{
[
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
CHECK
(
!
args
.
empty
())
<<
"The input argument of sort_schedule is empty! Please check.
\n
"
;
<<
"The input argument of sort_schedule is empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
common
::
CINNValuePack
arg_pack
=
args
[
0
];
...
@@ -213,8 +209,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
...
@@ -213,8 +209,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
ir_sch
.
MergeExprs
();
auto
blocks
=
ir_sch
.
GetAllBlocks
();
auto
blocks
=
ir_sch
.
GetAllBlocks
();
// TODO(Shixiaowei02): remove external calls, do not use local variables,
// TODO(Shixiaowei02): remove external calls, do not use local
//
because the size will exceed the limit.
// variables,
because the size will exceed the limit.
ir_sch
.
SetBuffer
(
blocks
[
0
],
"local"
);
ir_sch
.
SetBuffer
(
blocks
[
0
],
"local"
);
ir_sch
.
SetBuffer
(
blocks
[
1
],
"local"
);
ir_sch
.
SetBuffer
(
blocks
[
1
],
"local"
);
...
@@ -223,19 +219,12 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
...
@@ -223,19 +219,12 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
1
,
1
,
std
::
multiplies
<
int
>
());
std
::
multiplies
<
int
>
());
if
(
prod_size
>
1
&&
target
.
arch
==
Target
::
Arch
::
X86
)
{
if
(
prod_size
>
1
&&
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
}
}
std
::
vector
<
common
::
CINNValue
>
res
{
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of sort_schedule is empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
CHECK
(
out
.
as_tensor
());
*
ret
=
arg_pack
;
}
});
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
@@ -271,12 +260,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(
...
@@ -271,12 +260,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(
auto
stages
=
CreateStages
({
tensor_A
});
auto
stages
=
CreateStages
({
tensor_A
});
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
auto
tensor_name
=
UniqName
(
"ArgSort_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
3U
);
CHECK_EQ
(
pack_args
.
size
(),
3U
);
CHECK
(
pack_args
[
1
].
is_string
());
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
auto
out
=
ArgSort
(
tensor_A
,
target
,
stages
,
axis
,
is_ascend
,
tensor_name
);
auto
out
=
ArgSort
(
tensor_A
,
target
,
stages
,
axis
,
is_ascend
,
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
std
::
vector
<
CINNValue
>
res
;
stages
->
InsertLazily
(
out
.
at
(
0
));
stages
->
InsertLazily
(
out
.
at
(
0
));
...
@@ -291,7 +277,6 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(
...
@@ -291,7 +277,6 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(
framework
::
CINNSchedule
argsort_schedule
([
=
](
lang
::
Args
args
,
framework
::
CINNSchedule
argsort_schedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
CHECK
(
!
args
.
empty
())
<<
"The input argument of argsort_schedule is empty! Please check.
\n
"
;
<<
"The input argument of argsort_schedule is empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
common
::
CINNValuePack
arg_pack
=
args
[
0
];
...
@@ -322,14 +307,6 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(
...
@@ -322,14 +307,6 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(
std
::
vector
<
common
::
CINNValue
>
res
{
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of argsort_schedule is empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
CHECK
(
out
.
as_tensor
());
*
ret
=
arg_pack
;
}
});
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
...
paddle/cinn/hlir/op/elementwise.cc
浏览文件 @
70183c4b
...
@@ -67,12 +67,9 @@ std::shared_ptr<OpStrategy> StrategyForElementwise(
...
@@ -67,12 +67,9 @@ std::shared_ptr<OpStrategy> StrategyForElementwise(
CINNValuePack
pack_args
=
args
[
0
];
CINNValuePack
pack_args
=
args
[
0
];
CHECK_GE
(
pack_args
.
size
(),
1U
)
CHECK_GE
(
pack_args
.
size
(),
1U
)
<<
"1 input tensor for "
<<
op_name
<<
" compute"
;
<<
"1 input tensor for "
<<
op_name
<<
" compute"
;
std
::
string
tensor_name
=
UniqName
(
op_name
+
"_Out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2U
);
CHECK_EQ
(
pack_args
.
size
(),
2U
);
CHECK
(
pack_args
[
1
].
is_string
());
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
Expr
A_expr
=
pack_args
[
0
];
Expr
A_expr
=
pack_args
[
0
];
CHECK
(
A_expr
.
as_tensor
());
CHECK
(
A_expr
.
as_tensor
());
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
...
@@ -158,12 +155,9 @@ std::shared_ptr<OpStrategy> StrategyForScale(
...
@@ -158,12 +155,9 @@ std::shared_ptr<OpStrategy> StrategyForScale(
CHECK
(
A_expr
.
as_tensor
());
CHECK
(
A_expr
.
as_tensor
());
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
out
;
ir
::
Tensor
out
;
std
::
string
tensor_name
=
UniqName
(
"Scale_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2
);
CHECK_EQ
(
pack_args
.
size
(),
2
);
CHECK
(
pack_args
[
1
].
is_string
());
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
if
(
bias_after_scale
)
{
if
(
bias_after_scale
)
{
out
=
Compute
(
out
=
Compute
(
...
@@ -242,12 +236,9 @@ std::shared_ptr<OpStrategy> StrategyForConstScalar(
...
@@ -242,12 +236,9 @@ std::shared_ptr<OpStrategy> StrategyForConstScalar(
auto
scalar
=
GetScalarExpr
(
attrs
.
attr_store
.
at
(
"value"
));
auto
scalar
=
GetScalarExpr
(
attrs
.
attr_store
.
at
(
"value"
));
auto
scalar_type
=
out_type
.
at
(
0
);
auto
scalar_type
=
out_type
.
at
(
0
);
CINNValuePack
pack_args
=
args
[
0
];
CINNValuePack
pack_args
=
args
[
0
];
std
::
string
tensor_name
=
UniqName
(
"const_scalar_Out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
1U
);
CHECK_EQ
(
pack_args
.
size
(),
1U
);
CHECK
(
pack_args
[
0
].
is_string
());
CHECK
(
pack_args
[
0
].
is_string
());
tensor_name
=
pack_args
[
0
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
0
].
operator
std
::
string
();
}
auto
out
=
lang
::
Compute
(
auto
out
=
lang
::
Compute
(
{
Expr
(
1
)},
{
Expr
(
1
)},
...
@@ -371,12 +362,9 @@ std::shared_ptr<OpStrategy> StrategyForFillConstant(
...
@@ -371,12 +362,9 @@ std::shared_ptr<OpStrategy> StrategyForFillConstant(
}
}
CINNValuePack
arg_pack
=
args
[
0
];
CINNValuePack
arg_pack
=
args
[
0
];
std
::
string
tensor_name
=
UniqName
(
"fill_constant_Out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
arg_pack
.
size
(),
1U
);
CHECK_EQ
(
arg_pack
.
size
(),
1U
);
CHECK
(
arg_pack
[
0
].
is_string
());
CHECK
(
arg_pack
[
0
].
is_string
());
tensor_name
=
arg_pack
[
0
].
operator
std
::
string
();
std
::
string
tensor_name
=
arg_pack
[
0
].
operator
std
::
string
();
}
CHECK
(
!
shape
.
empty
())
<<
"shape attr is empty!"
;
CHECK
(
!
shape
.
empty
())
<<
"shape attr is empty!"
;
auto
shape_exprs
=
ToCinnExprs
(
shape
);
auto
shape_exprs
=
ToCinnExprs
(
shape
);
auto
out
=
lang
::
Compute
(
auto
out
=
lang
::
Compute
(
...
@@ -458,12 +446,9 @@ std::shared_ptr<OpStrategy> StrategyForAssignValue(
...
@@ -458,12 +446,9 @@ std::shared_ptr<OpStrategy> StrategyForAssignValue(
const
auto
&
value
=
attrs
.
attr_store
.
at
(
"values"
);
const
auto
&
value
=
attrs
.
attr_store
.
at
(
"values"
);
CINNValuePack
arg_pack
=
args
[
0
];
CINNValuePack
arg_pack
=
args
[
0
];
std
::
string
tensor_name
=
UniqName
(
"T_assign_value_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
arg_pack
.
size
(),
1U
);
CHECK_EQ
(
arg_pack
.
size
(),
1U
);
CHECK
(
arg_pack
[
0
].
is_string
());
CHECK
(
arg_pack
[
0
].
is_string
());
tensor_name
=
arg_pack
[
0
].
operator
std
::
string
();
std
::
string
tensor_name
=
arg_pack
[
0
].
operator
std
::
string
();
}
absl
::
optional
<
ir
::
Tensor
>
out
;
absl
::
optional
<
ir
::
Tensor
>
out
;
#define EXPAND_VALUE_TO_TENSOR(TYPE) \
#define EXPAND_VALUE_TO_TENSOR(TYPE) \
...
@@ -649,11 +634,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForSqueeze(
...
@@ -649,11 +634,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForSqueeze(
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
UniqName
(
"Squeeze_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2U
);
CHECK_EQ
(
pack_args
.
size
(),
2U
);
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
pe
::
Squeeze
(
tensor_A
,
axes
,
tensor_name
);
ir
::
Tensor
out
=
pe
::
Squeeze
(
tensor_A
,
axes
,
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
std
::
vector
<
CINNValue
>
res
;
...
@@ -729,12 +711,9 @@ std::shared_ptr<OpStrategy> StrategyForExpandDims(
...
@@ -729,12 +711,9 @@ std::shared_ptr<OpStrategy> StrategyForExpandDims(
Expr
x
=
input_args
[
0
];
Expr
x
=
input_args
[
0
];
CHECK
(
x
.
as_tensor
());
CHECK
(
x
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"expand_dims_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
input_args
.
size
(),
2U
);
CHECK_EQ
(
input_args
.
size
(),
2U
);
CHECK
(
input_args
[
1
].
is_string
());
CHECK
(
input_args
[
1
].
is_string
());
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
}
auto
out
=
auto
out
=
pe
::
ExpandDims
(
x
.
as_tensor_ref
(),
axes
,
output_shapes
[
0
],
tensor_name
);
pe
::
ExpandDims
(
x
.
as_tensor_ref
(),
axes
,
output_shapes
[
0
],
tensor_name
);
...
@@ -809,12 +788,9 @@ std::shared_ptr<OpStrategy> StrategyForReshape(
...
@@ -809,12 +788,9 @@ std::shared_ptr<OpStrategy> StrategyForReshape(
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
UniqName
(
"Reshape_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2
);
CHECK_EQ
(
pack_args
.
size
(),
2
);
CHECK
(
pack_args
[
1
].
is_string
());
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
pe
::
Reshape
(
tensor_A
,
output_shapes
[
0
],
tensor_name
);
ir
::
Tensor
out
=
pe
::
Reshape
(
tensor_A
,
output_shapes
[
0
],
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
std
::
vector
<
CINNValue
>
res
;
...
@@ -901,11 +877,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForCast(
...
@@ -901,11 +877,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForCast(
auto
stages
=
CreateStages
({
tensor_A
});
auto
stages
=
CreateStages
({
tensor_A
});
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
UniqName
(
"Cast_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2U
);
CHECK_EQ
(
pack_args
.
size
(),
2U
);
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
pe
::
Cast
(
tensor_A
,
out_type
[
0
],
tensor_name
);
ir
::
Tensor
out
=
pe
::
Cast
(
tensor_A
,
out_type
[
0
],
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
std
::
vector
<
CINNValue
>
res
;
stages
->
InsertLazily
(
out
);
stages
->
InsertLazily
(
out
);
...
@@ -953,11 +926,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForArange(
...
@@ -953,11 +926,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForArange(
<<
"The input argument of arange compute is empty! Please check.
\n
"
;
<<
"The input argument of arange compute is empty! Please check.
\n
"
;
CINNValuePack
pack_args
=
args
[
0
];
CINNValuePack
pack_args
=
args
[
0
];
std
::
string
tensor_name
=
common
::
UniqName
(
"T_Arange_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
1U
);
CHECK_EQ
(
pack_args
.
size
(),
1U
);
tensor_name
=
pack_args
[
0
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
0
].
operator
std
::
string
();
}
auto
out
=
pe
::
Arange
(
start
,
stop
,
step
,
dtype
,
tensor_name
);
auto
out
=
pe
::
Arange
(
start
,
stop
,
step
,
dtype
,
tensor_name
);
std
::
vector
<
common
::
CINNValue
>
res
;
std
::
vector
<
common
::
CINNValue
>
res
;
...
...
paddle/cinn/hlir/op/nn.cc
浏览文件 @
70183c4b
此差异已折叠。
点击以展开。
paddle/cinn/hlir/op/op_broadcast_test.cc
浏览文件 @
70183c4b
...
@@ -59,7 +59,6 @@ TEST(Operator, Operator_ElementWise_Add_Test0) {
...
@@ -59,7 +59,6 @@ TEST(Operator, Operator_ElementWise_Add_Test0) {
std
::
string
func_name
=
"add1"
;
std
::
string
func_name
=
"add1"
;
Module
::
Builder
builder
(
"module0"
,
target
);
Module
::
Builder
builder
(
"module0"
,
target
);
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
out_name
=
"C"
;
std
::
string
out_name
=
"C"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
...
@@ -76,23 +75,6 @@ TEST(Operator, Operator_ElementWise_Add_Test0) {
...
@@ -76,23 +75,6 @@ TEST(Operator, Operator_ElementWise_Add_Test0) {
builder
.
AddFunction
(
func
);
builder
.
AddFunction
(
func
);
}
}
}
else
{
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValue
(
B
)}};
common
::
CINNValuePack
rets
=
impl
->
fcompute
(
cinn_input
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
rets
=
impl
->
fschedule
(
rets
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
// the last element is a StageMap
for
(
int
i
=
0
;
i
<
rets
->
size
()
-
1
;
i
++
)
{
Expr
temp
=
rets
[
i
];
inputs
.
push_back
(
temp
.
as_tensor_ref
());
}
auto
func
=
Lower
(
"fn_"
+
func_name
,
rets
.
back
(),
inputs
);
LOG
(
INFO
)
<<
"Test Strategy Codegen:
\n
"
<<
func
;
builder
.
AddFunction
(
func
);
}
auto
jit
=
backends
::
ExecutionEngine
::
Create
({});
auto
jit
=
backends
::
ExecutionEngine
::
Create
({});
auto
module
=
builder
.
Build
();
auto
module
=
builder
.
Build
();
jit
->
Link
(
module
);
jit
->
Link
(
module
);
...
@@ -160,7 +142,6 @@ TEST(Operator, Operator_ElementWise_Add_Test1) {
...
@@ -160,7 +142,6 @@ TEST(Operator, Operator_ElementWise_Add_Test1) {
std
::
string
func_name
=
"add2"
;
std
::
string
func_name
=
"add2"
;
Module
::
Builder
builder
(
"module"
,
target
);
Module
::
Builder
builder
(
"module"
,
target
);
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
out_name
=
"C"
;
std
::
string
out_name
=
"C"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
...
@@ -176,22 +157,6 @@ TEST(Operator, Operator_ElementWise_Add_Test1) {
...
@@ -176,22 +157,6 @@ TEST(Operator, Operator_ElementWise_Add_Test1) {
LOG
(
INFO
)
<<
"Test Operator_ElementWise_Add_Test1's Strategy, func is :
\n
"
LOG
(
INFO
)
<<
"Test Operator_ElementWise_Add_Test1's Strategy, func is :
\n
"
<<
func
;
<<
func
;
}
}
}
else
{
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValue
(
B
)}};
common
::
CINNValuePack
rets
=
impl
->
fcompute
(
cinn_input
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
rets
=
impl
->
fschedule
(
rets
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
// the last element is a StageMap
for
(
int
i
=
0
;
i
<
rets
->
size
()
-
1
;
i
++
)
{
Expr
temp
=
rets
[
i
];
inputs
.
push_back
(
temp
.
as_tensor_ref
());
}
auto
func
=
Lower
(
"fn_"
+
func_name
,
rets
.
back
(),
inputs
);
LOG
(
INFO
)
<<
"Test Strategy Codegen:
\n
"
<<
func
;
builder
.
AddFunction
(
func
);
}
backends
::
CodeGenCUDA_Dev
codegen
(
target
);
backends
::
CodeGenCUDA_Dev
codegen
(
target
);
...
@@ -225,7 +190,6 @@ TEST(Operator, Operator_BroadcastTo) {
...
@@ -225,7 +190,6 @@ TEST(Operator, Operator_BroadcastTo) {
std
::
string
func_name
=
"broadcast_to"
;
std
::
string
func_name
=
"broadcast_to"
;
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
out_name
=
"C"
;
std
::
string
out_name
=
"C"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{
{
common
::
CINNValue
(
B
),
common
::
CINNValue
(
out_name
)}};
{
common
::
CINNValue
(
B
),
common
::
CINNValue
(
out_name
)}};
...
@@ -237,32 +201,13 @@ TEST(Operator, Operator_BroadcastTo) {
...
@@ -237,32 +201,13 @@ TEST(Operator, Operator_BroadcastTo) {
for
(
auto
func
:
funcs
)
{
for
(
auto
func
:
funcs
)
{
LOG
(
INFO
)
<<
"Test Operator_BroadcastTo's Strategy, func is :
\n
"
<<
func
;
LOG
(
INFO
)
<<
"Test Operator_BroadcastTo's Strategy, func is :
\n
"
<<
func
;
}
}
}
else
{
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
B
)}};
common
::
CINNValuePack
rets
=
impl
->
fcompute
(
cinn_input
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
rets
=
impl
->
fschedule
(
rets
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
// the last element is a StageMap
for
(
int
i
=
0
;
i
<
rets
->
size
()
-
1
;
i
++
)
{
Expr
temp
=
rets
[
i
];
inputs
.
push_back
(
temp
.
as_tensor_ref
());
}
auto
func
=
Lower
(
"func"
+
func_name
,
rets
.
back
(),
inputs
);
LOG
(
INFO
)
<<
"Test Operator_BroadcastTo's Strategy, func is :
\n
"
<<
func
;
}
}
}
common
::
CINNValuePack
GetComputeResult
(
common
::
CINNValuePack
GetComputeResult
(
const
std
::
shared_ptr
<
OpImpl
>
&
impl
,
const
std
::
shared_ptr
<
OpImpl
>
&
impl
,
std
::
vector
<
common
::
CINNValue
>
&
cinn_inputs
,
// NOLINT
std
::
vector
<
common
::
CINNValue
>
&
cinn_inputs
,
// NOLINT
const
std
::
string
&
output_name
=
""
)
{
const
std
::
string
&
output_name
=
""
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
cinn_inputs
.
emplace_back
(
output_name
);
cinn_inputs
.
emplace_back
(
output_name
);
}
return
impl
->
fcompute
(
common
::
CINNValuePack
{
cinn_inputs
});
return
impl
->
fcompute
(
common
::
CINNValuePack
{
cinn_inputs
});
}
}
...
...
paddle/cinn/hlir/op/op_util.cc
浏览文件 @
70183c4b
...
@@ -21,8 +21,6 @@
...
@@ -21,8 +21,6 @@
#include "paddle/cinn/hlir/pe/schedule.h"
#include "paddle/cinn/hlir/pe/schedule.h"
#include "paddle/cinn/ir/ir_schedule.h"
#include "paddle/cinn/ir/ir_schedule.h"
DECLARE_bool
(
cinn_ir_schedule
);
namespace
cinn
{
namespace
cinn
{
namespace
hlir
{
namespace
hlir
{
...
@@ -31,7 +29,6 @@ CINNSchedule GetElementwiseScheduleFunc(
...
@@ -31,7 +29,6 @@ CINNSchedule GetElementwiseScheduleFunc(
const
Target
&
target
,
const
Target
&
target
,
bool
vectorizable
)
{
bool
vectorizable
)
{
return
CINNSchedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
return
CINNSchedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of ElementwiseSchedule is "
CHECK
(
!
args
.
empty
())
<<
"The input argument of ElementwiseSchedule is "
"empty! Please check.
\n
"
;
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
common
::
CINNValuePack
arg_pack
=
args
[
0
];
...
@@ -50,25 +47,6 @@ CINNSchedule GetElementwiseScheduleFunc(
...
@@ -50,25 +47,6 @@ CINNSchedule GetElementwiseScheduleFunc(
std
::
vector
<
common
::
CINNValue
>
res
{
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of ElementwiseSchedule is "
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
poly
::
StageMap
stages
=
arg_pack
[
1
];
CHECK
(
out
.
as_tensor
());
CHECK_EQ
(
arg_pack
.
size
(),
2UL
);
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
pe
::
CudaScheduleInjective
(
stages
[
out
.
as_tensor_ref
()],
output_shapes
.
front
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
ScheduleInjectiveCPU
(
stages
[
out
.
as_tensor_ref
()],
output_shapes
.
front
(),
target
,
vectorizable
);
}
*
ret
=
arg_pack
;
}
});
});
}
}
...
@@ -77,7 +55,6 @@ CINNSchedule GetInjectiveScheduleFunc(
...
@@ -77,7 +55,6 @@ CINNSchedule GetInjectiveScheduleFunc(
const
Target
&
target
,
const
Target
&
target
,
bool
vectorizable
)
{
bool
vectorizable
)
{
return
CINNSchedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
return
CINNSchedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of InjectiveSchedule is "
CHECK
(
!
args
.
empty
())
<<
"The input argument of InjectiveSchedule is "
"empty! Please check.
\n
"
;
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
common
::
CINNValuePack
arg_pack
=
args
[
0
];
...
@@ -102,25 +79,6 @@ CINNSchedule GetInjectiveScheduleFunc(
...
@@ -102,25 +79,6 @@ CINNSchedule GetInjectiveScheduleFunc(
std
::
vector
<
common
::
CINNValue
>
res
{
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of InjectiveSchedule is "
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
poly
::
StageMap
stages
=
arg_pack
[
1
];
CHECK
(
out
.
as_tensor
());
CHECK_EQ
(
arg_pack
.
size
(),
2UL
);
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
pe
::
CudaScheduleInjective
(
stages
[
out
.
as_tensor_ref
()],
output_shapes
.
front
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
ScheduleInjectiveCPU
(
stages
[
out
.
as_tensor_ref
()],
output_shapes
.
front
(),
target
,
vectorizable
);
}
*
ret
=
arg_pack
;
}
});
});
}
}
...
...
paddle/cinn/hlir/op/reduction.cc
浏览文件 @
70183c4b
...
@@ -29,8 +29,6 @@
...
@@ -29,8 +29,6 @@
#include "paddle/cinn/ir/ir_schedule.h"
#include "paddle/cinn/ir/ir_schedule.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/ir_simplify.h"
DECLARE_bool
(
cinn_ir_schedule
);
namespace
cinn
{
namespace
cinn
{
namespace
hlir
{
namespace
hlir
{
namespace
op
{
namespace
op
{
...
@@ -115,16 +113,10 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
...
@@ -115,16 +113,10 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
<<
" compute is empty! Please check."
;
<<
" compute is empty! Please check."
;
CINNValuePack
arg_packs
=
args
[
0
];
CINNValuePack
arg_packs
=
args
[
0
];
std
::
string
tensor_name
=
UniqName
(
op_name
+
"_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
arg_packs
.
size
(),
2U
)
CHECK_EQ
(
arg_packs
.
size
(),
2U
)
<<
"There should be 2 input args for "
<<
op_name
<<
" compute"
;
<<
"There should be 2 input args for "
<<
op_name
<<
" compute"
;
CHECK
(
arg_packs
[
1
].
is_string
());
CHECK
(
arg_packs
[
1
].
is_string
());
tensor_name
=
arg_packs
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
arg_packs
[
1
].
operator
std
::
string
();
}
else
{
CHECK_EQ
(
arg_packs
.
size
(),
1U
)
<<
"There should be 1 input args for "
<<
op_name
<<
" compute"
;
}
Expr
x_expr
=
arg_packs
[
0
];
Expr
x_expr
=
arg_packs
[
0
];
CHECK
(
x_expr
.
as_tensor
());
CHECK
(
x_expr
.
as_tensor
());
ir
::
Tensor
x
=
x_expr
.
as_tensor_ref
();
ir
::
Tensor
x
=
x_expr
.
as_tensor_ref
();
...
@@ -175,12 +167,10 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
...
@@ -175,12 +167,10 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
lang
::
RetValue
*
ret
)
{
lang
::
RetValue
*
ret
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
<<
" schedule is empty! Please check."
;
<<
" schedule is empty! Please check."
;
CINNValuePack
arg_pack
=
args
[
0
];
if
(
FLAGS_cinn_ir_schedule
)
{
CINNValuePack
arg_pack
=
args
[
0
];
CHECK_GE
(
arg_pack
.
size
(),
2UL
);
CHECK_GE
(
arg_pack
.
size
(),
2UL
);
CHECK_LE
(
arg_pack
.
size
(),
8UL
);
CHECK_LE
(
arg_pack
.
size
(),
8UL
);
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
std
::
vector
<
Expr
>
vec_ast
;
std
::
vector
<
Expr
>
vec_tensor
;
std
::
vector
<
Expr
>
vec_tensor
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
...
@@ -291,8 +281,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
...
@@ -291,8 +281,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
Expr
reduce_reshape
=
vec_tensor
[
2
];
Expr
reduce_reshape
=
vec_tensor
[
2
];
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockShuffleReduce Schedule!"
;
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockShuffleReduce Schedule!"
;
pe
::
IRCudaScheduleBlockShuffleReduce
(
pe
::
IRCudaScheduleBlockShuffleReduce
(
ir_sch
,
ir_sch
,
reduce_reshape
.
as_tensor_ref
(),
reduce_reshape
.
as_tensor_ref
(),
reduce_internal
.
as_tensor_ref
(),
reduce_internal
.
as_tensor_ref
(),
reduce_out
.
as_tensor_ref
(),
reduce_out
.
as_tensor_ref
(),
...
@@ -310,72 +299,6 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
...
@@ -310,72 +299,6 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
*
ret
=
CINNValuePack
{
res
};
}
}
}
else
{
CHECK_GE
(
arg_pack
.
size
(),
2UL
);
CHECK_LE
(
arg_pack
.
size
(),
5UL
);
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
if
(
!
WithoutLastDimInReduce
(
inputs
[
0
]
->
shape
,
reduce_axes
))
{
if
(
arg_pack
.
size
()
==
3
)
{
Expr
out
=
arg_pack
[
0
];
Expr
tmp_out
=
arg_pack
[
1
];
poly
::
StageMap
stages
=
arg_pack
.
back
();
VLOG
(
3
)
<<
"Do CudaBlockReduceInternalSchedule Schedule!"
;
pe
::
CudaBlockReduceInternalSchedule
(
stages
,
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
common
::
DefaultNVGPUTarget
());
}
else
if
(
arg_pack
.
size
()
==
4
)
{
Expr
out
=
arg_pack
[
0
];
Expr
tmp_out
=
arg_pack
[
1
];
Expr
reduce_tmp_out
=
arg_pack
[
2
];
poly
::
StageMap
stages
=
arg_pack
.
back
();
VLOG
(
3
)
<<
"Do CudaBlockReduceSchedule Schedule!"
;
pe
::
CudaBlockReduceSchedule
(
stages
,
reduce_tmp_out
.
as_tensor_ref
(),
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
common
::
DefaultNVGPUTarget
());
}
else
{
Expr
out
=
arg_pack
[
0
];
Expr
tmp_out
=
arg_pack
[
1
];
Expr
reduce_tmp_out
=
arg_pack
[
2
];
Expr
reshape
=
arg_pack
[
3
];
poly
::
StageMap
stages
=
arg_pack
.
back
();
VLOG
(
3
)
<<
"Do CudaTwoStepReduceSchedule Schedule!"
;
pe
::
CudaTwoStepReduceSchedule
(
stages
,
reshape
.
as_tensor_ref
(),
reduce_tmp_out
.
as_tensor_ref
(),
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
common
::
DefaultNVGPUTarget
());
}
}
else
{
if
(
arg_pack
.
size
()
==
2
)
{
Expr
reduce_out
=
arg_pack
[
0
];
poly
::
StageMap
stages
=
arg_pack
.
back
();
VLOG
(
3
)
<<
"Do CudaReduceSchedule Schedule!"
;
pe
::
CudaReduceSchedule
(
stages
,
reduce_out
.
as_tensor_ref
(),
inputs
[
0
]
->
shape
.
size
()
-
reduce_axes
.
back
()
-
1
,
target
);
}
else
{
CHECK_EQ
(
arg_pack
.
size
(),
4
)
<<
"args is not equal 4!"
;
Expr
reduce_reshape
=
arg_pack
[
2
];
Expr
reduce_internal
=
arg_pack
[
1
];
Expr
reduce_out
=
arg_pack
[
0
];
poly
::
StageMap
stages
=
arg_pack
.
back
();
VLOG
(
3
)
<<
"Do CudaBlockShuffleReduceSchedule Schedule!"
;
pe
::
CudaBlockShuffleReduceSchedule
(
stages
,
reduce_reshape
.
as_tensor_ref
(),
reduce_internal
.
as_tensor_ref
(),
reduce_out
.
as_tensor_ref
(),
target
);
}
}
}
*
ret
=
arg_pack
;
}
});
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
...
paddle/cinn/hlir/op/transform.cc
浏览文件 @
70183c4b
...
@@ -73,12 +73,9 @@ std::shared_ptr<OpStrategy> StrategyForMatMul(
...
@@ -73,12 +73,9 @@ std::shared_ptr<OpStrategy> StrategyForMatMul(
CHECK
(
A
.
as_tensor
());
CHECK
(
A
.
as_tensor
());
CHECK
(
B
.
as_tensor
());
CHECK
(
B
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"MatMul"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_GE
(
pack_args
.
size
(),
3
);
CHECK_GE
(
pack_args
.
size
(),
3
);
CHECK
(
pack_args
[
2
].
is_string
());
CHECK
(
pack_args
[
2
].
is_string
());
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
}
auto
tensor_A
=
A
.
as_tensor_ref
();
auto
tensor_A
=
A
.
as_tensor_ref
();
auto
tensor_B
=
B
.
as_tensor_ref
();
auto
tensor_B
=
B
.
as_tensor_ref
();
...
@@ -130,32 +127,9 @@ std::shared_ptr<OpStrategy> StrategyForMatMul(
...
@@ -130,32 +127,9 @@ std::shared_ptr<OpStrategy> StrategyForMatMul(
CHECK
(
!
args
.
empty
())
CHECK
(
!
args
.
empty
())
<<
"The input argument of matmul schedule is empty! Please check.
\n
"
;
<<
"The input argument of matmul schedule is empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
CINNValuePack
arg_pack
=
args
[
0
];
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
vector
<
CINNValue
>
results
=
std
::
vector
<
CINNValue
>
results
=
pe
::
IRCudaScheduleMatMul
(
arg_pack
,
output_shape
,
target
);
pe
::
IRCudaScheduleMatMul
(
arg_pack
,
output_shape
,
target
);
*
ret
=
CINNValuePack
({
results
});
*
ret
=
CINNValuePack
({
results
});
}
else
{
CHECK
(
arg_pack
.
size
()
==
2UL
||
arg_pack
.
size
()
==
3UL
);
poly
::
StageMap
stages
=
arg_pack
.
back
();
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
Expr
out
=
arg_pack
[
0
];
CHECK
(
out
.
as_tensor
());
pe
::
MatmulScheduleCUDA
(
stages
,
out
.
as_tensor_ref
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
#ifdef CINN_WITH_MKL_CBLAS
CHECK_EQ
(
arg_pack
.
size
(),
3UL
);
#else
CHECK_EQ
(
arg_pack
.
size
(),
3UL
);
Expr
out
=
arg_pack
[
0
];
Expr
packedB
=
arg_pack
[
1
];
CHECK
(
packedB
.
as_tensor
());
CHECK
(
out
.
as_tensor
());
pe
::
MatmulScheduleCPU
(
stages
,
out
.
as_tensor_ref
(),
packedB
.
as_tensor_ref
(),
target
);
#endif
}
*
ret
=
arg_pack
;
}
});
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
@@ -262,17 +236,11 @@ std::shared_ptr<OpStrategy> StrategyForSplit(
...
@@ -262,17 +236,11 @@ std::shared_ptr<OpStrategy> StrategyForSplit(
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
std
::
vector
<
std
::
string
>
tensor_names
;
std
::
vector
<
std
::
string
>
tensor_names
;
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
output_shapes
.
size
()
+
1
);
CHECK_EQ
(
pack_args
.
size
(),
output_shapes
.
size
()
+
1
);
for
(
int
idx
=
1
;
idx
<
pack_args
.
size
();
++
idx
)
{
for
(
int
idx
=
1
;
idx
<
pack_args
.
size
();
++
idx
)
{
CHECK
(
pack_args
[
idx
].
is_string
());
CHECK
(
pack_args
[
idx
].
is_string
());
tensor_names
.
push_back
(
pack_args
[
idx
].
operator
std
::
string
());
tensor_names
.
push_back
(
pack_args
[
idx
].
operator
std
::
string
());
}
}
}
else
{
for
(
int
idx
=
0
;
idx
<
output_shapes
.
size
();
++
idx
)
{
tensor_names
.
push_back
(
UniqName
(
"T_Split_Out"
));
}
}
auto
out
=
pe
::
Split
(
A
,
axis
,
output_shapes
,
tensor_names
);
auto
out
=
pe
::
Split
(
A
,
axis
,
output_shapes
,
tensor_names
);
auto
stages
=
CreateStages
(
out
);
auto
stages
=
CreateStages
(
out
);
...
@@ -285,9 +253,8 @@ std::shared_ptr<OpStrategy> StrategyForSplit(
...
@@ -285,9 +253,8 @@ std::shared_ptr<OpStrategy> StrategyForSplit(
*
ret
=
CINNValuePack
{
res
};
*
ret
=
CINNValuePack
{
res
};
});
});
framework
::
CINNSchedule
split_schedule
([
=
](
lang
::
Args
args
,
framework
::
CINNSchedule
split_schedule
(
lang
::
RetValue
*
ret
)
{
[
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
CHECK
(
!
args
.
empty
())
<<
"The input argument of split schedule is empty! Please check."
;
<<
"The input argument of split schedule is empty! Please check."
;
CINNValuePack
arg_pack
=
args
[
0
];
CINNValuePack
arg_pack
=
args
[
0
];
...
@@ -306,16 +273,6 @@ std::shared_ptr<OpStrategy> StrategyForSplit(
...
@@ -306,16 +273,6 @@ std::shared_ptr<OpStrategy> StrategyForSplit(
std
::
vector
<
CINNValue
>
res
{
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
*
ret
=
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input arguments of split schedule is empty! Please check."
;
CINNValuePack
arg_pack
=
args
[
0
];
CHECK_GE
(
arg_pack
.
size
(),
2UL
)
<<
"The input tensor's size of split schedule is "
<<
arg_pack
.
size
()
<<
"and it should be greater equal to 2! Please check."
;
pe
::
CudaSplitSchedule
(
&
arg_pack
,
output_shapes
,
axis
,
target
);
*
ret
=
arg_pack
;
}
});
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
@@ -468,8 +425,7 @@ std::shared_ptr<OpStrategy> StrategyForConcat(
...
@@ -468,8 +425,7 @@ std::shared_ptr<OpStrategy> StrategyForConcat(
CHECK
(
!
out_type
.
empty
())
CHECK
(
!
out_type
.
empty
())
<<
"Output type of Concat is empty! Please check.
\n
"
;
<<
"Output type of Concat is empty! Please check.
\n
"
;
CINNValuePack
pack_args
=
args
[
0
];
CINNValuePack
pack_args
=
args
[
0
];
int
input_size
=
int
input_size
=
pack_args
.
size
()
-
1
;
FLAGS_cinn_ir_schedule
?
pack_args
.
size
()
-
1
:
pack_args
.
size
();
CHECK_GE
(
input_size
,
1UL
)
CHECK_GE
(
input_size
,
1UL
)
<<
"at least 2 input tensors for Concat compute
\n
"
;
<<
"at least 2 input tensors for Concat compute
\n
"
;
CHECK
(
!
output_shapes
.
empty
());
CHECK
(
!
output_shapes
.
empty
());
...
@@ -485,11 +441,8 @@ std::shared_ptr<OpStrategy> StrategyForConcat(
...
@@ -485,11 +441,8 @@ std::shared_ptr<OpStrategy> StrategyForConcat(
input_tensors
.
push_back
(
tensor
.
as_tensor_ref
());
input_tensors
.
push_back
(
tensor
.
as_tensor_ref
());
}
}
std
::
string
tensor_name
=
UniqName
(
"Concat_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
pack_args
[
input_size
].
is_string
());
CHECK
(
pack_args
[
input_size
].
is_string
());
tensor_name
=
pack_args
[
input_size
].
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
[
input_size
].
operator
std
::
string
();
}
auto
stages
=
CreateStages
(
input_tensors
);
auto
stages
=
CreateStages
(
input_tensors
);
auto
out
=
pe
::
Concat
(
input_tensors
,
axis
,
tensor_name
);
auto
out
=
pe
::
Concat
(
input_tensors
,
axis
,
tensor_name
);
...
@@ -612,11 +565,8 @@ std::shared_ptr<OpStrategy> StrategyForMul(
...
@@ -612,11 +565,8 @@ std::shared_ptr<OpStrategy> StrategyForMul(
auto
new_B
=
B_tensor
->
Reshape
(
new_shape_B_e
,
stages
);
auto
new_B
=
B_tensor
->
Reshape
(
new_shape_B_e
,
stages
);
std
::
vector
<
ir
::
Tensor
>
out
;
std
::
vector
<
ir
::
Tensor
>
out
;
std
::
string
tensor_name
=
UniqName
(
"Mul_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
pack_args
.
back
().
is_string
());
CHECK
(
pack_args
.
back
().
is_string
());
tensor_name
=
pack_args
.
back
().
operator
std
::
string
();
std
::
string
tensor_name
=
pack_args
.
back
().
operator
std
::
string
();
}
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
#ifdef CINN_WITH_MKL_CBLAS
#ifdef CINN_WITH_MKL_CBLAS
...
@@ -647,32 +597,9 @@ std::shared_ptr<OpStrategy> StrategyForMul(
...
@@ -647,32 +597,9 @@ std::shared_ptr<OpStrategy> StrategyForMul(
CHECK
(
!
args
.
empty
())
CHECK
(
!
args
.
empty
())
<<
"The input argument of matmul schedule is empty! Please check.
\n
"
;
<<
"The input argument of matmul schedule is empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
CINNValuePack
arg_pack
=
args
[
0
];
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
vector
<
CINNValue
>
results
=
std
::
vector
<
CINNValue
>
results
=
pe
::
IRCudaScheduleMatMul
(
arg_pack
,
output_shape
,
target
);
pe
::
IRCudaScheduleMatMul
(
arg_pack
,
output_shape
,
target
);
*
ret
=
CINNValuePack
({
results
});
*
ret
=
CINNValuePack
({
results
});
}
else
{
CHECK
(
arg_pack
.
size
()
==
2UL
||
arg_pack
.
size
()
==
3UL
);
poly
::
StageMap
stages
=
arg_pack
.
back
();
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
Expr
out
=
arg_pack
[
0
];
CHECK
(
out
.
as_tensor
());
pe
::
MatmulScheduleCUDA
(
stages
,
out
.
as_tensor_ref
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
#ifdef CINN_WITH_MKL_CBLAS
CHECK_EQ
(
arg_pack
.
size
(),
3UL
);
#else
CHECK_EQ
(
arg_pack
.
size
(),
3UL
);
Expr
out
=
arg_pack
[
0
];
Expr
packedB
=
arg_pack
[
1
];
CHECK
(
packedB
.
as_tensor
());
CHECK
(
out
.
as_tensor
());
pe
::
MatmulScheduleCPU
(
stages
,
out
.
as_tensor_ref
(),
packedB
.
as_tensor_ref
(),
target
);
#endif
}
*
ret
=
arg_pack
;
}
});
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
@@ -780,12 +707,9 @@ std::shared_ptr<OpStrategy> StrategyForCublasGemm(
...
@@ -780,12 +707,9 @@ std::shared_ptr<OpStrategy> StrategyForCublasGemm(
// dummy gemm computation, which will be replaced by
// dummy gemm computation, which will be replaced by
// cinn_gpu_cublas_gemm in the GemmRewriter pass.
// cinn_gpu_cublas_gemm in the GemmRewriter pass.
std
::
string
tensor_name
=
UniqName
(
"cublas_gemm_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
input_args
.
size
(),
4
);
CHECK_EQ
(
input_args
.
size
(),
4
);
CHECK
(
input_args
[
3
].
is_string
());
CHECK
(
input_args
[
3
].
is_string
());
tensor_name
=
input_args
[
3
].
operator
std
::
string
();
std
::
string
tensor_name
=
input_args
[
3
].
operator
std
::
string
();
}
auto
out
=
pe
::
Identity
(
bias_tensor
,
tensor_name
).
front
();
auto
out
=
pe
::
Identity
(
bias_tensor
,
tensor_name
).
front
();
auto
stages
=
CreateStages
(
auto
stages
=
CreateStages
(
{
lhs
.
as_tensor_ref
(),
rhs
.
as_tensor_ref
(),
bias_tensor
});
{
lhs
.
as_tensor_ref
(),
rhs
.
as_tensor_ref
(),
bias_tensor
});
...
@@ -849,12 +773,9 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform(
...
@@ -849,12 +773,9 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform(
Expr
A
=
input_args
[
0
];
Expr
A
=
input_args
[
0
];
CHECK
(
A
.
as_tensor
());
CHECK
(
A
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"layout_transform_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
input_args
.
size
(),
2
);
CHECK_EQ
(
input_args
.
size
(),
2
);
CHECK
(
input_args
[
1
].
is_string
());
CHECK
(
input_args
[
1
].
is_string
());
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
}
auto
out
=
pe
::
LayoutTransform
(
auto
out
=
pe
::
LayoutTransform
(
A
.
as_tensor_ref
(),
src_layout
,
dst_layout
,
tensor_name
);
A
.
as_tensor_ref
(),
src_layout
,
dst_layout
,
tensor_name
);
...
@@ -865,9 +786,8 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform(
...
@@ -865,9 +786,8 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform(
*
ret
=
CINNValuePack
{
res
};
*
ret
=
CINNValuePack
{
res
};
});
});
framework
::
CINNSchedule
layout_transform_schedule
(
framework
::
CINNSchedule
layout_transform_schedule
([
=
](
lang
::
Args
args
,
[
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of CublasGemm schedule "
CHECK
(
!
args
.
empty
())
<<
"The input argument of CublasGemm schedule "
"is empty! Please check."
;
"is empty! Please check."
;
CINNValuePack
arg_pack
=
args
[
0
];
CINNValuePack
arg_pack
=
args
[
0
];
...
@@ -888,29 +808,8 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform(
...
@@ -888,29 +808,8 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform(
}
else
{
}
else
{
CINN_NOT_IMPLEMENTED
CINN_NOT_IMPLEMENTED
}
}
std
::
vector
<
CINNValue
>
res
{
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
*
ret
=
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of layout_transform "
"schedule is empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
CHECK_EQ
(
arg_pack
.
size
(),
2UL
);
Expr
out
=
arg_pack
[
0
];
poly
::
StageMap
stages
=
arg_pack
[
1
];
CHECK
(
out
.
as_tensor
());
auto
tensor_out
=
out
.
as_tensor_ref
();
std
::
vector
<
int
>
out_shape
;
for
(
auto
shape
:
tensor_out
->
shape
)
{
out_shape
.
push_back
(
shape
.
as_int32
());
}
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
ScheduleInjectiveCPU
(
stages
[
tensor_out
],
out_shape
,
target
);
}
else
{
CINN_NOT_IMPLEMENTED
}
*
ret
=
arg_pack
;
}
});
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
@@ -996,12 +895,9 @@ std::shared_ptr<OpStrategy> StrategyForReverse(
...
@@ -996,12 +895,9 @@ std::shared_ptr<OpStrategy> StrategyForReverse(
Expr
A
=
input_args
[
0
];
Expr
A
=
input_args
[
0
];
CHECK
(
A
.
as_tensor
());
CHECK
(
A
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"Reverse_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
input_args
.
size
(),
2
);
CHECK_EQ
(
input_args
.
size
(),
2
);
CHECK
(
input_args
[
1
].
is_string
());
CHECK
(
input_args
[
1
].
is_string
());
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
}
auto
out
=
pe
::
Reverse
(
A
.
as_tensor_ref
(),
axis
,
tensor_name
);
auto
out
=
pe
::
Reverse
(
A
.
as_tensor_ref
(),
axis
,
tensor_name
);
auto
stages
=
CreateStages
({
A
.
as_tensor_ref
(),
out
});
auto
stages
=
CreateStages
({
A
.
as_tensor_ref
(),
out
});
...
@@ -1113,12 +1009,9 @@ std::shared_ptr<OpStrategy> StrategyForTranspose(
...
@@ -1113,12 +1009,9 @@ std::shared_ptr<OpStrategy> StrategyForTranspose(
<<
"at least one input tensor for transpose compute
\n
"
;
<<
"at least one input tensor for transpose compute
\n
"
;
Expr
A
=
input_args
[
0
];
Expr
A
=
input_args
[
0
];
CHECK
(
A
.
as_tensor
());
CHECK
(
A
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"Transpose_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
input_args
.
size
(),
2
);
CHECK_EQ
(
input_args
.
size
(),
2
);
CHECK
(
input_args
[
1
].
is_string
());
CHECK
(
input_args
[
1
].
is_string
());
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
}
auto
out
=
pe
::
Transpose
(
A
.
as_tensor_ref
(),
axis
,
tensor_name
);
auto
out
=
pe
::
Transpose
(
A
.
as_tensor_ref
(),
axis
,
tensor_name
);
auto
stages
=
CreateStages
({
out
});
auto
stages
=
CreateStages
({
out
});
...
@@ -1236,12 +1129,9 @@ std::shared_ptr<OpStrategy> StrategyForGather(
...
@@ -1236,12 +1129,9 @@ std::shared_ptr<OpStrategy> StrategyForGather(
Expr
index
=
input_args
[
1
];
Expr
index
=
input_args
[
1
];
CHECK
(
index
.
as_tensor
());
CHECK
(
index
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"gather_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
input_args
.
size
(),
3U
);
CHECK_EQ
(
input_args
.
size
(),
3U
);
CHECK
(
input_args
[
2
].
is_string
());
CHECK
(
input_args
[
2
].
is_string
());
tensor_name
=
input_args
[
2
].
operator
std
::
string
();
std
::
string
tensor_name
=
input_args
[
2
].
operator
std
::
string
();
}
auto
out
=
pe
::
Gather
(
x
.
as_tensor_ref
(),
auto
out
=
pe
::
Gather
(
x
.
as_tensor_ref
(),
index
.
as_tensor_ref
(),
index
.
as_tensor_ref
(),
...
@@ -1335,12 +1225,9 @@ std::shared_ptr<OpStrategy> StrategyForScatterAssign(
...
@@ -1335,12 +1225,9 @@ std::shared_ptr<OpStrategy> StrategyForScatterAssign(
auto
stages
=
CreateStages
({
tensor_input
,
tensor_updates
,
tensor_index
});
auto
stages
=
CreateStages
({
tensor_input
,
tensor_updates
,
tensor_index
});
std
::
string
tensor_name
=
UniqName
(
"scatter_assign_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
arg_pack
.
size
(),
4U
);
CHECK_EQ
(
arg_pack
.
size
(),
4U
);
CHECK
(
arg_pack
[
3
].
is_string
());
CHECK
(
arg_pack
[
3
].
is_string
());
tensor_name
=
arg_pack
[
3
].
operator
std
::
string
();
std
::
string
tensor_name
=
arg_pack
[
3
].
operator
std
::
string
();
}
auto
out
=
pe
::
ScatterAssign
(
auto
out
=
pe
::
ScatterAssign
(
tensor_input
,
tensor_updates
,
tensor_index
,
target
,
axis
,
tensor_name
);
tensor_input
,
tensor_updates
,
tensor_index
,
target
,
axis
,
tensor_name
);
...
@@ -1462,12 +1349,9 @@ std::shared_ptr<OpStrategy> StrategyForScatterAdd(
...
@@ -1462,12 +1349,9 @@ std::shared_ptr<OpStrategy> StrategyForScatterAdd(
auto
stages
=
CreateStages
({
tensor_input
,
tensor_updates
,
tensor_index
});
auto
stages
=
CreateStages
({
tensor_input
,
tensor_updates
,
tensor_index
});
std
::
string
tensor_name
=
UniqName
(
"scatter_add_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
arg_pack
.
size
(),
4U
);
CHECK_EQ
(
arg_pack
.
size
(),
4U
);
CHECK
(
arg_pack
[
3
].
is_string
());
CHECK
(
arg_pack
[
3
].
is_string
());
tensor_name
=
arg_pack
[
3
].
operator
std
::
string
();
std
::
string
tensor_name
=
arg_pack
[
3
].
operator
std
::
string
();
}
auto
out
=
pe
::
ScatterAdd
(
auto
out
=
pe
::
ScatterAdd
(
tensor_input
,
tensor_updates
,
tensor_index
,
target
,
axis
,
tensor_name
);
tensor_input
,
tensor_updates
,
tensor_index
,
target
,
axis
,
tensor_name
);
...
@@ -1617,12 +1501,9 @@ std::shared_ptr<OpStrategy> StrategyForSlice(
...
@@ -1617,12 +1501,9 @@ std::shared_ptr<OpStrategy> StrategyForSlice(
CHECK
(
A_expr
.
as_tensor
());
CHECK
(
A_expr
.
as_tensor
());
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
std
::
string
tensor_name
=
UniqName
(
"Slice_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
arg_pack
.
size
(),
2U
);
CHECK_EQ
(
arg_pack
.
size
(),
2U
);
CHECK
(
arg_pack
[
1
].
is_string
());
CHECK
(
arg_pack
[
1
].
is_string
());
tensor_name
=
arg_pack
[
1
].
operator
std
::
string
();
std
::
string
tensor_name
=
arg_pack
[
1
].
operator
std
::
string
();
}
auto
out
=
pe
::
Slice
(
auto
out
=
pe
::
Slice
(
A
,
starts
,
axes
,
strides
,
decrease_axis
,
output_shape
,
tensor_name
);
A
,
starts
,
axes
,
strides
,
decrease_axis
,
output_shape
,
tensor_name
);
...
@@ -1854,12 +1735,9 @@ std::shared_ptr<OpStrategy> StrategyForSliceAssign(
...
@@ -1854,12 +1735,9 @@ std::shared_ptr<OpStrategy> StrategyForSliceAssign(
Expr
assign
=
arg_pack
[
1
];
Expr
assign
=
arg_pack
[
1
];
CHECK
(
assign
.
as_tensor
());
CHECK
(
assign
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"slice_assign_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
arg_pack
.
size
(),
3U
);
CHECK_EQ
(
arg_pack
.
size
(),
3U
);
CHECK
(
arg_pack
[
2
].
is_string
());
CHECK
(
arg_pack
[
2
].
is_string
());
tensor_name
=
arg_pack
[
2
].
operator
std
::
string
();
std
::
string
tensor_name
=
arg_pack
[
2
].
operator
std
::
string
();
}
auto
out
=
pe
::
SliceAssign
(
input
.
as_tensor_ref
(),
auto
out
=
pe
::
SliceAssign
(
input
.
as_tensor_ref
(),
assign
.
as_tensor_ref
(),
assign
.
as_tensor_ref
(),
...
...
paddle/cinn/hlir/op/transform_test.cc
浏览文件 @
70183c4b
...
@@ -86,7 +86,6 @@ TEST(SliceAssign, SliceAssign_Op) {
...
@@ -86,7 +86,6 @@ TEST(SliceAssign, SliceAssign_Op) {
std
::
string
func_name
=
"slice_assign"
;
std
::
string
func_name
=
"slice_assign"
;
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
out_name
=
"output"
;
std
::
string
out_name
=
"output"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
input
.
tensor
()),
common
::
CINNValuePack
{{
common
::
CINNValue
(
input
.
tensor
()),
...
@@ -100,27 +99,6 @@ TEST(SliceAssign, SliceAssign_Op) {
...
@@ -100,27 +99,6 @@ TEST(SliceAssign, SliceAssign_Op) {
for
(
auto
func
:
funcs
)
{
for
(
auto
func
:
funcs
)
{
LOG
(
INFO
)
<<
"Test Operator_BroadcastTo's Strategy, func is :
\n
"
<<
func
;
LOG
(
INFO
)
<<
"Test Operator_BroadcastTo's Strategy, func is :
\n
"
<<
func
;
}
}
}
else
{
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
input
.
tensor
()),
common
::
CINNValue
(
assign
.
tensor
())}};
common
::
CINNValuePack
rets
=
impl
->
fcompute
(
cinn_input
);
rets
=
impl
->
fschedule
(
rets
);
// the last element is a StageMap
for
(
int
i
=
0
;
i
<
rets
->
size
()
-
1
;
i
++
)
{
Expr
temp
=
rets
[
i
];
if
(
!
temp
.
as_tensor_ref
()
->
buffer
.
defined
())
{
inputs
.
push_back
(
temp
.
as_tensor_ref
());
}
}
auto
func
=
lang
::
LowerVec
(
"slice_assign"
,
rets
.
back
(),
inputs
,
{},
{},
nullptr
,
target
);
for
(
auto
&
f
:
func
)
{
LOG
(
INFO
)
<<
"Test Strategy Codegen:
\n
"
<<
f
;
}
}
}
}
}
// namespace framework
}
// namespace framework
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录