Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8b17207c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8b17207c
编写于
8月 23, 2023
作者:
W
WangZhen
提交者:
GitHub
8月 23, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Gen all Apis (#56526)
上级
8fe86ebb
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
72 addition
and
76 deletion
+72
-76
paddle/fluid/ir/dialect/op_generator/api_gen.py
paddle/fluid/ir/dialect/op_generator/api_gen.py
+70
-56
paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.cc
paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.cc
+1
-14
paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h
paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h
+1
-6
未找到文件。
paddle/fluid/ir/dialect/op_generator/api_gen.py
浏览文件 @
8b17207c
...
@@ -64,7 +64,7 @@ API_IMPL_TEMPLATE = """
...
@@ -64,7 +64,7 @@ API_IMPL_TEMPLATE = """
{ret_type} {api_name}({args}){{
{ret_type} {api_name}({args}){{
{in_combine}
{in_combine}
{compute_op}
{compute_op}
{out_s
lice
}
{out_s
plit
}
{return_result}
{return_result}
}}
}}
...
@@ -73,34 +73,15 @@ API_IMPL_TEMPLATE = """
...
@@ -73,34 +73,15 @@ API_IMPL_TEMPLATE = """
COMBINE_OP_TEMPLATE
=
"""
COMBINE_OP_TEMPLATE
=
"""
auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>({in_name});"""
auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>({in_name});"""
S
LICE
_OP_TEMPLATE
=
"""
S
PLIT
_OP_TEMPLATE
=
"""
auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::S
lice
Op>({in_name});"""
auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::S
plit
Op>({in_name});"""
COMPUTE_OP_TEMPLATE
=
"""
COMPUTE_OP_TEMPLATE
=
"""
paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::{op_class_name}>({args});"""
paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::{op_class_name}>({args});"""
API_LIST
=
[
'add_n'
,
'mean'
,
'sum'
,
'divide'
,
'full'
,
'tanh_grad'
,
'mean_grad'
,
'concat'
,
'add'
,
'multiply'
,
'elementwise_pow'
,
'scale'
,
'reshape'
,
'expand'
,
'tile'
,
'add_grad'
,
'divide_grad'
,
'sum_grad'
,
]
OP_RESULT
=
'ir::OpResult'
OP_RESULT
=
'ir::OpResult'
VECTOR_TYPE
=
'ir::VectorType'
VECTOR_TYPE
=
'ir::VectorType'
PD_MANUAL_OP_LIST
=
[
'add_n'
]
def
get_op_class_name
(
op_name
):
def
get_op_class_name
(
op_name
):
...
@@ -142,56 +123,70 @@ class CodeGen:
...
@@ -142,56 +123,70 @@ class CodeGen:
ret
.
append
(
f
'
{
self
.
_type_map
[
type
]
}
{
name
}
'
)
ret
.
append
(
f
'
{
self
.
_type_map
[
type
]
}
{
name
}
'
)
return
', '
.
join
(
ret
)
return
', '
.
join
(
ret
)
def
_gen_api_attrs
(
self
,
op_info
,
with_default
):
def
_gen_api_attrs
(
self
,
op_info
,
with_default
,
is_mutable_attr
):
name_list
=
op_info
.
attribute_name_list
name_list
=
op_info
.
attribute_name_list
type_list
=
op_info
.
attribute_build_arg_type_list
type_list
=
op_info
.
attribute_build_arg_type_list
default_value_list
=
op_info
.
attribute_default_value_list
default_value_list
=
op_info
.
attribute_default_value_list
mutable_name_list
=
op_info
.
mutable_attribute_name_list
assert
len
(
name_list
)
==
len
(
type_list
)
==
len
(
default_value_list
)
assert
len
(
name_list
)
==
len
(
type_list
)
==
len
(
default_value_list
)
ret
=
[]
no_mutable_attr
=
[]
mutable_attr
=
[]
for
name
,
type
,
default_value
in
zip
(
for
name
,
type
,
default_value
in
zip
(
name_list
,
type_list
,
default_value_list
name_list
,
type_list
,
default_value_list
):
):
if
is_mutable_attr
and
name
in
mutable_name_list
:
mutable_attr
.
append
(
f
'
{
OP_RESULT
}
{
name
}
'
)
continue
if
with_default
and
default_value
is
not
None
:
if
with_default
and
default_value
is
not
None
:
if
type
in
[
'float'
,
'double'
]:
if
type
in
[
'float'
,
'double'
]:
default_value
=
default_value
.
strip
(
'"'
)
default_value
=
default_value
.
strip
(
'"'
)
ret
.
append
(
no_mutable_attr
.
append
(
'{type} {name} = {default_value}'
.
format
(
'{type} {name} = {default_value}'
.
format
(
type
=
type
,
name
=
name
,
default_value
=
default_value
type
=
type
,
name
=
name
,
default_value
=
default_value
)
)
)
)
else
:
else
:
ret
.
append
(
f
'
{
type
}
{
name
}
'
)
no_mutable_attr
.
append
(
f
'
{
type
}
{
name
}
'
)
return
', '
.
join
(
ret
)
return
', '
.
join
(
mutable_attr
+
no_mutable_attr
)
def
_gen_api_args
(
self
,
op_info
,
with_default_attr
):
def
_gen_api_args
(
self
,
op_info
,
with_default_attr
,
is_mutable_attr
):
inputs
=
self
.
_gen_api_inputs
(
op_info
)
inputs
=
self
.
_gen_api_inputs
(
op_info
)
attrs
=
self
.
_gen_api_attrs
(
op_info
,
with_default_attr
)
attrs
=
self
.
_gen_api_attrs
(
op_info
,
with_default_attr
,
is_mutable_attr
)
return
(
inputs
+
', '
+
attrs
).
strip
(
', '
)
return
(
inputs
+
', '
+
attrs
).
strip
(
', '
)
def
_gen_ret_type
(
self
,
op_info
):
def
_gen_ret_type
(
self
,
op_info
):
type_list
=
op_info
.
output_type_list
type_list
=
op_info
.
output_type_list
assert
len
(
type_list
)
>=
1
if
len
(
type_list
)
>
1
:
if
len
(
type_list
)
>
1
:
return
'std::tuple<{}>'
.
format
(
return
'std::tuple<{}>'
.
format
(
', '
.
join
([
self
.
_type_map
[
type
]
for
type
in
type_list
])
', '
.
join
([
self
.
_type_map
[
type
]
for
type
in
type_list
])
)
)
elif
len
(
type_list
)
==
1
:
elif
len
(
type_list
)
==
1
:
return
self
.
_type_map
[
type_list
[
0
]]
return
self
.
_type_map
[
type_list
[
0
]]
elif
len
(
type_list
)
==
0
:
return
'void'
def
_gen_one_declare
(
self
,
op_info
,
op_name
):
def
_gen_one_declare
(
self
,
op_info
,
op_name
,
is_mutable_attr
):
return
API_DECLARE_TEMPLATE
.
format
(
return
API_DECLARE_TEMPLATE
.
format
(
ret_type
=
self
.
_gen_ret_type
(
op_info
),
ret_type
=
self
.
_gen_ret_type
(
op_info
),
api_name
=
op_name
,
api_name
=
op_name
,
args
=
self
.
_gen_api_args
(
op_info
,
True
),
args
=
self
.
_gen_api_args
(
op_info
,
True
,
is_mutable_attr
),
)
)
def
_gen_h_file
(
self
,
op_info_items
,
namespaces
,
h_file_path
):
def
_gen_h_file
(
self
,
op_info_items
,
namespaces
,
h_file_path
):
declare_str
=
''
declare_str
=
''
for
op_info
in
op_info_items
:
for
op_info
in
op_info_items
:
for
op_name
in
op_info
.
op_phi_name
:
for
op_name
in
op_info
.
op_phi_name
:
if
op_name
not
in
API_LIST
:
# NOTE:When infer_meta_func is None, the Build() function generated in pd_op
# is wrong, so temporarily skip the automatic generation of these APIs
if
(
op_info
.
infer_meta_func
is
None
and
op_name
not
in
PD_MANUAL_OP_LIST
):
continue
continue
declare_str
+=
self
.
_gen_one_declare
(
op_info
,
op_name
)
declare_str
+=
self
.
_gen_one_declare
(
op_info
,
op_name
,
False
)
if
len
(
op_info
.
mutable_attribute_name_list
)
>
0
:
declare_str
+=
self
.
_gen_one_declare
(
op_info
,
op_name
,
True
)
body
=
declare_str
body
=
declare_str
for
namespace
in
reversed
(
namespaces
):
for
namespace
in
reversed
(
namespaces
):
body
=
NAMESPACE_TEMPLATE
.
format
(
namespace
=
namespace
,
body
=
body
)
body
=
NAMESPACE_TEMPLATE
.
format
(
namespace
=
namespace
,
body
=
body
)
...
@@ -218,9 +213,13 @@ class CodeGen:
...
@@ -218,9 +213,13 @@ class CodeGen:
combine_op_list
.
append
(
None
)
combine_op_list
.
append
(
None
)
return
combine_op
,
combine_op_list
return
combine_op
,
combine_op_list
def
_gen_compute_op_args
(
self
,
op_info
,
in_combine_op_list
):
def
_gen_compute_op_args
(
self
,
op_info
,
in_combine_op_list
,
is_mutable_attr
):
input_name_list
=
op_info
.
input_name_list
input_name_list
=
op_info
.
input_name_list
attribute_name_list
=
op_info
.
attribute_name_list
all_attr_list
=
op_info
.
attribute_name_list
no_mutable_attr_list
=
op_info
.
non_mutable_attribute_name_list
mutable_attr_list
=
op_info
.
mutable_attribute_name_list
assert
len
(
input_name_list
)
==
len
(
in_combine_op_list
)
assert
len
(
input_name_list
)
==
len
(
in_combine_op_list
)
ret
=
[]
ret
=
[]
for
input_name
,
combine_op
in
zip
(
input_name_list
,
in_combine_op_list
):
for
input_name
,
combine_op
in
zip
(
input_name_list
,
in_combine_op_list
):
...
@@ -228,61 +227,69 @@ class CodeGen:
...
@@ -228,61 +227,69 @@ class CodeGen:
ret
.
append
(
input_name
)
ret
.
append
(
input_name
)
else
:
else
:
ret
.
append
(
f
'
{
combine_op
}
.out()'
)
ret
.
append
(
f
'
{
combine_op
}
.out()'
)
ret
+=
list
(
attribute_name_list
)
if
is_mutable_attr
:
ret
+=
list
(
mutable_attr_list
+
no_mutable_attr_list
)
else
:
ret
+=
list
(
all_attr_list
)
return
', '
.
join
(
ret
)
return
', '
.
join
(
ret
)
def
_gen_compute_op
(
self
,
op_info
,
op_name
,
in_combine_op_list
):
def
_gen_compute_op
(
self
,
op_info
,
op_name
,
in_combine_op_list
,
is_mutable_attr
):
op_class_name
=
to_pascal_case
(
op_name
)
+
'Op'
op_class_name
=
to_pascal_case
(
op_name
)
+
'Op'
op_inst_name
=
op_name
+
'_op'
op_inst_name
=
op_name
+
'_op'
return
(
return
(
COMPUTE_OP_TEMPLATE
.
format
(
COMPUTE_OP_TEMPLATE
.
format
(
op_class_name
=
op_class_name
,
op_class_name
=
op_class_name
,
op_inst_name
=
op_inst_name
,
op_inst_name
=
op_inst_name
,
args
=
self
.
_gen_compute_op_args
(
op_info
,
in_combine_op_list
),
args
=
self
.
_gen_compute_op_args
(
op_info
,
in_combine_op_list
,
is_mutable_attr
),
),
),
op_inst_name
,
op_inst_name
,
)
)
def
_gen_out_s
lice
_and_ret_list
(
self
,
op_info
,
op_inst_name
):
def
_gen_out_s
plit
_and_ret_list
(
self
,
op_info
,
op_inst_name
):
name_list
=
op_info
.
output_name_list
name_list
=
op_info
.
output_name_list
type_list
=
op_info
.
output_type_list
type_list
=
op_info
.
output_type_list
s
lice
_op_str
=
''
s
plit
_op_str
=
''
ret_list
=
[]
ret_list
=
[]
for
i
,
(
name
,
type
)
in
enumerate
(
zip
(
name_list
,
type_list
)):
for
i
,
(
name
,
type
)
in
enumerate
(
zip
(
name_list
,
type_list
)):
if
VECTOR_TYPE
in
type
:
if
VECTOR_TYPE
in
type
:
s
lice_op_name
=
f
'
{
name
}
_slice
_op'
s
plit_op_name
=
f
'
{
name
}
_split
_op'
s
lice_op_str
+=
SLICE
_OP_TEMPLATE
.
format
(
s
plit_op_str
+=
SPLIT
_OP_TEMPLATE
.
format
(
op_name
=
s
lice
_op_name
,
in_name
=
f
'
{
op_inst_name
}
.result(
{
i
}
)'
op_name
=
s
plit
_op_name
,
in_name
=
f
'
{
op_inst_name
}
.result(
{
i
}
)'
)
)
ret_list
.
append
(
f
'
{
s
lice
_op_name
}
.outputs()'
)
ret_list
.
append
(
f
'
{
s
plit
_op_name
}
.outputs()'
)
else
:
else
:
ret_list
.
append
(
f
'
{
op_inst_name
}
.result(
{
i
}
)'
)
ret_list
.
append
(
f
'
{
op_inst_name
}
.result(
{
i
}
)'
)
return
s
lice
_op_str
,
ret_list
return
s
plit
_op_str
,
ret_list
def
_gen_return_result
(
self
,
ret_list
):
def
_gen_return_result
(
self
,
ret_list
):
assert
len
(
ret_list
)
>=
1
if
len
(
ret_list
)
>
1
:
if
len
(
ret_list
)
>
1
:
return
'return std::make_tuple({});'
.
format
(
', '
.
join
(
ret_list
))
return
'return std::make_tuple({});'
.
format
(
', '
.
join
(
ret_list
))
el
se
:
el
if
len
(
ret_list
)
==
1
:
return
f
'return
{
ret_list
[
0
]
}
;'
return
f
'return
{
ret_list
[
0
]
}
;'
elif
len
(
ret_list
)
==
0
:
return
'return;'
def
_gen_one_impl
(
self
,
op_info
,
op_name
):
def
_gen_one_impl
(
self
,
op_info
,
op_name
,
is_mutable_attr
):
in_combine
,
in_combine_op_list
=
self
.
_gen_in_combine
(
op_info
)
in_combine
,
in_combine_op_list
=
self
.
_gen_in_combine
(
op_info
)
compute_op
,
op_inst_name
=
self
.
_gen_compute_op
(
compute_op
,
op_inst_name
=
self
.
_gen_compute_op
(
op_info
,
op_name
,
in_combine_op_list
op_info
,
op_name
,
in_combine_op_list
,
is_mutable_attr
)
)
out_s
lice
,
ret_list
=
self
.
_gen_out_slice
_and_ret_list
(
out_s
plit
,
ret_list
=
self
.
_gen_out_split
_and_ret_list
(
op_info
,
op_inst_name
op_info
,
op_inst_name
)
)
ret
=
API_IMPL_TEMPLATE
.
format
(
ret
=
API_IMPL_TEMPLATE
.
format
(
ret_type
=
self
.
_gen_ret_type
(
op_info
),
ret_type
=
self
.
_gen_ret_type
(
op_info
),
api_name
=
op_name
,
api_name
=
op_name
,
args
=
self
.
_gen_api_args
(
op_info
,
False
),
args
=
self
.
_gen_api_args
(
op_info
,
False
,
is_mutable_attr
),
in_combine
=
in_combine
,
in_combine
=
in_combine
,
compute_op
=
compute_op
,
compute_op
=
compute_op
,
out_s
lice
=
out_slice
,
out_s
plit
=
out_split
,
return_result
=
self
.
_gen_return_result
(
ret_list
),
return_result
=
self
.
_gen_return_result
(
ret_list
),
)
)
...
@@ -293,9 +300,16 @@ class CodeGen:
...
@@ -293,9 +300,16 @@ class CodeGen:
impl_str
=
''
impl_str
=
''
for
op_info
in
op_info_items
:
for
op_info
in
op_info_items
:
for
op_name
in
op_info
.
op_phi_name
:
for
op_name
in
op_info
.
op_phi_name
:
if
op_name
not
in
API_LIST
:
# NOTE:When infer_meta_func is None, the Build() function generated in pd_op
# is wrong, so temporarily skip the automatic generation of these APIs
if
(
op_info
.
infer_meta_func
is
None
and
op_name
not
in
PD_MANUAL_OP_LIST
):
continue
continue
impl_str
+=
self
.
_gen_one_impl
(
op_info
,
op_name
)
impl_str
+=
self
.
_gen_one_impl
(
op_info
,
op_name
,
False
)
if
len
(
op_info
.
mutable_attribute_name_list
)
>
0
:
impl_str
+=
self
.
_gen_one_impl
(
op_info
,
op_name
,
True
)
body
=
impl_str
body
=
impl_str
for
namespace
in
reversed
(
namespaces
):
for
namespace
in
reversed
(
namespaces
):
body
=
NAMESPACE_TEMPLATE
.
format
(
namespace
=
namespace
,
body
=
body
)
body
=
NAMESPACE_TEMPLATE
.
format
(
namespace
=
namespace
,
body
=
body
)
...
...
paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.cc
浏览文件 @
8b17207c
...
@@ -18,18 +18,5 @@
...
@@ -18,18 +18,5 @@
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_op.h"
namespace
paddle
{
namespace
paddle
{
namespace
dialect
{
namespace
dialect
{}
// namespace dialect
std
::
vector
<
ir
::
OpResult
>
concat_grad
(
std
::
vector
<
ir
::
OpResult
>
x
,
ir
::
OpResult
out_grad
,
ir
::
OpResult
axis
)
{
auto
combine_op
=
APIBuilder
::
Instance
().
GetBuilder
()
->
Build
<
ir
::
CombineOp
>
(
x
);
paddle
::
dialect
::
ConcatGradOp
concat_grad_op
=
APIBuilder
::
Instance
().
GetBuilder
()
->
Build
<
paddle
::
dialect
::
ConcatGradOp
>
(
combine_op
.
out
(),
out_grad
,
axis
);
auto
split_op
=
APIBuilder
::
Instance
().
GetBuilder
()
->
Build
<
ir
::
SplitOp
>
(
concat_grad_op
.
result
(
0
));
return
split_op
.
outputs
();
}
}
// namespace dialect
}
// namespace paddle
}
// namespace paddle
paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h
浏览文件 @
8b17207c
...
@@ -21,10 +21,5 @@
...
@@ -21,10 +21,5 @@
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/place.h"
namespace
paddle
{
namespace
paddle
{
namespace
dialect
{
namespace
dialect
{}
// namespace dialect
std
::
vector
<
ir
::
OpResult
>
concat_grad
(
std
::
vector
<
ir
::
OpResult
>
x
,
ir
::
OpResult
out_grad
,
ir
::
OpResult
axis
);
}
// namespace dialect
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录