提交 c21550ce 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5402 Provide Tensor.from_numpy() to avoid copy

Merge pull request !5402 from hewei/tensor_from_numpy
...@@ -117,6 +117,76 @@ static bool IsCContiguous(const py::array &input) { ...@@ -117,6 +117,76 @@ static bool IsCContiguous(const py::array &input) {
return (flags & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) != 0; return (flags & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) != 0;
} }
// TensorDataNumpy implements TensorData using numpy array.
class TensorDataNumpy : public TensorData {
public:
explicit TensorDataNumpy(const py::array &input) : data_(input) {
if (!IsCContiguous(data_)) {
// Call numpy.ascontiguousarray() to convert data to C contiguous if it is not.
auto np = py::module::import("numpy");
auto convert = np.attr("ascontiguousarray");
data_ = convert(data_);
}
}
/// Total number of elements.
ssize_t size() const override { return data_.size(); }
/// Byte size of a single element.
ssize_t itemsize() const override { return data_.itemsize(); }
/// Total number of bytes.
ssize_t nbytes() const override { return data_.nbytes(); }
/// Number of dimensions.
ssize_t ndim() const override { return data_.ndim(); }
/// Data pointer.
void *data() override { return data_.request().ptr; }
const void *const_data() const override { return data_.request().ptr; }
/// Is data equals.
bool equals(const TensorData &other) const override {
auto ptr = dynamic_cast<const TensorDataNumpy *>(&other);
if (ptr == nullptr) {
// Not same type, compare data byte by byte.
return TensorData::equals(other);
}
return NumpyEquals(*ptr);
}
bool NumpyEquals(const TensorDataNumpy &other) const {
auto all_data_equal = [&other, this]() -> bool {
auto np = py::module::import("numpy");
auto equal = np.attr("equal")(data_, other.data_);
auto all_equal = np.attr("all")(equal);
return all_equal.cast<bool>();
};
return this == &other || data_.is(other.data_) || all_data_equal();
}
/// To string.
std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const override {
if (use_comma) {
// Call python np.array2string(data_, separator=', ') to convert string with comma.
py::dict kwargs;
kwargs["separator"] = ", ";
auto np = py::module::import("numpy");
auto array2string = np.attr("array2string");
return py::str(array2string(data_, **kwargs));
}
// without comma.
return py::str(data_);
}
/// py::array object.
py::array py_array() const { return data_; }
private:
mutable py::array data_;
};
TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) { TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) {
// Get input buffer info. // Get input buffer info.
py::buffer_info buf = input.request(); py::buffer_info buf = input.request();
...@@ -145,7 +215,7 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) ...@@ -145,7 +215,7 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr)
buf.ptr = tmp_buf.get(); buf.ptr = tmp_buf.get();
} }
// Get tensor shape. // Get tensor shape.
std::vector<int> shape(buf.shape.begin(), buf.shape.end()); ShapeVector shape(buf.shape.begin(), buf.shape.end());
if (data_type == buf_type) { if (data_type == buf_type) {
// Use memory copy if input data type is the same as the required type. // Use memory copy if input data type is the same as the required type.
return std::make_shared<Tensor>(data_type, shape, buf.ptr, buf.size * buf.itemsize); return std::make_shared<Tensor>(data_type, shape, buf.ptr, buf.size * buf.itemsize);
...@@ -154,6 +224,22 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) ...@@ -154,6 +224,22 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr)
return std::make_shared<Tensor>(data_type, shape, buf.ptr, buf_type); return std::make_shared<Tensor>(data_type, shape, buf.ptr, buf_type);
} }
/// Creates a Tensor from a numpy array without copy
TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
// Get input buffer info.
py::buffer_info buf = input.request();
// Get tensor dtype and check it.
auto dtype = GetDataType(buf);
if (dtype == TypeId::kTypeUnknown) {
MS_LOG(EXCEPTION) << "Unsupported data type!";
}
// Get tensor shape.
ShapeVector shape(buf.shape.begin(), buf.shape.end());
// Make a tensor with shared data with numpy array.
auto tensor_data = std::make_shared<TensorDataNumpy>(input);
return std::make_shared<Tensor>(dtype, shape, tensor_data);
}
static std::vector<ssize_t> GetStrides(const std::vector<ssize_t> &shape, ssize_t item_size) { static std::vector<ssize_t> GetStrides(const std::vector<ssize_t> &shape, ssize_t item_size) {
std::vector<ssize_t> strides; std::vector<ssize_t> strides;
strides.reserve(shape.size()); strides.reserve(shape.size());
...@@ -186,19 +272,23 @@ py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) { ...@@ -186,19 +272,23 @@ py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) {
py::array TensorPy::SyncAsNumpy(const Tensor &tensor) { py::array TensorPy::SyncAsNumpy(const Tensor &tensor) {
tensor.data_sync(); tensor.data_sync();
auto info = GetPyBufferInfo(tensor); return AsNumpy(tensor);
py::object self = py::cast(&tensor);
return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self);
} }
py::array TensorPy::AsNumpy(const Tensor &tensor) { py::array TensorPy::AsNumpy(const Tensor &tensor) {
auto data_numpy = dynamic_cast<const TensorDataNumpy *>(&tensor.data());
if (data_numpy) {
// Return internal numpy array if tensor data is implemented base on it.
return data_numpy->py_array();
}
// Otherwise, create numpy array by buffer protocol.
auto info = GetPyBufferInfo(tensor); auto info = GetPyBufferInfo(tensor);
py::object self = py::cast(&tensor); py::object self = py::cast(&tensor);
return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self);
} }
static std::vector<int> GetShapeFromTuple(const py::tuple &tuple) { static ShapeVector GetShapeFromTuple(const py::tuple &tuple) {
std::vector<int> shape; ShapeVector shape;
const size_t size = tuple.size(); const size_t size = tuple.size();
shape.reserve(tuple.size()); shape.reserve(tuple.size());
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
...@@ -210,7 +300,7 @@ static std::vector<int> GetShapeFromTuple(const py::tuple &tuple) { ...@@ -210,7 +300,7 @@ static std::vector<int> GetShapeFromTuple(const py::tuple &tuple) {
REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
// Define python MetaTensor class. // Define python MetaTensor class.
(void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor") (void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor")
.def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape")) .def(py::init<TypePtr, const ShapeVector>(), py::arg("dtype"), py::arg("shape"))
.def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
.def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.")
.def_property("_param_info", &MetaTensor::param_info, &MetaTensor::set_param_info) .def_property("_param_info", &MetaTensor::param_info, &MetaTensor::set_param_info)
...@@ -224,7 +314,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { ...@@ -224,7 +314,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
throw std::runtime_error("Invalid state!"); throw std::runtime_error("Invalid state!");
} }
/* Create a new C++ instance */ /* Create a new C++ instance */
MetaTensor tensor(TypeId(t[0].cast<int>()), t[1].cast<std::vector<int>>()); MetaTensor tensor(TypeId(t[0].cast<int>()), t[1].cast<ShapeVector>());
return tensor; return tensor;
})); }));
// Define python Tensor class. // Define python Tensor class.
...@@ -288,6 +378,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { ...@@ -288,6 +378,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.shape() >>> data.shape()
(3, 3) (3, 3)
)mydelimiter") )mydelimiter")
.def("from_numpy", TensorPy::MakeTensorNoCopy, R"mydelimiter(
Creates a Tensor from a numpy.ndarray without copy.
Arg:
array (numpy.ndarray): The input ndarray.
Returns:
Tensor, tensor with shared data to input ndarray.
Examples:
>>> a = np.ones((2, 3))
>>> t = mindspore.Tensor.from_numpy(a)
)mydelimiter")
.def("asnumpy", TensorPy::SyncAsNumpy, R"mydelimiter( .def("asnumpy", TensorPy::SyncAsNumpy, R"mydelimiter(
Convert tensor to numpy.ndarray. Convert tensor to numpy.ndarray.
......
...@@ -99,6 +99,11 @@ class TensorPy { ...@@ -99,6 +99,11 @@ class TensorPy {
// param data_type [TypeId] Data type of the tensor. // param data_type [TypeId] Data type of the tensor.
static TensorPtr MakeTensor(const py::array &input, const TypePtr &data_type = nullptr); static TensorPtr MakeTensor(const py::array &input, const TypePtr &data_type = nullptr);
// brief Create Tensor from a numpy array without copy.
//
// param input [py::array] Data value of the tensor.
static TensorPtr MakeTensorNoCopy(const py::array &input);
static py::array SyncAsNumpy(const Tensor &tensor); static py::array SyncAsNumpy(const Tensor &tensor);
static py::array AsNumpy(const Tensor &tensor); static py::array AsNumpy(const Tensor &tensor);
......
...@@ -198,10 +198,16 @@ class TensorDataImpl : public TensorData { ...@@ -198,10 +198,16 @@ class TensorDataImpl : public TensorData {
return data_.get(); return data_.get();
} }
const void *const_data() const override {
// May return nullptr if data not initialized.
return data_.get();
}
bool equals(const TensorData &other) const override { bool equals(const TensorData &other) const override {
auto ptr = dynamic_cast<const TensorDataImpl<T> *>(&other); auto ptr = dynamic_cast<const TensorDataImpl<T> *>(&other);
if (ptr == nullptr) { if (ptr == nullptr) {
return false; // Not same type, compare data byte by byte.
return TensorData::equals(other);
} }
if (ptr == this) { if (ptr == this) {
return true; return true;
......
...@@ -50,8 +50,23 @@ class TensorData { ...@@ -50,8 +50,23 @@ class TensorData {
virtual ssize_t ndim() const = 0; virtual ssize_t ndim() const = 0;
/// Data pointer. /// Data pointer.
virtual void *data() = 0; virtual void *data() = 0;
/// Const Data pointer.
virtual const void *const_data() const = 0;
/// Is data equals. /// Is data equals.
virtual bool equals(const TensorData &other) const = 0; virtual bool equals(const TensorData &other) const {
if (this == &other) {
return true;
}
// By default, compare data byte by byte.
auto this_data = static_cast<const uint8_t *>(const_data());
auto other_data = static_cast<const uint8_t *>(other.const_data());
if (this_data == nullptr || other_data == nullptr) {
// null means data not initialized, compare uninitialized data always return false.
return false;
}
return (this_data == other_data) || (ndim() == other.ndim() && nbytes() == other.nbytes() &&
std::equal(this_data, this_data + nbytes(), other_data));
}
/// To string. /// To string.
virtual std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const = 0; virtual std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const = 0;
}; };
......
...@@ -476,3 +476,16 @@ def test_tensor_operation(): ...@@ -476,3 +476,16 @@ def test_tensor_operation():
assert np.all(x.asnumpy() == np.ones((3, 3))) assert np.all(x.asnumpy() == np.ones((3, 3)))
res = 5 // x res = 5 // x
assert np.all(x.asnumpy() == np.ones((3, 3))) assert np.all(x.asnumpy() == np.ones((3, 3)))
def test_tensor_from_numpy():
a = np.ones((2, 3))
t = ms.Tensor.from_numpy(a)
assert np.all(t.asnumpy() == 1)
# 't' and 'a' share same data.
a[1] = 2
assert np.all(t.asnumpy()[0] == 1)
assert np.all(t.asnumpy()[1] == 2)
# 't' is still valid after 'a' deleted.
del a
assert np.all(t.asnumpy()[0] == 1)
assert np.all(t.asnumpy()[1] == 2)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册