未验证 提交 074a71bd 编写于 作者: L liym27 提交者: GitHub

Support assignment to a Variable in dynamic mode but not deal with backward. (#27471)

* Support assignment to a Variable in dynamic mode. Note: not deal with backward.

* Rewrite VarBase __setitem__ for high-performance.

* try to test 3 means to do __setitem__ and test the performance of 3 means.

* Retain the means of the highest performance: C++ code and don't trace op.
上级 5218b7af
...@@ -563,6 +563,33 @@ void BindImperative(py::module *m_ptr) { ...@@ -563,6 +563,33 @@ void BindImperative(py::module *m_ptr) {
.def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value")) .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
.def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor")) .def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"))
.def("__init__", &InitVarBaseFromNumpyWithKwargs) .def("__init__", &InitVarBaseFromNumpyWithKwargs)
.def("__setitem__",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index,
py::object &value_obj) {
auto self_tensor =
self->MutableVar()->GetMutable<framework::LoDTensor>();
auto self_numpy = TensorToPyArray(*self_tensor);
if (py::isinstance<py::array>(value_obj) ||
py::isinstance<py::int_>(value_obj) ||
py::isinstance<py::float_>(value_obj)) {
auto value_numpy = value_obj;
self_numpy[_index] = value_numpy;
SetTensorFromPyArray(self_tensor, self_numpy,
self_tensor->place(), true);
} else {
auto value =
value_obj.cast<std::shared_ptr<imperative::VarBase>>();
auto value_tensor =
value->MutableVar()->GetMutable<framework::LoDTensor>();
auto value_numpy = TensorToPyArray(*value_tensor);
self_numpy[_index] = value_numpy;
SetTensorFromPyArray(self_tensor, self_numpy,
self_tensor->place(), true);
}
})
.def("__getitem__", .def("__getitem__",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) { [](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
std::vector<int> slice_axes, slice_starts, slice_ends, std::vector<int> slice_axes, slice_starts, slice_ends,
...@@ -797,7 +824,8 @@ void BindImperative(py::module *m_ptr) { ...@@ -797,7 +824,8 @@ void BindImperative(py::module *m_ptr) {
return framework::vectorize<int>( return framework::vectorize<int>(
self.Var().Get<framework::SelectedRows>().value().dims()); self.Var().Get<framework::SelectedRows>().value().dims());
} else { } else {
VLOG(2) << "It is meaningless to get shape of variable type " VLOG(2) << "It is meaningless to get shape of "
"variable type "
<< GetTypeName(self); << GetTypeName(self);
return std::vector<int>(); return std::vector<int>();
} }
......
...@@ -15,12 +15,14 @@ ...@@ -15,12 +15,14 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode import numpy as np
import six
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core import paddle.fluid.core as core
import numpy as np import paddle.fluid.layers as layers
from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode
class TestVarBase(unittest.TestCase): class TestVarBase(unittest.TestCase):
...@@ -403,5 +405,52 @@ class TestVarBase(unittest.TestCase): ...@@ -403,5 +405,52 @@ class TestVarBase(unittest.TestCase):
self.assertListEqual(list(var_base.shape), list(static_var.shape)) self.assertListEqual(list(var_base.shape), list(static_var.shape))
class TestVarBaseSetitem(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32))
self.np_value = np.random.random((2, 3)).astype(np.float32)
self.tensor_value = paddle.to_tensor(self.np_value)
def _test(self, value):
paddle.disable_static()
id_origin = id(self.tensor_x)
self.tensor_x[0] = value
if isinstance(value, (six.integer_types, float)):
result = np.zeros((2, 3)).astype(np.float32) + value
else:
result = self.np_value
self.assertTrue(np.array_equal(self.tensor_x[0].numpy(), result))
self.assertEqual(id_origin, id(self.tensor_x))
self.tensor_x[1:2] = value
self.assertTrue(np.array_equal(self.tensor_x[1].numpy(), result))
self.assertEqual(id_origin, id(self.tensor_x))
self.tensor_x[...] = value
self.assertTrue(np.array_equal(self.tensor_x[3].numpy(), result))
self.assertEqual(id_origin, id(self.tensor_x))
def test_value_tensor(self):
paddle.disable_static()
self._test(self.tensor_value)
def test_value_numpy(self):
paddle.disable_static()
self._test(self.np_value)
def test_value_int(self):
paddle.disable_static()
self._test(10)
def test_value_float(self):
paddle.disable_static()
self._test(3.3)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册