Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
24d07b73
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
24d07b73
编写于
7月 06, 2022
作者:
Z
zyfncg
提交者:
GitHub
7月 06, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
generate map of extra attrs for ops (#44106)
上级
07b68eb3
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
137 addition
and
0 deletion
+137
-0
.gitignore
.gitignore
+1
-0
paddle/phi/api/lib/CMakeLists.txt
paddle/phi/api/lib/CMakeLists.txt
+15
-0
paddle/phi/api/yaml/api_compat.yaml
paddle/phi/api/yaml/api_compat.yaml
+9
-0
paddle/phi/api/yaml/generator/generate_op.py
paddle/phi/api/yaml/generator/generate_op.py
+2
-0
paddle/phi/api/yaml/generator/ops_extra_info_gen.py
paddle/phi/api/yaml/generator/ops_extra_info_gen.py
+110
-0
未找到文件。
.gitignore
浏览文件 @
24d07b73
...
...
@@ -5,6 +5,7 @@ paddle/fluid/API_PR.spec
paddle/fluid/eager/api/generated/*
paddle/fluid/op_use_default_grad_maker_DEV.spec
paddle/fluid/op_use_default_grad_maker_PR.spec
paddle/fluid/operators/ops_extra_info.h
paddle/phi/api/backward/backward_api.h
paddle/phi/api/backward/sparse_bw_api.h
paddle/phi/api/include/api.h
...
...
paddle/phi/api/lib/CMakeLists.txt
浏览文件 @
24d07b73
...
...
@@ -94,6 +94,14 @@ set(wrapped_infermeta_header_file
set
(
wrapped_infermeta_source_file
${
CMAKE_SOURCE_DIR
}
/paddle/phi/infermeta/generated.cc
)
# op extra info file
set
(
ops_extra_info_gen_file
${
CMAKE_SOURCE_DIR
}
/paddle/phi/api/yaml/generator/ops_extra_info_gen.py
)
set
(
api_compat_yaml_file
${
CMAKE_SOURCE_DIR
}
/paddle/phi/api/yaml/api_compat.yaml
)
set
(
ops_extra_info_file
${
CMAKE_SOURCE_DIR
}
/paddle/fluid/operators/ops_extra_info.h
)
if
(
NOT PYTHONINTERP_FOUND
)
find_package
(
PythonInterp REQUIRED
)
endif
()
...
...
@@ -211,6 +219,13 @@ else()
message
(
"remove
${
generated_argument_mapping_path
}
"
)
endif
()
# generate ops extra info
execute_process
(
COMMAND
${
PYTHON_EXECUTABLE
}
${
ops_extra_info_gen_file
}
--api_compat_yaml_path
${
api_compat_yaml_file
}
--ops_extra_info_path
${
ops_extra_info_file
}
)
message
(
"generate
${
ops_extra_info_file
}
"
)
# generate forward api
add_custom_command
(
OUTPUT
${
api_header_file
}
${
api_source_file
}
...
...
paddle/phi/api/yaml/api_compat.yaml
浏览文件 @
24d07b73
...
...
@@ -23,3 +23,12 @@
x
:
Input
outputs
:
out
:
Out
-
api
:
conv2d
extra
:
attrs
:
[
bool use_cudnn = false
,
bool fuse_relu_before_depthwise_conv = false
,
bool use_mkldnn = false
,
bool use_quantizer = false
,
str mkldnn_data_type = "float32"
,
bool fuse_relu = false
,
str fuse_activation = ""
,
bool fuse_alpha = false
,
bool fuse_beta = false
,
bool use_addto = false
,
bool fuse_residual_connection = false
,
float Scale_in = 1.0f
,
float Scale_out = 1.0f
,
float Scale_in_eltwise = 1.0f
,
'
float[]
Scale_weights
=
{1.0f}'
,
bool force_fp32_output = false
,
int workspace_size_MB = 512
,
bool exhaustive_search = false
]
paddle/phi/api/yaml/generator/generate_op.py
浏览文件 @
24d07b73
...
...
@@ -76,6 +76,8 @@ def main(api_yaml_path, backward_yaml_path, api_compat_yaml_path,
api_args_map
=
yaml
.
safe_load
(
f
)
# replace args name for OpMaker
for
api_args
in
api_args_map
:
if
api_args
[
'api'
]
not
in
forward_api_dict
:
continue
forward_api_item
=
forward_api_dict
[
api_args
[
'api'
]]
has_backward
=
True
if
forward_api_item
[
'backward'
]
else
False
if
has_backward
:
...
...
paddle/phi/api/yaml/generator/ops_extra_info_gen.py
0 → 100644
浏览文件 @
24d07b73
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
yaml
import
re
import
argparse
def
map_code_template
(
attrs_str
):
return
f
"""
#include "paddle/fluid/framework/attribute.h"
namespace paddle {{
const static std::unordered_map<std::string, paddle::framework::AttributeMap> extra_attrs_map = {{
{
attrs_str
}
}};
}} // namespace paddle
"""
ATTR_TYPE_STRING_MAP
=
{
'bool'
:
'bool'
,
'int'
:
'int'
,
'int64_t'
:
'int64_t'
,
'float'
:
'float'
,
'double'
:
'double'
,
'str'
:
'std::string'
,
'int[]'
:
'std::vector<int>'
,
'int64_t[]'
:
'std::vector<int64_t>'
,
'float[]'
:
'std::vector<float>'
,
'double[]'
:
'std::vector<double>'
,
'str[]'
:
'std::vector<std::string>'
}
def
parse_attr
(
attr_str
):
result
=
re
.
search
(
r
"(?P<attr_type>[a-z[\]]+)\s+(?P<name>[a-zA-Z0-9_]+)\s*=\s*(?P<default_val>\S+)"
,
attr_str
)
return
ATTR_TYPE_STRING_MAP
[
result
.
group
(
'attr_type'
)],
result
.
group
(
'name'
),
result
.
group
(
'default_val'
)
def
generate_extra_info
(
api_compat_yaml_path
,
ops_extra_info_path
):
compat_apis
=
[]
with
open
(
api_compat_yaml_path
,
'rt'
)
as
f
:
compat_apis
=
yaml
.
safe_load
(
f
)
extra_map_str_list
=
[]
for
api_compat_args
in
compat_apis
:
if
'extra'
in
api_compat_args
:
extra_args_map
=
api_compat_args
[
'extra'
]
# TODO(chenweihang): add inputs and outputs
if
'attrs'
in
extra_args_map
:
attr_map_list
=
[]
for
attr
in
extra_args_map
[
'attrs'
]:
attr_type
,
attr_name
,
default_val
=
parse_attr
(
attr
)
if
attr_type
.
startswith
(
"std::vector"
):
attr_map_list
.
append
(
f
"{{
\"
{
attr_name
}
\"
,
{
attr_type
}{
default_val
}
}}"
)
else
:
attr_map_list
.
append
(
f
"{{
\"
{
attr_name
}
\"
,
{
attr_type
}
{{
{
default_val
}
}}}}"
)
api_extra_attr_map
=
", "
.
join
(
attr_map_list
)
extra_map_str_list
.
append
(
f
"{{
\"
{
api_compat_args
[
'api'
]
}
\"
, {{
{
api_extra_attr_map
}
}}}}"
)
ops_extra_info_file
=
open
(
ops_extra_info_path
,
'w'
)
ops_extra_info_file
.
write
(
map_code_template
(
",
\n
"
.
join
(
extra_map_str_list
)))
ops_extra_info_file
.
close
()
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Generate PaddlePaddle Extra Param Info for Op'
)
parser
.
add_argument
(
'--api_compat_yaml_path'
,
help
=
'path to api compat yaml file'
,
default
=
'paddle/phi/api/yaml/api_compat.yaml'
)
parser
.
add_argument
(
'--ops_extra_info_path'
,
help
=
'output of generated extra_prama_info code file'
,
default
=
'paddle/fluid/operators/ops_extra_info.h'
)
options
=
parser
.
parse_args
()
api_compat_yaml_path
=
options
.
api_compat_yaml_path
ops_extra_info_path
=
options
.
ops_extra_info_path
generate_extra_info
(
api_compat_yaml_path
,
ops_extra_info_path
)
if
__name__
==
'__main__'
:
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录