diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py index 4b3b9fcf298855ae09856636e3e7af40ae8ae6da..cbe6b8a0ff942846d9e751a3daa2e02e4eefbb6b 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py @@ -368,5 +368,8 @@ class StaticAnalysisVisitor(object): if isinstance(node.func, gast.Name): return self.var_env.get_var_type(node.func.id) + if isinstance(node, gast.Subscript): + if self.is_tensor_node(node.value): + return {NodeVarType.TENSOR} return {NodeVarType.STATEMENT} diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py index 0243ef3a6ddae92efc14ee20d058481be8fa669e..e630c2b9c6feb35e01fd0a1882811639b880813a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest +import paddle import numpy as np import paddle.fluid as fluid from paddle.fluid.dygraph.jit import declarative @@ -61,6 +62,33 @@ def test_list_append_in_for_loop(x, iter_num): return a[0] +paddle.jit.set_code_level(100) + + +def test_list_append_in_for_subscript(x): + x = fluid.dygraph.to_variable(x) + iter_num = paddle.shape(x)[0] + a = [] + for i in range(iter_num): + x = x + 1 + a.append(x) + out = paddle.concat(a) + return out[0] + + +def test_list_append_in_while_loop_subscript(x): + x = fluid.dygraph.to_variable(x) + iter_num = paddle.shape(x)[0] + a = [] + i = 0 + while i < iter_num: + x = x + 1 + a.append(x) + i += 1 + out = paddle.concat(a) + return out[0] + + def test_list_append_in_for_loop_with_concat(x, iter_num): x = fluid.dygraph.to_variable(x) a = [] @@ -261,5 +289,16 @@ class TestListInForLoopWithConcat(TestListInWhileLoopWithStack): self.all_dygraph_funcs = [test_list_append_in_for_loop_with_concat, ] +class TestListInForLoopWithSubscript(TestListWithoutControlFlow): + def init_dygraph_func(self): + self.all_dygraph_funcs = [ + test_list_append_in_for_subscript, + test_list_append_in_while_loop_subscript + ] + + def init_data(self): + self.input = np.random.random((3, 4)).astype('float32') + + if __name__ == '__main__': unittest.main()