Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7b70b792
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7b70b792
编写于
2月 10, 2022
作者:
Z
zyfncg
提交者:
GitHub
2月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Pten】Refactor C++ API code-gen (#39408)
* refactor C++ API code-gen * fix windows problem of C++ API
上级
224bc511
变更
5
展开全部
显示空白变更内容
内联
并排
Showing
5 changed file
with
528 addition
and
579 deletion
+528
-579
paddle/pten/api/lib/CMakeLists.txt
paddle/pten/api/lib/CMakeLists.txt
+3
-3
python/paddle/utils/code_gen/api_base.py
python/paddle/utils/code_gen/api_base.py
+487
-0
python/paddle/utils/code_gen/api_gen.py
python/paddle/utils/code_gen/api_gen.py
+11
-86
python/paddle/utils/code_gen/backward_api_gen.py
python/paddle/utils/code_gen/backward_api_gen.py
+27
-124
python/paddle/utils/code_gen/gen_utils.py
python/paddle/utils/code_gen/gen_utils.py
+0
-366
未找到文件。
paddle/pten/api/lib/CMakeLists.txt
浏览文件 @
7b70b792
...
...
@@ -16,7 +16,7 @@ cc_library(op_meta_info SRCS op_meta_info.cc DEPS pten_tensor)
cc_library
(
op_kernel_info SRCS op_kernel_info.cc DEPS pten_tensor
)
set
(
api_gen_
utils
${
CMAKE_SOURCE_DIR
}
/python/paddle/utils/code_gen/gen_utils
.py
)
set
(
api_gen_
base
${
CMAKE_SOURCE_DIR
}
/python/paddle/utils/code_gen/api_base
.py
)
# forward api file
set
(
api_gen_file
${
CMAKE_SOURCE_DIR
}
/python/paddle/utils/code_gen/api_gen.py
)
...
...
@@ -49,7 +49,7 @@ add_custom_command(
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
api_header_file_tmp
}
${
api_header_file
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
api_source_file_tmp
}
${
api_source_file
}
COMMENT
"copy_if_different
${
api_header_file
}
${
api_source_file
}
"
DEPENDS
${
api_yaml_file
}
${
api_gen_file
}
${
api_gen_
utils
}
DEPENDS
${
api_yaml_file
}
${
api_gen_file
}
${
api_gen_
base
}
VERBATIM
)
# generate backward api
...
...
@@ -62,7 +62,7 @@ add_custom_command(
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
bw_api_header_file_tmp
}
${
bw_api_header_file
}
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
${
bw_api_source_file_tmp
}
${
bw_api_source_file
}
COMMENT
"copy_if_different
${
bw_api_header_file
}
${
bw_api_source_file
}
"
DEPENDS
${
bw_api_yaml_file
}
${
bw_api_gen_file
}
${
api_gen_
utils
}
DEPENDS
${
bw_api_yaml_file
}
${
bw_api_gen_file
}
${
api_gen_
base
}
VERBATIM
)
cc_library
(
pten_data_transform SRCS data_transform.cc DEPS pten_tensor transfer_layout_kernel cast_kernel data_device_transform
)
...
...
python/paddle/utils/code_gen/api_base.py
0 → 100644
浏览文件 @
7b70b792
此差异已折叠。
点击以展开。
python/paddle/utils/code_gen/api_gen.py
浏览文件 @
7b70b792
...
...
@@ -16,64 +16,19 @@ import os
import
yaml
import
argparse
import
gen_utils
from
api_base
import
BaseAPI
class
API
:
class
ForwardAPI
(
BaseAPI
)
:
prefix_tensor_name
=
'dense_'
def
__init__
(
self
,
api_item_yaml
):
self
.
api
=
api_item_yaml
[
'api'
]
# args:
# inputs:
# names : [], list of input names
# input_info : {input_name : type}
# attrs:
# names : [], list of attribute names
# attr_info : { attr_name : (type, default_values)}
self
.
args
=
gen_utils
.
parse_args
(
self
.
api
,
api_item_yaml
[
'args'
])
self
.
out_type_list
,
_
=
gen_utils
.
parse_output
(
self
.
api
,
api_item_yaml
[
'output'
])
self
.
return_type
=
self
.
out_type_list
[
0
]
if
len
(
self
.
out_type_list
)
==
1
else
"std::tuple<"
+
","
.
join
(
self
.
out_type_list
)
+
">"
self
.
is_base_api
=
True
if
'invoke'
in
api_item_yaml
:
self
.
is_base_api
=
False
self
.
invoke
=
api_item_yaml
[
'invoke'
]
else
:
self
.
kernel
=
api_item_yaml
[
'kernel'
]
if
'backend'
not
in
self
.
kernel
or
len
(
self
.
kernel
[
'backend'
])
==
0
:
self
.
kernel
[
'backend'
]
=
None
if
'layout'
not
in
self
.
kernel
or
len
(
self
.
kernel
[
'layout'
])
==
0
:
self
.
kernel
[
'layout'
]
=
None
if
'data_type'
not
in
self
.
kernel
or
len
(
self
.
kernel
[
'data_type'
])
==
0
:
self
.
kernel
[
'data_type'
]
=
None
if
'param'
not
in
self
.
kernel
:
self
.
kernel
[
'param'
]
=
None
self
.
infer_meta
=
api_item_yaml
[
'infer_meta'
]
if
'param'
not
in
self
.
infer_meta
:
self
.
infer_meta
[
'param'
]
=
None
self
.
data_transform
=
{
'skip_transform'
:
[],
'support_trans_dtype'
:
[]
}
if
'data_transform'
in
api_item_yaml
:
if
'skip_transform'
in
api_item_yaml
[
'data_transform'
]:
self
.
data_transform
[
'skip_transform'
]
=
api_item_yaml
[
'data_transform'
][
'skip_transform'
]
if
'support_trans_dtype'
in
api_item_yaml
[
'data_transform'
]:
self
.
data_transform
[
'support_trans_dtype'
]
=
api_item_yaml
[
'data_transform'
][
'support_trans_dtype'
]
def
gene_api_declaration
(
self
):
return
f
"""
PADDLE_API
{
self
.
return_type
}
{
self
.
api
}
(
{
self
.
args
[
'args_declare'
]
}
);
"""
super
(
ForwardAPI
,
self
).
__init__
(
api_item_yaml
)
def
get_return_type
(
self
,
out_type_list
):
return
out_type_list
[
0
]
if
len
(
out_type_list
)
==
1
else
"std::tuple<"
+
","
.
join
(
out_type_list
)
+
">"
def
gene_output
(
self
,
output_type_list
):
kernel_output
=
""
...
...
@@ -84,12 +39,12 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
kernel_output
=
'dense_out'
output_names
.
append
(
'dense_out'
)
output_create
=
f
"""
{
self
.
return_type
}
out;
{
self
.
outputs
[
'return_type'
]
}
out;
auto dense_out = SetKernelOutput(kernel_backend, &out);"""
elif
len
(
output_type_list
)
>
1
:
output_create
=
f
"""
{
self
.
return_type
}
out;"""
{
self
.
outputs
[
'return_type'
]
}
out;"""
for
i
in
range
(
len
(
output_type_list
)):
kernel_output
=
kernel_output
+
f
'dense_out_
{
i
}
, '
...
...
@@ -105,36 +60,6 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
return
kernel_output
,
output_names
,
output_create
def
gene_api_code
(
self
):
if
self
.
is_base_api
:
input_tensors
,
kernel_args
,
kernel_signature
=
gen_utils
.
get_kernel_args
(
self
.
args
[
'inputs'
],
self
.
args
[
'attrs'
],
self
.
out_type_list
,
self
.
kernel
[
'param'
],
self
.
data_transform
)
outputs_args
,
output_names
,
output_create
=
self
.
gene_output
(
self
.
out_type_list
)
return
f
"""
PADDLE_API
{
self
.
return_type
}
{
self
.
api
}
(
{
self
.
args
[
"args_define"
]
}
) {{
{
gen_utils
.
gene_kernel_select
(
self
.
api
,
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
],
self
.
kernel
)
}
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
{
input_tensors
}
{
output_create
}
{
gen_utils
.
gene_infer_meta
(
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
][
'names'
],
output_names
,
self
.
infer_meta
)
}
using kernel_signature =
{
kernel_signature
}
;
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(
{
kernel_args
}
,
{
outputs_args
}
);
return out;
}}
"""
else
:
return
f
"""
PADDLE_API
{
self
.
return_type
}
{
self
.
api
}
(
{
self
.
args
[
"args_define"
]
}
) {{
return
{
self
.
invoke
}
;
}}
"""
def
header_include
():
return
"""
...
...
@@ -203,7 +128,7 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
source_file
.
write
(
namespace
[
0
])
for
api
in
apis
:
api_code
=
API
(
api
)
api_code
=
Forward
API
(
api
)
print
(
api_code
.
gene_api_declaration
())
header_file
.
write
(
api_code
.
gene_api_declaration
())
source_file
.
write
(
api_code
.
gene_api_code
())
...
...
python/paddle/utils/code_gen/backward_api_gen.py
浏览文件 @
7b70b792
...
...
@@ -17,52 +17,16 @@ import yaml
import
argparse
import
re
import
gen_utils
from
api_base
import
BaseAPI
class
BackwardAPI
:
class
BackwardAPI
(
BaseAPI
)
:
def
__init__
(
self
,
backward_item_yaml
):
self
.
backward_api
=
backward_item_yaml
[
'backward_api'
]
self
.
args
,
self
.
output_type_list
,
self
.
return_comment
=
self
.
parse_and_check_args
(
backward_item_yaml
[
'forward'
],
backward_item_yaml
[
'args'
],
backward_item_yaml
[
'output'
])
self
.
return_type
=
self
.
output_type_list
[
0
]
if
len
(
self
.
output_type_list
)
==
1
else
"std::vector<std::vector<Tensor>>"
self
.
is_base_api
=
True
if
'invoke'
in
backward_item_yaml
:
self
.
is_base_api
=
False
self
.
invoke
=
backward_item_yaml
[
'invoke'
]
else
:
self
.
kernel
=
backward_item_yaml
[
'kernel'
]
if
'backend'
not
in
self
.
kernel
or
len
(
self
.
kernel
[
'backend'
])
==
0
:
self
.
kernel
[
'backend'
]
=
None
if
'layout'
not
in
self
.
kernel
or
len
(
self
.
kernel
[
'layout'
])
==
0
:
self
.
kernel
[
'layout'
]
=
None
if
'data_type'
not
in
self
.
kernel
or
len
(
self
.
kernel
[
'data_type'
])
==
0
:
self
.
kernel
[
'data_type'
]
=
None
if
'param'
not
in
self
.
kernel
or
len
(
self
.
kernel
[
'param'
])
==
0
:
self
.
kernel
[
'param'
]
=
None
self
.
infer_meta
=
backward_item_yaml
[
'infer_meta'
]
if
'param'
not
in
self
.
infer_meta
or
len
(
self
.
infer_meta
[
'param'
])
==
0
:
self
.
infer_meta
[
'param'
]
=
None
self
.
data_transform
=
{
'skip_transform'
:
[],
'support_trans_dtype'
:
[]
}
if
'data_transform'
in
backward_item_yaml
:
if
'skip_transform'
in
backward_item_yaml
[
'data_transform'
]:
self
.
data_transform
[
'skip_transform'
]
=
backward_item_yaml
[
'data_transform'
][
'skip_transform'
]
if
'support_trans_dtype'
in
backward_item_yaml
[
'data_transform'
]:
self
.
data_transform
[
'support_trans_dtype'
]
=
backward_item_yaml
[
'data_transform'
][
'support_trans_dtype'
]
super
(
BackwardAPI
,
self
).
__init__
(
backward_item_yaml
)
self
.
check_args
(
backward_item_yaml
[
'forward'
])
def
get_api_name
(
self
,
api_item_yaml
):
return
api_item_yaml
[
'backward_api'
]
def
parse_forward_config
(
self
,
forward_config
):
# api_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
...
...
@@ -71,51 +35,39 @@ class BackwardAPI:
forward_config
)
api
=
result
.
group
(
'api'
)
outputs
=
[
item
.
strip
()
for
item
in
result
.
group
(
'outputs'
).
split
(
','
)]
forward_args
=
gen_utils
.
parse_args
(
api
,
result
.
group
(
'args'
))
fw_inputs
,
fw_attrs
,
_
,
=
self
.
parse_input_and_attr
(
api
,
result
.
group
(
'args'
))
return
api
,
f
orward_args
[
'inputs'
],
forward_args
[
'attrs'
]
,
outputs
return
api
,
f
w_inputs
,
fw_attrs
,
outputs
def
parse_and_check_args
(
self
,
forward_config
,
args_config
,
output
_config
):
def
check_args
(
self
,
forward
_config
):
# parse the forward and backward config
_
,
fw_inputs
,
fw_attrs
,
fw_outputs
=
self
.
parse_forward_config
(
forward_config
)
bw_args
=
gen_utils
.
parse_args
(
self
.
backward_api
,
args_config
)
# check the inputs of backward
for
input
in
bw_args
[
'inputs'
]
[
'names'
]:
for
input
in
self
.
inputs
[
'names'
]:
if
input
not
in
fw_inputs
and
input
not
in
fw_outputs
:
if
input
.
endswith
(
'_grad'
):
original_name
=
input
[:
-
5
]
assert
original_name
in
fw_outputs
,
\
f
"
{
self
.
backward_
api
}
: Input Tensor error: the input tensor(
{
input
}
) of backward should be an input or output or grad of output in forward api.
\
Please check the forward of
{
self
.
backward_
api
}
in yaml."
f
"
{
self
.
api
}
: Input Tensor error: the input tensor(
{
input
}
) of backward should be an input or output or grad of output in forward api.
\
Please check the forward of
{
self
.
api
}
in yaml."
# check the attributes of backward
for
attr
in
bw_args
[
'attrs'
]
[
'names'
]:
assert
attr
in
fw_attrs
[
'names'
]
and
bw_args
[
'attrs'
]
[
'attr_info'
][
attr
][
0
]
==
fw_attrs
[
'attr_info'
][
attr
][
0
],
\
f
"
{
self
.
backward_
api
}
: Attribute error: The attribute(
{
attr
}
) of backward isn't consistent with forward api.
\
Please check the args of
{
self
.
backward_
api
}
in yaml."
for
attr
in
self
.
attrs
[
'names'
]:
assert
attr
in
fw_attrs
[
'names'
]
and
self
.
attrs
[
'attr_info'
][
attr
][
0
]
==
fw_attrs
[
'attr_info'
][
attr
][
0
],
\
f
"
{
self
.
api
}
: Attribute error: The attribute(
{
attr
}
) of backward isn't consistent with forward api.
\
Please check the args of
{
self
.
api
}
in yaml."
# check the output of backward
out_type_list
,
return_comment
=
gen_utils
.
parse_output
(
self
.
backward_api
,
output_config
)
assert
len
(
out_type_list
)
<=
len
(
fw_inputs
[
'names'
]),
\
f
"
{
self
.
backward_api
}
: Output error: The number of ouputs should be less then the number of inputs of forward api.
\
Please check the output of
{
self
.
backward_api
}
in yaml."
return
bw_args
,
out_type_list
,
return_comment
def
gene_api_declaration
(
self
):
if
self
.
return_comment
:
return
f
"""
//
{
self
.
return_comment
}
{
self
.
return_type
}
{
self
.
backward_api
}
(
{
self
.
args
[
'args_declare'
]
}
);
"""
assert
len
(
self
.
outputs
[
'types'
])
<=
len
(
fw_inputs
[
'names'
]),
\
f
"
{
self
.
api
}
: Output error: The number of ouputs should be less then the number of inputs of forward api.
\
Please check the output of
{
self
.
api
}
in yaml."
else
:
return
f
"""
{
self
.
return_type
}
{
self
.
backward_api
}
(
{
self
.
args
[
'args_declare'
]
}
);
"""
def
get_return_type
(
self
,
out_type_list
):
return
out_type_list
[
0
]
if
len
(
out_type_list
)
==
1
else
"std::vector<std::vector<Tensor>>"
def
gene_output
(
self
,
output_type_list
):
kernel_output
=
""
...
...
@@ -126,12 +78,12 @@ class BackwardAPI:
kernel_output
=
'dense_out'
output_names
.
append
(
'dense_out'
)
output_create
=
f
"""
{
self
.
return_type
}
out;
{
self
.
outputs
[
'return_type'
]
}
out;
auto dense_out = SetKernelOutput(kernel_backend, &out);"""
elif
len
(
output_type_list
)
>
1
:
output_create
=
f
"""
{
self
.
return_type
}
out(
{
len
(
output_type_list
)
}
);"""
{
self
.
outputs
[
'return_type'
]
}
out(
{
len
(
output_type_list
)
}
);"""
for
i
,
out_type_item
in
enumerate
(
output_type_list
):
kernel_output
=
kernel_output
+
f
'dense_out_
{
i
}
, '
...
...
@@ -150,58 +102,10 @@ class BackwardAPI:
else
:
raise
ValueError
(
"{} : Output error: the output should not be empty."
.
format
(
self
.
backward_
api
))
self
.
api
))
return
kernel_output
,
output_names
,
output_create
def
gene_api_code
(
self
):
if
self
.
is_base_api
:
input_tensors
,
kernel_args
,
kernel_signature
=
gen_utils
.
get_kernel_args
(
self
.
args
[
'inputs'
],
self
.
args
[
'attrs'
],
self
.
output_type_list
,
self
.
kernel
[
'param'
],
self
.
data_transform
)
outputs_args
,
output_names
,
output_create
=
self
.
gene_output
(
self
.
output_type_list
)
return
f
"""
//
{
self
.
return_comment
}
{
self
.
return_type
}
{
self
.
backward_api
}
(
{
self
.
args
[
"args_define"
]
}
) {{
{
gen_utils
.
gene_kernel_select
(
self
.
backward_api
,
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
],
self
.
kernel
)
}
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
{
input_tensors
}
{
output_create
}
{
gen_utils
.
gene_infer_meta
(
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
][
'names'
],
output_names
,
self
.
infer_meta
)
}
using kernel_signature =
{
kernel_signature
}
;
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(
{
kernel_args
}
,
{
outputs_args
}
);
return out;
}}
"""
else
:
inveke_func_name
=
self
.
invoke
.
split
(
'('
)[
0
].
strip
()
if
inveke_func_name
in
self
.
args
[
'attrs'
][
'names'
]:
# Adjust the param whose name is same with api invoked.
pattern
=
'\W'
+
inveke_func_name
+
'[^A-Za-z0-9_(]'
def
adjust_name
(
matched
):
matched_str
=
matched
.
group
()
return
matched_str
[
0
:
-
1
]
+
'_val'
+
matched_str
[
-
1
]
invoke_code
=
re
.
sub
(
pattern
,
adjust_name
,
self
.
invoke
)
params_code
=
re
.
sub
(
pattern
,
adjust_name
,
self
.
args
[
"args_define"
])
else
:
invoke_code
=
self
.
invoke
params_code
=
self
.
args
[
"args_define"
]
return
f
"""
//
{
self
.
return_comment
}
{
self
.
return_type
}
{
self
.
backward_api
}
(
{
params_code
}
) {{
return
{
invoke_code
}
;
}}
"""
def
header_include
():
return
"""
...
...
@@ -263,7 +167,6 @@ def generate_backward_api(backward_yaml_path, header_file_path,
for
bw_api
in
bw_apis
:
bw_api
=
BackwardAPI
(
bw_api
)
# print(api_code.gene_api_declaration())
header_file
.
write
(
bw_api
.
gene_api_declaration
())
source_file
.
write
(
bw_api
.
gene_api_code
())
...
...
python/paddle/utils/code_gen/gen_utils.py
已删除
100644 → 0
浏览文件 @
224bc511
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
re
PREFIX_TENSOR_NAME
=
'dense_'
PREFIX_META_TENSOR_NAME
=
'meta_'
def
parse_args
(
api_name
,
args_str
):
"""
Returns:
{ inputs : {
names : [] // list of input names
input_info : { input_name : type }
}
attrs: {
names : [] // list of attribute names
attr_info : { attr_name : (type, default_value)}
}
args_declare : "str" // str of funtion params with default value. Example: (..., bool flag=false)
args_define : "str" // str of funtion params without default value. Example: (..., bool flag)
}
"""
inputs
=
{
'names'
:
[],
'input_info'
:
{}}
attrs
=
{
'names'
:
[],
'attr_info'
:
{}}
args_str
=
args_str
.
strip
()
assert
args_str
.
startswith
(
'('
)
and
args_str
.
endswith
(
')'
),
\
f
"Args declaration should start with '(' and end with ')', please check the args of
{
api_name
}
in yaml."
args_str
=
args_str
[
1
:
-
1
]
args_list
=
args_str
.
split
(
','
)
input_types
=
[
'const Tensor&'
,
'const Tensor &'
,
'const std::vector<Tensor>&'
,
'const std::vector<Tensor> &'
]
attr_types
=
[
'const Scalar&'
,
'const Scalar &'
,
'const ScalarArray&'
,
'const ScalarArray &'
,
\
'int'
,
'int32_t'
,
'int64_t'
,
'size_t'
,
'float'
,
'double'
,
'bool'
,
\
'const std::vector<int64_t>&'
,
'Backend'
,
'DataLayout'
,
'DataType'
]
args_declare_str
=
""
args_define_str
=
""
for
item
in
args_list
:
item
=
item
.
strip
()
# match the input tensor
has_input
=
False
for
in_type
in
input_types
:
if
item
.
startswith
(
in_type
):
input_name
=
item
[
len
(
in_type
):].
strip
()
assert
len
(
input_name
)
>
0
,
\
f
"The input tensor name should not be empty. Please check the args of
{
api_name
}
in yaml."
assert
len
(
attrs
[
'names'
])
==
0
,
\
f
"The input Tensor should appear before attributes. please check the position of
{
api_name
}
:input(
{
input_name
}
) in yaml"
inputs
[
'names'
].
append
(
input_name
)
inputs
[
'input_info'
][
input_name
]
=
in_type
args_declare_str
=
args_declare_str
+
in_type
+
' '
+
input_name
+
', '
args_define_str
=
args_define_str
+
in_type
+
' '
+
input_name
+
', '
has_input
=
True
break
if
has_input
:
continue
# match the attribute
for
attr_type
in
attr_types
:
if
item
.
startswith
(
attr_type
):
attr_name
=
item
[
len
(
attr_type
):].
strip
()
assert
len
(
attr_name
)
>
0
,
\
f
"The attribute name should not be empty. Please check the args of
{
api_name
}
in yaml."
default_value
=
None
if
'='
in
attr_name
:
attr_infos
=
attr_name
.
split
(
'='
)
attr_name
=
attr_infos
[
0
].
strip
()
default_value
=
attr_infos
[
1
].
strip
()
default_value_str
=
""
if
default_value
is
None
else
'='
+
default_value
args_declare_str
=
args_declare_str
+
attr_type
+
' '
+
attr_name
+
default_value_str
+
', '
args_define_str
=
args_define_str
+
attr_type
+
' '
+
attr_name
+
', '
attrs
[
'names'
].
append
(
attr_name
)
attrs
[
'attr_info'
][
attr_name
]
=
(
attr_type
,
default_value
)
break
args
=
{
'inputs'
:
inputs
,
'attrs'
:
attrs
,
'args_declare'
:
args_declare_str
[:
-
2
],
'args_define'
:
args_define_str
[:
-
2
]
}
return
args
def
parse_output
(
api_name
,
output_config
):
def
parse_output_item
(
output_item
):
alllowd_output_types
=
[
'Tensor'
,
'std::vector<Tensor>'
]
if
re
.
search
(
r
'\(\w*\)'
,
output_item
):
result
=
re
.
search
(
r
"(?P<out_type>[a-zA-Z0-9_<>]+)\s*\((?P<name>\w+)\)"
,
output_item
)
out_type
=
result
.
group
(
'out_type'
)
assert
out_type
in
alllowd_output_types
,
\
f
"
{
api_name
}
: Output type error: the output type only support Tensor and std::vector<Tensor>,
\
but now is
{
out_type
}
."
return
out_type
,
result
.
group
(
'name'
)
else
:
if
output_item
.
strip
()
in
alllowd_output_types
:
return
output_item
.
strip
(),
'out'
else
:
raise
ValueError
(
"{} : Output type error: the output type only support Tensor and std::vector<Tensor>,
\
but now is {}."
.
format
(
api_name
,
out_type
))
temp_list
=
output_config
.
split
(
','
)
if
len
(
temp_list
)
==
1
:
out_type
,
out_name
=
parse_output_item
(
temp_list
[
0
])
return
[
out_type
],
out_name
else
:
out_type_list
=
[]
out_name_list
=
[]
for
output_item
in
temp_list
:
out_type
,
out_name
=
parse_output_item
(
output_item
)
out_type_list
.
append
(
out_type
)
out_name_list
.
append
(
out_name
)
return
out_type_list
,
", "
.
join
(
out_name_list
)
def
gene_kernel_select
(
api
,
input_names
,
attrs
,
kernel
)
->
str
:
kernel_key_item_init
=
"""
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
"""
# Check the tensor options
attr_backend_count
=
0
attr_layout_count
=
0
attr_data_type_count
=
0
for
attr_name
in
attrs
[
'names'
]:
if
attrs
[
'attr_info'
][
attr_name
][
0
]
==
'Backend'
:
assert
kernel
[
'backend'
]
is
not
None
,
\
f
"
{
api
}
api: When there is a parameter with 'Backend' type in attributes, you must set backend of kernel manually."
attr_backend_count
=
attr_backend_count
+
1
if
attrs
[
'attr_info'
][
attr_name
][
0
]
==
'DataLayout'
:
assert
kernel
[
'layout'
]
is
not
None
,
\
f
"
{
api
}
api: When there is a parameter with 'DataLayout' type in attributes, you must set layout of kernel manually."
attr_layout_count
=
attr_layout_count
+
1
if
attrs
[
'attr_info'
][
attr_name
][
0
]
==
'DataType'
:
assert
kernel
[
'data_type'
]
is
not
None
,
\
f
"
{
api
}
api: When there is a parameter with 'DataType' type in attributes, you must set data_type of kernel manually."
attr_data_type_count
=
attr_data_type_count
+
1
# preprocess kernel configures
kernel_select_code
=
""
if
kernel
[
'backend'
]
is
not
None
:
if
'>'
in
kernel
[
'backend'
]:
vars_list
=
kernel
[
'backend'
].
split
(
'>'
)
assert
len
(
vars_list
)
==
2
,
f
"
{
api
}
api: The number of params to set backend with '>' only allows 2, but received
{
len
(
vars_list
)
}
."
assert
(
vars_list
[
0
].
strip
()
in
attrs
[
'names'
])
and
(
attrs
[
'attr_info'
][
vars_list
[
0
].
strip
()][
0
]
==
'Backend'
),
\
f
"
{
api
}
api: When use '>' to set kernel backend, the first param should be a attribute with Backend type."
kernel_select_code
=
kernel_select_code
+
f
"""
kernel_backend = ParseBackendWithInputOrder(
{
vars_list
[
0
].
strip
()
}
,
{
vars_list
[
1
].
strip
()
}
);
"""
else
:
args_str
=
""
for
ele
in
kernel
[
'backend'
].
split
(
','
):
args_str
=
args_str
+
ele
.
strip
()
+
', '
kernel_select_code
=
kernel_select_code
+
f
"""
kernel_backend = ParseBackend(
{
args_str
[:
-
2
]
}
);
"""
if
kernel
[
'layout'
]
is
not
None
:
if
'>'
in
kernel
[
'layout'
]:
vars_list
=
kernel
[
'layout'
].
split
(
'>'
)
assert
len
(
vars_list
)
==
2
,
f
"
{
api
}
api: The number of params to set layout with '>' only allows 2, but received
{
len
(
vars_list
)
}
."
assert
vars_list
[
0
].
strip
()
in
attrs
[
'names'
]
and
attrs
[
'attr_info'
][
vars_list
[
0
].
strip
()][
0
]
==
'DataLayout'
,
\
f
"
{
api
}
api: When use '>' to set kernel layout, the first param should be a attribute with DataLayout type."
kernel_select_code
=
kernel_select_code
+
f
"""
kernel_layout = ParseLayoutWithInputOrder(
{
vars_list
[
0
].
strip
()
}
,
{
vars_list
[
1
].
strip
()
}
);
"""
else
:
vars_list
=
kernel
[
'layout'
].
split
(
','
)
assert
len
(
vars_list
)
==
1
,
f
"
{
api
}
api: The number of params to set layout must be 1, but received
{
len
(
vars_list
)
}
."
kernel_select_code
=
kernel_select_code
+
f
"""
kernel_layout = ParseLayout(
{
vars_list
[
0
].
strip
()
}
);
"""
if
kernel
[
'data_type'
]
is
not
None
:
if
'>'
in
kernel
[
'data_type'
]:
vars_list
=
kernel
[
'data_type'
].
split
(
'>'
)
assert
len
(
vars_list
)
==
2
,
f
"
{
api
}
api: The number of params to set data_type with '>' only allows 2, but received
{
len
(
vars_list
)
}
."
assert
vars_list
[
0
].
strip
()
in
attrs
[
'names'
]
and
attrs
[
'attr_info'
][
vars_list
[
0
].
strip
()][
0
]
==
'DataType'
,
\
f
"
{
api
}
api: When use '>' to set kernel data_type, the first param should be a attribute with DataType type."
kernel_select_code
=
kernel_select_code
+
f
"""
kernel_data_type = ParseDataTypeWithInputOrder(
{
vars_list
[
0
].
strip
()
}
,
{
vars_list
[
1
].
strip
()
}
);
"""
else
:
vars_list
=
kernel
[
'data_type'
].
split
(
','
)
assert
len
(
vars_list
)
==
1
,
f
"
{
api
}
api: The number of params to set data_type only allows 2, but received
{
len
(
vars_list
)
}
."
kernel_select_code
=
kernel_select_code
+
f
"""
kernel_data_type = ParseDataType(
{
vars_list
[
0
].
strip
()
}
);
"""
if
len
(
input_names
)
==
0
:
assert
attr_backend_count
>
0
and
attr_layout_count
>
0
and
attr_data_type_count
>
0
,
\
f
"
{
api
}
api: When there is no input tensor, the args must have 'Backend', 'DataLayout' and 'DataType'."
kernel_select_args
=
""
for
input_name
in
input_names
:
kernel_select_args
=
kernel_select_args
+
input_name
+
", "
if
len
(
kernel_select_args
)
>
2
:
kernel_select_args
=
kernel_select_args
[:
-
2
]
kernel_select_code
=
kernel_key_item_init
+
kernel_select_code
if
len
(
input_names
)
>
0
:
kernel_select_code
=
kernel_select_code
+
f
"""
if (kernel_backend == Backend::UNDEFINED
|| kernel_layout == DataLayout::UNDEFINED
|| kernel_data_type == DataType::UNDEFINED ) {{
auto kernel_key_set = ParseKernelKeyByInputArgs(
{
kernel_select_args
}
);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {{
kernel_backend = kernel_key.backend();
}}
if (kernel_layout == DataLayout::UNDEFINED) {{
kernel_layout = kernel_key.layout();
}}
if (kernel_data_type == DataType::UNDEFINED) {{
kernel_data_type = kernel_key.dtype();
}}
}}"""
kernel_select_code
=
kernel_select_code
+
f
"""
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"
{
kernel
[
'func'
]
}
", {{kernel_backend, kernel_layout, kernel_data_type}});
VLOG(6) << "
{
api
}
API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
VLOG(6) << "
{
api
}
API kernel: " << kernel;"""
return
kernel_select_code
def
gene_infer_meta
(
input_names
,
attr_names
,
output_names
,
infer_meta
)
->
str
:
infer_meta_params
=
infer_meta
[
'param'
]
+
output_names
if
infer_meta
[
'param'
]
is
not
None
else
input_names
+
attr_names
+
output_names
# generate meta tensors
meta_tensor_code
=
""
param_code
=
""
for
param
in
infer_meta_params
:
if
param
in
input_names
:
param_code
=
param_code
+
"MakeMetaTensor(*"
+
PREFIX_TENSOR_NAME
+
param
+
"), "
elif
param
in
output_names
:
meta_tensor_code
=
meta_tensor_code
+
" pten::MetaTensor "
+
param
.
replace
(
PREFIX_TENSOR_NAME
,
PREFIX_META_TENSOR_NAME
)
+
"("
+
param
+
");
\n
"
param_code
=
param_code
+
"&"
+
param
.
replace
(
PREFIX_TENSOR_NAME
,
PREFIX_META_TENSOR_NAME
)
+
", "
elif
param
in
attr_names
:
param_code
=
param_code
+
param
+
", "
elif
isinstance
(
param
,
str
):
param_code
=
param_code
+
"
\"
"
+
param
+
"
\"
, "
elif
isinstance
(
param
,
bool
):
param_code
=
param_code
+
str
(
param
).
lower
()
+
", "
else
:
param_code
=
param_code
+
str
(
param
)
+
", "
param_code
=
param_code
[:
-
2
]
return
f
"""
{
meta_tensor_code
}
pten::
{
infer_meta
[
'func'
]
}
(
{
param_code
}
);
"""
def
get_kernel_args
(
inputs
,
attrs
,
out_type_list
,
kernel_param
,
data_transform
):
input_trans_map
=
{
'const Tensor&'
:
'const pten::DenseTensor&'
,
'const Tensor &'
:
'const pten::DenseTensor&'
,
'const std::vector<Tensor>&'
:
'const std::vector<pten::DenseTensor>&'
,
'const std::vector<Tensor> &'
:
'const std::vector<pten::DenseTensor>&'
}
out_trans_map
=
{
'Tensor'
:
'pten::DenseTensor*'
,
'std::vector<Tensor>'
:
'std::vector<pten::DenseTensor*>&'
}
input_names
=
inputs
[
'names'
]
input_infos
=
inputs
[
'input_info'
]
kernel_args_type_list
=
[
'const platform::DeviceContext&'
]
input_tensor_code
=
""
for
input_name
in
input_names
:
# set input code
input_tensor_code
=
input_tensor_code
+
f
"""
auto
{
PREFIX_TENSOR_NAME
}{
input_name
}
= TensorToDenseTensor(
{
input_name
}
);"""
attr_names
=
attrs
[
'names'
]
if
kernel_param
is
None
:
kernel_param
=
input_names
+
attr_names
input_tensor_code
=
""
for
i
,
input_name
in
enumerate
(
input_names
):
# set input code
if
input_name
in
kernel_param
:
trans_flag
=
"{}"
if
input_name
in
data_transform
[
'skip_transform'
]:
trans_flag
=
"{true}"
elif
input_name
in
data_transform
[
'support_trans_dtype'
]:
trans_flag
=
"{false, true}"
input_tensor_code
=
input_tensor_code
+
f
"""
auto
{
PREFIX_TENSOR_NAME
}{
input_name
}
= PrepareData(
{
input_name
}
, kernel.InputAt(
{
i
}
),
{
trans_flag
}
);"""
else
:
input_tensor_code
=
input_tensor_code
+
f
"""
auto
{
PREFIX_TENSOR_NAME
}{
input_name
}
= TensorToDenseTensor(
{
input_name
}
);"""
kernel_args
=
"*dev_ctx, "
for
param
in
kernel_param
:
if
param
in
input_names
:
kernel_args
=
kernel_args
+
"*"
+
PREFIX_TENSOR_NAME
+
param
+
", "
kernel_args_type_list
.
append
(
input_trans_map
[
input_infos
[
param
]])
elif
param
in
attr_names
:
# set attr for kernel_context
if
'ScalarArray'
in
attrs
[
'attr_info'
][
param
][
0
]:
kernel_args_type_list
.
append
(
'const pten::ScalarArray&'
)
param
=
'pten::ScalarArray('
+
param
+
')'
elif
'Scalar'
in
attrs
[
'attr_info'
][
param
][
0
]:
kernel_args_type_list
.
append
(
'const pten::Scalar&'
)
param
=
'pten::Scalar('
+
param
+
')'
else
:
kernel_args_type_list
.
append
(
attrs
[
'attr_info'
][
param
][
0
])
kernel_args
=
kernel_args
+
param
+
", "
elif
isinstance
(
param
,
bool
):
kernel_args
=
kernel_args
+
str
(
param
).
lower
()
+
", "
else
:
kernel_args
=
kernel_args
+
str
(
param
)
+
", "
for
out_type
in
out_type_list
:
kernel_args_type_list
.
append
(
out_trans_map
[
out_type
])
kernel_signature
=
"void(*)("
+
", "
.
join
(
kernel_args_type_list
)
+
")"
return
input_tensor_code
,
kernel_args
[:
-
2
],
kernel_signature
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录