Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
7b70b792
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7b70b792
编写于
2月 10, 2022
作者:
Z
zyfncg
提交者:
GitHub
2月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Pten】Refactor C++ API code-gen (#39408)
* refactor C++ API code-gen * fix windows problem of C++ API
上级
224bc511
变更
5
展开全部
隐藏空白更改
内联
并排
Showing
5 changed file
with
528 addition
and
579 deletion
+528
-579
paddle/pten/api/lib/CMakeLists.txt
paddle/pten/api/lib/CMakeLists.txt
+3
-3
python/paddle/utils/code_gen/api_base.py
python/paddle/utils/code_gen/api_base.py
+487
-0
python/paddle/utils/code_gen/api_gen.py
python/paddle/utils/code_gen/api_gen.py
+11
-86
python/paddle/utils/code_gen/backward_api_gen.py
python/paddle/utils/code_gen/backward_api_gen.py
+27
-124
python/paddle/utils/code_gen/gen_utils.py
python/paddle/utils/code_gen/gen_utils.py
+0
-366
未找到文件。
paddle/pten/api/lib/CMakeLists.txt
浏览文件 @
7b70b792
...
@@ -16,7 +16,7 @@ cc_library(op_meta_info SRCS op_meta_info.cc DEPS pten_tensor)
...
@@ -16,7 +16,7 @@ cc_library(op_meta_info SRCS op_meta_info.cc DEPS pten_tensor)
cc_library
(
op_kernel_info SRCS op_kernel_info.cc DEPS pten_tensor
)
cc_library
(
op_kernel_info SRCS op_kernel_info.cc DEPS pten_tensor
)
set
(
api_gen_
utils
${
CMAKE_SOURCE_DIR
}
/python/paddle/utils/code_gen/gen_utils
.py
)
set
(
api_gen_
base
${
CMAKE_SOURCE_DIR
}
/python/paddle/utils/code_gen/api_base
.py
)
# forward api file
# forward api file
set
(
api_gen_file
${
CMAKE_SOURCE_DIR
}
/python/paddle/utils/code_gen/api_gen.py
)
set
(
api_gen_file
${
CMAKE_SOURCE_DIR
}
/python/paddle/utils/code_gen/api_gen.py
)
...
@@ -49,7 +49,7 @@ add_custom_command(
...
@@ -49,7 +49,7 @@ add_custom_command(
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
api_header_file_tmp
}
${
api_header_file
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
api_header_file_tmp
}
${
api_header_file
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
api_source_file_tmp
}
${
api_source_file
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
api_source_file_tmp
}
${
api_source_file
}
COMMENT
"copy_if_different
${
api_header_file
}
${
api_source_file
}
"
COMMENT
"copy_if_different
${
api_header_file
}
${
api_source_file
}
"
DEPENDS
${
api_yaml_file
}
${
api_gen_file
}
${
api_gen_
utils
}
DEPENDS
${
api_yaml_file
}
${
api_gen_file
}
${
api_gen_
base
}
VERBATIM
)
VERBATIM
)
# generate backward api
# generate backward api
...
@@ -62,7 +62,7 @@ add_custom_command(
...
@@ -62,7 +62,7 @@ add_custom_command(
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
bw_api_header_file_tmp
}
${
bw_api_header_file
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
bw_api_header_file_tmp
}
${
bw_api_header_file
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
bw_api_source_file_tmp
}
${
bw_api_source_file
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
bw_api_source_file_tmp
}
${
bw_api_source_file
}
COMMENT
"copy_if_different
${
bw_api_header_file
}
${
bw_api_source_file
}
"
COMMENT
"copy_if_different
${
bw_api_header_file
}
${
bw_api_source_file
}
"
DEPENDS
${
bw_api_yaml_file
}
${
bw_api_gen_file
}
${
api_gen_
utils
}
DEPENDS
${
bw_api_yaml_file
}
${
bw_api_gen_file
}
${
api_gen_
base
}
VERBATIM
)
VERBATIM
)
cc_library
(
pten_data_transform SRCS data_transform.cc DEPS pten_tensor transfer_layout_kernel cast_kernel data_device_transform
)
cc_library
(
pten_data_transform SRCS data_transform.cc DEPS pten_tensor transfer_layout_kernel cast_kernel data_device_transform
)
...
...
python/paddle/utils/code_gen/api_base.py
0 → 100644
浏览文件 @
7b70b792
此差异已折叠。
点击以展开。
python/paddle/utils/code_gen/api_gen.py
浏览文件 @
7b70b792
...
@@ -16,64 +16,19 @@ import os
...
@@ -16,64 +16,19 @@ import os
import
yaml
import
yaml
import
argparse
import
argparse
import
gen_utils
from
api_base
import
BaseAPI
class
API
:
class
ForwardAPI
(
BaseAPI
)
:
prefix_tensor_name
=
'dense_'
prefix_tensor_name
=
'dense_'
def
__init__
(
self
,
api_item_yaml
):
def
__init__
(
self
,
api_item_yaml
):
self
.
api
=
api_item_yaml
[
'api'
]
super
(
ForwardAPI
,
self
).
__init__
(
api_item_yaml
)
# args:
# inputs:
def
get_return_type
(
self
,
out_type_list
):
# names : [], list of input names
return
out_type_list
[
0
]
if
len
(
# input_info : {input_name : type}
out_type_list
)
==
1
else
"std::tuple<"
+
","
.
join
(
# attrs:
out_type_list
)
+
">"
# names : [], list of attribute names
# attr_info : { attr_name : (type, default_values)}
self
.
args
=
gen_utils
.
parse_args
(
self
.
api
,
api_item_yaml
[
'args'
])
self
.
out_type_list
,
_
=
gen_utils
.
parse_output
(
self
.
api
,
api_item_yaml
[
'output'
])
self
.
return_type
=
self
.
out_type_list
[
0
]
if
len
(
self
.
out_type_list
)
==
1
else
"std::tuple<"
+
","
.
join
(
self
.
out_type_list
)
+
">"
self
.
is_base_api
=
True
if
'invoke'
in
api_item_yaml
:
self
.
is_base_api
=
False
self
.
invoke
=
api_item_yaml
[
'invoke'
]
else
:
self
.
kernel
=
api_item_yaml
[
'kernel'
]
if
'backend'
not
in
self
.
kernel
or
len
(
self
.
kernel
[
'backend'
])
==
0
:
self
.
kernel
[
'backend'
]
=
None
if
'layout'
not
in
self
.
kernel
or
len
(
self
.
kernel
[
'layout'
])
==
0
:
self
.
kernel
[
'layout'
]
=
None
if
'data_type'
not
in
self
.
kernel
or
len
(
self
.
kernel
[
'data_type'
])
==
0
:
self
.
kernel
[
'data_type'
]
=
None
if
'param'
not
in
self
.
kernel
:
self
.
kernel
[
'param'
]
=
None
self
.
infer_meta
=
api_item_yaml
[
'infer_meta'
]
if
'param'
not
in
self
.
infer_meta
:
self
.
infer_meta
[
'param'
]
=
None
self
.
data_transform
=
{
'skip_transform'
:
[],
'support_trans_dtype'
:
[]
}
if
'data_transform'
in
api_item_yaml
:
if
'skip_transform'
in
api_item_yaml
[
'data_transform'
]:
self
.
data_transform
[
'skip_transform'
]
=
api_item_yaml
[
'data_transform'
][
'skip_transform'
]
if
'support_trans_dtype'
in
api_item_yaml
[
'data_transform'
]:
self
.
data_transform
[
'support_trans_dtype'
]
=
api_item_yaml
[
'data_transform'
][
'support_trans_dtype'
]
def
gene_api_declaration
(
self
):
return
f
"""
PADDLE_API
{
self
.
return_type
}
{
self
.
api
}
(
{
self
.
args
[
'args_declare'
]
}
);
"""
def
gene_output
(
self
,
output_type_list
):
def
gene_output
(
self
,
output_type_list
):
kernel_output
=
""
kernel_output
=
""
...
@@ -84,12 +39,12 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
...
@@ -84,12 +39,12 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
kernel_output
=
'dense_out'
kernel_output
=
'dense_out'
output_names
.
append
(
'dense_out'
)
output_names
.
append
(
'dense_out'
)
output_create
=
f
"""
output_create
=
f
"""
{
self
.
return_type
}
out;
{
self
.
outputs
[
'return_type'
]
}
out;
auto dense_out = SetKernelOutput(kernel_backend, &out);"""
auto dense_out = SetKernelOutput(kernel_backend, &out);"""
elif
len
(
output_type_list
)
>
1
:
elif
len
(
output_type_list
)
>
1
:
output_create
=
f
"""
output_create
=
f
"""
{
self
.
return_type
}
out;"""
{
self
.
outputs
[
'return_type'
]
}
out;"""
for
i
in
range
(
len
(
output_type_list
)):
for
i
in
range
(
len
(
output_type_list
)):
kernel_output
=
kernel_output
+
f
'dense_out_
{
i
}
, '
kernel_output
=
kernel_output
+
f
'dense_out_
{
i
}
, '
...
@@ -105,36 +60,6 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
...
@@ -105,36 +60,6 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
return
kernel_output
,
output_names
,
output_create
return
kernel_output
,
output_names
,
output_create
def
gene_api_code
(
self
):
if
self
.
is_base_api
:
input_tensors
,
kernel_args
,
kernel_signature
=
gen_utils
.
get_kernel_args
(
self
.
args
[
'inputs'
],
self
.
args
[
'attrs'
],
self
.
out_type_list
,
self
.
kernel
[
'param'
],
self
.
data_transform
)
outputs_args
,
output_names
,
output_create
=
self
.
gene_output
(
self
.
out_type_list
)
return
f
"""
PADDLE_API
{
self
.
return_type
}
{
self
.
api
}
(
{
self
.
args
[
"args_define"
]
}
) {{
{
gen_utils
.
gene_kernel_select
(
self
.
api
,
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
],
self
.
kernel
)
}
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
{
input_tensors
}
{
output_create
}
{
gen_utils
.
gene_infer_meta
(
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
][
'names'
],
output_names
,
self
.
infer_meta
)
}
using kernel_signature =
{
kernel_signature
}
;
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(
{
kernel_args
}
,
{
outputs_args
}
);
return out;
}}
"""
else
:
return
f
"""
PADDLE_API
{
self
.
return_type
}
{
self
.
api
}
(
{
self
.
args
[
"args_define"
]
}
) {{
return
{
self
.
invoke
}
;
}}
"""
def
header_include
():
def
header_include
():
return
"""
return
"""
...
@@ -203,7 +128,7 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
...
@@ -203,7 +128,7 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
source_file
.
write
(
namespace
[
0
])
source_file
.
write
(
namespace
[
0
])
for
api
in
apis
:
for
api
in
apis
:
api_code
=
API
(
api
)
api_code
=
Forward
API
(
api
)
print
(
api_code
.
gene_api_declaration
())
print
(
api_code
.
gene_api_declaration
())
header_file
.
write
(
api_code
.
gene_api_declaration
())
header_file
.
write
(
api_code
.
gene_api_declaration
())
source_file
.
write
(
api_code
.
gene_api_code
())
source_file
.
write
(
api_code
.
gene_api_code
())
...
...
python/paddle/utils/code_gen/backward_api_gen.py
浏览文件 @
7b70b792
...
@@ -17,52 +17,16 @@ import yaml
...
@@ -17,52 +17,16 @@ import yaml
import
argparse
import
argparse
import
re
import
re
import
gen_utils
from
api_base
import
BaseAPI
class
BackwardAPI
:
class
BackwardAPI
(
BaseAPI
)
:
def
__init__
(
self
,
backward_item_yaml
):
def
__init__
(
self
,
backward_item_yaml
):
self
.
backward_api
=
backward_item_yaml
[
'backward_api'
]
super
(
BackwardAPI
,
self
).
__init__
(
backward_item_yaml
)
self
.
args
,
self
.
output_type_list
,
self
.
return_comment
=
self
.
parse_and_check_args
(
self
.
check_args
(
backward_item_yaml
[
'forward'
])
backward_item_yaml
[
'forward'
],
backward_item_yaml
[
'args'
],
backward_item_yaml
[
'output'
])
def
get_api_name
(
self
,
api_item_yaml
):
self
.
return_type
=
self
.
output_type_list
[
0
]
if
len
(
return
api_item_yaml
[
'backward_api'
]
self
.
output_type_list
)
==
1
else
"std::vector<std::vector<Tensor>>"
self
.
is_base_api
=
True
if
'invoke'
in
backward_item_yaml
:
self
.
is_base_api
=
False
self
.
invoke
=
backward_item_yaml
[
'invoke'
]
else
:
self
.
kernel
=
backward_item_yaml
[
'kernel'
]
if
'backend'
not
in
self
.
kernel
or
len
(
self
.
kernel
[
'backend'
])
==
0
:
self
.
kernel
[
'backend'
]
=
None
if
'layout'
not
in
self
.
kernel
or
len
(
self
.
kernel
[
'layout'
])
==
0
:
self
.
kernel
[
'layout'
]
=
None
if
'data_type'
not
in
self
.
kernel
or
len
(
self
.
kernel
[
'data_type'
])
==
0
:
self
.
kernel
[
'data_type'
]
=
None
if
'param'
not
in
self
.
kernel
or
len
(
self
.
kernel
[
'param'
])
==
0
:
self
.
kernel
[
'param'
]
=
None
self
.
infer_meta
=
backward_item_yaml
[
'infer_meta'
]
if
'param'
not
in
self
.
infer_meta
or
len
(
self
.
infer_meta
[
'param'
])
==
0
:
self
.
infer_meta
[
'param'
]
=
None
self
.
data_transform
=
{
'skip_transform'
:
[],
'support_trans_dtype'
:
[]
}
if
'data_transform'
in
backward_item_yaml
:
if
'skip_transform'
in
backward_item_yaml
[
'data_transform'
]:
self
.
data_transform
[
'skip_transform'
]
=
backward_item_yaml
[
'data_transform'
][
'skip_transform'
]
if
'support_trans_dtype'
in
backward_item_yaml
[
'data_transform'
]:
self
.
data_transform
[
'support_trans_dtype'
]
=
backward_item_yaml
[
'data_transform'
][
'support_trans_dtype'
]
def
parse_forward_config
(
self
,
forward_config
):
def
parse_forward_config
(
self
,
forward_config
):
# api_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
# api_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
...
@@ -71,51 +35,39 @@ class BackwardAPI:
...
@@ -71,51 +35,39 @@ class BackwardAPI:
forward_config
)
forward_config
)
api
=
result
.
group
(
'api'
)
api
=
result
.
group
(
'api'
)
outputs
=
[
item
.
strip
()
for
item
in
result
.
group
(
'outputs'
).
split
(
','
)]
outputs
=
[
item
.
strip
()
for
item
in
result
.
group
(
'outputs'
).
split
(
','
)]
forward_args
=
gen_utils
.
parse_args
(
api
,
result
.
group
(
'args'
))
fw_inputs
,
fw_attrs
,
_
,
=
self
.
parse_input_and_attr
(
api
,
result
.
group
(
'args'
))
return
api
,
f
orward_args
[
'inputs'
],
forward_args
[
'attrs'
]
,
outputs
return
api
,
f
w_inputs
,
fw_attrs
,
outputs
def
parse_and_check_args
(
self
,
forward_config
,
args_config
,
output
_config
):
def
check_args
(
self
,
forward
_config
):
# parse the forward and backward config
# parse the forward and backward config
_
,
fw_inputs
,
fw_attrs
,
fw_outputs
=
self
.
parse_forward_config
(
_
,
fw_inputs
,
fw_attrs
,
fw_outputs
=
self
.
parse_forward_config
(
forward_config
)
forward_config
)
bw_args
=
gen_utils
.
parse_args
(
self
.
backward_api
,
args_config
)
# check the inputs of backward
# check the inputs of backward
for
input
in
bw_args
[
'inputs'
]
[
'names'
]:
for
input
in
self
.
inputs
[
'names'
]:
if
input
not
in
fw_inputs
and
input
not
in
fw_outputs
:
if
input
not
in
fw_inputs
and
input
not
in
fw_outputs
:
if
input
.
endswith
(
'_grad'
):
if
input
.
endswith
(
'_grad'
):
original_name
=
input
[:
-
5
]
original_name
=
input
[:
-
5
]
assert
original_name
in
fw_outputs
,
\
assert
original_name
in
fw_outputs
,
\
f
"
{
self
.
backward_
api
}
: Input Tensor error: the input tensor(
{
input
}
) of backward should be an input or output or grad of output in forward api.
\
f
"
{
self
.
api
}
: Input Tensor error: the input tensor(
{
input
}
) of backward should be an input or output or grad of output in forward api.
\
Please check the forward of
{
self
.
backward_
api
}
in yaml."
Please check the forward of
{
self
.
api
}
in yaml."
# check the attributes of backward
# check the attributes of backward
for
attr
in
bw_args
[
'attrs'
]
[
'names'
]:
for
attr
in
self
.
attrs
[
'names'
]:
assert
attr
in
fw_attrs
[
'names'
]
and
bw_args
[
'attrs'
]
[
'attr_info'
][
attr
][
0
]
==
fw_attrs
[
'attr_info'
][
attr
][
0
],
\
assert
attr
in
fw_attrs
[
'names'
]
and
self
.
attrs
[
'attr_info'
][
attr
][
0
]
==
fw_attrs
[
'attr_info'
][
attr
][
0
],
\
f
"
{
self
.
backward_
api
}
: Attribute error: The attribute(
{
attr
}
) of backward isn't consistent with forward api.
\
f
"
{
self
.
api
}
: Attribute error: The attribute(
{
attr
}
) of backward isn't consistent with forward api.
\
Please check the args of
{
self
.
backward_
api
}
in yaml."
Please check the args of
{
self
.
api
}
in yaml."
# check the output of backward
# check the output of backward
out_type_list
,
return_comment
=
gen_utils
.
parse_output
(
assert
len
(
self
.
outputs
[
'types'
])
<=
len
(
fw_inputs
[
'names'
]),
\
self
.
backward_api
,
output_config
)
f
"
{
self
.
api
}
: Output error: The number of ouputs should be less then the number of inputs of forward api.
\
assert
len
(
out_type_list
)
<=
len
(
fw_inputs
[
'names'
]),
\
Please check the output of
{
self
.
api
}
in yaml."
f
"
{
self
.
backward_api
}
: Output error: The number of ouputs should be less then the number of inputs of forward api.
\
Please check the output of
{
self
.
backward_api
}
in yaml."
return
bw_args
,
out_type_list
,
return_comment
def
gene_api_declaration
(
self
):
if
self
.
return_comment
:
return
f
"""
//
{
self
.
return_comment
}
{
self
.
return_type
}
{
self
.
backward_api
}
(
{
self
.
args
[
'args_declare'
]
}
);
"""
else
:
def
get_return_type
(
self
,
out_type_list
):
return
f
"""
return
out_type_list
[
0
]
if
len
(
{
self
.
return_type
}
{
self
.
backward_api
}
(
{
self
.
args
[
'args_declare'
]
}
);
out_type_list
)
==
1
else
"std::vector<std::vector<Tensor>>"
"""
def
gene_output
(
self
,
output_type_list
):
def
gene_output
(
self
,
output_type_list
):
kernel_output
=
""
kernel_output
=
""
...
@@ -126,12 +78,12 @@ class BackwardAPI:
...
@@ -126,12 +78,12 @@ class BackwardAPI:
kernel_output
=
'dense_out'
kernel_output
=
'dense_out'
output_names
.
append
(
'dense_out'
)
output_names
.
append
(
'dense_out'
)
output_create
=
f
"""
output_create
=
f
"""
{
self
.
return_type
}
out;
{
self
.
outputs
[
'return_type'
]
}
out;
auto dense_out = SetKernelOutput(kernel_backend, &out);"""
auto dense_out = SetKernelOutput(kernel_backend, &out);"""
elif
len
(
output_type_list
)
>
1
:
elif
len
(
output_type_list
)
>
1
:
output_create
=
f
"""
output_create
=
f
"""
{
self
.
return_type
}
out(
{
len
(
output_type_list
)
}
);"""
{
self
.
outputs
[
'return_type'
]
}
out(
{
len
(
output_type_list
)
}
);"""
for
i
,
out_type_item
in
enumerate
(
output_type_list
):
for
i
,
out_type_item
in
enumerate
(
output_type_list
):
kernel_output
=
kernel_output
+
f
'dense_out_
{
i
}
, '
kernel_output
=
kernel_output
+
f
'dense_out_
{
i
}
, '
...
@@ -150,58 +102,10 @@ class BackwardAPI:
...
@@ -150,58 +102,10 @@ class BackwardAPI:
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"{} : Output error: the output should not be empty."
.
format
(
"{} : Output error: the output should not be empty."
.
format
(
self
.
backward_
api
))
self
.
api
))
return
kernel_output
,
output_names
,
output_create
return
kernel_output
,
output_names
,
output_create
def
gene_api_code
(
self
):
if
self
.
is_base_api
:
input_tensors
,
kernel_args
,
kernel_signature
=
gen_utils
.
get_kernel_args
(
self
.
args
[
'inputs'
],
self
.
args
[
'attrs'
],
self
.
output_type_list
,
self
.
kernel
[
'param'
],
self
.
data_transform
)
outputs_args
,
output_names
,
output_create
=
self
.
gene_output
(
self
.
output_type_list
)
return
f
"""
//
{
self
.
return_comment
}
{
self
.
return_type
}
{
self
.
backward_api
}
(
{
self
.
args
[
"args_define"
]
}
) {{
{
gen_utils
.
gene_kernel_select
(
self
.
backward_api
,
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
],
self
.
kernel
)
}
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
{
input_tensors
}
{
output_create
}
{
gen_utils
.
gene_infer_meta
(
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
][
'names'
],
output_names
,
self
.
infer_meta
)
}
using kernel_signature =
{
kernel_signature
}
;
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(
{
kernel_args
}
,
{
outputs_args
}
);
return out;
}}
"""
else
:
inveke_func_name
=
self
.
invoke
.
split
(
'('
)[
0
].
strip
()
if
inveke_func_name
in
self
.
args
[
'attrs'
][
'names'
]:
# Adjust the param whose name is same with api invoked.
pattern
=
'\W'
+
inveke_func_name
+
'[^A-Za-z0-9_(]'
def
adjust_name
(
matched
):
matched_str
=
matched
.
group
()
return
matched_str
[
0
:
-
1
]
+
'_val'
+
matched_str
[
-
1
]
invoke_code
=
re
.
sub
(
pattern
,
adjust_name
,
self
.
invoke
)
params_code
=
re
.
sub
(
pattern
,
adjust_name
,
self
.
args
[
"args_define"
])
else
:
invoke_code
=
self
.
invoke
params_code
=
self
.
args
[
"args_define"
]
return
f
"""
//
{
self
.
return_comment
}
{
self
.
return_type
}
{
self
.
backward_api
}
(
{
params_code
}
) {{
return
{
invoke_code
}
;
}}
"""
def
header_include
():
def
header_include
():
return
"""
return
"""
...
@@ -263,7 +167,6 @@ def generate_backward_api(backward_yaml_path, header_file_path,
...
@@ -263,7 +167,6 @@ def generate_backward_api(backward_yaml_path, header_file_path,
for
bw_api
in
bw_apis
:
for
bw_api
in
bw_apis
:
bw_api
=
BackwardAPI
(
bw_api
)
bw_api
=
BackwardAPI
(
bw_api
)
# print(api_code.gene_api_declaration())
header_file
.
write
(
bw_api
.
gene_api_declaration
())
header_file
.
write
(
bw_api
.
gene_api_declaration
())
source_file
.
write
(
bw_api
.
gene_api_code
())
source_file
.
write
(
bw_api
.
gene_api_code
())
...
...
python/paddle/utils/code_gen/gen_utils.py
已删除
100644 → 0
浏览文件 @
224bc511
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
re
PREFIX_TENSOR_NAME
=
'dense_'
PREFIX_META_TENSOR_NAME
=
'meta_'
def
parse_args
(
api_name
,
args_str
):
"""
Returns:
{ inputs : {
names : [] // list of input names
input_info : { input_name : type }
}
attrs: {
names : [] // list of attribute names
attr_info : { attr_name : (type, default_value)}
}
args_declare : "str" // str of funtion params with default value. Example: (..., bool flag=false)
args_define : "str" // str of funtion params without default value. Example: (..., bool flag)
}
"""
inputs
=
{
'names'
:
[],
'input_info'
:
{}}
attrs
=
{
'names'
:
[],
'attr_info'
:
{}}
args_str
=
args_str
.
strip
()
assert
args_str
.
startswith
(
'('
)
and
args_str
.
endswith
(
')'
),
\
f
"Args declaration should start with '(' and end with ')', please check the args of
{
api_name
}
in yaml."
args_str
=
args_str
[
1
:
-
1
]
args_list
=
args_str
.
split
(
','
)
input_types
=
[
'const Tensor&'
,
'const Tensor &'
,
'const std::vector<Tensor>&'
,
'const std::vector<Tensor> &'
]
attr_types
=
[
'const Scalar&'
,
'const Scalar &'
,
'const ScalarArray&'
,
'const ScalarArray &'
,
\
'int'
,
'int32_t'
,
'int64_t'
,
'size_t'
,
'float'
,
'double'
,
'bool'
,
\
'const std::vector<int64_t>&'
,
'Backend'
,
'DataLayout'
,
'DataType'
]
args_declare_str
=
""
args_define_str
=
""
for
item
in
args_list
:
item
=
item
.
strip
()
# match the input tensor
has_input
=
False
for
in_type
in
input_types
:
if
item
.
startswith
(
in_type
):
input_name
=
item
[
len
(
in_type
):].
strip
()
assert
len
(
input_name
)
>
0
,
\
f
"The input tensor name should not be empty. Please check the args of
{
api_name
}
in yaml."
assert
len
(
attrs
[
'names'
])
==
0
,
\
f
"The input Tensor should appear before attributes. please check the position of
{
api_name
}
:input(
{
input_name
}
) in yaml"
inputs
[
'names'
].
append
(
input_name
)
inputs
[
'input_info'
][
input_name
]
=
in_type
args_declare_str
=
args_declare_str
+
in_type
+
' '
+
input_name
+
', '
args_define_str
=
args_define_str
+
in_type
+
' '
+
input_name
+
', '
has_input
=
True
break
if
has_input
:
continue
# match the attribute
for
attr_type
in
attr_types
:
if
item
.
startswith
(
attr_type
):
attr_name
=
item
[
len
(
attr_type
):].
strip
()
assert
len
(
attr_name
)
>
0
,
\
f
"The attribute name should not be empty. Please check the args of
{
api_name
}
in yaml."
default_value
=
None
if
'='
in
attr_name
:
attr_infos
=
attr_name
.
split
(
'='
)
attr_name
=
attr_infos
[
0
].
strip
()
default_value
=
attr_infos
[
1
].
strip
()
default_value_str
=
""
if
default_value
is
None
else
'='
+
default_value
args_declare_str
=
args_declare_str
+
attr_type
+
' '
+
attr_name
+
default_value_str
+
', '
args_define_str
=
args_define_str
+
attr_type
+
' '
+
attr_name
+
', '
attrs
[
'names'
].
append
(
attr_name
)
attrs
[
'attr_info'
][
attr_name
]
=
(
attr_type
,
default_value
)
break
args
=
{
'inputs'
:
inputs
,
'attrs'
:
attrs
,
'args_declare'
:
args_declare_str
[:
-
2
],
'args_define'
:
args_define_str
[:
-
2
]
}
return
args
def
parse_output
(
api_name
,
output_config
):
def
parse_output_item
(
output_item
):
alllowd_output_types
=
[
'Tensor'
,
'std::vector<Tensor>'
]
if
re
.
search
(
r
'\(\w*\)'
,
output_item
):
result
=
re
.
search
(
r
"(?P<out_type>[a-zA-Z0-9_<>]+)\s*\((?P<name>\w+)\)"
,
output_item
)
out_type
=
result
.
group
(
'out_type'
)
assert
out_type
in
alllowd_output_types
,
\
f
"
{
api_name
}
: Output type error: the output type only support Tensor and std::vector<Tensor>,
\
but now is
{
out_type
}
."
return
out_type
,
result
.
group
(
'name'
)
else
:
if
output_item
.
strip
()
in
alllowd_output_types
:
return
output_item
.
strip
(),
'out'
else
:
raise
ValueError
(
"{} : Output type error: the output type only support Tensor and std::vector<Tensor>,
\
but now is {}."
.
format
(
api_name
,
out_type
))
temp_list
=
output_config
.
split
(
','
)
if
len
(
temp_list
)
==
1
:
out_type
,
out_name
=
parse_output_item
(
temp_list
[
0
])
return
[
out_type
],
out_name
else
:
out_type_list
=
[]
out_name_list
=
[]
for
output_item
in
temp_list
:
out_type
,
out_name
=
parse_output_item
(
output_item
)
out_type_list
.
append
(
out_type
)
out_name_list
.
append
(
out_name
)
return
out_type_list
,
", "
.
join
(
out_name_list
)
def
gene_kernel_select
(
api
,
input_names
,
attrs
,
kernel
)
->
str
:
kernel_key_item_init
=
"""
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
"""
# Check the tensor options
attr_backend_count
=
0
attr_layout_count
=
0
attr_data_type_count
=
0
for
attr_name
in
attrs
[
'names'
]:
if
attrs
[
'attr_info'
][
attr_name
][
0
]
==
'Backend'
:
assert
kernel
[
'backend'
]
is
not
None
,
\
f
"
{
api
}
api: When there is a parameter with 'Backend' type in attributes, you must set backend of kernel manually."
attr_backend_count
=
attr_backend_count
+
1
if
attrs
[
'attr_info'
][
attr_name
][
0
]
==
'DataLayout'
:
assert
kernel
[
'layout'
]
is
not
None
,
\
f
"
{
api
}
api: When there is a parameter with 'DataLayout' type in attributes, you must set layout of kernel manually."
attr_layout_count
=
attr_layout_count
+
1
if
attrs
[
'attr_info'
][
attr_name
][
0
]
==
'DataType'
:
assert
kernel
[
'data_type'
]
is
not
None
,
\
f
"
{
api
}
api: When there is a parameter with 'DataType' type in attributes, you must set data_type of kernel manually."
attr_data_type_count
=
attr_data_type_count
+
1
# preprocess kernel configures
kernel_select_code
=
""
if
kernel
[
'backend'
]
is
not
None
:
if
'>'
in
kernel
[
'backend'
]:
vars_list
=
kernel
[
'backend'
].
split
(
'>'
)
assert
len
(
vars_list
)
==
2
,
f
"
{
api
}
api: The number of params to set backend with '>' only allows 2, but received
{
len
(
vars_list
)
}
."
assert
(
vars_list
[
0
].
strip
()
in
attrs
[
'names'
])
and
(
attrs
[
'attr_info'
][
vars_list
[
0
].
strip
()][
0
]
==
'Backend'
),
\
f
"
{
api
}
api: When use '>' to set kernel backend, the first param should be a attribute with Backend type."
kernel_select_code
=
kernel_select_code
+
f
"""
kernel_backend = ParseBackendWithInputOrder(
{
vars_list
[
0
].
strip
()
}
,
{
vars_list
[
1
].
strip
()
}
);
"""
else
:
args_str
=
""
for
ele
in
kernel
[
'backend'
].
split
(
','
):
args_str
=
args_str
+
ele
.
strip
()
+
', '
kernel_select_code
=
kernel_select_code
+
f
"""
kernel_backend = ParseBackend(
{
args_str
[:
-
2
]
}
);
"""
if
kernel
[
'layout'
]
is
not
None
:
if
'>'
in
kernel
[
'layout'
]:
vars_list
=
kernel
[
'layout'
].
split
(
'>'
)
assert
len
(
vars_list
)
==
2
,
f
"
{
api
}
api: The number of params to set layout with '>' only allows 2, but received
{
len
(
vars_list
)
}
."
assert
vars_list
[
0
].
strip
()
in
attrs
[
'names'
]
and
attrs
[
'attr_info'
][
vars_list
[
0
].
strip
()][
0
]
==
'DataLayout'
,
\
f
"
{
api
}
api: When use '>' to set kernel layout, the first param should be a attribute with DataLayout type."
kernel_select_code
=
kernel_select_code
+
f
"""
kernel_layout = ParseLayoutWithInputOrder(
{
vars_list
[
0
].
strip
()
}
,
{
vars_list
[
1
].
strip
()
}
);
"""
else
:
vars_list
=
kernel
[
'layout'
].
split
(
','
)
assert
len
(
vars_list
)
==
1
,
f
"
{
api
}
api: The number of params to set layout must be 1, but received
{
len
(
vars_list
)
}
."
kernel_select_code
=
kernel_select_code
+
f
"""
kernel_layout = ParseLayout(
{
vars_list
[
0
].
strip
()
}
);
"""
if
kernel
[
'data_type'
]
is
not
None
:
if
'>'
in
kernel
[
'data_type'
]:
vars_list
=
kernel
[
'data_type'
].
split
(
'>'
)
assert
len
(
vars_list
)
==
2
,
f
"
{
api
}
api: The number of params to set data_type with '>' only allows 2, but received
{
len
(
vars_list
)
}
."
assert
vars_list
[
0
].
strip
()
in
attrs
[
'names'
]
and
attrs
[
'attr_info'
][
vars_list
[
0
].
strip
()][
0
]
==
'DataType'
,
\
f
"
{
api
}
api: When use '>' to set kernel data_type, the first param should be a attribute with DataType type."
kernel_select_code
=
kernel_select_code
+
f
"""
kernel_data_type = ParseDataTypeWithInputOrder(
{
vars_list
[
0
].
strip
()
}
,
{
vars_list
[
1
].
strip
()
}
);
"""
else
:
vars_list
=
kernel
[
'data_type'
].
split
(
','
)
assert
len
(
vars_list
)
==
1
,
f
"
{
api
}
api: The number of params to set data_type only allows 2, but received
{
len
(
vars_list
)
}
."
kernel_select_code
=
kernel_select_code
+
f
"""
kernel_data_type = ParseDataType(
{
vars_list
[
0
].
strip
()
}
);
"""
if
len
(
input_names
)
==
0
:
assert
attr_backend_count
>
0
and
attr_layout_count
>
0
and
attr_data_type_count
>
0
,
\
f
"
{
api
}
api: When there is no input tensor, the args must have 'Backend', 'DataLayout' and 'DataType'."
kernel_select_args
=
""
for
input_name
in
input_names
:
kernel_select_args
=
kernel_select_args
+
input_name
+
", "
if
len
(
kernel_select_args
)
>
2
:
kernel_select_args
=
kernel_select_args
[:
-
2
]
kernel_select_code
=
kernel_key_item_init
+
kernel_select_code
if
len
(
input_names
)
>
0
:
kernel_select_code
=
kernel_select_code
+
f
"""
if (kernel_backend == Backend::UNDEFINED
|| kernel_layout == DataLayout::UNDEFINED
|| kernel_data_type == DataType::UNDEFINED ) {{
auto kernel_key_set = ParseKernelKeyByInputArgs(
{
kernel_select_args
}
);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {{
kernel_backend = kernel_key.backend();
}}
if (kernel_layout == DataLayout::UNDEFINED) {{
kernel_layout = kernel_key.layout();
}}
if (kernel_data_type == DataType::UNDEFINED) {{
kernel_data_type = kernel_key.dtype();
}}
}}"""
kernel_select_code
=
kernel_select_code
+
f
"""
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"
{
kernel
[
'func'
]
}
", {{kernel_backend, kernel_layout, kernel_data_type}});
VLOG(6) << "
{
api
}
API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
VLOG(6) << "
{
api
}
API kernel: " << kernel;"""
return
kernel_select_code
def
gene_infer_meta
(
input_names
,
attr_names
,
output_names
,
infer_meta
)
->
str
:
infer_meta_params
=
infer_meta
[
'param'
]
+
output_names
if
infer_meta
[
'param'
]
is
not
None
else
input_names
+
attr_names
+
output_names
# generate meta tensors
meta_tensor_code
=
""
param_code
=
""
for
param
in
infer_meta_params
:
if
param
in
input_names
:
param_code
=
param_code
+
"MakeMetaTensor(*"
+
PREFIX_TENSOR_NAME
+
param
+
"), "
elif
param
in
output_names
:
meta_tensor_code
=
meta_tensor_code
+
" pten::MetaTensor "
+
param
.
replace
(
PREFIX_TENSOR_NAME
,
PREFIX_META_TENSOR_NAME
)
+
"("
+
param
+
");
\n
"
param_code
=
param_code
+
"&"
+
param
.
replace
(
PREFIX_TENSOR_NAME
,
PREFIX_META_TENSOR_NAME
)
+
", "
elif
param
in
attr_names
:
param_code
=
param_code
+
param
+
", "
elif
isinstance
(
param
,
str
):
param_code
=
param_code
+
"
\"
"
+
param
+
"
\"
, "
elif
isinstance
(
param
,
bool
):
param_code
=
param_code
+
str
(
param
).
lower
()
+
", "
else
:
param_code
=
param_code
+
str
(
param
)
+
", "
param_code
=
param_code
[:
-
2
]
return
f
"""
{
meta_tensor_code
}
pten::
{
infer_meta
[
'func'
]
}
(
{
param_code
}
);
"""
def
get_kernel_args
(
inputs
,
attrs
,
out_type_list
,
kernel_param
,
data_transform
):
input_trans_map
=
{
'const Tensor&'
:
'const pten::DenseTensor&'
,
'const Tensor &'
:
'const pten::DenseTensor&'
,
'const std::vector<Tensor>&'
:
'const std::vector<pten::DenseTensor>&'
,
'const std::vector<Tensor> &'
:
'const std::vector<pten::DenseTensor>&'
}
out_trans_map
=
{
'Tensor'
:
'pten::DenseTensor*'
,
'std::vector<Tensor>'
:
'std::vector<pten::DenseTensor*>&'
}
input_names
=
inputs
[
'names'
]
input_infos
=
inputs
[
'input_info'
]
kernel_args_type_list
=
[
'const platform::DeviceContext&'
]
input_tensor_code
=
""
for
input_name
in
input_names
:
# set input code
input_tensor_code
=
input_tensor_code
+
f
"""
auto
{
PREFIX_TENSOR_NAME
}{
input_name
}
= TensorToDenseTensor(
{
input_name
}
);"""
attr_names
=
attrs
[
'names'
]
if
kernel_param
is
None
:
kernel_param
=
input_names
+
attr_names
input_tensor_code
=
""
for
i
,
input_name
in
enumerate
(
input_names
):
# set input code
if
input_name
in
kernel_param
:
trans_flag
=
"{}"
if
input_name
in
data_transform
[
'skip_transform'
]:
trans_flag
=
"{true}"
elif
input_name
in
data_transform
[
'support_trans_dtype'
]:
trans_flag
=
"{false, true}"
input_tensor_code
=
input_tensor_code
+
f
"""
auto
{
PREFIX_TENSOR_NAME
}{
input_name
}
= PrepareData(
{
input_name
}
, kernel.InputAt(
{
i
}
),
{
trans_flag
}
);"""
else
:
input_tensor_code
=
input_tensor_code
+
f
"""
auto
{
PREFIX_TENSOR_NAME
}{
input_name
}
= TensorToDenseTensor(
{
input_name
}
);"""
kernel_args
=
"*dev_ctx, "
for
param
in
kernel_param
:
if
param
in
input_names
:
kernel_args
=
kernel_args
+
"*"
+
PREFIX_TENSOR_NAME
+
param
+
", "
kernel_args_type_list
.
append
(
input_trans_map
[
input_infos
[
param
]])
elif
param
in
attr_names
:
# set attr for kernel_context
if
'ScalarArray'
in
attrs
[
'attr_info'
][
param
][
0
]:
kernel_args_type_list
.
append
(
'const pten::ScalarArray&'
)
param
=
'pten::ScalarArray('
+
param
+
')'
elif
'Scalar'
in
attrs
[
'attr_info'
][
param
][
0
]:
kernel_args_type_list
.
append
(
'const pten::Scalar&'
)
param
=
'pten::Scalar('
+
param
+
')'
else
:
kernel_args_type_list
.
append
(
attrs
[
'attr_info'
][
param
][
0
])
kernel_args
=
kernel_args
+
param
+
", "
elif
isinstance
(
param
,
bool
):
kernel_args
=
kernel_args
+
str
(
param
).
lower
()
+
", "
else
:
kernel_args
=
kernel_args
+
str
(
param
)
+
", "
for
out_type
in
out_type_list
:
kernel_args_type_list
.
append
(
out_trans_map
[
out_type
])
kernel_signature
=
"void(*)("
+
", "
.
join
(
kernel_args_type_list
)
+
")"
return
input_tensor_code
,
kernel_args
[:
-
2
],
kernel_signature
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录