提交 cc23f1d8 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4630 Adding wrapper around CreateFromMemory

Merge pull request !4630 from EricZ/md_tensor_from_mem
......@@ -95,6 +95,19 @@ MSTensor *DETensor::CreateTensor(const std::string &path) {
return new DETensor(std::move(t));
}
MSTensor *DETensor::CreateFromMemory(TypeId data_type, const std::vector<int> &shape, void *data) {
std::shared_ptr<dataset::Tensor> t;
// prepare shape info
std::vector<dataset::dsize_t> t_shape;
std::transform(shape.begin(), shape.end(), std::back_inserter(t_shape),
[](int s) -> dataset::dsize_t { return static_cast<dataset::dsize_t>(s); });
(void)dataset::Tensor::CreateFromMemory(dataset::TensorShape(t_shape), MSTypeToDEType(data_type),
static_cast<uchar *>(data), &t);
return new DETensor(std::move(t));
}
DETensor::DETensor(TypeId data_type, const std::vector<int> &shape) {
std::vector<dataset::dsize_t> t_shape;
t_shape.reserve(shape.size());
......
......@@ -37,6 +37,14 @@ class DETensor : public MSTensor {
/// \return - MSTensor pointer.
static MSTensor *CreateTensor(const std::string &path);
/// \brief Create a MSTensor pointer.
/// \note This function returns null_ptr if tensor creation fails.
/// \param[data_type] DataTypeId of tensor to be created.
/// \param[shape] Shape of tensor to be created.
/// \param[data] Data pointer.
/// \return - MSTensor pointer.
static MSTensor *CreateFromMemory(TypeId data_type, const std::vector<int> &shape, void *data);
DETensor(TypeId data_type, const std::vector<int> &shape);
explicit DETensor(std::shared_ptr<dataset::Tensor> tensor_ptr);
......
......@@ -96,3 +96,13 @@ TEST_F(MindDataTestTensorDE, MSTensorHash) {
auto ms_tensor = std::shared_ptr<MSTensor>(new DETensor(t));
ASSERT_EQ(ms_tensor->hash() == 11093771382437, true);
}
TEST_F(MindDataTestTensorDE, MSTensorCreateFromMemory) {
std::vector<float> x = {2.5, 2.5, 2.5, 2.5};
auto mem_tensor = DETensor::CreateFromMemory(mindspore::TypeId::kNumberTypeFloat32, {2, 2}, &x[0]);
std::shared_ptr<Tensor> t;
Tensor::CreateFromVector(x, TensorShape({2, 2}), &t);
auto ms_tensor = std::shared_ptr<MSTensor>(new DETensor(t));
ASSERT_EQ(ms_tensor->hash() == mem_tensor->hash(), true);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册