diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py index 3e606139245d60425cd488691b7f78b18e9c1ae6..a80dfa11402c5c434f278ab2964cf6efda41b106 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py @@ -39,7 +39,7 @@ class CallTransformer(gast.NodeTransformer): Determines whether a function needs to be transformed by `convert_call`. It doesn't need to be transformed when a function satisfies the following conditions: 1. It's a api of paddle - 2. It's a python builtin function not include `len` + 2. It's a python builtin function not include `len` and `zip` """ assert isinstance(node, gast.Call) if is_paddle_api(node): @@ -47,10 +47,11 @@ class CallTransformer(gast.NodeTransformer): func_str = ast_to_source_code(node.func).strip() try: - from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin + from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin, is_builtin_zip is_builtin = eval("is_builtin({})".format(func_str)) is_builtin_len = eval("is_builtin_len({})".format(func_str)) - return is_builtin and not is_builtin_len + is_builtin_zip = eval("is_builtin_zip({})".format(func_str)) + return is_builtin and not is_builtin_len and not is_builtin_zip except Exception: return False diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index 300586969ff65bc3982cf89b21a8f718028dd9b5..0b009c0049dcb8ed883b16d05aecd2bfc3021a95 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -27,7 +27,7 @@ import numpy import six from paddle.fluid.dygraph.container import Sequential -from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len +from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len, convert_zip from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static @@ -79,6 +79,10 @@ def is_builtin_len(func): return False +def is_builtin_zip(func): + return is_builtin(func) and func.__name__ == 'zip' + + def is_unsupported(func): """ Checks whether the func is supported by dygraph to static graph. @@ -164,6 +168,9 @@ def convert_call(func): if is_builtin_len(func): return convert_len + if is_builtin_zip(func): + return convert_zip + if is_builtin(func) or is_unsupported(func): return func diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 0ac4da947a46bcd289656fa83af7ddc1a3a74dab..ba45dedc40faa473c3c1a7e1f2dfba5a47e2a381 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -298,6 +298,15 @@ def convert_len(var): return len(var) +def convert_zip(*args): + for i, arg in enumerate(args): + if isinstance(arg, Variable) and arg.shape[0] == -1: + raise RuntimeError( + "Not support zip(tensor, ...) when tensor.shape[0] == -1, " + "but found args[{}].shape[0] == -1 in 'zip'".format(str(i))) + return zip(*args) + + def convert_var_shape(x, idx=None, in_control_flow=False): """ A function representation of the shape of variable. diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py index 2aab27c03110d16f45c7c7627435f72630e748e1..750ed615e7109e407a2899beb96d9bffe6925125 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py @@ -20,6 +20,7 @@ import unittest import paddle import paddle.fluid as fluid from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator +from paddle.static import InputSpec program_translator = ProgramTranslator() @@ -322,6 +323,24 @@ def for_original_tuple(): return z +# 23. for zip error +@paddle.jit.to_static( + input_spec=[InputSpec(shape=[None, 10]), InputSpec(shape=[None, 10])]) +def for_zip_error(x, y): + for i, j in zip(x, y): + a = i + j + return x + y + + +# 24. for zip +@paddle.jit.to_static( + input_spec=[InputSpec(shape=[2, 10]), InputSpec(shape=[2, 10])]) +def for_zip(x, y): + for i, j in zip(x, y): + a = i + j + return x + y + + class TestTransformBase(unittest.TestCase): def setUp(self): self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( @@ -512,5 +531,14 @@ class TestForOriginalTuple(TestTransformForOriginalList): self.transformed_result_compare() +class TestForZip(unittest.TestCase): + def test_for_zip_error(self): + with self.assertRaises(RuntimeError): + paddle.jit.save(for_zip_error, './for_zip_error') + + def test_for_zip(self): + paddle.jit.save(for_zip, './for_zip') + + if __name__ == '__main__': unittest.main()