未验证 提交 e014950e 编写于 作者: W wopeizl 提交者: GitHub

add slice support for dim < 0 (#16494)

* add slice support for dim < 0 test=develop
上级 8f7b5883
......@@ -789,13 +789,24 @@ class Variable(object):
if isinstance(item, tuple):
if len(item) > len(self.shape):
raise IndexError("Too many indexes")
fixedSize = True
for i in range(len(self.shape)):
if self.shape[i] == -1:
fixedSize = False
break
newitem = self._reconstructSliceinfo(item) or item
check, info = self._detectContinuesSlice(newitem)
if check:
starts = info[0]
ends = info[1]
axes = [i for i in range(len(starts))]
return self._sliceVar(axes, starts, ends)
if fixedSize:
check, info = self._detectContinuesSlice(newitem)
if check and fixedSize:
starts = info[0]
ends = info[1]
axes = [i for i in range(len(starts))]
return self._sliceVar(axes, starts, ends)
else:
new_var = self
for index, o in enumerate(newitem):
new_var = new_var._sliceAndConcatVar(o, index)
else:
new_var = self
for index, o in enumerate(newitem):
......
......@@ -61,7 +61,7 @@ class TestVariable(unittest.TestCase):
name='step_scopes', type=core.VarDesc.VarType.STEP_SCOPES)
self.assertEqual(core.VarDesc.VarType.STEP_SCOPES, var.type)
def _test_slice(self):
def _test_slice(self, place):
b = default_main_program().current_block()
w = b.create_var(dtype="float64", shape=[784, 100, 100], lod_level=0)
......@@ -83,7 +83,6 @@ class TestVariable(unittest.TestCase):
self.assertEqual(0, nw.lod_level)
place = fluid.CPUPlace()
main = fluid.Program()
with fluid.program_guard(main):
exe = fluid.Executor(place)
......@@ -100,10 +99,23 @@ class TestVariable(unittest.TestCase):
var6 = var[1, 1:, 1:]
var7 = var[1, ..., 1:]
var8 = var[1, ...]
var_reshape = fluid.layers.reshape(var, [3, -1, 3])
var9 = var_reshape[1, ..., 2]
var10 = var_reshape[:, :, -1]
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.fc(input=x, size=1, act=None)
var11 = y[:, 0]
feeder = fluid.DataFeeder(place=place, feed_list=[x])
data = []
data.append((np.random.randint(10, size=[13]).astype('float32')))
exe.run(fluid.default_startup_program())
local_out = exe.run(main,
feed=feeder.feed([data]),
fetch_list=[
var, var1, var2, var3, var4, var5, var6,
var7, var8
var7, var8, var9, var10, var11
])
self.assertTrue((np.array(local_out[1]) == np.array(tensor_array[
......@@ -122,38 +134,16 @@ class TestVariable(unittest.TestCase):
1, ..., 1:])).all())
self.assertTrue((np.array(local_out[8]) == np.array(tensor_array[
1, ...])).all())
self.assertEqual(local_out[9].shape, (1, 3, 1))
self.assertEqual(local_out[10].shape, (3, 3, 1))
self.assertEqual(local_out[11].shape, (1, 1))
def test_slice(self):
self._test_slice()
class TestVariableImperative(unittest.TestCase):
def _test_slice(self):
b = default_main_program().current_block()
w = b.create_var(dtype="float64", shape=[784, 100, 100], lod_level=0)
for i in range(3):
nw = w[i]
self.assertEqual([1, 100, 100], nw.shape)
nw = w[:]
self.assertEqual([784, 100, 100], nw.shape)
nw = w[:, :, :]
self.assertEqual([784, 100, 100], nw.shape)
nw = w[::2, ::2, :]
self.assertEqual([392, 50, 100], nw.shape)
nw = w[::-2, ::-2, :]
self.assertEqual([392, 50, 100], nw.shape)
nw = w[0::-2, 0::-2, :]
self.assertEqual([1, 1, 100], nw.shape)
place = fluid.CPUPlace()
self._test_slice(place)
def test_slice(self):
with fluid.dygraph.guard():
self._test_slice()
if core.is_compiled_with_cuda():
self._test_slice(core.CUDAPlace(0))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册