未验证 提交 026de65c 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2Stat]Polish for zip in dy2stat (#37846) (#37912)

Polish for zip in dy2stat
上级 4114c4a1
...@@ -39,7 +39,7 @@ class CallTransformer(gast.NodeTransformer): ...@@ -39,7 +39,7 @@ class CallTransformer(gast.NodeTransformer):
Determines whether a function needs to be transformed by `convert_call`. Determines whether a function needs to be transformed by `convert_call`.
It doesn't need to be transformed when a function satisfies the following conditions: It doesn't need to be transformed when a function satisfies the following conditions:
1. It's a api of paddle 1. It's a api of paddle
2. It's a python builtin function not include `len` 2. It's a python builtin function not include `len` and `zip`
""" """
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
if is_paddle_api(node): if is_paddle_api(node):
...@@ -47,10 +47,11 @@ class CallTransformer(gast.NodeTransformer): ...@@ -47,10 +47,11 @@ class CallTransformer(gast.NodeTransformer):
func_str = ast_to_source_code(node.func).strip() func_str = ast_to_source_code(node.func).strip()
try: try:
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin, is_builtin_zip
is_builtin = eval("is_builtin({})".format(func_str)) is_builtin = eval("is_builtin({})".format(func_str))
is_builtin_len = eval("is_builtin_len({})".format(func_str)) is_builtin_len = eval("is_builtin_len({})".format(func_str))
return is_builtin and not is_builtin_len is_builtin_zip = eval("is_builtin_zip({})".format(func_str))
return is_builtin and not is_builtin_len and not is_builtin_zip
except Exception: except Exception:
return False return False
......
...@@ -27,7 +27,7 @@ import numpy ...@@ -27,7 +27,7 @@ import numpy
import six import six
from paddle.fluid.dygraph.container import Sequential from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len, convert_zip
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
...@@ -79,6 +79,10 @@ def is_builtin_len(func): ...@@ -79,6 +79,10 @@ def is_builtin_len(func):
return False return False
def is_builtin_zip(func):
return is_builtin(func) and func.__name__ == 'zip'
def is_unsupported(func): def is_unsupported(func):
""" """
Checks whether the func is supported by dygraph to static graph. Checks whether the func is supported by dygraph to static graph.
...@@ -164,6 +168,9 @@ def convert_call(func): ...@@ -164,6 +168,9 @@ def convert_call(func):
if is_builtin_len(func): if is_builtin_len(func):
return convert_len return convert_len
if is_builtin_zip(func):
return convert_zip
if is_builtin(func) or is_unsupported(func): if is_builtin(func) or is_unsupported(func):
return func return func
......
...@@ -298,6 +298,15 @@ def convert_len(var): ...@@ -298,6 +298,15 @@ def convert_len(var):
return len(var) return len(var)
def convert_zip(*args):
for i, arg in enumerate(args):
if isinstance(arg, Variable) and arg.shape[0] == -1:
raise RuntimeError(
"Not support zip(tensor, ...) when tensor.shape[0] == -1, "
"but found args[{}].shape[0] == -1 in 'zip'".format(str(i)))
return zip(*args)
def convert_var_shape(x, idx=None, in_control_flow=False): def convert_var_shape(x, idx=None, in_control_flow=False):
""" """
A function representation of the shape of variable. A function representation of the shape of variable.
......
...@@ -20,6 +20,7 @@ import unittest ...@@ -20,6 +20,7 @@ import unittest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.static import InputSpec
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
...@@ -322,6 +323,24 @@ def for_original_tuple(): ...@@ -322,6 +323,24 @@ def for_original_tuple():
return z return z
# 23. for zip error
@paddle.jit.to_static(
input_spec=[InputSpec(shape=[None, 10]), InputSpec(shape=[None, 10])])
def for_zip_error(x, y):
for i, j in zip(x, y):
a = i + j
return x + y
# 24. for zip
@paddle.jit.to_static(
input_spec=[InputSpec(shape=[2, 10]), InputSpec(shape=[2, 10])])
def for_zip(x, y):
for i, j in zip(x, y):
a = i + j
return x + y
class TestTransformBase(unittest.TestCase): class TestTransformBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
...@@ -512,5 +531,14 @@ class TestForOriginalTuple(TestTransformForOriginalList): ...@@ -512,5 +531,14 @@ class TestForOriginalTuple(TestTransformForOriginalList):
self.transformed_result_compare() self.transformed_result_compare()
class TestForZip(unittest.TestCase):
def test_for_zip_error(self):
with self.assertRaises(RuntimeError):
paddle.jit.save(for_zip_error, './for_zip_error')
def test_for_zip(self):
paddle.jit.save(for_zip, './for_zip')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册