未验证 提交 e014950e 编写于 作者: W wopeizl 提交者: GitHub

add slice support for dim < 0 (#16494)

* add slice support for dim < 0 test=develop
上级 8f7b5883
...@@ -789,9 +789,16 @@ class Variable(object): ...@@ -789,9 +789,16 @@ class Variable(object):
if isinstance(item, tuple): if isinstance(item, tuple):
if len(item) > len(self.shape): if len(item) > len(self.shape):
raise IndexError("Too many indexes") raise IndexError("Too many indexes")
fixedSize = True
for i in range(len(self.shape)):
if self.shape[i] == -1:
fixedSize = False
break
newitem = self._reconstructSliceinfo(item) or item newitem = self._reconstructSliceinfo(item) or item
if fixedSize:
check, info = self._detectContinuesSlice(newitem) check, info = self._detectContinuesSlice(newitem)
if check: if check and fixedSize:
starts = info[0] starts = info[0]
ends = info[1] ends = info[1]
axes = [i for i in range(len(starts))] axes = [i for i in range(len(starts))]
...@@ -800,6 +807,10 @@ class Variable(object): ...@@ -800,6 +807,10 @@ class Variable(object):
new_var = self new_var = self
for index, o in enumerate(newitem): for index, o in enumerate(newitem):
new_var = new_var._sliceAndConcatVar(o, index) new_var = new_var._sliceAndConcatVar(o, index)
else:
new_var = self
for index, o in enumerate(newitem):
new_var = new_var._sliceAndConcatVar(o, index)
else: else:
new_var = self._sliceAndConcatVar(item, 0) new_var = self._sliceAndConcatVar(item, 0)
return new_var return new_var
......
...@@ -61,7 +61,7 @@ class TestVariable(unittest.TestCase): ...@@ -61,7 +61,7 @@ class TestVariable(unittest.TestCase):
name='step_scopes', type=core.VarDesc.VarType.STEP_SCOPES) name='step_scopes', type=core.VarDesc.VarType.STEP_SCOPES)
self.assertEqual(core.VarDesc.VarType.STEP_SCOPES, var.type) self.assertEqual(core.VarDesc.VarType.STEP_SCOPES, var.type)
def _test_slice(self): def _test_slice(self, place):
b = default_main_program().current_block() b = default_main_program().current_block()
w = b.create_var(dtype="float64", shape=[784, 100, 100], lod_level=0) w = b.create_var(dtype="float64", shape=[784, 100, 100], lod_level=0)
...@@ -83,7 +83,6 @@ class TestVariable(unittest.TestCase): ...@@ -83,7 +83,6 @@ class TestVariable(unittest.TestCase):
self.assertEqual(0, nw.lod_level) self.assertEqual(0, nw.lod_level)
place = fluid.CPUPlace()
main = fluid.Program() main = fluid.Program()
with fluid.program_guard(main): with fluid.program_guard(main):
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -100,10 +99,23 @@ class TestVariable(unittest.TestCase): ...@@ -100,10 +99,23 @@ class TestVariable(unittest.TestCase):
var6 = var[1, 1:, 1:] var6 = var[1, 1:, 1:]
var7 = var[1, ..., 1:] var7 = var[1, ..., 1:]
var8 = var[1, ...] var8 = var[1, ...]
var_reshape = fluid.layers.reshape(var, [3, -1, 3])
var9 = var_reshape[1, ..., 2]
var10 = var_reshape[:, :, -1]
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.fc(input=x, size=1, act=None)
var11 = y[:, 0]
feeder = fluid.DataFeeder(place=place, feed_list=[x])
data = []
data.append((np.random.randint(10, size=[13]).astype('float32')))
exe.run(fluid.default_startup_program())
local_out = exe.run(main, local_out = exe.run(main,
feed=feeder.feed([data]),
fetch_list=[ fetch_list=[
var, var1, var2, var3, var4, var5, var6, var, var1, var2, var3, var4, var5, var6,
var7, var8 var7, var8, var9, var10, var11
]) ])
self.assertTrue((np.array(local_out[1]) == np.array(tensor_array[ self.assertTrue((np.array(local_out[1]) == np.array(tensor_array[
...@@ -122,38 +134,16 @@ class TestVariable(unittest.TestCase): ...@@ -122,38 +134,16 @@ class TestVariable(unittest.TestCase):
1, ..., 1:])).all()) 1, ..., 1:])).all())
self.assertTrue((np.array(local_out[8]) == np.array(tensor_array[ self.assertTrue((np.array(local_out[8]) == np.array(tensor_array[
1, ...])).all()) 1, ...])).all())
self.assertEqual(local_out[9].shape, (1, 3, 1))
self.assertEqual(local_out[10].shape, (3, 3, 1))
self.assertEqual(local_out[11].shape, (1, 1))
def test_slice(self): def test_slice(self):
self._test_slice() place = fluid.CPUPlace()
self._test_slice(place)
class TestVariableImperative(unittest.TestCase):
def _test_slice(self):
b = default_main_program().current_block()
w = b.create_var(dtype="float64", shape=[784, 100, 100], lod_level=0)
for i in range(3):
nw = w[i]
self.assertEqual([1, 100, 100], nw.shape)
nw = w[:]
self.assertEqual([784, 100, 100], nw.shape)
nw = w[:, :, :]
self.assertEqual([784, 100, 100], nw.shape)
nw = w[::2, ::2, :]
self.assertEqual([392, 50, 100], nw.shape)
nw = w[::-2, ::-2, :]
self.assertEqual([392, 50, 100], nw.shape)
nw = w[0::-2, 0::-2, :]
self.assertEqual([1, 1, 100], nw.shape)
def test_slice(self): if core.is_compiled_with_cuda():
with fluid.dygraph.guard(): self._test_slice(core.CUDAPlace(0))
self._test_slice()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册