Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b91e5637
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b91e5637
编写于
8月 13, 2020
作者:
X
xiefangqi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add randomdataset and schema
上级
2cc6230f
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
732 addition
and
82 deletion
+732
-82
mindspore/ccsrc/minddata/dataset/api/datasets.cc
mindspore/ccsrc/minddata/dataset/api/datasets.cc
+228
-0
mindspore/ccsrc/minddata/dataset/api/de_tensor.cc
mindspore/ccsrc/minddata/dataset/api/de_tensor.cc
+4
-65
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc
...nddata/dataset/engine/datasetops/source/random_data_op.cc
+7
-14
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h
...inddata/dataset/engine/datasetops/source/random_data_op.h
+2
-3
mindspore/ccsrc/minddata/dataset/include/datasets.h
mindspore/ccsrc/minddata/dataset/include/datasets.h
+131
-0
mindspore/ccsrc/minddata/dataset/include/type_id.h
mindspore/ccsrc/minddata/dataset/include/type_id.h
+88
-0
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+1
-0
tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc
tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc
+271
-0
未找到文件。
mindspore/ccsrc/minddata/dataset/api/datasets.cc
浏览文件 @
b91e5637
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
// Dataset operator headers (in alphabetical order)
// Dataset operator headers (in alphabetical order)
...
@@ -100,6 +101,15 @@ Dataset::Dataset() {
...
@@ -100,6 +101,15 @@ Dataset::Dataset() {
worker_connector_size_
=
cfg
->
worker_connector_size
();
worker_connector_size_
=
cfg
->
worker_connector_size
();
}
}
/// \brief Function to create a SchemaObj
/// \param[in] schema_file Path of schema file
/// \return Shared pointer to the current schema
std
::
shared_ptr
<
SchemaObj
>
Schema
(
const
std
::
string
&
schema_file
)
{
auto
schema
=
std
::
make_shared
<
SchemaObj
>
(
schema_file
);
return
schema
->
init
()
?
schema
:
nullptr
;
}
// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
// (In alphabetical order)
// (In alphabetical order)
...
@@ -353,6 +363,163 @@ std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Datas
...
@@ -353,6 +363,163 @@ std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Datas
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
}
}
SchemaObj
::
SchemaObj
(
const
std
::
string
&
schema_file
)
:
schema_file_
(
schema_file
),
num_rows_
(
0
),
dataset_type_
(
""
)
{}
// SchemaObj init function
bool
SchemaObj
::
init
()
{
if
(
schema_file_
!=
""
)
{
Path
schema_file
(
schema_file_
);
if
(
!
schema_file
.
Exists
())
{
MS_LOG
(
ERROR
)
<<
"The file "
<<
schema_file
<<
" does not exist or permission denied!"
;
return
false
;
}
nlohmann
::
json
js
;
try
{
std
::
ifstream
in
(
schema_file_
);
in
>>
js
;
}
catch
(
const
std
::
exception
&
err
)
{
MS_LOG
(
ERROR
)
<<
"Schema file failed to load"
;
return
false
;
}
return
from_json
(
js
);
}
return
true
;
}
// Function to add a column to schema with a mstype de_type
bool
SchemaObj
::
add_column
(
std
::
string
name
,
TypeId
de_type
,
std
::
vector
<
int32_t
>
shape
)
{
nlohmann
::
json
new_column
;
new_column
[
"name"
]
=
name
;
// if de_type is mstype
DataType
data_type
=
dataset
::
MSTypeToDEType
(
de_type
);
new_column
[
"type"
]
=
data_type
.
ToString
();
if
(
shape
.
size
()
>
0
)
{
new_column
[
"shape"
]
=
shape
;
new_column
[
"rank"
]
=
shape
.
size
();
}
else
{
new_column
[
"rank"
]
=
1
;
}
columns_
.
push_back
(
new_column
);
return
true
;
}
// Function to add a column to schema with a string de_type
bool
SchemaObj
::
add_column
(
std
::
string
name
,
std
::
string
de_type
,
std
::
vector
<
int32_t
>
shape
)
{
nlohmann
::
json
new_column
;
new_column
[
"name"
]
=
name
;
DataType
data_type
(
de_type
);
new_column
[
"type"
]
=
data_type
.
ToString
();
if
(
shape
.
size
()
>
0
)
{
new_column
[
"shape"
]
=
shape
;
new_column
[
"rank"
]
=
shape
.
size
();
}
else
{
new_column
[
"rank"
]
=
1
;
}
columns_
.
push_back
(
new_column
);
return
true
;
}
std
::
string
SchemaObj
::
to_json
()
{
nlohmann
::
json
json_file
;
json_file
[
"columns"
]
=
columns_
;
if
(
dataset_type_
!=
""
)
{
json_file
[
"datasetType"
]
=
dataset_type_
;
}
if
(
num_rows_
>
0
)
{
json_file
[
"numRows"
]
=
num_rows_
;
}
return
json_file
.
dump
(
2
);
}
bool
SchemaObj
::
parse_column
(
nlohmann
::
json
columns
)
{
std
::
string
name
,
de_type
;
std
::
vector
<
int32_t
>
shape
;
columns_
.
clear
();
if
(
columns
.
type
()
==
nlohmann
::
json
::
value_t
::
array
)
{
// reference to python list
for
(
auto
column
:
columns
)
{
auto
key_name
=
column
.
find
(
"name"
);
if
(
key_name
==
column
.
end
())
{
MS_LOG
(
ERROR
)
<<
"Column's name is missing"
;
return
false
;
}
name
=
*
key_name
;
auto
key_type
=
column
.
find
(
"type"
);
if
(
key_type
==
column
.
end
())
{
MS_LOG
(
ERROR
)
<<
"Column's type is missing"
;
return
false
;
}
de_type
=
*
key_type
;
shape
.
clear
();
auto
key_shape
=
column
.
find
(
"shape"
);
if
(
key_shape
!=
column
.
end
())
{
shape
.
insert
(
shape
.
end
(),
(
*
key_shape
).
begin
(),
(
*
key_shape
).
end
());
}
if
(
!
add_column
(
name
,
de_type
,
shape
))
{
return
false
;
}
}
}
else
if
(
columns
.
type
()
==
nlohmann
::
json
::
value_t
::
object
)
{
for
(
const
auto
&
it_child
:
columns
.
items
())
{
name
=
it_child
.
key
();
auto
key_type
=
it_child
.
value
().
find
(
"type"
);
if
(
key_type
==
it_child
.
value
().
end
())
{
MS_LOG
(
ERROR
)
<<
"Column's type is missing"
;
return
false
;
}
de_type
=
*
key_type
;
shape
.
clear
();
auto
key_shape
=
it_child
.
value
().
find
(
"shape"
);
if
(
key_shape
!=
it_child
.
value
().
end
())
{
shape
.
insert
(
shape
.
end
(),
(
*
key_shape
).
begin
(),
(
*
key_shape
).
end
());
}
if
(
!
add_column
(
name
,
de_type
,
shape
))
{
return
false
;
}
}
}
else
{
MS_LOG
(
ERROR
)
<<
"columns must be dict or list, columns contain name, type, shape(optional)."
;
return
false
;
}
return
true
;
}
bool
SchemaObj
::
from_json
(
nlohmann
::
json
json_obj
)
{
for
(
const
auto
&
it_child
:
json_obj
.
items
())
{
if
(
it_child
.
key
()
==
"datasetType"
)
{
dataset_type_
=
it_child
.
value
();
}
else
if
(
it_child
.
key
()
==
"numRows"
)
{
num_rows_
=
it_child
.
value
();
}
else
if
(
it_child
.
key
()
==
"columns"
)
{
if
(
!
parse_column
(
it_child
.
value
()))
{
MS_LOG
(
ERROR
)
<<
"parse columns failed"
;
return
false
;
}
}
else
{
MS_LOG
(
ERROR
)
<<
"Unknown field "
<<
it_child
.
key
();
return
false
;
}
}
if
(
columns_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"Columns are missing."
;
return
false
;
}
if
(
num_rows_
<=
0
)
{
MS_LOG
(
ERROR
)
<<
"numRows must be greater than 0"
;
return
false
;
}
return
true
;
}
// OTHER FUNCTIONS
// OTHER FUNCTIONS
// (In alphabetical order)
// (In alphabetical order)
...
@@ -864,6 +1031,67 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
...
@@ -864,6 +1031,67 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
return
node_ops
;
return
node_ops
;
}
}
// ValideParams for RandomDataset
bool
RandomDataset
::
ValidateParams
()
{
if
(
total_rows_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"RandomDataset: total_rows must be greater than 0, now get "
<<
total_rows_
;
return
false
;
}
return
true
;
}
int32_t
RandomDataset
::
GenRandomInt
(
int32_t
min
,
int32_t
max
)
{
std
::
uniform_int_distribution
<
int32_t
>
uniDist
(
min
,
max
);
return
uniDist
(
rand_gen_
);
}
// Build for RandomDataset
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
RandomDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
rand_gen_
.
seed
(
GetSeed
());
// seed the random generator
// If total rows was not given, then randomly pick a number
std
::
shared_ptr
<
SchemaObj
>
schema_obj
;
if
(
!
schema_path_
.
empty
())
schema_obj
=
std
::
make_shared
<
SchemaObj
>
(
schema_path_
);
if
(
schema_obj
!=
nullptr
&&
total_rows_
==
0
)
{
total_rows_
=
schema_obj
->
get_num_rows
();
}
// If user does not specify Sampler, create a default sampler based on the shuffle variable.
if
(
sampler_
==
nullptr
)
{
sampler_
=
CreateDefaultSampler
();
}
std
::
string
schema_json_string
,
schema_file_path
;
if
(
schema_
!=
nullptr
)
{
schema_
->
set_dataset_type
(
"Random"
);
if
(
total_rows_
!=
0
)
{
schema_
->
set_num_rows
(
total_rows_
);
}
schema_json_string
=
schema_
->
to_json
();
}
else
{
schema_file_path
=
schema_path_
;
}
std
::
unique_ptr
<
DataSchema
>
data_schema
;
std
::
vector
<
std
::
string
>
columns_to_load
;
if
(
!
schema_file_path
.
empty
()
||
!
schema_json_string
.
empty
())
{
data_schema
=
std
::
make_unique
<
DataSchema
>
();
if
(
!
schema_file_path
.
empty
())
{
data_schema
->
LoadSchemaFile
(
schema_file_path
,
columns_to_load
);
}
else
if
(
!
schema_json_string
.
empty
())
{
data_schema
->
LoadSchemaString
(
schema_json_string
,
columns_to_load
);
}
}
std
::
shared_ptr
<
RandomDataOp
>
op
;
op
=
std
::
make_shared
<
RandomDataOp
>
(
num_workers_
,
connector_que_size_
,
rows_per_buffer_
,
total_rows_
,
std
::
move
(
data_schema
),
std
::
move
(
sampler_
->
Build
()));
node_ops
.
push_back
(
op
);
return
node_ops
;
}
// Constructor for TextFileDataset
// Constructor for TextFileDataset
TextFileDataset
::
TextFileDataset
(
std
::
vector
<
std
::
string
>
dataset_files
,
int32_t
num_samples
,
ShuffleMode
shuffle
,
TextFileDataset
::
TextFileDataset
(
std
::
vector
<
std
::
string
>
dataset_files
,
int32_t
num_samples
,
ShuffleMode
shuffle
,
int32_t
num_shards
,
int32_t
shard_id
)
int32_t
num_shards
,
int32_t
shard_id
)
...
...
mindspore/ccsrc/minddata/dataset/api/de_tensor.cc
浏览文件 @
b91e5637
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
*/
*/
#include "minddata/dataset/include/de_tensor.h"
#include "minddata/dataset/include/de_tensor.h"
#include "minddata/dataset/include/type_id.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/core/data_type.h"
#include "mindspore/core/ir/dtype/type_id.h"
#include "mindspore/core/ir/dtype/type_id.h"
...
@@ -23,68 +24,6 @@
...
@@ -23,68 +24,6 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
tensor
{
namespace
tensor
{
dataset
::
DataType
MSTypeToDEType
(
TypeId
data_type
)
{
switch
(
data_type
)
{
case
kNumberTypeBool
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_BOOL
);
case
kNumberTypeInt8
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_INT8
);
case
kNumberTypeUInt8
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_UINT8
);
case
kNumberTypeInt16
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_INT16
);
case
kNumberTypeUInt16
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_UINT16
);
case
kNumberTypeInt32
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_INT32
);
case
kNumberTypeUInt32
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_UINT32
);
case
kNumberTypeInt64
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_INT64
);
case
kNumberTypeUInt64
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_UINT64
);
case
kNumberTypeFloat16
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_FLOAT16
);
case
kNumberTypeFloat32
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_FLOAT32
);
case
kNumberTypeFloat64
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_FLOAT64
);
default:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_UNKNOWN
);
}
}
TypeId
DETypeToMSType
(
dataset
::
DataType
data_type
)
{
switch
(
data_type
.
value
())
{
case
dataset
::
DataType
::
DE_BOOL
:
return
mindspore
::
TypeId
::
kNumberTypeBool
;
case
dataset
::
DataType
::
DE_INT8
:
return
mindspore
::
TypeId
::
kNumberTypeInt8
;
case
dataset
::
DataType
::
DE_UINT8
:
return
mindspore
::
TypeId
::
kNumberTypeUInt8
;
case
dataset
::
DataType
::
DE_INT16
:
return
mindspore
::
TypeId
::
kNumberTypeInt16
;
case
dataset
::
DataType
::
DE_UINT16
:
return
mindspore
::
TypeId
::
kNumberTypeUInt16
;
case
dataset
::
DataType
::
DE_INT32
:
return
mindspore
::
TypeId
::
kNumberTypeInt32
;
case
dataset
::
DataType
::
DE_UINT32
:
return
mindspore
::
TypeId
::
kNumberTypeUInt32
;
case
dataset
::
DataType
::
DE_INT64
:
return
mindspore
::
TypeId
::
kNumberTypeInt64
;
case
dataset
::
DataType
::
DE_UINT64
:
return
mindspore
::
TypeId
::
kNumberTypeUInt64
;
case
dataset
::
DataType
::
DE_FLOAT16
:
return
mindspore
::
TypeId
::
kNumberTypeFloat16
;
case
dataset
::
DataType
::
DE_FLOAT32
:
return
mindspore
::
TypeId
::
kNumberTypeFloat32
;
case
dataset
::
DataType
::
DE_FLOAT64
:
return
mindspore
::
TypeId
::
kNumberTypeFloat64
;
default:
return
kTypeUnknown
;
}
}
MSTensor
*
DETensor
::
CreateTensor
(
TypeId
data_type
,
const
std
::
vector
<
int
>
&
shape
)
{
MSTensor
*
DETensor
::
CreateTensor
(
TypeId
data_type
,
const
std
::
vector
<
int
>
&
shape
)
{
return
new
DETensor
(
data_type
,
shape
);
return
new
DETensor
(
data_type
,
shape
);
}
}
...
@@ -100,7 +39,7 @@ DETensor::DETensor(TypeId data_type, const std::vector<int> &shape) {
...
@@ -100,7 +39,7 @@ DETensor::DETensor(TypeId data_type, const std::vector<int> &shape) {
t_shape
.
reserve
(
shape
.
size
());
t_shape
.
reserve
(
shape
.
size
());
std
::
transform
(
shape
.
begin
(),
shape
.
end
(),
std
::
back_inserter
(
t_shape
),
std
::
transform
(
shape
.
begin
(),
shape
.
end
(),
std
::
back_inserter
(
t_shape
),
[](
int
s
)
->
dataset
::
dsize_t
{
return
static_cast
<
dataset
::
dsize_t
>
(
s
);
});
[](
int
s
)
->
dataset
::
dsize_t
{
return
static_cast
<
dataset
::
dsize_t
>
(
s
);
});
dataset
::
Tensor
::
CreateEmpty
(
dataset
::
TensorShape
(
t_shape
),
MSTypeToDEType
(
data_type
),
&
this
->
tensor_impl_
);
dataset
::
Tensor
::
CreateEmpty
(
dataset
::
TensorShape
(
t_shape
),
dataset
::
MSTypeToDEType
(
data_type
),
&
this
->
tensor_impl_
);
}
}
DETensor
::
DETensor
(
std
::
shared_ptr
<
dataset
::
Tensor
>
tensor_ptr
)
{
this
->
tensor_impl_
=
std
::
move
(
tensor_ptr
);
}
DETensor
::
DETensor
(
std
::
shared_ptr
<
dataset
::
Tensor
>
tensor_ptr
)
{
this
->
tensor_impl_
=
std
::
move
(
tensor_ptr
);
}
...
@@ -120,14 +59,14 @@ std::shared_ptr<dataset::Tensor> DETensor::tensor() const {
...
@@ -120,14 +59,14 @@ std::shared_ptr<dataset::Tensor> DETensor::tensor() const {
TypeId
DETensor
::
data_type
()
const
{
TypeId
DETensor
::
data_type
()
const
{
MS_ASSERT
(
this
->
tensor_impl_
!=
nullptr
);
MS_ASSERT
(
this
->
tensor_impl_
!=
nullptr
);
return
DETypeToMSType
(
this
->
tensor_impl_
->
type
());
return
dataset
::
DETypeToMSType
(
this
->
tensor_impl_
->
type
());
}
}
TypeId
DETensor
::
set_data_type
(
TypeId
data_type
)
{
TypeId
DETensor
::
set_data_type
(
TypeId
data_type
)
{
MS_ASSERT
(
this
->
tensor_impl_
!=
nullptr
);
MS_ASSERT
(
this
->
tensor_impl_
!=
nullptr
);
if
(
data_type
!=
this
->
data_type
())
{
if
(
data_type
!=
this
->
data_type
())
{
std
::
shared_ptr
<
dataset
::
Tensor
>
temp
;
std
::
shared_ptr
<
dataset
::
Tensor
>
temp
;
dataset
::
Tensor
::
CreateFromMemory
(
this
->
tensor_impl_
->
shape
(),
MSTypeToDEType
(
data_type
),
dataset
::
Tensor
::
CreateFromMemory
(
this
->
tensor_impl_
->
shape
(),
dataset
::
MSTypeToDEType
(
data_type
),
this
->
tensor_impl_
->
GetBuffer
(),
&
temp
);
this
->
tensor_impl_
->
GetBuffer
(),
&
temp
);
this
->
tensor_impl_
=
temp
;
this
->
tensor_impl_
=
temp
;
}
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc
浏览文件 @
b91e5637
...
@@ -50,13 +50,6 @@ Status RandomDataOp::Builder::Build(std::shared_ptr<RandomDataOp> *out_op) {
...
@@ -50,13 +50,6 @@ Status RandomDataOp::Builder::Build(std::shared_ptr<RandomDataOp> *out_op) {
std
::
make_shared
<
RandomDataOp
>
(
builder_num_workers_
,
builder_op_connector_size_
,
builder_rows_per_buffer_
,
std
::
make_shared
<
RandomDataOp
>
(
builder_num_workers_
,
builder_op_connector_size_
,
builder_rows_per_buffer_
,
builder_total_rows_
,
std
::
move
(
builder_data_schema_
),
std
::
move
(
builder_sampler_
));
builder_total_rows_
,
std
::
move
(
builder_data_schema_
),
std
::
move
(
builder_sampler_
));
// If the user did not provide a schema, then we will ask the op to generate a pseudo-random
// schema.
// See details of generateSchema function to learn what type of schema it will create.
if
((
*
out_op
)
->
data_schema_
==
nullptr
)
{
RETURN_IF_NOT_OK
((
*
out_op
)
->
GenerateSchema
());
}
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
@@ -85,6 +78,12 @@ RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64
...
@@ -85,6 +78,12 @@ RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64
if
(
total_rows_
==
0
)
{
if
(
total_rows_
==
0
)
{
total_rows_
=
GenRandomInt
(
1
,
kMaxTotalRows
);
total_rows_
=
GenRandomInt
(
1
,
kMaxTotalRows
);
}
}
// If the user did not provide a schema, then we will ask the op to generate a pseudo-random
// schema.
// See details of generateSchema function to learn what type of schema it will create.
if
(
data_schema_
==
nullptr
)
{
GenerateSchema
();
}
// Everyone is already out from the sync area.
// Everyone is already out from the sync area.
all_out_
.
Set
();
all_out_
.
Set
();
}
}
...
@@ -106,11 +105,7 @@ void RandomDataOp::Print(std::ostream &out, bool show_all) const {
...
@@ -106,11 +105,7 @@ void RandomDataOp::Print(std::ostream &out, bool show_all) const {
}
}
// Helper function to produce a default/random schema if one didn't exist
// Helper function to produce a default/random schema if one didn't exist
Status
RandomDataOp
::
GenerateSchema
()
{
void
RandomDataOp
::
GenerateSchema
()
{
if
(
data_schema_
!=
nullptr
)
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"Generating a schema but one already exists!"
);
}
// To randomly create a schema, we need to choose:
// To randomly create a schema, we need to choose:
// a) how many columns
// a) how many columns
// b) the type of each column
// b) the type of each column
...
@@ -144,8 +139,6 @@ Status RandomDataOp::GenerateSchema() {
...
@@ -144,8 +139,6 @@ Status RandomDataOp::GenerateSchema() {
data_schema_
->
AddColumn
(
*
newCol
);
data_schema_
->
AddColumn
(
*
newCol
);
}
}
return
Status
::
OK
();
}
}
// Class functor operator () override.
// Class functor operator () override.
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h
浏览文件 @
b91e5637
...
@@ -213,9 +213,8 @@ class RandomDataOp : public ParallelOp {
...
@@ -213,9 +213,8 @@ class RandomDataOp : public ParallelOp {
/**
/**
* Helper function to produce a default/random schema if one didn't exist
* Helper function to produce a default/random schema if one didn't exist
@return Status - The error code return
*/
*/
void
GenerateSchema
();
Status
GenerateSchema
();
/**
/**
* Performs a synchronization between workers at the end of an epoch
* Performs a synchronization between workers at the end of an epoch
...
...
mindspore/ccsrc/minddata/dataset/include/datasets.h
浏览文件 @
b91e5637
...
@@ -24,9 +24,11 @@
...
@@ -24,9 +24,11 @@
#include <utility>
#include <utility>
#include <string>
#include <string>
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/include/tensor.h"
#include "minddata/dataset/include/tensor.h"
#include "minddata/dataset/include/iterator.h"
#include "minddata/dataset/include/iterator.h"
#include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/include/type_id.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
dataset
{
namespace
dataset
{
...
@@ -40,6 +42,7 @@ class TensorShape;
...
@@ -40,6 +42,7 @@ class TensorShape;
namespace
api
{
namespace
api
{
class
TensorOperation
;
class
TensorOperation
;
class
SchemaObj
;
class
SamplerObj
;
class
SamplerObj
;
// Datasets classes (in alphabetical order)
// Datasets classes (in alphabetical order)
class
CelebADataset
;
class
CelebADataset
;
...
@@ -49,6 +52,7 @@ class CLUEDataset;
...
@@ -49,6 +52,7 @@ class CLUEDataset;
class
CocoDataset
;
class
CocoDataset
;
class
ImageFolderDataset
;
class
ImageFolderDataset
;
class
MnistDataset
;
class
MnistDataset
;
class
RandomDataset
;
class
TextFileDataset
;
class
TextFileDataset
;
class
VOCDataset
;
class
VOCDataset
;
// Dataset Op classes (in alphabetical order)
// Dataset Op classes (in alphabetical order)
...
@@ -63,6 +67,11 @@ class SkipDataset;
...
@@ -63,6 +67,11 @@ class SkipDataset;
class
TakeDataset
;
class
TakeDataset
;
class
ZipDataset
;
class
ZipDataset
;
/// \brief Function to create a SchemaObj
/// \param[in] schema_file Path of schema file
/// \return Shared pointer to the current schema
std
::
shared_ptr
<
SchemaObj
>
Schema
(
const
std
::
string
&
schema_file
=
""
);
/// \brief Function to create a CelebADataset
/// \brief Function to create a CelebADataset
/// \notes The generated dataset has two columns ['image', 'attr'].
/// \notes The generated dataset has two columns ['image', 'attr'].
// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
...
@@ -167,6 +176,21 @@ std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<Sam
...
@@ -167,6 +176,21 @@ std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<Sam
std
::
shared_ptr
<
ConcatDataset
>
operator
+
(
const
std
::
shared_ptr
<
Dataset
>
&
datasets1
,
std
::
shared_ptr
<
ConcatDataset
>
operator
+
(
const
std
::
shared_ptr
<
Dataset
>
&
datasets1
,
const
std
::
shared_ptr
<
Dataset
>
&
datasets2
);
const
std
::
shared_ptr
<
Dataset
>
&
datasets2
);
/// \brief Function to create a RandomDataset
/// \param[in] total_rows Number of rows for the dataset to generate (default=0, number of rows is random)
/// \param[in] schema SchemaObj to set column type, data type and data shape
/// \param[in] columns_list List of columns to be read (default=None, read all columns)
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
template
<
typename
T
=
std
::
shared_ptr
<
SchemaObj
>
>
std
::
shared_ptr
<
RandomDataset
>
RandomData
(
const
int32_t
&
total_rows
=
0
,
T
schema
=
nullptr
,
std
::
vector
<
std
::
string
>
columns_list
=
{},
std
::
shared_ptr
<
SamplerObj
>
sampler
=
nullptr
)
{
auto
ds
=
std
::
make_shared
<
RandomDataset
>
(
total_rows
,
schema
,
std
::
move
(
columns_list
),
std
::
move
(
sampler
));
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
}
/// \brief Function to create a TextFileDataset
/// \brief Function to create a TextFileDataset
/// \notes The generated dataset has one column ['text']
/// \notes The generated dataset has one column ['text']
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
...
@@ -335,6 +359,66 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
...
@@ -335,6 +359,66 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
int32_t
worker_connector_size_
;
int32_t
worker_connector_size_
;
};
};
class
SchemaObj
{
public:
/// \brief Constructor
explicit
SchemaObj
(
const
std
::
string
&
schema_file
=
""
);
/// \brief Destructor
~
SchemaObj
()
=
default
;
/// \brief SchemaObj init function
/// \return bool true if schema init success
bool
init
();
/// \brief Add new column to the schema
/// \param[in] name name of the column.
/// \param[in] de_type data type of the column(TypeId).
/// \param[in] shape shape of the column.
/// \return bool true if schema init success
bool
add_column
(
std
::
string
name
,
TypeId
de_type
,
std
::
vector
<
int32_t
>
shape
);
/// \brief Add new column to the schema
/// \param[in] name name of the column.
/// \param[in] de_type data type of the column(std::string).
/// \param[in] shape shape of the column.
/// \return bool true if schema init success
bool
add_column
(
std
::
string
name
,
std
::
string
de_type
,
std
::
vector
<
int32_t
>
shape
);
/// \brief Get a JSON string of the schema
/// \return JSON string of the schema
std
::
string
to_json
();
/// \brief Get a JSON string of the schema
std
::
string
to_string
()
{
return
to_json
();
}
/// \brief set a new value to dataset_type
inline
void
set_dataset_type
(
std
::
string
dataset_type
)
{
dataset_type_
=
dataset_type
;
}
/// \brief set a new value to num_rows
inline
void
set_num_rows
(
int32_t
num_rows
)
{
num_rows_
=
num_rows
;
}
/// \brief get the current num_rows
inline
int32_t
get_num_rows
()
{
return
num_rows_
;
}
private:
/// \brief Parse the columns and add it to columns
/// \param[in] columns dataset attribution information, decoded from schema file.
/// support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject.
/// \return JSON string of the schema
bool
parse_column
(
nlohmann
::
json
columns
);
/// \brief Get schema file from json file
/// \param[in] json_obj object of json parsed.
/// \return bool true if json dump success
bool
from_json
(
nlohmann
::
json
json_obj
);
int32_t
num_rows_
;
std
::
string
dataset_type_
;
std
::
string
schema_file_
;
nlohmann
::
json
columns_
;
};
/* ####################################### Derived Dataset classes ################################# */
/* ####################################### Derived Dataset classes ################################# */
// DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS
// DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS
...
@@ -517,6 +601,53 @@ class MnistDataset : public Dataset {
...
@@ -517,6 +601,53 @@ class MnistDataset : public Dataset {
std
::
shared_ptr
<
SamplerObj
>
sampler_
;
std
::
shared_ptr
<
SamplerObj
>
sampler_
;
};
};
class
RandomDataset
:
public
Dataset
{
public:
// Some constants to provide limits to random generation.
static
constexpr
int32_t
kMaxNumColumns
=
4
;
static
constexpr
int32_t
kMaxRank
=
4
;
static
constexpr
int32_t
kMaxDimValue
=
32
;
/// \brief Constructor
RandomDataset
(
const
int32_t
&
total_rows
,
std
::
shared_ptr
<
SchemaObj
>
schema
,
std
::
vector
<
std
::
string
>
columns_list
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
:
total_rows_
(
total_rows
),
schema_path_
(
""
),
schema_
(
std
::
move
(
schema
)),
columns_list_
(
columns_list
),
sampler_
(
std
::
move
(
sampler
))
{}
/// \brief Constructor
RandomDataset
(
const
int32_t
&
total_rows
,
std
::
string
schema_path
,
std
::
vector
<
std
::
string
>
columns_list
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
:
total_rows_
(
total_rows
),
schema_path_
(
schema_path
),
columns_list_
(
columns_list
),
sampler_
(
std
::
move
(
sampler
))
{}
/// \brief Destructor
~
RandomDataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
Build
()
override
;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool
ValidateParams
()
override
;
private:
/// \brief A quick inline for producing a random number between (and including) min/max
/// \param[in] min minimum number that can be generated.
/// \param[in] max maximum number that can be generated.
/// \return The generated random number
int32_t
GenRandomInt
(
int32_t
min
,
int32_t
max
);
int32_t
total_rows_
;
std
::
string
schema_path_
;
std
::
shared_ptr
<
SchemaObj
>
schema_
;
std
::
vector
<
std
::
string
>
columns_list_
;
std
::
shared_ptr
<
SamplerObj
>
sampler_
;
std
::
mt19937
rand_gen_
;
};
/// \class TextFileDataset
/// \class TextFileDataset
/// \brief A Dataset derived class to represent TextFile dataset
/// \brief A Dataset derived class to represent TextFile dataset
class
TextFileDataset
:
public
Dataset
{
class
TextFileDataset
:
public
Dataset
{
...
...
mindspore/ccsrc/minddata/dataset/include/type_id.h
0 → 100644
浏览文件 @
b91e5637
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TYPEID_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TYPEID_H_
#include "minddata/dataset/core/data_type.h"
#include "mindspore/core/ir/dtype/type_id.h"
namespace
mindspore
{
namespace
dataset
{
inline
dataset
::
DataType
MSTypeToDEType
(
TypeId
data_type
)
{
switch
(
data_type
)
{
case
kNumberTypeBool
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_BOOL
);
case
kNumberTypeInt8
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_INT8
);
case
kNumberTypeUInt8
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_UINT8
);
case
kNumberTypeInt16
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_INT16
);
case
kNumberTypeUInt16
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_UINT16
);
case
kNumberTypeInt32
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_INT32
);
case
kNumberTypeUInt32
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_UINT32
);
case
kNumberTypeInt64
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_INT64
);
case
kNumberTypeUInt64
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_UINT64
);
case
kNumberTypeFloat16
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_FLOAT16
);
case
kNumberTypeFloat32
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_FLOAT32
);
case
kNumberTypeFloat64
:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_FLOAT64
);
default:
return
dataset
::
DataType
(
dataset
::
DataType
::
DE_UNKNOWN
);
}
}
inline
TypeId
DETypeToMSType
(
dataset
::
DataType
data_type
)
{
switch
(
data_type
.
value
())
{
case
dataset
::
DataType
::
DE_BOOL
:
return
mindspore
::
TypeId
::
kNumberTypeBool
;
case
dataset
::
DataType
::
DE_INT8
:
return
mindspore
::
TypeId
::
kNumberTypeInt8
;
case
dataset
::
DataType
::
DE_UINT8
:
return
mindspore
::
TypeId
::
kNumberTypeUInt8
;
case
dataset
::
DataType
::
DE_INT16
:
return
mindspore
::
TypeId
::
kNumberTypeInt16
;
case
dataset
::
DataType
::
DE_UINT16
:
return
mindspore
::
TypeId
::
kNumberTypeUInt16
;
case
dataset
::
DataType
::
DE_INT32
:
return
mindspore
::
TypeId
::
kNumberTypeInt32
;
case
dataset
::
DataType
::
DE_UINT32
:
return
mindspore
::
TypeId
::
kNumberTypeUInt32
;
case
dataset
::
DataType
::
DE_INT64
:
return
mindspore
::
TypeId
::
kNumberTypeInt64
;
case
dataset
::
DataType
::
DE_UINT64
:
return
mindspore
::
TypeId
::
kNumberTypeUInt64
;
case
dataset
::
DataType
::
DE_FLOAT16
:
return
mindspore
::
TypeId
::
kNumberTypeFloat16
;
case
dataset
::
DataType
::
DE_FLOAT32
:
return
mindspore
::
TypeId
::
kNumberTypeFloat32
;
case
dataset
::
DataType
::
DE_FLOAT64
:
return
mindspore
::
TypeId
::
kNumberTypeFloat64
;
default:
return
kTypeUnknown
;
}
}
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TYPEID_H_
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
b91e5637
...
@@ -100,6 +100,7 @@ SET(DE_UT_SRCS
...
@@ -100,6 +100,7 @@ SET(DE_UT_SRCS
c_api_dataset_clue_test.cc
c_api_dataset_clue_test.cc
c_api_dataset_coco_test.cc
c_api_dataset_coco_test.cc
c_api_dataset_filetext_test.cc
c_api_dataset_filetext_test.cc
c_api_dataset_randomdata_test.cc
c_api_dataset_voc_test.cc
c_api_dataset_voc_test.cc
c_api_datasets_test.cc
c_api_datasets_test.cc
c_api_dataset_iterator_test.cc
c_api_dataset_iterator_test.cc
...
...
tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc
0 → 100644
浏览文件 @
b91e5637
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/global_context.h"
#include "mindspore/core/ir/dtype/type_id.h"
using
namespace
mindspore
::
dataset
;
using
namespace
mindspore
::
dataset
::
api
;
using
mindspore
::
dataset
::
Tensor
;
using
mindspore
::
dataset
::
TensorShape
;
using
mindspore
::
dataset
::
DataType
;
class
MindDataTestPipeline
:
public
UT
::
DatasetOpTesting
{
protected:
};
TEST_F
(
MindDataTestPipeline
,
TestRandomDatasetBasic1
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestRandomDatasetBasic1."
;
// Create a RandomDataset
std
::
shared_ptr
<
SchemaObj
>
schema
=
Schema
();
schema
->
add_column
(
"image"
,
mindspore
::
TypeId
::
kNumberTypeUInt8
,
{
2
});
schema
->
add_column
(
"label"
,
mindspore
::
TypeId
::
kNumberTypeUInt8
,
{
1
});
std
::
shared_ptr
<
Dataset
>
ds
=
RandomData
(
50
,
schema
);
EXPECT_NE
(
ds
,
nullptr
);
ds
=
ds
->
SetNumWorkers
(
4
);
EXPECT_NE
(
ds
,
nullptr
);
// Create a Repeat operation on ds
ds
=
ds
->
Repeat
(
4
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
// Check if RandomDataOp read correct columns
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
auto
image
=
row
[
"image"
];
auto
label
=
row
[
"label"
];
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
MS_LOG
(
INFO
)
<<
"Tensor label shape: "
<<
label
->
shape
();
iter
->
GetNextRow
(
&
row
);
i
++
;
}
EXPECT_EQ
(
i
,
200
);
// Manually terminate the pipeline
iter
->
Stop
();
}
TEST_F
(
MindDataTestPipeline
,
TestRandomDatasetBasic2
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestRandomDatasetBasic2."
;
// Create a RandomDataset
std
::
shared_ptr
<
Dataset
>
ds
=
RandomData
(
10
);
EXPECT_NE
(
ds
,
nullptr
);
ds
=
ds
->
SetNumWorkers
(
1
);
EXPECT_NE
(
ds
,
nullptr
);
// Create a Repeat operation on ds
ds
=
ds
->
Repeat
(
2
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
// Check if RandomDataOp read correct columns
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
auto
image
=
row
[
"image"
];
auto
label
=
row
[
"label"
];
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
MS_LOG
(
INFO
)
<<
"Tensor label shape: "
<<
label
->
shape
();
iter
->
GetNextRow
(
&
row
);
i
++
;
}
EXPECT_EQ
(
i
,
20
);
// Manually terminate the pipeline
iter
->
Stop
();
}
TEST_F
(
MindDataTestPipeline
,
TestRandomDatasetBasic3
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestRandomDatasetBasic3."
;
// Create a RandomDataset
u_int32_t
curr_seed
=
GlobalContext
::
config_manager
()
->
seed
();
GlobalContext
::
config_manager
()
->
set_seed
(
246
);
std
::
string
SCHEMA_FILE
=
datasets_root_path_
+
"/testTFTestAllTypes/datasetSchema.json"
;
std
::
shared_ptr
<
SchemaObj
>
schema
=
Schema
(
SCHEMA_FILE
);
std
::
shared_ptr
<
Dataset
>
ds
=
RandomData
(
0
,
schema
);
EXPECT_NE
(
ds
,
nullptr
);
// Create a Repeat operation on ds
ds
=
ds
->
Repeat
(
2
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
// Check if RandomDataOp read correct columns
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
auto
col_sint16
=
row
[
"col_sint16"
];
auto
col_sint32
=
row
[
"col_sint32"
];
auto
col_sint64
=
row
[
"col_sint64"
];
auto
col_float
=
row
[
"col_float"
];
auto
col_1d
=
row
[
"col_1d"
];
auto
col_2d
=
row
[
"col_2d"
];
auto
col_3d
=
row
[
"col_3d"
];
auto
col_binary
=
row
[
"col_binary"
];
// validate shape
ASSERT_EQ
(
col_sint16
->
shape
(),
TensorShape
({
1
}));
ASSERT_EQ
(
col_sint32
->
shape
(),
TensorShape
({
1
}));
ASSERT_EQ
(
col_sint64
->
shape
(),
TensorShape
({
1
}));
ASSERT_EQ
(
col_float
->
shape
(),
TensorShape
({
1
}));
ASSERT_EQ
(
col_1d
->
shape
(),
TensorShape
({
2
}));
ASSERT_EQ
(
col_2d
->
shape
(),
TensorShape
({
2
,
2
}));
ASSERT_EQ
(
col_3d
->
shape
(),
TensorShape
({
2
,
2
,
2
}));
ASSERT_EQ
(
col_binary
->
shape
(),
TensorShape
({
1
}));
// validate Rank
ASSERT_EQ
(
col_sint16
->
Rank
(),
1
);
ASSERT_EQ
(
col_sint32
->
Rank
(),
1
);
ASSERT_EQ
(
col_sint64
->
Rank
(),
1
);
ASSERT_EQ
(
col_float
->
Rank
(),
1
);
ASSERT_EQ
(
col_1d
->
Rank
(),
1
);
ASSERT_EQ
(
col_2d
->
Rank
(),
2
);
ASSERT_EQ
(
col_3d
->
Rank
(),
3
);
ASSERT_EQ
(
col_binary
->
Rank
(),
1
);
// validate type
ASSERT_EQ
(
col_sint16
->
type
(),
DataType
::
DE_INT16
);
ASSERT_EQ
(
col_sint32
->
type
(),
DataType
::
DE_INT32
);
ASSERT_EQ
(
col_sint64
->
type
(),
DataType
::
DE_INT64
);
ASSERT_EQ
(
col_float
->
type
(),
DataType
::
DE_FLOAT32
);
ASSERT_EQ
(
col_1d
->
type
(),
DataType
::
DE_INT64
);
ASSERT_EQ
(
col_2d
->
type
(),
DataType
::
DE_INT64
);
ASSERT_EQ
(
col_3d
->
type
(),
DataType
::
DE_INT64
);
ASSERT_EQ
(
col_binary
->
type
(),
DataType
::
DE_UINT8
);
iter
->
GetNextRow
(
&
row
);
i
++
;
}
EXPECT_EQ
(
i
,
984
);
// Manually terminate the pipeline
iter
->
Stop
();
GlobalContext
::
config_manager
()
->
set_seed
(
curr_seed
);
}
TEST_F
(
MindDataTestPipeline
,
TestRandomDatasetBasic4
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestRandomDatasetBasic3."
;
// Create a RandomDataset
u_int32_t
curr_seed
=
GlobalContext
::
config_manager
()
->
seed
();
GlobalContext
::
config_manager
()
->
set_seed
(
246
);
std
::
string
SCHEMA_FILE
=
datasets_root_path_
+
"/testTFTestAllTypes/datasetSchema.json"
;
std
::
shared_ptr
<
Dataset
>
ds
=
RandomData
(
0
,
SCHEMA_FILE
);
EXPECT_NE
(
ds
,
nullptr
);
// Create a Repeat operation on ds
ds
=
ds
->
Repeat
(
2
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
// Check if RandomDataOp read correct columns
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
auto
col_sint16
=
row
[
"col_sint16"
];
auto
col_sint32
=
row
[
"col_sint32"
];
auto
col_sint64
=
row
[
"col_sint64"
];
auto
col_float
=
row
[
"col_float"
];
auto
col_1d
=
row
[
"col_1d"
];
auto
col_2d
=
row
[
"col_2d"
];
auto
col_3d
=
row
[
"col_3d"
];
auto
col_binary
=
row
[
"col_binary"
];
// validate shape
ASSERT_EQ
(
col_sint16
->
shape
(),
TensorShape
({
1
}));
ASSERT_EQ
(
col_sint32
->
shape
(),
TensorShape
({
1
}));
ASSERT_EQ
(
col_sint64
->
shape
(),
TensorShape
({
1
}));
ASSERT_EQ
(
col_float
->
shape
(),
TensorShape
({
1
}));
ASSERT_EQ
(
col_1d
->
shape
(),
TensorShape
({
2
}));
ASSERT_EQ
(
col_2d
->
shape
(),
TensorShape
({
2
,
2
}));
ASSERT_EQ
(
col_3d
->
shape
(),
TensorShape
({
2
,
2
,
2
}));
ASSERT_EQ
(
col_binary
->
shape
(),
TensorShape
({
1
}));
// validate Rank
ASSERT_EQ
(
col_sint16
->
Rank
(),
1
);
ASSERT_EQ
(
col_sint32
->
Rank
(),
1
);
ASSERT_EQ
(
col_sint64
->
Rank
(),
1
);
ASSERT_EQ
(
col_float
->
Rank
(),
1
);
ASSERT_EQ
(
col_1d
->
Rank
(),
1
);
ASSERT_EQ
(
col_2d
->
Rank
(),
2
);
ASSERT_EQ
(
col_3d
->
Rank
(),
3
);
ASSERT_EQ
(
col_binary
->
Rank
(),
1
);
// validate type
ASSERT_EQ
(
col_sint16
->
type
(),
DataType
::
DE_INT16
);
ASSERT_EQ
(
col_sint32
->
type
(),
DataType
::
DE_INT32
);
ASSERT_EQ
(
col_sint64
->
type
(),
DataType
::
DE_INT64
);
ASSERT_EQ
(
col_float
->
type
(),
DataType
::
DE_FLOAT32
);
ASSERT_EQ
(
col_1d
->
type
(),
DataType
::
DE_INT64
);
ASSERT_EQ
(
col_2d
->
type
(),
DataType
::
DE_INT64
);
ASSERT_EQ
(
col_3d
->
type
(),
DataType
::
DE_INT64
);
ASSERT_EQ
(
col_binary
->
type
(),
DataType
::
DE_UINT8
);
iter
->
GetNextRow
(
&
row
);
i
++
;
}
EXPECT_EQ
(
i
,
984
);
// Manually terminate the pipeline
iter
->
Stop
();
GlobalContext
::
config_manager
()
->
set_seed
(
curr_seed
);
}
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录