Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
81005a30
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
81005a30
编写于
7月 24, 2020
作者:
C
Cathy Wong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
C++ API: Reorder code contents alphabetically
上级
e07f7436
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
469 addition
and
447 deletion
+469
-447
mindspore/ccsrc/minddata/dataset/api/datasets.cc
mindspore/ccsrc/minddata/dataset/api/datasets.cc
+162
-143
mindspore/ccsrc/minddata/dataset/api/transforms.cc
mindspore/ccsrc/minddata/dataset/api/transforms.cc
+213
-212
mindspore/ccsrc/minddata/dataset/include/datasets.h
mindspore/ccsrc/minddata/dataset/include/datasets.h
+94
-92
未找到文件。
mindspore/ccsrc/minddata/dataset/api/datasets.cc
浏览文件 @
81005a30
...
...
@@ -17,12 +17,14 @@
#include <fstream>
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/engine/dataset_iterator.h"
// Source dataset headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
// Dataset operator headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/datasetops/map_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
...
...
@@ -31,6 +33,7 @@
#include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h"
#include "minddata/dataset/engine/datasetops/rename_op.h"
// Sampler headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
...
...
@@ -79,6 +82,18 @@ Dataset::Dataset() {
connector_que_size_
=
cfg
->
op_connector_size
();
}
// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
// (In alphabetical order)
// Function to create a Cifar10Dataset.
std
::
shared_ptr
<
Cifar10Dataset
>
Cifar10
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
{
auto
ds
=
std
::
make_shared
<
Cifar10Dataset
>
(
dataset_dir
,
num_samples
,
sampler
);
// Call derived class validation method.
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
}
// Function to create a ImageFolderDataset.
std
::
shared_ptr
<
ImageFolderDataset
>
ImageFolder
(
std
::
string
dataset_dir
,
bool
decode
,
std
::
shared_ptr
<
SamplerObj
>
sampler
,
std
::
set
<
std
::
string
>
extensions
,
...
...
@@ -101,14 +116,8 @@ std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<Sam
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
}
// Function to create a Cifar10Dataset.
std
::
shared_ptr
<
Cifar10Dataset
>
Cifar10
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
{
auto
ds
=
std
::
make_shared
<
Cifar10Dataset
>
(
dataset_dir
,
num_samples
,
sampler
);
// Call derived class validation method.
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
}
// FUNCTIONS TO CREATE DATASETS FOR DATASET OPS
// (In alphabetical order)
// Function to create a Batch dataset
std
::
shared_ptr
<
BatchDataset
>
Dataset
::
Batch
(
int32_t
batch_size
,
bool
drop_remainder
)
{
...
...
@@ -127,14 +136,12 @@ std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remai
return
ds
;
}
// Function to create Repeat dataset.
std
::
shared_ptr
<
Dataset
>
Dataset
::
Repeat
(
int32_t
count
)
{
// Workaround for repeat == 1, do not inject repeat.
if
(
count
==
1
)
{
return
shared_from_this
();
}
auto
ds
=
std
::
make_shared
<
RepeatDataset
>
(
count
);
// Function to create a Map dataset.
std
::
shared_ptr
<
MapDataset
>
Dataset
::
Map
(
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
operations
,
std
::
vector
<
std
::
string
>
input_columns
,
std
::
vector
<
std
::
string
>
output_columns
,
const
std
::
vector
<
std
::
string
>
&
project_columns
)
{
auto
ds
=
std
::
make_shared
<
MapDataset
>
(
operations
,
input_columns
,
output_columns
,
project_columns
);
if
(
!
ds
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -145,13 +152,10 @@ std::shared_ptr<Dataset> Dataset::Repeat(int32_t count) {
return
ds
;
}
// Function to create a Map dataset.
std
::
shared_ptr
<
MapDataset
>
Dataset
::
Map
(
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
operations
,
std
::
vector
<
std
::
string
>
input_columns
,
std
::
vector
<
std
::
string
>
output_columns
,
const
std
::
vector
<
std
::
string
>
&
project_columns
)
{
auto
ds
=
std
::
make_shared
<
MapDataset
>
(
operations
,
input_columns
,
output_columns
,
project_columns
);
// Function to create a ProjectDataset.
std
::
shared_ptr
<
ProjectDataset
>
Dataset
::
Project
(
const
std
::
vector
<
std
::
string
>
&
columns
)
{
auto
ds
=
std
::
make_shared
<
ProjectDataset
>
(
columns
);
// Call derived class validation method.
if
(
!
ds
->
ValidateParams
())
{
return
nullptr
;
}
...
...
@@ -161,11 +165,11 @@ std::shared_ptr<MapDataset> Dataset::Map(std::vector<std::shared_ptr<TensorOpera
return
ds
;
}
// Function to create a
ShuffleOp
std
::
shared_ptr
<
ShuffleDataset
>
Dataset
::
Shuffle
(
int32_t
shuffle_size
)
{
// Pass in reshuffle_each_epoch with true
auto
ds
=
std
::
make_shared
<
ShuffleDataset
>
(
shuffle_size
,
true
);
// Function to create a
RenameDataset.
std
::
shared_ptr
<
RenameDataset
>
Dataset
::
Rename
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
)
{
auto
ds
=
std
::
make_shared
<
RenameDataset
>
(
input_columns
,
output_columns
);
// Call derived class validation method.
if
(
!
ds
->
ValidateParams
())
{
return
nullptr
;
}
...
...
@@ -175,11 +179,15 @@ std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t shuffle_size) {
return
ds
;
}
// Function to create a SkipDataset.
std
::
shared_ptr
<
SkipDataset
>
Dataset
::
Skip
(
int32_t
count
)
{
auto
ds
=
std
::
make_shared
<
SkipDataset
>
(
count
);
// Function to create Repeat dataset.
std
::
shared_ptr
<
Dataset
>
Dataset
::
Repeat
(
int32_t
count
)
{
// Workaround for repeat == 1, do not inject repeat.
if
(
count
==
1
)
{
return
shared_from_this
();
}
auto
ds
=
std
::
make_shared
<
RepeatDataset
>
(
count
);
// Call derived class validation method.
if
(
!
ds
->
ValidateParams
())
{
return
nullptr
;
}
...
...
@@ -189,10 +197,11 @@ std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) {
return
ds
;
}
// Function to create a ProjectDataset.
std
::
shared_ptr
<
ProjectDataset
>
Dataset
::
Project
(
const
std
::
vector
<
std
::
string
>
&
columns
)
{
auto
ds
=
std
::
make_shared
<
ProjectDataset
>
(
columns
);
// Call derived class validation method.
// Function to create a ShuffleOp
std
::
shared_ptr
<
ShuffleDataset
>
Dataset
::
Shuffle
(
int32_t
shuffle_size
)
{
// Pass in reshuffle_each_epoch with true
auto
ds
=
std
::
make_shared
<
ShuffleDataset
>
(
shuffle_size
,
true
);
if
(
!
ds
->
ValidateParams
())
{
return
nullptr
;
}
...
...
@@ -202,10 +211,10 @@ std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string>
return
ds
;
}
// Function to create a
Rename
Dataset.
std
::
shared_ptr
<
RenameDataset
>
Dataset
::
Rename
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
)
{
auto
ds
=
std
::
make_shared
<
RenameDataset
>
(
input_columns
,
output_columns
);
// Function to create a
Skip
Dataset.
std
::
shared_ptr
<
SkipDataset
>
Dataset
::
Skip
(
int32_t
count
)
{
auto
ds
=
std
::
make_shared
<
SkipDataset
>
(
count
);
// Call derived class validation method.
if
(
!
ds
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -231,6 +240,9 @@ std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Datas
return
ds
;
}
// OTHER FUNCTIONS
// (In alphabetical order)
// Helper function to create default RandomSampler.
std
::
shared_ptr
<
SamplerObj
>
CreateDefaultSampler
()
{
const
int32_t
num_samples
=
0
;
// 0 means to sample all ids.
...
...
@@ -240,6 +252,48 @@ std::shared_ptr<SamplerObj> CreateDefaultSampler() {
/* ####################################### Derived Dataset classes ################################# */
// DERIVED DATASET CLASSES LEAF-NODE DATASETS
// (In alphabetical order)
// Constructor for Cifar10Dataset
Cifar10Dataset
::
Cifar10Dataset
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
:
dataset_dir_
(
dataset_dir
),
num_samples_
(
num_samples
),
sampler_
(
sampler
)
{}
bool
Cifar10Dataset
::
ValidateParams
()
{
if
(
dataset_dir_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"No dataset path is specified."
;
return
false
;
}
if
(
num_samples_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"Number of samples cannot be negative"
;
return
false
;
}
return
true
;
}
// Function to build CifarOp
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
Cifar10Dataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
// If user does not specify Sampler, create a default sampler based on the shuffle variable.
if
(
sampler_
==
nullptr
)
{
sampler_
=
CreateDefaultSampler
();
}
// Do internal Schema generation.
auto
schema
=
std
::
make_unique
<
DataSchema
>
();
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
"image"
,
DataType
(
DataType
::
DE_UINT8
),
TensorImpl
::
kCv
,
1
)));
TensorShape
scalar
=
TensorShape
::
CreateScalar
();
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
"label"
,
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
0
,
&
scalar
)));
node_ops
.
push_back
(
std
::
make_shared
<
CifarOp
>
(
CifarOp
::
CifarType
::
kCifar10
,
num_workers_
,
rows_per_buffer_
,
dataset_dir_
,
connector_que_size_
,
std
::
move
(
schema
),
std
::
move
(
sampler_
->
Build
())));
return
node_ops
;
}
ImageFolderDataset
::
ImageFolderDataset
(
std
::
string
dataset_dir
,
bool
decode
,
std
::
shared_ptr
<
SamplerObj
>
sampler
,
bool
recursive
,
std
::
set
<
std
::
string
>
extensions
,
std
::
map
<
std
::
string
,
int32_t
>
class_indexing
)
...
...
@@ -315,6 +369,9 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
return
node_ops
;
}
// DERIVED DATASET CLASSES LEAF-NODE DATASETS
// (In alphabetical order)
BatchDataset
::
BatchDataset
(
int32_t
batch_size
,
bool
drop_remainder
,
bool
pad
,
std
::
vector
<
std
::
string
>
cols_to_map
,
std
::
map
<
std
::
string
,
std
::
pair
<
TensorShape
,
std
::
shared_ptr
<
Tensor
>>>
pad_map
)
:
batch_size_
(
batch_size
),
...
...
@@ -347,24 +404,6 @@ bool BatchDataset::ValidateParams() {
return
true
;
}
RepeatDataset
::
RepeatDataset
(
uint32_t
count
)
:
repeat_count_
(
count
)
{}
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
RepeatDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
node_ops
.
push_back
(
std
::
make_shared
<
RepeatOp
>
(
repeat_count_
));
return
node_ops
;
}
bool
RepeatDataset
::
ValidateParams
()
{
if
(
repeat_count_
<=
0
)
{
MS_LOG
(
ERROR
)
<<
"Repeat: Repeat count cannot be negative"
;
return
false
;
}
return
true
;
}
MapDataset
::
MapDataset
(
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
operations
,
std
::
vector
<
std
::
string
>
input_columns
,
std
::
vector
<
std
::
string
>
output_columns
,
const
std
::
vector
<
std
::
string
>
&
project_columns
)
:
operations_
(
operations
),
...
...
@@ -409,6 +448,69 @@ bool MapDataset::ValidateParams() {
return
true
;
}
// Function to build ProjectOp
ProjectDataset
::
ProjectDataset
(
const
std
::
vector
<
std
::
string
>
&
columns
)
:
columns_
(
columns
)
{}
bool
ProjectDataset
::
ValidateParams
()
{
if
(
columns_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"No columns are specified."
;
return
false
;
}
return
true
;
}
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
ProjectDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
node_ops
.
push_back
(
std
::
make_shared
<
ProjectOp
>
(
columns_
));
return
node_ops
;
}
// Function to build RenameOp
RenameDataset
::
RenameDataset
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
)
:
input_columns_
(
input_columns
),
output_columns_
(
output_columns
)
{}
bool
RenameDataset
::
ValidateParams
()
{
if
(
input_columns_
.
empty
()
||
output_columns_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"input and output columns must be specified"
;
return
false
;
}
if
(
input_columns_
.
size
()
!=
output_columns_
.
size
())
{
MS_LOG
(
ERROR
)
<<
"input and output columns must be the same size"
;
return
false
;
}
return
true
;
}
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
RenameDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
node_ops
.
push_back
(
std
::
make_shared
<
RenameOp
>
(
input_columns_
,
output_columns_
,
connector_que_size_
));
return
node_ops
;
}
RepeatDataset
::
RepeatDataset
(
uint32_t
count
)
:
repeat_count_
(
count
)
{}
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
RepeatDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
node_ops
.
push_back
(
std
::
make_shared
<
RepeatOp
>
(
repeat_count_
));
return
node_ops
;
}
bool
RepeatDataset
::
ValidateParams
()
{
if
(
repeat_count_
<=
0
)
{
MS_LOG
(
ERROR
)
<<
"Repeat: Repeat count cannot be negative"
;
return
false
;
}
return
true
;
}
// Constructor for ShuffleDataset
ShuffleDataset
::
ShuffleDataset
(
int32_t
shuffle_size
,
bool
reset_every_epoch
)
:
shuffle_size_
(
shuffle_size
),
shuffle_seed_
(
GetSeed
()),
reset_every_epoch_
(
reset_every_epoch
)
{}
...
...
@@ -455,64 +557,6 @@ bool SkipDataset::ValidateParams() {
return
true
;
}
// Constructor for Cifar10Dataset
Cifar10Dataset
::
Cifar10Dataset
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
:
dataset_dir_
(
dataset_dir
),
num_samples_
(
num_samples
),
sampler_
(
sampler
)
{}
bool
Cifar10Dataset
::
ValidateParams
()
{
if
(
dataset_dir_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"No dataset path is specified."
;
return
false
;
}
if
(
num_samples_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"Number of samples cannot be negative"
;
return
false
;
}
return
true
;
}
// Function to build CifarOp
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
Cifar10Dataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
// If user does not specify Sampler, create a default sampler based on the shuffle variable.
if
(
sampler_
==
nullptr
)
{
sampler_
=
CreateDefaultSampler
();
}
// Do internal Schema generation.
auto
schema
=
std
::
make_unique
<
DataSchema
>
();
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
"image"
,
DataType
(
DataType
::
DE_UINT8
),
TensorImpl
::
kCv
,
1
)));
TensorShape
scalar
=
TensorShape
::
CreateScalar
();
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
"label"
,
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
0
,
&
scalar
)));
node_ops
.
push_back
(
std
::
make_shared
<
CifarOp
>
(
CifarOp
::
CifarType
::
kCifar10
,
num_workers_
,
rows_per_buffer_
,
dataset_dir_
,
connector_que_size_
,
std
::
move
(
schema
),
std
::
move
(
sampler_
->
Build
())));
return
node_ops
;
}
// Function to build ProjectOp
ProjectDataset
::
ProjectDataset
(
const
std
::
vector
<
std
::
string
>
&
columns
)
:
columns_
(
columns
)
{}
bool
ProjectDataset
::
ValidateParams
()
{
if
(
columns_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"No columns are specified."
;
return
false
;
}
return
true
;
}
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
ProjectDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
node_ops
.
push_back
(
std
::
make_shared
<
ProjectOp
>
(
columns_
));
return
node_ops
;
}
// Function to build ZipOp
ZipDataset
::
ZipDataset
()
{}
...
...
@@ -526,31 +570,6 @@ std::vector<std::shared_ptr<DatasetOp>> ZipDataset::Build() {
return
node_ops
;
}
// Function to build RenameOp
RenameDataset
::
RenameDataset
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
)
:
input_columns_
(
input_columns
),
output_columns_
(
output_columns
)
{}
bool
RenameDataset
::
ValidateParams
()
{
if
(
input_columns_
.
empty
()
||
output_columns_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"input and output columns must be specified"
;
return
false
;
}
if
(
input_columns_
.
size
()
!=
output_columns_
.
size
())
{
MS_LOG
(
ERROR
)
<<
"input and output columns must be the same size"
;
return
false
;
}
return
true
;
}
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
RenameDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
node_ops
.
push_back
(
std
::
make_shared
<
RenameOp
>
(
input_columns_
,
output_columns_
,
connector_que_size_
));
return
node_ops
;
}
}
// namespace api
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/api/transforms.cc
浏览文件 @
81005a30
...
...
@@ -16,18 +16,19 @@
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/kernels/image/normalize_op.h"
#include "minddata/dataset/kernels/image/center_crop_op.h"
#include "minddata/dataset/kernels/image/cut_out_op.h"
#include "minddata/dataset/kernels/image/decode_op.h"
#include "minddata/dataset/kernels/image/resize_op.h"
#include "minddata/dataset/kernels/image/normalize_op.h"
#include "minddata/dataset/kernels/image/pad_op.h"
#include "minddata/dataset/kernels/image/random_color_adjust_op.h"
#include "minddata/dataset/kernels/image/random_crop_op.h"
#include "minddata/dataset/kernels/image/center_crop_op.h"
#include "minddata/dataset/kernels/image/uniform_aug_op.h"
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
#include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
#include "minddata/dataset/kernels/image/random_rotation_op.h"
#include "minddata/dataset/kernels/image/
cut_out
_op.h"
#include "minddata/dataset/kernels/image/r
andom_color_adjust
_op.h"
#include "minddata/dataset/kernels/image/
pad
_op.h"
#include "minddata/dataset/kernels/image/
random_vertical_flip
_op.h"
#include "minddata/dataset/kernels/image/r
esize
_op.h"
#include "minddata/dataset/kernels/image/
uniform_aug
_op.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -38,9 +39,9 @@ TensorOperation::TensorOperation() {}
// Transform operations for computer vision.
namespace
vision
{
// Function to create
Normalize
Operation.
std
::
shared_ptr
<
NormalizeOperation
>
Normalize
(
std
::
vector
<
float
>
mean
,
std
::
vector
<
float
>
std
)
{
auto
op
=
std
::
make_shared
<
NormalizeOperation
>
(
mean
,
std
);
// Function to create
CenterCrop
Operation.
std
::
shared_ptr
<
CenterCropOperation
>
CenterCrop
(
std
::
vector
<
int32_t
>
size
)
{
auto
op
=
std
::
make_shared
<
CenterCropOperation
>
(
size
);
// Input validation
if
(
!
op
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -48,9 +49,9 @@ std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vect
return
op
;
}
// Function to create
DecodeOperation
.
std
::
shared_ptr
<
DecodeOperation
>
Decode
(
bool
rgb
)
{
auto
op
=
std
::
make_shared
<
DecodeOperation
>
(
rgb
);
// Function to create
CutOutOp
.
std
::
shared_ptr
<
CutOutOperation
>
CutOut
(
int32_t
length
,
int32_t
num_patches
)
{
auto
op
=
std
::
make_shared
<
CutOutOperation
>
(
length
,
num_patches
);
// Input validation
if
(
!
op
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -58,9 +59,9 @@ std::shared_ptr<DecodeOperation> Decode(bool rgb) {
return
op
;
}
// Function to create
Resiz
eOperation.
std
::
shared_ptr
<
ResizeOperation
>
Resize
(
std
::
vector
<
int32_t
>
size
,
InterpolationMode
interpolation
)
{
auto
op
=
std
::
make_shared
<
ResizeOperation
>
(
size
,
interpolation
);
// Function to create
Decod
eOperation.
std
::
shared_ptr
<
DecodeOperation
>
Decode
(
bool
rgb
)
{
auto
op
=
std
::
make_shared
<
DecodeOperation
>
(
rgb
);
// Input validation
if
(
!
op
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -68,10 +69,9 @@ std::shared_ptr<ResizeOperation> Resize(std::vector<int32_t> size, Interpolation
return
op
;
}
// Function to create RandomCropOperation.
std
::
shared_ptr
<
RandomCropOperation
>
RandomCrop
(
std
::
vector
<
int32_t
>
size
,
std
::
vector
<
int32_t
>
padding
,
bool
pad_if_needed
,
std
::
vector
<
uint8_t
>
fill_value
)
{
auto
op
=
std
::
make_shared
<
RandomCropOperation
>
(
size
,
padding
,
pad_if_needed
,
fill_value
);
// Function to create NormalizeOperation.
std
::
shared_ptr
<
NormalizeOperation
>
Normalize
(
std
::
vector
<
float
>
mean
,
std
::
vector
<
float
>
std
)
{
auto
op
=
std
::
make_shared
<
NormalizeOperation
>
(
mean
,
std
);
// Input validation
if
(
!
op
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -79,9 +79,10 @@ std::shared_ptr<RandomCropOperation> RandomCrop(std::vector<int32_t> size, std::
return
op
;
}
// Function to create CenterCropOperation.
std
::
shared_ptr
<
CenterCropOperation
>
CenterCrop
(
std
::
vector
<
int32_t
>
size
)
{
auto
op
=
std
::
make_shared
<
CenterCropOperation
>
(
size
);
// Function to create PadOperation.
std
::
shared_ptr
<
PadOperation
>
Pad
(
std
::
vector
<
int32_t
>
padding
,
std
::
vector
<
uint8_t
>
fill_value
,
BorderType
padding_mode
)
{
auto
op
=
std
::
make_shared
<
PadOperation
>
(
padding
,
fill_value
,
padding_mode
);
// Input validation
if
(
!
op
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -89,10 +90,11 @@ std::shared_ptr<CenterCropOperation> CenterCrop(std::vector<int32_t> size) {
return
op
;
}
// Function to create UniformAugOperation.
std
::
shared_ptr
<
UniformAugOperation
>
UniformAugment
(
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
transforms
,
int32_t
num_ops
)
{
auto
op
=
std
::
make_shared
<
UniformAugOperation
>
(
transforms
,
num_ops
);
// Function to create RandomColorAdjustOperation.
std
::
shared_ptr
<
RandomColorAdjustOperation
>
RandomColorAdjust
(
std
::
vector
<
float
>
brightness
,
std
::
vector
<
float
>
contrast
,
std
::
vector
<
float
>
saturation
,
std
::
vector
<
float
>
hue
)
{
auto
op
=
std
::
make_shared
<
RandomColorAdjustOperation
>
(
brightness
,
contrast
,
saturation
,
hue
);
// Input validation
if
(
!
op
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -100,9 +102,10 @@ std::shared_ptr<UniformAugOperation> UniformAugment(std::vector<std::shared_ptr<
return
op
;
}
// Function to create RandomHorizontalFlipOperation.
std
::
shared_ptr
<
RandomHorizontalFlipOperation
>
RandomHorizontalFlip
(
float
prob
)
{
auto
op
=
std
::
make_shared
<
RandomHorizontalFlipOperation
>
(
prob
);
// Function to create RandomCropOperation.
std
::
shared_ptr
<
RandomCropOperation
>
RandomCrop
(
std
::
vector
<
int32_t
>
size
,
std
::
vector
<
int32_t
>
padding
,
bool
pad_if_needed
,
std
::
vector
<
uint8_t
>
fill_value
)
{
auto
op
=
std
::
make_shared
<
RandomCropOperation
>
(
size
,
padding
,
pad_if_needed
,
fill_value
);
// Input validation
if
(
!
op
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -110,9 +113,9 @@ std::shared_ptr<RandomHorizontalFlipOperation> RandomHorizontalFlip(float prob)
return
op
;
}
// Function to create Random
Vertic
alFlipOperation.
std
::
shared_ptr
<
Random
VerticalFlipOperation
>
RandomVertic
alFlip
(
float
prob
)
{
auto
op
=
std
::
make_shared
<
Random
Vertic
alFlipOperation
>
(
prob
);
// Function to create Random
Horizont
alFlipOperation.
std
::
shared_ptr
<
Random
HorizontalFlipOperation
>
RandomHorizont
alFlip
(
float
prob
)
{
auto
op
=
std
::
make_shared
<
Random
Horizont
alFlipOperation
>
(
prob
);
// Input validation
if
(
!
op
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -132,10 +135,9 @@ std::shared_ptr<RandomRotationOperation> RandomRotation(std::vector<float> degre
return
op
;
}
// Function to create PadOperation.
std
::
shared_ptr
<
PadOperation
>
Pad
(
std
::
vector
<
int32_t
>
padding
,
std
::
vector
<
uint8_t
>
fill_value
,
BorderType
padding_mode
)
{
auto
op
=
std
::
make_shared
<
PadOperation
>
(
padding
,
fill_value
,
padding_mode
);
// Function to create RandomVerticalFlipOperation.
std
::
shared_ptr
<
RandomVerticalFlipOperation
>
RandomVerticalFlip
(
float
prob
)
{
auto
op
=
std
::
make_shared
<
RandomVerticalFlipOperation
>
(
prob
);
// Input validation
if
(
!
op
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -143,9 +145,9 @@ std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint
return
op
;
}
// Function to create
CutOutOp
.
std
::
shared_ptr
<
CutOutOperation
>
CutOut
(
int32_t
length
,
int32_t
num_patches
)
{
auto
op
=
std
::
make_shared
<
CutOutOperation
>
(
length
,
num_patches
);
// Function to create
ResizeOperation
.
std
::
shared_ptr
<
ResizeOperation
>
Resize
(
std
::
vector
<
int32_t
>
size
,
InterpolationMode
interpolation
)
{
auto
op
=
std
::
make_shared
<
ResizeOperation
>
(
size
,
interpolation
);
// Input validation
if
(
!
op
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -153,11 +155,10 @@ std::shared_ptr<CutOutOperation> CutOut(int32_t length, int32_t num_patches) {
return
op
;
}
// Function to create RandomColorAdjustOperation.
std
::
shared_ptr
<
RandomColorAdjustOperation
>
RandomColorAdjust
(
std
::
vector
<
float
>
brightness
,
std
::
vector
<
float
>
contrast
,
std
::
vector
<
float
>
saturation
,
std
::
vector
<
float
>
hue
)
{
auto
op
=
std
::
make_shared
<
RandomColorAdjustOperation
>
(
brightness
,
contrast
,
saturation
,
hue
);
// Function to create UniformAugOperation.
std
::
shared_ptr
<
UniformAugOperation
>
UniformAugment
(
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
transforms
,
int32_t
num_ops
)
{
auto
op
=
std
::
make_shared
<
UniformAugOperation
>
(
transforms
,
num_ops
);
// Input validation
if
(
!
op
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -167,104 +168,6 @@ std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float>
/* ####################################### Derived TensorOperation classes ################################# */
// NormalizeOperation
NormalizeOperation
::
NormalizeOperation
(
std
::
vector
<
float
>
mean
,
std
::
vector
<
float
>
std
)
:
mean_
(
mean
),
std_
(
std
)
{}
bool
NormalizeOperation
::
ValidateParams
()
{
if
(
mean_
.
size
()
!=
3
)
{
MS_LOG
(
ERROR
)
<<
"Normalize: mean vector has incorrect size: "
<<
mean_
.
size
();
return
false
;
}
if
(
std_
.
size
()
!=
3
)
{
MS_LOG
(
ERROR
)
<<
"Normalize: std vector has incorrect size: "
<<
std_
.
size
();
return
false
;
}
return
true
;
}
std
::
shared_ptr
<
TensorOp
>
NormalizeOperation
::
Build
()
{
return
std
::
make_shared
<
NormalizeOp
>
(
mean_
[
0
],
mean_
[
1
],
mean_
[
2
],
std_
[
0
],
std_
[
1
],
std_
[
2
]);
}
// DecodeOperation
DecodeOperation
::
DecodeOperation
(
bool
rgb
)
:
rgb_
(
rgb
)
{}
bool
DecodeOperation
::
ValidateParams
()
{
return
true
;
}
std
::
shared_ptr
<
TensorOp
>
DecodeOperation
::
Build
()
{
return
std
::
make_shared
<
DecodeOp
>
(
rgb_
);
}
// ResizeOperation
ResizeOperation
::
ResizeOperation
(
std
::
vector
<
int32_t
>
size
,
InterpolationMode
interpolation
)
:
size_
(
size
),
interpolation_
(
interpolation
)
{}
bool
ResizeOperation
::
ValidateParams
()
{
if
(
size_
.
empty
()
||
size_
.
size
()
>
2
)
{
MS_LOG
(
ERROR
)
<<
"Resize: size vector has incorrect size: "
<<
size_
.
size
();
return
false
;
}
return
true
;
}
std
::
shared_ptr
<
TensorOp
>
ResizeOperation
::
Build
()
{
int32_t
height
=
size_
[
0
];
int32_t
width
=
0
;
// User specified the width value.
if
(
size_
.
size
()
==
2
)
{
width
=
size_
[
1
];
}
return
std
::
make_shared
<
ResizeOp
>
(
height
,
width
,
interpolation_
);
}
// RandomCropOperation
RandomCropOperation
::
RandomCropOperation
(
std
::
vector
<
int32_t
>
size
,
std
::
vector
<
int32_t
>
padding
,
bool
pad_if_needed
,
std
::
vector
<
uint8_t
>
fill_value
)
:
size_
(
size
),
padding_
(
padding
),
pad_if_needed_
(
pad_if_needed
),
fill_value_
(
fill_value
)
{}
bool
RandomCropOperation
::
ValidateParams
()
{
if
(
size_
.
empty
()
||
size_
.
size
()
>
2
)
{
MS_LOG
(
ERROR
)
<<
"RandomCrop: size vector has incorrect size: "
<<
size_
.
size
();
return
false
;
}
if
(
padding_
.
empty
()
||
padding_
.
size
()
!=
4
)
{
MS_LOG
(
ERROR
)
<<
"RandomCrop: padding vector has incorrect size: padding.size()"
;
return
false
;
}
if
(
fill_value_
.
empty
()
||
fill_value_
.
size
()
!=
3
)
{
MS_LOG
(
ERROR
)
<<
"RandomCrop: fill_value vector has incorrect size: fill_value.size()"
;
return
false
;
}
return
true
;
}
std
::
shared_ptr
<
TensorOp
>
RandomCropOperation
::
Build
()
{
int32_t
crop_height
=
size_
[
0
];
int32_t
crop_width
=
0
;
int32_t
pad_top
=
padding_
[
0
];
int32_t
pad_bottom
=
padding_
[
1
];
int32_t
pad_left
=
padding_
[
2
];
int32_t
pad_right
=
padding_
[
3
];
uint8_t
fill_r
=
fill_value_
[
0
];
uint8_t
fill_g
=
fill_value_
[
1
];
uint8_t
fill_b
=
fill_value_
[
2
];
// User has specified the crop_width value.
if
(
size_
.
size
()
==
2
)
{
crop_width
=
size_
[
1
];
}
auto
tensor_op
=
std
::
make_shared
<
RandomCropOp
>
(
crop_height
,
crop_width
,
pad_top
,
pad_bottom
,
pad_left
,
pad_right
,
BorderType
::
kConstant
,
pad_if_needed_
,
fill_r
,
fill_g
,
fill_b
);
return
tensor_op
;
}
// CenterCropOperation
CenterCropOperation
::
CenterCropOperation
(
std
::
vector
<
int32_t
>
size
)
:
size_
(
size
)
{}
...
...
@@ -289,71 +192,52 @@ std::shared_ptr<TensorOp> CenterCropOperation::Build() {
return
tensor_op
;
}
// UniformAugOperation
UniformAugOperation
::
UniformAugOperation
(
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
transforms
,
int32_t
num_ops
)
:
transforms_
(
transforms
),
num_ops_
(
num_ops
)
{}
bool
UniformAugOperation
::
ValidateParams
()
{
return
true
;
}
// CutOutOperation
CutOutOperation
::
CutOutOperation
(
int32_t
length
,
int32_t
num_patches
)
:
length_
(
length
),
num_patches_
(
num_patches
)
{}
std
::
shared_ptr
<
TensorOp
>
UniformAugOperation
::
Build
()
{
std
::
vector
<
std
::
shared_ptr
<
TensorOp
>>
tensor_ops
;
(
void
)
std
::
transform
(
transforms_
.
begin
(),
transforms_
.
end
(),
std
::
back_inserter
(
tensor_ops
),
[](
std
::
shared_ptr
<
TensorOperation
>
op
)
->
std
::
shared_ptr
<
TensorOp
>
{
return
op
->
Build
();
});
std
::
shared_ptr
<
UniformAugOp
>
tensor_op
=
std
::
make_shared
<
UniformAugOp
>
(
tensor_ops
,
num_ops_
);
return
tensor_op
;
bool
CutOutOperation
::
ValidateParams
()
{
if
(
length_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"CutOut: length cannot be negative"
;
return
false
;
}
if
(
num_patches_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"CutOut: number of patches cannot be negative"
;
return
false
;
}
return
true
;
}
// RandomHorizontalFlipOperation
RandomHorizontalFlipOperation
::
RandomHorizontalFlipOperation
(
float
probability
)
:
probability_
(
probability
)
{}
bool
RandomHorizontalFlipOperation
::
ValidateParams
()
{
return
true
;
}
std
::
shared_ptr
<
TensorOp
>
RandomHorizontalFlipOperation
::
Build
()
{
std
::
shared_ptr
<
RandomHorizontalFlipOp
>
tensor_op
=
std
::
make_shared
<
RandomHorizontalFlipOp
>
(
probability_
);
std
::
shared_ptr
<
TensorOp
>
CutOutOperation
::
Build
()
{
std
::
shared_ptr
<
CutOutOp
>
tensor_op
=
std
::
make_shared
<
CutOutOp
>
(
length_
,
length_
,
num_patches_
,
false
,
0
,
0
,
0
);
return
tensor_op
;
}
//
RandomVerticalFlip
Operation
RandomVerticalFlipOperation
::
RandomVerticalFlipOperation
(
float
probability
)
:
probability_
(
probability
)
{}
//
Decode
Operation
DecodeOperation
::
DecodeOperation
(
bool
rgb
)
:
rgb_
(
rgb
)
{}
bool
RandomVerticalFlip
Operation
::
ValidateParams
()
{
return
true
;
}
bool
Decode
Operation
::
ValidateParams
()
{
return
true
;
}
std
::
shared_ptr
<
TensorOp
>
RandomVerticalFlipOperation
::
Build
()
{
std
::
shared_ptr
<
RandomVerticalFlipOp
>
tensor_op
=
std
::
make_shared
<
RandomVerticalFlipOp
>
(
probability_
);
return
tensor_op
;
}
std
::
shared_ptr
<
TensorOp
>
DecodeOperation
::
Build
()
{
return
std
::
make_shared
<
DecodeOp
>
(
rgb_
);
}
// Function to create RandomRotationOperation.
RandomRotationOperation
::
RandomRotationOperation
(
std
::
vector
<
float
>
degrees
,
InterpolationMode
interpolation_mode
,
bool
expand
,
std
::
vector
<
float
>
center
,
std
::
vector
<
uint8_t
>
fill_value
)
:
degrees_
(
degrees
),
interpolation_mode_
(
interpolation_mode
),
expand_
(
expand
),
center_
(
center
),
fill_value_
(
fill_value
)
{}
// NormalizeOperation
NormalizeOperation
::
NormalizeOperation
(
std
::
vector
<
float
>
mean
,
std
::
vector
<
float
>
std
)
:
mean_
(
mean
),
std_
(
std
)
{}
bool
RandomRotationOperation
::
ValidateParams
()
{
if
(
degrees_
.
empty
()
||
degrees_
.
size
()
!=
2
)
{
MS_LOG
(
ERROR
)
<<
"RandomRotation: degrees vector has incorrect size: degrees.size()"
;
return
false
;
}
if
(
center_
.
empty
()
||
center_
.
size
()
!=
2
)
{
MS_LOG
(
ERROR
)
<<
"RandomRotation: center vector has incorrect size: center.size()"
;
bool
NormalizeOperation
::
ValidateParams
()
{
if
(
mean_
.
size
()
!=
3
)
{
MS_LOG
(
ERROR
)
<<
"Normalize: mean vector has incorrect size: "
<<
mean_
.
size
();
return
false
;
}
if
(
fill_value_
.
empty
()
||
fill_value_
.
size
()
!=
3
)
{
MS_LOG
(
ERROR
)
<<
"RandomRotation: fill_value vector has incorrect size: fill_value.size()"
;
if
(
std_
.
size
()
!=
3
)
{
MS_LOG
(
ERROR
)
<<
"Normalize: std vector has incorrect size: "
<<
std_
.
size
();
return
false
;
}
return
true
;
}
std
::
shared_ptr
<
TensorOp
>
RandomRotationOperation
::
Build
()
{
std
::
shared_ptr
<
RandomRotationOp
>
tensor_op
=
std
::
make_shared
<
RandomRotationOp
>
(
degrees_
[
0
],
degrees_
[
1
],
center_
[
0
],
center_
[
1
],
interpolation_mode_
,
expand_
,
fill_value_
[
0
],
fill_value_
[
1
],
fill_value_
[
2
]);
return
tensor_op
;
std
::
shared_ptr
<
TensorOp
>
NormalizeOperation
::
Build
()
{
return
std
::
make_shared
<
NormalizeOp
>
(
mean_
[
0
],
mean_
[
1
],
mean_
[
2
],
std_
[
0
],
std_
[
1
],
std_
[
2
]);
}
// PadOperation
...
...
@@ -411,26 +295,6 @@ std::shared_ptr<TensorOp> PadOperation::Build() {
return
tensor_op
;
}
// CutOutOperation
CutOutOperation
::
CutOutOperation
(
int32_t
length
,
int32_t
num_patches
)
:
length_
(
length
),
num_patches_
(
num_patches
)
{}
bool
CutOutOperation
::
ValidateParams
()
{
if
(
length_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"CutOut: length cannot be negative"
;
return
false
;
}
if
(
num_patches_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"CutOut: number of patches cannot be negative"
;
return
false
;
}
return
true
;
}
std
::
shared_ptr
<
TensorOp
>
CutOutOperation
::
Build
()
{
std
::
shared_ptr
<
CutOutOp
>
tensor_op
=
std
::
make_shared
<
CutOutOp
>
(
length_
,
length_
,
num_patches_
,
false
,
0
,
0
,
0
);
return
tensor_op
;
}
// RandomColorAdjustOperation.
RandomColorAdjustOperation
::
RandomColorAdjustOperation
(
std
::
vector
<
float
>
brightness
,
std
::
vector
<
float
>
contrast
,
std
::
vector
<
float
>
saturation
,
std
::
vector
<
float
>
hue
)
...
...
@@ -485,6 +349,143 @@ std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() {
return
tensor_op
;
}
// RandomCropOperation
RandomCropOperation
::
RandomCropOperation
(
std
::
vector
<
int32_t
>
size
,
std
::
vector
<
int32_t
>
padding
,
bool
pad_if_needed
,
std
::
vector
<
uint8_t
>
fill_value
)
:
size_
(
size
),
padding_
(
padding
),
pad_if_needed_
(
pad_if_needed
),
fill_value_
(
fill_value
)
{}
bool
RandomCropOperation
::
ValidateParams
()
{
if
(
size_
.
empty
()
||
size_
.
size
()
>
2
)
{
MS_LOG
(
ERROR
)
<<
"RandomCrop: size vector has incorrect size: "
<<
size_
.
size
();
return
false
;
}
if
(
padding_
.
empty
()
||
padding_
.
size
()
!=
4
)
{
MS_LOG
(
ERROR
)
<<
"RandomCrop: padding vector has incorrect size: padding.size()"
;
return
false
;
}
if
(
fill_value_
.
empty
()
||
fill_value_
.
size
()
!=
3
)
{
MS_LOG
(
ERROR
)
<<
"RandomCrop: fill_value vector has incorrect size: fill_value.size()"
;
return
false
;
}
return
true
;
}
std
::
shared_ptr
<
TensorOp
>
RandomCropOperation
::
Build
()
{
int32_t
crop_height
=
size_
[
0
];
int32_t
crop_width
=
0
;
int32_t
pad_top
=
padding_
[
0
];
int32_t
pad_bottom
=
padding_
[
1
];
int32_t
pad_left
=
padding_
[
2
];
int32_t
pad_right
=
padding_
[
3
];
uint8_t
fill_r
=
fill_value_
[
0
];
uint8_t
fill_g
=
fill_value_
[
1
];
uint8_t
fill_b
=
fill_value_
[
2
];
// User has specified the crop_width value.
if
(
size_
.
size
()
==
2
)
{
crop_width
=
size_
[
1
];
}
auto
tensor_op
=
std
::
make_shared
<
RandomCropOp
>
(
crop_height
,
crop_width
,
pad_top
,
pad_bottom
,
pad_left
,
pad_right
,
BorderType
::
kConstant
,
pad_if_needed_
,
fill_r
,
fill_g
,
fill_b
);
return
tensor_op
;
}
// RandomHorizontalFlipOperation
RandomHorizontalFlipOperation
::
RandomHorizontalFlipOperation
(
float
probability
)
:
probability_
(
probability
)
{}
bool
RandomHorizontalFlipOperation
::
ValidateParams
()
{
return
true
;
}
std
::
shared_ptr
<
TensorOp
>
RandomHorizontalFlipOperation
::
Build
()
{
std
::
shared_ptr
<
RandomHorizontalFlipOp
>
tensor_op
=
std
::
make_shared
<
RandomHorizontalFlipOp
>
(
probability_
);
return
tensor_op
;
}
// Function to create RandomRotationOperation.
RandomRotationOperation
::
RandomRotationOperation
(
std
::
vector
<
float
>
degrees
,
InterpolationMode
interpolation_mode
,
bool
expand
,
std
::
vector
<
float
>
center
,
std
::
vector
<
uint8_t
>
fill_value
)
:
degrees_
(
degrees
),
interpolation_mode_
(
interpolation_mode
),
expand_
(
expand
),
center_
(
center
),
fill_value_
(
fill_value
)
{}
bool
RandomRotationOperation
::
ValidateParams
()
{
if
(
degrees_
.
empty
()
||
degrees_
.
size
()
!=
2
)
{
MS_LOG
(
ERROR
)
<<
"RandomRotation: degrees vector has incorrect size: degrees.size()"
;
return
false
;
}
if
(
center_
.
empty
()
||
center_
.
size
()
!=
2
)
{
MS_LOG
(
ERROR
)
<<
"RandomRotation: center vector has incorrect size: center.size()"
;
return
false
;
}
if
(
fill_value_
.
empty
()
||
fill_value_
.
size
()
!=
3
)
{
MS_LOG
(
ERROR
)
<<
"RandomRotation: fill_value vector has incorrect size: fill_value.size()"
;
return
false
;
}
return
true
;
}
std
::
shared_ptr
<
TensorOp
>
RandomRotationOperation
::
Build
()
{
std
::
shared_ptr
<
RandomRotationOp
>
tensor_op
=
std
::
make_shared
<
RandomRotationOp
>
(
degrees_
[
0
],
degrees_
[
1
],
center_
[
0
],
center_
[
1
],
interpolation_mode_
,
expand_
,
fill_value_
[
0
],
fill_value_
[
1
],
fill_value_
[
2
]);
return
tensor_op
;
}
// RandomVerticalFlipOperation
RandomVerticalFlipOperation
::
RandomVerticalFlipOperation
(
float
probability
)
:
probability_
(
probability
)
{}
bool
RandomVerticalFlipOperation
::
ValidateParams
()
{
return
true
;
}
std
::
shared_ptr
<
TensorOp
>
RandomVerticalFlipOperation
::
Build
()
{
std
::
shared_ptr
<
RandomVerticalFlipOp
>
tensor_op
=
std
::
make_shared
<
RandomVerticalFlipOp
>
(
probability_
);
return
tensor_op
;
}
// ResizeOperation
ResizeOperation
::
ResizeOperation
(
std
::
vector
<
int32_t
>
size
,
InterpolationMode
interpolation
)
:
size_
(
size
),
interpolation_
(
interpolation
)
{}
bool
ResizeOperation
::
ValidateParams
()
{
if
(
size_
.
empty
()
||
size_
.
size
()
>
2
)
{
MS_LOG
(
ERROR
)
<<
"Resize: size vector has incorrect size: "
<<
size_
.
size
();
return
false
;
}
return
true
;
}
std
::
shared_ptr
<
TensorOp
>
ResizeOperation
::
Build
()
{
int32_t
height
=
size_
[
0
];
int32_t
width
=
0
;
// User specified the width value.
if
(
size_
.
size
()
==
2
)
{
width
=
size_
[
1
];
}
return
std
::
make_shared
<
ResizeOp
>
(
height
,
width
,
interpolation_
);
}
// UniformAugOperation
UniformAugOperation
::
UniformAugOperation
(
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
transforms
,
int32_t
num_ops
)
:
transforms_
(
transforms
),
num_ops_
(
num_ops
)
{}
bool
UniformAugOperation
::
ValidateParams
()
{
return
true
;
}
std
::
shared_ptr
<
TensorOp
>
UniformAugOperation
::
Build
()
{
std
::
vector
<
std
::
shared_ptr
<
TensorOp
>>
tensor_ops
;
(
void
)
std
::
transform
(
transforms_
.
begin
(),
transforms_
.
end
(),
std
::
back_inserter
(
tensor_ops
),
[](
std
::
shared_ptr
<
TensorOperation
>
op
)
->
std
::
shared_ptr
<
TensorOp
>
{
return
op
->
Build
();
});
std
::
shared_ptr
<
UniformAugOp
>
tensor_op
=
std
::
make_shared
<
UniformAugOp
>
(
tensor_ops
,
num_ops_
);
return
tensor_op
;
}
}
// namespace vision
}
// namespace api
}
// namespace dataset
...
...
mindspore/ccsrc/minddata/dataset/include/datasets.h
浏览文件 @
81005a30
...
...
@@ -40,17 +40,29 @@ namespace api {
class
TensorOperation
;
class
SamplerObj
;
// Datasets classes (in alphabetical order)
class
Cifar10Dataset
;
class
ImageFolderDataset
;
class
MnistDataset
;
// Dataset Op classes (in alphabetical order)
class
BatchDataset
;
class
RepeatDataset
;
class
MapDataset
;
class
ProjectDataset
;
class
RenameDataset
;
class
RepeatDataset
;
class
ShuffleDataset
;
class
SkipDataset
;
class
Cifar10Dataset
;
class
ProjectDataset
;
class
ZipDataset
;
class
RenameDataset
;
/// \brief Function to create a Cifar10 Dataset
/// \notes The generated dataset has two columns ['image', 'label']
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] num_samples The number of images to be included in the dataset
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
Cifar10Dataset
>
Cifar10
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
);
/// \brief Function to create an ImageFolderDataset
/// \notes A source dataset that reads images from a tree of directories
...
...
@@ -76,16 +88,6 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool de
/// \return Shared pointer to the current MnistDataset
std
::
shared_ptr
<
MnistDataset
>
Mnist
(
std
::
string
dataset_dir
,
std
::
shared_ptr
<
SamplerObj
>
sampler
=
nullptr
);
/// \brief Function to create a Cifar10 Dataset
/// \notes The generated dataset has two columns ['image', 'label']
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] num_samples The number of images to be included in the dataset
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
Cifar10Dataset
>
Cifar10
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
);
/// \class Dataset datasets.h
/// \brief A base class to represent a dataset in the data pipeline.
class
Dataset
:
public
std
::
enable_shared_from_this
<
Dataset
>
{
...
...
@@ -128,14 +130,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current BatchDataset
std
::
shared_ptr
<
BatchDataset
>
Batch
(
int32_t
batch_size
,
bool
drop_remainder
=
false
);
/// \brief Function to create a RepeatDataset
/// \notes Repeats this dataset count times. Repeat indefinitely if count is -1
/// \param[in] count Number of times the dataset should be repeated
/// \return Shared pointer to the current Dataset
/// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset`
/// due to a limitation in the current implementation
std
::
shared_ptr
<
Dataset
>
Repeat
(
int32_t
count
=
-
1
);
/// \brief Function to create a MapDataset
/// \notes Applies each operation in operations to this dataset
/// \param[in] operations Vector of operations to be applied on the dataset. Operations are
...
...
@@ -156,6 +150,28 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
std
::
vector
<
std
::
string
>
output_columns
=
{},
const
std
::
vector
<
std
::
string
>
&
project_columns
=
{});
/// \brief Function to create a Project Dataset
/// \notes Applies project to the dataset
/// \param[in] columns The name of columns to project
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
ProjectDataset
>
Project
(
const
std
::
vector
<
std
::
string
>
&
columns
);
/// \brief Function to create a Rename Dataset
/// \notes Renames the columns in the input dataset
/// \param[in] input_columns List of the input columns to rename
/// \param[in] output_columns List of the output columns
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
RenameDataset
>
Rename
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
);
/// \brief Function to create a RepeatDataset
/// \notes Repeats this dataset count times. Repeat indefinitely if count is -1
/// \param[in] count Number of times the dataset should be repeated
/// \return Shared pointer to the current Dataset
/// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset`
/// due to a limitation in the current implementation
std
::
shared_ptr
<
Dataset
>
Repeat
(
int32_t
count
=
-
1
);
/// \brief Function to create a Shuffle Dataset
/// \notes Randomly shuffles the rows of this dataset
/// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling
...
...
@@ -168,26 +184,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current SkipDataset
std
::
shared_ptr
<
SkipDataset
>
Skip
(
int32_t
count
);
/// \brief Function to create a Project Dataset
/// \notes Applies project to the dataset
/// \param[in] columns The name of columns to project
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
ProjectDataset
>
Project
(
const
std
::
vector
<
std
::
string
>
&
columns
);
/// \brief Function to create a Zip Dataset
/// \notes Applies zip to the dataset
/// \param[in] datasets A list of shared pointer to the datasets that we want to zip
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
ZipDataset
>
Zip
(
const
std
::
vector
<
std
::
shared_ptr
<
Dataset
>>
&
datasets
);
/// \brief Function to create a Rename Dataset
/// \notes Renames the columns in the input dataset
/// \param[in] input_columns List of the input columns to rename
/// \param[in] output_columns List of the output columns
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
RenameDataset
>
Rename
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
);
protected:
std
::
vector
<
std
::
shared_ptr
<
Dataset
>>
children
;
std
::
shared_ptr
<
Dataset
>
parent
;
...
...
@@ -199,6 +201,28 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/* ####################################### Derived Dataset classes ################################# */
class
Cifar10Dataset
:
public
Dataset
{
public:
/// \brief Constructor
Cifar10Dataset
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
);
/// \brief Destructor
~
Cifar10Dataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
Build
()
override
;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool
ValidateParams
()
override
;
private:
std
::
string
dataset_dir_
;
int32_t
num_samples_
;
std
::
shared_ptr
<
SamplerObj
>
sampler_
;
};
/// \class ImageFolderDataset
/// \brief A Dataset derived class to represent ImageFolder dataset
class
ImageFolderDataset
:
public
Dataset
{
...
...
@@ -273,13 +297,14 @@ class BatchDataset : public Dataset {
std
::
map
<
std
::
string
,
std
::
pair
<
TensorShape
,
std
::
shared_ptr
<
Tensor
>>>
pad_map_
;
};
class
Repeat
Dataset
:
public
Dataset
{
class
Map
Dataset
:
public
Dataset
{
public:
/// \brief Constructor
explicit
RepeatDataset
(
uint32_t
count
);
MapDataset
(
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
operations
,
std
::
vector
<
std
::
string
>
input_columns
=
{},
std
::
vector
<
std
::
string
>
output_columns
=
{},
const
std
::
vector
<
std
::
string
>
&
columns
=
{});
/// \brief Destructor
~
Repeat
Dataset
()
=
default
;
~
Map
Dataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
...
...
@@ -290,32 +315,19 @@ class RepeatDataset : public Dataset {
bool
ValidateParams
()
override
;
private:
uint32_t
repeat_count_
;
};
class
ShuffleDataset
:
public
Dataset
{
public:
ShuffleDataset
(
int32_t
shuffle_size
,
bool
reset_every_epoch
);
~
ShuffleDataset
()
=
default
;
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
Build
()
override
;
bool
ValidateParams
()
override
;
private:
int32_t
shuffle_size_
;
uint32_t
shuffle_seed_
;
bool
reset_every_epoch_
;
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
operations_
;
std
::
vector
<
std
::
string
>
input_columns_
;
std
::
vector
<
std
::
string
>
output_columns_
;
std
::
vector
<
std
::
string
>
project_columns_
;
};
class
Skip
Dataset
:
public
Dataset
{
class
Project
Dataset
:
public
Dataset
{
public:
/// \brief Constructor
explicit
SkipDataset
(
int32_t
count
);
explicit
ProjectDataset
(
const
std
::
vector
<
std
::
string
>
&
columns
);
/// \brief Destructor
~
Skip
Dataset
()
=
default
;
~
Project
Dataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
...
...
@@ -326,17 +338,16 @@ class SkipDataset : public Dataset {
bool
ValidateParams
()
override
;
private:
int32_t
skip_count
_
;
std
::
vector
<
std
::
string
>
columns
_
;
};
class
Map
Dataset
:
public
Dataset
{
class
Rename
Dataset
:
public
Dataset
{
public:
/// \brief Constructor
MapDataset
(
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
operations
,
std
::
vector
<
std
::
string
>
input_columns
=
{},
std
::
vector
<
std
::
string
>
output_columns
=
{},
const
std
::
vector
<
std
::
string
>
&
columns
=
{});
explicit
RenameDataset
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
);
/// \brief Destructor
~
Map
Dataset
()
=
default
;
~
Rename
Dataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
...
...
@@ -347,19 +358,17 @@ class MapDataset : public Dataset {
bool
ValidateParams
()
override
;
private:
std
::
vector
<
std
::
shared_ptr
<
TensorOperation
>>
operations_
;
std
::
vector
<
std
::
string
>
input_columns_
;
std
::
vector
<
std
::
string
>
output_columns_
;
std
::
vector
<
std
::
string
>
project_columns_
;
};
class
Cifar10
Dataset
:
public
Dataset
{
class
Repeat
Dataset
:
public
Dataset
{
public:
/// \brief Constructor
Cifar10Dataset
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
);
explicit
RepeatDataset
(
uint32_t
count
);
/// \brief Destructor
~
Cifar10
Dataset
()
=
default
;
~
Repeat
Dataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
...
...
@@ -370,38 +379,32 @@ class Cifar10Dataset : public Dataset {
bool
ValidateParams
()
override
;
private:
std
::
string
dataset_dir_
;
int32_t
num_samples_
;
std
::
shared_ptr
<
SamplerObj
>
sampler_
;
uint32_t
repeat_count_
;
};
class
Project
Dataset
:
public
Dataset
{
class
Shuffle
Dataset
:
public
Dataset
{
public:
/// \brief Constructor
explicit
ProjectDataset
(
const
std
::
vector
<
std
::
string
>
&
columns
);
ShuffleDataset
(
int32_t
shuffle_size
,
bool
reset_every_epoch
);
/// \brief Destructor
~
ProjectDataset
()
=
default
;
~
ShuffleDataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
Build
()
override
;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool
ValidateParams
()
override
;
private:
std
::
vector
<
std
::
string
>
columns_
;
int32_t
shuffle_size_
;
uint32_t
shuffle_seed_
;
bool
reset_every_epoch_
;
};
class
Z
ipDataset
:
public
Dataset
{
class
Sk
ipDataset
:
public
Dataset
{
public:
/// \brief Constructor
ZipDataset
(
);
explicit
SkipDataset
(
int32_t
count
);
/// \brief Destructor
~
Z
ipDataset
()
=
default
;
~
Sk
ipDataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
...
...
@@ -410,15 +413,18 @@ class ZipDataset : public Dataset {
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool
ValidateParams
()
override
;
private:
int32_t
skip_count_
;
};
class
Rename
Dataset
:
public
Dataset
{
class
Zip
Dataset
:
public
Dataset
{
public:
/// \brief Constructor
explicit
RenameDataset
(
const
std
::
vector
<
std
::
string
>
&
input_columns
,
const
std
::
vector
<
std
::
string
>
&
output_columns
);
ZipDataset
(
);
/// \brief Destructor
~
Rename
Dataset
()
=
default
;
~
Zip
Dataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
...
...
@@ -427,10 +433,6 @@ class RenameDataset : public Dataset {
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool
ValidateParams
()
override
;
private:
std
::
vector
<
std
::
string
>
input_columns_
;
std
::
vector
<
std
::
string
>
output_columns_
;
};
}
// namespace api
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录