Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
70183c4b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
70183c4b
编写于
7月 17, 2023
作者:
H
Huihuang Zheng
提交者:
GitHub
7月 17, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove Old Schedules in Ops (#55391)
Remove old schedules.
上级
db1f2c42
变更
15
展开全部
隐藏空白更改
内联
并排
Showing
15 changed file
with
729 addition
and
1519 deletion
+729
-1519
paddle/cinn/hlir/op/broadcast.cc
paddle/cinn/hlir/op/broadcast.cc
+35
-43
paddle/cinn/hlir/op/contrib/gather_nd.cc
paddle/cinn/hlir/op/contrib/gather_nd.cc
+27
-40
paddle/cinn/hlir/op/contrib/logical_right_shift.cc
paddle/cinn/hlir/op/contrib/logical_right_shift.cc
+2
-6
paddle/cinn/hlir/op/contrib/lookup_table.cc
paddle/cinn/hlir/op/contrib/lookup_table.cc
+3
-5
paddle/cinn/hlir/op/contrib/one_hot.cc
paddle/cinn/hlir/op/contrib/one_hot.cc
+2
-6
paddle/cinn/hlir/op/contrib/reciprocal.cc
paddle/cinn/hlir/op/contrib/reciprocal.cc
+5
-11
paddle/cinn/hlir/op/contrib/resize.cc
paddle/cinn/hlir/op/contrib/resize.cc
+2
-5
paddle/cinn/hlir/op/contrib/sort.cc
paddle/cinn/hlir/op/contrib/sort.cc
+67
-90
paddle/cinn/hlir/op/elementwise.cc
paddle/cinn/hlir/op/elementwise.cc
+27
-57
paddle/cinn/hlir/op/nn.cc
paddle/cinn/hlir/op/nn.cc
+268
-647
paddle/cinn/hlir/op/op_broadcast_test.cc
paddle/cinn/hlir/op/op_broadcast_test.cc
+34
-89
paddle/cinn/hlir/op/op_util.cc
paddle/cinn/hlir/op/op_util.cc
+38
-80
paddle/cinn/hlir/op/reduction.cc
paddle/cinn/hlir/op/reduction.cc
+117
-194
paddle/cinn/hlir/op/transform.cc
paddle/cinn/hlir/op/transform.cc
+90
-212
paddle/cinn/hlir/op/transform_test.cc
paddle/cinn/hlir/op/transform_test.cc
+12
-34
未找到文件。
paddle/cinn/hlir/op/broadcast.cc
浏览文件 @
70183c4b
...
...
@@ -60,37 +60,34 @@ std::shared_ptr<OpStrategy> StrategyForBroadcast(
const
ir
::
Tensor
&
B
,
const
std
::
string
&
output_name
,
const
Expr
&
axis
))
{
framework
::
CINNCompute
binary_compute
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
<<
" compute is empty! Please check."
;
CINNValuePack
pack_args
=
args
[
0
];
CHECK_GE
(
pack_args
.
size
(),
2U
)
<<
"at least 2 input tensors for "
<<
op_name
<<
" compute"
;
std
::
string
tensor_name
=
UniqName
(
op_name
+
"_Out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_GE
(
pack_args
.
size
(),
3U
)
<<
op_name
<<
" 's input is not enough!"
;
CHECK
(
pack_args
[
2
].
is_string
());
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
}
Expr
A_expr
=
pack_args
[
0
];
Expr
B_expr
=
pack_args
[
1
];
CHECK
(
A_expr
.
as_tensor
());
CHECK
(
B_expr
.
as_tensor
());
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
B
=
B_expr
.
as_tensor_ref
();
Expr
axis
;
bool
trans_a
;
for
(
auto
&
iter
:
attrs
.
attr_store
)
{
if
(
iter
.
first
==
"axis"
)
{
axis
=
Expr
(
absl
::
get
<
int
>
(
iter
.
second
));
break
;
}
}
auto
out
=
pe_func
(
A
,
B
,
tensor_name
,
axis
);
auto
stages
=
CreateStages
({
A
,
B
,
out
});
*
ret
=
CINNValuePack
{{
CINNValue
(
Expr
(
out
.
get
())),
CINNValue
(
stages
)}};
});
framework
::
CINNCompute
binary_compute
(
[
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
<<
" compute is empty! Please check."
;
CINNValuePack
pack_args
=
args
[
0
];
CHECK_GE
(
pack_args
.
size
(),
2U
)
<<
"at least 2 input tensors for "
<<
op_name
<<
" compute"
;
CHECK_GE
(
pack_args
.
size
(),
3U
)
<<
op_name
<<
" 's input is not enough!"
;
CHECK
(
pack_args
[
2
].
is_string
());
std
::
string
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
Expr
A_expr
=
pack_args
[
0
];
Expr
B_expr
=
pack_args
[
1
];
CHECK
(
A_expr
.
as_tensor
());
CHECK
(
B_expr
.
as_tensor
());
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
B
=
B_expr
.
as_tensor_ref
();
Expr
axis
;
bool
trans_a
;
for
(
auto
&
iter
:
attrs
.
attr_store
)
{
if
(
iter
.
first
==
"axis"
)
{
axis
=
Expr
(
absl
::
get
<
int
>
(
iter
.
second
));
break
;
}
}
auto
out
=
pe_func
(
A
,
B
,
tensor_name
,
axis
);
auto
stages
=
CreateStages
({
A
,
B
,
out
});
*
ret
=
CINNValuePack
{{
CINNValue
(
Expr
(
out
.
get
())),
CINNValue
(
stages
)}};
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
strategy
->
AddImpl
(
binary_compute
,
...
...
@@ -198,12 +195,10 @@ std::shared_ptr<OpStrategy> StrategyForBroadcastTo(
CINNValuePack
pack_args
=
args
[
0
];
CHECK
(
!
pack_args
.
empty
())
<<
"The input tensors of broadcast_to compute is empty! Please check."
;
std
::
string
tensor_name
=
UniqName
(
"broadcast_to_Out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_GE
(
pack_args
.
size
(),
2U
);
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
CHECK_GE
(
pack_args
.
size
(),
2U
);
CHECK
(
pack_args
[
1
].
is_string
());
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
Expr
A_expr
=
pack_args
[
0
];
CHECK
(
A_expr
.
as_tensor
());
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
...
...
@@ -323,12 +318,9 @@ std::shared_ptr<OpStrategy> StrategyForIsClose(
CINNValuePack
pack_args
=
args
[
0
];
int
input_size
=
pack_args
.
size
();
std
::
string
tensor_name
=
UniqName
(
"IsClose_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
// the last pack argument is the output tensor name
tensor_name
=
pack_args
.
back
().
operator
std
::
string
();
--
input_size
;
}
// the last pack argument is the output tensor name
std
::
string
tensor_name
=
pack_args
.
back
().
operator
std
::
string
();
--
input_size
;
CHECK_EQ
(
input_size
,
2
)
<<
"The input number of isclose should be 2, but here "
<<
input_size
<<
"! Please check."
;
...
...
paddle/cinn/hlir/op/contrib/gather_nd.cc
浏览文件 @
70183c4b
...
...
@@ -114,11 +114,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
VLOG
(
3
)
<<
"x shape: "
<<
utils
::
Join
(
tensor_x
->
shape
,
", "
)
<<
", index shape: "
<<
utils
::
Join
(
tensor_index
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
UniqName
(
"GatherNd_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
3U
);
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
3U
);
std
::
string
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
ir
::
Tensor
out
=
GatherNd
(
tensor_x
,
tensor_index
,
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
stages
->
InsertLazily
(
out
);
...
...
@@ -131,44 +128,34 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
framework
::
CINNSchedule
gather_nd_schedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of gather_nd_schedule is "
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
}
CHECK
(
!
args
.
empty
())
<<
"The input argument of gather_nd_schedule is "
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
int64_t
prod_size
=
std
::
accumulate
(
output_shapes
[
0
].
begin
(),
output_shapes
[
0
].
end
(),
1
,
std
::
multiplies
<
int
>
());
if
(
prod_size
>
1
)
{
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
pe
::
IRCudaScheduleInjective
(
ir_sch
,
output_shapes
.
front
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
}
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
int64_t
prod_size
=
std
::
accumulate
(
output_shapes
[
0
].
begin
(),
output_shapes
[
0
].
end
(),
1
,
std
::
multiplies
<
int
>
());
if
(
prod_size
>
1
)
{
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
pe
::
IRCudaScheduleInjective
(
ir_sch
,
output_shapes
.
front
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
}
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of gather_nd_schedule is "
"empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
CHECK
(
out
.
as_tensor
());
*
ret
=
arg_pack
;
}
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
...
paddle/cinn/hlir/op/contrib/logical_right_shift.cc
浏览文件 @
70183c4b
...
...
@@ -105,12 +105,8 @@ std::shared_ptr<OpStrategy> StrategyForLogicalRightShift(
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
B
=
B_expr
.
as_tensor_ref
();
std
::
string
tensor_name
=
UniqName
(
"T_LogicalRightShift_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
3U
);
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
3U
);
std
::
string
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
auto
out
=
LogicalRightShift
(
A
,
B
,
target
,
tensor_name
);
auto
stages
=
CreateStages
({
out
});
...
...
paddle/cinn/hlir/op/contrib/lookup_table.cc
浏览文件 @
70183c4b
...
...
@@ -106,11 +106,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForLookupTable(
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", B shape: "
<<
utils
::
Join
(
tensor_B
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
UniqName
(
"LookupTable_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
3U
);
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
3U
);
std
::
string
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
ir
::
Tensor
out
=
LookupTable
(
tensor_A
,
tensor_B
,
padding_idx
,
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
stages
->
InsertLazily
(
out
);
...
...
paddle/cinn/hlir/op/contrib/one_hot.cc
100755 → 100644
浏览文件 @
70183c4b
...
...
@@ -194,12 +194,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForOneHot(
ir
::
Tensor
on_value
=
on_value_expr
.
as_tensor_ref
();
ir
::
Tensor
off_value
=
off_value_expr
.
as_tensor_ref
();
std
::
string
tensor_name
=
common
::
UniqName
(
"T_OneHot_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
4U
);
tensor_name
=
pack_args
[
3
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
4U
);
std
::
string
tensor_name
=
pack_args
[
3
].
operator
std
::
string
();
ir
::
Tensor
out
=
OneHot
(
indices
,
on_value
,
...
...
paddle/cinn/hlir/op/contrib/reciprocal.cc
浏览文件 @
70183c4b
...
...
@@ -94,13 +94,9 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(
CHECK
(
!
pack_args
.
empty
())
<<
"at least one input tensor for "
<<
op_name
<<
" compute
\n
"
;
std
::
string
tensor_name
=
UniqName
(
"Reciprocal_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2
);
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
2
);
CHECK
(
pack_args
[
1
].
is_string
());
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
Expr
A
=
pack_args
[
0
];
CHECK
(
A
.
as_tensor
());
...
...
@@ -110,10 +106,8 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2U
);
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
2U
);
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
ir
::
Tensor
out
=
Reciprocal
(
tensor_A
,
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
...
...
paddle/cinn/hlir/op/contrib/resize.cc
浏览文件 @
70183c4b
...
...
@@ -207,12 +207,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForResize(
auto
tensor_A
=
A
.
as_tensor_ref
();
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
common
::
UniqName
(
"T_Resize_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2U
);
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
2U
);
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
ir
::
Tensor
out
=
Resize
(
tensor_A
,
target
,
out_shape
,
mode
,
tensor_name
);
...
...
paddle/cinn/hlir/op/contrib/sort.cc
浏览文件 @
70183c4b
...
...
@@ -178,12 +178,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
auto
stages
=
CreateStages
({
tensor_A
});
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
auto
tensor_name
=
UniqName
(
"Sort_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2U
);
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
2U
);
CHECK
(
pack_args
[
1
].
is_string
());
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
std
::
vector
<
ir
::
Tensor
>
out
=
Sort
(
tensor_A
,
target
,
stages
,
axis
,
is_ascend
,
tensor_name
);
stages
->
InsertLazily
(
out
[
0
]);
...
...
@@ -195,48 +192,40 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
*
ret
=
CINNValuePack
{
res
};
});
framework
::
CINNSchedule
sort_schedule
(
[
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of sort_schedule is empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
]
;
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
]
;
vec_ast
.
emplace_back
(
temp
);
framework
::
CINNSchedule
sort_schedule
(
[
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of sort_schedule is empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
]
;
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
()
)
{
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
)
;
}
}
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
auto
blocks
=
ir_sch
.
GetAllBlocks
();
// TODO(Shixiaowei02): remove external calls, do not use local variables,
// because the size will exceed the limit.
ir_sch
.
SetBuffer
(
blocks
[
0
],
"local"
);
ir_sch
.
SetBuffer
(
blocks
[
1
],
"local"
);
int64_t
prod_size
=
std
::
accumulate
(
output_shapes
[
0
].
begin
(),
output_shapes
[
0
].
end
(),
1
,
std
::
multiplies
<
int
>
());
if
(
prod_size
>
1
&&
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
}
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of sort_schedule is empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
CHECK
(
out
.
as_tensor
());
*
ret
=
arg_pack
;
}
});
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
auto
blocks
=
ir_sch
.
GetAllBlocks
();
// TODO(Shixiaowei02): remove external calls, do not use local
// variables, because the size will exceed the limit.
ir_sch
.
SetBuffer
(
blocks
[
0
],
"local"
);
ir_sch
.
SetBuffer
(
blocks
[
1
],
"local"
);
int64_t
prod_size
=
std
::
accumulate
(
output_shapes
[
0
].
begin
(),
output_shapes
[
0
].
end
(),
1
,
std
::
multiplies
<
int
>
());
if
(
prod_size
>
1
&&
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
}
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
strategy
->
AddImpl
(
sort_compute
,
sort_schedule
,
"strategy.sort"
,
1
);
...
...
@@ -271,12 +260,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(
auto
stages
=
CreateStages
({
tensor_A
});
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
auto
tensor_name
=
UniqName
(
"ArgSort_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
3U
);
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
3U
);
CHECK
(
pack_args
[
1
].
is_string
());
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
auto
out
=
ArgSort
(
tensor_A
,
target
,
stages
,
axis
,
is_ascend
,
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
stages
->
InsertLazily
(
out
.
at
(
0
));
...
...
@@ -291,45 +277,36 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(
framework
::
CINNSchedule
argsort_schedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of argsort_schedule is empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
}
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
auto
blocks
=
ir_sch
.
GetAllBlocks
();
// TODO(Shixiaowei02): remove external calls, do not use local variables,
// because the size will exceed the limit.
// TODO(lanxianghit): There is a bug, setting buffer to "local" here will
// cause the var declared twice at CodeGen. ir_sch.SetBuffer(blocks[0],
// "local");
int64_t
prod_size
=
std
::
accumulate
(
output_shapes
[
0
].
begin
(),
output_shapes
[
0
].
end
(),
1
,
std
::
multiplies
<
int
>
());
if
(
prod_size
>
1
&&
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
CHECK
(
!
args
.
empty
())
<<
"The input argument of argsort_schedule is empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
}
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of argsort_schedule is empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
CHECK
(
out
.
as_tensor
());
*
ret
=
arg_pack
;
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
auto
blocks
=
ir_sch
.
GetAllBlocks
();
// TODO(Shixiaowei02): remove external calls, do not use local variables,
// because the size will exceed the limit.
// TODO(lanxianghit): There is a bug, setting buffer to "local" here will
// cause the var declared twice at CodeGen. ir_sch.SetBuffer(blocks[0],
// "local");
int64_t
prod_size
=
std
::
accumulate
(
output_shapes
[
0
].
begin
(),
output_shapes
[
0
].
end
(),
1
,
std
::
multiplies
<
int
>
());
if
(
prod_size
>
1
&&
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
,
true
);
}
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
...
paddle/cinn/hlir/op/elementwise.cc
浏览文件 @
70183c4b
...
...
@@ -67,12 +67,9 @@ std::shared_ptr<OpStrategy> StrategyForElementwise(
CINNValuePack
pack_args
=
args
[
0
];
CHECK_GE
(
pack_args
.
size
(),
1U
)
<<
"1 input tensor for "
<<
op_name
<<
" compute"
;
std
::
string
tensor_name
=
UniqName
(
op_name
+
"_Out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2U
);
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
2U
);
CHECK
(
pack_args
[
1
].
is_string
());
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
Expr
A_expr
=
pack_args
[
0
];
CHECK
(
A_expr
.
as_tensor
());
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
...
...
@@ -158,12 +155,9 @@ std::shared_ptr<OpStrategy> StrategyForScale(
CHECK
(
A_expr
.
as_tensor
());
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
ir
::
Tensor
out
;
std
::
string
tensor_name
=
UniqName
(
"Scale_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2
);
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
2
);
CHECK
(
pack_args
[
1
].
is_string
());
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
if
(
bias_after_scale
)
{
out
=
Compute
(
...
...
@@ -242,12 +236,9 @@ std::shared_ptr<OpStrategy> StrategyForConstScalar(
auto
scalar
=
GetScalarExpr
(
attrs
.
attr_store
.
at
(
"value"
));
auto
scalar_type
=
out_type
.
at
(
0
);
CINNValuePack
pack_args
=
args
[
0
];
std
::
string
tensor_name
=
UniqName
(
"const_scalar_Out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
1U
);
CHECK
(
pack_args
[
0
].
is_string
());
tensor_name
=
pack_args
[
0
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
1U
);
CHECK
(
pack_args
[
0
].
is_string
());
std
::
string
tensor_name
=
pack_args
[
0
].
operator
std
::
string
();
auto
out
=
lang
::
Compute
(
{
Expr
(
1
)},
...
...
@@ -371,12 +362,9 @@ std::shared_ptr<OpStrategy> StrategyForFillConstant(
}
CINNValuePack
arg_pack
=
args
[
0
];
std
::
string
tensor_name
=
UniqName
(
"fill_constant_Out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
arg_pack
.
size
(),
1U
);
CHECK
(
arg_pack
[
0
].
is_string
());
tensor_name
=
arg_pack
[
0
].
operator
std
::
string
();
}
CHECK_EQ
(
arg_pack
.
size
(),
1U
);
CHECK
(
arg_pack
[
0
].
is_string
());
std
::
string
tensor_name
=
arg_pack
[
0
].
operator
std
::
string
();
CHECK
(
!
shape
.
empty
())
<<
"shape attr is empty!"
;
auto
shape_exprs
=
ToCinnExprs
(
shape
);
auto
out
=
lang
::
Compute
(
...
...
@@ -458,12 +446,9 @@ std::shared_ptr<OpStrategy> StrategyForAssignValue(
const
auto
&
value
=
attrs
.
attr_store
.
at
(
"values"
);
CINNValuePack
arg_pack
=
args
[
0
];
std
::
string
tensor_name
=
UniqName
(
"T_assign_value_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
arg_pack
.
size
(),
1U
);
CHECK
(
arg_pack
[
0
].
is_string
());
tensor_name
=
arg_pack
[
0
].
operator
std
::
string
();
}
CHECK_EQ
(
arg_pack
.
size
(),
1U
);
CHECK
(
arg_pack
[
0
].
is_string
());
std
::
string
tensor_name
=
arg_pack
[
0
].
operator
std
::
string
();
absl
::
optional
<
ir
::
Tensor
>
out
;
#define EXPAND_VALUE_TO_TENSOR(TYPE) \
...
...
@@ -649,11 +634,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForSqueeze(
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
UniqName
(
"Squeeze_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2U
);
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
2U
);
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
ir
::
Tensor
out
=
pe
::
Squeeze
(
tensor_A
,
axes
,
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
...
...
@@ -729,12 +711,9 @@ std::shared_ptr<OpStrategy> StrategyForExpandDims(
Expr
x
=
input_args
[
0
];
CHECK
(
x
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"expand_dims_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
input_args
.
size
(),
2U
);
CHECK
(
input_args
[
1
].
is_string
());
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
}
CHECK_EQ
(
input_args
.
size
(),
2U
);
CHECK
(
input_args
[
1
].
is_string
());
std
::
string
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
auto
out
=
pe
::
ExpandDims
(
x
.
as_tensor_ref
(),
axes
,
output_shapes
[
0
],
tensor_name
);
...
...
@@ -809,12 +788,9 @@ std::shared_ptr<OpStrategy> StrategyForReshape(
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
UniqName
(
"Reshape_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2
);
CHECK
(
pack_args
[
1
].
is_string
());
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
2
);
CHECK
(
pack_args
[
1
].
is_string
());
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
ir
::
Tensor
out
=
pe
::
Reshape
(
tensor_A
,
output_shapes
[
0
],
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
...
...
@@ -901,11 +877,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForCast(
auto
stages
=
CreateStages
({
tensor_A
});
VLOG
(
3
)
<<
"A shape: "
<<
utils
::
Join
(
tensor_A
->
shape
,
", "
)
<<
", output_shapes: "
<<
utils
::
Join
(
output_shapes
[
0
],
", "
);
std
::
string
tensor_name
=
UniqName
(
"Cast_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
2U
);
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
2U
);
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
ir
::
Tensor
out
=
pe
::
Cast
(
tensor_A
,
out_type
[
0
],
tensor_name
);
std
::
vector
<
CINNValue
>
res
;
stages
->
InsertLazily
(
out
);
...
...
@@ -953,11 +926,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForArange(
<<
"The input argument of arange compute is empty! Please check.
\n
"
;
CINNValuePack
pack_args
=
args
[
0
];
std
::
string
tensor_name
=
common
::
UniqName
(
"T_Arange_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
1U
);
tensor_name
=
pack_args
[
0
].
operator
std
::
string
();
}
CHECK_EQ
(
pack_args
.
size
(),
1U
);
std
::
string
tensor_name
=
pack_args
[
0
].
operator
std
::
string
();
auto
out
=
pe
::
Arange
(
start
,
stop
,
step
,
dtype
,
tensor_name
);
std
::
vector
<
common
::
CINNValue
>
res
;
...
...
paddle/cinn/hlir/op/nn.cc
浏览文件 @
70183c4b
此差异已折叠。
点击以展开。
paddle/cinn/hlir/op/op_broadcast_test.cc
浏览文件 @
70183c4b
...
...
@@ -59,37 +59,19 @@ TEST(Operator, Operator_ElementWise_Add_Test0) {
std
::
string
func_name
=
"add1"
;
Module
::
Builder
builder
(
"module0"
,
target
);
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
out_name
=
"C"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValue
(
B
),
common
::
CINNValue
(
out_name
)}};
std
::
vector
<
std
::
string
>
input_output_names
{
"A"
,
"B"
,
out_name
};
auto
funcs
=
framework
::
GetFuncFromImpl
(
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
for
(
auto
func
:
funcs
)
{
LOG
(
INFO
)
<<
"Test Operator_ElementWise_Add_Test0's Strategy, func is :
\n
"
<<
func
;
builder
.
AddFunction
(
func
);
}
}
else
{
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValue
(
B
)}};
common
::
CINNValuePack
rets
=
impl
->
fcompute
(
cinn_input
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
rets
=
impl
->
fschedule
(
rets
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
// the last element is a StageMap
for
(
int
i
=
0
;
i
<
rets
->
size
()
-
1
;
i
++
)
{
Expr
temp
=
rets
[
i
];
inputs
.
push_back
(
temp
.
as_tensor_ref
());
}
auto
func
=
Lower
(
"fn_"
+
func_name
,
rets
.
back
(),
inputs
);
LOG
(
INFO
)
<<
"Test Strategy Codegen:
\n
"
<<
func
;
std
::
string
out_name
=
"C"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValue
(
B
),
common
::
CINNValue
(
out_name
)}};
std
::
vector
<
std
::
string
>
input_output_names
{
"A"
,
"B"
,
out_name
};
auto
funcs
=
framework
::
GetFuncFromImpl
(
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
for
(
auto
func
:
funcs
)
{
LOG
(
INFO
)
<<
"Test Operator_ElementWise_Add_Test0's Strategy, func is :
\n
"
<<
func
;
builder
.
AddFunction
(
func
);
}
...
...
@@ -160,37 +142,20 @@ TEST(Operator, Operator_ElementWise_Add_Test1) {
std
::
string
func_name
=
"add2"
;
Module
::
Builder
builder
(
"module"
,
target
);
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
out_name
=
"C"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValue
(
B
),
common
::
CINNValue
(
out_name
)}};
std
::
vector
<
std
::
string
>
input_output_names
{
"A"
,
"B"
,
out_name
};
auto
funcs
=
framework
::
GetFuncFromImpl
(
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
for
(
auto
func
:
funcs
)
{
builder
.
AddFunction
(
func
);
LOG
(
INFO
)
<<
"Test Operator_ElementWise_Add_Test1's Strategy, func is :
\n
"
<<
func
;
}
}
else
{
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValue
(
B
)}};
common
::
CINNValuePack
rets
=
impl
->
fcompute
(
cinn_input
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
rets
=
impl
->
fschedule
(
rets
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
// the last element is a StageMap
for
(
int
i
=
0
;
i
<
rets
->
size
()
-
1
;
i
++
)
{
Expr
temp
=
rets
[
i
];
inputs
.
push_back
(
temp
.
as_tensor_ref
());
}
auto
func
=
Lower
(
"fn_"
+
func_name
,
rets
.
back
(),
inputs
);
LOG
(
INFO
)
<<
"Test Strategy Codegen:
\n
"
<<
func
;
std
::
string
out_name
=
"C"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
A
),
common
::
CINNValue
(
B
),
common
::
CINNValue
(
out_name
)}};
std
::
vector
<
std
::
string
>
input_output_names
{
"A"
,
"B"
,
out_name
};
auto
funcs
=
framework
::
GetFuncFromImpl
(
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
for
(
auto
func
:
funcs
)
{
builder
.
AddFunction
(
func
);
LOG
(
INFO
)
<<
"Test Operator_ElementWise_Add_Test1's Strategy, func is :
\n
"
<<
func
;
}
backends
::
CodeGenCUDA_Dev
codegen
(
target
);
...
...
@@ -225,33 +190,15 @@ TEST(Operator, Operator_BroadcastTo) {
std
::
string
func_name
=
"broadcast_to"
;
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
out_name
=
"C"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{
{
common
::
CINNValue
(
B
),
common
::
CINNValue
(
out_name
)}};
std
::
vector
<
std
::
string
>
input_output_names
{
"B"
,
out_name
};
std
::
string
out_name
=
"C"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{
{
common
::
CINNValue
(
B
),
common
::
CINNValue
(
out_name
)}};
std
::
vector
<
std
::
string
>
input_output_names
{
"B"
,
out_name
};
auto
funcs
=
framework
::
GetFuncFromImpl
(
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
auto
funcs
=
framework
::
GetFuncFromImpl
(
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
for
(
auto
func
:
funcs
)
{
LOG
(
INFO
)
<<
"Test Operator_BroadcastTo's Strategy, func is :
\n
"
<<
func
;
}
}
else
{
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
B
)}};
common
::
CINNValuePack
rets
=
impl
->
fcompute
(
cinn_input
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
rets
=
impl
->
fschedule
(
rets
);
ASSERT_EQ
(
rets
.
size
(),
2UL
);
// the last element is a StageMap
for
(
int
i
=
0
;
i
<
rets
->
size
()
-
1
;
i
++
)
{
Expr
temp
=
rets
[
i
];
inputs
.
push_back
(
temp
.
as_tensor_ref
());
}
auto
func
=
Lower
(
"func"
+
func_name
,
rets
.
back
(),
inputs
);
for
(
auto
func
:
funcs
)
{
LOG
(
INFO
)
<<
"Test Operator_BroadcastTo's Strategy, func is :
\n
"
<<
func
;
}
}
...
...
@@ -260,9 +207,7 @@ common::CINNValuePack GetComputeResult(
const
std
::
shared_ptr
<
OpImpl
>
&
impl
,
std
::
vector
<
common
::
CINNValue
>
&
cinn_inputs
,
// NOLINT
const
std
::
string
&
output_name
=
""
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
cinn_inputs
.
emplace_back
(
output_name
);
}
cinn_inputs
.
emplace_back
(
output_name
);
return
impl
->
fcompute
(
common
::
CINNValuePack
{
cinn_inputs
});
}
...
...
paddle/cinn/hlir/op/op_util.cc
浏览文件 @
70183c4b
...
...
@@ -21,8 +21,6 @@
#include "paddle/cinn/hlir/pe/schedule.h"
#include "paddle/cinn/ir/ir_schedule.h"
DECLARE_bool
(
cinn_ir_schedule
);
namespace
cinn
{
namespace
hlir
{
...
...
@@ -31,44 +29,24 @@ CINNSchedule GetElementwiseScheduleFunc(
const
Target
&
target
,
bool
vectorizable
)
{
return
CINNSchedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of ElementwiseSchedule is "
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
}
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
pe
::
IRElementwiseSchedule
(
ir_sch
,
output_shapes
.
front
(),
target
);
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of ElementwiseSchedule is "
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
poly
::
StageMap
stages
=
arg_pack
[
1
];
CHECK
(
out
.
as_tensor
());
CHECK_EQ
(
arg_pack
.
size
(),
2UL
);
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
pe
::
CudaScheduleInjective
(
stages
[
out
.
as_tensor_ref
()],
output_shapes
.
front
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
ScheduleInjectiveCPU
(
stages
[
out
.
as_tensor_ref
()],
output_shapes
.
front
(),
target
,
vectorizable
);
CHECK
(
!
args
.
empty
())
<<
"The input argument of ElementwiseSchedule is "
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
}
*
ret
=
arg_pack
;
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
pe
::
IRElementwiseSchedule
(
ir_sch
,
output_shapes
.
front
(),
target
);
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
});
}
...
...
@@ -77,50 +55,30 @@ CINNSchedule GetInjectiveScheduleFunc(
const
Target
&
target
,
bool
vectorizable
)
{
return
CINNSchedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of InjectiveSchedule is "
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
}
CHECK
(
!
args
.
empty
())
<<
"The input argument of InjectiveSchedule is "
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
pe
::
IRInjectiveSchedule
(
ir_sch
,
output_shapes
.
front
(),
target
);
/*if (target.arch == Target::Arch::NVGPU) {
pe::IRInjectiveSchedule(ir_sch, output_shapes.front(), target);
} else if (target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target,
vectorizable);
}*/
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of InjectiveSchedule is "
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
poly
::
StageMap
stages
=
arg_pack
[
1
];
CHECK
(
out
.
as_tensor
());
CHECK_EQ
(
arg_pack
.
size
(),
2UL
);
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
pe
::
CudaScheduleInjective
(
stages
[
out
.
as_tensor_ref
()],
output_shapes
.
front
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
ScheduleInjectiveCPU
(
stages
[
out
.
as_tensor_ref
()],
output_shapes
.
front
(),
target
,
vectorizable
);
}
*
ret
=
arg_pack
;
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
pe
::
IRInjectiveSchedule
(
ir_sch
,
output_shapes
.
front
(),
target
);
/*if (target.arch == Target::Arch::NVGPU) {
pe::IRInjectiveSchedule(ir_sch, output_shapes.front(), target);
} else if (target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target,
vectorizable);
}*/
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
});
}
...
...
paddle/cinn/hlir/op/reduction.cc
浏览文件 @
70183c4b
...
...
@@ -29,8 +29,6 @@
#include "paddle/cinn/ir/ir_schedule.h"
#include "paddle/cinn/optim/ir_simplify.h"
DECLARE_bool
(
cinn_ir_schedule
);
namespace
cinn
{
namespace
hlir
{
namespace
op
{
...
...
@@ -115,16 +113,10 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
<<
" compute is empty! Please check."
;
CINNValuePack
arg_packs
=
args
[
0
];
std
::
string
tensor_name
=
UniqName
(
op_name
+
"_out"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
arg_packs
.
size
(),
2U
)
<<
"There should be 2 input args for "
<<
op_name
<<
" compute"
;
CHECK
(
arg_packs
[
1
].
is_string
());
tensor_name
=
arg_packs
[
1
].
operator
std
::
string
();
}
else
{
CHECK_EQ
(
arg_packs
.
size
(),
1U
)
<<
"There should be 1 input args for "
<<
op_name
<<
" compute"
;
}
CHECK_EQ
(
arg_packs
.
size
(),
2U
)
<<
"There should be 2 input args for "
<<
op_name
<<
" compute"
;
CHECK
(
arg_packs
[
1
].
is_string
());
std
::
string
tensor_name
=
arg_packs
[
1
].
operator
std
::
string
();
Expr
x_expr
=
arg_packs
[
0
];
CHECK
(
x_expr
.
as_tensor
());
ir
::
Tensor
x
=
x_expr
.
as_tensor_ref
();
...
...
@@ -175,206 +167,137 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
lang
::
RetValue
*
ret
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of "
<<
op_name
<<
" schedule is empty! Please check."
;
CINNValuePack
arg_pack
=
args
[
0
];
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_GE
(
arg_pack
.
size
(),
2UL
);
CHECK_LE
(
arg_pack
.
size
(),
8UL
);
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
std
::
vector
<
Expr
>
vec_tensor
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
// TODO(zhhsplendid): old reducetion schedule assumes all length-1
// for loops are simplified, but it is not after we add length-1
// back. Reduction schedule is complex and we haven't changed it to
// support the length-1 for loop yet. So we simplify here. The todo
// is that remove SimplifyForLoops below and change reduction schedule
optim
::
SimplifyForLoops
(
&
temp
);
vec_ast
.
emplace_back
(
temp
);
}
else
if
(
arg_pack
[
i
].
is_tensor
())
{
Expr
temp
=
arg_pack
[
i
];
vec_tensor
.
emplace_back
(
temp
);
}
CINNValuePack
arg_pack
=
args
[
0
];
CHECK_GE
(
arg_pack
.
size
(),
2UL
);
CHECK_LE
(
arg_pack
.
size
(),
8UL
);
std
::
vector
<
Expr
>
vec_ast
;
std
::
vector
<
Expr
>
vec_tensor
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
// TODO(zhhsplendid): old reducetion schedule assumes all length-1
// for loops are simplified, but it is not after we add length-1
// back. Reduction schedule is complex and we haven't changed it to
// support the length-1 for loop yet. So we simplify here. The todo
// is that remove SimplifyForLoops below and change reduction schedule
optim
::
SimplifyForLoops
(
&
temp
);
vec_ast
.
emplace_back
(
temp
);
}
else
if
(
arg_pack
[
i
].
is_tensor
())
{
Expr
temp
=
arg_pack
[
i
];
vec_tensor
.
emplace_back
(
temp
);
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
if
(
!
WithoutLastDimInReduce
(
inputs
[
0
]
->
shape
,
reduce_axes
))
{
if
(
arg_pack
.
size
()
==
4
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
2
);
Expr
out
=
vec_tensor
[
0
];
Expr
tmp_out
=
vec_tensor
[
1
];
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockReduceInternal Schedule!"
;
pe
::
IRCudaScheduleBlockReduceInternal
(
ir_sch
,
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
target
);
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
if
(
arg_pack
.
size
()
==
6
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
3
);
Expr
out
=
vec_tensor
[
0
];
Expr
tmp_out
=
vec_tensor
[
1
];
Expr
reduce_tmp_out
=
vec_tensor
[
2
];
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockReduce Schedule!"
;
pe
::
IRCudaScheduleBlockReduce
(
ir_sch
,
reduce_tmp_out
.
as_tensor_ref
(),
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
target
);
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
if
(
arg_pack
.
size
()
==
7
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
4
);
Expr
out
=
vec_tensor
[
0
];
Expr
tmp_out
=
vec_tensor
[
1
];
Expr
reduce_tmp_out
=
vec_tensor
[
2
];
Expr
reshape
=
vec_tensor
[
3
];
VLOG
(
3
)
<<
"Do IRCudaTwoStepReduceSchedule Schedule!"
;
pe
::
IRCudaTwoStepReduceSchedule
(
ir_sch
,
reshape
.
as_tensor_ref
(),
reduce_tmp_out
.
as_tensor_ref
(),
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
common
::
DefaultNVGPUTarget
());
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
if
(
arg_pack
.
size
()
==
5
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
3
);
Expr
out
=
vec_tensor
[
0
];
Expr
tmp_out
=
vec_tensor
[
1
];
Expr
reduce_tmp_out
=
vec_tensor
[
2
];
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockReduce Schedule!"
;
pe
::
IRCudaScheduleBlockReduce
(
ir_sch
,
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
if
(
!
WithoutLastDimInReduce
(
inputs
[
0
]
->
shape
,
reduce_axes
))
{
if
(
arg_pack
.
size
()
==
4
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
2
);
Expr
out
=
vec_tensor
[
0
];
Expr
tmp_out
=
vec_tensor
[
1
];
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockReduceInternal Schedule!"
;
pe
::
IRCudaScheduleBlockReduceInternal
(
ir_sch
,
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
target
);
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
if
(
arg_pack
.
size
()
==
6
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
3
);
Expr
out
=
vec_tensor
[
0
];
Expr
tmp_out
=
vec_tensor
[
1
];
Expr
reduce_tmp_out
=
vec_tensor
[
2
];
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockReduce Schedule!"
;
pe
::
IRCudaScheduleBlockReduce
(
ir_sch
,
reduce_tmp_out
.
as_tensor_ref
(),
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
target
);
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
if
(
arg_pack
.
size
()
==
7
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
4
);
Expr
out
=
vec_tensor
[
0
];
Expr
tmp_out
=
vec_tensor
[
1
];
Expr
reduce_tmp_out
=
vec_tensor
[
2
];
Expr
reshape
=
vec_tensor
[
3
];
VLOG
(
3
)
<<
"Do IRCudaTwoStepReduceSchedule Schedule!"
;
pe
::
IRCudaTwoStepReduceSchedule
(
ir_sch
,
reshape
.
as_tensor_ref
(),
reduce_tmp_out
.
as_tensor_ref
(),
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
common
::
DefaultNVGPUTarget
());
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
{
LOG
(
FATAL
)
<<
"Unkown Reduce Type!"
;
}
}
else
{
if
(
arg_pack
.
size
()
==
2
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
1
);
Expr
reduce_out
=
vec_tensor
[
0
];
VLOG
(
3
)
<<
"Do IRCudaScheduleReduce Schedule!"
;
pe
::
IRCudaScheduleReduce
(
ir_sch
,
reduce_out
.
as_tensor_ref
(),
inputs
[
0
]
->
shape
.
size
()
-
reduce_axes
.
back
()
-
1
,
target
);
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
if
(
arg_pack
.
size
()
==
6
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
3
);
Expr
reduce_out
=
vec_tensor
[
0
];
Expr
reduce_internal
=
vec_tensor
[
1
];
Expr
reduce_reshape
=
vec_tensor
[
2
];
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockShuffleReduce Schedule!"
;
pe
::
IRCudaScheduleBlockShuffleReduce
(
ir_sch
,
reduce_reshape
.
as_tensor_ref
(),
reduce_internal
.
as_tensor_ref
(),
reduce_out
.
as_tensor_ref
(),
target
);
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
{
LOG
(
FATAL
)
<<
"Unkown Reduce Type!"
;
}
}
}
else
{
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
}
else
{
CHECK_GE
(
arg_pack
.
size
(),
2UL
);
CHECK_LE
(
arg_pack
.
size
(),
5UL
);
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
if
(
!
WithoutLastDimInReduce
(
inputs
[
0
]
->
shape
,
reduce_axes
))
{
if
(
arg_pack
.
size
()
==
3
)
{
Expr
out
=
arg_pack
[
0
];
Expr
tmp_out
=
arg_pack
[
1
];
poly
::
StageMap
stages
=
arg_pack
.
back
();
VLOG
(
3
)
<<
"Do CudaBlockReduceInternalSchedule Schedule!"
;
pe
::
CudaBlockReduceInternalSchedule
(
stages
,
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
common
::
DefaultNVGPUTarget
());
}
else
if
(
arg_pack
.
size
()
==
4
)
{
Expr
out
=
arg_pack
[
0
];
Expr
tmp_out
=
arg_pack
[
1
];
Expr
reduce_tmp_out
=
arg_pack
[
2
];
poly
::
StageMap
stages
=
arg_pack
.
back
();
VLOG
(
3
)
<<
"Do CudaBlockReduceSchedule Schedule!"
;
pe
::
CudaBlockReduceSchedule
(
stages
,
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
if
(
arg_pack
.
size
()
==
5
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
3
);
Expr
out
=
vec_tensor
[
0
];
Expr
tmp_out
=
vec_tensor
[
1
];
Expr
reduce_tmp_out
=
vec_tensor
[
2
];
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockReduce Schedule!"
;
pe
::
IRCudaScheduleBlockReduce
(
ir_sch
,
reduce_tmp_out
.
as_tensor_ref
(),
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
common
::
DefaultNVGPUTarget
());
}
else
{
Expr
out
=
arg_pack
[
0
];
Expr
tmp_out
=
arg_pack
[
1
];
Expr
reduce_tmp_out
=
arg_pack
[
2
];
Expr
reshape
=
arg_pack
[
3
];
poly
::
StageMap
stages
=
arg_pack
.
back
();
VLOG
(
3
)
<<
"Do CudaTwoStepReduceSchedule Schedule!"
;
pe
::
CudaTwoStepReduceSchedule
(
stages
,
reshape
.
as_tensor_ref
(),
reduce_tmp_out
.
as_tensor_ref
(),
tmp_out
.
as_tensor_ref
(),
out
.
as_tensor_ref
(),
common
::
DefaultNVGPUTarget
());
}
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
{
if
(
arg_pack
.
size
()
==
2
)
{
Expr
reduce_out
=
arg_pack
[
0
];
poly
::
StageMap
stages
=
arg_pack
.
back
();
VLOG
(
3
)
<<
"Do CudaReduceSchedule Schedule!"
;
pe
::
CudaReduceSchedule
(
stages
,
reduce_out
.
as_tensor_ref
(),
inputs
[
0
]
->
shape
.
size
()
-
reduce_axes
.
back
()
-
1
,
target
);
}
else
{
CHECK_EQ
(
arg_pack
.
size
(),
4
)
<<
"args is not equal 4!"
;
Expr
reduce_reshape
=
arg_pack
[
2
];
Expr
reduce_internal
=
arg_pack
[
1
];
Expr
reduce_out
=
arg_pack
[
0
];
poly
::
StageMap
stages
=
arg_pack
.
back
();
VLOG
(
3
)
<<
"Do CudaBlockShuffleReduceSchedule Schedule!"
;
pe
::
CudaBlockShuffleReduceSchedule
(
stages
,
LOG
(
FATAL
)
<<
"Unkown Reduce Type!"
;
}
}
else
{
if
(
arg_pack
.
size
()
==
2
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
1
);
Expr
reduce_out
=
vec_tensor
[
0
];
VLOG
(
3
)
<<
"Do IRCudaScheduleReduce Schedule!"
;
pe
::
IRCudaScheduleReduce
(
ir_sch
,
reduce_out
.
as_tensor_ref
(),
inputs
[
0
]
->
shape
.
size
()
-
reduce_axes
.
back
()
-
1
,
target
);
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
if
(
arg_pack
.
size
()
==
6
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
3
);
Expr
reduce_out
=
vec_tensor
[
0
];
Expr
reduce_internal
=
vec_tensor
[
1
];
Expr
reduce_reshape
=
vec_tensor
[
2
];
VLOG
(
3
)
<<
"Do IRCudaScheduleBlockShuffleReduce Schedule!"
;
pe
::
IRCudaScheduleBlockShuffleReduce
(
ir_sch
,
reduce_reshape
.
as_tensor_ref
(),
reduce_internal
.
as_tensor_ref
(),
reduce_out
.
as_tensor_ref
(),
target
);
}
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
{
LOG
(
FATAL
)
<<
"Unkown Reduce Type!"
;
}
}
*
ret
=
arg_pack
;
}
else
{
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
});
...
...
paddle/cinn/hlir/op/transform.cc
浏览文件 @
70183c4b
...
...
@@ -73,12 +73,9 @@ std::shared_ptr<OpStrategy> StrategyForMatMul(
CHECK
(
A
.
as_tensor
());
CHECK
(
B
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"MatMul"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_GE
(
pack_args
.
size
(),
3
);
CHECK
(
pack_args
[
2
].
is_string
());
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
}
CHECK_GE
(
pack_args
.
size
(),
3
);
CHECK
(
pack_args
[
2
].
is_string
());
std
::
string
tensor_name
=
pack_args
[
2
].
operator
std
::
string
();
auto
tensor_A
=
A
.
as_tensor_ref
();
auto
tensor_B
=
B
.
as_tensor_ref
();
...
...
@@ -130,32 +127,9 @@ std::shared_ptr<OpStrategy> StrategyForMatMul(
CHECK
(
!
args
.
empty
())
<<
"The input argument of matmul schedule is empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
vector
<
CINNValue
>
results
=
pe
::
IRCudaScheduleMatMul
(
arg_pack
,
output_shape
,
target
);
*
ret
=
CINNValuePack
({
results
});
}
else
{
CHECK
(
arg_pack
.
size
()
==
2UL
||
arg_pack
.
size
()
==
3UL
);
poly
::
StageMap
stages
=
arg_pack
.
back
();
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
Expr
out
=
arg_pack
[
0
];
CHECK
(
out
.
as_tensor
());
pe
::
MatmulScheduleCUDA
(
stages
,
out
.
as_tensor_ref
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
#ifdef CINN_WITH_MKL_CBLAS
CHECK_EQ
(
arg_pack
.
size
(),
3UL
);
#else
CHECK_EQ
(
arg_pack
.
size
(),
3UL
);
Expr
out
=
arg_pack
[
0
];
Expr
packedB
=
arg_pack
[
1
];
CHECK
(
packedB
.
as_tensor
());
CHECK
(
out
.
as_tensor
());
pe
::
MatmulScheduleCPU
(
stages
,
out
.
as_tensor_ref
(),
packedB
.
as_tensor_ref
(),
target
);
#endif
}
*
ret
=
arg_pack
;
}
std
::
vector
<
CINNValue
>
results
=
pe
::
IRCudaScheduleMatMul
(
arg_pack
,
output_shape
,
target
);
*
ret
=
CINNValuePack
({
results
});
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
...
@@ -262,16 +236,10 @@ std::shared_ptr<OpStrategy> StrategyForSplit(
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
std
::
vector
<
std
::
string
>
tensor_names
;
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
pack_args
.
size
(),
output_shapes
.
size
()
+
1
);
for
(
int
idx
=
1
;
idx
<
pack_args
.
size
();
++
idx
)
{
CHECK
(
pack_args
[
idx
].
is_string
());
tensor_names
.
push_back
(
pack_args
[
idx
].
operator
std
::
string
());
}
}
else
{
for
(
int
idx
=
0
;
idx
<
output_shapes
.
size
();
++
idx
)
{
tensor_names
.
push_back
(
UniqName
(
"T_Split_Out"
));
}
CHECK_EQ
(
pack_args
.
size
(),
output_shapes
.
size
()
+
1
);
for
(
int
idx
=
1
;
idx
<
pack_args
.
size
();
++
idx
)
{
CHECK
(
pack_args
[
idx
].
is_string
());
tensor_names
.
push_back
(
pack_args
[
idx
].
operator
std
::
string
());
}
auto
out
=
pe
::
Split
(
A
,
axis
,
output_shapes
,
tensor_names
);
...
...
@@ -285,38 +253,27 @@ std::shared_ptr<OpStrategy> StrategyForSplit(
*
ret
=
CINNValuePack
{
res
};
});
framework
::
CINNSchedule
split_schedule
(
[
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of split schedule is empty! Please check."
;
CINNValuePack
arg_pack
=
args
[
0
]
;
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
]
;
vec_ast
.
emplace_back
(
temp
);
framework
::
CINNSchedule
split_schedule
(
[
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of split schedule is empty! Please check."
;
CINNValuePack
arg_pack
=
args
[
0
]
;
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
()
)
{
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
)
;
}
}
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
pe
::
IRCudaSplitSchedule
(
ir_sch
,
output_shapes
,
axis
,
target
);
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input arguments of split schedule is empty! Please check."
;
CINNValuePack
arg_pack
=
args
[
0
];
CHECK_GE
(
arg_pack
.
size
(),
2UL
)
<<
"The input tensor's size of split schedule is "
<<
arg_pack
.
size
()
<<
"and it should be greater equal to 2! Please check."
;
pe
::
CudaSplitSchedule
(
&
arg_pack
,
output_shapes
,
axis
,
target
);
*
ret
=
arg_pack
;
}
});
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
pe
::
IRCudaSplitSchedule
(
ir_sch
,
output_shapes
,
axis
,
target
);
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
strategy
->
AddImpl
(
split_compute
,
split_schedule
,
"strategy.split.x86"
,
1
);
...
...
@@ -468,8 +425,7 @@ std::shared_ptr<OpStrategy> StrategyForConcat(
CHECK
(
!
out_type
.
empty
())
<<
"Output type of Concat is empty! Please check.
\n
"
;
CINNValuePack
pack_args
=
args
[
0
];
int
input_size
=
FLAGS_cinn_ir_schedule
?
pack_args
.
size
()
-
1
:
pack_args
.
size
();
int
input_size
=
pack_args
.
size
()
-
1
;
CHECK_GE
(
input_size
,
1UL
)
<<
"at least 2 input tensors for Concat compute
\n
"
;
CHECK
(
!
output_shapes
.
empty
());
...
...
@@ -485,11 +441,8 @@ std::shared_ptr<OpStrategy> StrategyForConcat(
input_tensors
.
push_back
(
tensor
.
as_tensor_ref
());
}
std
::
string
tensor_name
=
UniqName
(
"Concat_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
pack_args
[
input_size
].
is_string
());
tensor_name
=
pack_args
[
input_size
].
operator
std
::
string
();
}
CHECK
(
pack_args
[
input_size
].
is_string
());
std
::
string
tensor_name
=
pack_args
[
input_size
].
operator
std
::
string
();
auto
stages
=
CreateStages
(
input_tensors
);
auto
out
=
pe
::
Concat
(
input_tensors
,
axis
,
tensor_name
);
...
...
@@ -612,11 +565,8 @@ std::shared_ptr<OpStrategy> StrategyForMul(
auto
new_B
=
B_tensor
->
Reshape
(
new_shape_B_e
,
stages
);
std
::
vector
<
ir
::
Tensor
>
out
;
std
::
string
tensor_name
=
UniqName
(
"Mul_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
pack_args
.
back
().
is_string
());
tensor_name
=
pack_args
.
back
().
operator
std
::
string
();
}
CHECK
(
pack_args
.
back
().
is_string
());
std
::
string
tensor_name
=
pack_args
.
back
().
operator
std
::
string
();
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
#ifdef CINN_WITH_MKL_CBLAS
...
...
@@ -647,32 +597,9 @@ std::shared_ptr<OpStrategy> StrategyForMul(
CHECK
(
!
args
.
empty
())
<<
"The input argument of matmul schedule is empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
vector
<
CINNValue
>
results
=
pe
::
IRCudaScheduleMatMul
(
arg_pack
,
output_shape
,
target
);
*
ret
=
CINNValuePack
({
results
});
}
else
{
CHECK
(
arg_pack
.
size
()
==
2UL
||
arg_pack
.
size
()
==
3UL
);
poly
::
StageMap
stages
=
arg_pack
.
back
();
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
Expr
out
=
arg_pack
[
0
];
CHECK
(
out
.
as_tensor
());
pe
::
MatmulScheduleCUDA
(
stages
,
out
.
as_tensor_ref
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
#ifdef CINN_WITH_MKL_CBLAS
CHECK_EQ
(
arg_pack
.
size
(),
3UL
);
#else
CHECK_EQ
(
arg_pack
.
size
(),
3UL
);
Expr
out
=
arg_pack
[
0
];
Expr
packedB
=
arg_pack
[
1
];
CHECK
(
packedB
.
as_tensor
());
CHECK
(
out
.
as_tensor
());
pe
::
MatmulScheduleCPU
(
stages
,
out
.
as_tensor_ref
(),
packedB
.
as_tensor_ref
(),
target
);
#endif
}
*
ret
=
arg_pack
;
}
std
::
vector
<
CINNValue
>
results
=
pe
::
IRCudaScheduleMatMul
(
arg_pack
,
output_shape
,
target
);
*
ret
=
CINNValuePack
({
results
});
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
...
...
@@ -780,12 +707,9 @@ std::shared_ptr<OpStrategy> StrategyForCublasGemm(
// dummy gemm computation, which will be replaced by
// cinn_gpu_cublas_gemm in the GemmRewriter pass.
std
::
string
tensor_name
=
UniqName
(
"cublas_gemm_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
input_args
.
size
(),
4
);
CHECK
(
input_args
[
3
].
is_string
());
tensor_name
=
input_args
[
3
].
operator
std
::
string
();
}
CHECK_EQ
(
input_args
.
size
(),
4
);
CHECK
(
input_args
[
3
].
is_string
());
std
::
string
tensor_name
=
input_args
[
3
].
operator
std
::
string
();
auto
out
=
pe
::
Identity
(
bias_tensor
,
tensor_name
).
front
();
auto
stages
=
CreateStages
(
{
lhs
.
as_tensor_ref
(),
rhs
.
as_tensor_ref
(),
bias_tensor
});
...
...
@@ -849,12 +773,9 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform(
Expr
A
=
input_args
[
0
];
CHECK
(
A
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"layout_transform_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
input_args
.
size
(),
2
);
CHECK
(
input_args
[
1
].
is_string
());
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
}
CHECK_EQ
(
input_args
.
size
(),
2
);
CHECK
(
input_args
[
1
].
is_string
());
std
::
string
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
auto
out
=
pe
::
LayoutTransform
(
A
.
as_tensor_ref
(),
src_layout
,
dst_layout
,
tensor_name
);
...
...
@@ -865,53 +786,31 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform(
*
ret
=
CINNValuePack
{
res
};
});
framework
::
CINNSchedule
layout_transform_schedule
(
[
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of CublasGemm schedule "
"is empty! Please check."
;
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
}
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
);
}
else
{
CINN_NOT_IMPLEMENTED
}
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of layout_transform "
"schedule is empty! Please check.
\n
"
;
CINNValuePack
arg_pack
=
args
[
0
];
CHECK_EQ
(
arg_pack
.
size
(),
2UL
);
Expr
out
=
arg_pack
[
0
];
poly
::
StageMap
stages
=
arg_pack
[
1
];
CHECK
(
out
.
as_tensor
());
auto
tensor_out
=
out
.
as_tensor_ref
();
std
::
vector
<
int
>
out_shape
;
for
(
auto
shape
:
tensor_out
->
shape
)
{
out_shape
.
push_back
(
shape
.
as_int32
());
}
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
ScheduleInjectiveCPU
(
stages
[
tensor_out
],
out_shape
,
target
);
}
else
{
CINN_NOT_IMPLEMENTED
}
*
ret
=
arg_pack
;
}
});
framework
::
CINNSchedule
layout_transform_schedule
([
=
](
lang
::
Args
args
,
lang
::
RetValue
*
ret
)
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of CublasGemm schedule "
"is empty! Please check."
;
CINNValuePack
arg_pack
=
args
[
0
];
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
}
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
IRScheduleInjectiveCPU
(
ir_sch
,
output_shapes
.
front
(),
target
);
}
else
{
CINN_NOT_IMPLEMENTED
}
std
::
vector
<
CINNValue
>
res
{
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
CINNValuePack
{
res
};
});
auto
strategy
=
std
::
make_shared
<
framework
::
OpStrategy
>
();
CHECK
(
out_type
.
size
())
...
...
@@ -996,12 +895,9 @@ std::shared_ptr<OpStrategy> StrategyForReverse(
Expr
A
=
input_args
[
0
];
CHECK
(
A
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"Reverse_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
input_args
.
size
(),
2
);
CHECK
(
input_args
[
1
].
is_string
());
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
}
CHECK_EQ
(
input_args
.
size
(),
2
);
CHECK
(
input_args
[
1
].
is_string
());
std
::
string
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
auto
out
=
pe
::
Reverse
(
A
.
as_tensor_ref
(),
axis
,
tensor_name
);
auto
stages
=
CreateStages
({
A
.
as_tensor_ref
(),
out
});
...
...
@@ -1113,12 +1009,9 @@ std::shared_ptr<OpStrategy> StrategyForTranspose(
<<
"at least one input tensor for transpose compute
\n
"
;
Expr
A
=
input_args
[
0
];
CHECK
(
A
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"Transpose_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
input_args
.
size
(),
2
);
CHECK
(
input_args
[
1
].
is_string
());
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
}
CHECK_EQ
(
input_args
.
size
(),
2
);
CHECK
(
input_args
[
1
].
is_string
());
std
::
string
tensor_name
=
input_args
[
1
].
operator
std
::
string
();
auto
out
=
pe
::
Transpose
(
A
.
as_tensor_ref
(),
axis
,
tensor_name
);
auto
stages
=
CreateStages
({
out
});
...
...
@@ -1236,12 +1129,9 @@ std::shared_ptr<OpStrategy> StrategyForGather(
Expr
index
=
input_args
[
1
];
CHECK
(
index
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"gather_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
input_args
.
size
(),
3U
);
CHECK
(
input_args
[
2
].
is_string
());
tensor_name
=
input_args
[
2
].
operator
std
::
string
();
}
CHECK_EQ
(
input_args
.
size
(),
3U
);
CHECK
(
input_args
[
2
].
is_string
());
std
::
string
tensor_name
=
input_args
[
2
].
operator
std
::
string
();
auto
out
=
pe
::
Gather
(
x
.
as_tensor_ref
(),
index
.
as_tensor_ref
(),
...
...
@@ -1335,12 +1225,9 @@ std::shared_ptr<OpStrategy> StrategyForScatterAssign(
auto
stages
=
CreateStages
({
tensor_input
,
tensor_updates
,
tensor_index
});
std
::
string
tensor_name
=
UniqName
(
"scatter_assign_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
arg_pack
.
size
(),
4U
);
CHECK
(
arg_pack
[
3
].
is_string
());
tensor_name
=
arg_pack
[
3
].
operator
std
::
string
();
}
CHECK_EQ
(
arg_pack
.
size
(),
4U
);
CHECK
(
arg_pack
[
3
].
is_string
());
std
::
string
tensor_name
=
arg_pack
[
3
].
operator
std
::
string
();
auto
out
=
pe
::
ScatterAssign
(
tensor_input
,
tensor_updates
,
tensor_index
,
target
,
axis
,
tensor_name
);
...
...
@@ -1462,12 +1349,9 @@ std::shared_ptr<OpStrategy> StrategyForScatterAdd(
auto
stages
=
CreateStages
({
tensor_input
,
tensor_updates
,
tensor_index
});
std
::
string
tensor_name
=
UniqName
(
"scatter_add_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
arg_pack
.
size
(),
4U
);
CHECK
(
arg_pack
[
3
].
is_string
());
tensor_name
=
arg_pack
[
3
].
operator
std
::
string
();
}
CHECK_EQ
(
arg_pack
.
size
(),
4U
);
CHECK
(
arg_pack
[
3
].
is_string
());
std
::
string
tensor_name
=
arg_pack
[
3
].
operator
std
::
string
();
auto
out
=
pe
::
ScatterAdd
(
tensor_input
,
tensor_updates
,
tensor_index
,
target
,
axis
,
tensor_name
);
...
...
@@ -1617,12 +1501,9 @@ std::shared_ptr<OpStrategy> StrategyForSlice(
CHECK
(
A_expr
.
as_tensor
());
ir
::
Tensor
A
=
A_expr
.
as_tensor_ref
();
std
::
string
tensor_name
=
UniqName
(
"Slice_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
arg_pack
.
size
(),
2U
);
CHECK
(
arg_pack
[
1
].
is_string
());
tensor_name
=
arg_pack
[
1
].
operator
std
::
string
();
}
CHECK_EQ
(
arg_pack
.
size
(),
2U
);
CHECK
(
arg_pack
[
1
].
is_string
());
std
::
string
tensor_name
=
arg_pack
[
1
].
operator
std
::
string
();
auto
out
=
pe
::
Slice
(
A
,
starts
,
axes
,
strides
,
decrease_axis
,
output_shape
,
tensor_name
);
...
...
@@ -1854,12 +1735,9 @@ std::shared_ptr<OpStrategy> StrategyForSliceAssign(
Expr
assign
=
arg_pack
[
1
];
CHECK
(
assign
.
as_tensor
());
std
::
string
tensor_name
=
UniqName
(
"slice_assign_output"
);
if
(
FLAGS_cinn_ir_schedule
)
{
CHECK_EQ
(
arg_pack
.
size
(),
3U
);
CHECK
(
arg_pack
[
2
].
is_string
());
tensor_name
=
arg_pack
[
2
].
operator
std
::
string
();
}
CHECK_EQ
(
arg_pack
.
size
(),
3U
);
CHECK
(
arg_pack
[
2
].
is_string
());
std
::
string
tensor_name
=
arg_pack
[
2
].
operator
std
::
string
();
auto
out
=
pe
::
SliceAssign
(
input
.
as_tensor_ref
(),
assign
.
as_tensor_ref
(),
...
...
paddle/cinn/hlir/op/transform_test.cc
浏览文件 @
70183c4b
...
...
@@ -86,40 +86,18 @@ TEST(SliceAssign, SliceAssign_Op) {
std
::
string
func_name
=
"slice_assign"
;
if
(
FLAGS_cinn_ir_schedule
)
{
std
::
string
out_name
=
"output"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
input
.
tensor
()),
common
::
CINNValue
(
assign
.
tensor
()),
common
::
CINNValue
(
out_name
)}};
std
::
vector
<
std
::
string
>
input_output_names
{
"input"
,
"assign"
,
out_name
};
auto
funcs
=
framework
::
GetFuncFromImpl
(
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
for
(
auto
func
:
funcs
)
{
LOG
(
INFO
)
<<
"Test Operator_BroadcastTo's Strategy, func is :
\n
"
<<
func
;
}
}
else
{
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
input
.
tensor
()),
common
::
CINNValue
(
assign
.
tensor
())}};
common
::
CINNValuePack
rets
=
impl
->
fcompute
(
cinn_input
);
rets
=
impl
->
fschedule
(
rets
);
// the last element is a StageMap
for
(
int
i
=
0
;
i
<
rets
->
size
()
-
1
;
i
++
)
{
Expr
temp
=
rets
[
i
];
if
(
!
temp
.
as_tensor_ref
()
->
buffer
.
defined
())
{
inputs
.
push_back
(
temp
.
as_tensor_ref
());
}
}
auto
func
=
lang
::
LowerVec
(
"slice_assign"
,
rets
.
back
(),
inputs
,
{},
{},
nullptr
,
target
);
for
(
auto
&
f
:
func
)
{
LOG
(
INFO
)
<<
"Test Strategy Codegen:
\n
"
<<
f
;
}
std
::
string
out_name
=
"output"
;
common
::
CINNValuePack
cinn_input
=
common
::
CINNValuePack
{{
common
::
CINNValue
(
input
.
tensor
()),
common
::
CINNValue
(
assign
.
tensor
()),
common
::
CINNValue
(
out_name
)}};
std
::
vector
<
std
::
string
>
input_output_names
{
"input"
,
"assign"
,
out_name
};
auto
funcs
=
framework
::
GetFuncFromImpl
(
impl
,
cinn_input
,
inputs
,
input_output_names
,
func_name
,
target
);
for
(
auto
func
:
funcs
)
{
LOG
(
INFO
)
<<
"Test Operator_BroadcastTo's Strategy, func is :
\n
"
<<
func
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录