提交 d5005bc8 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5280 Revert "Avoid copy when create Tensor from numpy array"

Merge pull request !5280 from hewei/master_revert_no_copy_tensor
......@@ -117,75 +117,6 @@ static bool IsCContiguous(const py::array &input) {
return (flags & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) != 0;
}
// TensorDataNumpy implements TensorData using numpy array.
class TensorDataNumpy : public TensorData {
public:
explicit TensorDataNumpy(const py::array &input) : data_(input) {}
/// Total number of elements.
ssize_t size() const override { return data_.size(); }
/// Byte size of a single element.
ssize_t itemsize() const override { return data_.itemsize(); }
/// Total number of bytes.
ssize_t nbytes() const override { return data_.nbytes(); }
/// Number of dimensions.
ssize_t ndim() const override { return data_.ndim(); }
/// Data pointer.
void *data() override {
EnsureDataContiguous();
return data_.request(true).ptr;
}
const void *const_data() const override {
EnsureDataContiguous();
return data_.request(false).ptr;
}
/// Is data equals.
bool equals(const TensorData &other) const override {
auto ptr = dynamic_cast<const TensorDataNumpy *>(&other);
if (ptr == nullptr) {
// Not same type, compare data byte by byte.
return TensorData::equals(other);
}
return NumpyEquals(*ptr);
}
bool NumpyEquals(const TensorDataNumpy &other) const {
auto all_data_equal = [&other, this]() -> bool {
auto np = py::module::import("numpy");
auto equal = np.attr("equal")(data_, other.data_);
auto all_equal = np.attr("all")(equal);
return all_equal.cast<bool>();
};
return this == &other || data_.is(other.data_) || all_data_equal();
}
/// To string.
std::string ToString(const TypeId type, const std::vector<int> &shape) const override {
return std::string(py::str(data_));
}
/// Data.
py::array data() const { return data_; }
private:
void EnsureDataContiguous() const {
if (!IsCContiguous(data_)) {
// Call numpy.ascontiguousarray() to convert data to C contiguous if it is not.
auto np = py::module::import("numpy");
auto convert = np.attr("ascontiguousarray");
data_ = convert(data_);
}
}
mutable py::array data_;
};
TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) {
// Get input buffer info.
py::buffer_info buf = input.request();
......@@ -199,13 +130,6 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr)
if (data_type == TypeId::kTypeUnknown) {
data_type = buf_type;
}
// Get tensor shape.
std::vector<int> shape(buf.shape.begin(), buf.shape.end());
if (data_type == buf_type) {
// Make a tensor with shared data with numpy if no type convertion needed.
auto tensor_data = std::make_shared<TensorDataNumpy>(input);
return std::make_shared<Tensor>(data_type, shape, tensor_data);
}
// Convert input array to C contiguous if need.
std::unique_ptr<char[]> tmp_buf;
if (!IsCContiguous(input)) {
......@@ -220,6 +144,8 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr)
PyBuffer_Release(&pybuf);
buf.ptr = tmp_buf.get();
}
// Get tensor shape.
std::vector<int> shape(buf.shape.begin(), buf.shape.end());
if (data_type == buf_type) {
// Use memory copy if input data type is the same as the required type.
return std::make_shared<Tensor>(data_type, shape, buf.ptr, buf.size * buf.itemsize);
......@@ -260,16 +186,12 @@ py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) {
py::array TensorPy::SyncAsNumpy(const Tensor &tensor) {
tensor.data_sync();
return AsNumpy(tensor);
auto info = GetPyBufferInfo(tensor);
py::object self = py::cast(&tensor);
return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self);
}
py::array TensorPy::AsNumpy(const Tensor &tensor) {
auto data_numpy = dynamic_cast<const TensorDataNumpy *>(&tensor.data());
if (data_numpy) {
// Return internal numpy array if tensor data is implemented base on it.
return data_numpy->data();
}
// Otherwise, create numpy array by buffer protocol.
auto info = GetPyBufferInfo(tensor);
py::object self = py::cast(&tensor);
return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self);
......
......@@ -145,7 +145,7 @@ class Parameter(MetaTensor):
data = data.to_tensor()
if isinstance(data, Tensor):
# make a copy of Tensor to init the parameter
return (Tensor, data.asnumpy().copy(),)
return (Tensor, data.asnumpy(),)
if isinstance(data, int):
return (Tensor, data, mstype.int32)
if isinstance(data, float):
......
......@@ -198,16 +198,10 @@ class TensorDataImpl : public TensorData {
return data_.get();
}
const void *const_data() const override {
// May return nullptr if data not initialized.
return data_.get();
}
bool equals(const TensorData &other) const override {
auto ptr = dynamic_cast<const TensorDataImpl<T> *>(&other);
if (ptr == nullptr) {
// Not same type, compare data byte by byte.
return TensorData::equals(other);
return false;
}
if (ptr == this) {
return true;
......
......@@ -50,23 +50,8 @@ class TensorData {
virtual ssize_t ndim() const = 0;
/// Data pointer.
virtual void *data() = 0;
/// Const Data pointer.
virtual const void *const_data() const = 0;
/// Is data equals.
virtual bool equals(const TensorData &other) const {
if (this == &other) {
return true;
}
// By default, compare data byte by byte.
auto this_data = static_cast<const uint8_t *>(const_data());
auto other_data = static_cast<const uint8_t *>(other.const_data());
if (this_data == nullptr || other_data == nullptr) {
// null means data not initialized, compare uninitialized data always return false.
return false;
}
return (this_data == other_data) || (ndim() == other.ndim() && nbytes() == other.nbytes() &&
std::equal(this_data, this_data + nbytes(), other_data));
}
virtual bool equals(const TensorData &other) const = 0;
/// To string.
virtual std::string ToString(const TypeId type, const ShapeVector &shape) const = 0;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册