提交 43e09670 编写于 作者: H He Wei

Decouple ir::Tensor class from python

上级 363a232c
......@@ -27,6 +27,7 @@
#include "utils/symbolic.h"
#include "ir/meta_func_graph.h"
#include "ir/param_value_py.h"
#include "ir/tensor_py.h"
#include "pipeline/parse/python_adapter.h"
#include "pipeline/parse/resolve.h"
#include "operator/composite/composite.h"
......@@ -39,6 +40,8 @@
#include "utils/context/ms_context.h"
#include "operator/ops.h"
using mindspore::tensor::TensorPy;
namespace mindspore {
// max number of elements in sequence
const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF;
......@@ -399,7 +402,7 @@ std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const Valu
oss << value->DumpText();
} else if (value->isa<tensor::Tensor>()) {
auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
oss << value->DumpText() << "@" << DumpObject(tensor_ptr->data(), "T");
oss << value->DumpText() << "@" << DumpObject(TensorPy::AsNumpy(*tensor_ptr), "T");
} else if (value->isa<parse::Symbol>() || value->isa<None>() || value->isa<NullObj>()) {
oss << value->DumpText();
} else if (value->isa<ValueSequeue>()) {
......@@ -1813,7 +1816,7 @@ class IrParser {
if (tensor_data == nullptr) {
return TOK_ERROR;
}
*val_ptr = std::make_shared<tensor::Tensor>(tensor_data, TypeIdToType(type));
*val_ptr = TensorPy::MakeTensor(tensor_data, TypeIdToType(type));
return lexer_.GetNextToken();
}
......
......@@ -117,7 +117,7 @@ void DebugServices::check_watchpoints(std::vector<std::string> *name, std::vecto
continue;
}
float *start_addr = reinterpret_cast<float *>(tensor_ptr->data_c(false));
float *start_addr = reinterpret_cast<float *>(tensor_ptr->data_c());
unsigned int num_elements = (tensor_ptr->data().nbytes()) / sizeof(float);
std::unordered_map<unsigned int, watchpoint_t>::iterator it_w_table_check;
......@@ -144,7 +144,7 @@ void DebugServices::check_watchpoints(std::vector<std::string> *name, std::vecto
name->push_back(name_no_slot);
slot->push_back(std::to_string(tensor_list[i]->GetSlot()));
data_ptr->push_back(reinterpret_cast<char *>(tensor_ptr->data_c(false)));
data_ptr->push_back(reinterpret_cast<char *>(tensor_ptr->data_c()));
data_size->push_back(tensor_ptr->data().nbytes());
int condition_item = -1;
......@@ -182,7 +182,7 @@ void DebugServices::read_nodes_tensors(std::vector<std::string> name, std::vecto
continue;
}
ret_name->push_back(std::get<0>(result));
data_ptr->push_back(reinterpret_cast<char *>(std::get<1>(result)->GetTensor()->data_c(false)));
data_ptr->push_back(reinterpret_cast<char *>(std::get<1>(result)->GetTensor()->data_c()));
data_size->push_back(std::get<1>(result)->GetTensor()->data().nbytes());
dtype->push_back(std::get<1>(result)->GetTensor()->Dtype());
shape->push_back(std::get<1>(result)->GetTensor()->shape());
......
......@@ -329,12 +329,12 @@ bool AscendDeviceAddress::DumpMemToFile(bool trans_flag, const std::string &file
MS_LOG(INFO) << "E2E Dump path is " << path;
mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(host_type, host_shape);
size_t host_size = out_tensor->data().nbytes();
ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c(true));
ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c());
if (!ret) {
MS_LOG(ERROR) << "Copy device mem to host failed";
return ret;
}
ret = mindspore::Dump::DumpToFile(path, out_tensor->data_c(false), host_size);
ret = mindspore::Dump::DumpToFile(path, out_tensor->data_c(), host_size);
} else {
auto host_tmp = std::vector<uint8_t>(size_);
auto ret_rt_memcpy = rtMemcpy(host_tmp.data(), size_, ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST);
......@@ -364,7 +364,7 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens
MS_LOG(INFO) << "E2E tensor name is " << tensor_name;
mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(host_type, host_shape);
size_t host_size = out_tensor->data().nbytes();
ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c(true));
ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c());
if (!ret) {
MS_LOG(ERROR) << "Copy device mem to host failed";
return ret;
......@@ -379,7 +379,7 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens
} else {
mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(type_id_, host_shape);
size_t host_size = out_tensor->data().nbytes();
auto ret_rt_memcpy = rtMemcpy(out_tensor->data_c(true), host_size, ptr_, host_size, RT_MEMCPY_DEVICE_TO_HOST);
auto ret_rt_memcpy = rtMemcpy(out_tensor->data_c(), host_size, ptr_, host_size, RT_MEMCPY_DEVICE_TO_HOST);
auto tensor_data = std::make_shared<mindspore::TensorData>();
tensor_data->SetName(tensor_name);
......
......@@ -80,11 +80,11 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeFloat32);
if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) {
address->ptr_ = tensor->data_c(false);
address->ptr_ = tensor->data_c();
} else {
address->ptr_ = resource_manager_.MemMalloc(tensor_size);
if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c(false))) {
tensor->data_c())) {
MS_LOG(EXCEPTION) << "Value node sync host to device failed!";
}
}
......@@ -177,7 +177,7 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &k
tensor->set_device_address(address);
need_sync_outputs->emplace_back(tensor);
} else {
address->ptr_ = tensor->data_c(true);
address->ptr_ = tensor->data_c();
address->ref_count_ = INIT_NODE_REF;
(void)bound_addresses->insert(address);
}
......@@ -220,11 +220,11 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
size_t tensor_size =
std::accumulate(data_shape.begin(), data_shape.end(), sizeof(float), std::multiplies<size_t>());
if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) {
address->ptr_ = tensor->data_c(false);
address->ptr_ = tensor->data_c();
} else {
address->ptr_ = resource_manager_.MemMalloc(tensor_size);
if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c(false))) {
tensor->data_c())) {
MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!";
}
tensor->set_dirty(true);
......
......@@ -390,7 +390,7 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph
tensor->set_device_address(device_address);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c(false))) {
tensor->data_c())) {
MS_LOG(INFO) << "SyncHostToDevice failed.";
return false;
}
......@@ -407,14 +407,14 @@ void KernelAdjust::LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs) {
tensor::TensorPtr loop_count_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
MS_EXCEPTION_IF_NULL(loop_count_tensor);
int32_t *val = nullptr;
val = static_cast<int32_t *>(loop_count_tensor->data_c(true));
val = static_cast<int32_t *>(loop_count_tensor->data_c());
MS_EXCEPTION_IF_NULL(val);
*val = 0;
inputs->push_back(loop_count_tensor);
tensor::TensorPtr iter_loop_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
MS_EXCEPTION_IF_NULL(iter_loop_tensor);
val = static_cast<int32_t *>(iter_loop_tensor->data_c(true));
val = static_cast<int32_t *>(iter_loop_tensor->data_c());
MS_EXCEPTION_IF_NULL(val);
*val = SizeToInt(LongToSize(ConfigManager::GetInstance().iter_num()));
MS_LOG(INFO) << "iter_loop_tensor = " << *val;
......@@ -422,14 +422,14 @@ void KernelAdjust::LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs) {
tensor::TensorPtr zero_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
MS_EXCEPTION_IF_NULL(zero_tensor);
val = static_cast<int32_t *>(zero_tensor->data_c(true));
val = static_cast<int32_t *>(zero_tensor->data_c());
MS_EXCEPTION_IF_NULL(val);
*val = 0;
inputs->push_back(zero_tensor);
tensor::TensorPtr one_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
MS_EXCEPTION_IF_NULL(one_tensor);
val = static_cast<int32_t *>(one_tensor->data_c(true));
val = static_cast<int32_t *>(one_tensor->data_c());
MS_EXCEPTION_IF_NULL(val);
*val = 1;
inputs->push_back(one_tensor);
......
......@@ -543,7 +543,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
}
AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
tensor->data_c(false))) {
tensor->data_c())) {
MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is"
<< AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is "
<< AnfAlgo::GetOutputInferDataType(value_node, output_idx);
......
......@@ -115,7 +115,7 @@ class MetaTensor : public Value {
// order it represents.
//
// return A const vector<int> which represents the shape of the tensor.
std::vector<int> shape() const { return shape_; }
const std::vector<int> &shape() const { return shape_; }
// brief Sets the shape of a tensor.
//
......
此差异已折叠。
......@@ -20,9 +20,7 @@
#include <memory>
#include <string>
#include <vector>
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include <numeric>
#include "Eigen/Core"
#include "device/device_address.h"
......@@ -30,63 +28,8 @@
#include "include/ms_tensor.h"
#include "utils/log_adapter.h"
namespace py = pybind11;
using float16 = Eigen::half;
namespace pybind11 {
namespace detail {
// Similar to enums in `pybind11/numpy.h`. Determined by doing:
// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
constexpr int NPY_FLOAT16 = 23;
template <typename T>
struct npy_scalar_caster {
PYBIND11_TYPE_CASTER(T, _("PleaseOverride"));
using Array = array_t<T>;
bool load(handle src, bool convert) {
// Taken from Eigen casters. Permits either scalar dtype or scalar array.
handle type = dtype::of<T>().attr("type");
if (!convert && !isinstance<Array>(src) && !isinstance(src, type)) return false;
Array tmp = Array::ensure(src);
if (tmp && tmp.size() == 1 && tmp.ndim() == 0) {
this->value = *tmp.data();
return true;
}
return false;
}
static handle cast(T src, return_value_policy, handle) {
Array tmp({1});
tmp.mutable_at(0) = src;
tmp.resize({});
// You could also just return the array if you want a scalar array.
object scalar = tmp[tuple()];
return scalar.release();
}
};
template <>
struct npy_format_descriptor<float16> {
static constexpr auto name = "float16";
static pybind11::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
return reinterpret_borrow<pybind11::dtype>(ptr);
}
virtual ~npy_format_descriptor<float16>() {}
};
template <>
struct type_caster<float16> : public npy_scalar_caster<float16> {
static constexpr auto name = "float16";
};
} // namespace detail
} // namespace pybind11
using mindspore::device::DeviceAddress;
using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>;
// brief mindspore namespace.
......@@ -98,179 +41,195 @@ namespace mindspore {
//
// A sub namespace in ME to support tensor related definition.
namespace tensor {
// Tensor data interface.
class TensorData {
public:
/// Total number of elements.
virtual ssize_t size() const = 0;
/// Byte size of a single element.
virtual ssize_t itemsize() const = 0;
/// Total number of bytes.
virtual ssize_t nbytes() const = 0;
/// Number of dimensions.
virtual ssize_t ndim() const = 0;
/// Data pointer.
virtual void *data() = 0;
/// Is data equals.
virtual bool equals(const TensorData &other) const = 0;
/// To string.
virtual std::string ToString() const = 0;
};
using TensorDataPtr = std::shared_ptr<TensorData>;
// Tensor entity class
class Tensor : public MetaTensor {
public:
Tensor() = default;
abstract::AbstractBasePtr ToAbstract() override;
// brief Constructor for Python.
// brief Create tensor from another tensor, data is shared.
//
// param tensor [Tensor] The input tensor.
explicit Tensor(const Tensor &tensor);
// brief Create tensor with given data type from another tensor.
//
// param type_ptr [TypePty] Data type of the tensor.
// param py_shape [py::tuple] The shape represented by py::tuple of the tensor.
Tensor(const TypePtr &type_ptr, const py::tuple &shape);
// param tensor [Tensor] The input tensor.
// param data_type [TypeId] The new tensor data type.
Tensor(const Tensor &tensor, TypeId data_type);
// brief Constructor for C++.
// brief Create tensor with the given shared tensor data.
//
// param data_type [TypeId] Data type of the tensor.
// param shape The shape represented by std::vector<int> of the tensor.
// param data The shared tensor data.
Tensor(TypeId data_type, const std::vector<int> &shape, TensorDataPtr data);
// brief Create an all zero tensor.
//
// param data_type [TypeId] Data type of the tensor.
// param shape The shape represented by std::vector<int> of the tensor.
Tensor(TypeId data_type, const std::vector<int> &shape);
// brief Constructor for Python.
// brief Create a tensor with input data buffer.
//
// param input [py::array] Data value of the tensor.
// param data_type [TypeId] Data type of the tensor.
explicit Tensor(const py::array &input, const TypePtr &data_type = nullptr);
// param shape The shape represented by std::vector<int> of the tensor.
// param data The input data to be copied into tensor.
// param data_len The length of data in bytes.
Tensor(TypeId data_type, const std::vector<int> &shape, void *data, size_t data_len);
// brief Constructor
// brief Create a tensor with input data buffer and given source data type.
//
// param input [py::list] the data for tensor
// param data_type [TypeId] data type
explicit Tensor(const py::list &input, const TypePtr &data_type = nullptr);
// param data_type [TypeId] Data type of the tensor.
// param shape The shape represented by std::vector<int> of the tensor.
// param data The input data to be copied into tensor.
// param src_data_type The source data type.
Tensor(TypeId data_type, const std::vector<int> &shape, void *data, TypeId src_data_type);
// brief Constructor
// brief Create 1 dimension tensor from an int vector.
//
// param input [py::tuple] the data for tensor
// param input [std::vector<int64_t>] the data for tensor
// param data_type [TypeId] data type
explicit Tensor(const py::tuple &input, const TypePtr &data_type = nullptr);
explicit Tensor(const std::vector<int64_t> &input, const TypePtr &data_type = nullptr);
// brief Constructor
// brief Create 1 dimension tensor from a float vector.
//
// param input [py::float_] the data for tensor
// param input [std::vector<double>] the data for tensor
// param data_type [TypeId] data type
explicit Tensor(const py::float_ &input, const TypePtr &data_type = nullptr);
explicit Tensor(const std::vector<double> &input, const TypePtr &data_type = nullptr);
// brief Constructor
// brief Create 0 dimension tensor from an int scalar.
//
// param input [py::int_] the data for tensor
// param input [int64] the data for tensor
// param data_type [TypeId] data type
explicit Tensor(const py::int_ &input, const TypePtr &data_type = nullptr);
explicit Tensor(int64_t input, const TypePtr &data_type = nullptr);
// brief Constructor
// brief Create 0 dimension tensor from a float scalar.
//
// param input [Tensor] the data for tensor
// param input [double] the data for tensor
// param data_type [TypeId] data type
Tensor(const Tensor &tensor, const TypePtr &data_type = nullptr);
explicit Tensor(double input, const TypePtr &data_type = nullptr);
~Tensor() override = default;
MS_DECLARE_PARENT(Tensor, MetaTensor);
// brief Overloads operator = for Tensor.
//
// The constructed Tensor object has the same type and shape with tensor.
//
// param tensor An existing Tensor object.
Tensor &operator=(const Tensor &tensor);
// brief Compares two Tensor objects.
//
// Compare two tensor objects to see if they have same data type, shape and
// data value.
// Compare two tensor objects to see if they have same data type, shape and data address.
//
// param tensor The Tensor object to be compared.
// return true: If having same type, shape and data, return true, or return false.
// return true: If having same type, shape and data address, return true, or return false.
bool operator==(const Tensor &tensor) const;
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
bool ValueEqual(const Tensor &other) const;
// assgin value to this tensor
Tensor &AssignValue(const Tensor &tensor);
// It is different from 'operator==' which just compare shape/type/address,
// it do real value comparison.
bool ValueEqual(const Tensor &tensor) const;
bool operator==(const Value &other) const override {
if (other.isa<Tensor>()) {
auto other_ = static_cast<const Tensor &>(other);
auto &other_ = static_cast<const Tensor &>(other);
return *this == other_;
} else {
return false;
}
return false;
}
py::tuple GetPyTupleShape() const;
// brief Gets tensor's dimension
//
// return The number of dimensions of the tensor data.
int DataDim() const;
int DataDim() const { return static_cast<int>(data().ndim()); }
// brief Getting tensor data size
//
// return The total number of elements of the tensor data.
int DataSize() const;
// brief Tensor's data value.
//
// return [py::array] The tensor's data in py::array.
py::array data() const;
int DataSize() const { return static_cast<int>(data().size()); }
// brief Get the data type fo the tensor for C++
//
// return [int] The tensor's data type will be cast to int to return.
int data_type_c() const;
int data_type_c() const { return static_cast<int>(data_type_); }
// brief Get the tensor's shape for C++
//
// return [std::vector<int>]
std::vector<int> shape_c(void) const;
std::vector<int> shape_c(void) const { return shape(); }
// brief Get Tensor data pointer for c++ type
//
// param writable true if writable, false if read only
// return The pointer to the object
void *data_c(bool writable = false);
void *data_c() { return data().data(); }
// brief Get Tensor data byte-size for c++ type
//
// return byte size of Tensor data
size_t Size() const { return this->data().nbytes(); }
size_t Size() const { return data().nbytes(); }
// brief Get data type from tensor data.
void *data_c() const { return data_->data(); }
// brief Sync data with device.
void data_sync() const;
// brief Get the internal data object.
//
// param buf The buffer info of the py::array data.
// return The [TypeId] of the tensor data.
TypeId GetDataType(const py::buffer_info &buf) const;
// return The reference to internal data object.
TensorData &data() { return *data_; }
// brief Sets the data type of a tensor.
// brief Get the internal data shared pointer.
//
// param data_type The data type of the tensor to be set.
// return The reference to internal data object.
const TensorDataPtr &data_ptr() const { return data_; }
// brief Get the internal data object.
//
// return The reference to internal data object.
const TensorData &data() const { return *data_; }
TypeId set_data_type(const TypeId data_type) override;
TypePtr SetDtype(const TypePtr type_ptr) override;
std::string GetShapeAndDataTypeInfo() const;
std::string ToString() const override;
std::string ToStringRepr() const;
py::array data_; // < Tensor's data value
const bool parse_info_ = true;
bool is_init();
void set_init_flag(bool flag);
private:
// brief init tensor
//
// param input [py::array] the data for tensor
// param data_type [TypeId] data type
// return true if succeed, false if failed.
void init(const py::array &input, const TypeId &data_type);
void init(const py::array &input, const TypePtr &type_ptr);
bool init_flag_{false};
// brief init tensor attribute
//
// param data_type [TypeId] Data type of the tensor.
// param shape [py::array] The shape of the tensor.
// return true if succeed, false if failed.
void init(TypeId data_type, const std::vector<int> &shape, py::array *data);
std::string ToStringRepr() const;
bool convert_data(const py::array &in, const TypeId in_data_type, py::array *out, const TypeId out_data_type);
bool is_init() { return init_flag_; }
void set_init_flag(bool flag) { init_flag_ = flag; }
public:
bool is_dirty() const { return dirty_; }
void set_dirty(const bool dirty) { dirty_ = dirty; }
DeviceAddressPtr device_address() const { return device_address_; }
void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; }
py::array data_sync();
std::string id() const { return id_; }
const bool parse_info_ = true;
private:
bool init_flag_{false};
TensorDataPtr data_{nullptr};
bool dirty_{true};
std::string id_{""};
DeviceAddressPtr device_address_{nullptr};
......@@ -282,8 +241,6 @@ using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
namespace inference {
class Tensor : public MSTensor {
public:
Tensor();
Tensor(TypeId data_type, const std::vector<int> &shape);
explicit Tensor(std::shared_ptr<tensor::Tensor> tensor_ptr);
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/tensor_py.h"
#include <functional>
#include <numeric>
#include <vector>
#include <sstream>
#include <string>
#include "device/device_address.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
#include "pipeline/static_analysis/abstract_value.h"
namespace mindspore {
namespace tensor {
static TypeId GetDataType(const py::buffer_info &buf) {
if (buf.format.size() == 1) {
switch (buf.format.front()) {
case 'e':
case 'f':
case 'd':
switch (buf.itemsize) {
case 2:
return TypeId::kNumberTypeFloat16;
case 4:
return TypeId::kNumberTypeFloat32;
case 8:
return TypeId::kNumberTypeFloat64;
}
break;
case 'b':
case 'h':
case 'i':
case 'l':
case 'q':
switch (buf.itemsize) {
case 1:
return TypeId::kNumberTypeInt8;
case 2:
return TypeId::kNumberTypeInt16;
case 4:
return TypeId::kNumberTypeInt32;
case 8:
return TypeId::kNumberTypeInt64;
}
break;
case 'B':
case 'H':
case 'I':
case 'L':
case 'Q':
switch (buf.itemsize) {
case 1:
return TypeId::kNumberTypeUInt8;
case 2:
return TypeId::kNumberTypeUInt16;
case 4:
return TypeId::kNumberTypeUInt32;
case 8:
return TypeId::kNumberTypeUInt64;
}
break;
case '?':
return TypeId::kNumberTypeBool;
}
}
MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize;
return TypeId::kTypeUnknown;
}
static std::string GetPyTypeFormat(TypeId data_type) {
switch (data_type) {
case TypeId::kNumberTypeFloat16:
return "e";
case TypeId::kNumberTypeFloat32:
return py::format_descriptor<float>::format();
case TypeId::kNumberTypeFloat64:
return py::format_descriptor<double>::format();
case TypeId::kNumberTypeUInt8:
return py::format_descriptor<uint8_t>::format();
case TypeId::kNumberTypeUInt16:
return py::format_descriptor<uint16_t>::format();
case TypeId::kNumberTypeUInt32:
return py::format_descriptor<uint32_t>::format();
case TypeId::kNumberTypeUInt64:
return py::format_descriptor<uint64_t>::format();
case TypeId::kNumberTypeInt8:
return py::format_descriptor<int8_t>::format();
case TypeId::kNumberTypeInt16:
return py::format_descriptor<int16_t>::format();
case TypeId::kNumberTypeInt32:
return py::format_descriptor<int32_t>::format();
case TypeId::kNumberTypeInt64:
return py::format_descriptor<int64_t>::format();
case TypeId::kNumberTypeBool:
return py::format_descriptor<bool>::format();
default:
MS_LOG(WARNING) << "Unsupported DataType " << data_type << ".";
return "";
}
}
static bool IsCContiguous(const py::array &input) {
auto flags = static_cast<unsigned int>(input.flags());
return (flags & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) != 0;
}
TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) {
// Get input buffer info.
py::buffer_info buf = input.request();
// Check data types.
auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kTypeUnknown;
auto buf_type = GetDataType(buf);
if (buf_type == TypeId::kTypeUnknown && data_type == TypeId::kTypeUnknown) {
MS_LOG(EXCEPTION) << "Unsupported tensor type!";
}
// Use buf type as data type if type_ptr not set.
if (data_type == TypeId::kTypeUnknown) {
data_type = buf_type;
}
// Convert input array to C contiguous if need.
std::unique_ptr<char[]> tmp_buf;
if (!IsCContiguous(input)) {
Py_buffer pybuf;
if (PyObject_GetBuffer(input.ptr(), &pybuf, PyBUF_ANY_CONTIGUOUS)) {
MS_LOG(EXCEPTION) << "Failed to get buffer from the input!";
}
tmp_buf = std::make_unique<char[]>(pybuf.len);
if (PyBuffer_ToContiguous(tmp_buf.get(), &pybuf, pybuf.len, 'C')) {
MS_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer.";
}
PyBuffer_Release(&pybuf);
buf.ptr = tmp_buf.get();
}
// Get tensor shape.
std::vector<int> shape(buf.shape.begin(), buf.shape.end());
if (data_type == buf_type) {
// Use memory copy if input data type is same as the required type.
return std::make_shared<Tensor>(data_type, shape, buf.ptr, buf.size * buf.itemsize);
}
// Create tensor with data type converted.
return std::make_shared<Tensor>(data_type, shape, buf.ptr, buf_type);
}
static std::vector<ssize_t> GetStrides(const std::vector<ssize_t> &shape, ssize_t item_size) {
std::vector<ssize_t> strides;
strides.reserve(shape.size());
const auto ndim = shape.size();
for (size_t i = 0; i < ndim; ++i) {
auto stride = item_size;
for (size_t j = i + 1; j < ndim; ++j) {
stride *= shape[j];
}
strides.push_back(stride);
}
return strides;
}
static py::buffer_info GetPyBufferInfo(const Tensor &tensor) {
std::vector<ssize_t> shape(tensor.shape().begin(), tensor.shape().end());
std::vector<ssize_t> strides = GetStrides(shape, tensor.data().itemsize());
return py::buffer_info{
tensor.data_c(), tensor.data().itemsize(), GetPyTypeFormat(tensor.data_type()), tensor.DataDim(), shape, strides};
}
py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) {
auto &shape = tensor.shape();
py::tuple dims(shape.size());
for (size_t i = 0; i < dims.size(); ++i) {
dims[i] = py::int_(shape[i]);
}
return dims;
}
py::array TensorPy::SyncAsNumpy(const Tensor &tensor) {
tensor.data_sync();
auto info = GetPyBufferInfo(tensor);
py::object self = py::cast(&tensor);
return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self);
}
py::array TensorPy::AsNumpy(const Tensor &tensor) {
auto info = GetPyBufferInfo(tensor);
py::object self = py::cast(&tensor);
return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self);
}
static std::vector<int> GetShapeFromTuple(const py::tuple &tuple) {
std::vector<int> shape;
const size_t size = tuple.size();
shape.reserve(tuple.size());
for (size_t i = 0; i < size; ++i) {
shape.push_back(py::int_(tuple[i]));
}
return shape;
}
REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
// Define python Tensor class.
// dtype should define before Tensor, because Tensor init depend dtype
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor")
.def(py::init([](const Tensor &tensor) { return std::make_shared<Tensor>(tensor); }),
py::arg("input"))
.def(py::init([](const Tensor &tensor, const TypePtr &type_ptr) {
TypeId data_type = type_ptr ? type_ptr->type_id() : kTypeUnknown;
if (data_type == kTypeUnknown || tensor.data_type() == data_type) {
return std::make_shared<Tensor>(tensor);
}
return std::make_shared<Tensor>(tensor, data_type);
}),
py::arg("input"), py::arg("dtype"))
.def(py::init([](const TypePtr &type_ptr, const py::tuple &shape) {
auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64;
return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape));
}),
py::arg("dtype"), py::arg("shape"))
.def(py::init([](const py::array &input, const TypePtr &type_ptr) {
return TensorPy::MakeTensor(input, type_ptr);
}),
py::arg("input"), py::arg("dtype") = nullptr)
.def(py::init([](py::float_ input, const TypePtr &type_ptr) {
return TensorPy::MakeTensor(py::array(input), type_ptr);
}),
py::arg("input"), py::arg("dtype") = nullptr)
.def(py::init([](py::int_ input, const TypePtr &type_ptr) {
return TensorPy::MakeTensor(py::array(input), type_ptr);
}),
py::arg("input"), py::arg("dtype") = nullptr)
.def(py::init([](py::list input, const TypePtr &type_ptr) {
return TensorPy::MakeTensor(py::array(input), type_ptr);
}),
py::arg("input"), py::arg("dtype") = nullptr)
.def(py::init([](py::tuple input, const TypePtr &type_ptr) {
return TensorPy::MakeTensor(py::array(input), type_ptr);
}),
py::arg("input"), py::arg("dtype") = nullptr)
.def_readonly(PYTHON_TENSOR_FLAG, &Tensor::parse_info_)
.def_property_readonly("dtype", &Tensor::Dtype, R"mydelimiter(
Get the tensor's data type.
Returns:
type, the data type of tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
>>> data.dtype
Int32
)mydelimiter")
.def_property_readonly("shape", TensorPy::GetPyTupleShape, R"mydelimiter(
Get the tensor's shape.
Returns:
tuple[int], the shape of tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((3, 3)))
>>> data.shape()
(3, 3)
)mydelimiter")
.def("asnumpy", TensorPy::SyncAsNumpy, R"mydelimiter(
Convert tensor to numpy.ndarray.
Returns:
numpy.ndarray.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 3)))
>>> array = data.asnumpy()
>>> array
array([[1., 1., 1.],
[1., 1., 1.]])
)mydelimiter")
.def("size", &Tensor::DataSize, R"mydelimiter(
Get tensor's data size.
Returns:
int, the size of tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 3)))
>>> data.size()
6
)mydelimiter")
.def("is_init", &Tensor::is_init, R"mydelimiter(
Get tensor init_flag.
Returns:
bool, whether the tensor init.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 3)))
>>> data.is_init()
False
)mydelimiter")
.def("set_init_flag", &Tensor::set_init_flag, R"mydelimiter(
Set tensor init_flag.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 3)))
>>> data.set_init_flag(True)
)mydelimiter")
.def("dim", &Tensor::DataDim, R"mydelimiter(
Get tensor's data dimension.
Returns:
int, the dimension of tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 3)))
>>> data.dim()
2
)mydelimiter")
.def("set_dtype", &Tensor::SetDtype, R"mydelimiter(
Set the tensor's data type.
Arg:
dtype (:class:`mindspore.dtype`): The type of output tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
>>> data.set_dtype(mindspore.int32)
mindspore.int32
)mydelimiter")
.def("__str__", &Tensor::ToString)
.def("__repr__", &Tensor::ToStringRepr)
.def(py::pickle(
[](const Tensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(TensorPy::AsNumpy(t));
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
return TensorPy::MakeTensor(t[0].cast<py::array>());
}));
// Define python MetaTensor class.
(void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor")
.def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape"))
.def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_)
.def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
.def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.")
.def(py::pickle(
[](const MetaTensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(static_cast<int>(t.data_type()), t.shape());
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 2) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
MetaTensor tensor(TypeId(t[0].cast<int>()), t[1].cast<std::vector<int>>());
return tensor;
}));
}));
} // namespace tensor
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_TENSOR_PY_H_
#define MINDSPORE_CCSRC_IR_TENSOR_PY_H_
#include <memory>
#include <string>
#include <vector>
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#include "ir/tensor.h"
namespace py = pybind11;
namespace pybind11 {
namespace detail {
// Similar to enums in `pybind11/numpy.h`. Determined by doing:
// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
constexpr int NPY_FLOAT16 = 23;
template <typename T>
struct npy_scalar_caster {
PYBIND11_TYPE_CASTER(T, _("PleaseOverride"));
using Array = array_t<T>;
bool load(handle src, bool convert) {
// Taken from Eigen casters. Permits either scalar dtype or scalar array.
handle type = dtype::of<T>().attr("type");
if (!convert && !isinstance<Array>(src) && !isinstance(src, type)) return false;
Array tmp = Array::ensure(src);
if (tmp && tmp.size() == 1 && tmp.ndim() == 0) {
this->value = *tmp.data();
return true;
}
return false;
}
static handle cast(T src, return_value_policy, handle) {
Array tmp({1});
tmp.mutable_at(0) = src;
tmp.resize({});
// You could also just return the array if you want a scalar array.
object scalar = tmp[tuple()];
return scalar.release();
}
};
template <>
struct npy_format_descriptor<float16> {
static constexpr auto name = "float16";
static pybind11::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
return reinterpret_borrow<pybind11::dtype>(ptr);
}
virtual ~npy_format_descriptor<float16>() {}
};
template <>
struct type_caster<float16> : public npy_scalar_caster<float16> {
static constexpr auto name = "float16";
};
} // namespace detail
} // namespace pybind11
using mindspore::device::DeviceAddress;
using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>;
// brief mindspore namespace.
//
// mindspore namespace is the top level namespace of Mindsporeession project.
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
namespace mindspore {
// brief mindspore::tensor namespace
//
// A sub namespace in ME to support tensor related definition.
namespace tensor {
// Tensor python wrapper and adapter class.
class TensorPy {
public:
// brief Create Tensor from a numpy array object.
//
// param input [py::array] Data value of the tensor.
// param data_type [TypeId] Data type of the tensor.
static TensorPtr MakeTensor(const py::array &input, const TypePtr &data_type = nullptr);
static py::array SyncAsNumpy(const Tensor &tensor);
static py::array AsNumpy(const Tensor &tensor);
static py::tuple GetPyTupleShape(const Tensor &tensor);
};
} // namespace tensor
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_TENSOR_PY_H_
......@@ -23,6 +23,7 @@
#include <algorithm>
#include <functional>
#include "ir/tensor_py.h"
#include "ir/param_value_py.h"
#include "debug/anf_ir_utils.h"
#include "operator/ops.h"
......@@ -257,7 +258,7 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::Att
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
auto data = value->cast<tensor::TensorPtr>();
tensor_proto->set_raw_data(data->data().request(true).ptr, static_cast<size_t>(data->data().nbytes()));
tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
auto dtype = data->data_type();
auto shape = data->shape_c();
tensor_proto->set_data_type(GetOnnxDataType(dtype));
......
......@@ -27,6 +27,7 @@
#include "proto/onnx.pb.h"
#include "operator/ops.h"
#include "ir/param_value_py.h"
#include "ir/tensor_py.h"
namespace mindspore {
enum OpMergeMode {
......@@ -1190,7 +1191,7 @@ void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *cons
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
auto data = dyn_cast<tensor::Tensor>(value);
tensor_proto->set_raw_data(data->data().request(true).ptr, static_cast<size_t>(data->data().nbytes()));
tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
auto dtype = data->data_type();
auto shape = data->shape_c();
......
......@@ -21,6 +21,9 @@
#include "pipeline/static_analysis/param_validator.h"
#include "operator/ops.h"
#include "utils/convert_utils.h"
#include "ir/tensor_py.h"
using mindspore::tensor::TensorPy;
namespace mindspore {
namespace abstract {
......@@ -554,7 +557,7 @@ AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitiveP
py::tuple data_tuple = ValuePtrToPyData(input->BuildValue());
py::array data = py::array(data_tuple);
auto tensor = std::make_shared<tensor::Tensor>(data);
auto tensor = TensorPy::MakeTensor(data);
auto ret = tensor->ToAbstract();
ret->set_value(tensor);
MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString();
......
......@@ -153,7 +153,7 @@ class TensorMultiplyBase : public AnfVisitor {
}
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
return tensor_ptr->data_c(writable);
return tensor_ptr->data_c();
}
// Make a new tensor (when possible) with the same shape as of `node`
......@@ -171,7 +171,7 @@ class TensorMultiplyBase : public AnfVisitor {
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
if (x == nullptr) {
std::memset(data, 0, mem_size);
......@@ -546,7 +546,7 @@ class ConstantDuplicateMul : public AnfVisitor {
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape);
size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
memcpy(data, data_out, mem_size);
auto new_vnode = NewValueNode(new_tensor_ptr);
......
......@@ -191,7 +191,7 @@ inline void ResetSharedOp() {
tensor::TensorPtr ConstData() {
std::vector<int> shp = {1};
tensor::TensorPtr const_data = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
auto *val = static_cast<int32_t *>(const_data->data_c(true));
auto *val = static_cast<int32_t *>(const_data->data_c());
*val = 0;
return const_data;
}
......@@ -267,7 +267,7 @@ CNodePtr GenerateSwitchControlDependNode(const FuncGraphPtr &graph, const AnfNod
auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast<PrimitivePtr>();
std::vector<int> shp = {1};
tensor::TensorPtr const_data = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
auto *val = static_cast<int32_t *>(const_data->data_c(true));
auto *val = static_cast<int32_t *>(const_data->data_c());
*val = 0;
// for the control_depend netoutput node , add two const data to merge the flow ,one for depended node with same
// switch the other use the opposite
......
......@@ -178,7 +178,7 @@ class ZeroLikeFillZero : public AnfVisitor {
tensor::TensorPtr new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
(void)memset_s(data, mem_size, 0, mem_size);
auto new_cnode = NewValueNode(new_tensor_ptr);
......
......@@ -71,7 +71,7 @@ class SpecializeTransform {
continue;
}
if (value_args[i] != nullptr) {
auto const_tensor = *value_args[i];
auto &const_tensor = *value_args[i];
auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor);
AnfNodePtr arg = NewValueNode(const_tensor_ptr);
(void)mng->Replace(params[i], arg);
......
......@@ -210,8 +210,8 @@ OperatorVector CreateSubOp(int32_t sub_value) {
OperatorName operator_name = SUB;
OperatorAttrs operator_attrs;
py::tuple tuple = py::make_tuple(sub_value);
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(tuple, kInt32);
std::vector<int64_t> tensor_data = {sub_value};
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(tensor_data, kInt32);
ValuePtr op_param_value = MakeValue(tensor_ptr);
Attr op1_param = std::make_pair("", op_param_value);
......
......@@ -204,8 +204,8 @@ ForwardOp CreatReduceMeanForwardOp(const std::vector<Group> &forward_group, cons
OperatorName operator1_name = REAL_DIV;
std::vector<Device> device_list = forward_group[0].GetDevicesList();
auto divisor = static_cast<float>(device_list.size());
py::tuple tuple = py::make_tuple(divisor);
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(tuple, dtype);
std::vector<double> tensor_data = {divisor};
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(tensor_data, dtype);
ValuePtr op1_param_value = MakeValue(tensor_ptr);
Attr op1_param = std::make_pair("divisor", op1_param_value);
OperatorParams operator1_params = {std::make_pair(op1_param, 2)};
......
......@@ -156,11 +156,11 @@ void ConvertObjectToTensors(const py::dict &dict, TensorOrderMap *const tensors)
if (py::isinstance<py::float_>(item.second.attr("default_input"))) {
// convert float to tensor with shape([1])
tensor = std::make_shared<Tensor>(kNumberTypeFloat32, std::vector<int>({1}));
*(static_cast<float *>(tensor->data_c(true))) = py::cast<float>(item.second.attr("default_input"));
*(static_cast<float *>(tensor->data_c())) = py::cast<float>(item.second.attr("default_input"));
} else if (py::isinstance<py::int_>(item.second.attr("default_input"))) {
// convert int to tensor with shape([1])
tensor = std::make_shared<Tensor>(kNumberTypeInt32, std::vector<int>({1}));
*(static_cast<float *>(tensor->data_c(true))) = py::cast<float>(item.second.attr("default_input"));
*(static_cast<float *>(tensor->data_c())) = py::cast<float>(item.second.attr("default_input"));
} else if (py::hasattr(item.second.attr("default_input"), PYTHON_TENSOR_FLAG)) {
// cast tensor
tensor = py::cast<std::shared_ptr<Tensor>>(item.second.attr("default_input"));
......@@ -330,7 +330,7 @@ py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::t
MS_LOG(EXCEPTION) << "The shape of the tensor derived is not Shape, is " << shape->ToString();
}
auto shape_me = shape->cast<abstract::ShapePtr>()->shape();
auto shape_ge = py::cast<Tensor>(data[*count]).shape();
auto shape_ge = py::cast<Tensor &>(data[*count]).shape();
if (shape_ge != shape_me) {
MS_LOG(EXCEPTION) << "The shape of the " << *count << "th tensor returned: " << shape_ge
<< " is not the same as the shape of the tensor derived: " << shape_me;
......
......@@ -44,7 +44,7 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
indices_tensor->set_device_info(device_info);
// 2 set value of tensor
auto data_ptr = indices_tensor->data_c(true);
auto data_ptr = indices_tensor->data_c();
MS_EXCEPTION_IF_NULL(data_ptr);
std::vector<Eigen::half> half_data;
for (size_t i = 0; i < last_dim; ++i) {
......
......@@ -348,7 +348,7 @@ tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_pt
MS_EXCEPTION_IF_NULL(tensor);
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr};
tensor->set_device_info(device_info);
auto data_ptr = tensor->data_c(true);
auto data_ptr = tensor->data_c();
MS_EXCEPTION_IF_NULL(data_ptr);
auto elem_num = values.size() * data_length;
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), elem_num);
......
......@@ -538,7 +538,7 @@ bool Kernel2Ms::KernelInput2MS(const std::vector<TensorPtr> &input_tensors) {
auto match_idx = match_to_rel_idxs[j];
auto real_tensor = input_tensors[match_idx];
auto real_size = LongToSize(real_tensor->data().nbytes());
auto real_data = real_tensor->data_c(false);
auto real_data = real_tensor->data_c();
MS_EXCEPTION_IF_NULL(real_data);
if (sub_ms_graph_->allTensors[cache_idx] != nullptr) {
sub_ms_graph_->allTensors[cache_idx]->data.resize(real_size);
......
......@@ -22,6 +22,7 @@
#include <unordered_set>
#include <algorithm>
#include "ir/tensor_py.h"
#include "ir/param_value_py.h"
#include "utils/any.h"
#include "utils/utils.h"
......@@ -51,6 +52,8 @@
#include "pynative/pynative_execute_ge.h"
#endif
using mindspore::tensor::TensorPy;
const char SINGLE_OP_GRAPH[] = "single_op_graph";
// primitive unable to infer value for constant input in PyNative mode
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient"};
......@@ -171,7 +174,8 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu
py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype());
(*out_args_list)[i] = py_args[i];
} else {
py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype());
double arg_value = py::cast<py::float_>(py_args[i]);
py_args[i] = std::make_shared<tensor::Tensor>(arg_value, tensor_ptr->Dtype());
(*out_args_list)[i] = py_args[i];
}
continue;
......@@ -262,7 +266,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
result[i] = py::getattr(input, "data");
} else {
auto tensor = py::cast<tensor::TensorPtr>(op_inputs[i]);
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data());
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
result[i] = new_tensor;
}
}
......@@ -366,13 +370,14 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr
if (py::isinstance<tensor::Tensor>(input_object)) {
tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
} else if (py::isinstance<py::float_>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::float_>(input_object), kFloat32);
double input_value = py::cast<py::float_>(input_object);
tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
*tensor_mask = kValueNodeTensorMask;
} else if (py::isinstance<py::int_>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), kInt32);
*tensor_mask = kValueNodeTensorMask;
} else if (py::isinstance<py::array>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr);
tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr);
} else if (py::isinstance<py::list>(input_object)) {
auto list_inputs = py::cast<py::list>(input_object);
py::tuple tuple_inputs(list_inputs.size());
......
......@@ -26,6 +26,7 @@
#include <stack>
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#include "pynative/base.h"
#include "utils/context/ms_context.h"
......
......@@ -28,9 +28,12 @@
#include "pipeline/parse/data_converter.h"
#include "pipeline/static_analysis/prim.h"
#include "session/session_factory.h"
#include "ir/tensor_py.h"
const char SINGLE_OP_GRAPH[] = "single_op_graph";
using mindspore::tensor::TensorPy;
namespace mindspore {
namespace pynative {
using MeTensor = mindspore::tensor::Tensor;
......@@ -56,15 +59,15 @@ MeTensorPtr ConvertPyObjToTensor(const py::object &obj) {
if (py::isinstance<MeTensor>(obj)) {
me_tensor_ptr = py::cast<MeTensorPtr>(obj);
} else if (py::isinstance<py::tuple>(obj)) {
me_tensor_ptr = std::make_shared<MeTensor>(py::cast<py::tuple>(obj), nullptr);
me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast<py::tuple>(obj)), nullptr);
} else if (py::isinstance<py::float_>(obj)) {
me_tensor_ptr = std::make_shared<MeTensor>(py::cast<py::float_>(obj), nullptr);
me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast<py::float_>(obj)), nullptr);
} else if (py::isinstance<py::int_>(obj)) {
me_tensor_ptr = std::make_shared<MeTensor>(py::cast<py::int_>(obj), nullptr);
me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast<py::int_>(obj)), nullptr);
} else if (py::isinstance<py::list>(obj)) {
me_tensor_ptr = std::make_shared<MeTensor>(py::cast<py::list>(obj), nullptr);
me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast<py::list>(obj)), nullptr);
} else if (py::isinstance<py::array>(obj)) {
me_tensor_ptr = std::make_shared<MeTensor>(py::cast<py::array>(obj), nullptr);
me_tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(obj), nullptr);
} else {
MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
}
......
......@@ -16,6 +16,7 @@
#include "session/ascend_inference_session.h"
#include "operator/ops.h"
#include "ir/tensor.h"
#include "ir/tensor_py.h"
#include "ir/anf.h"
#include "ir/param_value_py.h"
#include "device/kernel_runtime.h"
......@@ -26,6 +27,8 @@
#include "utils/config_manager.h"
#include "utils/base_ref_extends.h"
using mindspore::tensor::TensorPy;
namespace mindspore {
namespace session {
void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
......@@ -51,7 +54,7 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k
auto py_param = param_value->value();
MS_EXCEPTION_IF_NULL(py_param);
py::array py_array = py_param.cast<py::array>();
tensor = std::make_shared<tensor::Tensor>(py_array);
tensor = TensorPy::MakeTensor(py_array);
} else {
tensor = inputs[no_weight_input++];
}
......@@ -78,7 +81,7 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k
MS_EXCEPTION_IF_NULL(device_address);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c(false))) {
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
}
......
......@@ -989,7 +989,7 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true
MS_EXCEPTION_IF_NULL(condition_graph);
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int>{1});
int32_t *val = nullptr;
val = static_cast<int32_t *>(tensor->data_c(true));
val = static_cast<int32_t *>(tensor->data_c());
MS_EXCEPTION_IF_NULL(val);
*val = 0;
auto value_node = std::make_shared<ValueNode>(tensor);
......@@ -1523,7 +1523,7 @@ void AscendSession::SyncInitialTenosrToDevice() {
auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0);
MS_EXCEPTION_IF_NULL(addr);
if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size,
front_tensor->data_type(), front_tensor->data_c(false))) {
front_tensor->data_type(), front_tensor->data_c())) {
MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!";
}
}
......
......@@ -129,7 +129,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
MS_EXCEPTION_IF_NULL(device_address);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c(false))) {
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
}
......
......@@ -96,8 +96,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
tensor->set_dirty(false);
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c(true))) {
LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) {
MS_LOG(INFO) << "output sync device to host error!!!";
tensor->set_dirty(false);
}
......@@ -218,7 +217,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
}
auto tensor = (*inputs_params)[0];
MS_EXCEPTION_IF_NULL(tensor);
auto *val = static_cast<int32_t *>(tensor->data_c(true));
auto *val = static_cast<int32_t *>(tensor->data_c());
MS_EXCEPTION_IF_NULL(val);
*val = 0;
tensor->set_dirty(true);
......@@ -720,7 +719,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
MS_EXCEPTION_IF_NULL(device_address);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c(false))) {
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
}
......@@ -815,7 +814,7 @@ void SessionBasic::Summary(KernelGraph *graph) {
continue;
}
if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
tensor->data_type(), tensor->data_c(true))) {
tensor->data_type(), tensor->data_c())) {
MS_LOG(ERROR) << "Failed to sync output from device to host.";
}
tensor->set_dirty(false);
......
......@@ -342,7 +342,7 @@ MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const
MeTensor me_tensor(me_type, me_dims);
// Get the writable data pointer of the tensor and cast it to its data type
auto me_data_ptr = reinterpret_cast<uint8_t *>(me_tensor.data_c(true));
auto me_data_ptr = reinterpret_cast<uint8_t *>(me_tensor.data_c());
size_t me_data_size = static_cast<size_t>(me_tensor.data().nbytes());
MS_EXCEPTION_IF_NULL(me_data_ptr);
MS_EXCEPTION_IF_NULL(ge_tensor);
......
......@@ -579,11 +579,12 @@ tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) {
}
tensor::TensorPtr tensor = nullptr;
if (scalar->isa<FloatImm>()) {
tensor = std::make_shared<tensor::Tensor>(py::float_(GetValue<float>(scalar)), kFloat32);
tensor = std::make_shared<tensor::Tensor>(static_cast<double>(GetValue<float>(scalar)), kFloat32);
} else if (scalar->isa<IntergerImm>()) {
tensor = std::make_shared<tensor::Tensor>(py::int_(GetValue<int>(scalar)), kInt32);
tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int>(scalar)), kInt32);
} else if (scalar->isa<BoolImm>()) {
tensor = std::make_shared<tensor::Tensor>(py::array(py::bool_(GetValue<bool>(scalar))), kBool);
const int64_t bool_value = GetValue<bool>(scalar) ? 1 : 0;
tensor = std::make_shared<tensor::Tensor>(bool_value, kBool);
} else {
auto type = scalar->type();
auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
......
......@@ -22,12 +22,14 @@
#include <vector>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "ir/tensor.h"
#include "ir/tensor_py.h"
#include "ir/param_value_py.h"
#include "operator/ops.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "proto/onnx.pb.h"
#include "utils/log_adapter.h"
using mindspore::tensor::TensorPy;
using std::string;
namespace mindspore {
......@@ -117,11 +119,11 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, cons
if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) {
const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()];
std::string initial_data = initialize_proto.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c(true));
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
MS_EXCEPTION_IF_NULL(tensor_data_buf);
memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), initial_data.data(), initial_data.size());
py::array array_data = tensor_info->data();
py::array array_data = TensorPy::AsNumpy(*tensor_info);
ParamValuePyPtr para_value_ptr = std::make_shared<ParamValuePy>();
MS_EXCEPTION_IF_NULL(para_value_ptr);
para_value_ptr->set_value(array_data);
......@@ -249,7 +251,7 @@ bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node
}
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
const std::string &tensor_buf = attr_tensor.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c(true));
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
auto new_value_node = NewValueNode(MakeValue(tensor_info));
MS_EXCEPTION_IF_NULL(new_value_node);
......
......@@ -87,7 +87,7 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co
const size_t &memory_size) {
MS_EXCEPTION_IF_NULL(str_data_ptr);
MS_EXCEPTION_IF_NULL(print_tensor);
auto *tensor_data_ptr = static_cast<uint8_t *>(print_tensor->data_c(true));
auto *tensor_data_ptr = static_cast<uint8_t *>(print_tensor->data_c());
MS_EXCEPTION_IF_NULL(tensor_data_ptr);
auto cp_ret =
memcpy_s(tensor_data_ptr, static_cast<size_t>(print_tensor->data().nbytes()), str_data_ptr, memory_size);
......
......@@ -61,9 +61,9 @@ class Tensor(Tensor_):
if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']):
input_data = np.ascontiguousarray(input_data)
if dtype is None:
super(Tensor, self).__init__(input_data)
Tensor_.__init__(self, input_data)
else:
super(Tensor, self).__init__(input_data, dtype)
Tensor_.__init__(self, input_data, dtype)
self._virtual_flag = False
self._init_flag = False
......
......@@ -55,6 +55,7 @@ def rmsprop_numpy(variable, gradients, mean_square, moment,
mean_square = mean_square * decay + (1.0 - decay) * gradients * gradients
moment = momentum * moment + learning_rate / np.sqrt(mean_square + epsilon) * gradients
variable = variable - moment
return variable, gradients, mean_square, moment
def rmspropcented_numpy(variable, gradients, mean_gradients, mean_square, moment,
......@@ -64,7 +65,7 @@ def rmspropcented_numpy(variable, gradients, mean_gradients, mean_square, moment
moment = momentum * moment + learning_rate / np.sqrt(
mean_square - mean_gradients * mean_gradients + epsilon) * gradients
variable = variable - moment
return variable, gradients, mean_gradients, mean_square, moment
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
......@@ -85,12 +86,14 @@ def test_rmsprop():
moment_ms = Tensor(moment_np)
if centered:
variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np = \
rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np,
learning_rate, decay, momentum, epsilon)
net = NetCenteredRMSProp(learning_rate, decay, momentum, epsilon)
_ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms)
else:
variable_np, gradients_np, mean_square_np, moment_np = \
rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np,
learning_rate, decay, momentum, epsilon)
net = NetRMSProp(learning_rate, decay, momentum, epsilon)
......@@ -136,11 +139,13 @@ def test_rmspropcenter():
moment_ms = Tensor(moment_np)
if centered:
variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np = \
rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np,
learning_rate, decay, momentum, epsilon)
net = NetCenteredRMSProp(learning_rate, decay, momentum, epsilon)
_ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms)
else:
variable_np, gradients_np, mean_square_np, moment_np = \
rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np,
learning_rate, decay, momentum, epsilon)
net = NetRMSProp(learning_rate, decay, momentum, epsilon)
......
......@@ -22,6 +22,9 @@
#include "securec/include/securec.h"
#include "ir/tensor.h"
#include "ir/tensor_py.h"
using mindspore::tensor::TensorPy;
namespace mindspore {
namespace tensor {
......@@ -90,9 +93,7 @@ TEST_F(TestMetaTensor, EqualTest) {
class TestTensor : public UT::Common {
public:
TestTensor() {}
virtual void SetUp() {
UT::InitPythonPath();
}
virtual void SetUp() { UT::InitPythonPath(); }
};
py::array_t<float, py::array::c_style> BuildInputTensor() {
......@@ -124,7 +125,7 @@ TEST_F(TestTensor, PyArrayScalarTest) {
TEST_F(TestTensor, InitScalarTest) {
std::vector<int> dimensions;
Tensor tensor(TypeId::kNumberTypeInt64, dimensions);
uint8_t *data_buf = reinterpret_cast<uint8_t *>(tensor.data_c(true));
uint8_t *data_buf = reinterpret_cast<uint8_t *>(tensor.data_c());
int64_t num = 1;
errno_t ret = memcpy_s(data_buf, sizeof(int64_t), &num, sizeof(int64_t));
......@@ -172,9 +173,9 @@ TEST_F(TestTensor, InitTensorPtrTest) {
}
TEST_F(TestTensor, InitByTupleTest) {
py::tuple dimensions = py::make_tuple(2, 3, 4);
const std::vector<int> shape = {2, 3, 4};
TypePtr data_type = kFloat32;
Tensor tuple_tensor = Tensor(data_type, dimensions);
Tensor tuple_tensor(data_type->type_id(), shape);
ASSERT_EQ(2, tuple_tensor.DimensionSize(0));
ASSERT_EQ(3, tuple_tensor.DimensionSize(1));
ASSERT_EQ(4, tuple_tensor.DimensionSize(2));
......@@ -184,8 +185,8 @@ TEST_F(TestTensor, InitByTupleTest) {
ASSERT_EQ(TypeId::kNumberTypeFloat32, tuple_tensor.data_type());
py::tuple tuple = py::make_tuple(1.0, 2.0, 3, 4, 5, 6);
TensorPtr tensor = std::make_shared<Tensor>(tuple, kFloat64);
py::array array = tensor->data();
TensorPtr tensor = TensorPy::MakeTensor(py::array(tuple), kFloat64);
py::array array = TensorPy::AsNumpy(*tensor);
std::cout << "Dim: " << array.ndim() << std::endl;
ASSERT_EQ(1, array.ndim());
......@@ -203,24 +204,24 @@ TEST_F(TestTensor, InitByTupleTest) {
TEST_F(TestTensor, EqualTest) {
py::tuple tuple = py::make_tuple(1, 2, 3, 4, 5, 6);
TensorPtr tensor_int8 = std::make_shared<Tensor>(tuple, kInt8);
TensorPtr tensor_int8 = TensorPy::MakeTensor(py::array(tuple), kInt8);
ASSERT_TRUE(*tensor_int8 == *tensor_int8);
ASSERT_EQ(TypeId::kNumberTypeInt8, tensor_int8->data_type_c());
TensorPtr tensor_int16 = std::make_shared<Tensor>(tuple, kInt16);
TensorPtr tensor_int16 = TensorPy::MakeTensor(py::array(tuple), kInt16);
ASSERT_EQ(TypeId::kNumberTypeInt16, tensor_int16->data_type_c());
TensorPtr tensor_int32 = std::make_shared<Tensor>(tuple, kInt32);
TensorPtr tensor_int32 = TensorPy::MakeTensor(py::array(tuple), kInt32);
ASSERT_EQ(TypeId::kNumberTypeInt32, tensor_int32->data_type_c());
TensorPtr tensor_float16 = std::make_shared<Tensor>(tuple, kFloat16);
TensorPtr tensor_float16 = TensorPy::MakeTensor(py::array(tuple), kFloat16);
ASSERT_EQ(TypeId::kNumberTypeFloat16, tensor_float16->data_type_c());
TensorPtr tensor_float32 = std::make_shared<Tensor>(tuple, kFloat32);
TensorPtr tensor_float32 = TensorPy::MakeTensor(py::array(tuple), kFloat32);
ASSERT_EQ(TypeId::kNumberTypeFloat32, tensor_float32->data_type_c());
TensorPtr tensor_float64 = std::make_shared<Tensor>(tuple, kFloat64);
TensorPtr tensor_float64 = TensorPy::MakeTensor(py::array(tuple), kFloat64);
ASSERT_EQ(TypeId::kNumberTypeFloat64, tensor_float64->data_type_c());
}
......@@ -247,7 +248,7 @@ TEST_F(TestTensor, PyArrayTest) {
TEST_F(TestTensor, InitByFloatArrayDataCTest) {
// Init tensor data by py::array_t<float>
auto tensor = std::make_shared<Tensor>(BuildInputTensor());
auto tensor = TensorPy::MakeTensor(BuildInputTensor());
// Print some information of the tensor
std::cout << "Datatype: " << tensor->data_type() << std::endl;
......@@ -269,7 +270,7 @@ TEST_F(TestTensor, InitByFloatArrayDataCTest) {
TEST_F(TestTensor, InitByFloatArrayDataTest) {
// Init tensor data by py::array_t<float>
TensorPtr tensor = std::make_shared<Tensor>(BuildInputTensor());
TensorPtr tensor = TensorPy::MakeTensor(BuildInputTensor());
// Print some information of the tensor
std::cout << "Datatype: " << tensor->data_type() << std::endl;
......@@ -291,7 +292,7 @@ TEST_F(TestTensor, InitByFloatArrayDataTest) {
// Print each elements
std::cout << "Elements: " << std::endl;
py::array_t<float> data = (py::array_t<float>)tensor->data();
py::array_t<float> data = py::cast<py::array_t<float>>(TensorPy::AsNumpy(*tensor));
auto array = data.unchecked<2>();
for (int i = 0; i < array.shape(0); i++) {
for (int j = 0; j < array.shape(1); j++) {
......@@ -319,17 +320,17 @@ TEST_F(TestTensor, TensorDataTest) {
float ge_tensor_data[] = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6};
// Create a Tensor with wanted data type and shape
Tensor tensor = Tensor(TypeId::kNumberTypeFloat32, std::vector<int>({2, 3}));
Tensor tensor(TypeId::kNumberTypeFloat32, std::vector<int>({2, 3}));
// Get the writable data pointer from the tensor
float *me_tensor_data = reinterpret_cast<float *>(tensor.data_c(true));
float *me_tensor_data = reinterpret_cast<float *>(tensor.data_c());
// Copy data from buffer to tensor's data
errno_t ret = memcpy_s(me_tensor_data, tensor.data().nbytes(), ge_tensor_data, sizeof(ge_tensor_data));
ASSERT_EQ(0, ret);
// Testify if the data has been copied to the tensor data
py::array_t<float> data = (py::array_t<float>)tensor.data();
py::array_t<float> data = py::cast<py::array_t<float>>(TensorPy::AsNumpy(tensor));
auto array = data.mutable_unchecked();
for (int i = 0; i < array.shape(0); i++) {
for (int j = 0; j < array.shape(1); j++) {
......@@ -340,5 +341,17 @@ TEST_F(TestTensor, TensorDataTest) {
}
}
TEST_F(TestTensor, TensorPyCast) {
std::vector<int> shape{2, 3, 4, 5};
py::tuple py_tuple = py::make_tuple(std::make_shared<Tensor>(kNumberTypeFloat32, shape));
auto shape1 = py::cast<Tensor &>(py_tuple[0]).shape();
const py::tuple &t = py_tuple;
auto shape2 = py::cast<const Tensor &>(t[0]).shape();
auto shape3 = py::cast<Tensor &>(t[0]).shape();
ASSERT_EQ(shape, shape1);
ASSERT_EQ(shape, shape2);
ASSERT_EQ(shape, shape3);
}
} // namespace tensor
} // namespace mindspore
......@@ -60,15 +60,9 @@ CNodePtr Make_Node(Shape x, Shape y, Shape out, int condition = 0) {
BaseShapePtr shape1 = std::make_shared<abstract::Shape>(x);
BaseShapePtr shape2 = std::make_shared<abstract::Shape>(y);
BaseShapePtr shape3 = std::make_shared<abstract::Shape>(out);
std::shared_ptr<tensor::Tensor> inputs_x = std::make_shared<tensor::Tensor>();
inputs_x->set_data_type(kNumberTypeInt32);
inputs_x->set_shape(x);
std::shared_ptr<tensor::Tensor> inputs_y = std::make_shared<tensor::Tensor>();
inputs_y->set_data_type(kNumberTypeInt32);
inputs_y->set_shape(y);
std::shared_ptr<tensor::Tensor> inputs_out = std::make_shared<tensor::Tensor>();
inputs_out->set_data_type(kNumberTypeInt32);
inputs_out->set_shape(out);
std::shared_ptr<tensor::Tensor> inputs_x = std::make_shared<tensor::Tensor>(kNumberTypeInt32, x);
std::shared_ptr<tensor::Tensor> inputs_y = std::make_shared<tensor::Tensor>(kNumberTypeInt32, y);
std::shared_ptr<tensor::Tensor> inputs_out = std::make_shared<tensor::Tensor>(kNumberTypeInt32, out);
AbstractBasePtr abstract1 = abstract::FromValue(inputs_x, true);
AbstractBasePtr abstract2 = abstract::FromValue(inputs_y, true);
AbstractBasePtr abstract3 = abstract::FromValue(inputs_out, true);
......@@ -127,21 +121,11 @@ FuncGraphManagerPtr Make_Manager(int condition = 0) {
ParameterPtr param1 = func_graph->add_parameter();
ParameterPtr param2 = func_graph->add_parameter();
ParameterPtr param3 = func_graph->add_parameter();
std::shared_ptr<tensor::Tensor> inputs_x_dim = std::make_shared<tensor::Tensor>();
inputs_x_dim->set_data_type(kNumberTypeInt32);
inputs_x_dim->set_shape(inputs_x);
std::shared_ptr<tensor::Tensor> inputs_y_dim = std::make_shared<tensor::Tensor>();
inputs_y_dim->set_data_type(kNumberTypeInt32);
inputs_y_dim->set_shape(inputs_y);
std::shared_ptr<tensor::Tensor> inputs_z_dim = std::make_shared<tensor::Tensor>();
inputs_z_dim->set_data_type(kNumberTypeInt32);
inputs_z_dim->set_shape(inputs_z);
std::shared_ptr<tensor::Tensor> inputs_out1_dim = std::make_shared<tensor::Tensor>();
inputs_out1_dim->set_data_type(kNumberTypeInt32);
inputs_out1_dim->set_shape(outputs_1);
std::shared_ptr<tensor::Tensor> inputs_out2_dim = std::make_shared<tensor::Tensor>();
inputs_out2_dim->set_data_type(kNumberTypeInt32);
inputs_out2_dim->set_shape(outputs_2);
std::shared_ptr<tensor::Tensor> inputs_x_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_x);
std::shared_ptr<tensor::Tensor> inputs_y_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_y);
std::shared_ptr<tensor::Tensor> inputs_z_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_z);
std::shared_ptr<tensor::Tensor> inputs_out1_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, outputs_1);
std::shared_ptr<tensor::Tensor> inputs_out2_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, outputs_2);
AbstractBasePtr abstract_x = abstract::FromValue(inputs_x_dim, true);
AbstractBasePtr abstract_y = abstract::FromValue(inputs_y_dim, true);
AbstractBasePtr abstract_z = abstract::FromValue(inputs_z_dim, true);
......
......@@ -113,12 +113,8 @@ TEST_F(TestData, test_build_shape) {
std::vector<int> weight1_dims = {2, 20, 5, 5};
std::vector<int> weight2_dims = {2, 2, 5, 5};
tensor::TensorPtr weight1 = std::make_shared<tensor::Tensor>();
weight1->set_data_type(kNumberTypeInt32);
weight1->set_shape(weight1_dims);
tensor::TensorPtr weight2 = std::make_shared<tensor::Tensor>();
weight2->set_data_type(kNumberTypeInt32);
weight2->set_shape(weight2_dims);
tensor::TensorPtr weight1 = std::make_shared<tensor::Tensor>(kNumberTypeInt32, weight1_dims);
tensor::TensorPtr weight2 = std::make_shared<tensor::Tensor>(kNumberTypeInt32, weight2_dims);
AbstractBasePtr abstract_weight1 = FromValue(weight1, true);
AbstractBasePtr abstract_weight2 = FromValue(weight2, true);
......
......@@ -104,7 +104,7 @@ TEST_F(TestHWConstInputToTensorInput, test_value_tuple_tensor_input) {
EXPECT_TRUE(IsValueNode<tensor::Tensor>(input1));
auto tensor = input1->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>();
ASSERT_TRUE(tensor != nullptr);
auto data = tensor->data_c(false);
auto data = tensor->data_c();
EXPECT_EQ(std::vector<int>((int *)data, (int *)data + 4), std::vector<int>({2, 4, 2, 2}));
}
} // namespace opt
......
......@@ -706,7 +706,7 @@ TEST_F(TestConvert, TestConvertTensor) {
auto type_id = kNumberTypeFloat32;
MeTensor me_tensor(type_id, dims);
// Get the writable data pointer of the tensor and cast it to its data type
uint8_t* me_data_ptr = reinterpret_cast<uint8_t*>(me_tensor.data_c(true));
uint8_t* me_data_ptr = reinterpret_cast<uint8_t*>(me_tensor.data_c());
// Copy or use the writable data pointer of the ME tensor
memcpy_s(me_data_ptr, me_tensor.data().nbytes(), data, 12 * sizeof(float));
auto me_tensor_ptr = std::make_shared<MeTensor>(me_tensor);
......
......@@ -18,6 +18,7 @@
#include <memory>
#include "common/common_test.h"
#include "ir/dtype.h"
#include "ir/tensor_py.h"
#include "transform/transform_base_test.h"
#include "common/py_func_graph_fetcher.h"
#include "pipeline/static_analysis/static_analysis.h"
......@@ -35,6 +36,8 @@
#define private public
#include "transform/graph_runner.h"
using mindspore::tensor::TensorPy;
namespace mindspore {
namespace transform {
class TestGraphRunner : public UT::Common {
......@@ -70,7 +73,7 @@ std::shared_ptr<DfGraphConvertor> MakeGeGraph() {
return std::make_shared<DfGraphConvertor>(anf_graph);
}
namespace {
std::shared_ptr<std::vector<MeTensorPtr>> DoExecGraph(const std::vector<MeTensorPtr>& inputs) {
std::shared_ptr<std::vector<MeTensorPtr>> DoExecGraph(const std::vector<MeTensorPtr> &inputs) {
std::vector<GeTensorPtr> ge_tensor_ptrs = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW);
std::vector<GeTensorPtr> ge_outputs;
......@@ -109,7 +112,7 @@ TEST_F(TestGraphRunner, TestGeTensorConstructor) {
MeTensor tensor = MeTensor(TypeId::kNumberTypeFloat32, std::vector<int>({1, 2, 3}));
// Get the writable data pointer from the tensor
float* me_tensor_data = reinterpret_cast<float*>(tensor.data_c(true));
float *me_tensor_data = reinterpret_cast<float *>(tensor.data_c());
// Copy data from buffer to tensor's data
memcpy_s(me_tensor_data, static_cast<size_t>(tensor.data().nbytes()), ge_tensor_data, sizeof(ge_tensor_data));
......@@ -119,11 +122,11 @@ TEST_F(TestGraphRunner, TestGeTensorConstructor) {
py::tuple py_tuple =
py::make_tuple(py::make_tuple(py::make_tuple(1.1f, 2.2f, 3.3f), py::make_tuple(4.4f, 5.5f, 6.6f)));
py::array my_arry = py::array(py_tuple).attr("astype").cast<py::function>()("float32").cast<py::array>();
MeTensor tensor_tuple = MeTensor(my_arry, kFloat32);
PrintMeTensor(&tensor_tuple);
auto tensor_tuple = TensorPy::MakeTensor(my_arry, kFloat32);
PrintMeTensor(tensor_tuple.get());
py::array tensor_array = tensor.data();
py::array tensor_tuple_array = tensor_tuple.data();
py::array tensor_array = TensorPy::AsNumpy(tensor);
py::array tensor_tuple_array = TensorPy::AsNumpy(*tensor_tuple);
assert(memcmp(ge_tensor_data, tensor_array.data(), sizeof(ge_tensor_data)) == 0);
assert(memcmp(ge_tensor_data, tensor_tuple_array.data(), sizeof(ge_tensor_data)) == 0);
}
......@@ -131,7 +134,7 @@ TEST_F(TestGraphRunner, TestGeTensorConstructor) {
#if (!defined ENABLE_GE)
TEST_F(TestGraphRunner, TestRunGraphException) {
DfGraphManager& graph_manager = DfGraphManager::GetInstance();
DfGraphManager &graph_manager = DfGraphManager::GetInstance();
graph_manager.ClearGraph();
std::map<string, MeTensorPtr> dict;
......@@ -167,7 +170,7 @@ TEST_F(TestGraphRunner, TestRunGraphException) {
}
TEST_F(TestGraphRunner, TestRunGraph) {
DfGraphManager& graph_manager = DfGraphManager::GetInstance();
DfGraphManager &graph_manager = DfGraphManager::GetInstance();
graph_manager.ClearGraph();
std::shared_ptr<DfGraphConvertor> convertor = MakeGeGraph();
......@@ -183,7 +186,7 @@ TEST_F(TestGraphRunner, TestRunGraph) {
py::make_tuple(py::make_tuple(py::make_tuple(1.0, 2.0, 3.0, 4.0), py::make_tuple(4.0, 5.0, 6.0, 7.0))),
py::make_tuple(py::make_tuple(py::make_tuple(1.0, 2.0, 3.0, 4.0), py::make_tuple(4.0, 5.0, 6.0, 7.0))));
py::array array = py::array(tuple);
MeTensorPtr me_tensor_ptr = std::make_shared<MeTensor>(array, type_id);
MeTensorPtr me_tensor_ptr = TensorPy::MakeTensor(array, type_id);
MS_LOG(INFO) << "inputs me tensor data is: ";
PrintMeTensor(&(*me_tensor_ptr));
......@@ -204,7 +207,7 @@ TEST_F(TestGraphRunner, TestRunGraph) {
}
TEST_F(TestGraphRunner, TestAPI) {
DfGraphManager& graph_manager = DfGraphManager::GetInstance();
DfGraphManager &graph_manager = DfGraphManager::GetInstance();
graph_manager.ClearGraph();
std::shared_ptr<DfGraphConvertor> convertor = MakeGeGraph();
......
......@@ -16,6 +16,9 @@
#include <iostream>
#include "common/common_test.h"
#include "transform/transform_base_test.h"
#include "ir/tensor_py.h"
using mindspore::tensor::TensorPy;
namespace mindspore {
namespace transform {
......@@ -55,10 +58,10 @@ void PrintMeTensor(MeTensor* tensor) {
}
std::cout << "the py::str() data is: " << std::endl;
py::array tensor_data = (*tensor).data();
py::array tensor_data = TensorPy::AsNumpy(*tensor);
std::cout << std::string(py::str(tensor_data)) << std::endl;
std::cout << "tensor dtype is: " << std::string(tensor->data().dtype().str()) << std::endl;
std::cout << "tensor dtype is: " << std::string(tensor_data.dtype().str()) << std::endl;
}
FuncGraphPtr MakeFuncGraph(const PrimitivePtr prim, unsigned int nparam) {
......@@ -73,7 +76,7 @@ FuncGraphPtr MakeFuncGraph(const PrimitivePtr prim, unsigned int nparam) {
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim));
for (unsigned int i = 0; i < nparam; i++) {
if ((prim->name() == "ScalarSummary" || prim->name() == "TensorSummary" ||
if ((prim->name() == "ScalarSummary" || prim->name() == "TensorSummary" ||
prim->name() == "ImageSummary" || prim->name() == "HistogramSummary") &&
i == 0) {
auto input = NewValueNode("testSummary");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册