diff --git a/mindspore/ccsrc/minddata/dataset/api/de_tensor.cc b/mindspore/ccsrc/minddata/dataset/api/de_tensor.cc index efb2dbdf973cfbbdb44bf3fff4545cfbaf85d991..498861647912362347cfd46b4bdf32ccabff3566 100644 --- a/mindspore/ccsrc/minddata/dataset/api/de_tensor.cc +++ b/mindspore/ccsrc/minddata/dataset/api/de_tensor.cc @@ -95,6 +95,19 @@ MSTensor *DETensor::CreateTensor(const std::string &path) { return new DETensor(std::move(t)); } +MSTensor *DETensor::CreateFromMemory(TypeId data_type, const std::vector &shape, void *data) { + std::shared_ptr t; + // prepare shape info + std::vector t_shape; + + std::transform(shape.begin(), shape.end(), std::back_inserter(t_shape), + [](int s) -> dataset::dsize_t { return static_cast(s); }); + + (void)dataset::Tensor::CreateFromMemory(dataset::TensorShape(t_shape), MSTypeToDEType(data_type), + static_cast(data), &t); + return new DETensor(std::move(t)); +} + DETensor::DETensor(TypeId data_type, const std::vector &shape) { std::vector t_shape; t_shape.reserve(shape.size()); diff --git a/mindspore/ccsrc/minddata/dataset/include/de_tensor.h b/mindspore/ccsrc/minddata/dataset/include/de_tensor.h index 749b9d35d923bd4d3789646800d39dbe91011a29..cdbe00b371356027f80f43b75c815bc9e8baa499 100644 --- a/mindspore/ccsrc/minddata/dataset/include/de_tensor.h +++ b/mindspore/ccsrc/minddata/dataset/include/de_tensor.h @@ -37,6 +37,14 @@ class DETensor : public MSTensor { /// \return - MSTensor pointer. static MSTensor *CreateTensor(const std::string &path); + /// \brief Create a MSTensor pointer. + /// \note This function returns null_ptr if tensor creation fails. + /// \param[data_type] DataTypeId of tensor to be created. + /// \param[shape] Shape of tensor to be created. + /// \param[data] Data pointer. + /// \return - MSTensor pointer. + static MSTensor *CreateFromMemory(TypeId data_type, const std::vector &shape, void *data); + DETensor(TypeId data_type, const std::vector &shape); explicit DETensor(std::shared_ptr tensor_ptr); diff --git a/mindspore/lite/test/ut/src/dataset/de_tensor_test.cc b/mindspore/lite/test/ut/src/dataset/de_tensor_test.cc index 6d2505f957dc545e2ac7ffc7464b12b488e0e0fb..3a0323d4487531c8a385d90def48d659b57119f7 100644 --- a/mindspore/lite/test/ut/src/dataset/de_tensor_test.cc +++ b/mindspore/lite/test/ut/src/dataset/de_tensor_test.cc @@ -96,3 +96,13 @@ TEST_F(MindDataTestTensorDE, MSTensorHash) { auto ms_tensor = std::shared_ptr(new DETensor(t)); ASSERT_EQ(ms_tensor->hash() == 11093771382437, true); } + +TEST_F(MindDataTestTensorDE, MSTensorCreateFromMemory) { + std::vector x = {2.5, 2.5, 2.5, 2.5}; + auto mem_tensor = DETensor::CreateFromMemory(mindspore::TypeId::kNumberTypeFloat32, {2, 2}, &x[0]); + std::shared_ptr t; + Tensor::CreateFromVector(x, TensorShape({2, 2}), &t); + auto ms_tensor = std::shared_ptr(new DETensor(t)); + ASSERT_EQ(ms_tensor->hash() == mem_tensor->hash(), true); +} +