From b2483d78a8261c9e493d63164af2c61ca4b507c3 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Fri, 8 Jan 2021 11:18:49 +0800 Subject: [PATCH] Fix test_slice: avoid unnecessary copying of TensorArray from subblock to parent block(#30168) In control flow, don't copy TensorArray from subblock to parent block when TensorArray is created in parent block. --- python/paddle/fluid/layers/control_flow.py | 10 +++++++--- .../tests/unittests/dygraph_to_static/test_slice.py | 5 ++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 2ab807d1cf..b735ae247f 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -2278,9 +2278,13 @@ def copy_var_to_parent_block(var, layer_helper): assert parent_idx >= 0, "Got wrong parent block index when assigning var to parent scope in control_flow" parent_block = prog.block(parent_idx) - parent_block_var = parent_block.create_var( - dtype=var.dtype, shape=var.shape, type=var.type) - assign(var, parent_block_var) + if var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ + and parent_block._find_var_recursive(var.name): + parent_block_var = var + else: + parent_block_var = parent_block.create_var( + dtype=var.dtype, shape=var.shape, type=var.type) + assign(var, parent_block_var) return parent_block_var diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py index bf74299806..13bdbaedbe 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py @@ -40,9 +40,12 @@ def test_slice_in_if(x): if x.numpy()[0] > 0: a.append(x) else: - a.append(paddle.full(shape=[1, 2], fill_value=9, dtype="int64")) + a.append(paddle.full(shape=[1, 2], fill_value=9, dtype="int32")) + if x.numpy()[0] > 0: a[0] = x + + a[0] = x + 1 out = a[0] return out -- GitLab