Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d2a911b4
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d2a911b4
编写于
3月 04, 2022
作者:
Z
Zhanlue Yang
提交者:
GitHub
3月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Yaml]Support parsing fwd & bwd returns with name (#40107)
上级
73a4fe6c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
14 addition
and
27 deletion
+14
-27
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
...er/auto_code_generator/final_state_generator/eager_gen.py
+14
-27
未找到文件。
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
浏览文件 @
d2a911b4
...
@@ -208,39 +208,26 @@ def ParseYamlArgs(string):
...
@@ -208,39 +208,26 @@ def ParseYamlArgs(string):
def
ParseYamlReturns
(
string
):
def
ParseYamlReturns
(
string
):
# Example: Tensor, Tensor
# Example0: Tensor(out), Tensor(out1)
# Example1: Tensor, Tensor
# list = [ ["", ret_type, orig_position], ...]
# Example2: Tensor[](out), Tensor
returns_list
=
[]
returns
=
[
x
.
strip
()
for
x
in
string
.
strip
().
split
(
","
)]
for
i
in
range
(
len
(
returns
)):
ret_type
=
returns
[
i
]
assert
ret_type
in
yaml_types_mapping
.
keys
()
ret_type
=
yaml_types_mapping
[
ret_type
]
returns_list
.
append
([
""
,
ret_type
,
i
])
return
returns_list
def
ParseYamlReturnsWithName
(
string
):
# Example: Tensor(out), Tensor(out1)
# list = [ [ret_name, ret_type, orig_position], ...]
# list = [ [ret_name, ret_type, orig_position], ...]
returns_list
=
[]
returns_list
=
[]
returns
=
[
x
.
strip
()
for
x
in
string
.
strip
().
split
(
","
)]
returns
=
[
x
.
strip
()
for
x
in
string
.
strip
().
split
(
","
)]
atype
=
r
'(.*?)'
aname
=
r
'(.*?)'
pattern
=
f
'
{
atype
}
\(
{
aname
}
\)'
for
i
in
range
(
len
(
returns
)):
for
i
in
range
(
len
(
returns
)):
ret
=
returns
[
i
]
ret
=
returns
[
i
]
m
=
re
.
search
(
pattern
,
ret
)
ret_type
=
m
.
group
(
1
)
ret_name
=
""
ret_name
=
m
.
group
(
2
)
if
"("
in
ret
and
")"
in
ret
:
# Remove trailing ')'
ret
=
ret
[:
-
1
]
ret_type
=
ret
.
split
(
"("
)[
0
].
strip
()
ret_name
=
ret
.
split
(
"("
)[
1
].
strip
()
else
:
ret_type
=
ret
.
strip
()
assert
ret_type
in
yaml_types_mapping
.
keys
()
assert
ret_type
in
yaml_types_mapping
.
keys
()
ret_type
=
yaml_types_mapping
[
ret_type
]
ret_type
=
yaml_types_mapping
[
ret_type
]
...
@@ -266,7 +253,7 @@ def ParseYamlForwardFromBackward(string):
...
@@ -266,7 +253,7 @@ def ParseYamlForwardFromBackward(string):
function_returns
=
m
.
group
(
3
)
function_returns
=
m
.
group
(
3
)
forward_inputs_list
,
forward_attrs_list
=
ParseYamlArgs
(
function_args
)
forward_inputs_list
,
forward_attrs_list
=
ParseYamlArgs
(
function_args
)
forward_returns_list
=
ParseYamlReturns
WithName
(
function_returns
)
forward_returns_list
=
ParseYamlReturns
(
function_returns
)
return
forward_inputs_list
,
forward_attrs_list
,
forward_returns_list
return
forward_inputs_list
,
forward_attrs_list
,
forward_returns_list
...
@@ -296,7 +283,7 @@ def ParseYamlBackward(args_str, returns_str):
...
@@ -296,7 +283,7 @@ def ParseYamlBackward(args_str, returns_str):
args_str
=
re
.
search
(
args_pattern
,
args_str
).
group
(
1
)
args_str
=
re
.
search
(
args_pattern
,
args_str
).
group
(
1
)
inputs_list
,
attrs_list
=
ParseYamlArgs
(
args_str
)
inputs_list
,
attrs_list
=
ParseYamlArgs
(
args_str
)
returns_list
=
ParseYamlReturns
WithName
(
returns_str
)
returns_list
=
ParseYamlReturns
(
returns_str
)
return
inputs_list
,
attrs_list
,
returns_list
return
inputs_list
,
attrs_list
,
returns_list
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录