diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index da9900e2b271d08394cbc5e397f31b84e3b4d156..289540c8049a953c333562f3b3de2542c8b76676 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -563,6 +563,33 @@ void BindImperative(py::module *m_ptr) { .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value")) .def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor")) .def("__init__", &InitVarBaseFromNumpyWithKwargs) + .def("__setitem__", + [](std::shared_ptr &self, py::handle _index, + py::object &value_obj) { + auto self_tensor = + self->MutableVar()->GetMutable(); + auto self_numpy = TensorToPyArray(*self_tensor); + + if (py::isinstance(value_obj) || + py::isinstance(value_obj) || + py::isinstance(value_obj)) { + auto value_numpy = value_obj; + self_numpy[_index] = value_numpy; + SetTensorFromPyArray(self_tensor, self_numpy, + self_tensor->place(), true); + + } else { + auto value = + value_obj.cast>(); + auto value_tensor = + value->MutableVar()->GetMutable(); + auto value_numpy = TensorToPyArray(*value_tensor); + + self_numpy[_index] = value_numpy; + SetTensorFromPyArray(self_tensor, self_numpy, + self_tensor->place(), true); + } + }) .def("__getitem__", [](std::shared_ptr &self, py::handle _index) { std::vector slice_axes, slice_starts, slice_ends, @@ -797,7 +824,8 @@ void BindImperative(py::module *m_ptr) { return framework::vectorize( self.Var().Get().value().dims()); } else { - VLOG(2) << "It is meaningless to get shape of variable type " + VLOG(2) << "It is meaningless to get shape of " + "variable type " << GetTypeName(self); return std::vector(); } diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index deb49a3ffc2b5febf97680bc652e9695fb253373..e3edf82ab9959dbe4e2673e31059e573819ffb20 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -15,12 +15,14 @@ from __future__ import print_function import unittest -from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode +import numpy as np +import six + import paddle import paddle.fluid as fluid -import paddle.fluid.layers as layers import paddle.fluid.core as core -import numpy as np +import paddle.fluid.layers as layers +from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode class TestVarBase(unittest.TestCase): @@ -403,5 +405,52 @@ class TestVarBase(unittest.TestCase): self.assertListEqual(list(var_base.shape), list(static_var.shape)) +class TestVarBaseSetitem(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32)) + self.np_value = np.random.random((2, 3)).astype(np.float32) + self.tensor_value = paddle.to_tensor(self.np_value) + + def _test(self, value): + paddle.disable_static() + id_origin = id(self.tensor_x) + + self.tensor_x[0] = value + + if isinstance(value, (six.integer_types, float)): + result = np.zeros((2, 3)).astype(np.float32) + value + + else: + result = self.np_value + + self.assertTrue(np.array_equal(self.tensor_x[0].numpy(), result)) + self.assertEqual(id_origin, id(self.tensor_x)) + + self.tensor_x[1:2] = value + self.assertTrue(np.array_equal(self.tensor_x[1].numpy(), result)) + self.assertEqual(id_origin, id(self.tensor_x)) + + self.tensor_x[...] = value + self.assertTrue(np.array_equal(self.tensor_x[3].numpy(), result)) + self.assertEqual(id_origin, id(self.tensor_x)) + + def test_value_tensor(self): + paddle.disable_static() + self._test(self.tensor_value) + + def test_value_numpy(self): + paddle.disable_static() + self._test(self.np_value) + + def test_value_int(self): + paddle.disable_static() + self._test(10) + + def test_value_float(self): + paddle.disable_static() + self._test(3.3) + + if __name__ == '__main__': unittest.main()