Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
70183c4b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
70183c4b
编写于
7月 17, 2023
作者:
H
Huihuang Zheng
提交者:
GitHub
7月 17, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove Old Schedules in Ops (#55391)
Remove old schedules.
上级
db1f2c42
变更
15
展开全部
隐藏空白更改
内联
并排
Showing
15 changed file
with
729 addition
and
1519 deletion
+729
-1519
paddle/cinn/hlir/op/broadcast.cc
paddle/cinn/hlir/op/broadcast.cc
+35
-43
paddle/cinn/hlir/op/contrib/gather_nd.cc
paddle/cinn/hlir/op/contrib/gather_nd.cc
+27
-40
paddle/cinn/hlir/op/contrib/logical_right_shift.cc
paddle/cinn/hlir/op/contrib/logical_right_shift.cc
+2
-6
paddle/cinn/hlir/op/contrib/lookup_table.cc
paddle/cinn/hlir/op/contrib/lookup_table.cc
+3
-5
paddle/cinn/hlir/op/contrib/one_hot.cc
paddle/cinn/hlir/op/contrib/one_hot.cc
+2
-6
paddle/cinn/hlir/op/contrib/reciprocal.cc
paddle/cinn/hlir/op/contrib/reciprocal.cc
+5
-11
paddle/cinn/hlir/op/contrib/resize.cc
paddle/cinn/hlir/op/contrib/resize.cc
+2
-5
paddle/cinn/hlir/op/contrib/sort.cc
paddle/cinn/hlir/op/contrib/sort.cc
+67
-90
paddle/cinn/hlir/op/elementwise.cc
paddle/cinn/hlir/op/elementwise.cc
+27
-57
paddle/cinn/hlir/op/nn.cc
paddle/cinn/hlir/op/nn.cc
+268
-647
paddle/cinn/hlir/op/op_broadcast_test.cc
paddle/cinn/hlir/op/op_broadcast_test.cc
+34
-89
paddle/cinn/hlir/op/op_util.cc
paddle/cinn/hlir/op/op_util.cc
+38
-80
paddle/cinn/hlir/op/reduction.cc
paddle/cinn/hlir/op/reduction.cc
+117
-194
paddle/cinn/hlir/op/transform.cc
paddle/cinn/hlir/op/transform.cc
+90
-212
paddle/cinn/hlir/op/transform_test.cc
paddle/cinn/hlir/op/transform_test.cc
+12
-34
未找到文件。
paddle/cinn/hlir/op/broadcast.cc
浏览文件 @
70183c4b
...
@@ -60,37 +60,34 @@ std::shared_ptr<OpStrategy> StrategyForBroadcast(
...
@@ -60,37 +60,34 @@ std::shared_ptr<OpStrategy> StrategyForBroadcast(
const
ir
::
Tensor
&
B
,
const
ir
::
Tensor
&
B
,
const
std
::
string
&
output_name
,
const
std
::
string
&
output_name
,
const
Expr
&
axis
))
{
const
Expr
&
axis
))
{
framework
::
CINNCompute
binary_compute
([
=
](
lang
::
Args
args
,
framework
::
CINNCompute
binary_compute
(
lang
::
RetValue
*
ret
)
{
[
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
<<
" compute is empty! Please check."
;
<<
" compute is empty! Please check."
;
CINNValuePack
pack_args
=
args
[
0
];
CINNValuePack
pack_args
=
args
[
0
];
CHECK_GE
(
pack_args
.
size
(),
2U
)
CHECK_GE
(
pack_args
.
size
(),
2U
)
<<
"at least 2 input tensors for "
<<
op_name
<<
" compute"
;
<<
"at least 2 input tensors for "
<<
op_name
<<
" compute"
;
std
::
string
tensor_name
=
UniqName
(
op_name
+
"_Out"
);
CHECK_GE
(
pack_args
.
size
(),
3U
)
<<
op_name
<<
" 's input is not enough!"
;
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
pack_args
[
2
].
is_string
());
CHECK_GE
(
pack_args
.
size
(),
3U
)
<<
op_name
<<
" 's input is not enough!"
;
std
::
string
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
CHECK
(
pack_args
[
2
].
is_string
());
Expr
A_expr
=
pack_args
[
0
];
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
Expr
B_expr
=
pack_args
[
1
];
}
CHECK
(
A_expr
.
as_tensor
());
Expr
A_expr
=
pack_args
[
0
];
CHECK
(
B_expr
.
as_tensor
());
Expr
B_expr
=
pack_args
[
1
];
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
CHECK
(
A_expr
.
as_tensor
());
ir
::
Tensor
B
=
B_expr
.
as_tensor_ref
();
CHECK
(
B_expr
.
as_tensor
());
Expr
axis
;
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
bool
trans_a
;
ir
::
Tensor
B
=
B_expr
.
as_tensor_ref
();
for
(
auto
&
iter
:
attrs
.
attr_store
)
{
Expr
axis
;
if
(
iter
.
first
==
"axis"
)
{
bool
trans_a
;
axis
=
Expr
(
absl
::
get
<
int
>
(
iter
.
second
));
for
(
auto
&
iter
:
attrs
.
attr_store
)
{
break
;
if
(
iter
.
first
==
"axis"
)
{
}
axis
=
Expr
(
absl
::
get
<
int
>
(
iter
.
second
));
}
break
;
auto
out
=
pe_func
(
A
,
B
,
tensor_name
,
axis
);
}
auto
stages
=
CreateStages
({
A
,
B
,
out
});
}
*
ret
=
CINNValuePack
{{
CINNValue
(
Expr
(
out
.
get
())),
CINNValue
(
stages
)}};
auto
out
=
pe_func
(
A
,
B
,
tensor_name
,
axis
);
});
auto
stages
=
CreateStages
({
A
,
B
,
out
});
*
ret
=
CINNValuePack
{{
CINNValue
(
Expr
(
out
.
get
())),
CINNValue
(
stages
)}};
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
strategy
->
AddImpl
(
binary_compute
,
strategy
->
AddImpl
(
binary_compute
,
...
@@ -198,12 +195,10 @@ std::shared_ptr<OpStrategy> StrategyForBroadcastTo(
...
@@ -198,12 +195,10 @@ std::shared_ptr<OpStrategy> StrategyForBroadcastTo(
CINNValuePack
pack_args
=
args
[
0
];
CINNValuePack
pack_args
=
args
[
0
];
CHECK
(
!
pack_args
.
empty
())
CHECK
(
!
pack_args
.
empty
())
<<
"The input tensors of broadcast_to compute is empty! Please check."
;
<<
"The input tensors of broadcast_to compute is empty! Please check."
;
std
::
string
tensor_name
=
UniqName
(
"broadcast_to_Out"
);
CHECK_GE
(
pack_args
.
size
(),
2U
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
pack_args
[
1
].
is_string
());
CHECK_GE
(
pack_args
.
size
(),
2U
);
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
Expr
A_expr
=
pack_args
[
0
];
Expr
A_expr
=
pack_args
[
0
];
CHECK
(
A_expr
.
as_tensor
());
CHECK
(
A_expr
.
as_tensor
());
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
...
@@ -323,12 +318,9 @@ std::shared_ptr<OpStrategy> StrategyForIsClose(
...
@@ -323,12 +318,9 @@ std::shared_ptr<OpStrategy> StrategyForIsClose(
CINNValuePack
pack_args
=
args
[
0
];
CINNValuePack
pack_args
=
args
[
0
];
int
input_size
=
pack_args
.
size
();
int
input_size
=
pack_args
.
size
();
std
::
string
tensor_name
=
UniqName
(
"IsClose_output"
);
// the last pack argument is the output tensor name
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
tensor_name
=
pack_args
.
back
().
operator
std
::
string
();
// the last pack argument is the output tensor name
--
input_size
;
tensor_name
=
pack_args
.
back
().
operator
std
::
string
();
--
input_size
;
}
CHECK_EQ
(
input_size
,
2
)
CHECK_EQ
(
input_size
,
2
)
<<
"The input number of isclose should be 2, but here "
<<
"The input number of isclose should be 2, but here "
<<
input_size
<<
"! Please check."
;
<<
input_size
<<
"! Please check."
;
...
...
paddle/cinn/hlir/op/contrib/gather_nd.cc
浏览文件 @
70183c4b
...
@@ -114,11 +114,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
...
@@ -114,11 +114,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
VLOG
(
3
)
<<
"x shape: "
<<
utils
::
Join
(
tensor_x
->
shape
,
", "
)
VLOG
(
3
)
<<
"x shape: "
<<
utils
::
Join
(
tensor_x
->
shape
,
", "
)
<<
", index shape: "
<<
utils
::
Join
(
tensor_index
->
shape
,
", "
)
<<
", index shape: "
<<
utils
::
Join
(
tensor_index
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
UniqName
(
"GatherNd_out"
);
CHECK_EQ
(
pack_args
.
size
(),
3U
);
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
CHECK_EQ
(
pack_args
.
size
(),
3U
);
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
GatherNd
(
tensor_x
,
tensor_index
,
tensor_name
);
ir
::
Tensor
out
=
GatherNd
(
tensor_x
,
tensor_index
,
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
std
::
vector
<
CINNValue
>
res
;
stages
->
InsertLazily
(
out
);
stages
->
InsertLazily
(
out
);
...
@@ -131,44 +128,34 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
...
@@ -131,44 +128,34 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
framework
::
CINNSchedule
gather_nd_schedule
([
=
](
lang
::
Args
args
,
framework
::
CINNSchedule
gather_nd_schedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of gather_nd_schedule is "
CHECK
(
!
args
.
empty
())
<<
"The input argument of gather_nd_schedule is "
"empty! Please check.
\n
"
;
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
common
::
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
vec_ast
.
emplace_back
(
temp
);
}
}
}
CHECK
(
!
vec_ast
.
empty
());
}
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
CHECK
(
!
vec_ast
.
empty
());
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir_sch
.
MergeExprs
();
ir
::
IRSchedule
ir_sch
(
mod_expr
);
int64_t
prod_size
=
std
::
accumulate
(
output_shapes
[
0
].
begin
(),
ir_sch
.
MergeExprs
();
output_shapes
[
0
].
end
(),
int64_t
prod_size
=
std
::
accumulate
(
output_shapes
[
0
].
begin
(),
1
,
output_shapes
[
0
].
end
(),
std
::
multiplies
<
int
>
());
1
,
if
(
prod_size
>
1
)
{
std
::
multiplies
<
int
>
());
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
if
(
prod_size
>
1
)
{
pe
::
IRCudaScheduleInjective
(
ir_sch
,
output_shapes
.
front
(),
target
);
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
IRCudaScheduleInjective
(
ir_sch
,
output_shapes
.
front
(),
target
);
pe
::
IRScheduleInjectiveCPU
(
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
}
}
}
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of gather_nd_schedule is "
"empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
CHECK
(
out
.
as_tensor
());
*
ret
=
arg_pack
;
}
}
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
});
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
...
paddle/cinn/hlir/op/contrib/logical_right_shift.cc
浏览文件 @
70183c4b
...
@@ -105,12 +105,8 @@ std::shared_ptr<OpStrategy> StrategyForLogicalRightShift(
...
@@ -105,12 +105,8 @@ std::shared_ptr<OpStrategy> StrategyForLogicalRightShift(
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
B
=
B_expr
.
as_tensor_ref
();
ir
::
Tensor
B
=
B_expr
.
as_tensor_ref
();
std
::
string
tensor_name
=
UniqName
(
"T_LogicalRightShift_out"
);
CHECK_EQ
(
pack_args
.
size
(),
3U
);
std
::
string
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
3U
);
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
}
auto
out
=
LogicalRightShift
(
A
,
B
,
target
,
tensor_name
);
auto
out
=
LogicalRightShift
(
A
,
B
,
target
,
tensor_name
);
auto
stages
=
CreateStages
({
out
});
auto
stages
=
CreateStages
({
out
});
...
...
paddle/cinn/hlir/op/contrib/lookup_table.cc
浏览文件 @
70183c4b
...
@@ -106,11 +106,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForLookupTable(
...
@@ -106,11 +106,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForLookupTable(
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", B shape: "
<<
utils
::
Join
(
tensor_B
->
shape
,
", "
)
<<
", B shape: "
<<
utils
::
Join
(
tensor_B
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
UniqName
(
"LookupTable_out"
);
CHECK_EQ
(
pack_args
.
size
(),
3U
);
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
CHECK_EQ
(
pack_args
.
size
(),
3U
);
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
LookupTable
(
tensor_A
,
tensor_B
,
padding_idx
,
tensor_name
);
ir
::
Tensor
out
=
LookupTable
(
tensor_A
,
tensor_B
,
padding_idx
,
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
std
::
vector
<
CINNValue
>
res
;
stages
->
InsertLazily
(
out
);
stages
->
InsertLazily
(
out
);
...
...
paddle/cinn/hlir/op/contrib/one_hot.cc
100755 → 100644
浏览文件 @
70183c4b
...
@@ -194,12 +194,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForOneHot(
...
@@ -194,12 +194,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForOneHot(
ir
::
Tensor
on_value
=
on_value_expr
.
as_tensor_ref
();
ir
::
Tensor
on_value
=
on_value_expr
.
as_tensor_ref
();
ir
::
Tensor
off_value
=
off_value_expr
.
as_tensor_ref
();
ir
::
Tensor
off_value
=
off_value_expr
.
as_tensor_ref
();
std
::
string
tensor_name
=
common
::
UniqName
(
"T_OneHot_out"
);
CHECK_EQ
(
pack_args
.
size
(),
4U
);
std
::
string
tensor_name
=
pack_args
[
3
].
operator
std
::
string
();
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
4U
);
tensor_name
=
pack_args
[
3
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
OneHot
(
indices
,
ir
::
Tensor
out
=
OneHot
(
indices
,
on_value
,
on_value
,
...
...
paddle/cinn/hlir/op/contrib/reciprocal.cc
浏览文件 @
70183c4b
...
@@ -94,13 +94,9 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(
...
@@ -94,13 +94,9 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(
CHECK
(
!
pack_args
.
empty
())
CHECK
(
!
pack_args
.
empty
())
<<
"at least one input tensor for "
<<
op_name
<<
" compute
\n
"
;
<<
"at least one input tensor for "
<<
op_name
<<
" compute
\n
"
;
std
::
string
tensor_name
=
UniqName
(
"Reciprocal_out"
);
CHECK_EQ
(
pack_args
.
size
(),
2
);
CHECK
(
pack_args
[
1
].
is_string
());
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
CHECK_EQ
(
pack_args
.
size
(),
2
);
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
Expr
A
=
pack_args
[
0
];
Expr
A
=
pack_args
[
0
];
CHECK
(
A
.
as_tensor
());
CHECK
(
A
.
as_tensor
());
...
@@ -110,10 +106,8 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(
...
@@ -110,10 +106,8 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2U
);
CHECK_EQ
(
pack_args
.
size
(),
2U
);
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
Reciprocal
(
tensor_A
,
tensor_name
);
ir
::
Tensor
out
=
Reciprocal
(
tensor_A
,
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
std
::
vector
<
CINNValue
>
res
;
...
...
paddle/cinn/hlir/op/contrib/resize.cc
浏览文件 @
70183c4b
...
@@ -207,12 +207,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForResize(
...
@@ -207,12 +207,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForResize(
auto
tensor_A
=
A
.
as_tensor_ref
();
auto
tensor_A
=
A
.
as_tensor_ref
();
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
common
::
UniqName
(
"T_Resize_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2U
);
CHECK_EQ
(
pack_args
.
size
(),
2U
);
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
Resize
(
tensor_A
,
target
,
out_shape
,
mode
,
tensor_name
);
ir
::
Tensor
out
=
Resize
(
tensor_A
,
target
,
out_shape
,
mode
,
tensor_name
);
...
...
paddle/cinn/hlir/op/contrib/sort.cc
浏览文件 @
70183c4b
...
@@ -178,12 +178,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
...
@@ -178,12 +178,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
auto
stages
=
CreateStages
({
tensor_A
});
auto
stages
=
CreateStages
({
tensor_A
});
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
auto
tensor_name
=
UniqName
(
"Sort_out"
);
CHECK_EQ
(
pack_args
.
size
(),
2U
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
pack_args
[
1
].
is_string
());
CHECK_EQ
(
pack_args
.
size
(),
2U
);
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
std
::
vector
<
ir
::
Tensor
>
out
=
std
::
vector
<
ir
::
Tensor
>
out
=
Sort
(
tensor_A
,
target
,
stages
,
axis
,
is_ascend
,
tensor_name
);
Sort
(
tensor_A
,
target
,
stages
,
axis
,
is_ascend
,
tensor_name
);
stages
->
InsertLazily
(
out
[
0
]);
stages
->
InsertLazily
(
out
[
0
]);
...
@@ -195,48 +192,40 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
...
@@ -195,48 +192,40 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
*
ret
=
CINNValuePack
{
res
};
*
ret
=
CINNValuePack
{
res
};
});
});
framework
::
CINNSchedule
sort_schedule
(
[
=
](
lang
::
Args
args
,
framework
::
CINNSchedule
sort_schedule
(
lang
::
RetValue
*
ret
)
{
[
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
CHECK
(
!
args
.
empty
())
<<
"The input argument of sort_schedule is empty! Please check.
\n
"
;
<<
"The input argument of sort_schedule is empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
]
;
common
::
CINNValuePack
arg_pack
=
args
[
0
]
;
std
::
vector
<
Expr
>
vec_ast
;
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
()
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
Expr
temp
=
arg_pack
[
i
]
;
vec_ast
.
emplace_back
(
temp
)
;
vec_ast
.
emplace_back
(
temp
);
}
}
}
}
CHECK
(
!
vec_ast
.
empty
());
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
ir_sch
.
MergeExprs
();
auto
blocks
=
ir_sch
.
GetAllBlocks
();
auto
blocks
=
ir_sch
.
GetAllBlocks
();
// TODO(Shixiaowei02): remove external calls, do not use local
// TODO(Shixiaowei02): remove external calls, do not use local variables,
// variables, because the size will exceed the limit.
// because the size will exceed the limit.
ir_sch
.
SetBuffer
(
blocks
[
0
],
"local"
);
ir_sch
.
SetBuffer
(
blocks
[
0
],
"local"
);
ir_sch
.
SetBuffer
(
blocks
[
1
],
"local"
);
ir_sch
.
SetBuffer
(
blocks
[
1
],
"local"
);
int64_t
prod_size
=
std
::
accumulate
(
output_shapes
[
0
].
begin
(),
int64_t
prod_size
=
std
::
accumulate
(
output_shapes
[
0
].
begin
(),
output_shapes
[
0
].
end
(),
output_shapes
[
0
].
end
(),
1
,
1
,
std
::
multiplies
<
int
>
());
std
::
multiplies
<
int
>
());
if
(
prod_size
>
1
&&
target
.
arch
==
Target
::
Arch
::
X86
)
{
if
(
prod_size
>
1
&&
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
IRScheduleInjectiveCPU
(
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
}
}
std
::
vector
<
common
::
CINNValue
>
res
{
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
});
CHECK
(
!
args
.
empty
())
<<
"The input argument of sort_schedule is empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
CHECK
(
out
.
as_tensor
());
*
ret
=
arg_pack
;
}
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
strategy
->
AddImpl
(
sort_compute
,
sort_schedule
,
"strategy.sort"
,
1
);
strategy
->
AddImpl
(
sort_compute
,
sort_schedule
,
"strategy.sort"
,
1
);
...
@@ -271,12 +260,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(
...
@@ -271,12 +260,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(
auto
stages
=
CreateStages
({
tensor_A
});
auto
stages
=
CreateStages
({
tensor_A
});
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
auto
tensor_name
=
UniqName
(
"ArgSort_out"
);
CHECK_EQ
(
pack_args
.
size
(),
3U
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
pack_args
[
1
].
is_string
());
CHECK_EQ
(
pack_args
.
size
(),
3U
);
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
auto
out
=
ArgSort
(
tensor_A
,
target
,
stages
,
axis
,
is_ascend
,
tensor_name
);
auto
out
=
ArgSort
(
tensor_A
,
target
,
stages
,
axis
,
is_ascend
,
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
std
::
vector
<
CINNValue
>
res
;
stages
->
InsertLazily
(
out
.
at
(
0
));
stages
->
InsertLazily
(
out
.
at
(
0
));
...
@@ -291,45 +277,36 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(
...
@@ -291,45 +277,36 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(
framework
::
CINNSchedule
argsort_schedule
([
=
](
lang
::
Args
args
,
framework
::
CINNSchedule
argsort_schedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
CHECK
(
!
args
.
empty
())
<<
"The input argument of argsort_schedule is empty! Please check.
\n
"
;
<<
"The input argument of argsort_schedule is empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
common
::
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
vec_ast
.
emplace_back
(
temp
);
}
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
auto
blocks
=
ir_sch
.
GetAllBlocks
();
// TODO(Shixiaowei02): remove external calls, do not use local variables,
// because the size will exceed the limit.
// TODO(lanxianghit): There is a bug, setting buffer to "local" here will
// cause the var declared twice at CodeGen. ir_sch.SetBuffer(blocks[0],
// "local");
int64_t
prod_size
=
std
::
accumulate
(
output_shapes
[
0
].
begin
(),
output_shapes
[
0
].
end
(),
1
,
std
::
multiplies
<
int
>
());
if
(
prod_size
>
1
&&
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
}
}
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of argsort_schedule is empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
CHECK
(
out
.
as_tensor
());
*
ret
=
arg_pack
;
}
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
auto
blocks
=
ir_sch
.
GetAllBlocks
();
// TODO(Shixiaowei02): remove external calls, do not use local variables,
// because the size will exceed the limit.
// TODO(lanxianghit): There is a bug, setting buffer to "local" here will
// cause the var declared twice at CodeGen. ir_sch.SetBuffer(blocks[0],
// "local");
int64_t
prod_size
=
std
::
accumulate
(
output_shapes
[
0
].
begin
(),
output_shapes
[
0
].
end
(),
1
,
std
::
multiplies
<
int
>
());
if
(
prod_size
>
1
&&
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
}
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
});
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
...
paddle/cinn/hlir/op/elementwise.cc
浏览文件 @
70183c4b
...
@@ -67,12 +67,9 @@ std::shared_ptr<OpStrategy> StrategyForElementwise(
...
@@ -67,12 +67,9 @@ std::shared_ptr<OpStrategy> StrategyForElementwise(
CINNValuePack
pack_args
=
args
[
0
];
CINNValuePack
pack_args
=
args
[
0
];
CHECK_GE
(
pack_args
.
size
(),
1U
)
CHECK_GE
(
pack_args
.
size
(),
1U
)
<<
"1 input tensor for "
<<
op_name
<<
" compute"
;
<<
"1 input tensor for "
<<
op_name
<<
" compute"
;
std
::
string
tensor_name
=
UniqName
(
op_name
+
"_Out"
);
CHECK_EQ
(
pack_args
.
size
(),
2U
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
pack_args
[
1
].
is_string
());
CHECK_EQ
(
pack_args
.
size
(),
2U
);
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
Expr
A_expr
=
pack_args
[
0
];
Expr
A_expr
=
pack_args
[
0
];
CHECK
(
A_expr
.
as_tensor
());
CHECK
(
A_expr
.
as_tensor
());
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
...
@@ -158,12 +155,9 @@ std::shared_ptr<OpStrategy> StrategyForScale(
...
@@ -158,12 +155,9 @@ std::shared_ptr<OpStrategy> StrategyForScale(
CHECK
(
A_expr
.
as_tensor
());
CHECK
(
A_expr
.
as_tensor
());
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
out
;
ir
::
Tensor
out
;
std
::
string
tensor_name
=
UniqName
(
"Scale_out"
);
CHECK_EQ
(
pack_args
.
size
(),
2
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
pack_args
[
1
].
is_string
());
CHECK_EQ
(
pack_args
.
size
(),
2
);
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
if
(
bias_after_scale
)
{
if
(
bias_after_scale
)
{
out
=
Compute
(
out
=
Compute
(
...
@@ -242,12 +236,9 @@ std::shared_ptr<OpStrategy> StrategyForConstScalar(
...
@@ -242,12 +236,9 @@ std::shared_ptr<OpStrategy> StrategyForConstScalar(
auto
scalar
=
GetScalarExpr
(
attrs
.
attr_store
.
at
(
"value"
));
auto
scalar
=
GetScalarExpr
(
attrs
.
attr_store
.
at
(
"value"
));
auto
scalar_type
=
out_type
.
at
(
0
);
auto
scalar_type
=
out_type
.
at
(
0
);
CINNValuePack
pack_args
=
args
[
0
];
CINNValuePack
pack_args
=
args
[
0
];
std
::
string
tensor_name
=
UniqName
(
"const_scalar_Out"
);
CHECK_EQ
(
pack_args
.
size
(),
1U
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
pack_args
[
0
].
is_string
());
CHECK_EQ
(
pack_args
.
size
(),
1U
);
std
::
string
tensor_name
=
pack_args
[
0
].
operator
std
::
string
();
CHECK
(
pack_args
[
0
].
is_string
());
tensor_name
=
pack_args
[
0
].
operator
std
::
string
();
}
auto
out
=
lang
::
Compute
(
auto
out
=
lang
::
Compute
(
{
Expr
(
1
)},
{
Expr
(
1
)},
...
@@ -371,12 +362,9 @@ std::shared_ptr<OpStrategy> StrategyForFillConstant(
...
@@ -371,12 +362,9 @@ std::shared_ptr<OpStrategy> StrategyForFillConstant(
}
}
CINNValuePack
arg_pack
=
args
[
0
];
CINNValuePack
arg_pack
=
args
[
0
];
std
::
string
tensor_name
=
UniqName
(
"fill_constant_Out"
);
CHECK_EQ
(
arg_pack
.
size
(),
1U
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
arg_pack
[
0
].
is_string
());
CHECK_EQ
(
arg_pack
.
size
(),
1U
);
std
::
string
tensor_name
=
arg_pack
[
0
].
operator
std
::
string
();
CHECK
(
arg_pack
[
0
].
is_string
());
tensor_name
=
arg_pack
[
0
].
operator
std
::
string
();
}
CHECK
(
!
shape
.
empty
())
<<
"shape attr is empty!"
;
CHECK
(
!
shape
.
empty
())
<<
"shape attr is empty!"
;
auto
shape_exprs
=
ToCinnExprs
(
shape
);
auto
shape_exprs
=
ToCinnExprs
(
shape
);
auto
out
=
lang
::
Compute
(
auto
out
=
lang
::
Compute
(
...
@@ -458,12 +446,9 @@ std::shared_ptr<OpStrategy> StrategyForAssignValue(
...
@@ -458,12 +446,9 @@ std::shared_ptr<OpStrategy> StrategyForAssignValue(
const
auto
&
value
=
attrs
.
attr_store
.
at
(
"values"
);
const
auto
&
value
=
attrs
.
attr_store
.
at
(
"values"
);
CINNValuePack
arg_pack
=
args
[
0
];
CINNValuePack
arg_pack
=
args
[
0
];
std
::
string
tensor_name
=
UniqName
(
"T_assign_value_out"
);
CHECK_EQ
(
arg_pack
.
size
(),
1U
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
arg_pack
[
0
].
is_string
());
CHECK_EQ
(
arg_pack
.
size
(),
1U
);
std
::
string
tensor_name
=
arg_pack
[
0
].
operator
std
::
string
();
CHECK
(
arg_pack
[
0
].
is_string
());
tensor_name
=
arg_pack
[
0
].
operator
std
::
string
();
}
absl
::
optional
<
ir
::
Tensor
>
out
;
absl
::
optional
<
ir
::
Tensor
>
out
;
#define EXPAND_VALUE_TO_TENSOR(TYPE) \
#define EXPAND_VALUE_TO_TENSOR(TYPE) \
...
@@ -649,11 +634,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForSqueeze(
...
@@ -649,11 +634,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForSqueeze(
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
UniqName
(
"Squeeze_out"
);
CHECK_EQ
(
pack_args
.
size
(),
2U
);
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
CHECK_EQ
(
pack_args
.
size
(),
2U
);
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
pe
::
Squeeze
(
tensor_A
,
axes
,
tensor_name
);
ir
::
Tensor
out
=
pe
::
Squeeze
(
tensor_A
,
axes
,
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
std
::
vector
<
CINNValue
>
res
;
...
@@ -729,12 +711,9 @@ std::shared_ptr<OpStrategy> StrategyForExpandDims(
...
@@ -729,12 +711,9 @@ std::shared_ptr<OpStrategy> StrategyForExpandDims(
Expr
x
=
input_args
[
0
];
Expr
x
=
input_args
[
0
];
CHECK
(
x
.
as_tensor
());
CHECK
(
x
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"expand_dims_output"
);
CHECK_EQ
(
input_args
.
size
(),
2U
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
input_args
[
1
].
is_string
());
CHECK_EQ
(
input_args
.
size
(),
2U
);
std
::
string
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
CHECK
(
input_args
[
1
].
is_string
());
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
}
auto
out
=
auto
out
=
pe
::
ExpandDims
(
x
.
as_tensor_ref
(),
axes
,
output_shapes
[
0
],
tensor_name
);
pe
::
ExpandDims
(
x
.
as_tensor_ref
(),
axes
,
output_shapes
[
0
],
tensor_name
);
...
@@ -809,12 +788,9 @@ std::shared_ptr<OpStrategy> StrategyForReshape(
...
@@ -809,12 +788,9 @@ std::shared_ptr<OpStrategy> StrategyForReshape(
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
UniqName
(
"Reshape_out"
);
CHECK_EQ
(
pack_args
.
size
(),
2
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
pack_args
[
1
].
is_string
());
CHECK_EQ
(
pack_args
.
size
(),
2
);
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
pe
::
Reshape
(
tensor_A
,
output_shapes
[
0
],
tensor_name
);
ir
::
Tensor
out
=
pe
::
Reshape
(
tensor_A
,
output_shapes
[
0
],
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
std
::
vector
<
CINNValue
>
res
;
...
@@ -901,11 +877,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForCast(
...
@@ -901,11 +877,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForCast(
auto
stages
=
CreateStages
({
tensor_A
});
auto
stages
=
CreateStages
({
tensor_A
});
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
UniqName
(
"Cast_out"
);
CHECK_EQ
(
pack_args
.
size
(),
2U
);
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
CHECK_EQ
(
pack_args
.
size
(),
2U
);
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
ir
::
Tensor
out
=
pe
::
Cast
(
tensor_A
,
out_type
[
0
],
tensor_name
);
ir
::
Tensor
out
=
pe
::
Cast
(
tensor_A
,
out_type
[
0
],
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
std
::
vector
<
CINNValue
>
res
;
stages
->
InsertLazily
(
out
);
stages
->
InsertLazily
(
out
);
...
@@ -953,11 +926,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForArange(
...
@@ -953,11 +926,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForArange(
<<
"The input argument of arange compute is empty! Please check.
\n
"
;
<<
"The input argument of arange compute is empty! Please check.
\n
"
;
CINNValuePack
pack_args
=
args
[
0
];
CINNValuePack
pack_args
=
args
[
0
];
std
::
string
tensor_name
=
common
::
UniqName
(
"T_Arange_out"
);
CHECK_EQ
(
pack_args
.
size
(),
1U
);
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
tensor_name
=
pack_args
[
0
].
operator
std
::
string
();
CHECK_EQ
(
pack_args
.
size
(),
1U
);
tensor_name
=
pack_args
[
0
].
operator
std
::
string
();
}
auto
out
=
pe
::
Arange
(
start
,
stop
,
step
,
dtype
,
tensor_name
);
auto
out
=
pe
::
Arange
(
start
,
stop
,
step
,
dtype
,
tensor_name
);
std
::
vector
<
common
::
CINNValue
>
res
;
std
::
vector
<
common
::
CINNValue
>
res
;
...
...
paddle/cinn/hlir/op/nn.cc
浏览文件 @
70183c4b
此差异已折叠。
点击以展开。
paddle/cinn/hlir/op/op_broadcast_test.cc
浏览文件 @
70183c4b
...
@@ -59,37 +59,19 @@ TEST(Operator, Operator_ElementWise_Add_Test0) {
...
@@ -59,37 +59,19 @@ TEST(Operator, Operator_ElementWise_Add_Test0) {
std
::
string
func_name
=
"add1"
;
std
::
string
func_name
=
"add1"
;
Module
::
Builder
builder
(
"module0"
,
target
);
Module
::
Builder
builder
(
"module0"
,
target
);
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
out_name
=
"C"
;
std
::
string
out_name
=
"C"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValue
(
B
),
common
::
CINNValue
(
B
),
common
::
CINNValue
(
out_name
)}};
common
::
CINNValue
(
out_name
)}};
std
::
vector
<
std
::
string
>
input_output_names
{
"A"
,
"B"
,
out_name
};
std
::
vector
<
std
::
string
>
input_output_names
{
"A"
,
"B"
,
out_name
};
auto
funcs
=
framework
::
GetFuncFromImpl
(
auto
funcs
=
framework
::
GetFuncFromImpl
(
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
for
(
auto
func
:
funcs
)
{
for
(
auto
func
:
funcs
)
{
LOG
(
INFO
)
<<
"Test Operator_ElementWise_Add_Test0's Strategy, func is :
\n
"
LOG
(
INFO
)
<<
"Test Operator_ElementWise_Add_Test0's Strategy, func is :
\n
"
<<
func
;
<<
func
;
builder
.
AddFunction
(
func
);
}
}
else
{
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValue
(
B
)}};
common
::
CINNValuePack
rets
=
impl
->
fcompute
(
cinn_input
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
rets
=
impl
->
fschedule
(
rets
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
// the last element is a StageMap
for
(
int
i
=
0
;
i
<
rets
->
size
()
-
1
;
i
++
)
{
Expr
temp
=
rets
[
i
];
inputs
.
push_back
(
temp
.
as_tensor_ref
());
}
auto
func
=
Lower
(
"fn_"
+
func_name
,
rets
.
back
(),
inputs
);
LOG
(
INFO
)
<<
"Test Strategy Codegen:
\n
"
<<
func
;
builder
.
AddFunction
(
func
);
builder
.
AddFunction
(
func
);
}
}
...
@@ -160,37 +142,20 @@ TEST(Operator, Operator_ElementWise_Add_Test1) {
...
@@ -160,37 +142,20 @@ TEST(Operator, Operator_ElementWise_Add_Test1) {
std
::
string
func_name
=
"add2"
;
std
::
string
func_name
=
"add2"
;
Module
::
Builder
builder
(
"module"
,
target
);
Module
::
Builder
builder
(
"module"
,
target
);
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
out_name
=
"C"
;
std
::
string
out_name
=
"C"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValue
(
B
),
common
::
CINNValue
(
B
),
common
::
CINNValue
(
out_name
)}};
common
::
CINNValue
(
out_name
)}};
std
::
vector
<
std
::
string
>
input_output_names
{
"A"
,
"B"
,
out_name
};
std
::
vector
<
std
::
string
>
input_output_names
{
"A"
,
"B"
,
out_name
};
auto
funcs
=
framework
::
GetFuncFromImpl
(
auto
funcs
=
framework
::
GetFuncFromImpl
(
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
for
(
auto
func
:
funcs
)
{
for
(
auto
func
:
funcs
)
{
builder
.
AddFunction
(
func
);
LOG
(
INFO
)
<<
"Test Operator_ElementWise_Add_Test1's Strategy, func is :
\n
"
<<
func
;
}
}
else
{
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValue
(
B
)}};
common
::
CINNValuePack
rets
=
impl
->
fcompute
(
cinn_input
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
rets
=
impl
->
fschedule
(
rets
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
// the last element is a StageMap
for
(
int
i
=
0
;
i
<
rets
->
size
()
-
1
;
i
++
)
{
Expr
temp
=
rets
[
i
];
inputs
.
push_back
(
temp
.
as_tensor_ref
());
}
auto
func
=
Lower
(
"fn_"
+
func_name
,
rets
.
back
(),
inputs
);
LOG
(
INFO
)
<<
"Test Strategy Codegen:
\n
"
<<
func
;
builder
.
AddFunction
(
func
);
builder
.
AddFunction
(
func
);
LOG
(
INFO
)
<<
"Test Operator_ElementWise_Add_Test1's Strategy, func is :
\n
"
<<
func
;
}
}
backends
::
CodeGenCUDA_Dev
codegen
(
target
);
backends
::
CodeGenCUDA_Dev
codegen
(
target
);
...
@@ -225,33 +190,15 @@ TEST(Operator, Operator_BroadcastTo) {
...
@@ -225,33 +190,15 @@ TEST(Operator, Operator_BroadcastTo) {
std
::
string
func_name
=
"broadcast_to"
;
std
::
string
func_name
=
"broadcast_to"
;
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
out_name
=
"C"
;
std
::
string
out_name
=
"C"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{
{
common
::
CINNValue
(
B
),
common
::
CINNValue
(
out_name
)}};
{
common
::
CINNValue
(
B
),
common
::
CINNValue
(
out_name
)}};
std
::
vector
<
std
::
string
>
input_output_names
{
"B"
,
out_name
};
std
::
vector
<
std
::
string
>
input_output_names
{
"B"
,
out_name
};
auto
funcs
=
framework
::
GetFuncFromImpl
(
auto
funcs
=
framework
::
GetFuncFromImpl
(
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
for
(
auto
func
:
funcs
)
{
for
(
auto
func
:
funcs
)
{
LOG
(
INFO
)
<<
"Test Operator_BroadcastTo's Strategy, func is :
\n
"
<<
func
;
}
}
else
{
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
B
)}};
common
::
CINNValuePack
rets
=
impl
->
fcompute
(
cinn_input
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
rets
=
impl
->
fschedule
(
rets
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
// the last element is a StageMap
for
(
int
i
=
0
;
i
<
rets
->
size
()
-
1
;
i
++
)
{
Expr
temp
=
rets
[
i
];
inputs
.
push_back
(
temp
.
as_tensor_ref
());
}
auto
func
=
Lower
(
"func"
+
func_name
,
rets
.
back
(),
inputs
);
LOG
(
INFO
)
<<
"Test Operator_BroadcastTo's Strategy, func is :
\n
"
<<
func
;
LOG
(
INFO
)
<<
"Test Operator_BroadcastTo's Strategy, func is :
\n
"
<<
func
;
}
}
}
}
...
@@ -260,9 +207,7 @@ common::CINNValuePack GetComputeResult(
...
@@ -260,9 +207,7 @@ common::CINNValuePack GetComputeResult(
const
std
::
shared_ptr
<
OpImpl
>
&
impl
,
const
std
::
shared_ptr
<
OpImpl
>
&
impl
,
std
::
vector
<
common
::
CINNValue
>
&
cinn_inputs
,
// NOLINT
std
::
vector
<
common
::
CINNValue
>
&
cinn_inputs
,
// NOLINT
const
std
::
string
&
output_name
=
""
)
{
const
std
::
string
&
output_name
=
""
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
cinn_inputs
.
emplace_back
(
output_name
);
cinn_inputs
.
emplace_back
(
output_name
);
}
return
impl
->
fcompute
(
common
::
CINNValuePack
{
cinn_inputs
});
return
impl
->
fcompute
(
common
::
CINNValuePack
{
cinn_inputs
});
}
}
...
...
paddle/cinn/hlir/op/op_util.cc
浏览文件 @
70183c4b
...
@@ -21,8 +21,6 @@
...
@@ -21,8 +21,6 @@
#include "paddle/cinn/hlir/pe/schedule.h"
#include "paddle/cinn/hlir/pe/schedule.h"
#include "paddle/cinn/ir/ir_schedule.h"
#include "paddle/cinn/ir/ir_schedule.h"
DECLARE_bool
(
cinn_ir_schedule
);
namespace
cinn
{
namespace
cinn
{
namespace
hlir
{
namespace
hlir
{
...
@@ -31,44 +29,24 @@ CINNSchedule GetElementwiseScheduleFunc(
...
@@ -31,44 +29,24 @@ CINNSchedule GetElementwiseScheduleFunc(
const
Target
&
target
,
const
Target
&
target
,
bool
vectorizable
)
{
bool
vectorizable
)
{
return
CINNSchedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
return
CINNSchedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of ElementwiseSchedule is "
CHECK
(
!
args
.
empty
())
<<
"The input argument of ElementwiseSchedule is "
"empty! Please check.
\n
"
;
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
common
::
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
vec_ast
.
emplace_back
(
temp
);
}
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
pe
::
IRElementwiseSchedule
(
ir_sch
,
output_shapes
.
front
(),
target
);
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of ElementwiseSchedule is "
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
poly
::
StageMap
stages
=
arg_pack
[
1
];
CHECK
(
out
.
as_tensor
());
CHECK_EQ
(
arg_pack
.
size
(),
2UL
);
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
pe
::
CudaScheduleInjective
(
stages
[
out
.
as_tensor_ref
()],
output_shapes
.
front
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
ScheduleInjectiveCPU
(
stages
[
out
.
as_tensor_ref
()],
output_shapes
.
front
(),
target
,
vectorizable
);
}
}
*
ret
=
arg_pack
;
}
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
pe
::
IRElementwiseSchedule
(
ir_sch
,
output_shapes
.
front
(),
target
);
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
});
});
}
}
...
@@ -77,50 +55,30 @@ CINNSchedule GetInjectiveScheduleFunc(
...
@@ -77,50 +55,30 @@ CINNSchedule GetInjectiveScheduleFunc(
const
Target
&
target
,
const
Target
&
target
,
bool
vectorizable
)
{
bool
vectorizable
)
{
return
CINNSchedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
return
CINNSchedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of InjectiveSchedule is "
CHECK
(
!
args
.
empty
())
<<
"The input argument of InjectiveSchedule is "
"empty! Please check.
\n
"
;
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
common
::
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
vec_ast
.
emplace_back
(
temp
);
}
}
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
pe
::
IRInjectiveSchedule
(
ir_sch
,
output_shapes
.
front
(),
target
);
/*if (target.arch == Target::Arch::NVGPU) {
pe::IRInjectiveSchedule(ir_sch, output_shapes.front(), target);
} else if (target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target,
vectorizable);
}*/
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of InjectiveSchedule is "
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
poly
::
StageMap
stages
=
arg_pack
[
1
];
CHECK
(
out
.
as_tensor
());
CHECK_EQ
(
arg_pack
.
size
(),
2UL
);
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
pe
::
CudaScheduleInjective
(
stages
[
out
.
as_tensor_ref
()],
output_shapes
.
front
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
ScheduleInjectiveCPU
(
stages
[
out
.
as_tensor_ref
()],
output_shapes
.
front
(),
target
,
vectorizable
);
}
*
ret
=
arg_pack
;
}
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
pe
::
IRInjectiveSchedule
(
ir_sch
,
output_shapes
.
front
(),
target
);
/*if (target.arch == Target::Arch::NVGPU) {
pe::IRInjectiveSchedule(ir_sch, output_shapes.front(), target);
} else if (target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target,
vectorizable);
}*/
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
});
});
}
}
...
...
paddle/cinn/hlir/op/reduction.cc
浏览文件 @
70183c4b
...
@@ -29,8 +29,6 @@
...
@@ -29,8 +29,6 @@
#include "paddle/cinn/ir/ir_schedule.h"
#include "paddle/cinn/ir/ir_schedule.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/ir_simplify.h"
DECLARE_bool
(
cinn_ir_schedule
);
namespace
cinn
{
namespace
cinn
{
namespace
hlir
{
namespace
hlir
{
namespace
op
{
namespace
op
{
...
@@ -115,16 +113,10 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
...
@@ -115,16 +113,10 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
<<
" compute is empty! Please check."
;
<<
" compute is empty! Please check."
;
CINNValuePack
arg_packs
=
args
[
0
];
CINNValuePack
arg_packs
=
args
[
0
];
std
::
string
tensor_name
=
UniqName
(
op_name
+
"_out"
);
CHECK_EQ
(
arg_packs
.
size
(),
2U
)
if
(
FLAGS_cinn_ir_schedule
)
{
<<
"There should be 2 input args for "
<<
op_name
<<
" compute"
;
CHECK_EQ
(
arg_packs
.
size
(),
2U
)
CHECK
(
arg_packs
[
1
].
is_string
());
<<
"There should be 2 input args for "
<<
op_name
<<
" compute"
;
std
::
string
tensor_name
=
arg_packs
[
1
].
operator
std
::
string
();
CHECK
(
arg_packs
[
1
].
is_string
());
tensor_name
=
arg_packs
[
1
].
operator
std
::
string
();
}
else
{
CHECK_EQ
(
arg_packs
.
size
(),
1U
)
<<
"There should be 1 input args for "
<<
op_name
<<
" compute"
;
}
Expr
x_expr
=
arg_packs
[
0
];
Expr
x_expr
=
arg_packs
[
0
];
CHECK
(
x_expr
.
as_tensor
());
CHECK
(
x_expr
.
as_tensor
());
ir
::
Tensor
x
=
x_expr
.
as_tensor_ref
();
ir
::
Tensor
x
=
x_expr
.
as_tensor_ref
();
...
@@ -175,206 +167,137 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
...
@@ -175,206 +167,137 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
lang
::
RetValue
*
ret
)
{
lang
::
RetValue
*
ret
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
<<
" schedule is empty! Please check."
;
<<
" schedule is empty! Please check."
;
CINNValuePack
arg_pack
=
args
[
0
];
if
(
FLAGS_cinn_ir_schedule
)
{
CINNValuePack
arg_pack
=
args
[
0
];
CHECK_GE
(
arg_pack
.
size
(),
2UL
);
CHECK_GE
(
arg_pack
.
size
(),
2UL
);
CHECK_LE
(
arg_pack
.
size
(),
8UL
);
CHECK_LE
(
arg_pack
.
size
(),
8UL
);
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
std
::
vector
<
Expr
>
vec_ast
;
std
::
vector
<
Expr
>
vec_tensor
;
std
::
vector
<
Expr
>
vec_tensor
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
Expr
temp
=
arg_pack
[
i
];
// TODO(zhhsplendid): old reducetion schedule assumes all length-1
// TODO(zhhsplendid): old reducetion schedule assumes all length-1
// for loops are simplified, but it is not after we add length-1
// for loops are simplified, but it is not after we add length-1
// back. Reduction schedule is complex and we haven't changed it to
// back. Reduction schedule is complex and we haven't changed it to
// support the length-1 for loop yet. So we simplify here. The todo
// support the length-1 for loop yet. So we simplify here. The todo
// is that remove SimplifyForLoops below and change reduction schedule
// is that remove SimplifyForLoops below and change reduction schedule
optim
::
SimplifyForLoops
(
&
temp
);
optim
::
SimplifyForLoops
(
&
temp
);
vec_ast
.
emplace_back
(
temp
);
vec_ast
.
emplace_back
(
temp
);
}
else
if
(
arg_pack
[
i
].
is_tensor
())
{
}
else
if
(
arg_pack
[
i
].
is_tensor
())
{
Expr
temp
=
arg_pack
[
i
];
Expr
temp
=
arg_pack
[
i
];
vec_tensor
.
emplace_back
(
temp
);
vec_tensor
.
emplace_back
(
temp
);
}
}
}
CHECK
(
!
vec_ast
.
empty
());
}
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
CHECK
(
!
vec_ast
.
empty
());
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir_sch
.
MergeExprs
();
ir
::
IRSchedule
ir_sch
(
mod_expr
);
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
ir_sch
.
MergeExprs
();
if
(
!
WithoutLastDimInReduce
(
inputs
[
0
]
->
shape
,
reduce_axes
))
{
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
if
(
arg_pack
.
size
()
==
4
)
{
if
(
!
WithoutLastDimInReduce
(
inputs
[
0
]
->
shape
,
reduce_axes
))
{
CHECK_EQ
(
vec_tensor
.
size
(),
2
);
if
(
arg_pack
.
size
()
==
4
)
{
Expr
out
=
vec_tensor
[
0
];
CHECK_EQ
(
vec_tensor
.
size
(),
2
);
Expr
tmp_out
=
vec_tensor
[
1
];
Expr
out
=
vec_tensor
[
0
];
Expr
tmp_out
=
vec_tensor
[
1
];
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockReduceInternal Schedule!"
;
pe
::
IRCudaScheduleBlockReduceInternal
(
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockReduceInternal Schedule!"
;
ir_sch
,
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
target
);
pe
::
IRCudaScheduleBlockReduceInternal
(
ir_sch
,
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
target
);
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
std
::
vector
<
CINNValue
>
res
{
*
ret
=
CINNValuePack
{
res
};
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
}
else
if
(
arg_pack
.
size
()
==
6
)
{
*
ret
=
CINNValuePack
{
res
};
CHECK_EQ
(
vec_tensor
.
size
(),
3
);
}
else
if
(
arg_pack
.
size
()
==
6
)
{
Expr
out
=
vec_tensor
[
0
];
CHECK_EQ
(
vec_tensor
.
size
(),
3
);
Expr
tmp_out
=
vec_tensor
[
1
];
Expr
out
=
vec_tensor
[
0
];
Expr
reduce_tmp_out
=
vec_tensor
[
2
];
Expr
tmp_out
=
vec_tensor
[
1
];
Expr
reduce_tmp_out
=
vec_tensor
[
2
];
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockReduce Schedule!"
;
pe
::
IRCudaScheduleBlockReduce
(
ir_sch
,
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockReduce Schedule!"
;
reduce_tmp_out
.
as_tensor_ref
(),
pe
::
IRCudaScheduleBlockReduce
(
ir_sch
,
tmp_out
.
as_tensor_ref
(),
reduce_tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
tmp_out
.
as_tensor_ref
(),
target
);
out
.
as_tensor_ref
(),
target
);
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
std
::
vector
<
CINNValue
>
res
{
*
ret
=
CINNValuePack
{
res
};
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
}
else
if
(
arg_pack
.
size
()
==
7
)
{
*
ret
=
CINNValuePack
{
res
};
CHECK_EQ
(
vec_tensor
.
size
(),
4
);
}
else
if
(
arg_pack
.
size
()
==
7
)
{
Expr
out
=
vec_tensor
[
0
];
CHECK_EQ
(
vec_tensor
.
size
(),
4
);
Expr
tmp_out
=
vec_tensor
[
1
];
Expr
out
=
vec_tensor
[
0
];
Expr
reduce_tmp_out
=
vec_tensor
[
2
];
Expr
tmp_out
=
vec_tensor
[
1
];
Expr
reshape
=
vec_tensor
[
3
];
Expr
reduce_tmp_out
=
vec_tensor
[
2
];
Expr
reshape
=
vec_tensor
[
3
];
VLOG
(
3
)
<<
"Do IRCudaTwoStepReduceSchedule Schedule!"
;
pe
::
IRCudaTwoStepReduceSchedule
(
ir_sch
,
VLOG
(
3
)
<<
"Do IRCudaTwoStepReduceSchedule Schedule!"
;
reshape
.
as_tensor_ref
(),
pe
::
IRCudaTwoStepReduceSchedule
(
ir_sch
,
reduce_tmp_out
.
as_tensor_ref
(),
reshape
.
as_tensor_ref
(),
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
common
::
DefaultNVGPUTarget
());
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
if
(
arg_pack
.
size
()
==
5
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
3
);
Expr
out
=
vec_tensor
[
0
];
Expr
tmp_out
=
vec_tensor
[
1
];
Expr
reduce_tmp_out
=
vec_tensor
[
2
];
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockReduce Schedule!"
;
pe
::
IRCudaScheduleBlockReduce
(
ir_sch
,
reduce_tmp_out
.
as_tensor_ref
(),
reduce_tmp_out
.
as_tensor_ref
(),
tmp_out
.
as_tensor_ref
(),
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
common
::
DefaultNVGPUTarget
());
common
::
DefaultNVGPUTarget
());
std
::
vector
<
CINNValue
>
res
{
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
*
ret
=
CINNValuePack
{
res
};
}
else
{
}
else
if
(
arg_pack
.
size
()
==
5
)
{
LOG
(
FATAL
)
<<
"Unkown Reduce Type!"
;
CHECK_EQ
(
vec_tensor
.
size
(),
3
);
}
Expr
out
=
vec_tensor
[
0
];
}
else
{
Expr
tmp_out
=
vec_tensor
[
1
];
if
(
arg_pack
.
size
()
==
2
)
{
Expr
reduce_tmp_out
=
vec_tensor
[
2
];
CHECK_EQ
(
vec_tensor
.
size
(),
1
);
Expr
reduce_out
=
vec_tensor
[
0
];
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockReduce Schedule!"
;
pe
::
IRCudaScheduleBlockReduce
(
ir_sch
,
VLOG
(
3
)
<<
"Do IRCudaScheduleReduce Schedule!"
;
pe
::
IRCudaScheduleReduce
(
ir_sch
,
reduce_out
.
as_tensor_ref
(),
inputs
[
0
]
->
shape
.
size
()
-
reduce_axes
.
back
()
-
1
,
target
);
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
if
(
arg_pack
.
size
()
==
6
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
3
);
Expr
reduce_out
=
vec_tensor
[
0
];
Expr
reduce_internal
=
vec_tensor
[
1
];
Expr
reduce_reshape
=
vec_tensor
[
2
];
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockShuffleReduce Schedule!"
;
pe
::
IRCudaScheduleBlockShuffleReduce
(
ir_sch
,
reduce_reshape
.
as_tensor_ref
(),
reduce_internal
.
as_tensor_ref
(),
reduce_out
.
as_tensor_ref
(),
target
);
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
{
LOG
(
FATAL
)
<<
"Unkown Reduce Type!"
;
}
}
}
else
{
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
}
else
{
CHECK_GE
(
arg_pack
.
size
(),
2UL
);
CHECK_LE
(
arg_pack
.
size
(),
5UL
);
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
if
(
!
WithoutLastDimInReduce
(
inputs
[
0
]
->
shape
,
reduce_axes
))
{
if
(
arg_pack
.
size
()
==
3
)
{
Expr
out
=
arg_pack
[
0
];
Expr
tmp_out
=
arg_pack
[
1
];
poly
::
StageMap
stages
=
arg_pack
.
back
();
VLOG
(
3
)
<<
"Do CudaBlockReduceInternalSchedule Schedule!"
;
pe
::
CudaBlockReduceInternalSchedule
(
stages
,
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
common
::
DefaultNVGPUTarget
());
}
else
if
(
arg_pack
.
size
()
==
4
)
{
Expr
out
=
arg_pack
[
0
];
Expr
tmp_out
=
arg_pack
[
1
];
Expr
reduce_tmp_out
=
arg_pack
[
2
];
poly
::
StageMap
stages
=
arg_pack
.
back
();
VLOG
(
3
)
<<
"Do CudaBlockReduceSchedule Schedule!"
;
pe
::
CudaBlockReduceSchedule
(
stages
,
reduce_tmp_out
.
as_tensor_ref
(),
reduce_tmp_out
.
as_tensor_ref
(),
tmp_out
.
as_tensor_ref
(),
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
common
::
DefaultNVGPUTarget
());
common
::
DefaultNVGPUTarget
());
}
else
{
Expr
out
=
arg_pack
[
0
];
std
::
vector
<
CINNValue
>
res
{
Expr
tmp_out
=
arg_pack
[
1
];
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
Expr
reduce_tmp_out
=
arg_pack
[
2
];
*
ret
=
CINNValuePack
{
res
};
Expr
reshape
=
arg_pack
[
3
];
poly
::
StageMap
stages
=
arg_pack
.
back
();
VLOG
(
3
)
<<
"Do CudaTwoStepReduceSchedule Schedule!"
;
pe
::
CudaTwoStepReduceSchedule
(
stages
,
reshape
.
as_tensor_ref
(),
reduce_tmp_out
.
as_tensor_ref
(),
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
common
::
DefaultNVGPUTarget
());
}
}
else
{
}
else
{
if
(
arg_pack
.
size
()
==
2
)
{
LOG
(
FATAL
)
<<
"Unkown Reduce Type!"
;
Expr
reduce_out
=
arg_pack
[
0
];
}
poly
::
StageMap
stages
=
arg_pack
.
back
();
}
else
{
VLOG
(
3
)
<<
"Do CudaReduceSchedule Schedule!"
;
if
(
arg_pack
.
size
()
==
2
)
{
pe
::
CudaReduceSchedule
(
CHECK_EQ
(
vec_tensor
.
size
(),
1
);
stages
,
Expr
reduce_out
=
vec_tensor
[
0
];
reduce_out
.
as_tensor_ref
(),
inputs
[
0
]
->
shape
.
size
()
-
reduce_axes
.
back
()
-
1
,
VLOG
(
3
)
<<
"Do IRCudaScheduleReduce Schedule!"
;
target
);
pe
::
IRCudaScheduleReduce
(
}
else
{
ir_sch
,
CHECK_EQ
(
arg_pack
.
size
(),
4
)
<<
"args is not equal 4!"
;
reduce_out
.
as_tensor_ref
(),
Expr
reduce_reshape
=
arg_pack
[
2
];
inputs
[
0
]
->
shape
.
size
()
-
reduce_axes
.
back
()
-
1
,
Expr
reduce_internal
=
arg_pack
[
1
];
target
);
Expr
reduce_out
=
arg_pack
[
0
];
poly
::
StageMap
stages
=
arg_pack
.
back
();
std
::
vector
<
CINNValue
>
res
{
VLOG
(
3
)
<<
"Do CudaBlockShuffleReduceSchedule Schedule!"
;
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
pe
::
CudaBlockShuffleReduceSchedule
(
stages
,
*
ret
=
CINNValuePack
{
res
};
}
else
if
(
arg_pack
.
size
()
==
6
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
3
);
Expr
reduce_out
=
vec_tensor
[
0
];
Expr
reduce_internal
=
vec_tensor
[
1
];
Expr
reduce_reshape
=
vec_tensor
[
2
];
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockShuffleReduce Schedule!"
;
pe
::
IRCudaScheduleBlockShuffleReduce
(
ir_sch
,
reduce_reshape
.
as_tensor_ref
(),
reduce_reshape
.
as_tensor_ref
(),
reduce_internal
.
as_tensor_ref
(),
reduce_internal
.
as_tensor_ref
(),
reduce_out
.
as_tensor_ref
(),
reduce_out
.
as_tensor_ref
(),
target
);
target
);
}
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
{
LOG
(
FATAL
)
<<
"Unkown Reduce Type!"
;
}
}
}
}
*
ret
=
arg_pack
;
}
else
{
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
}
});
});
...
...
paddle/cinn/hlir/op/transform.cc
浏览文件 @
70183c4b
...
@@ -73,12 +73,9 @@ std::shared_ptr<OpStrategy> StrategyForMatMul(
...
@@ -73,12 +73,9 @@ std::shared_ptr<OpStrategy> StrategyForMatMul(
CHECK
(
A
.
as_tensor
());
CHECK
(
A
.
as_tensor
());
CHECK
(
B
.
as_tensor
());
CHECK
(
B
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"MatMul"
);
CHECK_GE
(
pack_args
.
size
(),
3
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
pack_args
[
2
].
is_string
());
CHECK_GE
(
pack_args
.
size
(),
3
);
std
::
string
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
CHECK
(
pack_args
[
2
].
is_string
());
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
}
auto
tensor_A
=
A
.
as_tensor_ref
();
auto
tensor_A
=
A
.
as_tensor_ref
();
auto
tensor_B
=
B
.
as_tensor_ref
();
auto
tensor_B
=
B
.
as_tensor_ref
();
...
@@ -130,32 +127,9 @@ std::shared_ptr<OpStrategy> StrategyForMatMul(
...
@@ -130,32 +127,9 @@ std::shared_ptr<OpStrategy> StrategyForMatMul(
CHECK
(
!
args
.
empty
())
CHECK
(
!
args
.
empty
())
<<
"The input argument of matmul schedule is empty! Please check.
\n
"
;
<<
"The input argument of matmul schedule is empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
CINNValuePack
arg_pack
=
args
[
0
];
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
vector
<
CINNValue
>
results
=
std
::
vector
<
CINNValue
>
results
=
pe
::
IRCudaScheduleMatMul
(
arg_pack
,
output_shape
,
target
);
pe
::
IRCudaScheduleMatMul
(
arg_pack
,
output_shape
,
target
);
*
ret
=
CINNValuePack
({
results
});
*
ret
=
CINNValuePack
({
results
});
}
else
{
CHECK
(
arg_pack
.
size
()
==
2UL
||
arg_pack
.
size
()
==
3UL
);
poly
::
StageMap
stages
=
arg_pack
.
back
();
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
Expr
out
=
arg_pack
[
0
];
CHECK
(
out
.
as_tensor
());
pe
::
MatmulScheduleCUDA
(
stages
,
out
.
as_tensor_ref
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
#ifdef CINN_WITH_MKL_CBLAS
CHECK_EQ
(
arg_pack
.
size
(),
3UL
);
#else
CHECK_EQ
(
arg_pack
.
size
(),
3UL
);
Expr
out
=
arg_pack
[
0
];
Expr
packedB
=
arg_pack
[
1
];
CHECK
(
packedB
.
as_tensor
());
CHECK
(
out
.
as_tensor
());
pe
::
MatmulScheduleCPU
(
stages
,
out
.
as_tensor_ref
(),
packedB
.
as_tensor_ref
(),
target
);
#endif
}
*
ret
=
arg_pack
;
}
});
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
@@ -262,16 +236,10 @@ std::shared_ptr<OpStrategy> StrategyForSplit(
...
@@ -262,16 +236,10 @@ std::shared_ptr<OpStrategy> StrategyForSplit(
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
std
::
vector
<
std
::
string
>
tensor_names
;
std
::
vector
<
std
::
string
>
tensor_names
;
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
output_shapes
.
size
()
+
1
);
CHECK_EQ
(
pack_args
.
size
(),
output_shapes
.
size
()
+
1
);
for
(
int
idx
=
1
;
idx
<
pack_args
.
size
();
++
idx
)
{
for
(
int
idx
=
1
;
idx
<
pack_args
.
size
();
++
idx
)
{
CHECK
(
pack_args
[
idx
].
is_string
());
CHECK
(
pack_args
[
idx
].
is_string
());
tensor_names
.
push_back
(
pack_args
[
idx
].
operator
std
::
string
());
tensor_names
.
push_back
(
pack_args
[
idx
].
operator
std
::
string
());
}
}
else
{
for
(
int
idx
=
0
;
idx
<
output_shapes
.
size
();
++
idx
)
{
tensor_names
.
push_back
(
UniqName
(
"T_Split_Out"
));
}
}
}
auto
out
=
pe
::
Split
(
A
,
axis
,
output_shapes
,
tensor_names
);
auto
out
=
pe
::
Split
(
A
,
axis
,
output_shapes
,
tensor_names
);
...
@@ -285,38 +253,27 @@ std::shared_ptr<OpStrategy> StrategyForSplit(
...
@@ -285,38 +253,27 @@ std::shared_ptr<OpStrategy> StrategyForSplit(
*
ret
=
CINNValuePack
{
res
};
*
ret
=
CINNValuePack
{
res
};
});
});
framework
::
CINNSchedule
split_schedule
(
[
=
](
lang
::
Args
args
,
framework
::
CINNSchedule
split_schedule
(
lang
::
RetValue
*
ret
)
{
[
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
CHECK
(
!
args
.
empty
())
<<
"The input argument of split schedule is empty! Please check."
;
<<
"The input argument of split schedule is empty! Please check."
;
CINNValuePack
arg_pack
=
args
[
0
]
;
CINNValuePack
arg_pack
=
args
[
0
]
;
std
::
vector
<
Expr
>
vec_ast
;
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
()
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
Expr
temp
=
arg_pack
[
i
]
;
vec_ast
.
emplace_back
(
temp
)
;
vec_ast
.
emplace_back
(
temp
);
}
}
}
}
CHECK
(
!
vec_ast
.
empty
());
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
ir_sch
.
MergeExprs
();
pe
::
IRCudaSplitSchedule
(
ir_sch
,
output_shapes
,
axis
,
target
);
pe
::
IRCudaSplitSchedule
(
ir_sch
,
output_shapes
,
axis
,
target
);
std
::
vector
<
CINNValue
>
res
{
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
*
ret
=
CINNValuePack
{
res
};
});
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input arguments of split schedule is empty! Please check."
;
CINNValuePack
arg_pack
=
args
[
0
];
CHECK_GE
(
arg_pack
.
size
(),
2UL
)
<<
"The input tensor's size of split schedule is "
<<
arg_pack
.
size
()
<<
"and it should be greater equal to 2! Please check."
;
pe
::
CudaSplitSchedule
(
&
arg_pack
,
output_shapes
,
axis
,
target
);
*
ret
=
arg_pack
;
}
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
strategy
->
AddImpl
(
split_compute
,
split_schedule
,
"strategy.split.x86"
,
1
);
strategy
->
AddImpl
(
split_compute
,
split_schedule
,
"strategy.split.x86"
,
1
);
...
@@ -468,8 +425,7 @@ std::shared_ptr<OpStrategy> StrategyForConcat(
...
@@ -468,8 +425,7 @@ std::shared_ptr<OpStrategy> StrategyForConcat(
CHECK
(
!
out_type
.
empty
())
CHECK
(
!
out_type
.
empty
())
<<
"Output type of Concat is empty! Please check.
\n
"
;
<<
"Output type of Concat is empty! Please check.
\n
"
;
CINNValuePack
pack_args
=
args
[
0
];
CINNValuePack
pack_args
=
args
[
0
];
int
input_size
=
int
input_size
=
pack_args
.
size
()
-
1
;
FLAGS_cinn_ir_schedule
?
pack_args
.
size
()
-
1
:
pack_args
.
size
();
CHECK_GE
(
input_size
,
1UL
)
CHECK_GE
(
input_size
,
1UL
)
<<
"at least 2 input tensors for Concat compute
\n
"
;
<<
"at least 2 input tensors for Concat compute
\n
"
;
CHECK
(
!
output_shapes
.
empty
());
CHECK
(
!
output_shapes
.
empty
());
...
@@ -485,11 +441,8 @@ std::shared_ptr<OpStrategy> StrategyForConcat(
...
@@ -485,11 +441,8 @@ std::shared_ptr<OpStrategy> StrategyForConcat(
input_tensors
.
push_back
(
tensor
.
as_tensor_ref
());
input_tensors
.
push_back
(
tensor
.
as_tensor_ref
());
}
}
std
::
string
tensor_name
=
UniqName
(
"Concat_output"
);
CHECK
(
pack_args
[
input_size
].
is_string
());
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
tensor_name
=
pack_args
[
input_size
].
operator
std
::
string
();
CHECK
(
pack_args
[
input_size
].
is_string
());
tensor_name
=
pack_args
[
input_size
].
operator
std
::
string
();
}
auto
stages
=
CreateStages
(
input_tensors
);
auto
stages
=
CreateStages
(
input_tensors
);
auto
out
=
pe
::
Concat
(
input_tensors
,
axis
,
tensor_name
);
auto
out
=
pe
::
Concat
(
input_tensors
,
axis
,
tensor_name
);
...
@@ -612,11 +565,8 @@ std::shared_ptr<OpStrategy> StrategyForMul(
...
@@ -612,11 +565,8 @@ std::shared_ptr<OpStrategy> StrategyForMul(
auto
new_B
=
B_tensor
->
Reshape
(
new_shape_B_e
,
stages
);
auto
new_B
=
B_tensor
->
Reshape
(
new_shape_B_e
,
stages
);
std
::
vector
<
ir
::
Tensor
>
out
;
std
::
vector
<
ir
::
Tensor
>
out
;
std
::
string
tensor_name
=
UniqName
(
"Mul_output"
);
CHECK
(
pack_args
.
back
().
is_string
());
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
tensor_name
=
pack_args
.
back
().
operator
std
::
string
();
CHECK
(
pack_args
.
back
().
is_string
());
tensor_name
=
pack_args
.
back
().
operator
std
::
string
();
}
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
#ifdef CINN_WITH_MKL_CBLAS
#ifdef CINN_WITH_MKL_CBLAS
...
@@ -647,32 +597,9 @@ std::shared_ptr<OpStrategy> StrategyForMul(
...
@@ -647,32 +597,9 @@ std::shared_ptr<OpStrategy> StrategyForMul(
CHECK
(
!
args
.
empty
())
CHECK
(
!
args
.
empty
())
<<
"The input argument of matmul schedule is empty! Please check.
\n
"
;
<<
"The input argument of matmul schedule is empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
CINNValuePack
arg_pack
=
args
[
0
];
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
vector
<
CINNValue
>
results
=
std
::
vector
<
CINNValue
>
results
=
pe
::
IRCudaScheduleMatMul
(
arg_pack
,
output_shape
,
target
);
pe
::
IRCudaScheduleMatMul
(
arg_pack
,
output_shape
,
target
);
*
ret
=
CINNValuePack
({
results
});
*
ret
=
CINNValuePack
({
results
});
}
else
{
CHECK
(
arg_pack
.
size
()
==
2UL
||
arg_pack
.
size
()
==
3UL
);
poly
::
StageMap
stages
=
arg_pack
.
back
();
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
Expr
out
=
arg_pack
[
0
];
CHECK
(
out
.
as_tensor
());
pe
::
MatmulScheduleCUDA
(
stages
,
out
.
as_tensor_ref
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
#ifdef CINN_WITH_MKL_CBLAS
CHECK_EQ
(
arg_pack
.
size
(),
3UL
);
#else
CHECK_EQ
(
arg_pack
.
size
(),
3UL
);
Expr
out
=
arg_pack
[
0
];
Expr
packedB
=
arg_pack
[
1
];
CHECK
(
packedB
.
as_tensor
());
CHECK
(
out
.
as_tensor
());
pe
::
MatmulScheduleCPU
(
stages
,
out
.
as_tensor_ref
(),
packedB
.
as_tensor_ref
(),
target
);
#endif
}
*
ret
=
arg_pack
;
}
});
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
@@ -780,12 +707,9 @@ std::shared_ptr<OpStrategy> StrategyForCublasGemm(
...
@@ -780,12 +707,9 @@ std::shared_ptr<OpStrategy> StrategyForCublasGemm(
// dummy gemm computation, which will be replaced by
// dummy gemm computation, which will be replaced by
// cinn_gpu_cublas_gemm in the GemmRewriter pass.
// cinn_gpu_cublas_gemm in the GemmRewriter pass.
std
::
string
tensor_name
=
UniqName
(
"cublas_gemm_output"
);
CHECK_EQ
(
input_args
.
size
(),
4
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
input_args
[
3
].
is_string
());
CHECK_EQ
(
input_args
.
size
(),
4
);
std
::
string
tensor_name
=
input_args
[
3
].
operator
std
::
string
();
CHECK
(
input_args
[
3
].
is_string
());
tensor_name
=
input_args
[
3
].
operator
std
::
string
();
}
auto
out
=
pe
::
Identity
(
bias_tensor
,
tensor_name
).
front
();
auto
out
=
pe
::
Identity
(
bias_tensor
,
tensor_name
).
front
();
auto
stages
=
CreateStages
(
auto
stages
=
CreateStages
(
{
lhs
.
as_tensor_ref
(),
rhs
.
as_tensor_ref
(),
bias_tensor
});
{
lhs
.
as_tensor_ref
(),
rhs
.
as_tensor_ref
(),
bias_tensor
});
...
@@ -849,12 +773,9 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform(
...
@@ -849,12 +773,9 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform(
Expr
A
=
input_args
[
0
];
Expr
A
=
input_args
[
0
];
CHECK
(
A
.
as_tensor
());
CHECK
(
A
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"layout_transform_output"
);
CHECK_EQ
(
input_args
.
size
(),
2
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
input_args
[
1
].
is_string
());
CHECK_EQ
(
input_args
.
size
(),
2
);
std
::
string
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
CHECK
(
input_args
[
1
].
is_string
());
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
}
auto
out
=
pe
::
LayoutTransform
(
auto
out
=
pe
::
LayoutTransform
(
A
.
as_tensor_ref
(),
src_layout
,
dst_layout
,
tensor_name
);
A
.
as_tensor_ref
(),
src_layout
,
dst_layout
,
tensor_name
);
...
@@ -865,53 +786,31 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform(
...
@@ -865,53 +786,31 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform(
*
ret
=
CINNValuePack
{
res
};
*
ret
=
CINNValuePack
{
res
};
});
});
framework
::
CINNSchedule
layout_transform_schedule
(
framework
::
CINNSchedule
layout_transform_schedule
([
=
](
lang
::
Args
args
,
[
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of CublasGemm schedule "
CHECK
(
!
args
.
empty
())
<<
"The input argument of CublasGemm schedule "
"is empty! Please check."
;
"is empty! Please check."
;
CINNValuePack
arg_pack
=
args
[
0
];
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
vec_ast
.
emplace_back
(
temp
);
}
}
}
}
CHECK
(
!
vec_ast
.
empty
());
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
ir_sch
.
MergeExprs
();
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
);
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
);
}
else
{
}
else
{
CINN_NOT_IMPLEMENTED
CINN_NOT_IMPLEMENTED
}
}
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
std
::
vector
<
CINNValue
>
res
{
*
ret
=
CINNValuePack
{
res
};
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
});
*
ret
=
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of layout_transform "
"schedule is empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
CHECK_EQ
(
arg_pack
.
size
(),
2UL
);
Expr
out
=
arg_pack
[
0
];
poly
::
StageMap
stages
=
arg_pack
[
1
];
CHECK
(
out
.
as_tensor
());
auto
tensor_out
=
out
.
as_tensor_ref
();
std
::
vector
<
int
>
out_shape
;
for
(
auto
shape
:
tensor_out
->
shape
)
{
out_shape
.
push_back
(
shape
.
as_int32
());
}
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
ScheduleInjectiveCPU
(
stages
[
tensor_out
],
out_shape
,
target
);
}
else
{
CINN_NOT_IMPLEMENTED
}
*
ret
=
arg_pack
;
}
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
CHECK
(
out_type
.
size
())
CHECK
(
out_type
.
size
())
...
@@ -996,12 +895,9 @@ std::shared_ptr<OpStrategy> StrategyForReverse(
...
@@ -996,12 +895,9 @@ std::shared_ptr<OpStrategy> StrategyForReverse(
Expr
A
=
input_args
[
0
];
Expr
A
=
input_args
[
0
];
CHECK
(
A
.
as_tensor
());
CHECK
(
A
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"Reverse_output"
);
CHECK_EQ
(
input_args
.
size
(),
2
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
input_args
[
1
].
is_string
());
CHECK_EQ
(
input_args
.
size
(),
2
);
std
::
string
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
CHECK
(
input_args
[
1
].
is_string
());
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
}
auto
out
=
pe
::
Reverse
(
A
.
as_tensor_ref
(),
axis
,
tensor_name
);
auto
out
=
pe
::
Reverse
(
A
.
as_tensor_ref
(),
axis
,
tensor_name
);
auto
stages
=
CreateStages
({
A
.
as_tensor_ref
(),
out
});
auto
stages
=
CreateStages
({
A
.
as_tensor_ref
(),
out
});
...
@@ -1113,12 +1009,9 @@ std::shared_ptr<OpStrategy> StrategyForTranspose(
...
@@ -1113,12 +1009,9 @@ std::shared_ptr<OpStrategy> StrategyForTranspose(
<<
"at least one input tensor for transpose compute
\n
"
;
<<
"at least one input tensor for transpose compute
\n
"
;
Expr
A
=
input_args
[
0
];
Expr
A
=
input_args
[
0
];
CHECK
(
A
.
as_tensor
());
CHECK
(
A
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"Transpose_output"
);
CHECK_EQ
(
input_args
.
size
(),
2
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
input_args
[
1
].
is_string
());
CHECK_EQ
(
input_args
.
size
(),
2
);
std
::
string
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
CHECK
(
input_args
[
1
].
is_string
());
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
}
auto
out
=
pe
::
Transpose
(
A
.
as_tensor_ref
(),
axis
,
tensor_name
);
auto
out
=
pe
::
Transpose
(
A
.
as_tensor_ref
(),
axis
,
tensor_name
);
auto
stages
=
CreateStages
({
out
});
auto
stages
=
CreateStages
({
out
});
...
@@ -1236,12 +1129,9 @@ std::shared_ptr<OpStrategy> StrategyForGather(
...
@@ -1236,12 +1129,9 @@ std::shared_ptr<OpStrategy> StrategyForGather(
Expr
index
=
input_args
[
1
];
Expr
index
=
input_args
[
1
];
CHECK
(
index
.
as_tensor
());
CHECK
(
index
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"gather_output"
);
CHECK_EQ
(
input_args
.
size
(),
3U
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
input_args
[
2
].
is_string
());
CHECK_EQ
(
input_args
.
size
(),
3U
);
std
::
string
tensor_name
=
input_args
[
2
].
operator
std
::
string
();
CHECK
(
input_args
[
2
].
is_string
());
tensor_name
=
input_args
[
2
].
operator
std
::
string
();
}
auto
out
=
pe
::
Gather
(
x
.
as_tensor_ref
(),
auto
out
=
pe
::
Gather
(
x
.
as_tensor_ref
(),
index
.
as_tensor_ref
(),
index
.
as_tensor_ref
(),
...
@@ -1335,12 +1225,9 @@ std::shared_ptr<OpStrategy> StrategyForScatterAssign(
...
@@ -1335,12 +1225,9 @@ std::shared_ptr<OpStrategy> StrategyForScatterAssign(
auto
stages
=
CreateStages
({
tensor_input
,
tensor_updates
,
tensor_index
});
auto
stages
=
CreateStages
({
tensor_input
,
tensor_updates
,
tensor_index
});
std
::
string
tensor_name
=
UniqName
(
"scatter_assign_output"
);
CHECK_EQ
(
arg_pack
.
size
(),
4U
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
arg_pack
[
3
].
is_string
());
CHECK_EQ
(
arg_pack
.
size
(),
4U
);
std
::
string
tensor_name
=
arg_pack
[
3
].
operator
std
::
string
();
CHECK
(
arg_pack
[
3
].
is_string
());
tensor_name
=
arg_pack
[
3
].
operator
std
::
string
();
}
auto
out
=
pe
::
ScatterAssign
(
auto
out
=
pe
::
ScatterAssign
(
tensor_input
,
tensor_updates
,
tensor_index
,
target
,
axis
,
tensor_name
);
tensor_input
,
tensor_updates
,
tensor_index
,
target
,
axis
,
tensor_name
);
...
@@ -1462,12 +1349,9 @@ std::shared_ptr<OpStrategy> StrategyForScatterAdd(
...
@@ -1462,12 +1349,9 @@ std::shared_ptr<OpStrategy> StrategyForScatterAdd(
auto
stages
=
CreateStages
({
tensor_input
,
tensor_updates
,
tensor_index
});
auto
stages
=
CreateStages
({
tensor_input
,
tensor_updates
,
tensor_index
});
std
::
string
tensor_name
=
UniqName
(
"scatter_add_output"
);
CHECK_EQ
(
arg_pack
.
size
(),
4U
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
arg_pack
[
3
].
is_string
());
CHECK_EQ
(
arg_pack
.
size
(),
4U
);
std
::
string
tensor_name
=
arg_pack
[
3
].
operator
std
::
string
();
CHECK
(
arg_pack
[
3
].
is_string
());
tensor_name
=
arg_pack
[
3
].
operator
std
::
string
();
}
auto
out
=
pe
::
ScatterAdd
(
auto
out
=
pe
::
ScatterAdd
(
tensor_input
,
tensor_updates
,
tensor_index
,
target
,
axis
,
tensor_name
);
tensor_input
,
tensor_updates
,
tensor_index
,
target
,
axis
,
tensor_name
);
...
@@ -1617,12 +1501,9 @@ std::shared_ptr<OpStrategy> StrategyForSlice(
...
@@ -1617,12 +1501,9 @@ std::shared_ptr<OpStrategy> StrategyForSlice(
CHECK
(
A_expr
.
as_tensor
());
CHECK
(
A_expr
.
as_tensor
());
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
std
::
string
tensor_name
=
UniqName
(
"Slice_output"
);
CHECK_EQ
(
arg_pack
.
size
(),
2U
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
arg_pack
[
1
].
is_string
());
CHECK_EQ
(
arg_pack
.
size
(),
2U
);
std
::
string
tensor_name
=
arg_pack
[
1
].
operator
std
::
string
();
CHECK
(
arg_pack
[
1
].
is_string
());
tensor_name
=
arg_pack
[
1
].
operator
std
::
string
();
}
auto
out
=
pe
::
Slice
(
auto
out
=
pe
::
Slice
(
A
,
starts
,
axes
,
strides
,
decrease_axis
,
output_shape
,
tensor_name
);
A
,
starts
,
axes
,
strides
,
decrease_axis
,
output_shape
,
tensor_name
);
...
@@ -1854,12 +1735,9 @@ std::shared_ptr<OpStrategy> StrategyForSliceAssign(
...
@@ -1854,12 +1735,9 @@ std::shared_ptr<OpStrategy> StrategyForSliceAssign(
Expr
assign
=
arg_pack
[
1
];
Expr
assign
=
arg_pack
[
1
];
CHECK
(
assign
.
as_tensor
());
CHECK
(
assign
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"slice_assign_output"
);
CHECK_EQ
(
arg_pack
.
size
(),
3U
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
arg_pack
[
2
].
is_string
());
CHECK_EQ
(
arg_pack
.
size
(),
3U
);
std
::
string
tensor_name
=
arg_pack
[
2
].
operator
std
::
string
();
CHECK
(
arg_pack
[
2
].
is_string
());
tensor_name
=
arg_pack
[
2
].
operator
std
::
string
();
}
auto
out
=
pe
::
SliceAssign
(
input
.
as_tensor_ref
(),
auto
out
=
pe
::
SliceAssign
(
input
.
as_tensor_ref
(),
assign
.
as_tensor_ref
(),
assign
.
as_tensor_ref
(),
...
...
paddle/cinn/hlir/op/transform_test.cc
浏览文件 @
70183c4b
...
@@ -86,40 +86,18 @@ TEST(SliceAssign, SliceAssign_Op) {
...
@@ -86,40 +86,18 @@ TEST(SliceAssign, SliceAssign_Op) {
std
::
string
func_name
=
"slice_assign"
;
std
::
string
func_name
=
"slice_assign"
;
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
out_name
=
"output"
;
std
::
string
out_name
=
"output"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
input
.
tensor
()),
common
::
CINNValuePack
{{
common
::
CINNValue
(
input
.
tensor
()),
common
::
CINNValue
(
assign
.
tensor
()),
common
::
CINNValue
(
assign
.
tensor
()),
common
::
CINNValue
(
out_name
)}};
common
::
CINNValue
(
out_name
)}};
std
::
vector
<
std
::
string
>
input_output_names
{
"input"
,
"assign"
,
out_name
};
std
::
vector
<
std
::
string
>
input_output_names
{
"input"
,
"assign"
,
out_name
};
auto
funcs
=
framework
::
GetFuncFromImpl
(
auto
funcs
=
framework
::
GetFuncFromImpl
(
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
for
(
auto
func
:
funcs
)
{
for
(
auto
func
:
funcs
)
{
LOG
(
INFO
)
<<
"Test Operator_BroadcastTo's Strategy, func is :
\n
"
<<
func
;
LOG
(
INFO
)
<<
"Test Operator_BroadcastTo's Strategy, func is :
\n
"
<<
func
;
}
}
else
{
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
input
.
tensor
()),
common
::
CINNValue
(
assign
.
tensor
())}};
common
::
CINNValuePack
rets
=
impl
->
fcompute
(
cinn_input
);
rets
=
impl
->
fschedule
(
rets
);
// the last element is a StageMap
for
(
int
i
=
0
;
i
<
rets
->
size
()
-
1
;
i
++
)
{
Expr
temp
=
rets
[
i
];
if
(
!
temp
.
as_tensor_ref
()
->
buffer
.
defined
())
{
inputs
.
push_back
(
temp
.
as_tensor_ref
());
}
}
auto
func
=
lang
::
LowerVec
(
"slice_assign"
,
rets
.
back
(),
inputs
,
{},
{},
nullptr
,
target
);
for
(
auto
&
f
:
func
)
{
LOG
(
INFO
)
<<
"Test Strategy Codegen:
\n
"
<<
f
;
}
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录