未验证 提交 5ea3818a 编写于 作者: C Chen Weihang 提交者: GitHub

fix slice bug in while loop, test=develop (#24326)

上级 420707c2
...@@ -647,6 +647,7 @@ def _getitem_impl_(var, item): ...@@ -647,6 +647,7 @@ def _getitem_impl_(var, item):
slice_step = [] slice_step = []
use_strided_slice = False use_strided_slice = False
reverse_axis = [] reverse_axis = []
target_block = default_main_program().current_block()
def fill_constant(shape, value, force_cpu=False, out=None): def fill_constant(shape, value, force_cpu=False, out=None):
var.block.append_op( var.block.append_op(
...@@ -701,8 +702,8 @@ def _getitem_impl_(var, item): ...@@ -701,8 +702,8 @@ def _getitem_impl_(var, item):
if isinstance(slice_item, Variable): if isinstance(slice_item, Variable):
temp_1 = var.block.create_var(dtype='int32') temp_1 = var.block.create_var(dtype='int32')
fill_constant([1], 1, force_cpu=True, out=temp_1) fill_constant([1], 1, force_cpu=True, out=temp_1)
temp_end = var.block.create_var(dtype='int32') temp_end = target_block.create_var(dtype='int32')
var.block.append_op( target_block.append_op(
type='elementwise_add', type='elementwise_add',
inputs={'X': slice_item, inputs={'X': slice_item,
'Y': temp_1}, 'Y': temp_1},
...@@ -785,11 +786,11 @@ def _getitem_impl_(var, item): ...@@ -785,11 +786,11 @@ def _getitem_impl_(var, item):
out = var out = var
if use_strided_slice == False and len(slice_axis) > 0: if use_strided_slice == False and len(slice_axis) > 0:
# append slice_op here # append slice_op here
slice_out_var = var.block.create_var( slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name + "_slice"), name=unique_name.generate_with_ignorable_key(var.name + "_slice"),
dtype=var.dtype) dtype=var.dtype)
var.block.append_op( target_block.append_op(
type="slice", type="slice",
inputs=inputs, inputs=inputs,
outputs={'Out': [slice_out_var]}, outputs={'Out': [slice_out_var]},
...@@ -797,11 +798,11 @@ def _getitem_impl_(var, item): ...@@ -797,11 +798,11 @@ def _getitem_impl_(var, item):
out = slice_out_var out = slice_out_var
elif use_strided_slice == True and len(slice_axis) > 0: elif use_strided_slice == True and len(slice_axis) > 0:
strided_slice_out_var = var.block.create_var( strided_slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name + name=unique_name.generate_with_ignorable_key(var.name +
"_strided_slice"), "_strided_slice"),
dtype=var.dtype) dtype=var.dtype)
var.block.append_op( target_block.append_op(
type="strided_slice", type="strided_slice",
inputs=inputs, inputs=inputs,
outputs={'Out': [strided_slice_out_var]}, outputs={'Out': [strided_slice_out_var]},
...@@ -810,11 +811,11 @@ def _getitem_impl_(var, item): ...@@ -810,11 +811,11 @@ def _getitem_impl_(var, item):
out = strided_slice_out_var out = strided_slice_out_var
if len(reverse_axis) > 0: if len(reverse_axis) > 0:
reverse_out_var = var.block.create_var( reverse_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name + name=unique_name.generate_with_ignorable_key(var.name +
"_slice_reverse"), "_slice_reverse"),
dtype=var.dtype) dtype=var.dtype)
var.block.append_op( target_block.append_op(
type="reverse", type="reverse",
inputs={'X': out}, inputs={'X': out},
outputs={'Out': [reverse_out_var]}, outputs={'Out': [reverse_out_var]},
......
...@@ -494,5 +494,33 @@ class TestApiWhileLoop_Error(unittest.TestCase): ...@@ -494,5 +494,33 @@ class TestApiWhileLoop_Error(unittest.TestCase):
value_error_body_returns_with_mutable_list) value_error_body_returns_with_mutable_list)
class TestApiWhileLoopSliceInBody(unittest.TestCase):
def test_var_slice(self):
def cond(z, i):
return i + 1 <= x_shape[0]
def body(z, i):
z = z + x[i]
i += 1
return z, i
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
x = fluid.layers.data(name='x', shape=[5], dtype='int32')
z = fluid.layers.fill_constant([1], 'int32', 0)
x_shape = fluid.layers.shape(x)
i = fluid.layers.fill_constant([1], 'int32', 0)
z, _ = fluid.layers.while_loop(cond, body, [z, i])
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
np_x = np.array([1, 2, 3, 4, 5], dtype='int32')
res = exe.run(main_program, feed={'x': np_x}, fetch_list=[z])
self.assertTrue(np.array_equal(res[0], [np.sum(np_x)]))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册