From 86d4af395b4fa9be466c8d8260672b4ef448b823 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 15 Sep 2021 17:35:33 +0800 Subject: [PATCH] Change the invoking method of settiem from numpy to set_value op when value isn't tensor (#35701) * Change the invoking method of settiem from numpy to set_value op when value is not tensor * fix the check logic for inplace in setitem * fix the unittest problem caused by setitem doesn't support fp16 * modify some code format in setitem --- paddle/fluid/operators/slice_utils.h | 13 +- paddle/fluid/pybind/imperative.cc | 424 ++++++++++++------ .../tests/unittests/test_tensor_fill_.py | 7 +- .../tests/unittests/test_tensor_zero_.py | 7 +- .../fluid/tests/unittests/test_var_base.py | 14 +- python/paddle/nn/layer/common.py | 3 +- 6 files changed, 303 insertions(+), 165 deletions(-) diff --git a/paddle/fluid/operators/slice_utils.h b/paddle/fluid/operators/slice_utils.h index 60782a9a924..290df94774b 100644 --- a/paddle/fluid/operators/slice_utils.h +++ b/paddle/fluid/operators/slice_utils.h @@ -36,17 +36,18 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, if (infer_flags != nullptr && (*infer_flags)[i] == -1) { continue; } - T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i]; - start = std::max(start, static_cast(0)); - - T end = (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i]; - end = std::min(end, dim_value); - T step = steps == nullptr ? 1 : (*steps)[i]; PADDLE_ENFORCE_NE( step, 0, platform::errors::InvalidArgument( "Step should not be 0, but received step = %d.", step)); + T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i]; + start = std::max(start, static_cast(0)); + + T end = + 0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i]; + end = std::min(end, dim_value); + if (step > 0) { start = std::min(start, dim_value); end = std::max(end, static_cast(0)); diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 8e230a6a108..62279449e3c 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -46,6 +46,7 @@ limitations under the License. */ #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/memory/allocation/mmap_allocator.h" +#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/pybind/op_function.h" #include "paddle/fluid/pybind/pybind_boost_headers.h" #include "paddle/fluid/pybind/tensor_py.h" @@ -340,6 +341,51 @@ static bool IsNumpyType(PyObject *obj) { return type_name == "numpy.int64" || type_name == "numpy.longlong" || type_name == "numpy.int32" || type_name == "numpy.int16"; } + +static bool PyCheckTensor(PyObject *obj) { + return py::isinstance(obj); +} + +// cast numpy type form S to T, this may allocate new memory +template +static py::array_t CastNumpyType(py::array_t array) { + if (std::is_same::value) { + return array; + } + auto dim = array.ndim(); + std::vector result_shape(dim); + for (auto i = 0; i < dim; i++) { + result_shape[i] = array.shape(i); + } + + py::array_t result(result_shape); + + return py::vectorize([](S s) { return static_cast(s); })(array); +} + +template +static py::array_t CastNumpyArray(const py::object &array) { + if (py::isinstance>(array)) { + return CastNumpyType(array.cast>()); + } else if (py::isinstance>(array)) { + return CastNumpyType(array.cast>()); + } else if (py::isinstance>(array)) { + return CastNumpyType(array.cast>()); + } else if (py::isinstance>(array)) { + return CastNumpyType(array.cast>()); + } else if (py::isinstance>(array)) { + return CastNumpyType(array.cast>()); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Value type error. The assign numpy value allows integer, float, " + "double and bool, " + "but received %s.", + Py_TYPE(array.ptr())->tp_name)); + } + // can't reach here + return py::array_t(); +} + static imperative::NameVarBaseMap ConvertToNameVarBaseMap( const PyNameVarBaseMap &map) { imperative::NameVarBaseMap result; @@ -364,6 +410,27 @@ static bool PyCheckInteger(PyObject *obj) { #endif } +static Py_ssize_t GetSliceIndexFromTensor( + const std::shared_ptr &tensor_index) { + const auto &tensor = tensor_index->Var().Get(); + if (tensor.numel() == 1) { + if (tensor.type() == framework::proto::VarType::INT32) { + return static_cast(operators::GetValue(&tensor)); + } else if (tensor.type() == framework::proto::VarType::INT64) { + return static_cast(operators::GetValue(&tensor)); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Currently, the type of tensor in slice indices only allows " + "int32 and int64, please check the type of index tensor.")); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Currently, tensor in slice indices only allows 1 element, " + "but received %d.", + tensor.numel())); + } +} + // NOTE(zhiqiu): Revised version of PySlice_GetIndices. From: // https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Objects/sliceobject.c#L103 // Original PySlice_GetIndices return wrong result when @@ -382,10 +449,13 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length, } else { if (PyCheckInteger(r->step) || IsNumpyType(r->step)) { *step = PyLong_AsLong(r->step); + } else if (PyCheckTensor(r->step)) { + *step = GetSliceIndexFromTensor( + py::cast>(r->step)); } else { PADDLE_THROW(platform::errors::InvalidArgument( - "Currently, VarBase.__getitem__() only allows None or integers in " - "slice item, but received %s.", + "Currently, slice indices only allows None, integers, " + "tensor(int) and numpy(int) in slice item, but received %s.", std::string(Py_TYPE(r->step)->tp_name))); } } @@ -394,10 +464,13 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length, } else { if (PyCheckInteger(r->start) || IsNumpyType(r->start)) { *start = PyLong_AsLong(r->start); + } else if (PyCheckTensor(r->start)) { + *start = GetSliceIndexFromTensor( + py::cast>(r->start)); } else { PADDLE_THROW(platform::errors::InvalidArgument( - "Currently, VarBase.__getitem__() only allows None or integers in " - "slice item, but received %s.", + "Currently, slice indices only allows None, integers, " + "tensor(int) and numpy(int) in slice item, but received %s.", std::string(Py_TYPE(r->start)->tp_name))); } if (*start < 0) *start += length; @@ -408,13 +481,16 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length, } else { if (PyCheckInteger(r->stop) || IsNumpyType(r->stop)) { *stop = PyLong_AsLong(r->stop); + } else if (PyCheckTensor(r->stop)) { + *stop = GetSliceIndexFromTensor( + py::cast>(r->stop)); } else { PADDLE_THROW(platform::errors::InvalidArgument( - "Currently, VarBase.__getitem__() only allows None or integers in " - "slice item, but received %s.", + "Currently, slice indices only allows None, integers, " + "tensor(int) and numpy(int) in slice item, but received %s.", std::string(Py_TYPE(r->stop)->tp_name))); } - if (*stop < 0) *stop += length; + if (0 < *step && *stop < 0) *stop += length; *stop = std::min(*stop, length); } if (*stop > length) return -1; @@ -554,7 +630,7 @@ static void ParseIndexingSlice( } else { PADDLE_THROW(platform::errors::InvalidArgument( - "Currently, VarBase.__getitem__() only allows indexing " + "Currently, Tensor.__indices__() only allows indexing " "by Integers, Slices, Ellipsis, None, tuples of these types " "and list of Bool and Integers, but received " "%s in %dth slice item", @@ -824,148 +900,204 @@ void BindImperative(py::module *m_ptr) { .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value")) .def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor")) .def("__init__", &InitVarBaseFromNumpyWithKwargs) - .def("__setitem_varbase__", - [](std::shared_ptr &self, py::handle _index, - py::object &value_obj) { - VLOG(4) << "Call __setitem_varbase__"; - - auto self_tensor = - self->MutableVar()->GetMutable(); - // NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New - // https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251 - PyObject *index_ptr = !PyTuple_Check(_index.ptr()) - ? PyTuple_Pack(1, _index.ptr()) - : _index.ptr(); - DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() { - if (!PyTuple_Check(_index.ptr())) { - Py_DECREF(index_ptr); - VLOG(4) << "Call Py_DECREF"; - } - }); - // 1. Check argumnets - // 1.1 Check whether value obj is a tensor. - bool value_is_tensor = true; - bool parse_index = true; - if (py::isinstance(value_obj) || - py::isinstance(value_obj) || - py::isinstance(value_obj)) { - value_is_tensor = false; - } - - auto is_tensor = [](py::handle var) { - if (!var.ptr() || var.ptr() == Py_None) { - return false; - } - try { - py::cast>(var); - return true; - } catch (py::cast_error &) { - return false; - } - }; - - // 1.2 Check whether _index can be parsed. - const int size = PyTuple_GET_SIZE(index_ptr); - for (int dim = 0; dim < size; ++dim) { - PyObject *slice_item = PyTuple_GetItem(index_ptr, dim); - if (!(PyCheckInteger(slice_item) || PySlice_Check(slice_item) || - slice_item == Py_Ellipsis || slice_item == Py_None)) { - parse_index = false; - break; - } - } + .def( + "__setitem_varbase__", + [](std::shared_ptr &self, py::handle _index, + py::object &value_obj) { + VLOG(4) << "Call __setitem_varbase__"; + + auto self_tensor = + self->MutableVar()->GetMutable(); + // NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New + // https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251 + PyObject *index_ptr = !PyTuple_Check(_index.ptr()) + ? PyTuple_Pack(1, _index.ptr()) + : _index.ptr(); + DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() { + if (!PyTuple_Check(_index.ptr())) { + Py_DECREF(index_ptr); + VLOG(4) << "Call Py_DECREF"; + } + }); - // 2. Call op set_value to speed up if the condition is met, - // otherwise call TensorToPyArray. - // TODO(liym27): Try not to call TensorToPyArray because it always - // copys data to cpu place, which reduces performance. - if (parse_index && value_is_tensor) { - std::vector axes, starts, ends, steps, decrease_axes, - none_axes, infer_flags, list_select_idxs; - // if index is a list, list_select_flag will be true - bool list_select_flag; - ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends, - &steps, &decrease_axes, &none_axes, - &infer_flags, &list_select_idxs, - &list_select_flag); + auto is_tensor = [](py::handle var) { + if (!var.ptr() || var.ptr() == Py_None) { + return false; + } + try { + py::cast>(var); + return true; + } catch (py::cast_error &) { + return false; + } + }; + + // 1. Check argumnets + bool parse_index = true; + + // Check whether _index can be parsed. + const int size = PyTuple_GET_SIZE(index_ptr); + for (int dim = 0; dim < size; ++dim) { + PyObject *slice_item = PyTuple_GetItem(index_ptr, dim); + if (!(PyCheckInteger(slice_item) || PySlice_Check(slice_item) || + slice_item == Py_Ellipsis || slice_item == Py_None)) { + parse_index = false; + break; + } + } - framework::AttributeMap attrs = { - {"axes", axes}, - {"starts", starts}, - {"ends", ends}, - {"steps", steps}, - {"decrease_axes", decrease_axes}, - {"none_axes", none_axes}}; + // 2. Call op set_value to speed up if the condition is met, + // otherwise call TensorToPyArray. + // TODO(liym27): Try not to call TensorToPyArray because it always + // copys data to cpu place, which reduces performance. + if (parse_index) { + std::vector axes, starts, ends, steps, decrease_axes, + none_axes, infer_flags, list_select_idxs; + // if index is a list, list_select_flag will be true + bool list_select_flag = false; + ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends, + &steps, &decrease_axes, &none_axes, + &infer_flags, &list_select_idxs, + &list_select_flag); + + framework::AttributeMap attrs = {{"axes", axes}, + {"starts", starts}, + {"ends", ends}, + {"steps", steps}, + {"decrease_axes", decrease_axes}, + {"none_axes", none_axes}}; + + imperative::NameVarBaseMap ins = {{"Input", {self}}}; + imperative::NameVarBaseMap outs = {{"Out", {self}}}; + + const auto &tracer = imperative::GetCurrentTracer(); + + if (tracer->HasGrad()) { + PADDLE_ENFORCE_EQ( + self->IsLeaf() && !self->OverridedStopGradient(), false, + platform::errors::InvalidArgument( + "Leaf Tensor (%s) that doesn't stop gradient can't use " + "inplace strategy.", + self->Name())); + } - imperative::NameVarBaseMap ins = {{"Input", {self}}}; - imperative::NameVarBaseMap outs = {{"Out", {self}}}; + if (PyCheckTensor(value_obj.ptr())) { + auto value_tensor = + value_obj.cast>(); + ins.insert({"ValueTensor", {value_tensor}}); + } else if (py::isinstance(value_obj)) { + auto value_tensor = std::shared_ptr( + new imperative::VarBase(false, + tracer->GenerateUniqueName())); + py::object value = value_obj; + if (self->DataType() == framework::proto::VarType::FP32) { + if (!py::isinstance>(value_obj)) { + value = CastNumpyArray(value_obj); + } + } else if (self->DataType() == + framework::proto::VarType::FP64) { + if (!py::isinstance>(value_obj)) { + value = CastNumpyArray(value_obj); + } + } else if (self->DataType() == + framework::proto::VarType::INT32) { + if (!py::isinstance>(value_obj)) { + value = CastNumpyArray(value_obj); + } + } else if (self->DataType() == + framework::proto::VarType::INT64) { + if (!py::isinstance>(value_obj)) { + value = CastNumpyArray(value_obj); + } + } else if (self->DataType() == + framework::proto::VarType::BOOL) { + if (!py::isinstance>(value_obj)) { + value = CastNumpyArray(value_obj); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "When assign a numpy.np value to a paddle.Tensor, " + "the data type of the paddle.Tensor must be bool, " + "float32, int32 or int64, " + "please check the type of tensor.")); + } + + SetTensorFromPyArray(value_tensor->MutableVar() + ->GetMutable(), + value, self->Place(), false); + ins.insert({"ValueTensor", {value_tensor}}); + + } else { + // convert the value to self data type + if (py::isinstance(value_obj) || + py::isinstance(value_obj) || + py::isinstance(value_obj)) { + if (self->DataType() == framework::proto::VarType::FP32) { + attrs["fp32_values"] = + std::vector{value_obj.cast()}; + } else if (self->DataType() == + framework::proto::VarType::FP64) { + attrs["fp64_values"] = + std::vector{value_obj.cast()}; + } else if (self->DataType() == + framework::proto::VarType::INT32) { + attrs["int32_values"] = + std::vector{value_obj.cast()}; + } else if (self->DataType() == + framework::proto::VarType::INT64) { + attrs["int64_values"] = + std::vector{value_obj.cast()}; + } else if (self->DataType() == + framework::proto::VarType::BOOL) { + attrs["bool_values"] = + std::vector{value_obj.cast()}; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "When assign a value to a paddle.Tensor, " + "the data type of the paddle.Tensor must be bool, " + "float32, int32 or int64, " + "please check the type of tensor.")); + } + attrs["shape"] = std::vector{1}; + + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Value type error. The assign value allows " + "numpy.ndarray, integer, float or bool, " + "but received %s.", + Py_TYPE(value_obj.ptr()))); + } + } - PADDLE_ENFORCE_EQ( - self->IsLeaf() && !self->OverridedStopGradient(), false, - platform::errors::InvalidArgument( - "Leaf Tensor (%s) that doesn't stop gradient can't use " - "inplace strategy.", - self->Name())); - - auto value_tensor = - value_obj.cast>(); - ins.insert({"ValueTensor", {value_tensor}}); - - const auto &tracer = imperative::GetCurrentTracer(); - { - // Release gil and do tracing - py::gil_scoped_release release; - tracer->TraceOp("set_value", ins, outs, std::move(attrs), - {{"Input", "Out"}}); - } - } else { - auto self_numpy = TensorToPyArray(*self_tensor); - VLOG(4) << "parse_index is false"; - - if (value_is_tensor) { - VLOG(4) << "value is tensor"; - auto value = - value_obj.cast>(); - auto value_tensor = - value->MutableVar()->GetMutable(); - auto value_numpy = TensorToPyArray(*value_tensor); - if (is_tensor(_index)) { - VLOG(4) << "index is tensor"; - auto index_var = - py::cast>(_index); - auto index_tensor = index_var->MutableVar() - ->GetMutable(); - auto index_numpy = TensorToPyArray(*index_tensor); - self_numpy[index_numpy] = value_numpy; - } else { - VLOG(4) << "index is not tensor"; - self_numpy[_index] = value_numpy; - } - SetTensorFromPyArray(self_tensor, self_numpy, - self_tensor->place(), true); - } else { - VLOG(4) << "value is not tensor"; - if (is_tensor(_index)) { - VLOG(4) << "index is tensor"; - auto index_var = - py::cast>(_index); - auto index_tensor = index_var->MutableVar() - ->GetMutable(); - auto index_numpy = TensorToPyArray(*index_tensor); - self_numpy[index_numpy] = value_obj; - } else { - VLOG(4) << "index is not tensor"; - self_numpy[_index] = value_obj; - } - SetTensorFromPyArray(self_tensor, self_numpy, - self_tensor->place(), true); - } - } - // NOTE(liym27): - // Increase the version of VarBase self because __setitem__ is an - // inplace operator for the VarBase self. - self->BumpInplaceVersion(); - }) + { + // Release gil and do tracing + py::gil_scoped_release release; + tracer->TraceOp("set_value", ins, outs, std::move(attrs), + {{"Input", "Out"}}); + } + } else { + auto self_numpy = TensorToPyArray(*self_tensor); + VLOG(4) << "parse_index is false"; + if (is_tensor(_index)) { + VLOG(4) << "index is tensor"; + auto index_var = + py::cast>(_index); + auto index_tensor = + index_var->MutableVar()->GetMutable(); + auto index_numpy = TensorToPyArray(*index_tensor); + self_numpy[index_numpy] = value_obj; + } else { + VLOG(4) << "index is not tensor"; + self_numpy[_index] = value_obj; + } + SetTensorFromPyArray(self_tensor, self_numpy, + self_tensor->place(), false); + } + // NOTE(liym27): + // Increase the version of VarBase self because __setitem__ is an + // inplace operator for the VarBase self. + self->BumpInplaceVersion(); + }) .def("_getitem_index_not_tensor", [](std::shared_ptr &self, py::handle _index) { VLOG(4) << "Call _getitem_index_not_tensor"; diff --git a/python/paddle/fluid/tests/unittests/test_tensor_fill_.py b/python/paddle/fluid/tests/unittests/test_tensor_fill_.py index bc4456bb969..5891aee5bd3 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_fill_.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_fill_.py @@ -40,12 +40,11 @@ class TensorFill_Test(unittest.TestCase): for dtype in typelist: var = 1. tensor = paddle.to_tensor(np_arr, place=p, dtype=dtype) - newtensor = tensor.clone() - newtensor[...] = var + target = tensor.numpy() + target[...] = var tensor.fill_(var) #var type is basic type in typelist - self.assertEqual((tensor.numpy() == newtensor.numpy()).all(), - True) + self.assertEqual((tensor.numpy() == target).all(), True) def test_tensor_fill_backward(self): typelist = ['float32'] diff --git a/python/paddle/fluid/tests/unittests/test_tensor_zero_.py b/python/paddle/fluid/tests/unittests/test_tensor_zero_.py index 716607710f1..65620038fc4 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_zero_.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_zero_.py @@ -35,12 +35,11 @@ class TensorFill_Test(unittest.TestCase): np.array(six.moves.range(np.prod(self.shape))), self.shape) for dtype in typelist: tensor = paddle.to_tensor(np_arr, place=p, dtype=dtype) - newtensor = tensor.clone() - newtensor[...] = 0 + target = tensor.numpy() + target[...] = 0 tensor.zero_() - self.assertEqual( - (tensor.numpy() == newtensor.numpy()).all().item(), True) + self.assertEqual((tensor.numpy() == target).all().item(), True) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index cbfb9860fa6..cfaef15c1d3 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -391,6 +391,11 @@ class TestVarBase(unittest.TestCase): self.assertTrue(cmp_float(x.grad.numpy(), [20.0])) self.assertTrue(cmp_float(detach_x.grad.numpy(), [60.0])) + with self.assertRaises(ValueError): + detach_x[:] = 5.0 + + detach_x.stop_gradient = True + # Due to sharing of data with origin Tensor, There are some unsafe operations: with self.assertRaises(RuntimeError): y = 2**x @@ -438,10 +443,11 @@ class TestVarBase(unittest.TestCase): self.assertTrue(np.array_equal(y.numpy(), y_copy.numpy())) self.assertNotEqual(id(x), id(x_copy)) - x_copy[:] = 5. - self.assertTrue(np.array_equal(x_copy.numpy(), [5.])) self.assertTrue(np.array_equal(x.numpy(), [2.])) + with self.assertRaises(ValueError): + x_copy[:] = 5. + with self.assertRaises(RuntimeError): copy.deepcopy(z) @@ -805,8 +811,8 @@ class TestVarBase(unittest.TestCase): # case2: tensor_x = paddle.to_tensor( np.zeros(12).reshape(2, 6).astype(np.float32)) - tensor_y1 = paddle.zeros([1]) + 2 - tensor_y2 = paddle.zeros([1]) + 5 + tensor_y1 = paddle.zeros([1], dtype='int32') + 2 + tensor_y2 = paddle.zeros([1], dtype='int32') + 5 tensor_x[:, tensor_y1:tensor_y2] = 42 res = tensor_x.numpy() exp = np.array([[0., 0., 42., 42., 42., 0.], diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 357c9bc7d40..d5557bd9ea4 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -1390,7 +1390,8 @@ class Embedding(Layer): is_bias=False) if in_dygraph_mode() and padding_idx != -1: - self.weight[padding_idx] = 0.0 + with paddle.no_grad(): + self.weight[padding_idx] = 0.0 def forward(self, x): return F.embedding( -- GitLab