未验证 提交 5ea3818a 编写于 作者: C Chen Weihang 提交者: GitHub

fix slice bug in while loop, test=develop (#24326)

上级 420707c2
......@@ -647,6 +647,7 @@ def _getitem_impl_(var, item):
slice_step = []
use_strided_slice = False
reverse_axis = []
target_block = default_main_program().current_block()
def fill_constant(shape, value, force_cpu=False, out=None):
var.block.append_op(
......@@ -701,8 +702,8 @@ def _getitem_impl_(var, item):
if isinstance(slice_item, Variable):
temp_1 = var.block.create_var(dtype='int32')
fill_constant([1], 1, force_cpu=True, out=temp_1)
temp_end = var.block.create_var(dtype='int32')
var.block.append_op(
temp_end = target_block.create_var(dtype='int32')
target_block.append_op(
type='elementwise_add',
inputs={'X': slice_item,
'Y': temp_1},
......@@ -785,11 +786,11 @@ def _getitem_impl_(var, item):
out = var
if use_strided_slice == False and len(slice_axis) > 0:
# append slice_op here
slice_out_var = var.block.create_var(
slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name + "_slice"),
dtype=var.dtype)
var.block.append_op(
target_block.append_op(
type="slice",
inputs=inputs,
outputs={'Out': [slice_out_var]},
......@@ -797,11 +798,11 @@ def _getitem_impl_(var, item):
out = slice_out_var
elif use_strided_slice == True and len(slice_axis) > 0:
strided_slice_out_var = var.block.create_var(
strided_slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name +
"_strided_slice"),
dtype=var.dtype)
var.block.append_op(
target_block.append_op(
type="strided_slice",
inputs=inputs,
outputs={'Out': [strided_slice_out_var]},
......@@ -810,11 +811,11 @@ def _getitem_impl_(var, item):
out = strided_slice_out_var
if len(reverse_axis) > 0:
reverse_out_var = var.block.create_var(
reverse_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name +
"_slice_reverse"),
dtype=var.dtype)
var.block.append_op(
target_block.append_op(
type="reverse",
inputs={'X': out},
outputs={'Out': [reverse_out_var]},
......
......@@ -494,5 +494,33 @@ class TestApiWhileLoop_Error(unittest.TestCase):
value_error_body_returns_with_mutable_list)
class TestApiWhileLoopSliceInBody(unittest.TestCase):
def test_var_slice(self):
def cond(z, i):
return i + 1 <= x_shape[0]
def body(z, i):
z = z + x[i]
i += 1
return z, i
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
x = fluid.layers.data(name='x', shape=[5], dtype='int32')
z = fluid.layers.fill_constant([1], 'int32', 0)
x_shape = fluid.layers.shape(x)
i = fluid.layers.fill_constant([1], 'int32', 0)
z, _ = fluid.layers.while_loop(cond, body, [z, i])
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
np_x = np.array([1, 2, 3, 4, 5], dtype='int32')
res = exe.run(main_program, feed={'x': np_x}, fetch_list=[z])
self.assertTrue(np.array_equal(res[0], [np.sum(np_x)]))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册