未验证 提交 026de65c 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2Stat]Polish for zip in dy2stat (#37846) (#37912)

Polish for zip in dy2stat
上级 4114c4a1
......@@ -39,7 +39,7 @@ class CallTransformer(gast.NodeTransformer):
Determines whether a function needs to be transformed by `convert_call`.
It doesn't need to be transformed when a function satisfies the following conditions:
1. It's a api of paddle
2. It's a python builtin function not include `len`
2. It's a python builtin function not include `len` and `zip`
"""
assert isinstance(node, gast.Call)
if is_paddle_api(node):
......@@ -47,10 +47,11 @@ class CallTransformer(gast.NodeTransformer):
func_str = ast_to_source_code(node.func).strip()
try:
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin, is_builtin_zip
is_builtin = eval("is_builtin({})".format(func_str))
is_builtin_len = eval("is_builtin_len({})".format(func_str))
return is_builtin and not is_builtin_len
is_builtin_zip = eval("is_builtin_zip({})".format(func_str))
return is_builtin and not is_builtin_len and not is_builtin_zip
except Exception:
return False
......
......@@ -27,7 +27,7 @@ import numpy
import six
from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len, convert_zip
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
......@@ -79,6 +79,10 @@ def is_builtin_len(func):
return False
def is_builtin_zip(func):
return is_builtin(func) and func.__name__ == 'zip'
def is_unsupported(func):
"""
Checks whether the func is supported by dygraph to static graph.
......@@ -164,6 +168,9 @@ def convert_call(func):
if is_builtin_len(func):
return convert_len
if is_builtin_zip(func):
return convert_zip
if is_builtin(func) or is_unsupported(func):
return func
......
......@@ -298,6 +298,15 @@ def convert_len(var):
return len(var)
def convert_zip(*args):
for i, arg in enumerate(args):
if isinstance(arg, Variable) and arg.shape[0] == -1:
raise RuntimeError(
"Not support zip(tensor, ...) when tensor.shape[0] == -1, "
"but found args[{}].shape[0] == -1 in 'zip'".format(str(i)))
return zip(*args)
def convert_var_shape(x, idx=None, in_control_flow=False):
"""
A function representation of the shape of variable.
......
......@@ -20,6 +20,7 @@ import unittest
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.static import InputSpec
program_translator = ProgramTranslator()
......@@ -322,6 +323,24 @@ def for_original_tuple():
return z
# 23. for zip error
@paddle.jit.to_static(
input_spec=[InputSpec(shape=[None, 10]), InputSpec(shape=[None, 10])])
def for_zip_error(x, y):
for i, j in zip(x, y):
a = i + j
return x + y
# 24. for zip
@paddle.jit.to_static(
input_spec=[InputSpec(shape=[2, 10]), InputSpec(shape=[2, 10])])
def for_zip(x, y):
for i, j in zip(x, y):
a = i + j
return x + y
class TestTransformBase(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
......@@ -512,5 +531,14 @@ class TestForOriginalTuple(TestTransformForOriginalList):
self.transformed_result_compare()
class TestForZip(unittest.TestCase):
def test_for_zip_error(self):
with self.assertRaises(RuntimeError):
paddle.jit.save(for_zip_error, './for_zip_error')
def test_for_zip(self):
paddle.jit.save(for_zip, './for_zip')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册