未验证 提交 aceec7fb 编写于 作者: L liym27 提交者: GitHub

[slice] Support index is Tensor for slice in dynamic mode (#32435)

上级 25e723e7
......@@ -746,7 +746,7 @@ void BindImperative(py::module *m_ptr) {
// inplace operator for the VarBase self.
self->BumpInplaceVersion();
})
.def("__getitem__",
.def("_getitem_index_not_tensor",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
std::vector<int> slice_axes, slice_starts, slice_ends,
slice_strides, decrease_axis, infer_flags;
......
......@@ -21,7 +21,7 @@ import paddle
from .. import framework
from .. import core
from .. import unique_name
from ..framework import Variable, Parameter, ParamBase
from ..framework import Variable, Parameter, ParamBase, _getitem_impl_
from .base import switch_to_static_graph
from .math_op_patch import monkey_patch_math_varbase
from .parallel import scale_loss
......@@ -437,6 +437,31 @@ def monkey_patch_varbase():
def __array__(self, dtype=None):
return self.numpy().astype(dtype)
def __getitem__(self, item):
def contain_tensor(item):
if not isinstance(item, tuple):
item = [item]
for slice_item in item:
if isinstance(slice_item, slice):
if isinstance(slice_item.start, Variable) \
or isinstance(slice_item.stop, Variable) \
or isinstance(slice_item.step, Variable):
return True
else:
if isinstance(slice_item, Variable):
return True
return False
if contain_tensor(item):
# 1. Call _getitem_impl_ when item contains tensor.
# Why not call a c++ function ? Because item can't be parsed when it contains tensor.
return _getitem_impl_(self, item)
else:
# 2. Call c++ func getitem_index_not_tensor to speedup.
return self._getitem_index_not_tensor(item)
for method_name, method in (
("__bool__", __bool__), ("__nonzero__", __nonzero__),
("_to_static_var", _to_static_var), ("set_value", set_value),
......@@ -445,7 +470,8 @@ def monkey_patch_varbase():
("gradient", gradient), ("register_hook", register_hook),
("__str__", __str__), ("__repr__", __str__),
("__deepcopy__", __deepcopy__), ("__module__", "paddle"),
("__name__", "Tensor"), ("__array__", __array__)):
("__name__", "Tensor"), ("__array__", __array__),
("__getitem__", __getitem__)):
setattr(core.VarBase, method_name, method)
# NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
......
......@@ -473,6 +473,70 @@ class TestVarBase(unittest.TestCase):
np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1]))
self.assertTrue(np.array_equal(local_out[16], tensor_array[-4:4]))
def _test_slice_for_tensor_attr(self):
tensor_array = np.array(
[[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]]).astype('float32')
var = paddle.to_tensor(tensor_array)
one = paddle.ones(shape=[1], dtype="int32")
two = paddle.full(shape=[1], fill_value=2, dtype="int32")
negative_one = paddle.full(shape=[1], fill_value=-1, dtype="int32")
four = paddle.full(shape=[1], fill_value=4, dtype="int32")
var = fluid.dygraph.to_variable(tensor_array)
var1 = var[0, one, one]
var2 = var[one:]
var3 = var[0:one]
var4 = var[::negative_one]
var5 = var[one, one:, one:]
var_reshape = fluid.layers.reshape(var, [3, negative_one, 3])
var6 = var_reshape[:, :, negative_one]
var7 = var[:, :, :negative_one]
var8 = var[:one, :one, :1]
var9 = var[:-1, :negative_one, :negative_one]
var10 = var[::negative_one, :one, :negative_one]
var11 = var[:negative_one, ::-1, negative_one:]
var12 = var[one:2, 2:, ::negative_one]
var13 = var[two:10, 2:, -2:negative_one]
var14 = var[1:negative_one, 0:2, ::negative_one]
var15 = var[::negative_one, ::-1, ::negative_one]
var16 = var[-4:4]
vars = [
var, var1, var2, var3, var4, var5, var6, var7, var8, var9, var10,
var11, var12, var13, var14, var15, var16
]
local_out = [var.numpy() for var in vars]
self.assertTrue(np.array_equal(local_out[1], tensor_array[0, 1, 1:2]))
self.assertTrue(np.array_equal(local_out[2], tensor_array[1:]))
self.assertTrue(np.array_equal(local_out[3], tensor_array[0:1]))
self.assertTrue(np.array_equal(local_out[4], tensor_array[::-1]))
self.assertTrue(np.array_equal(local_out[5], tensor_array[1, 1:, 1:]))
self.assertTrue(
np.array_equal(local_out[6],
tensor_array.reshape((3, -1, 3))[:, :, -1]))
self.assertTrue(np.array_equal(local_out[7], tensor_array[:, :, :-1]))
self.assertTrue(np.array_equal(local_out[8], tensor_array[:1, :1, :1]))
self.assertTrue(
np.array_equal(local_out[9], tensor_array[:-1, :-1, :-1]))
self.assertTrue(
np.array_equal(local_out[10], tensor_array[::-1, :1, :-1]))
self.assertTrue(
np.array_equal(local_out[11], tensor_array[:-1, ::-1, -1:]))
self.assertTrue(
np.array_equal(local_out[12], tensor_array[1:2, 2:, ::-1]))
self.assertTrue(
np.array_equal(local_out[13], tensor_array[2:10, 2:, -2:-1]))
self.assertTrue(
np.array_equal(local_out[14], tensor_array[1:-1, 0:2, ::-1]))
self.assertTrue(
np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1]))
self.assertTrue(np.array_equal(local_out[16], tensor_array[-4:4]))
def _test_for_var(self):
np_value = np.random.random((30, 100, 100)).astype('float32')
w = fluid.dygraph.to_variable(np_value)
......@@ -483,6 +547,7 @@ class TestVarBase(unittest.TestCase):
def test_slice(self):
with fluid.dygraph.guard():
self._test_slice()
self._test_slice_for_tensor_attr()
self._test_for_var()
var = fluid.dygraph.to_variable(self.array)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册