Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
33b3e28a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
33b3e28a
编写于
1月 26, 2022
作者:
Z
zyfncg
提交者:
GitHub
1月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change output of backward_api (#39229)
上级
30470853
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
82 addition
and
42 deletion
+82
-42
python/paddle/utils/code_gen/api_gen.py
python/paddle/utils/code_gen/api_gen.py
+37
-6
python/paddle/utils/code_gen/backward_api_gen.py
python/paddle/utils/code_gen/backward_api_gen.py
+42
-11
python/paddle/utils/code_gen/gen_utils.py
python/paddle/utils/code_gen/gen_utils.py
+3
-25
未找到文件。
python/paddle/utils/code_gen/api_gen.py
浏览文件 @
33b3e28a
...
@@ -31,7 +31,12 @@ class API:
...
@@ -31,7 +31,12 @@ class API:
# names : [], list of attribute names
# names : [], list of attribute names
# attr_info : { attr_name : (type, default_values)}
# attr_info : { attr_name : (type, default_values)}
self
.
args
=
gen_utils
.
parse_args
(
self
.
api
,
api_item_yaml
[
'args'
])
self
.
args
=
gen_utils
.
parse_args
(
self
.
api
,
api_item_yaml
[
'args'
])
self
.
output
=
api_item_yaml
[
'output'
]
self
.
out_type_list
,
_
=
gen_utils
.
parse_output
(
self
.
api
,
api_item_yaml
[
'output'
])
self
.
return_type
=
self
.
out_type_list
[
0
]
if
len
(
self
.
out_type_list
)
==
1
else
"std::tuple<"
+
","
.
join
(
self
.
out_type_list
)
+
">"
self
.
is_base_api
=
True
self
.
is_base_api
=
True
if
'invoke'
in
api_item_yaml
:
if
'invoke'
in
api_item_yaml
:
self
.
is_base_api
=
False
self
.
is_base_api
=
False
...
@@ -54,18 +59,44 @@ class API:
...
@@ -54,18 +59,44 @@ class API:
def
gene_api_declaration
(
self
):
def
gene_api_declaration
(
self
):
return
f
"""
return
f
"""
PADDLE_API
{
self
.
output
}
{
self
.
api
}
(
{
self
.
args
[
'args_declare'
]
}
);
PADDLE_API
{
self
.
return_type
}
{
self
.
api
}
(
{
self
.
args
[
'args_declare'
]
}
);
"""
"""
def
gene_output
(
self
,
output_type_list
):
kernel_output
=
""
output_create
=
""
if
len
(
output_type_list
)
==
1
:
kernel_output
=
'dense_out'
output_create
=
f
"""
{
self
.
return_type
}
out;
auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);"""
elif
len
(
output_type_list
)
>
1
:
output_create
=
f
"""
{
self
.
return_type
}
out;"""
for
i
in
range
(
len
(
output_type_list
)):
kernel_output
=
kernel_output
+
f
'dense_out_
{
i
}
, '
output_create
=
output_create
+
f
"""
auto dense_out_
{
i
}
= SetKernelOutput(std::get<
{
i
}
>(out_meta), kernel_backend, &std::get<
{
i
}
>(out));"""
kernel_output
=
kernel_output
[:
-
2
]
else
:
raise
ValueError
(
"{} : Output error: the output should not be empty."
.
format
(
self
.
api
))
return
kernel_output
,
output_create
def
gene_api_code
(
self
):
def
gene_api_code
(
self
):
if
self
.
is_base_api
:
if
self
.
is_base_api
:
input_tensors
,
kernel_args
=
gen_utils
.
get_kernel_args
(
input_tensors
,
kernel_args
=
gen_utils
.
get_kernel_args
(
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
],
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
],
self
.
kernel
[
'param'
])
self
.
kernel
[
'param'
])
out_type
,
_
=
gen_utils
.
parse_output
(
self
.
api
,
self
.
output
)
outputs_args
,
output_create
=
self
.
gene_output
(
self
.
out_type_list
)
outputs_args
,
output_create
=
gen_utils
.
gene_output
(
out_type
)
return
f
"""
return
f
"""
PADDLE_API
{
self
.
output
}
{
self
.
api
}
(
{
self
.
args
[
"args_define"
]
}
) {{
PADDLE_API
{
self
.
return_type
}
{
self
.
api
}
(
{
self
.
args
[
"args_define"
]
}
) {{
{
gen_utils
.
gene_kernel_select
(
self
.
api
,
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
],
self
.
kernel
)
}
{
gen_utils
.
gene_kernel_select
(
self
.
api
,
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
],
self
.
kernel
)
}
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
...
@@ -82,7 +113,7 @@ PADDLE_API {self.output} {self.api}({self.args["args_define"]}) {{
...
@@ -82,7 +113,7 @@ PADDLE_API {self.output} {self.api}({self.args["args_define"]}) {{
else
:
else
:
return
f
"""
return
f
"""
PADDLE_API
{
self
.
output
}
{
self
.
api
}
(
{
self
.
args
[
"args_define"
]
}
) {{
PADDLE_API
{
self
.
return_type
}
{
self
.
api
}
(
{
self
.
args
[
"args_define"
]
}
) {{
return
{
self
.
invoke
}
;
return
{
self
.
invoke
}
;
}}
}}
"""
"""
...
...
python/paddle/utils/code_gen/backward_api_gen.py
浏览文件 @
33b3e28a
...
@@ -23,9 +23,11 @@ import gen_utils
...
@@ -23,9 +23,11 @@ import gen_utils
class
BackwardAPI
:
class
BackwardAPI
:
def
__init__
(
self
,
backward_item_yaml
):
def
__init__
(
self
,
backward_item_yaml
):
self
.
backward_api
=
backward_item_yaml
[
'backward_api'
]
self
.
backward_api
=
backward_item_yaml
[
'backward_api'
]
self
.
args
,
self
.
output_type
,
self
.
return_comment
=
self
.
parse_and_check_args
(
self
.
args
,
self
.
output_type
_list
,
self
.
return_comment
=
self
.
parse_and_check_args
(
backward_item_yaml
[
'forward'
],
backward_item_yaml
[
'args'
],
backward_item_yaml
[
'forward'
],
backward_item_yaml
[
'args'
],
backward_item_yaml
[
'output'
])
backward_item_yaml
[
'output'
])
self
.
return_type
=
self
.
output_type_list
[
0
]
if
len
(
self
.
output_type_list
)
==
1
else
"std::vector<std::vector<Tensor>>"
self
.
is_base_api
=
True
self
.
is_base_api
=
True
if
'invoke'
in
backward_item_yaml
:
if
'invoke'
in
backward_item_yaml
:
...
@@ -81,36 +83,65 @@ class BackwardAPI:
...
@@ -81,36 +83,65 @@ class BackwardAPI:
Please check the args of
{
self
.
backward_api
}
in yaml."
Please check the args of
{
self
.
backward_api
}
in yaml."
# check the output of backward
# check the output of backward
out
put_type
,
return_comment
=
gen_utils
.
parse_output
(
self
.
backward_api
,
out
_type_list
,
return_comment
=
gen_utils
.
parse_output
(
output_config
)
self
.
backward_api
,
output_config
)
assert
output_type
.
count
(
'Tensor'
)
<=
len
(
fw_inputs
[
'names'
]),
\
assert
len
(
out_type_list
)
<=
len
(
fw_inputs
[
'names'
]),
\
f
"
{
self
.
backward_api
}
: Output error: The number of ouputs should be less then the number of inputs of forward api.
\
f
"
{
self
.
backward_api
}
: Output error: The number of ouputs should be less then the number of inputs of forward api.
\
Please check the output of
{
self
.
backward_api
}
in yaml."
Please check the output of
{
self
.
backward_api
}
in yaml."
return
bw_args
,
out
put_type
,
return_comment
return
bw_args
,
out
_type_list
,
return_comment
def
gene_api_declaration
(
self
):
def
gene_api_declaration
(
self
):
if
self
.
return_comment
:
if
self
.
return_comment
:
return
f
"""
return
f
"""
//
{
self
.
return_comment
}
//
{
self
.
return_comment
}
{
self
.
output
_type
}
{
self
.
backward_api
}
(
{
self
.
args
[
'args_declare'
]
}
);
{
self
.
return
_type
}
{
self
.
backward_api
}
(
{
self
.
args
[
'args_declare'
]
}
);
"""
"""
else
:
else
:
return
f
"""
return
f
"""
{
self
.
output
_type
}
{
self
.
backward_api
}
(
{
self
.
args
[
'args_declare'
]
}
);
{
self
.
return
_type
}
{
self
.
backward_api
}
(
{
self
.
args
[
'args_declare'
]
}
);
"""
"""
def
gene_output
(
self
,
output_type_list
):
kernel_output
=
""
output_create
=
""
if
len
(
output_type_list
)
==
1
:
return_type
=
output_type_list
[
0
]
kernel_output
=
'dense_out'
output_create
=
f
"""
{
self
.
return_type
}
out;
auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);"""
elif
len
(
output_type_list
)
>
1
:
output_create
=
f
"""
{
self
.
return_type
}
out;"""
for
i
,
out_type_item
in
enumerate
(
output_type_list
):
kernel_output
=
kernel_output
+
f
'dense_out_
{
i
}
, '
get_out_code
=
f
'&out[
{
i
}
][0]'
if
out_type_item
==
'Tensor'
else
f
'&out[
{
i
}
]'
output_create
=
output_create
+
f
"""
auto dense_out_
{
i
}
= SetKernelOutput(std::get<
{
i
}
>(out_meta), kernel_backend,
{
get_out_code
}
);"""
kernel_output
=
kernel_output
[:
-
2
]
else
:
raise
ValueError
(
"{} : Output error: the output should not be empty."
.
format
(
self
.
backward_api
))
return
kernel_output
,
output_create
def
gene_api_code
(
self
):
def
gene_api_code
(
self
):
if
self
.
is_base_api
:
if
self
.
is_base_api
:
input_tensors
,
kernel_args
=
gen_utils
.
get_kernel_args
(
input_tensors
,
kernel_args
=
gen_utils
.
get_kernel_args
(
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
],
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
],
self
.
kernel
[
'param'
])
self
.
kernel
[
'param'
])
outputs_args
,
output_create
=
gen_utils
.
gene_output
(
outputs_args
,
output_create
=
self
.
gene_output
(
self
.
output_type
)
self
.
output_type
_list
)
return
f
"""
return
f
"""
//
{
self
.
return_comment
}
//
{
self
.
return_comment
}
{
self
.
output
_type
}
{
self
.
backward_api
}
(
{
self
.
args
[
"args_define"
]
}
) {{
{
self
.
return
_type
}
{
self
.
backward_api
}
(
{
self
.
args
[
"args_define"
]
}
) {{
{
gen_utils
.
gene_kernel_select
(
self
.
backward_api
,
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
],
self
.
kernel
)
}
{
gen_utils
.
gene_kernel_select
(
self
.
backward_api
,
self
.
args
[
'inputs'
][
'names'
],
self
.
args
[
'attrs'
],
self
.
kernel
)
}
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
...
@@ -143,7 +174,7 @@ class BackwardAPI:
...
@@ -143,7 +174,7 @@ class BackwardAPI:
params_code
=
self
.
args
[
"args_define"
]
params_code
=
self
.
args
[
"args_define"
]
return
f
"""
return
f
"""
//
{
self
.
return_comment
}
//
{
self
.
return_comment
}
{
self
.
output
_type
}
{
self
.
backward_api
}
(
{
params_code
}
) {{
{
self
.
return
_type
}
{
self
.
backward_api
}
(
{
params_code
}
) {{
return
{
invoke_code
}
;
return
{
invoke_code
}
;
}}
}}
"""
"""
...
...
python/paddle/utils/code_gen/gen_utils.py
浏览文件 @
33b3e28a
...
@@ -124,7 +124,7 @@ def parse_output(api_name, output_config):
...
@@ -124,7 +124,7 @@ def parse_output(api_name, output_config):
if
len
(
temp_list
)
==
1
:
if
len
(
temp_list
)
==
1
:
out_type
,
out_name
=
parse_output_item
(
temp_list
[
0
])
out_type
,
out_name
=
parse_output_item
(
temp_list
[
0
])
return
out_type
,
out_name
return
[
out_type
]
,
out_name
else
:
else
:
out_type_list
=
[]
out_type_list
=
[]
out_name_list
=
[]
out_name_list
=
[]
...
@@ -133,8 +133,7 @@ def parse_output(api_name, output_config):
...
@@ -133,8 +133,7 @@ def parse_output(api_name, output_config):
out_type_list
.
append
(
out_type
)
out_type_list
.
append
(
out_type
)
out_name_list
.
append
(
out_name
)
out_name_list
.
append
(
out_name
)
return
"std::tuple<"
+
","
.
join
(
out_type_list
)
+
">"
,
", "
.
join
(
return
out_type_list
,
", "
.
join
(
out_name_list
)
out_name_list
)
def
gene_kernel_select
(
api
,
input_names
,
attrs
,
kernel
)
->
str
:
def
gene_kernel_select
(
api
,
input_names
,
attrs
,
kernel
)
->
str
:
...
@@ -315,24 +314,3 @@ def get_kernel_args(input_names, attrs, kernel_param):
...
@@ -315,24 +314,3 @@ def get_kernel_args(input_names, attrs, kernel_param):
else
:
else
:
kernel_args
=
kernel_args
+
str
(
param
)
+
", "
kernel_args
=
kernel_args
+
str
(
param
)
+
", "
return
input_tensor_code
,
kernel_args
[:
-
2
]
return
input_tensor_code
,
kernel_args
[:
-
2
]
def
gene_output
(
output_type
):
kernel_output
=
""
output_create
=
f
"""
{
output_type
}
out;"""
if
output_type
==
'Tensor'
or
output_type
==
'std::vector<Tensor>'
:
kernel_output
=
'dense_out'
output_create
=
output_create
+
"""
auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);"""
elif
re
.
match
(
r
'std::tuple<.*>$'
,
output_type
):
out_num
=
output_type
.
count
(
'Tensor'
)
for
i
in
range
(
out_num
):
kernel_output
=
kernel_output
+
f
'dense_out_
{
i
}
, '
output_create
=
output_create
+
f
"""
auto dense_out_
{
i
}
= SetKernelOutput(std::get<
{
i
}
>(out_meta), kernel_backend, &std::get<
{
i
}
>(out));"""
kernel_output
=
kernel_output
[:
-
2
]
return
kernel_output
,
output_create
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录