Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
81005a30
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
81005a30
编写于
7月 24, 2020
作者:
C
Cathy Wong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
C++ API: Reorder code contents alphabetically
上级
e07f7436
变更
3
展开全部
隐藏空白更改
内联
并排
Showing
3 changed file
with
469 addition
and
447 deletion
+469
-447
mindspore/ccsrc/minddata/dataset/api/datasets.cc
mindspore/ccsrc/minddata/dataset/api/datasets.cc
+162
-143
mindspore/ccsrc/minddata/dataset/api/transforms.cc
mindspore/ccsrc/minddata/dataset/api/transforms.cc
+213
-212
mindspore/ccsrc/minddata/dataset/include/datasets.h
mindspore/ccsrc/minddata/dataset/include/datasets.h
+94
-92
未找到文件。
mindspore/ccsrc/minddata/dataset/api/datasets.cc
浏览文件 @
81005a30
...
...
@@ -17,12 +17,14 @@
#include <fstream>
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/engine/dataset_iterator.h"
// Source dataset headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
// Dataset operator headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/datasetops/map_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
...
...
@@ -31,6 +33,7 @@
#include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h"
#include "minddata/dataset/engine/datasetops/rename_op.h"
// Sampler headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
...
...
@@ -79,6 +82,18 @@ Dataset::Dataset() {
connector_que_size_
=
cfg
->
op_connector_size
();
}
// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
// (In alphabetical order)
// Function to create a Cifar10Dataset.
std
::
shared_ptr
<
Cifar10Dataset
>
Cifar10
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
{
auto
ds
=
std
::
make_shared
<
Cifar10Dataset
>
(
dataset_dir
,
num_samples
,
sampler
);
// Call derived class validation method.
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
}
// Function to create a ImageFolderDataset.
std
::
shared_ptr
<
ImageFolderDataset
>
ImageFolder
(
std
::
string
dataset_dir
,
bool
decode
,
std
::
shared_ptr
<
SamplerObj
>
sampler
,
std
::
set
<
std
::
string
>
extensions
,
...
...
@@ -101,14 +116,8 @@ std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<Sam
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
}
// Function to create a Cifar10Dataset.
std
::
shared_ptr
<
Cifar10Dataset
>
Cifar10
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
{
auto
ds
=
std
::
make_shared
<
Cifar10Dataset
>
(
dataset_dir
,
num_samples
,
sampler
);
// Call derived class validation method.
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
}
// FUNCTIONS TO CREATE DATASETS FOR DATASET OPS
// (In alphabetical order)
// Function to create a Batch dataset
std
::
shared_ptr
<
BatchDataset
>
Dataset
::
Batch
(
int32_t
batch_size
,
bool
drop_remainder
)
{
...
...
@@ -127,14 +136,12 @@ std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remai
return
ds
;
}
// Function to create Repeat dataset.
std
::
shared_ptr
<
Dataset
>
Dataset
::
Repeat
(
int32_t
count
)
{
// Workaround for repeat == 1, do not inject repeat.
if
(
count
==
1
)
{
return
shared_from_this
();
}
auto
ds
=
std
::
make_shared
<
RepeatDataset
>
(
count
);
// Function to create a Map dataset.
std
::
shared_ptr
<
MapDataset
>
Dataset
::
Map
(
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
operations
,
std
::
vector
<
std
::
string
>
input_columns
,
std
::
vector
<
std
::
string
>
output_columns
,
const
std
::
vector
<
std
::
string
>
&
project_columns
)
{
auto
ds
=
std
::
make_shared
<
MapDataset
>
(
operations
,
input_columns
,
output_columns
,
project_columns
);
if
(
!
ds
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -145,13 +152,10 @@ std::shared_ptr<Dataset> Dataset::Repeat(int32_t count) {
return
ds
;
}
// Function to create a Map dataset.
std
::
shared_ptr
<
MapDataset
>
Dataset
::
Map
(
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
operations
,
std
::
vector
<
std
::
string
>
input_columns
,
std
::
vector
<
std
::
string
>
output_columns
,
const
std
::
vector
<
std
::
string
>
&
project_columns
)
{
auto
ds
=
std
::
make_shared
<
MapDataset
>
(
operations
,
input_columns
,
output_columns
,
project_columns
);
// Function to create a ProjectDataset.
std
::
shared_ptr
<
ProjectDataset
>
Dataset
::
Project
(
const
std
::
vector
<
std
::
string
>
&
columns
)
{
auto
ds
=
std
::
make_shared
<
ProjectDataset
>
(
columns
);
// Call derived class validation method.
if
(
!
ds
->
ValidateParams
())
{
return
nullptr
;
}
...
...
@@ -161,11 +165,11 @@ std::shared_ptr<MapDataset> Dataset::Map(std::vector<std::shared_ptr<TensorOpera
return
ds
;
}
// Function to create a
ShuffleOp
std
::
shared_ptr
<
ShuffleDataset
>
Dataset
::
Shuffle
(
int32_t
shuffle_size
)
{
// Pass in reshuffle_each_epoch with true
auto
ds
=
std
::
make_shared
<
ShuffleDataset
>
(
shuffle_size
,
true
);
// Function to create a
RenameDataset.
std
::
shared_ptr
<
RenameDataset
>
Dataset
::
Rename
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
)
{
auto
ds
=
std
::
make_shared
<
RenameDataset
>
(
input_columns
,
output_columns
);
// Call derived class validation method.
if
(
!
ds
->
ValidateParams
())
{
return
nullptr
;
}
...
...
@@ -175,11 +179,15 @@ std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t shuffle_size) {
return
ds
;
}
// Function to create a SkipDataset.
std
::
shared_ptr
<
SkipDataset
>
Dataset
::
Skip
(
int32_t
count
)
{
auto
ds
=
std
::
make_shared
<
SkipDataset
>
(
count
);
// Function to create Repeat dataset.
std
::
shared_ptr
<
Dataset
>
Dataset
::
Repeat
(
int32_t
count
)
{
// Workaround for repeat == 1, do not inject repeat.
if
(
count
==
1
)
{
return
shared_from_this
();
}
auto
ds
=
std
::
make_shared
<
RepeatDataset
>
(
count
);
// Call derived class validation method.
if
(
!
ds
->
ValidateParams
())
{
return
nullptr
;
}
...
...
@@ -189,10 +197,11 @@ std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) {
return
ds
;
}
// Function to create a ProjectDataset.
std
::
shared_ptr
<
ProjectDataset
>
Dataset
::
Project
(
const
std
::
vector
<
std
::
string
>
&
columns
)
{
auto
ds
=
std
::
make_shared
<
ProjectDataset
>
(
columns
);
// Call derived class validation method.
// Function to create a ShuffleOp
std
::
shared_ptr
<
ShuffleDataset
>
Dataset
::
Shuffle
(
int32_t
shuffle_size
)
{
// Pass in reshuffle_each_epoch with true
auto
ds
=
std
::
make_shared
<
ShuffleDataset
>
(
shuffle_size
,
true
);
if
(
!
ds
->
ValidateParams
())
{
return
nullptr
;
}
...
...
@@ -202,10 +211,10 @@ std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string>
return
ds
;
}
// Function to create a
Rename
Dataset.
std
::
shared_ptr
<
RenameDataset
>
Dataset
::
Rename
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
)
{
auto
ds
=
std
::
make_shared
<
RenameDataset
>
(
input_columns
,
output_columns
);
// Function to create a
Skip
Dataset.
std
::
shared_ptr
<
SkipDataset
>
Dataset
::
Skip
(
int32_t
count
)
{
auto
ds
=
std
::
make_shared
<
SkipDataset
>
(
count
);
// Call derived class validation method.
if
(
!
ds
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -231,6 +240,9 @@ std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Datas
return
ds
;
}
// OTHER FUNCTIONS
// (In alphabetical order)
// Helper function to create default RandomSampler.
std
::
shared_ptr
<
SamplerObj
>
CreateDefaultSampler
()
{
const
int32_t
num_samples
=
0
;
// 0 means to sample all ids.
...
...
@@ -240,6 +252,48 @@ std::shared_ptr<SamplerObj> CreateDefaultSampler() {
/* ####################################### Derived Dataset classes ################################# */
// DERIVED DATASET CLASSES LEAF-NODE DATASETS
// (In alphabetical order)
// Constructor for Cifar10Dataset
Cifar10Dataset
::
Cifar10Dataset
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
:
dataset_dir_
(
dataset_dir
),
num_samples_
(
num_samples
),
sampler_
(
sampler
)
{}
bool
Cifar10Dataset
::
ValidateParams
()
{
if
(
dataset_dir_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"No dataset path is specified."
;
return
false
;
}
if
(
num_samples_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"Number of samples cannot be negative"
;
return
false
;
}
return
true
;
}
// Function to build CifarOp
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
Cifar10Dataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
// If user does not specify Sampler, create a default sampler based on the shuffle variable.
if
(
sampler_
==
nullptr
)
{
sampler_
=
CreateDefaultSampler
();
}
// Do internal Schema generation.
auto
schema
=
std
::
make_unique
<
DataSchema
>
();
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
"image"
,
DataType
(
DataType
::
DE_UINT8
),
TensorImpl
::
kCv
,
1
)));
TensorShape
scalar
=
TensorShape
::
CreateScalar
();
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
"label"
,
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
0
,
&
scalar
)));
node_ops
.
push_back
(
std
::
make_shared
<
CifarOp
>
(
CifarOp
::
CifarType
::
kCifar10
,
num_workers_
,
rows_per_buffer_
,
dataset_dir_
,
connector_que_size_
,
std
::
move
(
schema
),
std
::
move
(
sampler_
->
Build
())));
return
node_ops
;
}
ImageFolderDataset
::
ImageFolderDataset
(
std
::
string
dataset_dir
,
bool
decode
,
std
::
shared_ptr
<
SamplerObj
>
sampler
,
bool
recursive
,
std
::
set
<
std
::
string
>
extensions
,
std
::
map
<
std
::
string
,
int32_t
>
class_indexing
)
...
...
@@ -315,6 +369,9 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
return
node_ops
;
}
// DERIVED DATASET CLASSES LEAF-NODE DATASETS
// (In alphabetical order)
BatchDataset
::
BatchDataset
(
int32_t
batch_size
,
bool
drop_remainder
,
bool
pad
,
std
::
vector
<
std
::
string
>
cols_to_map
,
std
::
map
<
std
::
string
,
std
::
pair
<
TensorShape
,
std
::
shared_ptr
<
Tensor
>>>
pad_map
)
:
batch_size_
(
batch_size
),
...
...
@@ -347,24 +404,6 @@ bool BatchDataset::ValidateParams() {
return
true
;
}
RepeatDataset
::
RepeatDataset
(
uint32_t
count
)
:
repeat_count_
(
count
)
{}
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
RepeatDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
node_ops
.
push_back
(
std
::
make_shared
<
RepeatOp
>
(
repeat_count_
));
return
node_ops
;
}
bool
RepeatDataset
::
ValidateParams
()
{
if
(
repeat_count_
<=
0
)
{
MS_LOG
(
ERROR
)
<<
"Repeat: Repeat count cannot be negative"
;
return
false
;
}
return
true
;
}
MapDataset
::
MapDataset
(
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
operations
,
std
::
vector
<
std
::
string
>
input_columns
,
std
::
vector
<
std
::
string
>
output_columns
,
const
std
::
vector
<
std
::
string
>
&
project_columns
)
:
operations_
(
operations
),
...
...
@@ -409,6 +448,69 @@ bool MapDataset::ValidateParams() {
return
true
;
}
// Function to build ProjectOp
ProjectDataset
::
ProjectDataset
(
const
std
::
vector
<
std
::
string
>
&
columns
)
:
columns_
(
columns
)
{}
bool
ProjectDataset
::
ValidateParams
()
{
if
(
columns_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"No columns are specified."
;
return
false
;
}
return
true
;
}
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
ProjectDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
node_ops
.
push_back
(
std
::
make_shared
<
ProjectOp
>
(
columns_
));
return
node_ops
;
}
// Function to build RenameOp
RenameDataset
::
RenameDataset
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
)
:
input_columns_
(
input_columns
),
output_columns_
(
output_columns
)
{}
bool
RenameDataset
::
ValidateParams
()
{
if
(
input_columns_
.
empty
()
||
output_columns_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"input and output columns must be specified"
;
return
false
;
}
if
(
input_columns_
.
size
()
!=
output_columns_
.
size
())
{
MS_LOG
(
ERROR
)
<<
"input and output columns must be the same size"
;
return
false
;
}
return
true
;
}
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
RenameDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
node_ops
.
push_back
(
std
::
make_shared
<
RenameOp
>
(
input_columns_
,
output_columns_
,
connector_que_size_
));
return
node_ops
;
}
RepeatDataset
::
RepeatDataset
(
uint32_t
count
)
:
repeat_count_
(
count
)
{}
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
RepeatDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
node_ops
.
push_back
(
std
::
make_shared
<
RepeatOp
>
(
repeat_count_
));
return
node_ops
;
}
bool
RepeatDataset
::
ValidateParams
()
{
if
(
repeat_count_
<=
0
)
{
MS_LOG
(
ERROR
)
<<
"Repeat: Repeat count cannot be negative"
;
return
false
;
}
return
true
;
}
// Constructor for ShuffleDataset
ShuffleDataset
::
ShuffleDataset
(
int32_t
shuffle_size
,
bool
reset_every_epoch
)
:
shuffle_size_
(
shuffle_size
),
shuffle_seed_
(
GetSeed
()),
reset_every_epoch_
(
reset_every_epoch
)
{}
...
...
@@ -455,64 +557,6 @@ bool SkipDataset::ValidateParams() {
return
true
;
}
// Constructor for Cifar10Dataset
Cifar10Dataset
::
Cifar10Dataset
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
:
dataset_dir_
(
dataset_dir
),
num_samples_
(
num_samples
),
sampler_
(
sampler
)
{}
bool
Cifar10Dataset
::
ValidateParams
()
{
if
(
dataset_dir_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"No dataset path is specified."
;
return
false
;
}
if
(
num_samples_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"Number of samples cannot be negative"
;
return
false
;
}
return
true
;
}
// Function to build CifarOp
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
Cifar10Dataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
// If user does not specify Sampler, create a default sampler based on the shuffle variable.
if
(
sampler_
==
nullptr
)
{
sampler_
=
CreateDefaultSampler
();
}
// Do internal Schema generation.
auto
schema
=
std
::
make_unique
<
DataSchema
>
();
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
"image"
,
DataType
(
DataType
::
DE_UINT8
),
TensorImpl
::
kCv
,
1
)));
TensorShape
scalar
=
TensorShape
::
CreateScalar
();
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
"label"
,
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
0
,
&
scalar
)));
node_ops
.
push_back
(
std
::
make_shared
<
CifarOp
>
(
CifarOp
::
CifarType
::
kCifar10
,
num_workers_
,
rows_per_buffer_
,
dataset_dir_
,
connector_que_size_
,
std
::
move
(
schema
),
std
::
move
(
sampler_
->
Build
())));
return
node_ops
;
}
// Function to build ProjectOp
ProjectDataset
::
ProjectDataset
(
const
std
::
vector
<
std
::
string
>
&
columns
)
:
columns_
(
columns
)
{}
bool
ProjectDataset
::
ValidateParams
()
{
if
(
columns_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"No columns are specified."
;
return
false
;
}
return
true
;
}
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
ProjectDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
node_ops
.
push_back
(
std
::
make_shared
<
ProjectOp
>
(
columns_
));
return
node_ops
;
}
// Function to build ZipOp
ZipDataset
::
ZipDataset
()
{}
...
...
@@ -526,31 +570,6 @@ std::vector<std::shared_ptr<DatasetOp>> ZipDataset::Build() {
return
node_ops
;
}
// Function to build RenameOp
RenameDataset
::
RenameDataset
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
)
:
input_columns_
(
input_columns
),
output_columns_
(
output_columns
)
{}
bool
RenameDataset
::
ValidateParams
()
{
if
(
input_columns_
.
empty
()
||
output_columns_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"input and output columns must be specified"
;
return
false
;
}
if
(
input_columns_
.
size
()
!=
output_columns_
.
size
())
{
MS_LOG
(
ERROR
)
<<
"input and output columns must be the same size"
;
return
false
;
}
return
true
;
}
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
RenameDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
node_ops
.
push_back
(
std
::
make_shared
<
RenameOp
>
(
input_columns_
,
output_columns_
,
connector_que_size_
));
return
node_ops
;
}
}
// namespace api
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/api/transforms.cc
浏览文件 @
81005a30
此差异已折叠。
点击以展开。
mindspore/ccsrc/minddata/dataset/include/datasets.h
浏览文件 @
81005a30
...
...
@@ -40,17 +40,29 @@ namespace api {
class
TensorOperation
;
class
SamplerObj
;
// Datasets classes (in alphabetical order)
class
Cifar10Dataset
;
class
ImageFolderDataset
;
class
MnistDataset
;
// Dataset Op classes (in alphabetical order)
class
BatchDataset
;
class
RepeatDataset
;
class
MapDataset
;
class
ProjectDataset
;
class
RenameDataset
;
class
RepeatDataset
;
class
ShuffleDataset
;
class
SkipDataset
;
class
Cifar10Dataset
;
class
ProjectDataset
;
class
ZipDataset
;
class
RenameDataset
;
/// \brief Function to create a Cifar10 Dataset
/// \notes The generated dataset has two columns ['image', 'label']
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] num_samples The number of images to be included in the dataset
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
Cifar10Dataset
>
Cifar10
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
);
/// \brief Function to create an ImageFolderDataset
/// \notes A source dataset that reads images from a tree of directories
...
...
@@ -76,16 +88,6 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool de
/// \return Shared pointer to the current MnistDataset
std
::
shared_ptr
<
MnistDataset
>
Mnist
(
std
::
string
dataset_dir
,
std
::
shared_ptr
<
SamplerObj
>
sampler
=
nullptr
);
/// \brief Function to create a Cifar10 Dataset
/// \notes The generated dataset has two columns ['image', 'label']
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] num_samples The number of images to be included in the dataset
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
Cifar10Dataset
>
Cifar10
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
);
/// \class Dataset datasets.h
/// \brief A base class to represent a dataset in the data pipeline.
class
Dataset
:
public
std
::
enable_shared_from_this
<
Dataset
>
{
...
...
@@ -128,14 +130,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current BatchDataset
std
::
shared_ptr
<
BatchDataset
>
Batch
(
int32_t
batch_size
,
bool
drop_remainder
=
false
);
/// \brief Function to create a RepeatDataset
/// \notes Repeats this dataset count times. Repeat indefinitely if count is -1
/// \param[in] count Number of times the dataset should be repeated
/// \return Shared pointer to the current Dataset
/// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset`
/// due to a limitation in the current implementation
std
::
shared_ptr
<
Dataset
>
Repeat
(
int32_t
count
=
-
1
);
/// \brief Function to create a MapDataset
/// \notes Applies each operation in operations to this dataset
/// \param[in] operations Vector of operations to be applied on the dataset. Operations are
...
...
@@ -156,6 +150,28 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
std
::
vector
<
std
::
string
>
output_columns
=
{},
const
std
::
vector
<
std
::
string
>
&
project_columns
=
{});
/// \brief Function to create a Project Dataset
/// \notes Applies project to the dataset
/// \param[in] columns The name of columns to project
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
ProjectDataset
>
Project
(
const
std
::
vector
<
std
::
string
>
&
columns
);
/// \brief Function to create a Rename Dataset
/// \notes Renames the columns in the input dataset
/// \param[in] input_columns List of the input columns to rename
/// \param[in] output_columns List of the output columns
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
RenameDataset
>
Rename
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
);
/// \brief Function to create a RepeatDataset
/// \notes Repeats this dataset count times. Repeat indefinitely if count is -1
/// \param[in] count Number of times the dataset should be repeated
/// \return Shared pointer to the current Dataset
/// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset`
/// due to a limitation in the current implementation
std
::
shared_ptr
<
Dataset
>
Repeat
(
int32_t
count
=
-
1
);
/// \brief Function to create a Shuffle Dataset
/// \notes Randomly shuffles the rows of this dataset
/// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling
...
...
@@ -168,26 +184,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current SkipDataset
std
::
shared_ptr
<
SkipDataset
>
Skip
(
int32_t
count
);
/// \brief Function to create a Project Dataset
/// \notes Applies project to the dataset
/// \param[in] columns The name of columns to project
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
ProjectDataset
>
Project
(
const
std
::
vector
<
std
::
string
>
&
columns
);
/// \brief Function to create a Zip Dataset
/// \notes Applies zip to the dataset
/// \param[in] datasets A list of shared pointer to the datasets that we want to zip
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
ZipDataset
>
Zip
(
const
std
::
vector
<
std
::
shared_ptr
<
Dataset
>>
&
datasets
);
/// \brief Function to create a Rename Dataset
/// \notes Renames the columns in the input dataset
/// \param[in] input_columns List of the input columns to rename
/// \param[in] output_columns List of the output columns
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
RenameDataset
>
Rename
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
);
protected:
std
::
vector
<
std
::
shared_ptr
<
Dataset
>>
children
;
std
::
shared_ptr
<
Dataset
>
parent
;
...
...
@@ -199,6 +201,28 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/* ####################################### Derived Dataset classes ################################# */
class
Cifar10Dataset
:
public
Dataset
{
public:
/// \brief Constructor
Cifar10Dataset
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
);
/// \brief Destructor
~
Cifar10Dataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
Build
()
override
;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool
ValidateParams
()
override
;
private:
std
::
string
dataset_dir_
;
int32_t
num_samples_
;
std
::
shared_ptr
<
SamplerObj
>
sampler_
;
};
/// \class ImageFolderDataset
/// \brief A Dataset derived class to represent ImageFolder dataset
class
ImageFolderDataset
:
public
Dataset
{
...
...
@@ -273,13 +297,14 @@ class BatchDataset : public Dataset {
std
::
map
<
std
::
string
,
std
::
pair
<
TensorShape
,
std
::
shared_ptr
<
Tensor
>>>
pad_map_
;
};
class
Repeat
Dataset
:
public
Dataset
{
class
Map
Dataset
:
public
Dataset
{
public:
/// \brief Constructor
explicit
RepeatDataset
(
uint32_t
count
);
MapDataset
(
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
operations
,
std
::
vector
<
std
::
string
>
input_columns
=
{},
std
::
vector
<
std
::
string
>
output_columns
=
{},
const
std
::
vector
<
std
::
string
>
&
columns
=
{});
/// \brief Destructor
~
Repeat
Dataset
()
=
default
;
~
Map
Dataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
...
...
@@ -290,32 +315,19 @@ class RepeatDataset : public Dataset {
bool
ValidateParams
()
override
;
private:
uint32_t
repeat_count_
;
};
class
ShuffleDataset
:
public
Dataset
{
public:
ShuffleDataset
(
int32_t
shuffle_size
,
bool
reset_every_epoch
);
~
ShuffleDataset
()
=
default
;
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
Build
()
override
;
bool
ValidateParams
()
override
;
private:
int32_t
shuffle_size_
;
uint32_t
shuffle_seed_
;
bool
reset_every_epoch_
;
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
operations_
;
std
::
vector
<
std
::
string
>
input_columns_
;
std
::
vector
<
std
::
string
>
output_columns_
;
std
::
vector
<
std
::
string
>
project_columns_
;
};
class
Skip
Dataset
:
public
Dataset
{
class
Project
Dataset
:
public
Dataset
{
public:
/// \brief Constructor
explicit
SkipDataset
(
int32_t
count
);
explicit
ProjectDataset
(
const
std
::
vector
<
std
::
string
>
&
columns
);
/// \brief Destructor
~
Skip
Dataset
()
=
default
;
~
Project
Dataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
...
...
@@ -326,17 +338,16 @@ class SkipDataset : public Dataset {
bool
ValidateParams
()
override
;
private:
int32_t
skip_count
_
;
std
::
vector
<
std
::
string
>
columns
_
;
};
class
Map
Dataset
:
public
Dataset
{
class
Rename
Dataset
:
public
Dataset
{
public:
/// \brief Constructor
MapDataset
(
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
operations
,
std
::
vector
<
std
::
string
>
input_columns
=
{},
std
::
vector
<
std
::
string
>
output_columns
=
{},
const
std
::
vector
<
std
::
string
>
&
columns
=
{});
explicit
RenameDataset
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
);
/// \brief Destructor
~
Map
Dataset
()
=
default
;
~
Rename
Dataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
...
...
@@ -347,19 +358,17 @@ class MapDataset : public Dataset {
bool
ValidateParams
()
override
;
private:
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
operations_
;
std
::
vector
<
std
::
string
>
input_columns_
;
std
::
vector
<
std
::
string
>
output_columns_
;
std
::
vector
<
std
::
string
>
project_columns_
;
};
class
Cifar10
Dataset
:
public
Dataset
{
class
Repeat
Dataset
:
public
Dataset
{
public:
/// \brief Constructor
Cifar10Dataset
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
);
explicit
RepeatDataset
(
uint32_t
count
);
/// \brief Destructor
~
Cifar10
Dataset
()
=
default
;
~
Repeat
Dataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
...
...
@@ -370,38 +379,32 @@ class Cifar10Dataset : public Dataset {
bool
ValidateParams
()
override
;
private:
std
::
string
dataset_dir_
;
int32_t
num_samples_
;
std
::
shared_ptr
<
SamplerObj
>
sampler_
;
uint32_t
repeat_count_
;
};
class
Project
Dataset
:
public
Dataset
{
class
Shuffle
Dataset
:
public
Dataset
{
public:
/// \brief Constructor
explicit
ProjectDataset
(
const
std
::
vector
<
std
::
string
>
&
columns
);
ShuffleDataset
(
int32_t
shuffle_size
,
bool
reset_every_epoch
);
/// \brief Destructor
~
ProjectDataset
()
=
default
;
~
ShuffleDataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
Build
()
override
;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool
ValidateParams
()
override
;
private:
std
::
vector
<
std
::
string
>
columns_
;
int32_t
shuffle_size_
;
uint32_t
shuffle_seed_
;
bool
reset_every_epoch_
;
};
class
Z
ipDataset
:
public
Dataset
{
class
Sk
ipDataset
:
public
Dataset
{
public:
/// \brief Constructor
ZipDataset
(
);
explicit
SkipDataset
(
int32_t
count
);
/// \brief Destructor
~
Z
ipDataset
()
=
default
;
~
Sk
ipDataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
...
...
@@ -410,15 +413,18 @@ class ZipDataset : public Dataset {
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool
ValidateParams
()
override
;
private:
int32_t
skip_count_
;
};
class
Rename
Dataset
:
public
Dataset
{
class
Zip
Dataset
:
public
Dataset
{
public:
/// \brief Constructor
explicit
RenameDataset
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
);
ZipDataset
(
);
/// \brief Destructor
~
Rename
Dataset
()
=
default
;
~
Zip
Dataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
...
...
@@ -427,10 +433,6 @@ class RenameDataset : public Dataset {
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool
ValidateParams
()
override
;
private:
std
::
vector
<
std
::
string
>
input_columns_
;
std
::
vector
<
std
::
string
>
output_columns_
;
};
}
// namespace api
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录