Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
026de65c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
026de65c
编写于
12月 09, 2021
作者:
0
0x45f
提交者:
GitHub
12月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2Stat]Polish for zip in dy2stat (#37846) (#37912)
Polish for zip in dy2stat
上级
4114c4a1
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
49 addition
and
4 deletion
+49
-4
python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py
...addle/fluid/dygraph/dygraph_to_static/call_transformer.py
+4
-3
python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py
...ddle/fluid/dygraph/dygraph_to_static/convert_call_func.py
+8
-1
python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
...ddle/fluid/dygraph/dygraph_to_static/convert_operators.py
+9
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py
...d/tests/unittests/dygraph_to_static/test_for_enumerate.py
+28
-0
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py
浏览文件 @
026de65c
...
@@ -39,7 +39,7 @@ class CallTransformer(gast.NodeTransformer):
...
@@ -39,7 +39,7 @@ class CallTransformer(gast.NodeTransformer):
Determines whether a function needs to be transformed by `convert_call`.
Determines whether a function needs to be transformed by `convert_call`.
It doesn't need to be transformed when a function satisfies the following conditions:
It doesn't need to be transformed when a function satisfies the following conditions:
1. It's a api of paddle
1. It's a api of paddle
2. It's a python builtin function not include `len`
2. It's a python builtin function not include `len`
and `zip`
"""
"""
assert
isinstance
(
node
,
gast
.
Call
)
assert
isinstance
(
node
,
gast
.
Call
)
if
is_paddle_api
(
node
):
if
is_paddle_api
(
node
):
...
@@ -47,10 +47,11 @@ class CallTransformer(gast.NodeTransformer):
...
@@ -47,10 +47,11 @@ class CallTransformer(gast.NodeTransformer):
func_str
=
ast_to_source_code
(
node
.
func
).
strip
()
func_str
=
ast_to_source_code
(
node
.
func
).
strip
()
try
:
try
:
from
paddle.fluid.dygraph.dygraph_to_static.convert_call_func
import
is_builtin_len
,
is_builtin
from
paddle.fluid.dygraph.dygraph_to_static.convert_call_func
import
is_builtin_len
,
is_builtin
,
is_builtin_zip
is_builtin
=
eval
(
"is_builtin({})"
.
format
(
func_str
))
is_builtin
=
eval
(
"is_builtin({})"
.
format
(
func_str
))
is_builtin_len
=
eval
(
"is_builtin_len({})"
.
format
(
func_str
))
is_builtin_len
=
eval
(
"is_builtin_len({})"
.
format
(
func_str
))
return
is_builtin
and
not
is_builtin_len
is_builtin_zip
=
eval
(
"is_builtin_zip({})"
.
format
(
func_str
))
return
is_builtin
and
not
is_builtin_len
and
not
is_builtin_zip
except
Exception
:
except
Exception
:
return
False
return
False
...
...
python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py
浏览文件 @
026de65c
...
@@ -27,7 +27,7 @@ import numpy
...
@@ -27,7 +27,7 @@ import numpy
import
six
import
six
from
paddle.fluid.dygraph.container
import
Sequential
from
paddle.fluid.dygraph.container
import
Sequential
from
paddle.fluid.dygraph.dygraph_to_static.convert_operators
import
convert_len
from
paddle.fluid.dygraph.dygraph_to_static.convert_operators
import
convert_len
,
convert_zip
from
paddle.fluid.dygraph.dygraph_to_static.logging_utils
import
TranslatorLogger
from
paddle.fluid.dygraph.dygraph_to_static.logging_utils
import
TranslatorLogger
from
paddle.fluid.dygraph.dygraph_to_static.program_translator
import
StaticFunction
from
paddle.fluid.dygraph.dygraph_to_static.program_translator
import
StaticFunction
from
paddle.fluid.dygraph.dygraph_to_static.program_translator
import
convert_to_static
from
paddle.fluid.dygraph.dygraph_to_static.program_translator
import
convert_to_static
...
@@ -79,6 +79,10 @@ def is_builtin_len(func):
...
@@ -79,6 +79,10 @@ def is_builtin_len(func):
return
False
return
False
def
is_builtin_zip
(
func
):
return
is_builtin
(
func
)
and
func
.
__name__
==
'zip'
def
is_unsupported
(
func
):
def
is_unsupported
(
func
):
"""
"""
Checks whether the func is supported by dygraph to static graph.
Checks whether the func is supported by dygraph to static graph.
...
@@ -164,6 +168,9 @@ def convert_call(func):
...
@@ -164,6 +168,9 @@ def convert_call(func):
if
is_builtin_len
(
func
):
if
is_builtin_len
(
func
):
return
convert_len
return
convert_len
if
is_builtin_zip
(
func
):
return
convert_zip
if
is_builtin
(
func
)
or
is_unsupported
(
func
):
if
is_builtin
(
func
)
or
is_unsupported
(
func
):
return
func
return
func
...
...
python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
浏览文件 @
026de65c
...
@@ -298,6 +298,15 @@ def convert_len(var):
...
@@ -298,6 +298,15 @@ def convert_len(var):
return
len
(
var
)
return
len
(
var
)
def
convert_zip
(
*
args
):
for
i
,
arg
in
enumerate
(
args
):
if
isinstance
(
arg
,
Variable
)
and
arg
.
shape
[
0
]
==
-
1
:
raise
RuntimeError
(
"Not support zip(tensor, ...) when tensor.shape[0] == -1, "
"but found args[{}].shape[0] == -1 in 'zip'"
.
format
(
str
(
i
)))
return
zip
(
*
args
)
def
convert_var_shape
(
x
,
idx
=
None
,
in_control_flow
=
False
):
def
convert_var_shape
(
x
,
idx
=
None
,
in_control_flow
=
False
):
"""
"""
A function representation of the shape of variable.
A function representation of the shape of variable.
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py
浏览文件 @
026de65c
...
@@ -20,6 +20,7 @@ import unittest
...
@@ -20,6 +20,7 @@ import unittest
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.dygraph_to_static
import
ProgramTranslator
from
paddle.fluid.dygraph.dygraph_to_static
import
ProgramTranslator
from
paddle.static
import
InputSpec
program_translator
=
ProgramTranslator
()
program_translator
=
ProgramTranslator
()
...
@@ -322,6 +323,24 @@ def for_original_tuple():
...
@@ -322,6 +323,24 @@ def for_original_tuple():
return
z
return
z
# 23. for zip error
@
paddle
.
jit
.
to_static
(
input_spec
=
[
InputSpec
(
shape
=
[
None
,
10
]),
InputSpec
(
shape
=
[
None
,
10
])])
def
for_zip_error
(
x
,
y
):
for
i
,
j
in
zip
(
x
,
y
):
a
=
i
+
j
return
x
+
y
# 24. for zip
@
paddle
.
jit
.
to_static
(
input_spec
=
[
InputSpec
(
shape
=
[
2
,
10
]),
InputSpec
(
shape
=
[
2
,
10
])])
def
for_zip
(
x
,
y
):
for
i
,
j
in
zip
(
x
,
y
):
a
=
i
+
j
return
x
+
y
class
TestTransformBase
(
unittest
.
TestCase
):
class
TestTransformBase
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
is_compiled_with_cuda
(
self
.
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
is_compiled_with_cuda
(
...
@@ -512,5 +531,14 @@ class TestForOriginalTuple(TestTransformForOriginalList):
...
@@ -512,5 +531,14 @@ class TestForOriginalTuple(TestTransformForOriginalList):
self
.
transformed_result_compare
()
self
.
transformed_result_compare
()
class
TestForZip
(
unittest
.
TestCase
):
def
test_for_zip_error
(
self
):
with
self
.
assertRaises
(
RuntimeError
):
paddle
.
jit
.
save
(
for_zip_error
,
'./for_zip_error'
)
def
test_for_zip
(
self
):
paddle
.
jit
.
save
(
for_zip
,
'./for_zip'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录