Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5b97278e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5b97278e
编写于
6月 12, 2023
作者:
Z
zhangbo9674
提交者:
GitHub
6月 12, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add inplace view info to OpYamlInfoInterface (#54551)
上级
a56eba3a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
54 addition
and
4 deletion
+54
-4
paddle/fluid/ir/dialect/op_gen.py
paddle/fluid/ir/dialect/op_gen.py
+30
-2
paddle/fluid/ir/dialect/utils.h
paddle/fluid/ir/dialect/utils.h
+8
-2
paddle/fluid/operators/generator/parse_utils.py
paddle/fluid/operators/generator/parse_utils.py
+16
-0
未找到文件。
paddle/fluid/ir/dialect/op_gen.py
浏览文件 @
5b97278e
...
@@ -110,7 +110,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{
...
@@ -110,7 +110,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{
std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}});
paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}
, {{{inplace}}}, {{{view}}}
);
return std::make_tuple(inputs, attributes, outputs, run_time_info);
return std::make_tuple(inputs, attributes, outputs, run_time_info);
}}
}}
"""
"""
...
@@ -386,6 +386,10 @@ class OpInfoParser:
...
@@ -386,6 +386,10 @@ class OpInfoParser:
else
:
else
:
self
.
infer_shape_func
=
None
self
.
infer_shape_func
=
None
# parse inplace && view
self
.
inplace_map
=
self
.
parse_op_inplace_info
()
self
.
view_map
=
self
.
parse_op_view_info
()
def
cross_check
(
self
,
name_list
,
type_list
,
optional_list
=
None
):
def
cross_check
(
self
,
name_list
,
type_list
,
optional_list
=
None
):
assert
len
(
name_list
)
==
len
(
assert
len
(
name_list
)
==
len
(
type_list
type_list
...
@@ -396,7 +400,9 @@ class OpInfoParser:
...
@@ -396,7 +400,9 @@ class OpInfoParser:
),
"type list size != optional list size."
),
"type list size != optional list size."
def
parse_op_phi_name
(
self
):
def
parse_op_phi_name
(
self
):
if
self
.
parse_op_inplace_info
()
is
None
:
if
(
self
.
parse_op_inplace_info
()
is
None
)
and
(
self
.
parse_op_view_info
()
is
None
):
return
[
self
.
op_yaml_item
[
'name'
]]
return
[
self
.
op_yaml_item
[
'name'
]]
else
:
else
:
if
self
.
op_yaml_item
[
'name'
][
-
1
]
==
"_"
:
if
self
.
op_yaml_item
[
'name'
][
-
1
]
==
"_"
:
...
@@ -412,6 +418,11 @@ class OpInfoParser:
...
@@ -412,6 +418,11 @@ class OpInfoParser:
return
self
.
op_yaml_item
[
'inplace'
]
return
self
.
op_yaml_item
[
'inplace'
]
return
None
return
None
def
parse_op_view_info
(
self
):
if
'view'
in
self
.
op_yaml_item
:
return
self
.
op_yaml_item
[
'view'
]
return
None
def
parse_mutable_attribute
(
self
):
def
parse_mutable_attribute
(
self
):
"""
"""
{'axis': 'paddle::dialect::ScalarAttribute', 'rotl': 'paddle::dialect::IntArrayAttribute'}
{'axis': 'paddle::dialect::ScalarAttribute', 'rotl': 'paddle::dialect::IntArrayAttribute'}
...
@@ -1256,6 +1267,8 @@ def OpGenerator(
...
@@ -1256,6 +1267,8 @@ def OpGenerator(
# others
# others
op_infer_meta_map
=
op_info
.
infer_meta_map
op_infer_meta_map
=
op_info
.
infer_meta_map
op_kernel_map
=
op_info
.
kernel_map
op_kernel_map
=
op_info
.
kernel_map
op_inplace_map
=
op_info
.
inplace_map
op_view_map
=
op_info
.
view_map
op_interfaces
=
[
"OpYamlInfoInterface"
]
op_interfaces
=
[
"OpYamlInfoInterface"
]
op_traits
=
[]
op_traits
=
[]
...
@@ -1472,12 +1485,25 @@ def OpGenerator(
...
@@ -1472,12 +1485,25 @@ def OpGenerator(
if
op_infer_meta_map
is
not
None
:
if
op_infer_meta_map
is
not
None
:
infer_meta_func_str
=
op_infer_meta_map
[
'func'
]
infer_meta_func_str
=
op_infer_meta_map
[
'func'
]
infer_meta_param_str
=
'", "'
.
join
(
op_infer_meta_map
[
'param'
])
infer_meta_param_str
=
'", "'
.
join
(
op_infer_meta_map
[
'param'
])
kernel_func_str
=
""
kernel_func_str
=
""
kernel_param_str
=
""
kernel_param_str
=
""
if
op_kernel_map
is
not
None
:
if
op_kernel_map
is
not
None
:
kernel_func_str
=
'", "'
.
join
(
op_kernel_map
[
'func'
])
kernel_func_str
=
'", "'
.
join
(
op_kernel_map
[
'func'
])
kernel_param_str
=
'", "'
.
join
(
op_kernel_map
[
'param'
])
kernel_param_str
=
'", "'
.
join
(
op_kernel_map
[
'param'
])
inplace_str
=
""
view_str
=
""
if
op_name
[
-
1
]
==
"_"
:
if
op_inplace_map
is
not
None
:
for
key
,
value
in
op_inplace_map
.
items
():
inplace_str
+=
'{"'
+
key
+
'", "'
+
value
+
'"},'
inplace_str
=
inplace_str
[:
-
1
]
if
op_view_map
is
not
None
:
for
key
,
value
in
op_view_map
.
items
():
view_str
+=
'{"'
+
key
+
'", "'
+
value
+
'"},'
view_str
=
view_str
[:
-
1
]
op_info_func_str
=
OP_INFO_TEMPLATE
.
format
(
op_info_func_str
=
OP_INFO_TEMPLATE
.
format
(
op_name
=
op_class_name
,
op_name
=
op_class_name
,
inputs
=
inputs_info_str
,
inputs
=
inputs_info_str
,
...
@@ -1487,6 +1513,8 @@ def OpGenerator(
...
@@ -1487,6 +1513,8 @@ def OpGenerator(
infer_meta_param
=
infer_meta_param_str
,
infer_meta_param
=
infer_meta_param_str
,
kernel_func
=
kernel_func_str
,
kernel_func
=
kernel_func_str
,
kernel_param
=
kernel_param_str
,
kernel_param
=
kernel_param_str
,
inplace
=
inplace_str
,
view
=
view_str
,
)
)
# =================================== #
# =================================== #
...
...
paddle/fluid/ir/dialect/utils.h
浏览文件 @
5b97278e
...
@@ -144,14 +144,20 @@ struct OpRunTimeInfo {
...
@@ -144,14 +144,20 @@ struct OpRunTimeInfo {
std
::
vector
<
std
::
string
>
infer_meta_param
;
std
::
vector
<
std
::
string
>
infer_meta_param
;
std
::
vector
<
std
::
string
>
kernel_func
;
std
::
vector
<
std
::
string
>
kernel_func
;
std
::
vector
<
std
::
string
>
kernel_param
;
std
::
vector
<
std
::
string
>
kernel_param
;
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
inplace
;
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
view
;
OpRunTimeInfo
(
std
::
string
infer_meta_func
,
OpRunTimeInfo
(
std
::
string
infer_meta_func
,
std
::
vector
<
std
::
string
>
infer_meta_param
,
std
::
vector
<
std
::
string
>
infer_meta_param
,
std
::
vector
<
std
::
string
>
kernel_func
,
std
::
vector
<
std
::
string
>
kernel_func
,
std
::
vector
<
std
::
string
>
kernel_param
)
std
::
vector
<
std
::
string
>
kernel_param
,
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
inplace
,
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
view
)
:
infer_meta_func
(
infer_meta_func
),
:
infer_meta_func
(
infer_meta_func
),
infer_meta_param
(
infer_meta_param
),
infer_meta_param
(
infer_meta_param
),
kernel_func
(
kernel_func
),
kernel_func
(
kernel_func
),
kernel_param
(
kernel_param
)
{}
kernel_param
(
kernel_param
),
inplace
(
inplace
),
view
(
view
)
{}
};
};
}
// namespace dialect
}
// namespace dialect
...
...
paddle/fluid/operators/generator/parse_utils.py
浏览文件 @
5b97278e
...
@@ -259,12 +259,22 @@ def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]:
...
@@ -259,12 +259,22 @@ def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]:
return
kernel
return
kernel
def
delete_bracket
(
name
:
str
):
if
name
[
0
]
==
"("
:
name
=
name
.
lstrip
(
"("
)
if
name
[
-
1
]
==
")"
:
name
=
name
.
rstrip
(
")"
)
return
name
def
parse_inplace
(
op_name
:
str
,
inplace_cfg
:
str
)
->
Dict
[
str
,
str
]:
def
parse_inplace
(
op_name
:
str
,
inplace_cfg
:
str
)
->
Dict
[
str
,
str
]:
inplace_map
=
{}
inplace_map
=
{}
inplace_cfg
=
inplace_cfg
.
lstrip
(
"("
).
rstrip
(
")"
)
inplace_cfg
=
inplace_cfg
.
lstrip
(
"("
).
rstrip
(
")"
)
pairs
=
parse_plain_list
(
inplace_cfg
)
pairs
=
parse_plain_list
(
inplace_cfg
)
for
pair
in
pairs
:
for
pair
in
pairs
:
in_name
,
out_name
=
parse_plain_list
(
pair
,
sep
=
"->"
)
in_name
,
out_name
=
parse_plain_list
(
pair
,
sep
=
"->"
)
in_name
=
delete_bracket
(
in_name
)
out_name
=
delete_bracket
(
out_name
)
inplace_map
[
out_name
]
=
in_name
inplace_map
[
out_name
]
=
in_name
return
inplace_map
return
inplace_map
...
@@ -521,11 +531,17 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
...
@@ -521,11 +531,17 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
inplace_pairs
=
parse_inplace
(
op_name
,
op_entry
[
"inplace"
])
inplace_pairs
=
parse_inplace
(
op_name
,
op_entry
[
"inplace"
])
else
:
else
:
inplace_pairs
=
None
inplace_pairs
=
None
# view
if
"view"
in
op_entry
:
view_pairs
=
parse_inplace
(
op_name
,
op_entry
[
"view"
])
else
:
view_pairs
=
None
op
.
update
(
op
.
update
(
{
{
"infer_meta"
:
infer_meta
,
"infer_meta"
:
infer_meta
,
"kernel"
:
kernel
,
"kernel"
:
kernel
,
"inplace"
:
inplace_pairs
,
"inplace"
:
inplace_pairs
,
"view"
:
view_pairs
,
}
}
)
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录