From 5ea3818ab1688bab1c1013521612a92c6f9fd720 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 7 May 2020 20:13:44 +0800 Subject: [PATCH] fix slice bug in while loop, test=develop (#24326) --- python/paddle/fluid/framework.py | 17 +++++------ .../tests/unittests/test_while_loop_op.py | 28 +++++++++++++++++++ 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index bb2a250854..b77d3ef549 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -647,6 +647,7 @@ def _getitem_impl_(var, item): slice_step = [] use_strided_slice = False reverse_axis = [] + target_block = default_main_program().current_block() def fill_constant(shape, value, force_cpu=False, out=None): var.block.append_op( @@ -701,8 +702,8 @@ def _getitem_impl_(var, item): if isinstance(slice_item, Variable): temp_1 = var.block.create_var(dtype='int32') fill_constant([1], 1, force_cpu=True, out=temp_1) - temp_end = var.block.create_var(dtype='int32') - var.block.append_op( + temp_end = target_block.create_var(dtype='int32') + target_block.append_op( type='elementwise_add', inputs={'X': slice_item, 'Y': temp_1}, @@ -785,11 +786,11 @@ def _getitem_impl_(var, item): out = var if use_strided_slice == False and len(slice_axis) > 0: # append slice_op here - slice_out_var = var.block.create_var( + slice_out_var = target_block.create_var( name=unique_name.generate_with_ignorable_key(var.name + "_slice"), dtype=var.dtype) - var.block.append_op( + target_block.append_op( type="slice", inputs=inputs, outputs={'Out': [slice_out_var]}, @@ -797,11 +798,11 @@ def _getitem_impl_(var, item): out = slice_out_var elif use_strided_slice == True and len(slice_axis) > 0: - strided_slice_out_var = var.block.create_var( + strided_slice_out_var = target_block.create_var( name=unique_name.generate_with_ignorable_key(var.name + "_strided_slice"), dtype=var.dtype) - var.block.append_op( + target_block.append_op( type="strided_slice", inputs=inputs, outputs={'Out': [strided_slice_out_var]}, @@ -810,11 +811,11 @@ def _getitem_impl_(var, item): out = strided_slice_out_var if len(reverse_axis) > 0: - reverse_out_var = var.block.create_var( + reverse_out_var = target_block.create_var( name=unique_name.generate_with_ignorable_key(var.name + "_slice_reverse"), dtype=var.dtype) - var.block.append_op( + target_block.append_op( type="reverse", inputs={'X': out}, outputs={'Out': [reverse_out_var]}, diff --git a/python/paddle/fluid/tests/unittests/test_while_loop_op.py b/python/paddle/fluid/tests/unittests/test_while_loop_op.py index b9c3d9dbf3..224dfd7f0a 100644 --- a/python/paddle/fluid/tests/unittests/test_while_loop_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_loop_op.py @@ -494,5 +494,33 @@ class TestApiWhileLoop_Error(unittest.TestCase): value_error_body_returns_with_mutable_list) +class TestApiWhileLoopSliceInBody(unittest.TestCase): + def test_var_slice(self): + def cond(z, i): + return i + 1 <= x_shape[0] + + def body(z, i): + z = z + x[i] + i += 1 + return z, i + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = fluid.layers.data(name='x', shape=[5], dtype='int32') + z = fluid.layers.fill_constant([1], 'int32', 0) + x_shape = fluid.layers.shape(x) + i = fluid.layers.fill_constant([1], 'int32', 0) + z, _ = fluid.layers.while_loop(cond, body, [z, i]) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + + np_x = np.array([1, 2, 3, 4, 5], dtype='int32') + res = exe.run(main_program, feed={'x': np_x}, fetch_list=[z]) + self.assertTrue(np.array_equal(res[0], [np.sum(np_x)])) + + if __name__ == '__main__': unittest.main() -- GitLab