Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ea947568
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ea947568
编写于
9月 10, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
9月 10, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5605 Introduce usage flag to MNIST and CIFAR dataset
Merge pull request !5605 from ZiruiWu/add_usage_to_cifar_mnist_coco
上级
ae7e8a74
1bb93580
变更
40
隐藏空白更改
内联
并排
Showing
40 changed file
with
656 addition
and
482 deletion
+656
-482
mindspore/ccsrc/minddata/dataset/api/datasets.cc
mindspore/ccsrc/minddata/dataset/api/datasets.cc
+49
-51
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/bindings.cc
...hon/bindings/dataset/engine/datasetops/source/bindings.cc
+4
-4
mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc
mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc
+12
-17
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc
...rc/minddata/dataset/engine/datasetops/source/celeba_op.cc
+9
-9
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h
...src/minddata/dataset/engine/datasetops/source/celeba_op.h
+8
-8
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc
...src/minddata/dataset/engine/datasetops/source/cifar_op.cc
+76
-60
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h
...csrc/minddata/dataset/engine/datasetops/source/cifar_op.h
+17
-6
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc
...src/minddata/dataset/engine/datasetops/source/mnist_op.cc
+32
-38
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h
...csrc/minddata/dataset/engine/datasetops/source/mnist_op.h
+14
-6
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc
...ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc
+9
-9
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h
.../ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h
+6
-6
mindspore/ccsrc/minddata/dataset/include/datasets.h
mindspore/ccsrc/minddata/dataset/include/datasets.h
+28
-23
mindspore/dataset/core/validator_helpers.py
mindspore/dataset/core/validator_helpers.py
+6
-0
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+36
-21
mindspore/dataset/engine/serializer_deserializer.py
mindspore/dataset/engine/serializer_deserializer.py
+3
-3
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+12
-8
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+3
-3
tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc
tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc
+11
-11
tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc
tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc
+6
-7
tests/ut/cpp/dataset/c_api_dataset_ops_test.cc
tests/ut/cpp/dataset/c_api_dataset_ops_test.cc
+27
-29
tests/ut/cpp/dataset/c_api_datasets_test.cc
tests/ut/cpp/dataset/c_api_datasets_test.cc
+7
-6
tests/ut/cpp/dataset/c_api_transforms_test.cc
tests/ut/cpp/dataset/c_api_transforms_test.cc
+35
-31
tests/ut/cpp/dataset/celeba_op_test.cc
tests/ut/cpp/dataset/celeba_op_test.cc
+7
-3
tests/ut/cpp/dataset/voc_op_test.cc
tests/ut/cpp/dataset/voc_op_test.cc
+4
-9
tests/ut/python/dataset/test_bounding_box_augment.py
tests/ut/python/dataset/test_bounding_box_augment.py
+13
-13
tests/ut/python/dataset/test_datasets_celeba.py
tests/ut/python/dataset/test_datasets_celeba.py
+4
-2
tests/ut/python/dataset/test_datasets_cifarop.py
tests/ut/python/dataset/test_datasets_cifarop.py
+55
-0
tests/ut/python/dataset/test_datasets_get_dataset_size.py
tests/ut/python/dataset/test_datasets_get_dataset_size.py
+33
-0
tests/ut/python/dataset/test_datasets_mnist.py
tests/ut/python/dataset/test_datasets_mnist.py
+36
-0
tests/ut/python/dataset/test_datasets_voc.py
tests/ut/python/dataset/test_datasets_voc.py
+13
-13
tests/ut/python/dataset/test_epoch_ctrl.py
tests/ut/python/dataset/test_epoch_ctrl.py
+14
-1
tests/ut/python/dataset/test_get_col_names.py
tests/ut/python/dataset/test_get_col_names.py
+1
-1
tests/ut/python/dataset/test_noop_mode.py
tests/ut/python/dataset/test_noop_mode.py
+2
-2
tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py
...t/python/dataset/test_random_crop_and_resize_with_bbox.py
+10
-10
tests/ut/python/dataset/test_random_crop_with_bbox.py
tests/ut/python/dataset/test_random_crop_with_bbox.py
+14
-14
tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py
...t/python/dataset/test_random_horizontal_flip_with_bbox.py
+11
-15
tests/ut/python/dataset/test_random_resize_with_bbox.py
tests/ut/python/dataset/test_random_resize_with_bbox.py
+8
-12
tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py
.../ut/python/dataset/test_random_vertical_flip_with_bbox.py
+11
-18
tests/ut/python/dataset/test_resize_with_bbox.py
tests/ut/python/dataset/test_resize_with_bbox.py
+8
-12
tests/ut/python/dataset/test_serdes_dataset.py
tests/ut/python/dataset/test_serdes_dataset.py
+2
-1
未找到文件。
mindspore/ccsrc/minddata/dataset/api/datasets.cc
浏览文件 @
ea947568
...
...
@@ -15,7 +15,7 @@
*/
#include <fstream>
#include <unordered_set>
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/include/transforms.h"
...
...
@@ -132,26 +132,28 @@ std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::s
}
// Function to create a CelebADataset.
std
::
shared_ptr
<
CelebADataset
>
CelebA
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
dataset_typ
e
,
std
::
shared_ptr
<
CelebADataset
>
CelebA
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
usag
e
,
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
,
bool
decode
,
const
std
::
set
<
std
::
string
>
&
extensions
)
{
auto
ds
=
std
::
make_shared
<
CelebADataset
>
(
dataset_dir
,
dataset_typ
e
,
sampler
,
decode
,
extensions
);
auto
ds
=
std
::
make_shared
<
CelebADataset
>
(
dataset_dir
,
usag
e
,
sampler
,
decode
,
extensions
);
// Call derived class validation method.
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
}
// Function to create a Cifar10Dataset.
std
::
shared_ptr
<
Cifar10Dataset
>
Cifar10
(
const
std
::
string
&
dataset_dir
,
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
)
{
auto
ds
=
std
::
make_shared
<
Cifar10Dataset
>
(
dataset_dir
,
sampler
);
std
::
shared_ptr
<
Cifar10Dataset
>
Cifar10
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
usage
,
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
)
{
auto
ds
=
std
::
make_shared
<
Cifar10Dataset
>
(
dataset_dir
,
usage
,
sampler
);
// Call derived class validation method.
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
}
// Function to create a Cifar100Dataset.
std
::
shared_ptr
<
Cifar100Dataset
>
Cifar100
(
const
std
::
string
&
dataset_dir
,
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
)
{
auto
ds
=
std
::
make_shared
<
Cifar100Dataset
>
(
dataset_dir
,
sampler
);
std
::
shared_ptr
<
Cifar100Dataset
>
Cifar100
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
usage
,
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
)
{
auto
ds
=
std
::
make_shared
<
Cifar100Dataset
>
(
dataset_dir
,
usage
,
sampler
);
// Call derived class validation method.
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
...
...
@@ -217,8 +219,9 @@ std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const
#endif
// Function to create a MnistDataset.
std
::
shared_ptr
<
MnistDataset
>
Mnist
(
const
std
::
string
&
dataset_dir
,
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
)
{
auto
ds
=
std
::
make_shared
<
MnistDataset
>
(
dataset_dir
,
sampler
);
std
::
shared_ptr
<
MnistDataset
>
Mnist
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
usage
,
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
)
{
auto
ds
=
std
::
make_shared
<
MnistDataset
>
(
dataset_dir
,
usage
,
sampler
);
// Call derived class validation method.
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
...
...
@@ -244,10 +247,10 @@ std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &datase
#ifndef ENABLE_ANDROID
// Function to create a VOCDataset.
std
::
shared_ptr
<
VOCDataset
>
VOC
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
task
,
const
std
::
string
&
mod
e
,
std
::
shared_ptr
<
VOCDataset
>
VOC
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
task
,
const
std
::
string
&
usag
e
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
class_indexing
,
bool
decode
,
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
)
{
auto
ds
=
std
::
make_shared
<
VOCDataset
>
(
dataset_dir
,
task
,
mod
e
,
class_indexing
,
decode
,
sampler
);
auto
ds
=
std
::
make_shared
<
VOCDataset
>
(
dataset_dir
,
task
,
usag
e
,
class_indexing
,
decode
,
sampler
);
// Call derived class validation method.
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
...
...
@@ -727,6 +730,10 @@ bool ValidateDatasetSampler(const std::string &dataset_name, const std::shared_p
return
true
;
}
bool
ValidateStringValue
(
const
std
::
string
&
str
,
const
std
::
unordered_set
<
std
::
string
>
&
valid_strings
)
{
return
valid_strings
.
find
(
str
)
!=
valid_strings
.
end
();
}
// Helper function to validate dataset input/output column parameter
bool
ValidateDatasetColumnParam
(
const
std
::
string
&
dataset_name
,
const
std
::
string
&
column_param
,
const
std
::
vector
<
std
::
string
>
&
columns
)
{
...
...
@@ -802,29 +809,14 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumDataset::Build() {
}
// Constructor for CelebADataset
CelebADataset
::
CelebADataset
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
dataset_typ
e
,
CelebADataset
::
CelebADataset
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
usag
e
,
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
,
const
bool
&
decode
,
const
std
::
set
<
std
::
string
>
&
extensions
)
:
dataset_dir_
(
dataset_dir
),
dataset_type_
(
dataset_type
),
sampler_
(
sampler
),
decode_
(
decode
),
extensions_
(
extensions
)
{}
:
dataset_dir_
(
dataset_dir
),
usage_
(
usage
),
sampler_
(
sampler
),
decode_
(
decode
),
extensions_
(
extensions
)
{}
bool
CelebADataset
::
ValidateParams
()
{
if
(
!
ValidateDatasetDirParam
(
"CelebADataset"
,
dataset_dir_
))
{
return
false
;
}
if
(
!
ValidateDatasetSampler
(
"CelebADataset"
,
sampler_
))
{
return
false
;
}
std
::
set
<
std
::
string
>
dataset_type_list
=
{
"all"
,
"train"
,
"valid"
,
"test"
};
auto
iter
=
dataset_type_list
.
find
(
dataset_type_
);
if
(
iter
==
dataset_type_list
.
end
())
{
MS_LOG
(
ERROR
)
<<
"dataset_type should be one of 'all', 'train', 'valid' or 'test'."
;
return
false
;
}
return
true
;
return
ValidateDatasetDirParam
(
"CelebADataset"
,
dataset_dir_
)
&&
ValidateDatasetSampler
(
"CelebADataset"
,
sampler_
)
&&
ValidateStringValue
(
usage_
,
{
"all"
,
"train"
,
"valid"
,
"test"
});
}
// Function to build CelebADataset
...
...
@@ -839,17 +831,20 @@ std::vector<std::shared_ptr<DatasetOp>> CelebADataset::Build() {
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
"attr"
,
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
1
)));
node_ops
.
push_back
(
std
::
make_shared
<
CelebAOp
>
(
num_workers_
,
rows_per_buffer_
,
dataset_dir_
,
connector_que_size_
,
decode_
,
dataset_typ
e_
,
extensions_
,
std
::
move
(
schema
),
decode_
,
usag
e_
,
extensions_
,
std
::
move
(
schema
),
std
::
move
(
sampler_
->
Build
())));
return
node_ops
;
}
// Constructor for Cifar10Dataset
Cifar10Dataset
::
Cifar10Dataset
(
const
std
::
string
&
dataset_dir
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
:
dataset_dir_
(
dataset_dir
),
sampler_
(
sampler
)
{}
Cifar10Dataset
::
Cifar10Dataset
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
usage
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
:
dataset_dir_
(
dataset_dir
),
usage_
(
usage
),
sampler_
(
sampler
)
{}
bool
Cifar10Dataset
::
ValidateParams
()
{
return
ValidateDatasetDirParam
(
"Cifar10Dataset"
,
dataset_dir_
)
&&
ValidateDatasetSampler
(
"Cifar10Dataset"
,
sampler_
);
return
ValidateDatasetDirParam
(
"Cifar10Dataset"
,
dataset_dir_
)
&&
ValidateDatasetSampler
(
"Cifar10Dataset"
,
sampler_
)
&&
ValidateStringValue
(
usage_
,
{
"train"
,
"test"
,
"all"
,
""
});
}
// Function to build CifarOp for Cifar10
...
...
@@ -864,19 +859,21 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
"label"
,
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
0
,
&
scalar
)));
node_ops
.
push_back
(
std
::
make_shared
<
CifarOp
>
(
CifarOp
::
CifarType
::
kCifar10
,
num_workers_
,
rows_per_buffer_
,
node_ops
.
push_back
(
std
::
make_shared
<
CifarOp
>
(
CifarOp
::
CifarType
::
kCifar10
,
usage_
,
num_workers_
,
rows_per_buffer_
,
dataset_dir_
,
connector_que_size_
,
std
::
move
(
schema
),
std
::
move
(
sampler_
->
Build
())));
return
node_ops
;
}
// Constructor for Cifar100Dataset
Cifar100Dataset
::
Cifar100Dataset
(
const
std
::
string
&
dataset_dir
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
:
dataset_dir_
(
dataset_dir
),
sampler_
(
sampler
)
{}
Cifar100Dataset
::
Cifar100Dataset
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
usage
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
:
dataset_dir_
(
dataset_dir
),
usage_
(
usage
),
sampler_
(
sampler
)
{}
bool
Cifar100Dataset
::
ValidateParams
()
{
return
ValidateDatasetDirParam
(
"Cifar100Dataset"
,
dataset_dir_
)
&&
ValidateDatasetSampler
(
"Cifar100Dataset"
,
sampler_
);
ValidateDatasetSampler
(
"Cifar100Dataset"
,
sampler_
)
&&
ValidateStringValue
(
usage_
,
{
"train"
,
"test"
,
"all"
,
""
});
}
// Function to build CifarOp for Cifar100
...
...
@@ -893,7 +890,7 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() {
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
"fine_label"
,
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
0
,
&
scalar
)));
node_ops
.
push_back
(
std
::
make_shared
<
CifarOp
>
(
CifarOp
::
CifarType
::
kCifar100
,
num_workers_
,
rows_per_buffer_
,
node_ops
.
push_back
(
std
::
make_shared
<
CifarOp
>
(
CifarOp
::
CifarType
::
kCifar100
,
usage_
,
num_workers_
,
rows_per_buffer_
,
dataset_dir_
,
connector_que_size_
,
std
::
move
(
schema
),
std
::
move
(
sampler_
->
Build
())));
return
node_ops
;
...
...
@@ -1360,11 +1357,12 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestDataset::Build() {
}
#endif
MnistDataset
::
MnistDataset
(
std
::
string
dataset_dir
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
:
dataset_dir_
(
dataset_dir
),
sampler_
(
sampler
)
{}
MnistDataset
::
MnistDataset
(
std
::
string
dataset_dir
,
std
::
s
tring
usage
,
std
::
s
hared_ptr
<
SamplerObj
>
sampler
)
:
dataset_dir_
(
dataset_dir
),
usage_
(
usage
),
sampler_
(
sampler
)
{}
bool
MnistDataset
::
ValidateParams
()
{
return
ValidateDatasetDirParam
(
"MnistDataset"
,
dataset_dir_
)
&&
ValidateDatasetSampler
(
"MnistDataset"
,
sampler_
);
return
ValidateStringValue
(
usage_
,
{
"train"
,
"test"
,
"all"
,
""
})
&&
ValidateDatasetDirParam
(
"MnistDataset"
,
dataset_dir_
)
&&
ValidateDatasetSampler
(
"MnistDataset"
,
sampler_
);
}
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
MnistDataset
::
Build
()
{
...
...
@@ -1378,8 +1376,8 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
"label"
,
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
0
,
&
scalar
)));
node_ops
.
push_back
(
std
::
make_shared
<
MnistOp
>
(
num_workers_
,
rows_per_buffer_
,
dataset_dir_
,
connector_que_size
_
,
std
::
move
(
schema
),
std
::
move
(
sampler_
->
Build
())));
node_ops
.
push_back
(
std
::
make_shared
<
MnistOp
>
(
usage_
,
num_workers_
,
rows_per_buffer_
,
dataset_dir
_
,
connector_que_size_
,
std
::
move
(
schema
),
std
::
move
(
sampler_
->
Build
())));
return
node_ops
;
}
...
...
@@ -1570,12 +1568,12 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordDataset::Build() {
#ifndef ENABLE_ANDROID
// Constructor for VOCDataset
VOCDataset
::
VOCDataset
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
task
,
const
std
::
string
&
mod
e
,
VOCDataset
::
VOCDataset
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
task
,
const
std
::
string
&
usag
e
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
class_indexing
,
bool
decode
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
:
dataset_dir_
(
dataset_dir
),
task_
(
task
),
mode_
(
mod
e
),
usage_
(
usag
e
),
class_index_
(
class_indexing
),
decode_
(
decode
),
sampler_
(
sampler
)
{}
...
...
@@ -1594,15 +1592,15 @@ bool VOCDataset::ValidateParams() {
MS_LOG
(
ERROR
)
<<
"class_indexing is invalid in Segmentation task."
;
return
false
;
}
Path
imagesets_file
=
dir
/
"ImageSets"
/
"Segmentation"
/
mod
e_
+
".txt"
;
Path
imagesets_file
=
dir
/
"ImageSets"
/
"Segmentation"
/
usag
e_
+
".txt"
;
if
(
!
imagesets_file
.
Exists
())
{
MS_LOG
(
ERROR
)
<<
"Invalid mode: "
<<
mode_
<<
", file
\"
"
<<
imagesets_file
<<
"
\"
is not exists
!"
;
MS_LOG
(
ERROR
)
<<
"Invalid mode: "
<<
usage_
<<
", file
\"
"
<<
imagesets_file
<<
"
\"
does not exist
!"
;
return
false
;
}
}
else
if
(
task_
==
"Detection"
)
{
Path
imagesets_file
=
dir
/
"ImageSets"
/
"Main"
/
mod
e_
+
".txt"
;
Path
imagesets_file
=
dir
/
"ImageSets"
/
"Main"
/
usag
e_
+
".txt"
;
if
(
!
imagesets_file
.
Exists
())
{
MS_LOG
(
ERROR
)
<<
"Invalid mode: "
<<
mode_
<<
", file
\"
"
<<
imagesets_file
<<
"
\"
is not exists
!"
;
MS_LOG
(
ERROR
)
<<
"Invalid mode: "
<<
usage_
<<
", file
\"
"
<<
imagesets_file
<<
"
\"
does not exist
!"
;
return
false
;
}
}
else
{
...
...
@@ -1641,7 +1639,7 @@ std::vector<std::shared_ptr<DatasetOp>> VOCDataset::Build() {
}
std
::
shared_ptr
<
VOCOp
>
voc_op
;
voc_op
=
std
::
make_shared
<
VOCOp
>
(
task_type_
,
mod
e_
,
dataset_dir_
,
class_index_
,
num_workers_
,
rows_per_buffer_
,
voc_op
=
std
::
make_shared
<
VOCOp
>
(
task_type_
,
usag
e_
,
dataset_dir_
,
class_index_
,
num_workers_
,
rows_per_buffer_
,
connector_que_size_
,
decode_
,
std
::
move
(
schema
),
std
::
move
(
sampler_
->
Build
()));
node_ops
.
push_back
(
voc_op
);
return
node_ops
;
...
...
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/bindings.cc
浏览文件 @
ea947568
...
...
@@ -41,9 +41,9 @@ namespace dataset {
PYBIND_REGISTER
(
CifarOp
,
1
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
CifarOp
,
DatasetOp
,
std
::
shared_ptr
<
CifarOp
>>
(
*
m
,
"CifarOp"
)
.
def_static
(
"get_num_rows"
,
[](
const
std
::
string
&
dir
,
bool
isCifar10
)
{
.
def_static
(
"get_num_rows"
,
[](
const
std
::
string
&
dir
,
const
std
::
string
&
usage
,
bool
isCifar10
)
{
int64_t
count
=
0
;
THROW_IF_ERROR
(
CifarOp
::
CountTotalRows
(
dir
,
isCifar10
,
&
count
));
THROW_IF_ERROR
(
CifarOp
::
CountTotalRows
(
dir
,
usage
,
isCifar10
,
&
count
));
return
count
;
});
}));
...
...
@@ -131,9 +131,9 @@ PYBIND_REGISTER(MindRecordOp, 1, ([](const py::module *m) {
PYBIND_REGISTER
(
MnistOp
,
1
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
MnistOp
,
DatasetOp
,
std
::
shared_ptr
<
MnistOp
>>
(
*
m
,
"MnistOp"
)
.
def_static
(
"get_num_rows"
,
[](
const
std
::
string
&
dir
)
{
.
def_static
(
"get_num_rows"
,
[](
const
std
::
string
&
dir
,
const
std
::
string
&
usage
)
{
int64_t
count
=
0
;
THROW_IF_ERROR
(
MnistOp
::
CountTotalRows
(
dir
,
&
count
));
THROW_IF_ERROR
(
MnistOp
::
CountTotalRows
(
dir
,
usage
,
&
count
));
return
count
;
});
}));
...
...
mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc
浏览文件 @
ea947568
...
...
@@ -1354,25 +1354,14 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
Status
DEPipeline
::
ParseVOCOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
)
{
if
(
args
[
"dataset_dir"
].
is_none
())
{
std
::
string
err_msg
=
"Error: No dataset path specified"
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
if
(
args
[
"task"
].
is_none
())
{
std
::
string
err_msg
=
"Error: No task specified"
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
if
(
args
[
"mode"
].
is_none
())
{
std
::
string
err_msg
=
"Error: No mode specified"
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
CHECK_FAIL_RETURN_UNEXPECTED
(
!
args
[
"dataset_dir"
].
is_none
(),
"Error: No dataset path specified."
);
CHECK_FAIL_RETURN_UNEXPECTED
(
!
args
[
"task"
].
is_none
(),
"Error: No task specified."
);
CHECK_FAIL_RETURN_UNEXPECTED
(
!
args
[
"usage"
].
is_none
(),
"Error: No usage specified."
);
std
::
shared_ptr
<
VOCOp
::
Builder
>
builder
=
std
::
make_shared
<
VOCOp
::
Builder
>
();
(
void
)
builder
->
SetDir
(
ToString
(
args
[
"dataset_dir"
]));
(
void
)
builder
->
SetTask
(
ToString
(
args
[
"task"
]));
(
void
)
builder
->
Set
Mode
(
ToString
(
args
[
"mod
e"
]));
(
void
)
builder
->
Set
Usage
(
ToString
(
args
[
"usag
e"
]));
for
(
auto
arg
:
args
)
{
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
...
...
@@ -1461,6 +1450,8 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
auto
create
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
).
attr
(
"create"
);
std
::
shared_ptr
<
Sampler
>
sampler
=
create
().
cast
<
std
::
shared_ptr
<
Sampler
>>
();
(
void
)
builder
->
SetSampler
(
std
::
move
(
sampler
));
}
else
if
(
key
==
"usage"
)
{
(
void
)
builder
->
SetUsage
(
ToString
(
value
));
}
}
}
...
...
@@ -1495,6 +1486,8 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
auto
create
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
).
attr
(
"create"
);
std
::
shared_ptr
<
Sampler
>
sampler
=
create
().
cast
<
std
::
shared_ptr
<
Sampler
>>
();
(
void
)
builder
->
SetSampler
(
std
::
move
(
sampler
));
}
else
if
(
key
==
"usage"
)
{
(
void
)
builder
->
SetUsage
(
ToString
(
value
));
}
}
}
...
...
@@ -1608,6 +1601,8 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
auto
create
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
).
attr
(
"create"
);
std
::
shared_ptr
<
Sampler
>
sampler
=
create
().
cast
<
std
::
shared_ptr
<
Sampler
>>
();
(
void
)
builder
->
SetSampler
(
std
::
move
(
sampler
));
}
else
if
(
key
==
"usage"
)
{
(
void
)
builder
->
SetUsage
(
ToString
(
value
));
}
}
}
...
...
@@ -1645,8 +1640,8 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
(
void
)
builder
->
SetDecode
(
ToBool
(
value
));
}
else
if
(
key
==
"extensions"
)
{
(
void
)
builder
->
SetExtensions
(
ToStringSet
(
value
));
}
else
if
(
key
==
"
dataset_typ
e"
)
{
(
void
)
builder
->
Set
DatasetTyp
e
(
ToString
(
value
));
}
else
if
(
key
==
"
usag
e"
)
{
(
void
)
builder
->
Set
Usag
e
(
ToString
(
value
));
}
}
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc
浏览文件 @
ea947568
...
...
@@ -36,7 +36,7 @@ CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr)
Status
CelebAOp
::
Builder
::
Build
(
std
::
shared_ptr
<
CelebAOp
>
*
op
)
{
MS_LOG
(
DEBUG
)
<<
"Celeba dataset directory is "
<<
builder_dir_
.
c_str
()
<<
"."
;
MS_LOG
(
DEBUG
)
<<
"Celeba dataset type is "
<<
builder_
dataset_typ
e_
.
c_str
()
<<
"."
;
MS_LOG
(
DEBUG
)
<<
"Celeba dataset type is "
<<
builder_
usag
e_
.
c_str
()
<<
"."
;
RETURN_IF_NOT_OK
(
SanityCheck
());
if
(
builder_sampler_
==
nullptr
)
{
const
int64_t
num_samples
=
0
;
...
...
@@ -51,8 +51,8 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
RETURN_IF_NOT_OK
(
builder_schema_
->
AddColumn
(
ColDescriptor
(
"attr"
,
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
1
)));
*
op
=
std
::
make_shared
<
CelebAOp
>
(
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_dir_
,
builder_op_connector_size_
,
builder_decode_
,
builder_
dataset_type
_
,
builder_extensions_
,
std
::
move
(
builder_schema_
),
std
::
move
(
builder_sampler_
));
builder_op_connector_size_
,
builder_decode_
,
builder_
usage_
,
builder_extensions
_
,
std
::
move
(
builder_schema_
),
std
::
move
(
builder_sampler_
));
if
(
*
op
==
nullptr
)
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"CelebAOp is null"
);
}
...
...
@@ -69,7 +69,7 @@ Status CelebAOp::Builder::SanityCheck() {
}
CelebAOp
::
CelebAOp
(
int32_t
num_workers
,
int32_t
rows_per_buffer
,
const
std
::
string
&
dir
,
int32_t
queue_size
,
bool
decode
,
const
std
::
string
&
dataset_typ
e
,
const
std
::
set
<
std
::
string
>
&
exts
,
bool
decode
,
const
std
::
string
&
usag
e
,
const
std
::
set
<
std
::
string
>
&
exts
,
std
::
unique_ptr
<
DataSchema
>
schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_workers
,
queue_size
,
std
::
move
(
sampler
)),
rows_per_buffer_
(
rows_per_buffer
),
...
...
@@ -78,7 +78,7 @@ CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::stri
extensions_
(
exts
),
data_schema_
(
std
::
move
(
schema
)),
num_rows_in_attr_file_
(
0
),
dataset_type_
(
dataset_typ
e
)
{
usage_
(
usag
e
)
{
attr_info_queue_
=
std
::
make_unique
<
Queue
<
std
::
vector
<
std
::
string
>>>
(
queue_size
);
io_block_queues_
.
Init
(
num_workers_
,
queue_size
);
}
...
...
@@ -135,7 +135,7 @@ Status CelebAOp::ParseAttrFile() {
std
::
vector
<
std
::
string
>
image_infos
;
image_infos
.
reserve
(
oc_queue_size_
);
while
(
getline
(
attr_file
,
image_info
))
{
if
((
image_info
.
empty
())
||
(
dataset_typ
e_
!=
"all"
&&
!
CheckDatasetTypeValid
()))
{
if
((
image_info
.
empty
())
||
(
usag
e_
!=
"all"
&&
!
CheckDatasetTypeValid
()))
{
continue
;
}
image_infos
.
push_back
(
image_info
);
...
...
@@ -179,11 +179,11 @@ bool CelebAOp::CheckDatasetTypeValid() {
return
false
;
}
// train:0, valid=1, test=2
if
(
dataset_typ
e_
==
"train"
&&
(
type
==
0
))
{
if
(
usag
e_
==
"train"
&&
(
type
==
0
))
{
return
true
;
}
else
if
(
dataset_typ
e_
==
"valid"
&&
(
type
==
1
))
{
}
else
if
(
usag
e_
==
"valid"
&&
(
type
==
1
))
{
return
true
;
}
else
if
(
dataset_typ
e_
==
"test"
&&
(
type
==
2
))
{
}
else
if
(
usag
e_
==
"test"
&&
(
type
==
2
))
{
return
true
;
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h
浏览文件 @
ea947568
...
...
@@ -109,10 +109,10 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
}
// Setter method
// @param const std::string
dataset_typ
e: type to be read
// @param const std::string
usag
e: type to be read
// @return Builder setter method returns reference to the builder.
Builder
&
Set
DatasetType
(
const
std
::
string
&
dataset_typ
e
)
{
builder_
dataset_type_
=
dataset_typ
e
;
Builder
&
Set
Usage
(
const
std
::
string
&
usag
e
)
{
builder_
usage_
=
usag
e
;
return
*
this
;
}
// Check validity of input args
...
...
@@ -133,7 +133,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
std
::
set
<
std
::
string
>
builder_extensions_
;
std
::
shared_ptr
<
Sampler
>
builder_sampler_
;
std
::
unique_ptr
<
DataSchema
>
builder_schema_
;
std
::
string
builder_
dataset_typ
e_
;
std
::
string
builder_
usag
e_
;
};
// Constructor
...
...
@@ -143,12 +143,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @param int32_t queueSize - connector queue size
// @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read
CelebAOp
(
int32_t
num_workers
,
int32_t
rows_per_buffer
,
const
std
::
string
&
dir
,
int32_t
queue_size
,
bool
decode
,
const
std
::
string
&
dataset_typ
e
,
const
std
::
set
<
std
::
string
>
&
exts
,
std
::
unique_ptr
<
DataSchema
>
schema
,
const
std
::
string
&
usag
e
,
const
std
::
set
<
std
::
string
>
&
exts
,
std
::
unique_ptr
<
DataSchema
>
schema
,
std
::
shared_ptr
<
Sampler
>
sampler
);
~
CelebAOp
()
override
=
default
;
// Main Loop of Celeb
a
Op
// Main Loop of Celeb
A
Op
// Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
// @return Status - The error code return
...
...
@@ -177,7 +177,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// Op name getter
// @return Name of the current Op
std
::
string
Name
()
const
{
return
"CelebAOp"
;
}
std
::
string
Name
()
const
override
{
return
"CelebAOp"
;
}
private:
// Called first when function is called
...
...
@@ -232,7 +232,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
QueueList
<
std
::
unique_ptr
<
IOBlock
>>
io_block_queues_
;
WaitPost
wp_
;
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
vector
<
int32_t
>>>
image_labels_vec_
;
std
::
string
dataset_typ
e_
;
std
::
string
usag
e_
;
std
::
ifstream
partition_file_
;
};
}
// namespace dataset
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc
浏览文件 @
ea947568
...
...
@@ -18,15 +18,16 @@
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <set>
#include <utility>
#include "utils/ms_utils.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "utils/ms_utils.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -36,7 +37,7 @@ constexpr uint32_t kCifarImageChannel = 3;
constexpr
uint32_t
kCifarBlockImageNum
=
5
;
constexpr
uint32_t
kCifarImageSize
=
kCifarImageHeight
*
kCifarImageWidth
*
kCifarImageChannel
;
CifarOp
::
Builder
::
Builder
()
:
sampler_
(
nullptr
)
{
CifarOp
::
Builder
::
Builder
()
:
sampler_
(
nullptr
)
,
usage_
(
""
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
num_workers_
=
cfg
->
num_parallel_workers
();
rows_per_buffer_
=
cfg
->
rows_per_buffer
();
...
...
@@ -65,23 +66,27 @@ Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) {
ColDescriptor
(
"fine_label"
,
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
0
,
&
another_scalar
)));
}
*
ptr
=
std
::
make_shared
<
CifarOp
>
(
cifar_type_
,
num_workers_
,
rows_per_buffer_
,
dir_
,
op_connect_size_
,
*
ptr
=
std
::
make_shared
<
CifarOp
>
(
cifar_type_
,
usage_
,
num_workers_
,
rows_per_buffer_
,
dir_
,
op_connect_size_
,
std
::
move
(
schema_
),
std
::
move
(
sampler_
));
return
Status
::
OK
();
}
Status
CifarOp
::
Builder
::
SanityCheck
()
{
const
std
::
set
<
std
::
string
>
valid
=
{
"test"
,
"train"
,
"all"
,
""
};
Path
dir
(
dir_
);
std
::
string
err_msg
;
err_msg
+=
dir
.
IsDirectory
()
==
false
?
"Cifar path is invalid or not set
\n
"
:
""
;
err_msg
+=
num_workers_
<=
0
?
"Num of parallel workers is negative or 0
\n
"
:
""
;
err_msg
+=
valid
.
find
(
usage_
)
==
valid
.
end
()
?
"usage needs to be 'train','test' or 'all'
\n
"
:
""
;
return
err_msg
.
empty
()
?
Status
::
OK
()
:
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
err_msg
);
}
CifarOp
::
CifarOp
(
CifarType
type
,
int32_t
num_works
,
int32_t
rows_per_buf
,
const
std
::
string
&
file_dir
,
int32_t
queue_size
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
CifarOp
::
CifarOp
(
CifarType
type
,
const
std
::
string
&
usage
,
int32_t
num_works
,
int32_t
rows_per_buf
,
const
std
::
string
&
file_dir
,
int32_t
queue_size
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_works
,
queue_size
,
std
::
move
(
sampler
)),
cifar_type_
(
type
),
usage_
(
usage
),
rows_per_buffer_
(
rows_per_buf
),
folder_path_
(
file_dir
),
data_schema_
(
std
::
move
(
data_schema
)),
...
...
@@ -258,21 +263,32 @@ Status CifarOp::ReadCifarBlockDataAsync() {
}
Status
CifarOp
::
ReadCifar10BlockData
()
{
// CIFAR 10 has 6 bin files. data_batch_1.bin ... data_batch_5.bin and 1 test_batch.bin file
// each of the file has exactly 10K images and labels and size is 30,730 KB
// each image has the dimension of 32 x 32 x 3 = 3072 plus 1 label (label has 10 classes) so each row has 3073 bytes
constexpr
uint32_t
num_cifar10_records
=
10000
;
uint32_t
block_size
=
(
kCifarImageSize
+
1
)
*
kCifarBlockImageNum
;
// about 2M
std
::
vector
<
unsigned
char
>
image_data
(
block_size
*
sizeof
(
unsigned
char
),
0
);
for
(
auto
&
file
:
cifar_files_
)
{
std
::
ifstream
in
(
file
,
std
::
ios
::
binary
);
if
(
!
in
.
is_open
())
{
std
::
string
err_msg
=
file
+
" can not be opened."
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
// check the validity of the file path
Path
file_path
(
file
);
CHECK_FAIL_RETURN_UNEXPECTED
(
file_path
.
Exists
()
&&
!
file_path
.
IsDirectory
(),
"invalid file:"
+
file
);
std
::
string
file_name
=
file_path
.
Basename
();
if
(
usage_
==
"train"
)
{
if
(
file_name
.
find
(
"data_batch"
)
==
std
::
string
::
npos
)
continue
;
}
else
if
(
usage_
==
"test"
)
{
if
(
file_name
.
find
(
"test_batch"
)
==
std
::
string
::
npos
)
continue
;
}
else
{
// get all the files that contain the word batch, aka any cifar 100 files
if
(
file_name
.
find
(
"batch"
)
==
std
::
string
::
npos
)
continue
;
}
std
::
ifstream
in
(
file
,
std
::
ios
::
binary
);
CHECK_FAIL_RETURN_UNEXPECTED
(
in
.
is_open
(),
file
+
" can not be opened."
);
for
(
uint32_t
index
=
0
;
index
<
num_cifar10_records
/
kCifarBlockImageNum
;
++
index
)
{
(
void
)
in
.
read
(
reinterpret_cast
<
char
*>
(
&
(
image_data
[
0
])),
block_size
*
sizeof
(
unsigned
char
));
if
(
in
.
fail
())
{
RETURN_STATUS_UNEXPECTED
(
"Fail to read cifar file"
+
file
);
}
CHECK_FAIL_RETURN_UNEXPECTED
(
!
in
.
fail
(),
"Fail to read cifar file"
+
file
);
(
void
)
cifar_raw_data_block_
->
EmplaceBack
(
image_data
);
}
in
.
close
();
...
...
@@ -283,15 +299,21 @@ Status CifarOp::ReadCifar10BlockData() {
}
Status
CifarOp
::
ReadCifar100BlockData
()
{
// CIFAR 100 has 2 bin files. train.bin (60K imgs) 153,700KB and test.bin (30,740KB) (10K imgs)
// each img has two labels. Each row then is 32 * 32 *5 + 2 = 3,074 Bytes
uint32_t
num_cifar100_records
=
0
;
// test:10000, train:50000
uint32_t
block_size
=
(
kCifarImageSize
+
2
)
*
kCifarBlockImageNum
;
// about 2M
std
::
vector
<
unsigned
char
>
image_data
(
block_size
*
sizeof
(
unsigned
char
),
0
);
for
(
auto
&
file
:
cifar_files_
)
{
int
pos
=
file
.
find_last_of
(
'/'
);
if
(
pos
==
std
::
string
::
npos
)
{
RETURN_STATUS_UNEXPECTED
(
"Invalid cifar100 file path"
);
}
std
::
string
file_name
(
file
.
substr
(
pos
+
1
));
// check the validity of the file path
Path
file_path
(
file
);
CHECK_FAIL_RETURN_UNEXPECTED
(
file_path
.
Exists
()
&&
!
file_path
.
IsDirectory
(),
"invalid file:"
+
file
);
std
::
string
file_name
=
file_path
.
Basename
();
// if usage is train/test, get only these 2 files
if
(
usage_
==
"train"
&&
file_name
.
find
(
"train"
)
==
std
::
string
::
npos
)
continue
;
if
(
usage_
==
"test"
&&
file_name
.
find
(
"test"
)
==
std
::
string
::
npos
)
continue
;
if
(
file_name
.
find
(
"test"
)
!=
std
::
string
::
npos
)
{
num_cifar100_records
=
10000
;
}
else
if
(
file_name
.
find
(
"train"
)
!=
std
::
string
::
npos
)
{
...
...
@@ -301,15 +323,11 @@ Status CifarOp::ReadCifar100BlockData() {
}
std
::
ifstream
in
(
file
,
std
::
ios
::
binary
);
if
(
!
in
.
is_open
())
{
RETURN_STATUS_UNEXPECTED
(
file
+
" can not be opened."
);
}
CHECK_FAIL_RETURN_UNEXPECTED
(
in
.
is_open
(),
file
+
" can not be opened."
);
for
(
uint32_t
index
=
0
;
index
<
num_cifar100_records
/
kCifarBlockImageNum
;
index
++
)
{
(
void
)
in
.
read
(
reinterpret_cast
<
char
*>
(
&
(
image_data
[
0
])),
block_size
*
sizeof
(
unsigned
char
));
if
(
in
.
fail
())
{
RETURN_STATUS_UNEXPECTED
(
"Fail to read cifar file"
+
file
);
}
CHECK_FAIL_RETURN_UNEXPECTED
(
!
in
.
fail
(),
"Fail to read cifar file"
+
file
);
(
void
)
cifar_raw_data_block_
->
EmplaceBack
(
image_data
);
}
in
.
close
();
...
...
@@ -319,26 +337,20 @@ Status CifarOp::ReadCifar100BlockData() {
}
Status
CifarOp
::
GetCifarFiles
()
{
// Initialize queue to hold the file names
const
std
::
string
kExtension
=
".bin"
;
Path
d
ataset_directory
(
folder_path_
);
auto
dirIt
=
Path
::
DirIterator
::
OpenDirectory
(
&
d
ataset_directory
);
Path
d
ir_path
(
folder_path_
);
auto
dirIt
=
Path
::
DirIterator
::
OpenDirectory
(
&
d
ir_path
);
if
(
dirIt
)
{
while
(
dirIt
->
hasNext
())
{
Path
file
=
dirIt
->
next
();
std
::
string
filename
=
file
.
toString
();
if
(
filename
.
find
(
kExtension
)
!=
std
::
string
::
npos
)
{
cifar_files_
.
push_back
(
filename
);
MS_LOG
(
INFO
)
<<
"Cifar operator found file at "
<<
filename
<<
"."
;
if
(
file
.
Extension
()
==
kExtension
)
{
cifar_files_
.
push_back
(
file
.
toString
());
}
}
}
else
{
std
::
string
err_msg
=
"Unable to open directory "
+
dataset_directory
.
toString
();
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
if
(
cifar_files_
.
size
()
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"No .bin files found under "
+
folder_path_
);
RETURN_STATUS_UNEXPECTED
(
"Unable to open directory "
+
dir_path
.
toString
());
}
CHECK_FAIL_RETURN_UNEXPECTED
(
!
cifar_files_
.
empty
(),
"No .bin files found under "
+
folder_path_
);
std
::
sort
(
cifar_files_
.
begin
(),
cifar_files_
.
end
());
return
Status
::
OK
();
}
...
...
@@ -378,9 +390,8 @@ Status CifarOp::ParseCifarData() {
num_rows_
=
cifar_image_label_pairs_
.
size
();
if
(
num_rows_
==
0
)
{
std
::
string
api
=
cifar_type_
==
kCifar10
?
"Cifar10Dataset"
:
"Cifar100Dataset"
;
std
::
string
err_msg
=
"There is no valid data matching the dataset API "
+
api
+
".Please check file path or dataset API validation first."
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
RETURN_STATUS_UNEXPECTED
(
"There is no valid data matching the dataset API "
+
api
+
".Please check file path or dataset API validation first."
);
}
cifar_raw_data_block_
->
Reset
();
return
Status
::
OK
();
...
...
@@ -403,46 +414,51 @@ Status CifarOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) co
return
Status
::
OK
();
}
Status
CifarOp
::
CountTotalRows
(
const
std
::
string
&
dir
,
bool
isCIFAR10
,
int64_t
*
count
)
{
Status
CifarOp
::
CountTotalRows
(
const
std
::
string
&
dir
,
const
std
::
string
&
usage
,
bool
isCIFAR10
,
int64_t
*
count
)
{
// the logic of counting the number of samples is copied from ReadCifar100Block() and ReadCifar10Block()
std
::
shared_ptr
<
CifarOp
>
op
;
*
count
=
0
;
RETURN_IF_NOT_OK
(
Builder
().
SetCifarDir
(
dir
).
SetCifarType
(
isCIFAR10
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
Builder
().
SetCifarDir
(
dir
).
SetCifarType
(
isCIFAR10
).
SetUsage
(
usage
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
op
->
GetCifarFiles
());
if
(
op
->
cifar_type_
==
kCifar10
)
{
constexpr
int64_t
num_cifar10_records
=
10000
;
for
(
auto
&
file
:
op
->
cifar_files_
)
{
std
::
ifstream
in
(
file
,
std
::
ios
::
binary
);
if
(
!
in
.
is_open
())
{
std
::
string
err_msg
=
file
+
" can not be opened."
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
Path
file_path
(
file
);
CHECK_FAIL_RETURN_UNEXPECTED
(
file_path
.
Exists
()
&&
!
file_path
.
IsDirectory
(),
"invalid file:"
+
file
);
std
::
string
file_name
=
file_path
.
Basename
();
if
(
op
->
usage_
==
"train"
)
{
if
(
file_name
.
find
(
"data_batch"
)
==
std
::
string
::
npos
)
continue
;
}
else
if
(
op
->
usage_
==
"test"
)
{
if
(
file_name
.
find
(
"test_batch"
)
==
std
::
string
::
npos
)
continue
;
}
else
{
// get all the files that contain the word batch, aka any cifar 100 files
if
(
file_name
.
find
(
"batch"
)
==
std
::
string
::
npos
)
continue
;
}
std
::
ifstream
in
(
file
,
std
::
ios
::
binary
);
CHECK_FAIL_RETURN_UNEXPECTED
(
in
.
is_open
(),
file
+
" can not be opened."
);
*
count
=
*
count
+
num_cifar10_records
;
}
return
Status
::
OK
();
}
else
{
int64_t
num_cifar100_records
=
0
;
for
(
auto
&
file
:
op
->
cifar_files_
)
{
size_t
pos
=
file
.
find_last_of
(
'/'
);
if
(
pos
==
std
::
string
::
npos
)
{
std
::
string
err_msg
=
"Invalid cifar100 file path"
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
std
::
string
file_name
;
if
(
file
.
size
()
>
0
)
file_name
=
file
.
substr
(
pos
+
1
);
else
RETURN_STATUS_UNEXPECTED
(
"Invalid string length!"
);
Path
file_path
(
file
);
std
::
string
file_name
=
file_path
.
Basename
();
CHECK_FAIL_RETURN_UNEXPECTED
(
file_path
.
Exists
()
&&
!
file_path
.
IsDirectory
(),
"invalid file:"
+
file
);
if
(
op
->
usage_
==
"train"
&&
file_path
.
Basename
().
find
(
"train"
)
==
std
::
string
::
npos
)
continue
;
if
(
op
->
usage_
==
"test"
&&
file_path
.
Basename
().
find
(
"test"
)
==
std
::
string
::
npos
)
continue
;
if
(
file_name
.
find
(
"test"
)
!=
std
::
string
::
npos
)
{
num_cifar100_records
=
10000
;
num_cifar100_records
+
=
10000
;
}
else
if
(
file_name
.
find
(
"train"
)
!=
std
::
string
::
npos
)
{
num_cifar100_records
=
50000
;
num_cifar100_records
+
=
50000
;
}
std
::
ifstream
in
(
file
,
std
::
ios
::
binary
);
if
(
!
in
.
is_open
())
{
std
::
string
err_msg
=
file
+
" can not be opened."
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
CHECK_FAIL_RETURN_UNEXPECTED
(
in
.
is_open
(),
file
+
" can not be opened."
);
}
*
count
=
num_cifar100_records
;
return
Status
::
OK
();
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h
浏览文件 @
ea947568
...
...
@@ -83,15 +83,23 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// Setter method
// @param const std::string & dir
// @return
// @return
Builder setter method returns reference to the builder.
Builder
&
SetCifarDir
(
const
std
::
string
&
dir
)
{
dir_
=
dir
;
return
*
this
;
}
// Setter method
// @param const std::string &usage
// @return Builder setter method returns reference to the builder.
Builder
&
SetUsage
(
const
std
::
string
&
usage
)
{
usage_
=
usage
;
return
*
this
;
}
// Setter method
// @param const std::string & dir
// @return
// @return
Builder setter method returns reference to the builder.
Builder
&
SetCifarType
(
const
bool
cifar10
)
{
if
(
cifar10
)
{
cifar_type_
=
kCifar10
;
...
...
@@ -112,6 +120,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
private:
std
::
string
dir_
;
std
::
string
usage_
;
int32_t
num_workers_
;
int32_t
rows_per_buffer_
;
int32_t
op_connect_size_
;
...
...
@@ -122,13 +131,15 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// Constructor
// @param CifarType type - Cifar10 or Cifar100
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'
// @param uint32_t numWorks - Num of workers reading images in parallel
// @param uint32_t - rowsPerBuffer Number of images (rows) in each buffer
// @param std::string - dir directory of cifar dataset
// @param uint32_t - queueSize - connector queue size
// @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
CifarOp
(
CifarType
type
,
int32_t
num_works
,
int32_t
rows_per_buf
,
const
std
::
string
&
file_dir
,
int32_t
queue_size
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
);
CifarOp
(
CifarType
type
,
const
std
::
string
&
usage
,
int32_t
num_works
,
int32_t
rows_per_buf
,
const
std
::
string
&
file_dir
,
int32_t
queue_size
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
);
// Destructor.
~
CifarOp
()
=
default
;
...
...
@@ -153,7 +164,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @param isCIFAR10 true if CIFAR10 and false if CIFAR100
// @param count output arg that will hold the actual dataset size
// @return
static
Status
CountTotalRows
(
const
std
::
string
&
dir
,
bool
isCIFAR10
,
int64_t
*
count
);
static
Status
CountTotalRows
(
const
std
::
string
&
dir
,
const
std
::
string
&
usage
,
bool
isCIFAR10
,
int64_t
*
count
);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
...
...
@@ -224,7 +235,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
std
::
unique_ptr
<
DataSchema
>
data_schema_
;
int64_t
row_cnt_
;
int64_t
buf_cnt_
;
const
std
::
string
usage_
;
// can only be either "train" or "test"
WaitPost
wp_
;
QueueList
<
std
::
unique_ptr
<
IOBlock
>>
io_block_queues_
;
std
::
unique_ptr
<
Queue
<
std
::
vector
<
unsigned
char
>>>
cifar_raw_data_block_
;
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc
浏览文件 @
ea947568
...
...
@@ -17,6 +17,7 @@
#include <fstream>
#include <iomanip>
#include <set>
#include "utils/ms_utils.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
...
...
@@ -32,7 +33,7 @@ const int32_t kMnistLabelFileMagicNumber = 2049;
const
int32_t
kMnistImageRows
=
28
;
const
int32_t
kMnistImageCols
=
28
;
MnistOp
::
Builder
::
Builder
()
:
builder_sampler_
(
nullptr
)
{
MnistOp
::
Builder
::
Builder
()
:
builder_sampler_
(
nullptr
)
,
builder_usage_
(
""
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
builder_num_workers_
=
cfg
->
num_parallel_workers
();
builder_rows_per_buffer_
=
cfg
->
rows_per_buffer
();
...
...
@@ -52,22 +53,25 @@ Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) {
TensorShape
scalar
=
TensorShape
::
CreateScalar
();
RETURN_IF_NOT_OK
(
builder_schema_
->
AddColumn
(
ColDescriptor
(
"label"
,
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
0
,
&
scalar
)));
*
ptr
=
std
::
make_shared
<
MnistOp
>
(
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_dir_
,
*
ptr
=
std
::
make_shared
<
MnistOp
>
(
builder_
usage_
,
builder_
num_workers_
,
builder_rows_per_buffer_
,
builder_dir_
,
builder_op_connector_size_
,
std
::
move
(
builder_schema_
),
std
::
move
(
builder_sampler_
));
return
Status
::
OK
();
}
Status
MnistOp
::
Builder
::
SanityCheck
()
{
const
std
::
set
<
std
::
string
>
valid
=
{
"test"
,
"train"
,
"all"
,
""
};
Path
dir
(
builder_dir_
);
std
::
string
err_msg
;
err_msg
+=
dir
.
IsDirectory
()
==
false
?
"MNIST path is invalid or not set
\n
"
:
""
;
err_msg
+=
builder_num_workers_
<=
0
?
"Number of parallel workers is set to 0 or negative
\n
"
:
""
;
err_msg
+=
valid
.
find
(
builder_usage_
)
==
valid
.
end
()
?
"usage needs to be 'train','test' or 'all'
\n
"
:
""
;
return
err_msg
.
empty
()
?
Status
::
OK
()
:
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
err_msg
);
}
MnistOp
::
MnistOp
(
int32_t
num_workers
,
int32_t
rows_per_buffer
,
std
::
string
folder_path
,
int32_t
queue_size
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
MnistOp
::
MnistOp
(
const
std
::
string
&
usage
,
int32_t
num_workers
,
int32_t
rows_per_buffer
,
std
::
string
folder_path
,
int32_t
queue_size
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_workers
,
queue_size
,
std
::
move
(
sampler
)),
usage_
(
usage
),
buf_cnt_
(
0
),
row_cnt_
(
0
),
folder_path_
(
folder_path
),
...
...
@@ -226,9 +230,7 @@ Status MnistOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) co
Status
MnistOp
::
ReadFromReader
(
std
::
ifstream
*
reader
,
uint32_t
*
result
)
{
uint32_t
res
=
0
;
reader
->
read
(
reinterpret_cast
<
char
*>
(
&
res
),
4
);
if
(
reader
->
fail
())
{
RETURN_STATUS_UNEXPECTED
(
"Failed to read 4 bytes from file"
);
}
CHECK_FAIL_RETURN_UNEXPECTED
(
!
reader
->
fail
(),
"Failed to read 4 bytes from file"
);
*
result
=
SwapEndian
(
res
);
return
Status
::
OK
();
}
...
...
@@ -239,15 +241,12 @@ uint32_t MnistOp::SwapEndian(uint32_t val) const {
}
Status
MnistOp
::
CheckImage
(
const
std
::
string
&
file_name
,
std
::
ifstream
*
image_reader
,
uint32_t
*
num_images
)
{
if
(
image_reader
->
is_open
()
==
false
)
{
RETURN_STATUS_UNEXPECTED
(
"Cannot open mnist image file: "
+
file_name
);
}
CHECK_FAIL_RETURN_UNEXPECTED
(
image_reader
->
is_open
(),
"Cannot open mnist image file: "
+
file_name
);
int64_t
image_len
=
image_reader
->
seekg
(
0
,
std
::
ios
::
end
).
tellg
();
(
void
)
image_reader
->
seekg
(
0
,
std
::
ios
::
beg
);
// The first 16 bytes of the image file are type, number, row and column
if
(
image_len
<
16
)
{
RETURN_STATUS_UNEXPECTED
(
"Mnist file is corrupted."
);
}
CHECK_FAIL_RETURN_UNEXPECTED
(
image_len
>=
16
,
"Mnist file is corrupted."
);
uint32_t
magic_number
;
RETURN_IF_NOT_OK
(
ReadFromReader
(
image_reader
,
&
magic_number
));
CHECK_FAIL_RETURN_UNEXPECTED
(
magic_number
==
kMnistImageFileMagicNumber
,
...
...
@@ -260,35 +259,25 @@ Status MnistOp::CheckImage(const std::string &file_name, std::ifstream *image_re
uint32_t
cols
;
RETURN_IF_NOT_OK
(
ReadFromReader
(
image_reader
,
&
cols
));
// The image size of the Mnist dataset is fixed at [28,28]
if
((
rows
!=
kMnistImageRows
)
||
(
cols
!=
kMnistImageCols
))
{
RETURN_STATUS_UNEXPECTED
(
"Wrong shape of image."
);
}
if
((
image_len
-
16
)
!=
num_items
*
rows
*
cols
)
{
RETURN_STATUS_UNEXPECTED
(
"Wrong number of image."
);
}
CHECK_FAIL_RETURN_UNEXPECTED
((
rows
==
kMnistImageRows
)
&&
(
cols
==
kMnistImageCols
),
"Wrong shape of image."
);
CHECK_FAIL_RETURN_UNEXPECTED
((
image_len
-
16
)
==
num_items
*
rows
*
cols
,
"Wrong number of image."
);
*
num_images
=
num_items
;
return
Status
::
OK
();
}
Status
MnistOp
::
CheckLabel
(
const
std
::
string
&
file_name
,
std
::
ifstream
*
label_reader
,
uint32_t
*
num_labels
)
{
if
(
label_reader
->
is_open
()
==
false
)
{
RETURN_STATUS_UNEXPECTED
(
"Cannot open mnist label file: "
+
file_name
);
}
CHECK_FAIL_RETURN_UNEXPECTED
(
label_reader
->
is_open
(),
"Cannot open mnist label file: "
+
file_name
);
int64_t
label_len
=
label_reader
->
seekg
(
0
,
std
::
ios
::
end
).
tellg
();
(
void
)
label_reader
->
seekg
(
0
,
std
::
ios
::
beg
);
// The first 8 bytes of the image file are type and number
if
(
label_len
<
8
)
{
RETURN_STATUS_UNEXPECTED
(
"Mnist file is corrupted."
);
}
CHECK_FAIL_RETURN_UNEXPECTED
(
label_len
>=
8
,
"Mnist file is corrupted."
);
uint32_t
magic_number
;
RETURN_IF_NOT_OK
(
ReadFromReader
(
label_reader
,
&
magic_number
));
CHECK_FAIL_RETURN_UNEXPECTED
(
magic_number
==
kMnistLabelFileMagicNumber
,
"This is not the mnist label file: "
+
file_name
);
uint32_t
num_items
;
RETURN_IF_NOT_OK
(
ReadFromReader
(
label_reader
,
&
num_items
));
if
((
label_len
-
8
)
!=
num_items
)
{
RETURN_STATUS_UNEXPECTED
(
"Wrong number of labels!"
);
}
CHECK_FAIL_RETURN_UNEXPECTED
((
label_len
-
8
)
==
num_items
,
"Wrong number of labels!"
);
*
num_labels
=
num_items
;
return
Status
::
OK
();
}
...
...
@@ -330,6 +319,9 @@ Status MnistOp::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *la
}
Status
MnistOp
::
ParseMnistData
()
{
// MNIST contains 4 files, idx3 are image files, idx 1 are labels
// training files contain 60K examples and testing files contain 10K examples
// t10k-images-idx3-ubyte t10k-labels-idx1-ubyte train-images-idx3-ubyte train-labels-idx1-ubyte
for
(
size_t
i
=
0
;
i
<
image_names_
.
size
();
++
i
)
{
std
::
ifstream
image_reader
,
label_reader
;
image_reader
.
open
(
image_names_
[
i
],
std
::
ios
::
binary
);
...
...
@@ -354,18 +346,22 @@ Status MnistOp::ParseMnistData() {
Status
MnistOp
::
WalkAllFiles
()
{
const
std
::
string
kImageExtension
=
"idx3-ubyte"
;
const
std
::
string
kLabelExtension
=
"idx1-ubyte"
;
const
std
::
string
train_prefix
=
"train"
;
const
std
::
string
test_prefix
=
"t10k"
;
Path
dir
(
folder_path_
);
auto
dir_it
=
Path
::
DirIterator
::
OpenDirectory
(
&
dir
);
std
::
string
prefix
;
// empty string, used to match usage = "" (default) or usage == "all"
if
(
usage_
==
"train"
||
usage_
==
"test"
)
prefix
=
(
usage_
==
"test"
?
test_prefix
:
train_prefix
);
if
(
dir_it
!=
nullptr
)
{
while
(
dir_it
->
hasNext
())
{
Path
file
=
dir_it
->
next
();
std
::
string
filename
=
file
.
toString
();
if
(
filename
.
find
(
kImageExtension
)
!=
std
::
string
::
npos
)
{
image_names_
.
push_back
(
file
name
);
std
::
string
filename
=
file
.
Basename
();
if
(
filename
.
find
(
prefix
+
"-images-"
+
kImageExtension
)
!=
std
::
string
::
npos
)
{
image_names_
.
push_back
(
file
.
toString
()
);
MS_LOG
(
INFO
)
<<
"Mnist operator found image file at "
<<
filename
<<
"."
;
}
else
if
(
filename
.
find
(
kLabelExtension
)
!=
std
::
string
::
npos
)
{
label_names_
.
push_back
(
file
name
);
}
else
if
(
filename
.
find
(
prefix
+
"-labels-"
+
kLabelExtension
)
!=
std
::
string
::
npos
)
{
label_names_
.
push_back
(
file
.
toString
()
);
MS_LOG
(
INFO
)
<<
"Mnist Operator found label file at "
<<
filename
<<
"."
;
}
}
...
...
@@ -376,9 +372,7 @@ Status MnistOp::WalkAllFiles() {
std
::
sort
(
image_names_
.
begin
(),
image_names_
.
end
());
std
::
sort
(
label_names_
.
begin
(),
label_names_
.
end
());
if
(
image_names_
.
size
()
!=
label_names_
.
size
())
{
RETURN_STATUS_UNEXPECTED
(
"num of images does not equal to num of labels"
);
}
CHECK_FAIL_RETURN_UNEXPECTED
(
image_names_
.
size
()
==
label_names_
.
size
(),
"num of idx3 files != num of idx1 files"
);
return
Status
::
OK
();
}
...
...
@@ -397,11 +391,11 @@ Status MnistOp::LaunchThreadsAndInitOp() {
return
Status
::
OK
();
}
Status
MnistOp
::
CountTotalRows
(
const
std
::
string
&
dir
,
int64_t
*
count
)
{
Status
MnistOp
::
CountTotalRows
(
const
std
::
string
&
dir
,
const
std
::
string
&
usage
,
int64_t
*
count
)
{
// the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader()
std
::
shared_ptr
<
MnistOp
>
op
;
*
count
=
0
;
RETURN_IF_NOT_OK
(
Builder
().
SetDir
(
dir
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
Builder
().
SetDir
(
dir
).
SetUsage
(
usage
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
op
->
WalkAllFiles
());
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h
浏览文件 @
ea947568
...
...
@@ -47,8 +47,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
class
Builder
{
public:
// Constructor for Builder class of MnistOp
// @param uint32_t numWrks - number of parallel workers
// @param dir - directory folder got ImageNetFolder
Builder
();
// Destructor.
...
...
@@ -87,13 +85,20 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
}
// Setter method
// @param const std::string &
dir
// @param const std::string &dir
// @return
Builder
&
SetDir
(
const
std
::
string
&
dir
)
{
builder_dir_
=
dir
;
return
*
this
;
}
// Setter method
// @param const std::string &usage
// @return
Builder
&
SetUsage
(
const
std
::
string
&
usage
)
{
builder_usage_
=
usage
;
return
*
this
;
}
// Check validity of input args
// @return - The error code return
Status
SanityCheck
();
...
...
@@ -105,6 +110,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
private:
std
::
string
builder_dir_
;
std
::
string
builder_usage_
;
int32_t
builder_num_workers_
;
int32_t
builder_rows_per_buffer_
;
int32_t
builder_op_connector_size_
;
...
...
@@ -113,14 +119,15 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
};
// Constructor
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'
// @param int32_t num_workers - number of workers reading images in parallel
// @param int32_t rows_per_buffer - number of images (rows) in each buffer
// @param std::string folder_path - dir directory of mnist
// @param int32_t queue_size - connector queue size
// @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset
// @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read
MnistOp
(
int32_t
num_workers
,
int32_t
rows_per_buffer
,
std
::
string
folder_path
,
int32_t
queue_size
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
);
MnistOp
(
const
std
::
string
&
usage
,
int32_t
num_workers
,
int32_t
rows_per_buffer
,
std
::
string
folder_path
,
int32_t
queue_size
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
);
// Destructor.
~
MnistOp
()
=
default
;
...
...
@@ -150,7 +157,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @param dir path to the MNIST directory
// @param count output arg that will hold the minimum of the actual dataset size and numSamples
// @return
static
Status
CountTotalRows
(
const
std
::
string
&
dir
,
int64_t
*
count
);
static
Status
CountTotalRows
(
const
std
::
string
&
dir
,
const
std
::
string
&
usage
,
int64_t
*
count
);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
...
...
@@ -241,6 +248,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
WaitPost
wp_
;
std
::
string
folder_path_
;
// directory of image folder
int32_t
rows_per_buffer_
;
const
std
::
string
usage_
;
// can only be either "train" or "test"
std
::
unique_ptr
<
DataSchema
>
data_schema_
;
std
::
vector
<
MnistLabelPair
>
image_label_pairs_
;
std
::
vector
<
std
::
string
>
image_names_
;
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc
浏览文件 @
ea947568
...
...
@@ -18,14 +18,15 @@
#include <algorithm>
#include <fstream>
#include <iomanip>
#include "./tinyxml2.h"
#include "utils/ms_utils.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "utils/ms_utils.h"
using
tinyxml2
::
XMLDocument
;
using
tinyxml2
::
XMLElement
;
...
...
@@ -81,7 +82,7 @@ Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) {
RETURN_IF_NOT_OK
(
builder_schema_
->
AddColumn
(
ColDescriptor
(
std
::
string
(
kColumnTruncate
),
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
1
)));
}
*
ptr
=
std
::
make_shared
<
VOCOp
>
(
builder_task_type_
,
builder_
task_mod
e_
,
builder_dir_
,
builder_labels_to_read_
,
*
ptr
=
std
::
make_shared
<
VOCOp
>
(
builder_task_type_
,
builder_
usag
e_
,
builder_dir_
,
builder_labels_to_read_
,
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_op_connector_size_
,
builder_decode_
,
std
::
move
(
builder_schema_
),
std
::
move
(
builder_sampler_
));
return
Status
::
OK
();
...
...
@@ -103,7 +104,7 @@ VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std:
row_cnt_
(
0
),
buf_cnt_
(
0
),
task_type_
(
task_type
),
task_mod
e_
(
task_mode
),
usag
e_
(
task_mode
),
folder_path_
(
folder_path
),
class_index_
(
class_index
),
rows_per_buffer_
(
rows_per_buffer
),
...
...
@@ -251,10 +252,9 @@ Status VOCOp::WorkerEntry(int32_t worker_id) {
Status
VOCOp
::
ParseImageIds
()
{
std
::
string
image_sets_file
;
if
(
task_type_
==
TaskType
::
Segmentation
)
{
image_sets_file
=
folder_path_
+
std
::
string
(
kImageSetsSegmentation
)
+
task_mode_
+
std
::
string
(
kImageSetsExtension
);
image_sets_file
=
folder_path_
+
std
::
string
(
kImageSetsSegmentation
)
+
usage_
+
std
::
string
(
kImageSetsExtension
);
}
else
if
(
task_type_
==
TaskType
::
Detection
)
{
image_sets_file
=
folder_path_
+
std
::
string
(
kImageSetsMain
)
+
task_mod
e_
+
std
::
string
(
kImageSetsExtension
);
image_sets_file
=
folder_path_
+
std
::
string
(
kImageSetsMain
)
+
usag
e_
+
std
::
string
(
kImageSetsExtension
);
}
std
::
ifstream
in_file
;
in_file
.
open
(
image_sets_file
);
...
...
@@ -431,13 +431,13 @@ Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_typ
std
::
shared_ptr
<
VOCOp
>
op
;
RETURN_IF_NOT_OK
(
Builder
().
SetDir
(
dir
).
SetTask
(
task_type
).
Set
Mod
e
(
task_mode
).
SetClassIndex
(
input_class_indexing
).
Build
(
&
op
));
Builder
().
SetDir
(
dir
).
SetTask
(
task_type
).
Set
Usag
e
(
task_mode
).
SetClassIndex
(
input_class_indexing
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
op
->
ParseImageIds
());
RETURN_IF_NOT_OK
(
op
->
ParseAnnotationIds
());
*
count
=
static_cast
<
int64_t
>
(
op
->
image_ids_
.
size
());
}
else
if
(
task_type
==
"Segmentation"
)
{
std
::
shared_ptr
<
VOCOp
>
op
;
RETURN_IF_NOT_OK
(
Builder
().
SetDir
(
dir
).
SetTask
(
task_type
).
Set
Mod
e
(
task_mode
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
Builder
().
SetDir
(
dir
).
SetTask
(
task_type
).
Set
Usag
e
(
task_mode
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
op
->
ParseImageIds
());
*
count
=
static_cast
<
int64_t
>
(
op
->
image_ids_
.
size
());
}
...
...
@@ -458,7 +458,7 @@ Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_t
}
else
{
std
::
shared_ptr
<
VOCOp
>
op
;
RETURN_IF_NOT_OK
(
Builder
().
SetDir
(
dir
).
SetTask
(
task_type
).
Set
Mod
e
(
task_mode
).
SetClassIndex
(
input_class_indexing
).
Build
(
&
op
));
Builder
().
SetDir
(
dir
).
SetTask
(
task_type
).
Set
Usag
e
(
task_mode
).
SetClassIndex
(
input_class_indexing
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
op
->
ParseImageIds
());
RETURN_IF_NOT_OK
(
op
->
ParseAnnotationIds
());
for
(
const
auto
label
:
op
->
label_index_
)
{
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h
浏览文件 @
ea947568
...
...
@@ -73,7 +73,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
}
// Setter method.
// @param const std::string &
task_type
// @param const std::string &task_type
// @return Builder setter method returns reference to the builder.
Builder
&
SetTask
(
const
std
::
string
&
task_type
)
{
if
(
task_type
==
"Segmentation"
)
{
...
...
@@ -85,10 +85,10 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
}
// Setter method.
// @param const std::string &
task_mod
e
// @param const std::string &
usag
e
// @return Builder setter method returns reference to the builder.
Builder
&
Set
Mode
(
const
std
::
string
&
task_mod
e
)
{
builder_
task_mode_
=
task_mod
e
;
Builder
&
Set
Usage
(
const
std
::
string
&
usag
e
)
{
builder_
usage_
=
usag
e
;
return
*
this
;
}
...
...
@@ -145,7 +145,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
bool
builder_decode_
;
std
::
string
builder_dir_
;
TaskType
builder_task_type_
;
std
::
string
builder_
task_mod
e_
;
std
::
string
builder_
usag
e_
;
int32_t
builder_num_workers_
;
int32_t
builder_op_connector_size_
;
int32_t
builder_rows_per_buffer_
;
...
...
@@ -279,7 +279,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
int64_t
buf_cnt_
;
std
::
string
folder_path_
;
TaskType
task_type_
;
std
::
string
task_mod
e_
;
std
::
string
usag
e_
;
int32_t
rows_per_buffer_
;
std
::
unique_ptr
<
DataSchema
>
data_schema_
;
...
...
mindspore/ccsrc/minddata/dataset/include/datasets.h
浏览文件 @
ea947568
...
...
@@ -111,34 +111,36 @@ std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::s
/// \brief Function to create a CelebADataset
/// \notes The generated dataset has two columns ['image', 'attr'].
// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
//
/
The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in]
dataset_type One of 'all', 'train', 'valid' or 'test'
.
/// \param[in]
usage One of "all", "train", "valid" or "test"
.
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \param[in] decode Decode the images after reading (default=false).
/// \param[in] extensions Set of file extensions to be included in the dataset (default={}).
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
CelebADataset
>
CelebA
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
dataset_typ
e
=
"all"
,
std
::
shared_ptr
<
CelebADataset
>
CelebA
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
usag
e
=
"all"
,
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
=
RandomSampler
(),
bool
decode
=
false
,
const
std
::
set
<
std
::
string
>
&
extensions
=
{});
/// \brief Function to create a Cifar10 Dataset
/// \notes The generated dataset has two columns [
'image', 'label'
]
/// \notes The generated dataset has two columns [
"image", "label"
]
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] usage of CIFAR10, can be "train", "test" or "all"
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
Cifar10Dataset
>
Cifar10
(
const
std
::
string
&
dataset_dir
,
std
::
shared_ptr
<
Cifar10Dataset
>
Cifar10
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
usage
=
std
::
string
(),
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
=
RandomSampler
());
/// \brief Function to create a Cifar100 Dataset
/// \notes The generated dataset has three columns [
'image', 'coarse_label', 'fine_label'
]
/// \notes The generated dataset has three columns [
"image", "coarse_label", "fine_label"
]
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] usage of CIFAR100, can be "train", "test" or "all"
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
Cifar100Dataset
>
Cifar100
(
const
std
::
string
&
dataset_dir
,
std
::
shared_ptr
<
Cifar100Dataset
>
Cifar100
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
usage
=
std
::
string
(),
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
=
RandomSampler
());
/// \brief Function to create a CLUEDataset
...
...
@@ -212,7 +214,7 @@ std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, c
/// \brief Function to create an ImageFolderDataset
/// \notes A source dataset that reads images from a tree of directories
/// All images within one folder have the same label
/// The generated dataset has two columns [
'image', 'label'
]
/// The generated dataset has two columns [
"image", "label"
]
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] decode A flag to decode in ImageFolder
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
...
...
@@ -227,7 +229,7 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir,
#ifndef ENABLE_ANDROID
/// \brief Function to create a ManifestDataset
/// \notes The generated dataset has two columns [
'image', 'label'
]
/// \notes The generated dataset has two columns [
"image", "label"
]
/// \param[in] dataset_file The dataset file to be read
/// \param[in] usage Need "train", "eval" or "inference" data (default="train")
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
...
...
@@ -243,12 +245,13 @@ std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const
#endif
/// \brief Function to create a MnistDataset
/// \notes The generated dataset has two columns [
'image', 'label'
]
/// \notes The generated dataset has two columns [
"image", "label"
]
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] usage of MNIST, can be "train", "test" or "all"
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \return Shared pointer to the current MnistDataset
std
::
shared_ptr
<
MnistDataset
>
Mnist
(
const
std
::
string
&
dataset_dir
,
std
::
shared_ptr
<
MnistDataset
>
Mnist
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
usage
=
std
::
string
(),
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
=
RandomSampler
());
/// \brief Function to create a ConcatDataset
...
...
@@ -404,14 +407,14 @@ std::shared_ptr<TFRecordDataset> TFRecord(const std::vector<std::string> &datase
/// - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]].
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] task Set the task type of reading voc data, now only support "Segmentation" or "Detection"
/// \param[in]
mode Set the data list txt file to be reade
d
/// \param[in]
usage The type of data list text file to be rea
d
/// \param[in] class_indexing A str-to-int mapping from label name to index, only valid in "Detection" task
/// \param[in] decode Decode the images after reading
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
VOCDataset
>
VOC
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
task
=
"Segmentation"
,
const
std
::
string
&
mod
e
=
"train"
,
const
std
::
string
&
usag
e
=
"train"
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
class_indexing
=
{},
bool
decode
=
false
,
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
=
RandomSampler
());
#endif
...
...
@@ -702,9 +705,8 @@ class AlbumDataset : public Dataset {
class
CelebADataset
:
public
Dataset
{
public:
/// \brief Constructor
CelebADataset
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
dataset_type
,
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
,
const
bool
&
decode
,
const
std
::
set
<
std
::
string
>
&
extensions
);
CelebADataset
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
usage
,
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
,
const
bool
&
decode
,
const
std
::
set
<
std
::
string
>
&
extensions
);
/// \brief Destructor
~
CelebADataset
()
=
default
;
...
...
@@ -719,7 +721,7 @@ class CelebADataset : public Dataset {
private:
std
::
string
dataset_dir_
;
std
::
string
dataset_typ
e_
;
std
::
string
usag
e_
;
bool
decode_
;
std
::
set
<
std
::
string
>
extensions_
;
std
::
shared_ptr
<
SamplerObj
>
sampler_
;
...
...
@@ -730,7 +732,7 @@ class CelebADataset : public Dataset {
class
Cifar10Dataset
:
public
Dataset
{
public:
/// \brief Constructor
Cifar10Dataset
(
const
std
::
string
&
dataset_dir
,
std
::
shared_ptr
<
SamplerObj
>
sampler
);
Cifar10Dataset
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
usage
,
std
::
shared_ptr
<
SamplerObj
>
sampler
);
/// \brief Destructor
~
Cifar10Dataset
()
=
default
;
...
...
@@ -745,13 +747,14 @@ class Cifar10Dataset : public Dataset {
private:
std
::
string
dataset_dir_
;
std
::
string
usage_
;
std
::
shared_ptr
<
SamplerObj
>
sampler_
;
};
class
Cifar100Dataset
:
public
Dataset
{
public:
/// \brief Constructor
Cifar100Dataset
(
const
std
::
string
&
dataset_dir
,
std
::
shared_ptr
<
SamplerObj
>
sampler
);
Cifar100Dataset
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
usage
,
std
::
shared_ptr
<
SamplerObj
>
sampler
);
/// \brief Destructor
~
Cifar100Dataset
()
=
default
;
...
...
@@ -766,6 +769,7 @@ class Cifar100Dataset : public Dataset {
private:
std
::
string
dataset_dir_
;
std
::
string
usage_
;
std
::
shared_ptr
<
SamplerObj
>
sampler_
;
};
...
...
@@ -831,7 +835,7 @@ class CocoDataset : public Dataset {
enum
CsvType
:
uint8_t
{
INT
=
0
,
FLOAT
,
STRING
};
/// \brief Base class of CSV Record
struct
CsvBase
{
class
CsvBase
{
public:
CsvBase
()
=
default
;
explicit
CsvBase
(
CsvType
t
)
:
type
(
t
)
{}
...
...
@@ -936,7 +940,7 @@ class ManifestDataset : public Dataset {
class
MnistDataset
:
public
Dataset
{
public:
/// \brief Constructor
MnistDataset
(
std
::
string
dataset_dir
,
std
::
shared_ptr
<
SamplerObj
>
sampler
);
MnistDataset
(
std
::
string
dataset_dir
,
std
::
s
tring
usage
,
std
::
s
hared_ptr
<
SamplerObj
>
sampler
);
/// \brief Destructor
~
MnistDataset
()
=
default
;
...
...
@@ -951,6 +955,7 @@ class MnistDataset : public Dataset {
private:
std
::
string
dataset_dir_
;
std
::
string
usage_
;
std
::
shared_ptr
<
SamplerObj
>
sampler_
;
};
...
...
@@ -1087,7 +1092,7 @@ class TFRecordDataset : public Dataset {
class
VOCDataset
:
public
Dataset
{
public:
/// \brief Constructor
VOCDataset
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
task
,
const
std
::
string
&
mod
e
,
VOCDataset
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
task
,
const
std
::
string
&
usag
e
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
class_indexing
,
bool
decode
,
std
::
shared_ptr
<
SamplerObj
>
sampler
);
/// \brief Destructor
...
...
@@ -1110,7 +1115,7 @@ class VOCDataset : public Dataset {
const
std
::
string
kColumnTruncate
=
"truncate"
;
std
::
string
dataset_dir_
;
std
::
string
task_
;
std
::
string
mod
e_
;
std
::
string
usag
e_
;
std
::
map
<
std
::
string
,
int32_t
>
class_index_
;
bool
decode_
;
std
::
shared_ptr
<
SamplerObj
>
sampler_
;
...
...
mindspore/dataset/core/validator_helpers.py
浏览文件 @
ea947568
...
...
@@ -132,6 +132,12 @@ def check_valid_detype(type_):
return
True
def
check_valid_str
(
value
,
valid_strings
,
arg_name
=
""
):
type_check
(
value
,
(
str
,),
arg_name
)
if
value
not
in
valid_strings
:
raise
ValueError
(
"Input {0} is not within the valid set of {1}."
.
format
(
arg_name
,
str
(
valid_strings
)))
def
check_columns
(
columns
,
name
):
"""
Validate strings in column_names.
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
ea947568
...
...
@@ -2877,6 +2877,9 @@ class MnistDataset(MappableDataset):
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be "train", "test" or "all" . "train" will read from 60,000
train samples, "test" will read from 10,000 test samples, "all" will read from all 70,000 samples.
(default=None, all samples)
num_samples (int, optional): The number of images to be included in the dataset
(default=None, all images).
num_parallel_workers (int, optional): Number of workers to read the data
...
...
@@ -2906,11 +2909,12 @@ class MnistDataset(MappableDataset):
"""
@
check_mnist_cifar_dataset
def
__init__
(
self
,
dataset_dir
,
num_samples
=
None
,
num_parallel_workers
=
None
,
def
__init__
(
self
,
dataset_dir
,
usage
=
None
,
num_samples
=
None
,
num_parallel_workers
=
None
,
shuffle
=
None
,
sampler
=
None
,
num_shards
=
None
,
shard_id
=
None
):
super
().
__init__
(
num_parallel_workers
)
self
.
dataset_dir
=
dataset_dir
self
.
usage
=
usage
self
.
sampler
=
_select_sampler
(
num_samples
,
sampler
,
shuffle
,
num_shards
,
shard_id
)
self
.
num_samples
=
num_samples
self
.
shuffle_level
=
shuffle
...
...
@@ -2920,6 +2924,7 @@ class MnistDataset(MappableDataset):
def
get_args
(
self
):
args
=
super
().
get_args
()
args
[
"dataset_dir"
]
=
self
.
dataset_dir
args
[
"usage"
]
=
self
.
usage
args
[
"num_samples"
]
=
self
.
num_samples
args
[
"shuffle"
]
=
self
.
shuffle_level
args
[
"sampler"
]
=
self
.
sampler
...
...
@@ -2935,7 +2940,7 @@ class MnistDataset(MappableDataset):
Number, number of batches.
"""
if
self
.
dataset_size
is
None
:
num_rows
=
MnistOp
.
get_num_rows
(
self
.
dataset_dir
)
num_rows
=
MnistOp
.
get_num_rows
(
self
.
dataset_dir
,
"all"
if
self
.
usage
is
None
else
self
.
usage
)
self
.
dataset_size
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
if
rows_from_sampler
is
not
None
and
rows_from_sampler
<
self
.
dataset_size
:
...
...
@@ -3913,6 +3918,9 @@ class Cifar10Dataset(MappableDataset):
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be "train", "test" or "all" . "train" will read from 50,000
train samples, "test" will read from 10,000 test samples, "all" will read from all 60,000 samples.
(default=None, all samples)
num_samples (int, optional): The number of images to be included in the dataset.
(default=None, all images).
num_parallel_workers (int, optional): Number of workers to read the data
...
...
@@ -3946,11 +3954,12 @@ class Cifar10Dataset(MappableDataset):
"""
@
check_mnist_cifar_dataset
def
__init__
(
self
,
dataset_dir
,
num_samples
=
None
,
num_parallel_workers
=
None
,
def
__init__
(
self
,
dataset_dir
,
usage
=
None
,
num_samples
=
None
,
num_parallel_workers
=
None
,
shuffle
=
None
,
sampler
=
None
,
num_shards
=
None
,
shard_id
=
None
):
super
().
__init__
(
num_parallel_workers
)
self
.
dataset_dir
=
dataset_dir
self
.
usage
=
usage
self
.
sampler
=
_select_sampler
(
num_samples
,
sampler
,
shuffle
,
num_shards
,
shard_id
)
self
.
num_samples
=
num_samples
self
.
num_shards
=
num_shards
...
...
@@ -3960,6 +3969,7 @@ class Cifar10Dataset(MappableDataset):
def
get_args
(
self
):
args
=
super
().
get_args
()
args
[
"dataset_dir"
]
=
self
.
dataset_dir
args
[
"usage"
]
=
self
.
usage
args
[
"num_samples"
]
=
self
.
num_samples
args
[
"sampler"
]
=
self
.
sampler
args
[
"num_shards"
]
=
self
.
num_shards
...
...
@@ -3975,7 +3985,7 @@ class Cifar10Dataset(MappableDataset):
Number, number of batches.
"""
if
self
.
dataset_size
is
None
:
num_rows
=
CifarOp
.
get_num_rows
(
self
.
dataset_dir
,
True
)
num_rows
=
CifarOp
.
get_num_rows
(
self
.
dataset_dir
,
"all"
if
self
.
usage
is
None
else
self
.
usage
,
True
)
self
.
dataset_size
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
...
...
@@ -4051,6 +4061,9 @@ class Cifar100Dataset(MappableDataset):
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be "train", "test" or "all" . "train" will read from 50,000
train samples, "test" will read from 10,000 test samples, "all" will read from all 60,000 samples.
(default=None, all samples)
num_samples (int, optional): The number of images to be included in the dataset.
(default=None, all images).
num_parallel_workers (int, optional): Number of workers to read the data
...
...
@@ -4082,11 +4095,12 @@ class Cifar100Dataset(MappableDataset):
"""
@
check_mnist_cifar_dataset
def
__init__
(
self
,
dataset_dir
,
num_samples
=
None
,
num_parallel_workers
=
None
,
def
__init__
(
self
,
dataset_dir
,
usage
=
None
,
num_samples
=
None
,
num_parallel_workers
=
None
,
shuffle
=
None
,
sampler
=
None
,
num_shards
=
None
,
shard_id
=
None
):
super
().
__init__
(
num_parallel_workers
)
self
.
dataset_dir
=
dataset_dir
self
.
usage
=
usage
self
.
sampler
=
_select_sampler
(
num_samples
,
sampler
,
shuffle
,
num_shards
,
shard_id
)
self
.
num_samples
=
num_samples
self
.
num_shards
=
num_shards
...
...
@@ -4096,6 +4110,7 @@ class Cifar100Dataset(MappableDataset):
def
get_args
(
self
):
args
=
super
().
get_args
()
args
[
"dataset_dir"
]
=
self
.
dataset_dir
args
[
"usage"
]
=
self
.
usage
args
[
"num_samples"
]
=
self
.
num_samples
args
[
"sampler"
]
=
self
.
sampler
args
[
"num_shards"
]
=
self
.
num_shards
...
...
@@ -4111,7 +4126,7 @@ class Cifar100Dataset(MappableDataset):
Number, number of batches.
"""
if
self
.
dataset_size
is
None
:
num_rows
=
CifarOp
.
get_num_rows
(
self
.
dataset_dir
,
False
)
num_rows
=
CifarOp
.
get_num_rows
(
self
.
dataset_dir
,
"all"
if
self
.
usage
is
None
else
self
.
usage
,
False
)
self
.
dataset_size
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
...
...
@@ -4467,7 +4482,7 @@ class VOCDataset(MappableDataset):
dataset_dir (str): Path to the root directory that contains the dataset.
task (str): Set the task type of reading voc data, now only support "Segmentation" or "Detection"
(default="Segmentation").
mode (str): Set the data list txt file to be reade
d (default="train").
usage (str): The type of data list text file to be rea
d (default="train").
class_indexing (dict, optional): A str-to-int mapping from label name to index, only valid in
"Detection" task (default=None, the folder names will be sorted alphabetically and each
class will be given a unique index starting from 0).
...
...
@@ -4502,24 +4517,24 @@ class VOCDataset(MappableDataset):
>>> import mindspore.dataset as ds
>>> dataset_dir = "/path/to/voc_dataset_directory"
>>> # 1) read VOC data for segmenatation train
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Segmentation",
mod
e="train")
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Segmentation",
usag
e="train")
>>> # 2) read VOC data for detection train
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection",
mod
e="train")
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection",
usag
e="train")
>>> # 3) read all VOC dataset samples in dataset_dir with 8 threads in random order:
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection",
mod
e="train", num_parallel_workers=8)
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection",
usag
e="train", num_parallel_workers=8)
>>> # 4) read then decode all VOC dataset samples in dataset_dir in sequence:
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection",
mod
e="train", decode=True, shuffle=False)
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection",
usag
e="train", decode=True, shuffle=False)
>>> # in VOC dataset, if task='Segmentation', each dictionary has keys "image" and "target"
>>> # in VOC dataset, if task='Detection', each dictionary has keys "image" and "annotation"
"""
@
check_vocdataset
def
__init__
(
self
,
dataset_dir
,
task
=
"Segmentation"
,
mod
e
=
"train"
,
class_indexing
=
None
,
num_samples
=
None
,
def
__init__
(
self
,
dataset_dir
,
task
=
"Segmentation"
,
usag
e
=
"train"
,
class_indexing
=
None
,
num_samples
=
None
,
num_parallel_workers
=
None
,
shuffle
=
None
,
decode
=
False
,
sampler
=
None
,
num_shards
=
None
,
shard_id
=
None
):
super
().
__init__
(
num_parallel_workers
)
self
.
dataset_dir
=
dataset_dir
self
.
task
=
task
self
.
mode
=
mod
e
self
.
usage
=
usag
e
self
.
class_indexing
=
class_indexing
self
.
sampler
=
_select_sampler
(
num_samples
,
sampler
,
shuffle
,
num_shards
,
shard_id
)
self
.
num_samples
=
num_samples
...
...
@@ -4532,7 +4547,7 @@ class VOCDataset(MappableDataset):
args
=
super
().
get_args
()
args
[
"dataset_dir"
]
=
self
.
dataset_dir
args
[
"task"
]
=
self
.
task
args
[
"
mode"
]
=
self
.
mod
e
args
[
"
usage"
]
=
self
.
usag
e
args
[
"class_indexing"
]
=
self
.
class_indexing
args
[
"num_samples"
]
=
self
.
num_samples
args
[
"sampler"
]
=
self
.
sampler
...
...
@@ -4560,7 +4575,7 @@ class VOCDataset(MappableDataset):
else
:
class_indexing
=
self
.
class_indexing
num_rows
=
VOCOp
.
get_num_rows
(
self
.
dataset_dir
,
self
.
task
,
self
.
mod
e
,
class_indexing
,
num_samples
)
num_rows
=
VOCOp
.
get_num_rows
(
self
.
dataset_dir
,
self
.
task
,
self
.
usag
e
,
class_indexing
,
num_samples
)
self
.
dataset_size
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
...
...
@@ -4584,7 +4599,7 @@ class VOCDataset(MappableDataset):
else
:
class_indexing
=
self
.
class_indexing
return
VOCOp
.
get_class_indexing
(
self
.
dataset_dir
,
self
.
task
,
self
.
mod
e
,
class_indexing
)
return
VOCOp
.
get_class_indexing
(
self
.
dataset_dir
,
self
.
task
,
self
.
usag
e
,
class_indexing
)
def
is_shuffled
(
self
):
if
self
.
shuffle_level
is
None
:
...
...
@@ -4824,7 +4839,7 @@ class CelebADataset(MappableDataset):
dataset_dir (str): Path to the root directory that contains the dataset.
num_parallel_workers (int, optional): Number of workers to read the data (default=value set in the config).
shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None).
dataset_typ
e (str): one of 'all', 'train', 'valid' or 'test'.
usag
e (str): one of 'all', 'train', 'valid' or 'test'.
sampler (Sampler, optional): Object used to choose samples from the dataset (default=None).
decode (bool, optional): decode the images after reading (default=False).
extensions (list[str], optional): List of file extensions to be
...
...
@@ -4838,8 +4853,8 @@ class CelebADataset(MappableDataset):
"""
@
check_celebadataset
def
__init__
(
self
,
dataset_dir
,
num_parallel_workers
=
None
,
shuffle
=
None
,
dataset_type
=
'all'
,
sampler
=
None
,
decode
=
False
,
extensions
=
None
,
num_samples
=
None
,
num_shards
=
None
,
shard_id
=
None
):
def
__init__
(
self
,
dataset_dir
,
num_parallel_workers
=
None
,
shuffle
=
None
,
usage
=
'all'
,
sampler
=
None
,
decode
=
False
,
extensions
=
None
,
num_samples
=
None
,
num_shards
=
None
,
shard_id
=
None
):
super
().
__init__
(
num_parallel_workers
)
self
.
dataset_dir
=
dataset_dir
self
.
sampler
=
_select_sampler
(
num_samples
,
sampler
,
shuffle
,
num_shards
,
shard_id
)
...
...
@@ -4847,7 +4862,7 @@ class CelebADataset(MappableDataset):
self
.
decode
=
decode
self
.
extensions
=
extensions
self
.
num_samples
=
num_samples
self
.
dataset_type
=
dataset_typ
e
self
.
usage
=
usag
e
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
self
.
shuffle_level
=
shuffle
...
...
@@ -4860,7 +4875,7 @@ class CelebADataset(MappableDataset):
args
[
"decode"
]
=
self
.
decode
args
[
"extensions"
]
=
self
.
extensions
args
[
"num_samples"
]
=
self
.
num_samples
args
[
"
dataset_type"
]
=
self
.
dataset_typ
e
args
[
"
usage"
]
=
self
.
usag
e
args
[
"num_shards"
]
=
self
.
num_shards
args
[
"shard_id"
]
=
self
.
shard_id
return
args
...
...
mindspore/dataset/engine/serializer_deserializer.py
浏览文件 @
ea947568
...
...
@@ -273,7 +273,7 @@ def create_node(node):
elif
dataset_op
==
'MnistDataset'
:
sampler
=
construct_sampler
(
node
.
get
(
'sampler'
))
pyobj
=
pyclass
(
node
[
'dataset_dir'
],
node
.
get
(
'num_samples'
),
node
.
get
(
'num_parallel_workers'
),
pyobj
=
pyclass
(
node
[
'dataset_dir'
],
node
[
'usage'
],
node
.
get
(
'num_samples'
),
node
.
get
(
'num_parallel_workers'
),
node
.
get
(
'shuffle'
),
sampler
,
node
.
get
(
'num_shards'
),
node
.
get
(
'shard_id'
))
elif
dataset_op
==
'MindDataset'
:
...
...
@@ -296,12 +296,12 @@ def create_node(node):
elif
dataset_op
==
'Cifar10Dataset'
:
sampler
=
construct_sampler
(
node
.
get
(
'sampler'
))
pyobj
=
pyclass
(
node
[
'dataset_dir'
],
node
.
get
(
'num_samples'
),
node
.
get
(
'num_parallel_workers'
),
pyobj
=
pyclass
(
node
[
'dataset_dir'
],
node
[
'usage'
],
node
.
get
(
'num_samples'
),
node
.
get
(
'num_parallel_workers'
),
node
.
get
(
'shuffle'
),
sampler
,
node
.
get
(
'num_shards'
),
node
.
get
(
'shard_id'
))
elif
dataset_op
==
'Cifar100Dataset'
:
sampler
=
construct_sampler
(
node
.
get
(
'sampler'
))
pyobj
=
pyclass
(
node
[
'dataset_dir'
],
node
.
get
(
'num_samples'
),
node
.
get
(
'num_parallel_workers'
),
pyobj
=
pyclass
(
node
[
'dataset_dir'
],
node
[
'usage'
],
node
.
get
(
'num_samples'
),
node
.
get
(
'num_parallel_workers'
),
node
.
get
(
'shuffle'
),
sampler
,
node
.
get
(
'num_shards'
),
node
.
get
(
'shard_id'
))
elif
dataset_op
==
'VOCDataset'
:
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
ea947568
...
...
@@ -27,7 +27,7 @@ from mindspore.dataset.callback import DSCallback
from
..core.validator_helpers
import
parse_user_args
,
type_check
,
type_check_list
,
check_value
,
\
INT32_MAX
,
check_valid_detype
,
check_dir
,
check_file
,
check_sampler_shuffle_shard_options
,
\
validate_dataset_param_value
,
check_padding_options
,
check_gnn_list_or_ndarray
,
check_num_parallel_workers
,
\
check_columns
,
check_pos_int32
check_columns
,
check_pos_int32
,
check_valid_str
from
.
import
datasets
from
.
import
samplers
...
...
@@ -74,6 +74,10 @@ def check_mnist_cifar_dataset(method):
dataset_dir
=
param_dict
.
get
(
'dataset_dir'
)
check_dir
(
dataset_dir
)
usage
=
param_dict
.
get
(
'usage'
)
if
usage
is
not
None
:
check_valid_str
(
usage
,
[
"train"
,
"test"
,
"all"
],
"usage"
)
validate_dataset_param_value
(
nreq_param_int
,
param_dict
,
int
)
validate_dataset_param_value
(
nreq_param_bool
,
param_dict
,
bool
)
...
...
@@ -154,15 +158,15 @@ def check_vocdataset(method):
task
=
param_dict
.
get
(
'task'
)
type_check
(
task
,
(
str
,),
"task"
)
mode
=
param_dict
.
get
(
'mod
e'
)
type_check
(
mode
,
(
str
,),
"mod
e"
)
usage
=
param_dict
.
get
(
'usag
e'
)
type_check
(
usage
,
(
str
,),
"usag
e"
)
if
task
==
"Segmentation"
:
imagesets_file
=
os
.
path
.
join
(
dataset_dir
,
"ImageSets"
,
"Segmentation"
,
mod
e
+
".txt"
)
imagesets_file
=
os
.
path
.
join
(
dataset_dir
,
"ImageSets"
,
"Segmentation"
,
usag
e
+
".txt"
)
if
param_dict
.
get
(
'class_indexing'
)
is
not
None
:
raise
ValueError
(
"class_indexing is invalid in Segmentation task"
)
elif
task
==
"Detection"
:
imagesets_file
=
os
.
path
.
join
(
dataset_dir
,
"ImageSets"
,
"Main"
,
mod
e
+
".txt"
)
imagesets_file
=
os
.
path
.
join
(
dataset_dir
,
"ImageSets"
,
"Main"
,
usag
e
+
".txt"
)
else
:
raise
ValueError
(
"Invalid task : "
+
task
)
...
...
@@ -235,9 +239,9 @@ def check_celebadataset(method):
validate_dataset_param_value
(
nreq_param_list
,
param_dict
,
list
)
validate_dataset_param_value
(
nreq_param_str
,
param_dict
,
str
)
dataset_type
=
param_dict
.
get
(
'dataset_typ
e'
)
if
dataset_type
is
not
None
and
dataset_typ
e
not
in
(
'all'
,
'train'
,
'valid'
,
'test'
):
raise
ValueError
(
"
dataset_typ
e should be one of 'all', 'train', 'valid' or 'test'."
)
usage
=
param_dict
.
get
(
'usag
e'
)
if
usage
is
not
None
and
usag
e
not
in
(
'all'
,
'train'
,
'valid'
,
'test'
):
raise
ValueError
(
"
usag
e should be one of 'all', 'train', 'valid' or 'test'."
)
check_sampler_shuffle_shard_options
(
param_dict
)
...
...
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
ea947568
...
...
@@ -5,7 +5,7 @@ SET(DE_UT_SRCS
common/cvop_common.cc
common/bboxop_common.cc
auto_contrast_op_test.cc
album_op_test.cc
album_op_test.cc
batch_op_test.cc
bit_functions_test.cc
storage_container_test.cc
...
...
@@ -62,8 +62,8 @@ SET(DE_UT_SRCS
rescale_op_test.cc
resize_op_test.cc
resize_with_bbox_op_test.cc
rgba_to_bgr_op_test.cc
rgba_to_rgb_op_test.cc
rgba_to_bgr_op_test.cc
rgba_to_rgb_op_test.cc
schema_test.cc
skip_op_test.cc
shuffle_op_test.cc
...
...
tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc
浏览文件 @
ea947568
...
...
@@ -28,7 +28,7 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) {
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
...
...
@@ -45,10 +45,10 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) {
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
i
++
;
auto
image
=
row
[
"image"
];
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
iter
->
GetNextRow
(
&
row
);
i
++
;
auto
image
=
row
[
"image"
];
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
iter
->
GetNextRow
(
&
row
);
}
EXPECT_EQ
(
i
,
10
);
...
...
@@ -62,7 +62,7 @@ TEST_F(MindDataTestPipeline, TestCifar100Dataset) {
// Create a Cifar100 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar100Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar100
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar100
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
...
...
@@ -96,7 +96,7 @@ TEST_F(MindDataTestPipeline, TestCifar100DatasetFail1) {
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestCifar100DatasetFail1."
;
// Create a Cifar100 Dataset
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar100
(
""
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar100
(
""
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_EQ
(
ds
,
nullptr
);
}
...
...
@@ -104,7 +104,7 @@ TEST_F(MindDataTestPipeline, TestCifar10DatasetFail1) {
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestCifar10DatasetFail1."
;
// Create a Cifar10 Dataset
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
""
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
""
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_EQ
(
ds
,
nullptr
);
}
...
...
@@ -113,7 +113,7 @@ TEST_F(MindDataTestPipeline, TestCifar10DatasetWithNullSampler) {
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
nullptr
);
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
std
::
string
(),
nullptr
);
// Expect failure: sampler can not be nullptr
EXPECT_EQ
(
ds
,
nullptr
);
}
...
...
@@ -123,7 +123,7 @@ TEST_F(MindDataTestPipeline, TestCifar100DatasetWithNullSampler) {
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar100Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar100
(
folder_path
,
nullptr
);
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar100
(
folder_path
,
std
::
string
(),
nullptr
);
// Expect failure: sampler can not be nullptr
EXPECT_EQ
(
ds
,
nullptr
);
}
...
...
@@ -133,7 +133,7 @@ TEST_F(MindDataTestPipeline, TestCifar100DatasetWithWrongSampler) {
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar100Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar100
(
folder_path
,
RandomSampler
(
false
,
-
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar100
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
-
10
));
// Expect failure: sampler is not construnced correctly
EXPECT_EQ
(
ds
,
nullptr
);
}
tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc
浏览文件 @
ea947568
...
...
@@ -28,7 +28,7 @@ TEST_F(MindDataTestPipeline, TestIteratorEmptyColumn) {
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestIteratorEmptyColumn."
;
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
5
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
5
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Rename operation on ds
...
...
@@ -64,7 +64,7 @@ TEST_F(MindDataTestPipeline, TestIteratorOneColumn) {
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestIteratorOneColumn."
;
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
4
));
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
4
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
...
...
@@ -103,7 +103,7 @@ TEST_F(MindDataTestPipeline, TestIteratorReOrder) {
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestIteratorReOrder."
;
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
SequentialSampler
(
false
,
4
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
std
::
string
(),
SequentialSampler
(
false
,
4
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Take operation on ds
...
...
@@ -160,9 +160,8 @@ TEST_F(MindDataTestPipeline, TestIteratorTwoColumns) {
// Iterate the dataset and get each row
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
std
::
vector
<
TensorShape
>
expect
=
{
TensorShape
({
173673
}),
TensorShape
({
1
,
4
}),
TensorShape
({
173673
}),
TensorShape
({
1
,
4
}),
TensorShape
({
147025
}),
TensorShape
({
1
,
4
}),
std
::
vector
<
TensorShape
>
expect
=
{
TensorShape
({
173673
}),
TensorShape
({
1
,
4
}),
TensorShape
({
173673
}),
TensorShape
({
1
,
4
}),
TensorShape
({
147025
}),
TensorShape
({
1
,
4
}),
TensorShape
({
211653
}),
TensorShape
({
1
,
4
})};
uint64_t
i
=
0
;
...
...
@@ -187,7 +186,7 @@ TEST_F(MindDataTestPipeline, TestIteratorWrongColumn) {
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestIteratorOneColumn."
;
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
4
));
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
4
));
EXPECT_NE
(
ds
,
nullptr
);
// Pass wrong column name
...
...
tests/ut/cpp/dataset/c_api_dataset_ops_test.cc
浏览文件 @
ea947568
...
...
@@ -40,7 +40,7 @@ TEST_F(MindDataTestPipeline, TestBatchAndRepeat) {
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Repeat operation on ds
...
...
@@ -82,7 +82,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess1) {
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a BucketBatchByLength operation on ds
...
...
@@ -118,13 +118,12 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess2) {
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a BucketBatchByLength operation on ds
std
::
map
<
std
::
string
,
std
::
pair
<
mindspore
::
dataset
::
TensorShape
,
std
::
shared_ptr
<
Tensor
>>>
pad_info
;
ds
=
ds
->
BucketBatchByLength
({
"image"
},
{
1
,
2
},
{
1
,
2
,
3
},
&
BucketBatchTestFunction
,
pad_info
,
true
,
true
);
ds
=
ds
->
BucketBatchByLength
({
"image"
},
{
1
,
2
},
{
1
,
2
,
3
},
&
BucketBatchTestFunction
,
pad_info
,
true
,
true
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
...
...
@@ -157,7 +156,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail1) {
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a BucketBatchByLength operation on ds
...
...
@@ -172,7 +171,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail2) {
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a BucketBatchByLength operation on ds
...
...
@@ -187,7 +186,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail3) {
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a BucketBatchByLength operation on ds
...
...
@@ -202,7 +201,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail4) {
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a BucketBatchByLength operation on ds
...
...
@@ -217,7 +216,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail5) {
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a BucketBatchByLength operation on ds
...
...
@@ -232,7 +231,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail6) {
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a BucketBatchByLength operation on ds
ds
=
ds
->
BucketBatchByLength
({
"image"
},
{
1
,
2
},
{
1
,
-
2
,
3
});
...
...
@@ -246,7 +245,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail7) {
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a BucketBatchByLength operation on ds
...
...
@@ -313,7 +312,7 @@ TEST_F(MindDataTestPipeline, TestConcatSuccess) {
// Create a Cifar10 Dataset
// Column names: {"image", "label"}
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds2
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
9
));
std
::
shared_ptr
<
Dataset
>
ds2
=
Cifar10
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
9
));
EXPECT_NE
(
ds2
,
nullptr
);
// Create a Project operation on ds
...
...
@@ -365,7 +364,7 @@ TEST_F(MindDataTestPipeline, TestConcatSuccess2) {
// Create a Cifar10 Dataset
// Column names: {"image", "label"}
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds2
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
9
));
std
::
shared_ptr
<
Dataset
>
ds2
=
Cifar10
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
9
));
EXPECT_NE
(
ds2
,
nullptr
);
// Create a Project operation on ds
...
...
@@ -704,11 +703,11 @@ TEST_F(MindDataTestPipeline, TestRenameSuccess) {
}
TEST_F
(
MindDataTestPipeline
,
TestRepeatDefault
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestRepeatDefault."
;
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestRepeatDefault."
;
// Create an ImageFolder Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
ImageFolder
(
folder_path
,
true
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
ImageFolder
(
folder_path
,
true
,
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Repeat operation on ds
...
...
@@ -723,21 +722,21 @@ TEST_F(MindDataTestPipeline, TestRepeatDefault) {
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// iterate over the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
while
(
row
.
size
()
!=
0
)
{
// manually stop
if
(
i
==
100
)
{
break
;
}
i
++
;
auto
image
=
row
[
"image"
];
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
iter
->
GetNextRow
(
&
row
);
}
...
...
@@ -747,11 +746,11 @@ TEST_F(MindDataTestPipeline, TestRepeatDefault) {
}
TEST_F
(
MindDataTestPipeline
,
TestRepeatOne
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestRepeatOne."
;
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestRepeatOne."
;
// Create an ImageFolder Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
ImageFolder
(
folder_path
,
true
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
ImageFolder
(
folder_path
,
true
,
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Repeat operation on ds
...
...
@@ -766,17 +765,17 @@ TEST_F(MindDataTestPipeline, TestRepeatOne) {
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// iterate over the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
while
(
row
.
size
()
!=
0
)
{
i
++
;
auto
image
=
row
[
"image"
];
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
iter
->
GetNextRow
(
&
row
);
}
...
...
@@ -1013,7 +1012,7 @@ TEST_F(MindDataTestPipeline, TestTensorOpsAndMap) {
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
20
));
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
20
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Repeat operation on ds
...
...
@@ -1060,7 +1059,6 @@ TEST_F(MindDataTestPipeline, TestTensorOpsAndMap) {
iter
->
Stop
();
}
TEST_F
(
MindDataTestPipeline
,
TestZipFail
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestZipFail."
;
// We expect this test to fail because we are the both datasets we are zipping have "image" and "label" columns
...
...
@@ -1128,7 +1126,7 @@ TEST_F(MindDataTestPipeline, TestZipSuccess) {
EXPECT_NE
(
ds1
,
nullptr
);
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds2
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds2
=
Cifar10
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds2
,
nullptr
);
// Create a Project operation on ds
...
...
tests/ut/cpp/dataset/c_api_datasets_test.cc
浏览文件 @
ea947568
...
...
@@ -43,10 +43,11 @@ TEST_F(MindDataTestPipeline, TestCelebADataset) {
// Check if CelebAOp read correct images/attr
std
::
string
expect_file
[]
=
{
"1.JPEG"
,
"2.jpg"
};
std
::
vector
<
std
::
vector
<
uint32_t
>>
expect_attr_vector
=
{{
0
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
0
,
1
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
1
},
{
0
,
0
,
0
,
1
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
}};
std
::
vector
<
std
::
vector
<
uint32_t
>>
expect_attr_vector
=
{
{
0
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
0
,
1
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
1
},
{
0
,
0
,
0
,
1
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
}};
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
auto
image
=
row
[
"image"
];
...
...
@@ -132,7 +133,7 @@ TEST_F(MindDataTestPipeline, TestMnistFailWithWrongDatasetDir) {
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestMnistFailWithWrongDatasetDir."
;
// Create a Mnist Dataset
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
""
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
""
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_EQ
(
ds
,
nullptr
);
}
...
...
@@ -141,7 +142,7 @@ TEST_F(MindDataTestPipeline, TestMnistFailWithNullSampler) {
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
nullptr
);
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
std
::
string
(),
nullptr
);
// Expect failure: sampler can not be nullptr
EXPECT_EQ
(
ds
,
nullptr
);
}
...
...
tests/ut/cpp/dataset/c_api_transforms_test.cc
浏览文件 @
ea947568
...
...
@@ -30,7 +30,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) {
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
int
number_of_classes
=
10
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create objects for the tensor ops
...
...
@@ -38,7 +38,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) {
EXPECT_NE
(
hwc_to_chw
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
hwc_to_chw
},{
"image"
});
ds
=
ds
->
Map
({
hwc_to_chw
},
{
"image"
});
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
...
...
@@ -51,10 +51,11 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) {
EXPECT_NE
(
one_hot_op
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
one_hot_op
},{
"label"
});
ds
=
ds
->
Map
({
one_hot_op
},
{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNCHW
,
1.0
,
1.0
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNCHW
,
1.0
,
1.0
);
EXPECT_NE
(
cutmix_batch_op
,
nullptr
);
// Create a Map operation on ds
...
...
@@ -77,10 +78,12 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) {
auto
label
=
row
[
"label"
];
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
MS_LOG
(
INFO
)
<<
"Label shape: "
<<
label
->
shape
();
EXPECT_EQ
(
image
->
shape
().
AsVector
().
size
()
==
4
&&
batch_size
==
image
->
shape
()[
0
]
&&
3
==
image
->
shape
()[
1
]
&&
32
==
image
->
shape
()[
2
]
&&
32
==
image
->
shape
()[
3
],
true
);
EXPECT_EQ
(
image
->
shape
().
AsVector
().
size
()
==
4
&&
batch_size
==
image
->
shape
()[
0
]
&&
3
==
image
->
shape
()[
1
]
&&
32
==
image
->
shape
()[
2
]
&&
32
==
image
->
shape
()[
3
],
true
);
EXPECT_EQ
(
label
->
shape
().
AsVector
().
size
()
==
2
&&
batch_size
==
label
->
shape
()[
0
]
&&
number_of_classes
==
label
->
shape
()[
1
],
true
);
number_of_classes
==
label
->
shape
()[
1
],
true
);
iter
->
GetNextRow
(
&
row
);
}
...
...
@@ -95,7 +98,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) {
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
int
number_of_classes
=
10
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
...
...
@@ -108,7 +111,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) {
EXPECT_NE
(
one_hot_op
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
one_hot_op
},{
"label"
});
ds
=
ds
->
Map
({
one_hot_op
},
{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
);
...
...
@@ -134,10 +137,12 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) {
auto
label
=
row
[
"label"
];
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
MS_LOG
(
INFO
)
<<
"Label shape: "
<<
label
->
shape
();
EXPECT_EQ
(
image
->
shape
().
AsVector
().
size
()
==
4
&&
batch_size
==
image
->
shape
()[
0
]
&&
32
==
image
->
shape
()[
1
]
&&
32
==
image
->
shape
()[
2
]
&&
3
==
image
->
shape
()[
3
],
true
);
EXPECT_EQ
(
image
->
shape
().
AsVector
().
size
()
==
4
&&
batch_size
==
image
->
shape
()[
0
]
&&
32
==
image
->
shape
()[
1
]
&&
32
==
image
->
shape
()[
2
]
&&
3
==
image
->
shape
()[
3
],
true
);
EXPECT_EQ
(
label
->
shape
().
AsVector
().
size
()
==
2
&&
batch_size
==
label
->
shape
()[
0
]
&&
number_of_classes
==
label
->
shape
()[
1
],
true
);
number_of_classes
==
label
->
shape
()[
1
],
true
);
iter
->
GetNextRow
(
&
row
);
}
...
...
@@ -151,7 +156,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail1) {
// Must fail because alpha can't be negative
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
...
...
@@ -164,10 +169,11 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail1) {
EXPECT_NE
(
one_hot_op
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
one_hot_op
},{
"label"
});
ds
=
ds
->
Map
({
one_hot_op
},
{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
,
-
1
,
0.5
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
,
-
1
,
0.5
);
EXPECT_EQ
(
cutmix_batch_op
,
nullptr
);
}
...
...
@@ -175,7 +181,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) {
// Must fail because prob can't be negative
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
...
...
@@ -188,20 +194,19 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) {
EXPECT_NE
(
one_hot_op
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
one_hot_op
},{
"label"
});
ds
=
ds
->
Map
({
one_hot_op
},
{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
,
1
,
-
0.5
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
,
1
,
-
0.5
);
EXPECT_EQ
(
cutmix_batch_op
,
nullptr
);
}
TEST_F
(
MindDataTestPipeline
,
TestCutMixBatchFail3
)
{
// Must fail because alpha can't be zero
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
...
...
@@ -214,11 +219,11 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail3) {
EXPECT_NE
(
one_hot_op
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
one_hot_op
},{
"label"
});
ds
=
ds
->
Map
({
one_hot_op
},
{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
,
0.0
,
0.5
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
,
0.0
,
0.5
);
EXPECT_EQ
(
cutmix_batch_op
,
nullptr
);
}
...
...
@@ -371,7 +376,7 @@ TEST_F(MindDataTestPipeline, TestHwcToChw) {
TEST_F
(
MindDataTestPipeline
,
TestMixUpBatchFail1
)
{
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
...
...
@@ -395,7 +400,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail2) {
// This should fail because alpha can't be zero
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
...
...
@@ -418,7 +423,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail2) {
TEST_F
(
MindDataTestPipeline
,
TestMixUpBatchSuccess1
)
{
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
...
...
@@ -467,7 +472,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) {
TEST_F
(
MindDataTestPipeline
,
TestMixUpBatchSuccess2
)
{
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
std
::
string
(),
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
...
...
@@ -871,8 +876,7 @@ TEST_F(MindDataTestPipeline, TestRandomPosterizeSuccess1) {
EXPECT_NE
(
ds
,
nullptr
);
// Create objects for the tensor ops
std
::
shared_ptr
<
TensorOperation
>
posterize
=
vision
::
RandomPosterize
({
1
,
4
});
std
::
shared_ptr
<
TensorOperation
>
posterize
=
vision
::
RandomPosterize
({
1
,
4
});
EXPECT_NE
(
posterize
,
nullptr
);
// Create a Map operation on ds
...
...
@@ -1114,7 +1118,7 @@ TEST_F(MindDataTestPipeline, TestRandomRotation) {
TEST_F
(
MindDataTestPipeline
,
TestUniformAugWithOps
)
{
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
20
));
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
""
,
RandomSampler
(
false
,
20
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Repeat operation on ds
...
...
tests/ut/cpp/dataset/celeba_op_test.cc
浏览文件 @
ea947568
...
...
@@ -42,9 +42,13 @@ std::shared_ptr<CelebAOp> Celeba(int32_t num_workers, int32_t rows_per_buffer, i
bool
decode
=
false
,
const
std
::
string
&
dataset_type
=
"all"
)
{
std
::
shared_ptr
<
CelebAOp
>
so
;
CelebAOp
::
Builder
builder
;
Status
rc
=
builder
.
SetNumWorkers
(
num_workers
).
SetCelebADir
(
dir
).
SetRowsPerBuffer
(
rows_per_buffer
)
.
SetOpConnectorSize
(
queue_size
).
SetSampler
(
std
::
move
(
sampler
)).
SetDecode
(
decode
)
.
SetDatasetType
(
dataset_type
).
Build
(
&
so
);
Status
rc
=
builder
.
SetNumWorkers
(
num_workers
)
.
SetCelebADir
(
dir
)
.
SetRowsPerBuffer
(
rows_per_buffer
)
.
SetOpConnectorSize
(
queue_size
)
.
SetSampler
(
std
::
move
(
sampler
))
.
SetDecode
(
decode
)
.
SetUsage
(
dataset_type
).
Build
(
&
so
);
return
so
;
}
...
...
tests/ut/cpp/dataset/voc_op_test.cc
浏览文件 @
ea947568
...
...
@@ -63,9 +63,7 @@ TEST_F(MindDataTestVOCOp, TestVOCDetection) {
std
::
string
task_mode
(
"train"
);
std
::
shared_ptr
<
VOCOp
>
my_voc_op
;
VOCOp
::
Builder
builder
;
Status
rc
=
builder
.
SetDir
(
dataset_path
)
.
SetTask
(
task_type
)
.
SetMode
(
task_mode
)
Status
rc
=
builder
.
SetDir
(
dataset_path
).
SetTask
(
task_type
).
SetUsage
(
task_mode
)
.
Build
(
&
my_voc_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
...
...
@@ -116,9 +114,7 @@ TEST_F(MindDataTestVOCOp, TestVOCSegmentation) {
std
::
string
task_mode
(
"train"
);
std
::
shared_ptr
<
VOCOp
>
my_voc_op
;
VOCOp
::
Builder
builder
;
Status
rc
=
builder
.
SetDir
(
dataset_path
)
.
SetTask
(
task_type
)
.
SetMode
(
task_mode
)
Status
rc
=
builder
.
SetDir
(
dataset_path
).
SetTask
(
task_type
).
SetUsage
(
task_mode
)
.
Build
(
&
my_voc_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
...
...
@@ -173,9 +169,8 @@ TEST_F(MindDataTestVOCOp, TestVOCClassIndex) {
class_index
[
"train"
]
=
5
;
std
::
shared_ptr
<
VOCOp
>
my_voc_op
;
VOCOp
::
Builder
builder
;
Status
rc
=
builder
.
SetDir
(
dataset_path
)
.
SetTask
(
task_type
)
.
SetMode
(
task_mode
)
Status
rc
=
builder
.
SetDir
(
dataset_path
).
SetTask
(
task_type
).
SetUsage
(
task_mode
)
.
SetClassIndex
(
class_index
)
.
Build
(
&
my_voc_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
...
...
tests/ut/python/dataset/test_bounding_box_augment.py
浏览文件 @
ea947568
...
...
@@ -42,8 +42,8 @@ def test_bounding_box_augment_with_rotation_op(plot_vis=False):
original_seed
=
config_get_set_seed
(
0
)
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
# Ratio is set to 1 to apply rotation on all bounding boxes.
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomRotation
(
90
),
1
)
...
...
@@ -81,8 +81,8 @@ def test_bounding_box_augment_with_crop_op(plot_vis=False):
original_seed
=
config_get_set_seed
(
0
)
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
# Ratio is set to 0.9 to apply RandomCrop of size (50, 50) on 90% of the bounding boxes.
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomCrop
(
50
),
0.9
)
...
...
@@ -120,8 +120,8 @@ def test_bounding_box_augment_valid_ratio_c(plot_vis=False):
original_seed
=
config_get_set_seed
(
1
)
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomHorizontalFlip
(
1
),
0.9
)
...
...
@@ -188,8 +188,8 @@ def test_bounding_box_augment_valid_edge_c(plot_vis=False):
original_seed
=
config_get_set_seed
(
1
)
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomHorizontalFlip
(
1
),
1
)
...
...
@@ -232,7 +232,7 @@ def test_bounding_box_augment_invalid_ratio_c():
"""
logger
.
info
(
"test_bounding_box_augment_invalid_ratio_c"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
try
:
# ratio range is from 0 - 1
...
...
@@ -256,13 +256,13 @@ def test_bounding_box_augment_invalid_bounds_c():
test_op
=
c_vision
.
BoundingBoxAugment
(
c_vision
.
RandomHorizontalFlip
(
1
),
1
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
NegativeXY
,
"min_x"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
WrongShape
,
"4 features"
)
...
...
tests/ut/python/dataset/test_datasets_celeba.py
浏览文件 @
ea947568
...
...
@@ -20,7 +20,7 @@ DATA_DIR = "../data/dataset/testCelebAData/"
def
test_celeba_dataset_label
():
data
=
ds
.
CelebADataset
(
DATA_DIR
,
decode
=
True
,
shuffle
=
Fals
e
)
data
=
ds
.
CelebADataset
(
DATA_DIR
,
shuffle
=
False
,
decode
=
Tru
e
)
expect_labels
=
[
[
0
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
0
,
1
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
1
],
...
...
@@ -85,11 +85,13 @@ def test_celeba_dataset_distribute():
count
=
count
+
1
assert
count
==
1
def
test_celeba_get_dataset_size
():
data
=
ds
.
CelebADataset
(
DATA_DIR
,
decode
=
True
,
shuffle
=
Fals
e
)
data
=
ds
.
CelebADataset
(
DATA_DIR
,
shuffle
=
False
,
decode
=
Tru
e
)
size
=
data
.
get_dataset_size
()
assert
size
==
2
if
__name__
==
'__main__'
:
test_celeba_dataset_label
()
test_celeba_dataset_op
()
...
...
tests/ut/python/dataset/test_datasets_cifarop.py
浏览文件 @
ea947568
...
...
@@ -392,6 +392,59 @@ def test_cifar100_visualize(plot=False):
visualize_dataset
(
image_list
,
label_list
)
def
test_cifar_usage
():
"""
test usage of cifar
"""
logger
.
info
(
"Test Cifar100Dataset usage flag"
)
# flag, if True, test cifar10 else test cifar100
def
test_config
(
usage
,
flag
=
True
,
cifar_path
=
None
):
if
cifar_path
is
None
:
cifar_path
=
DATA_DIR_10
if
flag
else
DATA_DIR_100
try
:
data
=
ds
.
Cifar10Dataset
(
cifar_path
,
usage
=
usage
)
if
flag
else
ds
.
Cifar100Dataset
(
cifar_path
,
usage
=
usage
)
num_rows
=
0
for
_
in
data
.
create_dict_iterator
():
num_rows
+=
1
except
(
ValueError
,
TypeError
,
RuntimeError
)
as
e
:
return
str
(
e
)
return
num_rows
# test the usage of CIFAR100
assert
test_config
(
"train"
)
==
10000
assert
test_config
(
"all"
)
==
10000
assert
"usage is not within the valid set of ['train', 'test', 'all']"
in
test_config
(
"invalid"
)
assert
"Argument usage with value ['list'] is not of type (<class 'str'>,)"
in
test_config
([
"list"
])
assert
"no valid data matching the dataset API Cifar10Dataset"
in
test_config
(
"test"
)
# test the usage of CIFAR10
assert
test_config
(
"test"
,
False
)
==
10000
assert
test_config
(
"all"
,
False
)
==
10000
assert
"no valid data matching the dataset API Cifar100Dataset"
in
test_config
(
"train"
,
False
)
assert
"usage is not within the valid set of ['train', 'test', 'all']"
in
test_config
(
"invalid"
,
False
)
# change this directory to the folder that contains all cifar10 files
all_cifar10
=
None
if
all_cifar10
is
not
None
:
assert
test_config
(
"train"
,
True
,
all_cifar10
)
==
50000
assert
test_config
(
"test"
,
True
,
all_cifar10
)
==
10000
assert
test_config
(
"all"
,
True
,
all_cifar10
)
==
60000
assert
ds
.
Cifar10Dataset
(
all_cifar10
,
usage
=
"train"
).
get_dataset_size
()
==
50000
assert
ds
.
Cifar10Dataset
(
all_cifar10
,
usage
=
"test"
).
get_dataset_size
()
==
10000
assert
ds
.
Cifar10Dataset
(
all_cifar10
,
usage
=
"all"
).
get_dataset_size
()
==
60000
# change this directory to the folder that contains all cifar100 files
all_cifar100
=
None
if
all_cifar100
is
not
None
:
assert
test_config
(
"train"
,
False
,
all_cifar100
)
==
50000
assert
test_config
(
"test"
,
False
,
all_cifar100
)
==
10000
assert
test_config
(
"all"
,
False
,
all_cifar100
)
==
60000
assert
ds
.
Cifar100Dataset
(
all_cifar100
,
usage
=
"train"
).
get_dataset_size
()
==
50000
assert
ds
.
Cifar100Dataset
(
all_cifar100
,
usage
=
"test"
).
get_dataset_size
()
==
10000
assert
ds
.
Cifar100Dataset
(
all_cifar100
,
usage
=
"all"
).
get_dataset_size
()
==
60000
if
__name__
==
'__main__'
:
test_cifar10_content_check
()
test_cifar10_basic
()
...
...
@@ -405,3 +458,5 @@ if __name__ == '__main__':
test_cifar100_pk_sampler
()
test_cifar100_exception
()
test_cifar100_visualize
(
plot
=
False
)
test_cifar_usage
()
tests/ut/python/dataset/test_datasets_get_dataset_size.py
浏览文件 @
ea947568
...
...
@@ -58,6 +58,14 @@ def test_mnist_dataset_size():
ds_total
=
ds
.
MnistDataset
(
MNIST_DATA_DIR
)
assert
ds_total
.
get_dataset_size
()
==
10000
# test get dataset_size with the usage arg
test_size
=
ds
.
MnistDataset
(
MNIST_DATA_DIR
,
usage
=
"test"
).
get_dataset_size
()
assert
test_size
==
10000
train_size
=
ds
.
MnistDataset
(
MNIST_DATA_DIR
,
usage
=
"train"
).
get_dataset_size
()
assert
train_size
==
0
all_size
=
ds
.
MnistDataset
(
MNIST_DATA_DIR
,
usage
=
"all"
).
get_dataset_size
()
assert
all_size
==
10000
ds_shard_1_0
=
ds
.
MnistDataset
(
MNIST_DATA_DIR
,
num_shards
=
1
,
shard_id
=
0
)
assert
ds_shard_1_0
.
get_dataset_size
()
==
10000
...
...
@@ -86,6 +94,14 @@ def test_cifar10_dataset_size():
ds_total
=
ds
.
Cifar10Dataset
(
CIFAR10_DATA_DIR
)
assert
ds_total
.
get_dataset_size
()
==
10000
# test get_dataset_size with usage flag
train_size
=
ds
.
Cifar10Dataset
(
CIFAR10_DATA_DIR
,
usage
=
"train"
).
get_dataset_size
()
assert
train_size
==
10000
test_size
=
ds
.
Cifar10Dataset
(
CIFAR10_DATA_DIR
,
usage
=
"test"
).
get_dataset_size
()
assert
test_size
==
0
all_size
=
ds
.
Cifar10Dataset
(
CIFAR10_DATA_DIR
,
usage
=
"all"
).
get_dataset_size
()
assert
all_size
==
10000
ds_shard_1_0
=
ds
.
Cifar10Dataset
(
CIFAR10_DATA_DIR
,
num_shards
=
1
,
shard_id
=
0
)
assert
ds_shard_1_0
.
get_dataset_size
()
==
10000
...
...
@@ -103,6 +119,14 @@ def test_cifar100_dataset_size():
ds_total
=
ds
.
Cifar100Dataset
(
CIFAR100_DATA_DIR
)
assert
ds_total
.
get_dataset_size
()
==
10000
# test get_dataset_size with usage flag
train_size
=
ds
.
Cifar100Dataset
(
CIFAR100_DATA_DIR
,
usage
=
"train"
).
get_dataset_size
()
assert
train_size
==
0
test_size
=
ds
.
Cifar100Dataset
(
CIFAR100_DATA_DIR
,
usage
=
"test"
).
get_dataset_size
()
assert
test_size
==
10000
all_size
=
ds
.
Cifar100Dataset
(
CIFAR100_DATA_DIR
,
usage
=
"all"
).
get_dataset_size
()
assert
all_size
==
10000
ds_shard_1_0
=
ds
.
Cifar100Dataset
(
CIFAR100_DATA_DIR
,
num_shards
=
1
,
shard_id
=
0
)
assert
ds_shard_1_0
.
get_dataset_size
()
==
10000
...
...
@@ -111,3 +135,12 @@ def test_cifar100_dataset_size():
ds_shard_3_0
=
ds
.
Cifar100Dataset
(
CIFAR100_DATA_DIR
,
num_shards
=
3
,
shard_id
=
0
)
assert
ds_shard_3_0
.
get_dataset_size
()
==
3334
if
__name__
==
'__main__'
:
test_imagenet_rawdata_dataset_size
()
test_imagenet_tf_file_dataset_size
()
test_mnist_dataset_size
()
test_manifest_dataset_size
()
test_cifar10_dataset_size
()
test_cifar100_dataset_size
()
tests/ut/python/dataset/test_datasets_mnist.py
浏览文件 @
ea947568
...
...
@@ -229,6 +229,41 @@ def test_mnist_visualize(plot=False):
visualize_dataset
(
image_list
,
label_list
)
def
test_mnist_usage
():
"""
Validate MnistDataset image readings
"""
logger
.
info
(
"Test MnistDataset usage flag"
)
def
test_config
(
usage
,
mnist_path
=
None
):
mnist_path
=
DATA_DIR
if
mnist_path
is
None
else
mnist_path
try
:
data
=
ds
.
MnistDataset
(
mnist_path
,
usage
=
usage
,
shuffle
=
False
)
num_rows
=
0
for
_
in
data
.
create_dict_iterator
():
num_rows
+=
1
except
(
ValueError
,
TypeError
,
RuntimeError
)
as
e
:
return
str
(
e
)
return
num_rows
assert
test_config
(
"test"
)
==
10000
assert
test_config
(
"all"
)
==
10000
assert
" no valid data matching the dataset API MnistDataset"
in
test_config
(
"train"
)
assert
"usage is not within the valid set of ['train', 'test', 'all']"
in
test_config
(
"invalid"
)
assert
"Argument usage with value ['list'] is not of type (<class 'str'>,)"
in
test_config
([
"list"
])
# change this directory to the folder that contains all mnist files
all_files_path
=
None
# the following tests on the entire datasets
if
all_files_path
is
not
None
:
assert
test_config
(
"train"
,
all_files_path
)
==
60000
assert
test_config
(
"test"
,
all_files_path
)
==
10000
assert
test_config
(
"all"
,
all_files_path
)
==
70000
assert
ds
.
MnistDataset
(
all_files_path
,
usage
=
"train"
).
get_dataset_size
()
==
60000
assert
ds
.
MnistDataset
(
all_files_path
,
usage
=
"test"
).
get_dataset_size
()
==
10000
assert
ds
.
MnistDataset
(
all_files_path
,
usage
=
"all"
).
get_dataset_size
()
==
70000
if
__name__
==
'__main__'
:
test_mnist_content_check
()
test_mnist_basic
()
...
...
@@ -236,3 +271,4 @@ if __name__ == '__main__':
test_mnist_sequential_sampler
()
test_mnist_exception
()
test_mnist_visualize
(
plot
=
True
)
test_mnist_usage
()
tests/ut/python/dataset/test_datasets_voc.py
浏览文件 @
ea947568
...
...
@@ -21,7 +21,7 @@ TARGET_SHAPE = [680, 680, 680, 680, 642, 607, 561, 596, 612, 680]
def
test_voc_segmentation
():
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Segmentation"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Segmentation"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
num
=
0
for
item
in
data1
.
create_dict_iterator
(
num_epochs
=
1
):
assert
item
[
"image"
].
shape
[
0
]
==
IMAGE_SHAPE
[
num
]
...
...
@@ -31,7 +31,7 @@ def test_voc_segmentation():
def
test_voc_detection
():
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
num
=
0
count
=
[
0
,
0
,
0
,
0
,
0
,
0
]
for
item
in
data1
.
create_dict_iterator
(
num_epochs
=
1
):
...
...
@@ -45,7 +45,7 @@ def test_voc_detection():
def
test_voc_class_index
():
class_index
=
{
'car'
:
0
,
'cat'
:
1
,
'train'
:
5
}
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mod
e
=
"train"
,
class_indexing
=
class_index
,
decode
=
True
)
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usag
e
=
"train"
,
class_indexing
=
class_index
,
decode
=
True
)
class_index1
=
data1
.
get_class_indexing
()
assert
(
class_index1
==
{
'car'
:
0
,
'cat'
:
1
,
'train'
:
5
})
data1
=
data1
.
shuffle
(
4
)
...
...
@@ -63,7 +63,7 @@ def test_voc_class_index():
def
test_voc_get_class_indexing
():
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mod
e
=
"train"
,
decode
=
True
)
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usag
e
=
"train"
,
decode
=
True
)
class_index1
=
data1
.
get_class_indexing
()
assert
(
class_index1
==
{
'car'
:
0
,
'cat'
:
1
,
'chair'
:
2
,
'dog'
:
3
,
'person'
:
4
,
'train'
:
5
})
data1
=
data1
.
shuffle
(
4
)
...
...
@@ -81,7 +81,7 @@ def test_voc_get_class_indexing():
def
test_case_0
():
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Segmentation"
,
mod
e
=
"train"
,
decode
=
True
)
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Segmentation"
,
usag
e
=
"train"
,
decode
=
True
)
resize_op
=
vision
.
Resize
((
224
,
224
))
...
...
@@ -99,7 +99,7 @@ def test_case_0():
def
test_case_1
():
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mod
e
=
"train"
,
decode
=
True
)
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usag
e
=
"train"
,
decode
=
True
)
resize_op
=
vision
.
Resize
((
224
,
224
))
...
...
@@ -116,7 +116,7 @@ def test_case_1():
def
test_case_2
():
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Segmentation"
,
mod
e
=
"train"
,
decode
=
True
)
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Segmentation"
,
usag
e
=
"train"
,
decode
=
True
)
sizes
=
[
0.5
,
0.5
]
randomize
=
False
dataset1
,
dataset2
=
data1
.
split
(
sizes
=
sizes
,
randomize
=
randomize
)
...
...
@@ -134,7 +134,7 @@ def test_case_2():
def
test_voc_exception
():
try
:
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"InvalidTask"
,
mod
e
=
"train"
,
decode
=
True
)
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"InvalidTask"
,
usag
e
=
"train"
,
decode
=
True
)
for
_
in
data1
.
create_dict_iterator
(
num_epochs
=
1
):
pass
assert
False
...
...
@@ -142,7 +142,7 @@ def test_voc_exception():
pass
try
:
data2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Segmentation"
,
mod
e
=
"train"
,
class_indexing
=
{
"cat"
:
0
},
decode
=
True
)
data2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Segmentation"
,
usag
e
=
"train"
,
class_indexing
=
{
"cat"
:
0
},
decode
=
True
)
for
_
in
data2
.
create_dict_iterator
(
num_epochs
=
1
):
pass
assert
False
...
...
@@ -150,7 +150,7 @@ def test_voc_exception():
pass
try
:
data3
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mod
e
=
"notexist"
,
decode
=
True
)
data3
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usag
e
=
"notexist"
,
decode
=
True
)
for
_
in
data3
.
create_dict_iterator
(
num_epochs
=
1
):
pass
assert
False
...
...
@@ -158,7 +158,7 @@ def test_voc_exception():
pass
try
:
data4
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mod
e
=
"xmlnotexist"
,
decode
=
True
)
data4
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usag
e
=
"xmlnotexist"
,
decode
=
True
)
for
_
in
data4
.
create_dict_iterator
(
num_epochs
=
1
):
pass
assert
False
...
...
@@ -166,7 +166,7 @@ def test_voc_exception():
pass
try
:
data5
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mod
e
=
"invalidxml"
,
decode
=
True
)
data5
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usag
e
=
"invalidxml"
,
decode
=
True
)
for
_
in
data5
.
create_dict_iterator
(
num_epochs
=
1
):
pass
assert
False
...
...
@@ -174,7 +174,7 @@ def test_voc_exception():
pass
try
:
data6
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mod
e
=
"xmlnoobject"
,
decode
=
True
)
data6
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usag
e
=
"xmlnoobject"
,
decode
=
True
)
for
_
in
data6
.
create_dict_iterator
(
num_epochs
=
1
):
pass
assert
False
...
...
tests/ut/python/dataset/test_epoch_ctrl.py
浏览文件 @
ea947568
...
...
@@ -35,6 +35,7 @@ def diff_mse(in1, in2):
mse
=
(
np
.
square
(
in1
.
astype
(
float
)
/
255
-
in2
.
astype
(
float
)
/
255
)).
mean
()
return
mse
*
100
def
test_cifar10
():
"""
dataset parameter
...
...
@@ -45,7 +46,7 @@ def test_cifar10():
batch_size
=
32
limit_dataset
=
100
# apply dataset operations
data1
=
ds
.
Cifar10Dataset
(
data_dir_10
,
limit_dataset
)
data1
=
ds
.
Cifar10Dataset
(
data_dir_10
,
num_samples
=
limit_dataset
)
data1
=
data1
.
repeat
(
num_repeat
)
data1
=
data1
.
batch
(
batch_size
,
True
)
num_epoch
=
5
...
...
@@ -139,6 +140,7 @@ def test_generator_dict_0():
np
.
testing
.
assert_array_equal
(
item
[
"data"
],
golden
)
i
=
i
+
1
def
test_generator_dict_1
():
"""
test generator dict 1
...
...
@@ -158,6 +160,7 @@ def test_generator_dict_1():
i
=
i
+
1
assert
i
==
64
def
test_generator_dict_2
():
"""
test generator dict 2
...
...
@@ -180,6 +183,7 @@ def test_generator_dict_2():
assert
item1
# rely on garbage collector to destroy iter1
def
test_generator_dict_3
():
"""
test generator dict 3
...
...
@@ -226,6 +230,7 @@ def test_generator_dict_4():
err_msg
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert
err_msg
in
str
(
info
.
value
)
def
test_generator_dict_4_1
():
"""
test generator dict 4_1
...
...
@@ -249,6 +254,7 @@ def test_generator_dict_4_1():
err_msg
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert
err_msg
in
str
(
info
.
value
)
def
test_generator_dict_4_2
():
"""
test generator dict 4_2
...
...
@@ -274,6 +280,7 @@ def test_generator_dict_4_2():
err_msg
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert
err_msg
in
str
(
info
.
value
)
def
test_generator_dict_5
():
"""
test generator dict 5
...
...
@@ -305,6 +312,7 @@ def test_generator_dict_5():
err_msg
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert
err_msg
in
str
(
info
.
value
)
# Test tuple iterator
def
test_generator_tuple_0
():
...
...
@@ -323,6 +331,7 @@ def test_generator_tuple_0():
np
.
testing
.
assert_array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
def
test_generator_tuple_1
():
"""
test generator tuple 1
...
...
@@ -342,6 +351,7 @@ def test_generator_tuple_1():
i
=
i
+
1
assert
i
==
64
def
test_generator_tuple_2
():
"""
test generator tuple 2
...
...
@@ -364,6 +374,7 @@ def test_generator_tuple_2():
assert
item1
# rely on garbage collector to destroy iter1
def
test_generator_tuple_3
():
"""
test generator tuple 3
...
...
@@ -442,6 +453,7 @@ def test_generator_tuple_5():
err_msg
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert
err_msg
in
str
(
info
.
value
)
# Test with repeat
def
test_generator_tuple_repeat_1
():
"""
...
...
@@ -536,6 +548,7 @@ def test_generator_tuple_repeat_repeat_2():
iter1
.
__next__
()
assert
"object has no attribute 'depipeline'"
in
str
(
info
.
value
)
def
test_generator_tuple_repeat_repeat_3
():
"""
test generator tuple repeat repeat 3
...
...
tests/ut/python/dataset/test_get_col_names.py
浏览文件 @
ea947568
...
...
@@ -149,7 +149,7 @@ def test_get_column_name_to_device():
def
test_get_column_name_voc
():
data
=
ds
.
VOCDataset
(
VOC_DIR
,
task
=
"Segmentation"
,
mod
e
=
"train"
,
decode
=
True
,
shuffle
=
False
)
data
=
ds
.
VOCDataset
(
VOC_DIR
,
task
=
"Segmentation"
,
usag
e
=
"train"
,
decode
=
True
,
shuffle
=
False
)
assert
data
.
get_col_names
()
==
[
"image"
,
"target"
]
...
...
tests/ut/python/dataset/test_noop_mode.py
浏览文件 @
ea947568
...
...
@@ -22,7 +22,7 @@ DATA_DIR = "../data/dataset/testVOC2012"
def
test_noop_pserver
():
os
.
environ
[
'MS_ROLE'
]
=
'MS_PSERVER'
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Segmentation"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Segmentation"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
num
=
0
for
_
in
data1
.
create_dict_iterator
(
num_epochs
=
1
):
num
+=
1
...
...
@@ -32,7 +32,7 @@ def test_noop_pserver():
def
test_noop_sched
():
os
.
environ
[
'MS_ROLE'
]
=
'MS_SCHED'
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Segmentation"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Segmentation"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
num
=
0
for
_
in
data1
.
create_dict_iterator
(
num_epochs
=
1
):
num
+=
1
...
...
tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py
浏览文件 @
ea947568
...
...
@@ -42,8 +42,8 @@ def test_random_resized_crop_with_bbox_op_c(plot_vis=False):
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
test_op
=
c_vision
.
RandomResizedCropWithBBox
((
256
,
512
),
(
0.5
,
0.5
),
(
0.5
,
0.5
))
...
...
@@ -108,8 +108,8 @@ def test_random_resized_crop_with_bbox_op_edge_c(plot_vis=False):
logger
.
info
(
"test_random_resized_crop_with_bbox_op_edge_c"
)
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
test_op
=
c_vision
.
RandomResizedCropWithBBox
((
256
,
512
),
(
0.5
,
0.5
),
(
0.5
,
0.5
))
...
...
@@ -142,7 +142,7 @@ def test_random_resized_crop_with_bbox_op_invalid_c():
logger
.
info
(
"test_random_resized_crop_with_bbox_op_invalid_c"
)
# Load dataset, only Augmented Dataset as test will raise ValueError
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
try
:
# If input range of scale is not in the order of (min, max), ValueError will be raised.
...
...
@@ -168,7 +168,7 @@ def test_random_resized_crop_with_bbox_op_invalid2_c():
"""
logger
.
info
(
"test_random_resized_crop_with_bbox_op_invalid2_c"
)
# Load dataset # only loading the to AugDataset as test will fail on this
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
try
:
# If input range of ratio is not in the order of (min, max), ValueError will be raised.
...
...
@@ -195,13 +195,13 @@ def test_random_resized_crop_with_bbox_op_bad_c():
logger
.
info
(
"test_random_resized_crop_with_bbox_op_bad_c"
)
test_op
=
c_vision
.
RandomResizedCropWithBBox
((
256
,
512
),
(
0.5
,
0.5
),
(
0.5
,
0.5
))
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
NegativeXY
,
"min_x"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WrongShape
,
"4 features"
)
...
...
tests/ut/python/dataset/test_random_crop_with_bbox.py
浏览文件 @
ea947568
...
...
@@ -39,8 +39,8 @@ def test_random_crop_with_bbox_op_c(plot_vis=False):
logger
.
info
(
"test_random_crop_with_bbox_op_c"
)
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
# define test OP with values to match existing Op UT
test_op
=
c_vision
.
RandomCropWithBBox
([
512
,
512
],
[
200
,
200
,
200
,
200
])
...
...
@@ -101,8 +101,8 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False):
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
# define test OP with values to match existing Op unit - test
test_op
=
c_vision
.
RandomCropWithBBox
(
512
,
[
200
,
200
,
200
,
200
],
fill_value
=
(
255
,
255
,
255
))
...
...
@@ -138,8 +138,8 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False):
logger
.
info
(
"test_random_crop_with_bbox_op3_c"
)
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
# define test OP with values to match existing Op unit - test
test_op
=
c_vision
.
RandomCropWithBBox
(
512
,
[
200
,
200
,
200
,
200
],
padding_mode
=
mode
.
Border
.
EDGE
)
...
...
@@ -168,8 +168,8 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False):
logger
.
info
(
"test_random_crop_with_bbox_op_edge_c"
)
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
# define test OP with values to match existing Op unit - test
test_op
=
c_vision
.
RandomCropWithBBox
(
512
,
[
200
,
200
,
200
,
200
],
padding_mode
=
mode
.
Border
.
EDGE
)
...
...
@@ -205,7 +205,7 @@ def test_random_crop_with_bbox_op_invalid_c():
logger
.
info
(
"test_random_crop_with_bbox_op_invalid_c"
)
# Load dataset
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
try
:
# define test OP with values to match existing Op unit - test
...
...
@@ -231,13 +231,13 @@ def test_random_crop_with_bbox_op_bad_c():
logger
.
info
(
"test_random_crop_with_bbox_op_bad_c"
)
test_op
=
c_vision
.
RandomCropWithBBox
([
512
,
512
],
[
200
,
200
,
200
,
200
])
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
NegativeXY
,
"min_x"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WrongShape
,
"4 features"
)
...
...
@@ -247,7 +247,7 @@ def test_random_crop_with_bbox_op_bad_padding():
"""
logger
.
info
(
"test_random_crop_with_bbox_op_invalid_c"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
try
:
test_op
=
c_vision
.
RandomCropWithBBox
([
512
,
512
],
padding
=-
1
)
...
...
tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py
浏览文件 @
ea947568
...
...
@@ -37,11 +37,9 @@ def test_random_horizontal_flip_with_bbox_op_c(plot_vis=False):
logger
.
info
(
"test_random_horizontal_flip_with_bbox_op_c"
)
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
test_op
=
c_vision
.
RandomHorizontalFlipWithBBox
(
1
)
...
...
@@ -102,11 +100,9 @@ def test_random_horizontal_flip_with_bbox_valid_rand_c(plot_vis=False):
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
test_op
=
c_vision
.
RandomHorizontalFlipWithBBox
(
0.6
)
...
...
@@ -140,8 +136,8 @@ def test_random_horizontal_flip_with_bbox_valid_edge_c(plot_vis=False):
"""
logger
.
info
(
"test_horizontal_flip_with_bbox_valid_edge_c"
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
test_op
=
c_vision
.
RandomHorizontalFlipWithBBox
(
1
)
...
...
@@ -178,7 +174,7 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c():
"""
logger
.
info
(
"test_random_horizontal_bbox_invalid_prob_c"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
try
:
# Note: Valid range of prob should be [0.0, 1.0]
...
...
@@ -201,13 +197,13 @@ def test_random_horizontal_flip_with_bbox_invalid_bounds_c():
test_op
=
c_vision
.
RandomHorizontalFlipWithBBox
(
1
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
NegativeXY
,
"min_x"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
dataVoc2
,
test_op
,
InvalidBBoxType
.
WrongShape
,
"4 features"
)
...
...
tests/ut/python/dataset/test_random_resize_with_bbox.py
浏览文件 @
ea947568
...
...
@@ -39,11 +39,9 @@ def test_random_resize_with_bbox_op_voc_c(plot_vis=False):
original_seed
=
config_get_set_seed
(
123
)
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
test_op
=
c_vision
.
RandomResizeWithBBox
(
100
)
...
...
@@ -120,11 +118,9 @@ def test_random_resize_with_bbox_op_edge_c(plot_vis=False):
box has dimensions as the image itself.
"""
logger
.
info
(
"test_random_resize_with_bbox_op_edge_c"
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
test_op
=
c_vision
.
RandomResizeWithBBox
(
500
)
...
...
@@ -197,13 +193,13 @@ def test_random_resize_with_bbox_op_bad_c():
logger
.
info
(
"test_random_resize_with_bbox_op_bad_c"
)
test_op
=
c_vision
.
RandomResizeWithBBox
((
400
,
300
))
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
NegativeXY
,
"min_x"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WrongShape
,
"4 features"
)
...
...
tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py
浏览文件 @
ea947568
...
...
@@ -37,11 +37,9 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False):
"""
logger
.
info
(
"test_random_vertical_flip_with_bbox_op_c"
)
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
test_op
=
c_vision
.
RandomVerticalFlipWithBBox
(
1
)
...
...
@@ -102,11 +100,9 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False):
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
test_op
=
c_vision
.
RandomVerticalFlipWithBBox
(
0.8
)
...
...
@@ -139,11 +135,9 @@ def test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=False):
applied on dynamically generated edge case, expected to pass
"""
logger
.
info
(
"test_random_vertical_flip_with_bbox_op_edge_c"
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
test_op
=
c_vision
.
RandomVerticalFlipWithBBox
(
1
)
...
...
@@ -174,8 +168,7 @@ def test_random_vertical_flip_with_bbox_op_invalid_c():
Test RandomVerticalFlipWithBBox Op on invalid constructor parameters, expected to raise ValueError
"""
logger
.
info
(
"test_random_vertical_flip_with_bbox_op_invalid_c"
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
try
:
test_op
=
c_vision
.
RandomVerticalFlipWithBBox
(
2
)
...
...
@@ -201,13 +194,13 @@ def test_random_vertical_flip_with_bbox_op_bad_c():
logger
.
info
(
"test_random_vertical_flip_with_bbox_op_bad_c"
)
test_op
=
c_vision
.
RandomVerticalFlipWithBBox
(
1
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
NegativeXY
,
"min_x"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR_VOC
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WrongShape
,
"4 features"
)
...
...
tests/ut/python/dataset/test_resize_with_bbox.py
浏览文件 @
ea947568
...
...
@@ -39,11 +39,9 @@ def test_resize_with_bbox_op_voc_c(plot_vis=False):
logger
.
info
(
"test_resize_with_bbox_op_voc_c"
)
# Load dataset
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
test_op
=
c_vision
.
ResizeWithBBox
(
100
)
...
...
@@ -110,11 +108,9 @@ def test_resize_with_bbox_op_edge_c(plot_vis=False):
box has dimensions as the image itself.
"""
logger
.
info
(
"test_resize_with_bbox_op_edge_c"
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
False
)
dataVoc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
True
)
test_op
=
c_vision
.
ResizeWithBBox
(
500
)
...
...
@@ -163,13 +159,13 @@ def test_resize_with_bbox_op_bad_c():
logger
.
info
(
"test_resize_with_bbox_op_bad_c"
)
test_op
=
c_vision
.
ResizeWithBBox
((
200
,
300
))
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WidthOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
HeightOverflow
,
"bounding boxes is out of bounds of the image"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
NegativeXY
,
"min_x"
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
mode
=
"train"
,
decode
=
True
,
shuffle
=
Fals
e
)
data_voc2
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Detection"
,
usage
=
"train"
,
shuffle
=
False
,
decode
=
Tru
e
)
check_bad_bbox
(
data_voc2
,
test_op
,
InvalidBBoxType
.
WrongShape
,
"4 features"
)
...
...
tests/ut/python/dataset/test_serdes_dataset.py
浏览文件 @
ea947568
...
...
@@ -32,6 +32,7 @@ from mindspore.dataset.vision import Inter
def
test_imagefolder
(
remove_json_files
=
True
):
"""
Test simulating resnet50 dataset pipeline.
...
...
@@ -103,7 +104,7 @@ def test_mnist_dataset(remove_json_files=True):
data_dir
=
"../data/dataset/testMnistData"
ds
.
config
.
set_seed
(
1
)
data1
=
ds
.
MnistDataset
(
data_dir
,
100
)
data1
=
ds
.
MnistDataset
(
data_dir
,
num_samples
=
100
)
one_hot_encode
=
c
.
OneHot
(
10
)
# num_classes is input argument
data1
=
data1
.
map
(
input_columns
=
"label"
,
operations
=
one_hot_encode
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录