Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
65bf3ecc
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
65bf3ecc
编写于
8月 20, 2020
作者:
E
Eric
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Adding data_helper class
上级
39e27911
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
556 addition
and
3 deletion
+556
-3
mindspore/ccsrc/minddata/dataset/include/datasets.h
mindspore/ccsrc/minddata/dataset/include/datasets.h
+3
-3
mindspore/ccsrc/minddata/dataset/util/CMakeLists.txt
mindspore/ccsrc/minddata/dataset/util/CMakeLists.txt
+1
-0
mindspore/ccsrc/minddata/dataset/util/data_helper.cc
mindspore/ccsrc/minddata/dataset/util/data_helper.cc
+142
-0
mindspore/ccsrc/minddata/dataset/util/data_helper.h
mindspore/ccsrc/minddata/dataset/util/data_helper.h
+214
-0
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+1
-0
tests/ut/cpp/dataset/data_helper_test.cc
tests/ut/cpp/dataset/data_helper_test.cc
+195
-0
未找到文件。
mindspore/ccsrc/minddata/dataset/include/datasets.h
浏览文件 @
65bf3ecc
...
...
@@ -146,9 +146,9 @@ std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
/// (Default = 0 means all samples.)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
/// Can be any of:
/// ShuffleMode
.
kFalse - No shuffling is performed.
/// ShuffleMode
.
kFiles - Shuffle files only.
/// ShuffleMode
.
kGlobal - Shuffle both the files and samples.
/// ShuffleMode
::
kFalse - No shuffling is performed.
/// ShuffleMode
::
kFiles - Shuffle files only.
/// ShuffleMode
::
kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0)
...
...
mindspore/ccsrc/minddata/dataset/util/CMakeLists.txt
浏览文件 @
65bf3ecc
...
...
@@ -5,6 +5,7 @@ add_library(utils OBJECT
buddy.cc
cache_pool.cc
circular_pool.cc
data_helper.cc
memory_pool.cc
cond_var.cc
intrp_service.cc
...
...
mindspore/ccsrc/minddata/dataset/util/data_helper.cc
0 → 100644
浏览文件 @
65bf3ecc
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/util/data_helper.h"
#include <algorithm>
#include <fstream>
#include <iostream>
#include <map>
#include <memory>
#include <sstream>
#include <nlohmann/json.hpp>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "utils/log_adapter.h"
#include "minddata/dataset/util/path.h"
namespace
mindspore
{
namespace
dataset
{
// Create a numbered json file from image folder
Status
DataHelper
::
CreateAlbum
(
const
std
::
string
&
in_dir
,
const
std
::
string
&
out_dir
)
{
// in check
Path
base_dir
=
Path
(
in_dir
);
if
(
!
base_dir
.
IsDirectory
()
||
!
base_dir
.
Exists
())
{
RETURN_STATUS_UNEXPECTED
(
"Input dir is not a directory or doesn't exist"
);
}
// check if output_dir exists and create it if it does not exist
Path
target_dir
=
Path
(
out_dir
);
RETURN_IF_NOT_OK
(
target_dir
.
CreateDirectory
());
// iterate over in dir and create json for all images
uint64_t
index
=
0
;
auto
dir_it
=
Path
::
DirIterator
::
OpenDirectory
(
&
base_dir
);
while
(
dir_it
->
hasNext
())
{
Path
v
=
dir_it
->
next
();
// check if found file fits image extension
// create json file in output dir with the path
std
::
string
out_file
=
out_dir
+
"/"
+
std
::
to_string
(
index
)
+
".json"
;
UpdateValue
(
out_file
,
"image"
,
v
.
toString
(),
out_file
);
index
++
;
}
return
Status
::
OK
();
}
// A print method typically used for debugging
void
DataHelper
::
Print
(
std
::
ostream
&
out
)
const
{
out
<<
" Data Helper"
<<
"
\n
"
;
}
Status
DataHelper
::
UpdateArray
(
const
std
::
string
&
in_file
,
const
std
::
string
&
key
,
const
std
::
vector
<
std
::
string
>
&
value
,
const
std
::
string
&
out_file
)
{
try
{
Path
in
=
Path
(
in_file
);
nlohmann
::
json
js
;
if
(
in
.
Exists
())
{
std
::
ifstream
in_stream
(
in_file
);
MS_LOG
(
INFO
)
<<
"Filename: "
<<
in_file
<<
"."
;
in_stream
>>
js
;
in_stream
.
close
();
}
js
[
key
]
=
value
;
MS_LOG
(
INFO
)
<<
"Write outfile is: "
<<
js
<<
"."
;
if
(
out_file
==
""
)
{
std
::
ofstream
o
(
in_file
,
std
::
ofstream
::
trunc
);
o
<<
js
;
o
.
close
();
}
else
{
std
::
ofstream
o
(
out_file
,
std
::
ofstream
::
trunc
);
o
<<
js
;
o
.
close
();
}
}
// Catch any exception and convert to Status return code
catch
(
const
std
::
exception
&
err
)
{
RETURN_STATUS_UNEXPECTED
(
"Update json failed "
);
}
return
Status
::
OK
();
}
Status
DataHelper
::
RemoveKey
(
const
std
::
string
&
in_file
,
const
std
::
string
&
key
,
const
std
::
string
&
out_file
)
{
try
{
Path
in
=
Path
(
in_file
);
nlohmann
::
json
js
;
if
(
in
.
Exists
())
{
std
::
ifstream
in_stream
(
in_file
);
MS_LOG
(
INFO
)
<<
"Filename: "
<<
in_file
<<
"."
;
in_stream
>>
js
;
in_stream
.
close
();
}
js
.
erase
(
key
);
MS_LOG
(
INFO
)
<<
"Write outfile is: "
<<
js
<<
"."
;
if
(
out_file
==
""
)
{
std
::
ofstream
o
(
in_file
,
std
::
ofstream
::
trunc
);
o
<<
js
;
o
.
close
();
}
else
{
std
::
ofstream
o
(
out_file
,
std
::
ofstream
::
trunc
);
o
<<
js
;
o
.
close
();
}
}
// Catch any exception and convert to Status return code
catch
(
const
std
::
exception
&
err
)
{
RETURN_STATUS_UNEXPECTED
(
"Update json failed "
);
}
return
Status
::
OK
();
}
size_t
DataHelper
::
DumpTensor
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
void
*
addr
,
const
size_t
&
buffer_size
)
{
// get tensor size
size_t
tensor_size
=
input
->
SizeInBytes
();
// iterate over entire tensor
const
unsigned
char
*
tensor_addr
=
input
->
GetBuffer
();
// tensor iterator print
// write to address, input order is: destination, source
errno_t
ret
=
memcpy_s
(
addr
,
buffer_size
,
tensor_addr
,
tensor_size
);
if
(
ret
!=
0
)
{
// memcpy failed
MS_LOG
(
ERROR
)
<<
"memcpy tensor memory failed"
<<
"."
;
return
0
;
// amount of data copied is 0, error
}
return
tensor_size
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/util/data_helper.h
0 → 100644
浏览文件 @
65bf3ecc
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_DATA_HELPER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_DATA_HELPER_H_
#include <fstream>
#include <iostream>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include <nlohmann/json.hpp>
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "utils/log_adapter.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/status.h"
namespace
mindspore
{
namespace
dataset
{
/// \brief Simple class to do data manipulation, contains helper function to update json files in dataset
class
DataHelper
{
public:
/// \brief constructor
DataHelper
()
{}
/// \brief Destructor
~
DataHelper
()
=
default
;
/// \brief Create an Album dataset while taking in a path to a image folder
/// Creates the output directory if doesn't exist
/// \param[in] in_dir Image folder directory that takes in images
/// \param[in] out_dir Directory containing output json files
Status
CreateAlbum
(
const
std
::
string
&
in_dir
,
const
std
::
string
&
out_dir
);
/// \brief Update a json file field with a vector of integers
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional input for output file path, will write to input file if not specified
/// \return Status The error code return
Status
UpdateArray
(
const
std
::
string
&
in_file
,
const
std
::
string
&
key
,
const
std
::
vector
<
std
::
string
>
&
value
,
const
std
::
string
&
out_file
=
""
);
/// \brief Update a json file field with a vector of type T values
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The error code return
template
<
typename
T
>
Status
UpdateArray
(
const
std
::
string
&
in_file
,
const
std
::
string
&
key
,
const
std
::
vector
<
T
>
&
value
,
const
std
::
string
&
out_file
=
""
)
{
try
{
Path
in
=
Path
(
in_file
);
nlohmann
::
json
js
;
if
(
in
.
Exists
())
{
std
::
ifstream
in
(
in_file
);
MS_LOG
(
INFO
)
<<
"Filename: "
<<
in_file
<<
"."
;
in
>>
js
;
in
.
close
();
}
js
[
key
]
=
value
;
MS_LOG
(
INFO
)
<<
"Write outfile is: "
<<
js
<<
"."
;
if
(
out_file
==
""
)
{
std
::
ofstream
o
(
in_file
,
std
::
ofstream
::
trunc
);
o
<<
js
;
o
.
close
();
}
else
{
std
::
ofstream
o
(
out_file
,
std
::
ofstream
::
trunc
);
o
<<
js
;
o
.
close
();
}
}
// Catch any exception and convert to Status return code
catch
(
const
std
::
exception
&
err
)
{
RETURN_STATUS_UNEXPECTED
(
"Update json failed "
);
}
return
Status
::
OK
();
}
/// \brief Update a json file field with a single value of of type T
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The error code return
template
<
typename
T
>
Status
UpdateValue
(
const
std
::
string
&
in_file
,
const
std
::
string
&
key
,
const
T
&
value
,
const
std
::
string
&
out_file
=
""
)
{
try
{
Path
in
=
Path
(
in_file
);
nlohmann
::
json
js
;
if
(
in
.
Exists
())
{
std
::
ifstream
in
(
in_file
);
MS_LOG
(
INFO
)
<<
"Filename: "
<<
in_file
<<
"."
;
in
>>
js
;
in
.
close
();
}
js
[
key
]
=
value
;
MS_LOG
(
INFO
)
<<
"Write outfile is: "
<<
js
<<
"."
;
if
(
out_file
==
""
)
{
std
::
ofstream
o
(
in_file
,
std
::
ofstream
::
trunc
);
o
<<
js
;
o
.
close
();
}
else
{
std
::
ofstream
o
(
out_file
,
std
::
ofstream
::
trunc
);
o
<<
js
;
o
.
close
();
}
}
// Catch any exception and convert to Status return code
catch
(
const
std
::
exception
&
err
)
{
RETURN_STATUS_UNEXPECTED
(
"Update json failed "
);
}
return
Status
::
OK
();
}
/// \brief Template function to write tensor to file
/// \param[in] in_file File to write to
/// \param[in] data Array of type T values
/// \return Status The error code return
template
<
typename
T
>
Status
WriteBinFile
(
const
std
::
string
&
in_file
,
const
std
::
vector
<
T
>
&
data
)
{
try
{
std
::
ofstream
o
(
in_file
,
std
::
ios
::
binary
|
std
::
ios
::
out
);
if
(
!
o
.
is_open
())
{
RETURN_STATUS_UNEXPECTED
(
"Error opening Bin file to write"
);
}
size_t
length
=
data
.
size
();
o
.
write
(
reinterpret_cast
<
const
char
*>
(
&
data
[
0
]),
std
::
streamsize
(
length
*
sizeof
(
T
)));
o
.
close
();
}
// Catch any exception and convert to Status return code
catch
(
const
std
::
exception
&
err
)
{
RETURN_STATUS_UNEXPECTED
(
"Write bin file failed "
);
}
return
Status
::
OK
();
}
/// \brief Write pointer to bin, use pointer to avoid memcpy
/// \param[in] in_file File name to write to
/// \param[in] data Pointer to data
/// \param[in] length Length of values to write from pointer
/// \return Status The error code return
template
<
typename
T
>
Status
WriteBinFile
(
const
std
::
string
&
in_file
,
T
*
data
,
size_t
length
)
{
try
{
std
::
ofstream
o
(
in_file
,
std
::
ios
::
binary
|
std
::
ios
::
out
);
if
(
!
o
.
is_open
())
{
RETURN_STATUS_UNEXPECTED
(
"Error opening Bin file to write"
);
}
o
.
write
(
reinterpret_cast
<
const
char
*>
(
data
),
std
::
streamsize
(
length
*
sizeof
(
T
)));
o
.
close
();
}
// Catch any exception and convert to Status return code
catch
(
const
std
::
exception
&
err
)
{
RETURN_STATUS_UNEXPECTED
(
"Write bin file failed "
);
}
return
Status
::
OK
();
}
/// \brief Helper function to copy content of a tensor to buffer
/// \note This function iterates over the tensor in bytes, since
/// \param[in] input The tensor to copy value from
/// \param[out] addr The address to copy tensor data to
/// \param[in] buffer_size The buffer size of addr
/// \return The size of the tensor (bytes copied
size_t
DumpTensor
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
void
*
addr
,
const
size_t
&
buffer_size
);
/// \brief Helper function to delete key in json file
/// note This function will return okay even if key not found
/// \param[in] in_file Json file to remove key from
/// \param[in] key The key to remove
/// \return Status The error code return
Status
RemoveKey
(
const
std
::
string
&
in_file
,
const
std
::
string
&
key
,
const
std
::
string
&
out_file
=
""
);
/// \brief A print method typically used for debugging
/// \param out - The output stream to write output to
void
Print
(
std
::
ostream
&
out
)
const
;
/// \brief << Stream output operator overload
/// \notes This allows you to write the debug print info using stream operators
/// \param out Reference to the output stream being overloaded
/// \param ds Reference to the DataSchema to display
/// \return The output stream must be returned
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
DataHelper
&
dh
)
{
dh
.
Print
(
out
);
return
out
;
}
};
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_DATA_HELPER_H_
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
65bf3ecc
...
...
@@ -121,6 +121,7 @@ SET(DE_UT_SRCS
solarize_op_test.cc
swap_red_blue_test.cc
distributed_sampler_test.cc
data_helper_test.cc
)
if
(
ENABLE_PYTHON
)
...
...
tests/ut/cpp/dataset/data_helper_test.cc
0 → 100644
浏览文件 @
65bf3ecc
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include "common/common.h"
#include "minddata/dataset/core/client.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "minddata/dataset/util/data_helper.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/status.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "securec.h"
using
namespace
mindspore
::
dataset
;
using
mindspore
::
MsLogLevel
::
ERROR
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
LogStream
;
class
MindDataTestDataHelper
:
public
UT
::
DatasetOpTesting
{
protected:
};
TEST_F
(
MindDataTestDataHelper
,
MindDataTestHelper
)
{
std
::
string
file_path
=
datasets_root_path_
+
"/testAlbum/images/1.json"
;
DataHelper
dh
;
std
::
vector
<
std
::
string
>
new_label
=
{
"3"
,
"4"
};
Status
rc
=
dh
.
UpdateArray
(
file_path
,
"label"
,
new_label
);
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
"Return code error detected during label update: "
<<
"."
;
EXPECT_TRUE
(
false
);
}
}
TEST_F
(
MindDataTestDataHelper
,
MindDataTestAlbumGen
)
{
std
::
string
file_path
=
datasets_root_path_
+
"/testAlbum/original"
;
std
::
string
out_path
=
datasets_root_path_
+
"/testAlbum/testout"
;
DataHelper
dh
;
Status
rc
=
dh
.
CreateAlbum
(
file_path
,
out_path
);
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
"Return code error detected during album generation: "
<<
"."
;
EXPECT_TRUE
(
false
);
}
}
TEST_F
(
MindDataTestDataHelper
,
MindDataTestTemplateUpdateArrayInt
)
{
std
::
string
file_path
=
datasets_root_path_
+
"/testAlbum/testout/2.json"
;
DataHelper
dh
;
std
::
vector
<
int
>
new_label
=
{
3
,
4
};
Status
rc
=
dh
.
UpdateArray
(
file_path
,
"label"
,
new_label
);
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
"Return code error detected during json int array update: "
<<
"."
;
EXPECT_TRUE
(
false
);
}
}
TEST_F
(
MindDataTestDataHelper
,
MindDataTestTemplateUpdateArrayString
)
{
std
::
string
file_path
=
datasets_root_path_
+
"/testAlbum/testout/3.json"
;
DataHelper
dh
;
std
::
vector
<
std
::
string
>
new_label
=
{
"3"
,
"4"
};
Status
rc
=
dh
.
UpdateArray
(
file_path
,
"label"
,
new_label
);
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
"Return code error detected during json string array update: "
<<
"."
;
EXPECT_TRUE
(
false
);
}
}
TEST_F
(
MindDataTestDataHelper
,
MindDataTestTemplateUpdateValueInt
)
{
std
::
string
file_path
=
datasets_root_path_
+
"/testAlbum/testout/4.json"
;
DataHelper
dh
;
int
new_label
=
3
;
Status
rc
=
dh
.
UpdateValue
(
file_path
,
"label"
,
new_label
);
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
"Return code error detected during json int update: "
<<
"."
;
EXPECT_TRUE
(
false
);
}
}
TEST_F
(
MindDataTestDataHelper
,
MindDataTestTemplateUpdateString
)
{
std
::
string
file_path
=
datasets_root_path_
+
"/testAlbum/testout/5.json"
;
DataHelper
dh
;
std
::
string
new_label
=
"new label"
;
Status
rc
=
dh
.
UpdateValue
(
file_path
,
"label"
,
new_label
);
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
"Return code error detected during json string update: "
<<
"."
;
EXPECT_TRUE
(
false
);
}
}
TEST_F
(
MindDataTestDataHelper
,
MindDataTestDeleteKey
)
{
std
::
string
file_path
=
datasets_root_path_
+
"/testAlbum/testout/5.json"
;
DataHelper
dh
;
Status
rc
=
dh
.
RemoveKey
(
file_path
,
"label"
);
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
"Return code error detected during json key remove: "
<<
"."
;
EXPECT_TRUE
(
false
);
}
}
TEST_F
(
MindDataTestDataHelper
,
MindDataTestBinWrite
)
{
std
::
string
file_path
=
datasets_root_path_
+
"/testAlbum/1.bin"
;
DataHelper
dh
;
std
::
vector
<
float
>
bin_content
=
{
3
,
4
};
Status
rc
=
dh
.
WriteBinFile
(
file_path
,
bin_content
);
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
"Return code error detected during bin file write: "
<<
"."
;
EXPECT_TRUE
(
false
);
}
}
TEST_F
(
MindDataTestDataHelper
,
MindDataTestBinWritePointer
)
{
std
::
string
file_path
=
datasets_root_path_
+
"/testAlbum/2.bin"
;
DataHelper
dh
;
std
::
vector
<
float
>
bin_content
=
{
3
,
4
};
Status
rc
=
dh
.
WriteBinFile
(
file_path
,
&
bin_content
[
0
],
bin_content
.
size
());
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
"Return code error detected during binfile write: "
<<
"."
;
EXPECT_TRUE
(
false
);
}
}
TEST_F
(
MindDataTestDataHelper
,
MindDataTestTensorWriteFloat
)
{
// create tensor
std
::
vector
<
float
>
y
=
{
2.5
,
3.0
,
3.5
,
4.0
};
std
::
shared_ptr
<
Tensor
>
t
;
Tensor
::
CreateFromVector
(
y
,
&
t
);
// create buffer using system mempool
DataHelper
dh
;
void
*
data
=
malloc
(
t
->
SizeInBytes
());
auto
bytes_copied
=
dh
.
DumpTensor
(
std
::
move
(
t
),
data
,
t
->
SizeInBytes
());
if
(
bytes_copied
!=
t
->
SizeInBytes
())
{
EXPECT_TRUE
(
false
);
}
float
*
array
=
static_cast
<
float
*>
(
data
);
if
(
array
[
0
]
!=
2.5
)
{
EXPECT_TRUE
(
false
);
}
if
(
array
[
1
]
!=
3.0
)
{
EXPECT_TRUE
(
false
);
}
if
(
array
[
2
]
!=
3.5
)
{
EXPECT_TRUE
(
false
);
}
if
(
array
[
3
]
!=
4.0
)
{
EXPECT_TRUE
(
false
);
}
std
::
free
(
data
);
}
TEST_F
(
MindDataTestDataHelper
,
MindDataTestTensorWriteUInt
)
{
// create tensor
std
::
vector
<
uint8_t
>
y
=
{
1
,
2
,
3
,
4
};
std
::
shared_ptr
<
Tensor
>
t
;
Tensor
::
CreateFromVector
(
y
,
&
t
);
uint8_t
o
;
t
->
GetItemAt
<
uint8_t
>
(
&
o
,
{
0
,
0
});
MS_LOG
(
INFO
)
<<
"before op :"
<<
std
::
to_string
(
o
)
<<
"."
;
// create buffer using system mempool
DataHelper
dh
;
void
*
data
=
malloc
(
t
->
SizeInBytes
());
auto
bytes_copied
=
dh
.
DumpTensor
(
t
,
data
,
t
->
SizeInBytes
());
if
(
bytes_copied
!=
t
->
SizeInBytes
())
{
EXPECT_TRUE
(
false
);
}
t
->
GetItemAt
<
uint8_t
>
(
&
o
,
{});
MS_LOG
(
INFO
)
<<
"after op :"
<<
std
::
to_string
(
o
)
<<
"."
;
uint8_t
*
array
=
static_cast
<
uint8_t
*>
(
data
);
if
(
array
[
0
]
!=
1
)
{
EXPECT_TRUE
(
false
);
}
if
(
array
[
1
]
!=
2
)
{
EXPECT_TRUE
(
false
);
}
if
(
array
[
2
]
!=
3
)
{
EXPECT_TRUE
(
false
);
}
if
(
array
[
3
]
!=
4
)
{
EXPECT_TRUE
(
false
);
}
std
::
free
(
data
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录