diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index ace62d210b368d1a20fa5aa890f23cd78162bdf7..66eaed5adb832e2666aabed1221afc07646ba0bb 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -746,7 +746,7 @@ void BindImperative(py::module *m_ptr) { // inplace operator for the VarBase self. self->BumpInplaceVersion(); }) - .def("__getitem__", + .def("_getitem_index_not_tensor", [](std::shared_ptr &self, py::handle _index) { std::vector slice_axes, slice_starts, slice_ends, slice_strides, decrease_axis, infer_flags; diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 64209aee875ba02364481674f796a0b519f771f5..11bc150b281aa95bbd838433ae95ac7ec0c23410 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -21,7 +21,7 @@ import paddle from .. import framework from .. import core from .. import unique_name -from ..framework import Variable, Parameter, ParamBase +from ..framework import Variable, Parameter, ParamBase, _getitem_impl_ from .base import switch_to_static_graph from .math_op_patch import monkey_patch_math_varbase from .parallel import scale_loss @@ -437,6 +437,31 @@ def monkey_patch_varbase(): def __array__(self, dtype=None): return self.numpy().astype(dtype) + def __getitem__(self, item): + def contain_tensor(item): + if not isinstance(item, tuple): + item = [item] + + for slice_item in item: + if isinstance(slice_item, slice): + if isinstance(slice_item.start, Variable) \ + or isinstance(slice_item.stop, Variable) \ + or isinstance(slice_item.step, Variable): + return True + else: + if isinstance(slice_item, Variable): + return True + return False + + if contain_tensor(item): + # 1. Call _getitem_impl_ when item contains tensor. + # Why not call a c++ function ? Because item can't be parsed when it contains tensor. + return _getitem_impl_(self, item) + + else: + # 2. Call c++ func getitem_index_not_tensor to speedup. + return self._getitem_index_not_tensor(item) + for method_name, method in ( ("__bool__", __bool__), ("__nonzero__", __nonzero__), ("_to_static_var", _to_static_var), ("set_value", set_value), @@ -445,7 +470,8 @@ def monkey_patch_varbase(): ("gradient", gradient), ("register_hook", register_hook), ("__str__", __str__), ("__repr__", __str__), ("__deepcopy__", __deepcopy__), ("__module__", "paddle"), - ("__name__", "Tensor"), ("__array__", __array__)): + ("__name__", "Tensor"), ("__array__", __array__), + ("__getitem__", __getitem__)): setattr(core.VarBase, method_name, method) # NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class. diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 77432a59de273c570e87dda388adf8e11475eff7..7901df79171216d864118e16bd7b11c8e327774c 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -473,6 +473,70 @@ class TestVarBase(unittest.TestCase): np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1])) self.assertTrue(np.array_equal(local_out[16], tensor_array[-4:4])) + def _test_slice_for_tensor_attr(self): + tensor_array = np.array( + [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24], [25, 26, 27]]]).astype('float32') + + var = paddle.to_tensor(tensor_array) + + one = paddle.ones(shape=[1], dtype="int32") + two = paddle.full(shape=[1], fill_value=2, dtype="int32") + negative_one = paddle.full(shape=[1], fill_value=-1, dtype="int32") + four = paddle.full(shape=[1], fill_value=4, dtype="int32") + + var = fluid.dygraph.to_variable(tensor_array) + var1 = var[0, one, one] + var2 = var[one:] + var3 = var[0:one] + var4 = var[::negative_one] + var5 = var[one, one:, one:] + var_reshape = fluid.layers.reshape(var, [3, negative_one, 3]) + var6 = var_reshape[:, :, negative_one] + var7 = var[:, :, :negative_one] + var8 = var[:one, :one, :1] + var9 = var[:-1, :negative_one, :negative_one] + var10 = var[::negative_one, :one, :negative_one] + var11 = var[:negative_one, ::-1, negative_one:] + var12 = var[one:2, 2:, ::negative_one] + var13 = var[two:10, 2:, -2:negative_one] + var14 = var[1:negative_one, 0:2, ::negative_one] + var15 = var[::negative_one, ::-1, ::negative_one] + var16 = var[-4:4] + + vars = [ + var, var1, var2, var3, var4, var5, var6, var7, var8, var9, var10, + var11, var12, var13, var14, var15, var16 + ] + local_out = [var.numpy() for var in vars] + + self.assertTrue(np.array_equal(local_out[1], tensor_array[0, 1, 1:2])) + self.assertTrue(np.array_equal(local_out[2], tensor_array[1:])) + self.assertTrue(np.array_equal(local_out[3], tensor_array[0:1])) + self.assertTrue(np.array_equal(local_out[4], tensor_array[::-1])) + self.assertTrue(np.array_equal(local_out[5], tensor_array[1, 1:, 1:])) + self.assertTrue( + np.array_equal(local_out[6], + tensor_array.reshape((3, -1, 3))[:, :, -1])) + self.assertTrue(np.array_equal(local_out[7], tensor_array[:, :, :-1])) + self.assertTrue(np.array_equal(local_out[8], tensor_array[:1, :1, :1])) + self.assertTrue( + np.array_equal(local_out[9], tensor_array[:-1, :-1, :-1])) + self.assertTrue( + np.array_equal(local_out[10], tensor_array[::-1, :1, :-1])) + self.assertTrue( + np.array_equal(local_out[11], tensor_array[:-1, ::-1, -1:])) + self.assertTrue( + np.array_equal(local_out[12], tensor_array[1:2, 2:, ::-1])) + self.assertTrue( + np.array_equal(local_out[13], tensor_array[2:10, 2:, -2:-1])) + self.assertTrue( + np.array_equal(local_out[14], tensor_array[1:-1, 0:2, ::-1])) + self.assertTrue( + np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1])) + self.assertTrue(np.array_equal(local_out[16], tensor_array[-4:4])) + def _test_for_var(self): np_value = np.random.random((30, 100, 100)).astype('float32') w = fluid.dygraph.to_variable(np_value) @@ -483,6 +547,7 @@ class TestVarBase(unittest.TestCase): def test_slice(self): with fluid.dygraph.guard(): self._test_slice() + self._test_slice_for_tensor_attr() self._test_for_var() var = fluid.dygraph.to_variable(self.array)