Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
769ae609
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
769ae609
编写于
6月 04, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 04, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1808 consistent design for num_samples
Merge pull request !1808 from Jamie/numsamples
上级
06ee0296
51bc0c04
变更
55
隐藏空白更改
内联
并排
Showing
55 changed file
with
617 addition
and
1154 deletion
+617
-1154
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+11
-23
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+23
-33
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
...spore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
+14
-46
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h
+1
-22
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
+10
-38
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h
+3
-25
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
...ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
+11
-42
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h
.../ccsrc/dataset/engine/datasetops/source/image_folder_op.h
+5
-28
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
...ore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
+17
-48
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h
...pore/ccsrc/dataset/engine/datasetops/source/manifest_op.h
+6
-27
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
...e/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
+0
-1
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
+10
-38
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h
+3
-26
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt
...c/dataset/engine/datasetops/source/sampler/CMakeLists.txt
+0
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
...t/engine/datasetops/source/sampler/distributed_sampler.cc
+8
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h
...et/engine/datasetops/source/sampler/distributed_sampler.h
+4
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
...rc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
+22
-12
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h
...src/dataset/engine/datasetops/source/sampler/pk_sampler.h
+6
-5
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc
...ataset/engine/datasetops/source/sampler/python_sampler.cc
+7
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h
...dataset/engine/datasetops/source/sampler/python_sampler.h
+5
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
...ataset/engine/datasetops/source/sampler/random_sampler.cc
+9
-13
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h
...dataset/engine/datasetops/source/sampler/random_sampler.h
+4
-5
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
...ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
+42
-6
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
.../ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
+32
-33
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
...et/engine/datasetops/source/sampler/sequential_sampler.cc
+29
-16
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h
...set/engine/datasetops/source/sampler/sequential_sampler.h
+8
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
...engine/datasetops/source/sampler/subset_random_sampler.cc
+15
-9
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h
.../engine/datasetops/source/sampler/subset_random_sampler.h
+2
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc
...ataset/engine/datasetops/source/sampler/subset_sampler.cc
+0
-85
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h
...dataset/engine/datasetops/source/sampler/subset_sampler.h
+0
-58
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
...gine/datasetops/source/sampler/weighted_random_sampler.cc
+16
-13
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h
...ngine/datasetops/source/sampler/weighted_random_sampler.h
+2
-5
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
...re/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
+10
-10
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h
...ore/ccsrc/dataset/engine/datasetops/source/text_file_op.h
+5
-5
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
+8
-37
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
+3
-27
mindspore/dataset/__init__.py
mindspore/dataset/__init__.py
+1
-1
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+45
-70
mindspore/dataset/engine/samplers.py
mindspore/dataset/engine/samplers.py
+76
-104
mindspore/dataset/engine/serializer_deserializer.py
mindspore/dataset/engine/serializer_deserializer.py
+14
-0
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+2
-2
tests/ut/cpp/dataset/celeba_op_test.cc
tests/ut/cpp/dataset/celeba_op_test.cc
+6
-28
tests/ut/cpp/dataset/cifar_op_test.cc
tests/ut/cpp/dataset/cifar_op_test.cc
+10
-34
tests/ut/cpp/dataset/image_folder_op_test.cc
tests/ut/cpp/dataset/image_folder_op_test.cc
+22
-63
tests/ut/cpp/dataset/manifest_op_test.cc
tests/ut/cpp/dataset/manifest_op_test.cc
+14
-7
tests/ut/cpp/dataset/map_op_test.cc
tests/ut/cpp/dataset/map_op_test.cc
+2
-3
tests/ut/cpp/dataset/mnist_op_test.cc
tests/ut/cpp/dataset/mnist_op_test.cc
+10
-6
tests/ut/cpp/dataset/stand_alone_samplers_test.cc
tests/ut/cpp/dataset/stand_alone_samplers_test.cc
+9
-15
tests/ut/cpp/dataset/subset_random_sampler_test.cc
tests/ut/cpp/dataset/subset_random_sampler_test.cc
+9
-16
tests/ut/cpp/dataset/weighted_random_sampler_test.cc
tests/ut/cpp/dataset/weighted_random_sampler_test.cc
+10
-18
tests/ut/python/dataset/test_datasets_imagefolder.py
tests/ut/python/dataset/test_datasets_imagefolder.py
+2
-2
tests/ut/python/dataset/test_datasets_sharding.py
tests/ut/python/dataset/test_datasets_sharding.py
+2
-2
tests/ut/python/dataset/test_exceptions.py
tests/ut/python/dataset/test_exceptions.py
+2
-7
tests/ut/python/dataset/test_generator.py
tests/ut/python/dataset/test_generator.py
+3
-1
tests/ut/python/dataset/test_sampler.py
tests/ut/python/dataset/test_sampler.py
+27
-26
未找到文件。
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
769ae609
...
...
@@ -856,9 +856,7 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"num_samples"
)
{
(
void
)
builder
->
SetNumSamples
(
ToInt
(
value
));
}
else
if
(
key
==
"num_parallel_workers"
)
{
if
(
key
==
"num_parallel_workers"
)
{
(
void
)
builder
->
SetNumWorkers
(
ToInt
(
value
));
}
else
if
(
key
==
"sampler"
)
{
auto
create
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
).
attr
(
"create"
);
...
...
@@ -893,9 +891,7 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"num_samples"
)
{
(
void
)
builder
->
SetNumSamples
(
ToInt
(
value
));
}
else
if
(
key
==
"num_parallel_workers"
)
{
if
(
key
==
"num_parallel_workers"
)
{
(
void
)
builder
->
SetNumWorkers
(
ToInt
(
value
));
}
else
if
(
key
==
"sampler"
)
{
auto
create
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
).
attr
(
"create"
);
...
...
@@ -930,9 +926,7 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"num_samples"
)
{
(
void
)
builder
->
SetNumSamples
(
ToInt
(
value
));
}
else
if
(
key
==
"num_parallel_workers"
)
{
if
(
key
==
"num_parallel_workers"
)
{
(
void
)
builder
->
SetNumWorkers
(
ToInt
(
value
));
}
else
if
(
key
==
"sampler"
)
{
auto
create
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
).
attr
(
"create"
);
...
...
@@ -966,9 +960,7 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"num_samples"
)
{
(
void
)
builder
->
SetNumSamples
(
ToInt
(
value
));
}
else
if
(
key
==
"num_parallel_workers"
)
{
if
(
key
==
"num_parallel_workers"
)
{
(
void
)
builder
->
SetNumWorkers
(
ToInt
(
value
));
}
else
if
(
key
==
"sampler"
)
{
auto
create
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
).
attr
(
"create"
);
...
...
@@ -1001,9 +993,7 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"num_samples"
)
{
(
void
)
builder
->
SetNumSamples
(
ToInt
(
value
));
}
else
if
(
key
==
"num_parallel_workers"
)
{
if
(
key
==
"num_parallel_workers"
)
{
(
void
)
builder
->
SetNumWorkers
(
ToInt
(
value
));
}
else
if
(
key
==
"sampler"
)
{
auto
create
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
).
attr
(
"create"
);
...
...
@@ -1039,10 +1029,12 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
(
void
)
builder
.
SetNumWorkers
(
ToInt
(
value
));
}
else
if
(
key
==
"schema_file_path"
||
key
==
"schema_json_string"
)
{
schema_exists
=
true
;
}
else
if
(
key
==
"num_samples"
)
{
(
void
)
builder
.
SetTotalRows
(
ToInt
(
value
));
}
else
if
(
key
==
"columns_list"
)
{
columns_to_load
=
ToStringVector
(
value
);
}
else
if
(
key
==
"num_samples"
)
{
// This is not sampling here. The random data op needs to know how much data to
// generate. It does not currently support sampling.
(
void
)
builder
.
SetTotalRows
(
ToInt
(
value
));
}
}
if
(
schema_exists
)
{
...
...
@@ -1077,9 +1069,7 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"num_samples"
)
{
(
void
)
builder
->
SetNumSamples
(
ToInt
(
value
));
}
else
if
(
key
==
"num_parallel_workers"
)
{
if
(
key
==
"num_parallel_workers"
)
{
(
void
)
builder
->
SetNumWorkers
(
ToInt
(
value
));
}
else
if
(
key
==
"sampler"
)
{
auto
create
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
).
attr
(
"create"
);
...
...
@@ -1121,8 +1111,6 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
(
void
)
builder
->
SetDecode
(
ToBool
(
value
));
}
else
if
(
key
==
"extensions"
)
{
(
void
)
builder
->
SetExtensions
(
ToStringSet
(
value
));
}
else
if
(
key
==
"num_samples"
)
{
(
void
)
builder
->
SetNumSamples
(
ToInt
(
value
));
}
else
if
(
key
==
"dataset_type"
)
{
(
void
)
builder
->
SetDatasetType
(
ToString
(
value
));
}
...
...
@@ -1153,7 +1141,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
}
else
if
(
key
==
"shuffle_files"
)
{
(
void
)
builder
->
SetShuffleFiles
(
ToBool
(
value
));
}
else
if
(
key
==
"num_samples"
)
{
(
void
)
builder
->
Set
NumSample
s
(
ToInt
(
value
));
(
void
)
builder
->
Set
TotalRow
s
(
ToInt
(
value
));
}
else
if
(
key
==
"num_shards"
)
{
(
void
)
builder
->
SetNumDevices
(
ToInt
(
value
));
}
else
if
(
key
==
"shard_id"
)
{
...
...
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
769ae609
...
...
@@ -49,7 +49,6 @@
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
...
...
@@ -143,17 +142,16 @@ void bindDatasetOps(py::module *m) {
});
(
void
)
py
::
class_
<
CifarOp
,
DatasetOp
,
std
::
shared_ptr
<
CifarOp
>>
(
*
m
,
"CifarOp"
)
.
def_static
(
"get_num_rows"
,
[](
const
std
::
string
&
dir
,
int64_t
numSamples
,
bool
isCifar10
)
{
.
def_static
(
"get_num_rows"
,
[](
const
std
::
string
&
dir
,
bool
isCifar10
)
{
int64_t
count
=
0
;
THROW_IF_ERROR
(
CifarOp
::
CountTotalRows
(
dir
,
numSamples
,
isCifar10
,
&
count
));
THROW_IF_ERROR
(
CifarOp
::
CountTotalRows
(
dir
,
isCifar10
,
&
count
));
return
count
;
});
(
void
)
py
::
class_
<
ImageFolderOp
,
DatasetOp
,
std
::
shared_ptr
<
ImageFolderOp
>>
(
*
m
,
"ImageFolderOp"
)
.
def_static
(
"get_num_rows_and_classes"
,
[](
const
std
::
string
&
path
,
int64_t
numSamples
)
{
.
def_static
(
"get_num_rows_and_classes"
,
[](
const
std
::
string
&
path
)
{
int64_t
count
=
0
,
num_classes
=
0
;
THROW_IF_ERROR
(
ImageFolderOp
::
CountRowsAndClasses
(
path
,
numSamples
,
std
::
set
<
std
::
string
>
{},
&
count
,
&
num_classes
));
THROW_IF_ERROR
(
ImageFolderOp
::
CountRowsAndClasses
(
path
,
std
::
set
<
std
::
string
>
{},
&
count
,
&
num_classes
));
return
py
::
make_tuple
(
count
,
num_classes
);
});
...
...
@@ -172,22 +170,21 @@ void bindDatasetOps(py::module *m) {
(
void
)
py
::
class_
<
ManifestOp
,
DatasetOp
,
std
::
shared_ptr
<
ManifestOp
>>
(
*
m
,
"ManifestOp"
)
.
def_static
(
"get_num_rows_and_classes"
,
[](
const
std
::
string
&
file
,
int64_t
numSamples
,
const
py
::
dict
&
dict
,
const
std
::
string
&
usage
)
{
[](
const
std
::
string
&
file
,
const
py
::
dict
&
dict
,
const
std
::
string
&
usage
)
{
int64_t
count
=
0
,
num_classes
=
0
;
THROW_IF_ERROR
(
ManifestOp
::
CountTotalRows
(
file
,
numSamples
,
dict
,
usage
,
&
count
,
&
num_classes
));
THROW_IF_ERROR
(
ManifestOp
::
CountTotalRows
(
file
,
dict
,
usage
,
&
count
,
&
num_classes
));
return
py
::
make_tuple
(
count
,
num_classes
);
})
.
def_static
(
"get_class_indexing"
,
[](
const
std
::
string
&
file
,
int64_t
numSamples
,
const
py
::
dict
&
dict
,
const
std
::
string
&
usage
)
{
std
::
map
<
std
::
string
,
int32_t
>
output_class_indexing
;
THROW_IF_ERROR
(
ManifestOp
::
GetClassIndexing
(
file
,
numSamples
,
dict
,
usage
,
&
output_class_indexing
));
return
output_class_indexing
;
});
.
def_static
(
"get_class_indexing"
,
[](
const
std
::
string
&
file
,
const
py
::
dict
&
dict
,
const
std
::
string
&
usage
)
{
std
::
map
<
std
::
string
,
int32_t
>
output_class_indexing
;
THROW_IF_ERROR
(
ManifestOp
::
GetClassIndexing
(
file
,
dict
,
usage
,
&
output_class_indexing
));
return
output_class_indexing
;
});
(
void
)
py
::
class_
<
MnistOp
,
DatasetOp
,
std
::
shared_ptr
<
MnistOp
>>
(
*
m
,
"MnistOp"
)
.
def_static
(
"get_num_rows"
,
[](
const
std
::
string
&
dir
,
int64_t
numSamples
)
{
.
def_static
(
"get_num_rows"
,
[](
const
std
::
string
&
dir
)
{
int64_t
count
=
0
;
THROW_IF_ERROR
(
MnistOp
::
CountTotalRows
(
dir
,
numSamples
,
&
count
));
THROW_IF_ERROR
(
MnistOp
::
CountTotalRows
(
dir
,
&
count
));
return
count
;
});
...
...
@@ -206,13 +203,13 @@ void bindDatasetOps(py::module *m) {
[](
const
std
::
string
&
dir
,
const
std
::
string
&
task_type
,
const
std
::
string
&
task_mode
,
const
py
::
dict
&
dict
,
int64_t
numSamples
)
{
int64_t
count
=
0
;
THROW_IF_ERROR
(
VOCOp
::
CountTotalRows
(
dir
,
task_type
,
task_mode
,
dict
,
numSamples
,
&
count
));
THROW_IF_ERROR
(
VOCOp
::
CountTotalRows
(
dir
,
task_type
,
task_mode
,
dict
,
&
count
));
return
count
;
})
.
def_static
(
"get_class_indexing"
,
[](
const
std
::
string
&
dir
,
const
std
::
string
&
task_type
,
const
std
::
string
&
task_mode
,
const
py
::
dict
&
dict
,
int64_t
numSamples
)
{
const
std
::
string
&
task_mode
,
const
py
::
dict
&
dict
)
{
std
::
map
<
std
::
string
,
int32_t
>
output_class_indexing
;
THROW_IF_ERROR
(
VOCOp
::
GetClassIndexing
(
dir
,
task_type
,
task_mode
,
dict
,
numSamples
,
&
output_class_indexing
));
THROW_IF_ERROR
(
VOCOp
::
GetClassIndexing
(
dir
,
task_type
,
task_mode
,
dict
,
&
output_class_indexing
));
return
output_class_indexing
;
});
}
...
...
@@ -452,25 +449,19 @@ void bindSamplerOps(py::module *m) {
(
void
)
py
::
class_
<
mindrecord
::
ShardOperator
,
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>>
(
*
m
,
"ShardOperator"
);
(
void
)
py
::
class_
<
DistributedSampler
,
Sampler
,
std
::
shared_ptr
<
DistributedSampler
>>
(
*
m
,
"DistributedSampler"
)
.
def
(
py
::
init
<
int64_t
,
int64_t
,
bool
,
uint32_t
>
(),
py
::
arg
(
"numDev"
),
py
::
arg
(
"devId"
),
py
::
arg
(
"shuffle"
),
py
::
arg
(
"seed"
));
.
def
(
py
::
init
<
int64_t
,
int64_t
,
int64_t
,
bool
,
uint32_t
>
());
(
void
)
py
::
class_
<
PKSampler
,
Sampler
,
std
::
shared_ptr
<
PKSampler
>>
(
*
m
,
"PKSampler"
)
.
def
(
py
::
init
<
int64_t
,
bool
>
(),
py
::
arg
(
"kVal"
),
py
::
arg
(
"shuffle"
));
.
def
(
py
::
init
<
int64_t
,
int64_t
,
bool
>
(
));
(
void
)
py
::
class_
<
RandomSampler
,
Sampler
,
std
::
shared_ptr
<
RandomSampler
>>
(
*
m
,
"RandomSampler"
)
.
def
(
py
::
init
<
bool
,
bool
,
int64_t
>
(),
py
::
arg
(
"replacement"
),
py
::
arg
(
"reshuffle_each_epoch"
),
py
::
arg
(
"num_samples"
))
.
def
(
py
::
init
<
bool
,
bool
>
(),
py
::
arg
(
"replacement"
),
py
::
arg
(
"reshuffle_each_epoch"
));
.
def
(
py
::
init
<
int64_t
,
bool
,
bool
>
());
(
void
)
py
::
class_
<
SequentialSampler
,
Sampler
,
std
::
shared_ptr
<
SequentialSampler
>>
(
*
m
,
"SequentialSampler"
)
.
def
(
py
::
init
<>
());
(
void
)
py
::
class_
<
SubsetSampler
,
Sampler
,
std
::
shared_ptr
<
SubsetSampler
>>
(
*
m
,
"SubsetSampler"
)
.
def
(
py
::
init
<
int64_t
,
int64_t
>
(),
py
::
arg
(
"start_index"
),
py
::
arg
(
"subset_size"
));
.
def
(
py
::
init
<
int64_t
,
int64_t
>
());
(
void
)
py
::
class_
<
SubsetRandomSampler
,
Sampler
,
std
::
shared_ptr
<
SubsetRandomSampler
>>
(
*
m
,
"SubsetRandomSampler"
)
.
def
(
py
::
init
<
std
::
vector
<
int64_t
>>
(),
py
::
arg
(
"indices"
));
.
def
(
py
::
init
<
int64_t
,
std
::
vector
<
int64_t
>>
(
));
(
void
)
py
::
class_
<
mindrecord
::
ShardSample
,
mindrecord
::
ShardOperator
,
std
::
shared_ptr
<
mindrecord
::
ShardSample
>>
(
*
m
,
"MindrecordSubsetRandomSampler"
)
...
...
@@ -487,11 +478,10 @@ void bindSamplerOps(py::module *m) {
}));
(
void
)
py
::
class_
<
WeightedRandomSampler
,
Sampler
,
std
::
shared_ptr
<
WeightedRandomSampler
>>
(
*
m
,
"WeightedRandomSampler"
)
.
def
(
py
::
init
<
std
::
vector
<
double
>
,
int64_t
,
bool
>
(),
py
::
arg
(
"weights"
),
py
::
arg
(
"numSamples"
),
py
::
arg
(
"replacement"
));
.
def
(
py
::
init
<
int64_t
,
std
::
vector
<
double
>
,
bool
>
());
(
void
)
py
::
class_
<
PythonSampler
,
Sampler
,
std
::
shared_ptr
<
PythonSampler
>>
(
*
m
,
"PythonSampler"
)
.
def
(
py
::
init
<
py
::
object
>
(),
py
::
arg
(
"pySampler"
));
.
def
(
py
::
init
<
int64_t
,
py
::
object
>
(
));
}
void
bindInfoObjects
(
py
::
module
*
m
)
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
浏览文件 @
769ae609
...
...
@@ -26,7 +26,7 @@
namespace
mindspore
{
namespace
dataset
{
CelebAOp
::
Builder
::
Builder
()
:
builder_decode_
(
false
),
builder_sampler_
(
nullptr
)
,
builder_num_samples_
(
0
)
{
CelebAOp
::
Builder
::
Builder
()
:
builder_decode_
(
false
),
builder_sampler_
(
nullptr
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
builder_num_workers_
=
cfg
->
num_parallel_workers
();
builder_rows_per_buffer_
=
cfg
->
rows_per_buffer
();
...
...
@@ -38,7 +38,9 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
MS_LOG
(
DEBUG
)
<<
"Celeba dataset type is "
<<
builder_dataset_type_
.
c_str
()
<<
"."
;
RETURN_IF_NOT_OK
(
SanityCheck
());
if
(
builder_sampler_
==
nullptr
)
{
builder_sampler_
=
std
::
make_shared
<
SequentialSampler
>
();
int64_t
num_samples
=
0
;
int64_t
start_index
=
0
;
builder_sampler_
=
std
::
make_shared
<
SequentialSampler
>
(
start_index
,
num_samples
);
}
builder_schema_
=
std
::
make_unique
<
DataSchema
>
();
...
...
@@ -47,10 +49,9 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
// label is like this:0 1 0 0 1......
RETURN_IF_NOT_OK
(
builder_schema_
->
AddColumn
(
ColDescriptor
(
"attr"
,
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
1
)));
*
op
=
std
::
make_shared
<
CelebAOp
>
(
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_dir_
,
builder_op_connector_size_
,
builder_decode_
,
builder_dataset_type_
,
builder_extensions_
,
std
::
move
(
builder_schema_
),
std
::
move
(
builder_sampler_
),
builder_num_samples_
);
*
op
=
std
::
make_shared
<
CelebAOp
>
(
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_dir_
,
builder_op_connector_size_
,
builder_decode_
,
builder_dataset_type_
,
builder_extensions_
,
std
::
move
(
builder_schema_
),
std
::
move
(
builder_sampler_
));
if
(
*
op
==
nullptr
)
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"CelebAOp is null"
);
}
...
...
@@ -68,7 +69,7 @@ Status CelebAOp::Builder::SanityCheck() {
CelebAOp
::
CelebAOp
(
int32_t
num_workers
,
int32_t
rows_per_buffer
,
const
std
::
string
&
dir
,
int32_t
queue_size
,
bool
decode
,
const
std
::
string
&
dataset_type
,
const
std
::
set
<
std
::
string
>
&
exts
,
std
::
unique_ptr
<
DataSchema
>
schema
,
std
::
shared_ptr
<
Sampler
>
sampler
,
int64_t
num_samples
)
std
::
unique_ptr
<
DataSchema
>
schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_workers
,
queue_size
),
rows_per_buffer_
(
rows_per_buffer
),
folder_path_
(
dir
),
...
...
@@ -77,8 +78,6 @@ CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::stri
data_schema_
(
std
::
move
(
schema
)),
sampler_
(
std
::
move
(
sampler
)),
num_rows_in_attr_file_
(
0
),
num_rows_exact_
(
0
),
num_samples_
(
num_samples
),
dataset_type_
(
dataset_type
)
{
// Set the column name map (base class field)
for
(
int32_t
index
=
0
;
index
<
data_schema_
->
NumColumns
();
index
++
)
{
...
...
@@ -202,13 +201,6 @@ Status CelebAOp::ParseImageAttrInfo() {
RETURN_IF_NOT_OK
(
attr_info_queue_
->
PopFront
(
&
image_infos
));
while
(
!
image_infos
.
empty
()
&&
needMoreData
)
{
for
(
uint32_t
index
=
0
;
index
<
image_infos
.
size
();
index
++
)
{
if
(
num_samples_
!=
0
&&
image_labels_vec_
.
size
()
>=
num_samples_
)
{
MS_LOG
(
WARNING
)
<<
"Image number("
<<
image_labels_vec_
.
size
()
<<
" is more than"
<<
" rows num eval attr file("
<<
num_rows_in_attr_file_
<<
") or num samples("
<<
num_samples_
<<
")."
;
needMoreData
=
false
;
break
;
}
std
::
string
image_info
=
image_infos
[
index
];
std
::
vector
<
std
::
string
>
split
=
Split
(
image_info
);
std
::
pair
<
std
::
string
,
std
::
vector
<
int32_t
>>
image_labels
;
...
...
@@ -239,14 +231,13 @@ Status CelebAOp::ParseImageAttrInfo() {
RETURN_IF_NOT_OK
(
attr_info_queue_
->
PopFront
(
&
image_infos
));
}
num_rows_exact_
=
image_labels_vec_
.
size
();
num_samples_
=
(
num_samples_
==
0
||
num_samples_
>
num_rows_exact_
)
?
num_rows_exact_
:
num_samples_
;
if
(
num_rows_exact_
==
0
)
{
num_rows_
=
image_labels_vec_
.
size
();
if
(
num_rows_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API "
"validation first."
);
}
MS_LOG
(
DEBUG
)
<<
"Celeba dataset rows number is "
<<
num_rows_
exact_
<<
"."
;
MS_LOG
(
DEBUG
)
<<
"Celeba dataset rows number is "
<<
num_rows_
<<
"."
;
return
Status
::
OK
();
}
...
...
@@ -268,28 +259,6 @@ std::vector<std::string> CelebAOp::Split(const std::string &line) {
return
split
;
}
// Derived from RandomAccessOp
Status
CelebAOp
::
GetNumSamples
(
int64_t
*
num
)
const
{
if
(
num
==
nullptr
||
num_samples_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API "
"validation first."
);
}
(
*
num
)
=
num_samples_
;
return
Status
::
OK
();
}
Status
CelebAOp
::
GetNumRowsInDataset
(
int64_t
*
num
)
const
{
if
(
num
==
nullptr
||
num_rows_exact_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API "
"validation first."
);
}
*
num
=
num_rows_exact_
;
return
Status
::
OK
();
}
// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work
Status
CelebAOp
::
operator
()()
{
RETURN_IF_NOT_OK
(
LaunchThreadsAndInitOp
());
...
...
@@ -310,9 +279,8 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
RETURN_IF_NOT_OK
((
*
data_buffer
)
->
PopRow
(
&
sample_row
));
std
::
shared_ptr
<
Tensor
>
sample_ids
=
sample_row
[
0
];
for
(
auto
itr
=
sample_ids
->
begin
<
int64_t
>
();
itr
!=
sample_ids
->
end
<
int64_t
>
();
++
itr
)
{
if
((
*
itr
)
>=
num_rows_exact_
)
{
MS_LOG
(
WARNING
)
<<
"Sample Id ("
<<
*
itr
<<
") is out of bounds, skipping. Max id is "
<<
num_rows_exact_
<<
"."
;
if
((
*
itr
)
>=
num_rows_
)
{
MS_LOG
(
WARNING
)
<<
"Sample Id ("
<<
*
itr
<<
") is out of bounds, skipping. Max id is "
<<
num_rows_
<<
"."
;
continue
;
}
keys
.
push_back
(
*
itr
);
...
...
@@ -446,7 +414,7 @@ void CelebAOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info
ParallelOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal stuff
out
<<
"
\n
Number of rows:"
<<
num_rows_
exact_
<<
"
\n
celeba dir: "
<<
folder_path_
<<
"
\n\n
"
;
out
<<
"
\n
Number of rows:"
<<
num_rows_
<<
"
\n
celeba dir: "
<<
folder_path_
<<
"
\n\n
"
;
}
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h
浏览文件 @
769ae609
...
...
@@ -108,14 +108,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
return
*
this
;
}
// Setter method
// @param int64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder
&
SetNumSamples
(
int64_t
num_samples
)
{
builder_num_samples_
=
num_samples
;
return
*
this
;
}
// Setter method
// @param const std::string dataset_type: type to be read
// @return Builder setter method returns reference to the builder.
...
...
@@ -141,7 +133,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
std
::
set
<
std
::
string
>
builder_extensions_
;
std
::
shared_ptr
<
Sampler
>
builder_sampler_
;
std
::
unique_ptr
<
DataSchema
>
builder_schema_
;
int64_t
builder_num_samples_
;
std
::
string
builder_dataset_type_
;
};
...
...
@@ -153,7 +144,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read
CelebAOp
(
int32_t
num_workers
,
int32_t
rows_per_buffer
,
const
std
::
string
&
dir
,
int32_t
queue_size
,
bool
decode
,
const
std
::
string
&
dataset_type
,
const
std
::
set
<
std
::
string
>
&
exts
,
std
::
unique_ptr
<
DataSchema
>
schema
,
std
::
shared_ptr
<
Sampler
>
sampler
,
int64_t
num_samples
);
std
::
shared_ptr
<
Sampler
>
sampler
);
~
CelebAOp
()
override
=
default
;
...
...
@@ -163,16 +154,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @return Status - The error code return
Status
operator
()()
override
;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param int64_t num - to return numRows
// @return Status - The error code return
Status
GetNumSamples
(
int64_t
*
num
)
const
override
;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param int64_t num - to return numRows
// @return Status - The error code return
Status
GetNumRowsInDataset
(
int64_t
*
num
)
const
override
;
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t worker_id - id of each worker
// @return Status - The error code return
...
...
@@ -233,11 +214,9 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
std
::
shared_ptr
<
Sampler
>
sampler_
;
std
::
unique_ptr
<
Queue
<
std
::
vector
<
std
::
string
>>>
attr_info_queue_
;
int64_t
num_rows_in_attr_file_
;
// rows number specified in attr file
int64_t
num_rows_exact_
;
// exact rows number,maybe is less than rows_num_in_attr_file_
QueueList
<
std
::
unique_ptr
<
IOBlock
>>
io_block_queues_
;
WaitPost
wp_
;
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
vector
<
int32_t
>>>
image_labels_vec_
;
int64_t
num_samples_
;
std
::
string
dataset_type_
;
std
::
ifstream
partition_file_
;
};
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
浏览文件 @
769ae609
...
...
@@ -35,7 +35,7 @@ constexpr uint32_t kCifarImageChannel = 3;
constexpr
uint32_t
kCifarBlockImageNum
=
5
;
constexpr
uint32_t
kCifarImageSize
=
kCifarImageHeight
*
kCifarImageWidth
*
kCifarImageChannel
;
CifarOp
::
Builder
::
Builder
()
:
num_samples_
(
0
),
sampler_
(
nullptr
)
{
CifarOp
::
Builder
::
Builder
()
:
sampler_
(
nullptr
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
num_workers_
=
cfg
->
num_parallel_workers
();
rows_per_buffer_
=
cfg
->
rows_per_buffer
();
...
...
@@ -46,7 +46,9 @@ CifarOp::Builder::Builder() : num_samples_(0), sampler_(nullptr) {
Status
CifarOp
::
Builder
::
Build
(
std
::
shared_ptr
<
CifarOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
if
(
sampler_
==
nullptr
)
{
sampler_
=
std
::
make_shared
<
SequentialSampler
>
();
int64_t
num_samples
=
0
;
int64_t
start_index
=
0
;
sampler_
=
std
::
make_shared
<
SequentialSampler
>
(
start_index
,
num_samples
);
}
schema_
=
std
::
make_unique
<
DataSchema
>
();
TensorShape
scalar
=
TensorShape
::
CreateScalar
();
...
...
@@ -62,7 +64,7 @@ Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) {
ColDescriptor
(
"fine_label"
,
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
0
,
&
another_scalar
)));
}
*
ptr
=
std
::
make_shared
<
CifarOp
>
(
cifar_type_
,
num_workers_
,
rows_per_buffer_
,
dir_
,
op_connect_size_
,
num_samples_
,
*
ptr
=
std
::
make_shared
<
CifarOp
>
(
cifar_type_
,
num_workers_
,
rows_per_buffer_
,
dir_
,
op_connect_size_
,
std
::
move
(
schema_
),
std
::
move
(
sampler_
));
return
Status
::
OK
();
}
...
...
@@ -76,16 +78,13 @@ Status CifarOp::Builder::SanityCheck() {
}
CifarOp
::
CifarOp
(
CifarType
type
,
int32_t
num_works
,
int32_t
rows_per_buf
,
const
std
::
string
&
file_dir
,
int32_t
queue_size
,
int64_t
num_samples
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
int32_t
queue_size
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_works
,
queue_size
),
cifar_type_
(
type
),
rows_per_buffer_
(
rows_per_buf
),
folder_path_
(
file_dir
),
num_samples_
(
num_samples
),
data_schema_
(
std
::
move
(
data_schema
)),
sampler_
(
std
::
move
(
sampler
)),
num_rows_
(
0
),
row_cnt_
(
0
),
buf_cnt_
(
0
)
{
// set the column name map (base class field)
...
...
@@ -112,8 +111,7 @@ Status CifarOp::operator()() {
for
(
auto
itr
=
sample_ids
->
begin
<
int64_t
>
();
itr
!=
sample_ids
->
end
<
int64_t
>
();
itr
++
)
{
keys
.
push_back
(
*
itr
);
row_cnt_
++
;
if
((
*
itr
)
>=
num_rows_
)
continue
;
// index out of bound, skipping
if
(
row_cnt_
>=
num_samples_
)
break
;
// enough row read, break for loop
if
((
*
itr
)
>=
num_rows_
)
continue
;
// index out of bound, skipping
if
(
row_cnt_
%
rows_per_buffer_
==
0
)
{
RETURN_IF_NOT_OK
(
io_block_queues_
[
buf_cnt_
++
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
(
keys
,
IOBlock
::
kDeIoBlockNone
))));
...
...
@@ -255,30 +253,6 @@ Status CifarOp::InitSampler() {
return
Status
::
OK
();
}
// Derived from RandomAccessOp
Status
CifarOp
::
GetNumSamples
(
int64_t
*
num
)
const
{
if
(
num
==
nullptr
||
num_rows_
==
0
)
{
std
::
string
api
=
cifar_type_
==
kCifar10
?
"Cifar10Dataset"
:
"Cifar100Dataset"
;
std
::
string
err_msg
=
"There is no valid data matching the dataset API "
+
api
+
".Please check file path or dataset API validation first."
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
(
*
num
)
=
num_samples_
;
return
Status
::
OK
();
}
// Derived from RandomAccessOp
Status
CifarOp
::
GetNumRowsInDataset
(
int64_t
*
num
)
const
{
if
(
num
==
nullptr
||
num_rows_
==
0
)
{
std
::
string
api
=
cifar_type_
==
kCifar10
?
"Cifar10Dataset"
:
"Cifar100Dataset"
;
std
::
string
err_msg
=
"There is no valid data matching the dataset API "
+
api
+
".Please check file path or dataset API validation first."
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
(
*
num
)
=
num_rows_
;
return
Status
::
OK
();
}
Status
CifarOp
::
ReadCifarBlockDataAsync
()
{
TaskManager
::
FindMe
()
->
Post
();
RETURN_IF_NOT_OK
(
GetCifarFiles
());
...
...
@@ -404,7 +378,6 @@ Status CifarOp::ParseCifarData() {
}
cifar_image_label_pairs_
.
shrink_to_fit
();
num_rows_
=
cifar_image_label_pairs_
.
size
();
num_samples_
=
(
num_samples_
==
0
||
num_samples_
>
num_rows_
)
?
num_rows_
:
num_samples_
;
if
(
num_rows_
==
0
)
{
std
::
string
api
=
cifar_type_
==
kCifar10
?
"Cifar10Dataset"
:
"Cifar100Dataset"
;
std
::
string
err_msg
=
"There is no valid data matching the dataset API "
+
api
+
...
...
@@ -432,11 +405,11 @@ Status CifarOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) co
return
Status
::
OK
();
}
Status
CifarOp
::
CountTotalRows
(
const
std
::
string
&
dir
,
int64_t
numSamples
,
bool
isCIFAR10
,
int64_t
*
count
)
{
Status
CifarOp
::
CountTotalRows
(
const
std
::
string
&
dir
,
bool
isCIFAR10
,
int64_t
*
count
)
{
// the logic of counting the number of samples is copied from ReadCifar100Block() and ReadCifar10Block()
std
::
shared_ptr
<
CifarOp
>
op
;
*
count
=
0
;
RETURN_IF_NOT_OK
(
Builder
().
SetCifarDir
(
dir
).
Set
NumSamples
(
numSamples
).
Set
CifarType
(
isCIFAR10
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
Builder
().
SetCifarDir
(
dir
).
SetCifarType
(
isCIFAR10
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
op
->
GetCifarFiles
());
if
(
op
->
cifar_type_
==
kCifar10
)
{
constexpr
int64_t
num_cifar10_records
=
10000
;
...
...
@@ -448,7 +421,6 @@ Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool
}
*
count
=
*
count
+
num_cifar10_records
;
}
*
count
=
*
count
<
numSamples
||
numSamples
==
0
?
*
count
:
numSamples
;
return
Status
::
OK
();
}
else
{
int64_t
num_cifar100_records
=
0
;
...
...
@@ -470,7 +442,7 @@ Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
}
*
count
=
num_cifar100_records
<
numSamples
||
numSamples
==
0
?
num_cifar100_records
:
numSamples
;
*
count
=
num_cifar100_records
;
return
Status
::
OK
();
}
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h
浏览文件 @
769ae609
...
...
@@ -73,14 +73,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
return
*
this
;
}
// Setter method
// @param uint64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder
&
SetNumSamples
(
uint64_t
num_samples
)
{
num_samples_
=
num_samples
;
return
*
this
;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
...
...
@@ -121,7 +113,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
private:
std
::
string
dir_
;
int32_t
num_workers_
;
uint64_t
num_samples_
;
int32_t
rows_per_buffer_
;
int32_t
op_connect_size_
;
std
::
shared_ptr
<
Sampler
>
sampler_
;
...
...
@@ -137,7 +128,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @param uint32_t - queueSize - connector queue size
// @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
CifarOp
(
CifarType
type
,
int32_t
num_works
,
int32_t
rows_per_buf
,
const
std
::
string
&
file_dir
,
int32_t
queue_size
,
int64_t
num_samples
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
);
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
);
// Destructor.
~
CifarOp
()
=
default
;
...
...
@@ -152,16 +143,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @return Status - The error code return
Status
operator
()()
override
;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param uint64_t num - to return numRows
// @return Status - The error code return
Status
GetNumSamples
(
int64_t
*
num
)
const
override
;
// Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset
// @param uint64_t num - to return numRows
// @return Status - The error code return
Status
GetNumRowsInDataset
(
int64_t
*
num
)
const
override
;
// A print method typically used for debugging
// @param out
// @param show_all
...
...
@@ -169,11 +150,10 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// Function to count the number of samples in the CIFAR dataset
// @param dir path to the CIFAR directory
// @param numSamples maximum number of samples requested
// @param isCIFAR10 true if CIFAR10 and false if CIFAR100
// @param count output arg that will hold the
minimum of the actual dataset size and numSamples
// @param count output arg that will hold the
actual dataset size
// @return
static
Status
CountTotalRows
(
const
std
::
string
&
dir
,
int64_t
numSamples
,
bool
isCIFAR10
,
int64_t
*
count
);
static
Status
CountTotalRows
(
const
std
::
string
&
dir
,
bool
isCIFAR10
,
int64_t
*
count
);
private:
// Initialize Sampler, calls sampler->Init() within
...
...
@@ -227,10 +207,8 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
CifarType
cifar_type_
;
int32_t
rows_per_buffer_
;
std
::
string
folder_path_
;
int64_t
num_samples_
;
std
::
unique_ptr
<
DataSchema
>
data_schema_
;
std
::
shared_ptr
<
Sampler
>
sampler_
;
int64_t
num_rows_
;
int64_t
row_cnt_
;
int64_t
buf_cnt_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
浏览文件 @
769ae609
...
...
@@ -26,8 +26,7 @@
namespace
mindspore
{
namespace
dataset
{
ImageFolderOp
::
Builder
::
Builder
()
:
builder_decode_
(
false
),
builder_recursive_
(
false
),
builder_num_samples_
(
0
),
builder_sampler_
(
nullptr
)
{
ImageFolderOp
::
Builder
::
Builder
()
:
builder_decode_
(
false
),
builder_recursive_
(
false
),
builder_sampler_
(
nullptr
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
builder_num_workers_
=
cfg
->
num_parallel_workers
();
builder_rows_per_buffer_
=
cfg
->
rows_per_buffer
();
...
...
@@ -37,7 +36,9 @@ ImageFolderOp::Builder::Builder()
Status
ImageFolderOp
::
Builder
::
Build
(
std
::
shared_ptr
<
ImageFolderOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
if
(
builder_sampler_
==
nullptr
)
{
builder_sampler_
=
std
::
make_shared
<
SequentialSampler
>
();
int64_t
num_samples
=
0
;
// default num samples of 0 means to sample entire set of data
int64_t
start_index
=
0
;
builder_sampler_
=
std
::
make_shared
<
SequentialSampler
>
(
start_index
,
num_samples
);
}
builder_schema_
=
std
::
make_unique
<
DataSchema
>
();
TensorShape
scalar
=
TensorShape
::
CreateScalar
();
...
...
@@ -46,9 +47,9 @@ Status ImageFolderOp::Builder::Build(std::shared_ptr<ImageFolderOp> *ptr) {
RETURN_IF_NOT_OK
(
builder_schema_
->
AddColumn
(
ColDescriptor
(
"label"
,
DataType
(
DataType
::
DE_INT32
),
TensorImpl
::
kFlexible
,
0
,
&
scalar
)));
*
ptr
=
std
::
make_shared
<
ImageFolderOp
>
(
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_dir_
,
builder_op_connector_size_
,
builder_
num_samples_
,
builder_recursiv
e_
,
builder_
decode_
,
builder_extensions_
,
builder_labels_to_read_
,
std
::
move
(
builder_s
chema_
),
std
::
move
(
builder_s
ampler_
));
builder_op_connector_size_
,
builder_
recursive_
,
builder_decod
e_
,
builder_
extensions_
,
builder_labels_to_read_
,
std
::
move
(
builder_schema_
)
,
std
::
move
(
builder_sampler_
));
return
Status
::
OK
();
}
...
...
@@ -61,20 +62,18 @@ Status ImageFolderOp::Builder::SanityCheck() {
}
ImageFolderOp
::
ImageFolderOp
(
int32_t
num_wkrs
,
int32_t
rows_per_buffer
,
std
::
string
file_dir
,
int32_t
queue_size
,
int64_t
num_samples
,
bool
recursive
,
bool
do_decode
,
const
std
::
set
<
std
::
string
>
&
exts
,
bool
recursive
,
bool
do_decode
,
const
std
::
set
<
std
::
string
>
&
exts
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
map
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_wkrs
,
queue_size
),
rows_per_buffer_
(
rows_per_buffer
),
folder_path_
(
file_dir
),
num_samples_
(
num_samples
),
recursive_
(
recursive
),
decode_
(
do_decode
),
extensions_
(
exts
),
class_index_
(
map
),
data_schema_
(
std
::
move
(
data_schema
)),
sampler_
(
std
::
move
(
sampler
)),
num_rows_
(
0
),
row_cnt_
(
0
),
buf_cnt_
(
0
),
sampler_ind_
(
0
),
...
...
@@ -117,7 +116,6 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) {
}
image_label_pairs_
.
shrink_to_fit
();
num_rows_
=
image_label_pairs_
.
size
();
num_samples_
=
(
num_samples_
==
0
||
num_samples_
>
num_rows_
)
?
num_rows_
:
num_samples_
;
// free memory of two queues used for pre-scan
folder_name_queue_
->
Reset
();
image_name_queue_
->
Reset
();
...
...
@@ -138,8 +136,7 @@ Status ImageFolderOp::operator()() {
std
::
shared_ptr
<
Tensor
>
sample_ids
=
sample_row
[
0
];
if
(
sample_ids
->
type
()
!=
DataType
(
DataType
::
DE_INT64
))
RETURN_STATUS_UNEXPECTED
(
"Sampler Tensor isn't int64"
);
for
(
auto
itr
=
sample_ids
->
begin
<
int64_t
>
();
itr
!=
sample_ids
->
end
<
int64_t
>
();
++
itr
)
{
if
((
*
itr
)
>=
num_rows_
)
continue
;
// index out of bound, skipping
if
(
row_cnt_
>=
num_samples_
)
break
;
// enough row read, break for loop
if
((
*
itr
)
>=
num_rows_
)
continue
;
// index out of bound, skipping
keys
.
push_back
(
*
itr
);
row_cnt_
++
;
if
(
row_cnt_
%
rows_per_buffer_
==
0
)
{
...
...
@@ -272,28 +269,6 @@ Status ImageFolderOp::InitSampler() {
return
Status
::
OK
();
}
// Derived from RandomAccessOp
Status
ImageFolderOp
::
GetNumSamples
(
int64_t
*
num
)
const
{
if
(
num
==
nullptr
||
num_samples_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset API "
"validation first."
);
}
(
*
num
)
=
num_samples_
;
return
Status
::
OK
();
}
// Derived from RandomAccessOp
Status
ImageFolderOp
::
GetNumRowsInDataset
(
int64_t
*
num
)
const
{
if
(
num
==
nullptr
||
num_rows_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset API "
"validation first."
);
}
(
*
num
)
=
num_rows_
;
return
Status
::
OK
();
}
// Derived from RandomAccessOp
Status
ImageFolderOp
::
GetClassIds
(
std
::
map
<
int32_t
,
std
::
vector
<
int64_t
>>
*
cls_ids
)
const
{
if
(
cls_ids
==
nullptr
||
!
cls_ids
->
empty
()
||
image_label_pairs_
.
empty
())
{
...
...
@@ -413,16 +388,14 @@ Status ImageFolderOp::LaunchThreadsAndInitOp() {
return
Status
::
OK
();
}
Status
ImageFolderOp
::
CountRowsAndClasses
(
const
std
::
string
&
path
,
const
int64_t
&
num_samples
,
const
std
::
set
<
std
::
string
>
&
exts
,
int64_t
*
num_rows
,
int64_t
*
num_classes
,
int64_t
dev_id
,
int64_t
num_dev
)
{
Status
ImageFolderOp
::
CountRowsAndClasses
(
const
std
::
string
&
path
,
const
std
::
set
<
std
::
string
>
&
exts
,
int64_t
*
num_rows
,
int64_t
*
num_classes
,
int64_t
dev_id
,
int64_t
num_dev
)
{
Path
dir
(
path
);
std
::
string
err_msg
=
""
;
int64_t
row_cnt
=
0
;
err_msg
+=
(
dir
.
Exists
()
==
false
||
dir
.
IsDirectory
()
==
false
)
?
"unable to open dir "
+
path
:
""
;
err_msg
+=
(
num_classes
==
nullptr
||
num_rows
==
nullptr
)
?
"num_class/num_rows is null
\n
"
:
""
;
err_msg
+=
(
dev_id
>=
num_dev
||
num_dev
<=
0
)
?
"invalid sharding config
\n
"
:
""
;
err_msg
+=
num_samples
<
0
?
"num_samples can't be negative! set it to 0 to use all samples
\n
"
:
""
;
if
(
err_msg
.
empty
()
==
false
)
{
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
...
...
@@ -441,10 +414,6 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t
while
(
dir_itr
->
hasNext
())
{
if
(
exts
.
empty
()
||
exts
.
find
(
subdir
.
Extension
())
!=
exts
.
end
())
{
++
row_cnt
;
if
(
row_cnt
==
num_samples
*
num_dev
)
{
(
*
num_rows
)
=
(
row_cnt
/
num_dev
)
+
(
row_cnt
%
num_dev
==
0
?
0
:
1
);
return
Status
::
OK
();
}
}
}
foldernames
.
pop
();
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h
浏览文件 @
769ae609
...
...
@@ -107,14 +107,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
return
*
this
;
}
// Setter method
// @param int64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder
&
SetNumSamples
(
int64_t
num_samples
)
{
builder_num_samples_
=
num_samples
;
return
*
this
;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
...
...
@@ -153,7 +145,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
bool
builder_recursive_
;
std
::
string
builder_dir_
;
int32_t
builder_num_workers_
;
int64_t
builder_num_samples_
;
int32_t
builder_rows_per_buffer_
;
int32_t
builder_op_connector_size_
;
std
::
set
<
std
::
string
>
builder_extensions_
;
...
...
@@ -169,10 +160,9 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// @param int32_t queue_size - connector queue size
// @param std::set<std::string> exts - set of file extensions to read, if empty, read everything under the dir
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
ImageFolderOp
(
int32_t
num_wkrs
,
int32_t
rows_per_buffer
,
std
::
string
file_dir
,
int32_t
queue_size
,
int64_t
num_samples
,
bool
recursive
,
bool
do_decode
,
const
std
::
set
<
std
::
string
>
&
exts
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
map
,
std
::
unique_ptr
<
DataSchema
>
,
std
::
shared_ptr
<
Sampler
>
sampler
);
ImageFolderOp
(
int32_t
num_wkrs
,
int32_t
rows_per_buffer
,
std
::
string
file_dir
,
int32_t
queue_size
,
bool
recursive
,
bool
do_decode
,
const
std
::
set
<
std
::
string
>
&
exts
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
map
,
std
::
unique_ptr
<
DataSchema
>
,
std
::
shared_ptr
<
Sampler
>
sampler
);
// Destructor.
~
ImageFolderOp
()
=
default
;
...
...
@@ -198,16 +188,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// @return Status - The error code return
Status
operator
()()
override
;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param int64_t num - to return numRows
// @return Status - The error code return
Status
GetNumSamples
(
int64_t
*
num
)
const
override
;
// Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset
// @param int64_t num - to return numRows
// @return Status - The error code return
Status
GetNumRowsInDataset
(
int64_t
*
num
)
const
override
;
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
// @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
...
...
@@ -221,9 +201,8 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// This function is a hack! It is to return the num_class and num_rows the old storageOp does. The result
// returned by this function may not be consistent with what image_folder_op is going to return
// user this at your own risk!
static
Status
CountRowsAndClasses
(
const
std
::
string
&
path
,
const
int64_t
&
num_samples
,
const
std
::
set
<
std
::
string
>
&
exts
,
int64_t
*
num_rows
,
int64_t
*
num_classes
,
int64_t
dev_id
=
0
,
int64_t
num_dev
=
1
);
static
Status
CountRowsAndClasses
(
const
std
::
string
&
path
,
const
std
::
set
<
std
::
string
>
&
exts
,
int64_t
*
num_rows
,
int64_t
*
num_classes
,
int64_t
dev_id
=
0
,
int64_t
num_dev
=
1
);
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
...
...
@@ -266,14 +245,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
int32_t
rows_per_buffer_
;
std
::
string
folder_path_
;
// directory of image folder
int64_t
num_samples_
;
bool
recursive_
;
bool
decode_
;
std
::
set
<
std
::
string
>
extensions_
;
// extensions allowed
std
::
map
<
std
::
string
,
int32_t
>
class_index_
;
std
::
unique_ptr
<
DataSchema
>
data_schema_
;
std
::
shared_ptr
<
Sampler
>
sampler_
;
int64_t
num_rows_
;
// total number of images in ImageFolder
int64_t
row_cnt_
;
int64_t
buf_cnt_
;
int64_t
sampler_ind_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
浏览文件 @
769ae609
...
...
@@ -29,7 +29,7 @@
namespace
mindspore
{
namespace
dataset
{
ManifestOp
::
Builder
::
Builder
()
:
builder_sampler_
(
nullptr
),
builder_
num_samples_
(
0
),
builder_
decode_
(
false
)
{
ManifestOp
::
Builder
::
Builder
()
:
builder_sampler_
(
nullptr
),
builder_decode_
(
false
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
builder_num_workers_
=
cfg
->
num_parallel_workers
();
builder_rows_per_buffer_
=
cfg
->
rows_per_buffer
();
...
...
@@ -39,16 +39,18 @@ ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_num_samples_
Status
ManifestOp
::
Builder
::
Build
(
std
::
shared_ptr
<
ManifestOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
if
(
builder_sampler_
==
nullptr
)
{
builder_sampler_
=
std
::
make_shared
<
SequentialSampler
>
();
int64_t
num_samples
=
0
;
int64_t
start_index
=
0
;
builder_sampler_
=
std
::
make_shared
<
SequentialSampler
>
(
start_index
,
num_samples
);
}
builder_schema_
=
std
::
make_unique
<
DataSchema
>
();
RETURN_IF_NOT_OK
(
builder_schema_
->
AddColumn
(
ColDescriptor
(
"image"
,
DataType
(
DataType
::
DE_UINT8
),
TensorImpl
::
kFlexible
,
1
)));
RETURN_IF_NOT_OK
(
builder_schema_
->
AddColumn
(
ColDescriptor
(
"label"
,
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
1
)));
*
ptr
=
std
::
make_shared
<
ManifestOp
>
(
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_file_
,
builder_op_connector_size_
,
builder_num_samples
_
,
builder_decode_
,
builder_labels_to_read_
,
std
::
move
(
builder_schema_
),
std
::
move
(
builder_sampler_
),
builder_usage_
);
*
ptr
=
std
::
make_shared
<
ManifestOp
>
(
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_file_
,
builder_op_connector_size_
,
builder_decode_
,
builder_labels_to_read
_
,
std
::
move
(
builder_schema_
),
std
::
move
(
builder_sampler_
),
builder_usage_
);
return
Status
::
OK
();
}
...
...
@@ -59,9 +61,9 @@ Status ManifestOp::Builder::SanityCheck() {
return
err_msg
.
empty
()
?
Status
::
OK
()
:
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
err_msg
);
}
ManifestOp
::
ManifestOp
(
int32_t
num_works
,
int32_t
rows_per_buffer
,
std
::
string
file
,
int32_t
queue_size
,
int64_t
num_samples
,
bool
decode
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
class_index
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
,
std
::
string
usage
)
ManifestOp
::
ManifestOp
(
int32_t
num_works
,
int32_t
rows_per_buffer
,
std
::
string
file
,
int32_t
queue_size
,
bool
decode
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
class_index
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
,
std
::
string
usage
)
:
ParallelOp
(
num_works
,
queue_size
),
rows_per_buffer_
(
rows_per_buffer
),
io_block_pushed_
(
0
),
...
...
@@ -71,8 +73,6 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f
file_
(
file
),
class_index_
(
class_index
),
sampler_
(
std
::
move
(
sampler
)),
num_samples_
(
num_samples
),
num_rows_
(
0
),
decode_
(
decode
),
usage_
(
usage
),
buf_cnt_
(
0
)
{
...
...
@@ -101,8 +101,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
RETURN_IF_NOT_OK
((
*
sampler_buffer
)
->
PopRow
(
&
sample_row
));
std
::
shared_ptr
<
Tensor
>
sample_ids
=
sample_row
[
0
];
for
(
auto
itr
=
sample_ids
->
begin
<
int64_t
>
();
itr
!=
sample_ids
->
end
<
int64_t
>
();
++
itr
)
{
if
((
*
itr
)
>=
num_rows_
)
continue
;
// index out of bound, skipping
if
(
row_cnt_
>=
num_samples_
)
break
;
// enough row read, break for loop
if
((
*
itr
)
>=
num_rows_
)
continue
;
// index out of bound, skipping
keys
.
push_back
(
*
itr
);
row_cnt_
++
;
if
(
row_cnt_
%
rows_per_buffer_
==
0
)
{
...
...
@@ -269,28 +268,6 @@ Status ManifestOp::InitSampler() {
return
Status
::
OK
();
}
// Derived from RandomAccessOp
Status
ManifestOp
::
GetNumSamples
(
int64_t
*
num
)
const
{
if
(
num
==
nullptr
||
num_rows_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API "
"validation first."
);
}
(
*
num
)
=
num_samples_
;
return
Status
::
OK
();
}
// Derived from RandomAccessOp
Status
ManifestOp
::
GetNumRowsInDataset
(
int64_t
*
num
)
const
{
if
(
num
==
nullptr
||
num_rows_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API "
"validation first."
);
}
(
*
num
)
=
num_rows_
;
return
Status
::
OK
();
}
// Derived from RandomAccessOp
Status
ManifestOp
::
GetClassIds
(
std
::
map
<
int32_t
,
std
::
vector
<
int64_t
>>
*
cls_ids
)
const
{
if
(
cls_ids
==
nullptr
||
!
cls_ids
->
empty
()
||
image_labelname_
.
empty
())
{
...
...
@@ -408,7 +385,6 @@ Status ManifestOp::CountDatasetInfo() {
}
num_rows_
=
static_cast
<
int64_t
>
(
image_labelname_
.
size
());
num_samples_
=
(
num_samples_
==
0
||
num_samples_
>
num_rows_
)
?
num_rows_
:
num_samples_
;
if
(
num_rows_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API "
...
...
@@ -417,8 +393,8 @@ Status ManifestOp::CountDatasetInfo() {
return
Status
::
OK
();
}
Status
ManifestOp
::
CountTotalRows
(
const
std
::
string
&
file
,
int64_t
numSamples
,
const
py
::
dict
&
dict
,
const
std
::
string
&
usage
,
int64_t
*
count
,
int64_t
*
numClasses
)
{
Status
ManifestOp
::
CountTotalRows
(
const
std
::
string
&
file
,
const
py
::
dict
&
dict
,
const
std
::
string
&
usage
,
int64_t
*
count
,
int64_t
*
numClasses
)
{
// the logic of counting the number of samples is copied from ParseManifestFile()
std
::
map
<
std
::
string
,
int32_t
>
map
;
for
(
auto
p
:
dict
)
{
...
...
@@ -428,17 +404,15 @@ Status ManifestOp::CountTotalRows(const std::string &file, int64_t numSamples, c
std
::
shared_ptr
<
ManifestOp
>
op
;
*
count
=
0
;
RETURN_IF_NOT_OK
(
Builder
().
SetManifestFile
(
file
).
SetNumSamples
(
numSamples
).
SetClassIndex
(
map
).
SetUsage
(
usage
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
Builder
().
SetManifestFile
(
file
).
SetClassIndex
(
map
).
SetUsage
(
usage
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
op
->
ParseManifestFile
());
*
numClasses
=
static_cast
<
int64_t
>
(
op
->
label_index_
.
size
());
*
count
=
static_cast
<
int64_t
>
(
op
->
image_labelname_
.
size
());
*
count
=
(
*
count
<
numSamples
||
numSamples
==
0
)
?
*
count
:
numSamples
;
return
Status
::
OK
();
}
Status
ManifestOp
::
GetClassIndexing
(
const
std
::
string
&
file
,
int64_t
numSamples
,
const
py
::
dict
&
dict
,
const
std
::
string
&
usage
,
std
::
map
<
std
::
string
,
int32_t
>
*
output_class_indexing
)
{
Status
ManifestOp
::
GetClassIndexing
(
const
std
::
string
&
file
,
const
py
::
dict
&
dict
,
const
std
::
string
&
usage
,
std
::
map
<
std
::
string
,
int32_t
>
*
output_class_indexing
)
{
std
::
map
<
std
::
string
,
int32_t
>
input_class_indexing
;
for
(
auto
p
:
dict
)
{
(
void
)
input_class_indexing
.
insert
(
std
::
pair
<
std
::
string
,
int32_t
>
(
py
::
reinterpret_borrow
<
py
::
str
>
(
p
.
first
),
...
...
@@ -449,12 +423,7 @@ Status ManifestOp::GetClassIndexing(const std::string &file, int64_t numSamples,
*
output_class_indexing
=
input_class_indexing
;
}
else
{
std
::
shared_ptr
<
ManifestOp
>
op
;
RETURN_IF_NOT_OK
(
Builder
()
.
SetManifestFile
(
file
)
.
SetNumSamples
(
numSamples
)
.
SetClassIndex
(
input_class_indexing
)
.
SetUsage
(
usage
)
.
Build
(
&
op
));
RETURN_IF_NOT_OK
(
Builder
().
SetManifestFile
(
file
).
SetClassIndex
(
input_class_indexing
).
SetUsage
(
usage
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
op
->
ParseManifestFile
());
RETURN_IF_NOT_OK
(
op
->
CountDatasetInfo
());
uint32_t
count
=
0
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h
浏览文件 @
769ae609
...
...
@@ -86,14 +86,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
return
*
this
;
}
// Setter method
// @param int64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder
&
SetNumSamples
(
int64_t
num_samples
)
{
builder_num_samples_
=
num_samples
;
return
*
this
;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
...
...
@@ -129,7 +121,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
private:
std
::
shared_ptr
<
Sampler
>
builder_sampler_
;
int64_t
builder_num_samples_
;
bool
builder_decode_
;
std
::
string
builder_file_
;
...
...
@@ -147,8 +138,8 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// @param std::string - file list of Manifest
// @param int32_t queue_size - connector queue size
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
ManifestOp
(
int32_t
num_works
,
int32_t
rows_per_buffer
,
std
::
string
file
,
int32_t
queue_size
,
int64_t
num_samples
,
bool
decode
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
class_index
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
ManifestOp
(
int32_t
num_works
,
int32_t
rows_per_buffer
,
std
::
string
file
,
int32_t
queue_size
,
bool
decode
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
class_index
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
,
std
::
string
usage
);
// Destructor.
~
ManifestOp
()
=
default
;
...
...
@@ -164,16 +155,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// @return Status - The error code return
Status
operator
()()
override
;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param int64_t num - to return numRows
// @return Status - The error code return
Status
GetNumSamples
(
int64_t
*
num
)
const
override
;
// Method derived from RandomAccess Op, enable Sampler to get total number of Rows in dataset
// @param int64_t num - to return numRows
// @return Status - The error code return
Status
GetNumRowsInDataset
(
int64_t
*
num
)
const
override
;
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
// @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
...
...
@@ -184,12 +165,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// @param show_all
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
static
Status
CountTotalRows
(
const
std
::
string
&
file
,
int64_t
numSamples
,
const
py
::
dict
&
dic
t
,
const
std
::
string
&
usage
,
int64_t
*
count
,
int64_t
*
numClasses
);
static
Status
CountTotalRows
(
const
std
::
string
&
file
,
const
py
::
dict
&
dict
,
const
std
::
string
&
usage
,
int64_t
*
coun
t
,
int64_t
*
numClasses
);
// Get str-to-int mapping from label name to index
static
Status
GetClassIndexing
(
const
std
::
string
&
file
,
int64_t
numSamples
,
const
py
::
dict
&
dict
,
const
std
::
string
&
usage
,
std
::
map
<
std
::
string
,
int32_t
>
*
output_class_indexing
);
static
Status
GetClassIndexing
(
const
std
::
string
&
file
,
const
py
::
dict
&
dict
,
const
std
::
string
&
usage
,
std
::
map
<
std
::
string
,
int32_t
>
*
output_class_indexing
);
private:
// Initialize Sampler, calls sampler->Init() within
...
...
@@ -240,8 +221,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
std
::
string
file_
;
// file that store the information of images
std
::
map
<
std
::
string
,
int32_t
>
class_index_
;
std
::
shared_ptr
<
Sampler
>
sampler_
;
int64_t
num_samples_
;
int64_t
num_rows_
;
bool
decode_
;
std
::
string
usage_
;
int64_t
buf_cnt_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
浏览文件 @
769ae609
...
...
@@ -91,7 +91,6 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf
block_reader_
(
block_reader
),
buffers_needed_
(
0
),
buf_cnt_
(
0
),
num_rows_
(
0
),
ended_worker_
(
0
),
buffer_water_mark_
(
0
)
{
io_blk_queues_
.
Init
(
num_workers_
,
op_connector_queue_size
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
浏览文件 @
769ae609
...
...
@@ -31,7 +31,7 @@ const int32_t kMnistLabelFileMagicNumber = 2049;
const
int32_t
kMnistImageRows
=
28
;
const
int32_t
kMnistImageCols
=
28
;
MnistOp
::
Builder
::
Builder
()
:
builder_
num_samples_
(
0
),
builder_
sampler_
(
nullptr
)
{
MnistOp
::
Builder
::
Builder
()
:
builder_sampler_
(
nullptr
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
builder_num_workers_
=
cfg
->
num_parallel_workers
();
builder_rows_per_buffer_
=
cfg
->
rows_per_buffer
();
...
...
@@ -41,7 +41,9 @@ MnistOp::Builder::Builder() : builder_num_samples_(0), builder_sampler_(nullptr)
Status
MnistOp
::
Builder
::
Build
(
std
::
shared_ptr
<
MnistOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
if
(
builder_sampler_
==
nullptr
)
{
builder_sampler_
=
std
::
make_shared
<
SequentialSampler
>
();
int64_t
num_samples
=
0
;
int64_t
start_index
=
0
;
builder_sampler_
=
std
::
make_shared
<
SequentialSampler
>
(
start_index
,
num_samples
);
}
builder_schema_
=
std
::
make_unique
<
DataSchema
>
();
RETURN_IF_NOT_OK
(
...
...
@@ -49,9 +51,8 @@ Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) {
TensorShape
scalar
=
TensorShape
::
CreateScalar
();
RETURN_IF_NOT_OK
(
builder_schema_
->
AddColumn
(
ColDescriptor
(
"label"
,
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
0
,
&
scalar
)));
*
ptr
=
std
::
make_shared
<
MnistOp
>
(
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_dir_
,
builder_op_connector_size_
,
builder_num_samples_
,
std
::
move
(
builder_schema_
),
std
::
move
(
builder_sampler_
));
*
ptr
=
std
::
make_shared
<
MnistOp
>
(
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_dir_
,
builder_op_connector_size_
,
std
::
move
(
builder_schema_
),
std
::
move
(
builder_sampler_
));
return
Status
::
OK
();
}
...
...
@@ -60,17 +61,14 @@ Status MnistOp::Builder::SanityCheck() {
std
::
string
err_msg
;
err_msg
+=
dir
.
IsDirectory
()
==
false
?
"MNIST path is invalid or not set
\n
"
:
""
;
err_msg
+=
builder_num_workers_
<=
0
?
"Number of parallel workers is set to 0 or negative
\n
"
:
""
;
err_msg
+=
builder_num_samples_
<
0
?
"Number of samples is set to negative
\n
"
:
""
;
return
err_msg
.
empty
()
?
Status
::
OK
()
:
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
err_msg
);
}
MnistOp
::
MnistOp
(
int32_t
num_workers
,
int32_t
rows_per_buffer
,
std
::
string
folder_path
,
int32_t
queue_size
,
int64_t
num_samples
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_workers
,
queue_size
),
buf_cnt_
(
0
),
row_cnt_
(
0
),
num_rows_
(
0
),
num_samples_
(
num_samples
),
folder_path_
(
folder_path
),
rows_per_buffer_
(
rows_per_buffer
),
sampler_
(
std
::
move
(
sampler
)),
...
...
@@ -84,8 +82,7 @@ MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folde
Status
MnistOp
::
TraversalSampleIds
(
const
std
::
shared_ptr
<
Tensor
>
&
sample_ids
,
std
::
vector
<
int64_t
>
*
keys
)
{
for
(
auto
itr
=
sample_ids
->
begin
<
int64_t
>
();
itr
!=
sample_ids
->
end
<
int64_t
>
();
++
itr
)
{
if
((
*
itr
)
>=
num_rows_
)
continue
;
// index out of bound, skipping
if
(
row_cnt_
>=
num_samples_
)
break
;
// enough row read, break for loop
if
((
*
itr
)
>=
num_rows_
)
continue
;
// index out of bound, skipping
keys
->
push_back
(
*
itr
);
row_cnt_
++
;
if
(
row_cnt_
%
rows_per_buffer_
==
0
)
{
...
...
@@ -219,17 +216,6 @@ Status MnistOp::InitSampler() {
return
Status
::
OK
();
}
// Derived from RandomAccessOp
Status
MnistOp
::
GetNumSamples
(
int64_t
*
num
)
const
{
if
(
num
==
nullptr
||
num_rows_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API "
"validation first."
);
}
(
*
num
)
=
num_samples_
;
return
Status
::
OK
();
}
// Derived from RandomAccessOp
Status
MnistOp
::
GetClassIds
(
std
::
map
<
int32_t
,
std
::
vector
<
int64_t
>>
*
cls_ids
)
const
{
if
(
cls_ids
==
nullptr
||
!
cls_ids
->
empty
()
||
image_label_pairs_
.
empty
())
{
...
...
@@ -364,7 +350,6 @@ Status MnistOp::ParseMnistData() {
}
image_label_pairs_
.
shrink_to_fit
();
num_rows_
=
image_label_pairs_
.
size
();
num_samples_
=
(
num_samples_
==
0
||
num_samples_
>
num_rows_
)
?
num_rows_
:
num_samples_
;
return
Status
::
OK
();
}
...
...
@@ -414,11 +399,11 @@ Status MnistOp::LaunchThreadsAndInitOp() {
return
Status
::
OK
();
}
Status
MnistOp
::
CountTotalRows
(
const
std
::
string
&
dir
,
int64_t
numSamples
,
int64_t
*
count
)
{
Status
MnistOp
::
CountTotalRows
(
const
std
::
string
&
dir
,
int64_t
*
count
)
{
// the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader()
std
::
shared_ptr
<
MnistOp
>
op
;
*
count
=
0
;
RETURN_IF_NOT_OK
(
Builder
().
SetDir
(
dir
).
SetNumSamples
(
numSamples
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
Builder
().
SetDir
(
dir
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
op
->
WalkAllFiles
());
...
...
@@ -440,19 +425,6 @@ Status MnistOp::CountTotalRows(const std::string &dir, int64_t numSamples, int64
label_reader
.
close
();
}
*
count
=
(
numSamples
==
0
||
*
count
<
numSamples
)
?
*
count
:
numSamples
;
return
Status
::
OK
();
}
// Derived from RandomAccessOp
Status
MnistOp
::
GetNumRowsInDataset
(
int64_t
*
num
)
const
{
if
(
num
==
nullptr
||
num_rows_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API "
"validation first."
);
}
(
*
num
)
=
num_rows_
;
return
Status
::
OK
();
}
}
// namespace dataset
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h
浏览文件 @
769ae609
...
...
@@ -78,14 +78,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
return
*
this
;
}
// Setter method
// @param int64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder
&
SetNumSamples
(
int64_t
num_samples
)
{
builder_num_samples_
=
num_samples
;
return
*
this
;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
...
...
@@ -114,7 +106,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
private:
std
::
string
builder_dir_
;
int32_t
builder_num_workers_
;
int64_t
builder_num_samples_
;
int32_t
builder_rows_per_buffer_
;
int32_t
builder_op_connector_size_
;
std
::
shared_ptr
<
Sampler
>
builder_sampler_
;
...
...
@@ -126,11 +117,10 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @param int32_t rows_per_buffer - number of images (rows) in each buffer
// @param std::string folder_path - dir directory of mnist
// @param int32_t queue_size - connector queue size
// @param int64_t num_samples - number of samples to read
// @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset
// @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read
MnistOp
(
int32_t
num_workers
,
int32_t
rows_per_buffer
,
std
::
string
folder_path
,
int32_t
queue_size
,
int64_t
num_samples
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
);
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
);
// Destructor.
~
MnistOp
()
=
default
;
...
...
@@ -146,16 +136,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @return Status - The error code return
Status
operator
()()
override
;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param int64_t num - to return numRows
// @return Status - The error code return
Status
GetNumSamples
(
int64_t
*
num
)
const
override
;
// Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset
// @param int64_t num - to return numRows
// @return Status - The error code return
Status
GetNumRowsInDataset
(
int64_t
*
num
)
const
override
;
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
// @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
...
...
@@ -167,11 +147,10 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
// Function to count the number of samples in the MNIST dataset
// @param dir path to the MNSIT directory
// @param numSamples maximum number of samples requested
// @param dir path to the MNIST directory
// @param count output arg that will hold the minimum of the actual dataset size and numSamples
// @return
static
Status
CountTotalRows
(
const
std
::
string
&
dir
,
int64_t
numSamples
,
int64_t
*
count
);
static
Status
CountTotalRows
(
const
std
::
string
&
dir
,
int64_t
*
count
);
private:
// Initialize Sampler, calls sampler->Init() within
...
...
@@ -244,9 +223,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
int64_t
buf_cnt_
;
int64_t
row_cnt_
;
int64_t
num_rows_
;
// total number of images in Mnist
WaitPost
wp_
;
int64_t
num_samples_
;
std
::
string
folder_path_
;
// directory of image folder
int32_t
rows_per_buffer_
;
std
::
shared_ptr
<
Sampler
>
sampler_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt
浏览文件 @
769ae609
...
...
@@ -8,6 +8,5 @@ add_library(engine-datasetops-source-sampler OBJECT
sampler.cc
sequential_sampler.cc
subset_random_sampler.cc
subset_sampler.cc
weighted_random_sampler.cc
)
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
浏览文件 @
769ae609
...
...
@@ -23,8 +23,9 @@
namespace
mindspore
{
namespace
dataset
{
DistributedSampler
::
DistributedSampler
(
int64_t
num_dev
,
int64_t
dev_id
,
bool
shuffle
,
uint32_t
seed
)
:
Sampler
(),
DistributedSampler
::
DistributedSampler
(
int64_t
num_samples
,
int64_t
num_dev
,
int64_t
dev_id
,
bool
shuffle
,
uint32_t
seed
)
:
Sampler
(
num_samples
,
std
::
numeric_limits
<
int64_t
>::
max
()),
cnt_
(
0
),
seed_
(
seed
==
std
::
numeric_limits
<
uint32_t
>::
max
()
?
GetSeed
()
:
seed
),
device_id_
(
dev_id
),
...
...
@@ -32,6 +33,11 @@ DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shu
shuffle_
(
shuffle
)
{}
Status
DistributedSampler
::
InitSampler
()
{
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if
(
num_samples_
==
0
||
num_samples_
>
num_rows_
)
{
num_samples_
=
num_rows_
;
}
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
,
"num_samples <= 0
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
,
"num_rows <= 0
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
device_id_
<
num_devices_
&&
device_id_
>=
0
&&
num_rows_
>
0
&&
num_samples_
>
0
,
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h
浏览文件 @
769ae609
...
...
@@ -27,10 +27,11 @@ namespace mindspore {
namespace
dataset
{
class
DistributedSampler
:
public
Sampler
{
public:
// @param int64_t numDev
// @param int64_t devId
// @param num_samples
// @param int64_t num_dev
// @param int64_t dev_id
// @param bool shuffle
DistributedSampler
(
int64_t
num_
dev
,
int64_t
dev_id
,
bool
shuffle
=
tru
e
,
DistributedSampler
(
int64_t
num_
samples
,
int64_t
num_dev
,
int64_t
dev_id
,
bool
shuffl
e
,
uint32_t
seed
=
std
::
numeric_limits
<
uint32_t
>::
max
());
// default destructor
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
浏览文件 @
769ae609
...
...
@@ -20,12 +20,11 @@
namespace
mindspore
{
namespace
dataset
{
PKSampler
::
PKSampler
(
int64_t
val
,
bool
shuffle
,
int64_t
samples_per_buffer
)
:
Sampler
(
samples_per_buffer
),
PKSampler
::
PKSampler
(
int64_t
num_samples
,
int64_t
val
,
bool
shuffle
,
int64_t
samples_per_buffer
)
:
Sampler
(
num_samples
,
samples_per_buffer
),
shuffle_
(
shuffle
),
seed_
(
GetSeed
()),
next_id_
(
0
),
num_pk_samples_
(
0
),
samples_per_class_
(
val
)
{}
Status
PKSampler
::
InitSampler
()
{
...
...
@@ -36,22 +35,34 @@ Status PKSampler::InitSampler() {
}
}
rnd_
.
seed
(
seed_
++
);
num_pk_samples_
=
samples_per_class_
*
static_cast
<
int64_t
>
(
labels_
.
size
());
samples_per_buffer_
=
(
samples_per_buffer_
>
num_pk_samples_
)
?
num_pk_samples_
:
samples_per_buffer_
;
num_samples_
=
num_pk_samples_
;
// The special handshake gives the list of classes and id's, but it did not set the num_rows_ to
// capture the total number of possible sample ids.
// Compute that here for this case to find the total number of samples that are available to return.
// (in this case, samples per class * total classes).
num_rows_
=
samples_per_class_
*
static_cast
<
int64_t
>
(
labels_
.
size
());
// The user may have chosen to sample less than the total amount.
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if
(
num_samples_
==
0
||
num_samples_
>
num_rows_
)
{
num_samples_
=
num_rows_
;
}
samples_per_buffer_
=
(
samples_per_buffer_
>
num_samples_
)
?
num_samples_
:
samples_per_buffer_
;
if
(
shuffle_
==
true
)
{
std
::
shuffle
(
labels_
.
begin
(),
labels_
.
end
(),
rnd_
);
}
else
{
std
::
sort
(
labels_
.
begin
(),
labels_
.
end
());
}
CHECK_FAIL_RETURN_UNEXPECTED
(
num_
pk_
samples_
>
0
,
"num_class or K (num samples per class) is not positive"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
,
"num_class or K (num samples per class) is not positive"
);
return
Status
::
OK
();
}
Status
PKSampler
::
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
if
(
next_id_
>
num_
pk_samples_
||
num_pk
_samples_
==
0
)
{
if
(
next_id_
>
num_
samples_
||
num
_samples_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"Index out of bound in PKSampler"
);
}
else
if
(
next_id_
==
num_
pk_
samples_
)
{
}
else
if
(
next_id_
==
num_samples_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
if
(
HasChildSampler
())
{
...
...
@@ -60,8 +71,7 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
next_id_
,
DataBuffer
::
kDeBFlagNone
);
std
::
shared_ptr
<
Tensor
>
sample_ids
;
int64_t
last_id
=
(
samples_per_buffer_
+
next_id_
>
num_pk_samples_
)
?
num_pk_samples_
:
samples_per_buffer_
+
next_id_
;
int64_t
last_id
=
(
samples_per_buffer_
+
next_id_
>
num_samples_
)
?
num_samples_
:
samples_per_buffer_
+
next_id_
;
RETURN_IF_NOT_OK
(
CreateSamplerTensor
(
&
sample_ids
,
last_id
-
next_id_
));
int64_t
*
id_ptr
=
reinterpret_cast
<
int64_t
*>
(
sample_ids
->
GetMutableBuffer
());
while
(
next_id_
<
last_id
)
{
...
...
@@ -85,7 +95,7 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
}
Status
PKSampler
::
Reset
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
next_id_
==
num_
pk_
samples_
,
"ERROR Reset() called early/late"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
next_id_
==
num_samples_
,
"ERROR Reset() called early/late"
);
next_id_
=
0
;
rnd_
.
seed
(
seed_
++
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h
浏览文件 @
769ae609
...
...
@@ -28,10 +28,11 @@ namespace mindspore {
namespace
dataset
{
class
PKSampler
:
public
Sampler
{
// NOT YET FINISHED
public:
// @param int64_t kVal
// @param num_samples - the number of samples to draw. value of 0 means to take the full amount
// @param int64_t val
// @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit
PKSampler
(
int64_t
val
,
bool
shuffle
=
fals
e
,
explicit
PKSampler
(
int64_t
num_samples
,
int64_t
val
,
bool
shuffl
e
,
int64_t
samples_per_buffer
=
std
::
numeric_limits
<
int64_t
>::
max
());
// default destructor
...
...
@@ -42,8 +43,9 @@ class PKSampler : public Sampler { // NOT YET FINISHED
// @return - The error code return
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
// first handshake between StorageOp and Sampler
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
// first handshake between leaf source op and Sampler. This func will determine the amount of data
// in the dataset that we can sample from.
// @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is
// @return
Status
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
)
override
;
...
...
@@ -58,7 +60,6 @@ class PKSampler : public Sampler { // NOT YET FINISHED
bool
shuffle_
;
uint32_t
seed_
;
int64_t
next_id_
;
int64_t
num_pk_samples_
;
int64_t
samples_per_class_
;
std
::
mt19937
rnd_
;
std
::
vector
<
int64_t
>
labels_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc
浏览文件 @
769ae609
...
...
@@ -20,8 +20,8 @@
namespace
mindspore
{
namespace
dataset
{
PythonSampler
::
PythonSampler
(
py
::
object
py_sampler_instance
,
int64_t
samples_per_buffer
)
:
Sampler
(
samples_per_buffer
),
py_sampler_instance
(
py_sampler_instance
),
need_to_reset_
(
false
)
{}
PythonSampler
::
PythonSampler
(
int64_t
num_samples
,
py
::
object
py_sampler_instance
,
int64_t
samples_per_buffer
)
:
Sampler
(
num_samples
,
samples_per_buffer
),
py_sampler_instance
(
py_sampler_instance
),
need_to_reset_
(
false
)
{}
Status
PythonSampler
::
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
if
(
need_to_reset_
)
{
...
...
@@ -65,6 +65,11 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
Status
PythonSampler
::
InitSampler
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
,
"ERROR num_rows_ should be greater than 0"
);
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if
(
num_samples_
==
0
||
num_samples_
>
num_rows_
)
{
num_samples_
=
num_rows_
;
}
{
py
::
gil_scoped_acquire
gil_acquire
;
if
(
Py_IsInitialized
()
==
0
)
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h
浏览文件 @
769ae609
...
...
@@ -26,8 +26,11 @@ namespace dataset {
class
PythonSampler
:
public
Sampler
{
public:
// Constructor
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit
PythonSampler
(
py
::
object
py_sampler_instance
,
// @param num_samples - the number of samples to draw. Value of 0 means to sample all of the
// data from the dataset.
// @param py_sampler_instance - the python instance of the sampler
// @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit
PythonSampler
(
int64_t
num_samples
,
py
::
object
py_sampler_instance
,
int64_t
samples_per_buffer
=
std
::
numeric_limits
<
int64_t
>::
max
());
// Destructor.
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
浏览文件 @
769ae609
...
...
@@ -22,12 +22,11 @@
namespace
mindspore
{
namespace
dataset
{
RandomSampler
::
RandomSampler
(
bool
replacement
,
bool
reshuffle_each_epoch
,
int64_t
num_samples
,
RandomSampler
::
RandomSampler
(
int64_t
num_samples
,
bool
replacement
,
bool
reshuffle_each_epoch
,
int64_t
samples_per_buffer
)
:
Sampler
(
samples_per_buffer
),
:
Sampler
(
num_samples
,
samples_per_buffer
),
seed_
(
GetSeed
()),
replacement_
(
replacement
),
user_num_samples_
(
num_samples
),
next_id_
(
0
),
reshuffle_each_epoch_
(
reshuffle_each_epoch
),
dist
(
nullptr
)
{}
...
...
@@ -70,27 +69,25 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
}
Status
RandomSampler
::
InitSampler
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
,
"num_rows needs to be positive."
);
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if
(
num_samples_
==
0
||
num_samples_
>
num_rows_
)
{
num_samples_
=
num_rows_
;
}
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
&&
num_rows_
>
0
,
"both num_samples & num_rows need to be positive"
);
samples_per_buffer_
=
samples_per_buffer_
>
num_samples_
?
num_samples_
:
samples_per_buffer_
;
rnd_
.
seed
(
seed_
);
if
(
replacement_
==
false
)
{
num_samples_
=
std
::
min
(
num_samples_
,
num_rows_
);
num_samples_
=
std
::
min
(
num_samples_
,
user_num_samples_
);
shuffled_ids_
.
reserve
(
num_rows_
);
for
(
int64_t
i
=
0
;
i
<
num_rows_
;
i
++
)
{
shuffled_ids_
.
push_back
(
i
);
}
std
::
shuffle
(
shuffled_ids_
.
begin
(),
shuffled_ids_
.
end
(),
rnd_
);
}
else
{
num_samples_
=
std
::
min
(
num_samples_
,
user_num_samples_
);
dist
=
std
::
make_unique
<
std
::
uniform_int_distribution
<
int64_t
>>
(
0
,
num_rows_
-
1
);
}
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
,
"num_samples needs to be positive."
);
samples_per_buffer_
=
samples_per_buffer_
>
num_samples_
?
num_samples_
:
samples_per_buffer_
;
return
Status
::
OK
();
}
...
...
@@ -119,7 +116,6 @@ void RandomSampler::Print(std::ostream &out, bool show_all) const {
out
<<
"(sampler): RandomSampler
\n
"
;
if
(
show_all
)
{
out
<<
"user_num_samples_: "
<<
user_num_samples_
<<
'\n'
;
out
<<
"num_samples_: "
<<
num_samples_
<<
'\n'
;
out
<<
"next_id_: "
<<
next_id_
<<
'\n'
;
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h
浏览文件 @
769ae609
...
...
@@ -27,11 +27,11 @@ namespace dataset {
class
RandomSampler
:
public
Sampler
{
public:
// Constructor
// @param int64_t num_samples - number samples to draw
// @param bool replacement - put he id back / or not after a sample
// @param int64_t numSamples - number samples to draw
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit
RandomSampler
(
bool
replacement
=
false
,
bool
reshuffle_each_epoch
=
true
,
int64_t
num_samples
=
std
::
numeric_limits
<
int64_t
>::
max
(),
// @param reshuffle_each_epoch - T/F to reshuffle after epoch
// @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit
RandomSampler
(
int64_t
num_samples
,
bool
replacement
,
bool
reshuffle_each_epoch
,
int64_t
samples_per_buffer
=
std
::
numeric_limits
<
int64_t
>::
max
());
// Destructor.
...
...
@@ -55,7 +55,6 @@ class RandomSampler : public Sampler {
private:
uint32_t
seed_
;
bool
replacement_
;
int64_t
user_num_samples_
;
std
::
vector
<
int64_t
>
shuffled_ids_
;
// only used for NO REPLACEMENT
int64_t
next_id_
;
std
::
mt19937
rnd_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
浏览文件 @
769ae609
...
...
@@ -19,8 +19,25 @@
namespace
mindspore
{
namespace
dataset
{
Sampler
::
Sampler
(
int64_t
samples_per_buffer
)
:
DatasetOp
(
0
),
num_rows_
(
0
),
num_samples_
(
0
),
samples_per_buffer_
(
samples_per_buffer
),
col_desc_
(
nullptr
)
{}
Status
RandomAccessOp
::
GetNumRowsInDataset
(
int64_t
*
num
)
const
{
// The sampler base class itself does not compute it's own num_rows_ value.
// Instead, this value is computed by the derived leaf op during it's own initialization
// after it has interacted with it's storage layers.
// Here, it is just a getter method to return the value. However, it is invalid if there is
// not a value set for this count, so generate a failure if that is the case.
if
(
num
==
nullptr
||
num_rows_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"RandomAccessOp has not computed it's num rows yet."
);
}
(
*
num
)
=
num_rows_
;
return
Status
::
OK
();
}
Sampler
::
Sampler
(
int64_t
num_samples
,
int64_t
samples_per_buffer
)
:
DatasetOp
(
0
),
num_rows_
(
0
),
num_samples_
(
num_samples
),
samples_per_buffer_
(
samples_per_buffer
),
col_desc_
(
nullptr
)
{}
Status
Sampler
::
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
)
{
std
::
shared_ptr
<
Sampler
>
child_sampler
;
...
...
@@ -36,10 +53,10 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
}
CHECK_FAIL_RETURN_UNEXPECTED
(
op
!=
nullptr
,
"RandomAccessOp is nullptr
\n
"
);
RETURN_IF_NOT_OK
(
op
->
GetNumSamples
(
&
num_samples_
));
// If there's a child sampler, set the row count to be it's sample count
if
(
HasChildSampler
())
{
int64_t
child_num_samples
=
child_sampler
->
num_samples
();
num_rows_
=
child_num_samples
;
num_rows_
=
child_sampler
->
num_samples_
;
}
else
{
RETURN_IF_NOT_OK
(
op
->
GetNumRowsInDataset
(
&
num_rows_
));
}
...
...
@@ -105,7 +122,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
}
Status
Sampler
::
SetNumSamples
(
int64_t
num_samples
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples
>
0
,
"num_samples is negative or 0
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples
>
=
0
,
"num_samples is negative
"
);
num_samples_
=
num_samples
;
return
Status
::
OK
();
}
...
...
@@ -116,6 +133,16 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
return
Status
::
OK
();
}
// inline op doesn't have it's own consumer, it's assigned from parent
int32_t
Sampler
::
num_consumers
()
const
{
if
(
parent_
.
empty
()
||
parent_
[
0
]
==
nullptr
)
{
MS_LOG
(
WARNING
)
<<
"Sampler with no parent. num_consumers is 0."
;
return
0
;
}
else
{
return
parent_
[
0
]
->
num_consumers
();
}
}
Status
Sampler
::
AddChild
(
std
::
shared_ptr
<
DatasetOp
>
child
)
{
if
(
child
==
nullptr
)
{
return
Status
::
OK
();
...
...
@@ -155,5 +182,14 @@ Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
return
Status
::
OK
();
}
// inline op doesn't have it's own producers, it's assigned from child
int32_t
Sampler
::
num_producers
()
const
{
if
(
child_
.
empty
()
||
child_
[
0
]
==
nullptr
)
{
MS_LOG
(
WARNING
)
<<
"Sampler with no child, num_producers is 0."
;
return
0
;
}
else
{
return
child_
[
0
]
->
num_producers
();
}
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
浏览文件 @
769ae609
...
...
@@ -33,23 +33,10 @@ namespace dataset {
// must inherit from if those leaf operator wish to support sampling.
class
RandomAccessOp
{
public:
// Sampler get numRows from StorageOp
// @param int64_t num - return number of rows, normally num of samples
// @return - The error code return
virtual
Status
GetNumSamples
(
int64_t
*
num_samples
)
const
{
// CI complains num_samples not used if the following line is not added
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples
!=
nullptr
,
"num_samples == nullptr"
);
RETURN_STATUS_UNEXPECTED
(
"function GetNumSamples needs to overridden to support this sampler"
);
}
// Sampler get number of rows in the dataset!
// Sampler get number of rows in the dataset
// @param int64_t num - return number of rows for this dataset
// @return - The error code return
virtual
Status
GetNumRowsInDataset
(
int64_t
*
num_rows
)
const
{
// CI complains num_rows not used if the following line is not added
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows
!=
nullptr
,
"num_rows == nullptr"
);
RETURN_STATUS_UNEXPECTED
(
"function GetNumRowsInDataset needs to overridden to support this sampler"
);
}
Status
GetNumRowsInDataset
(
int64_t
*
num_rows
)
const
;
// sampler gets label , imageIds from storageOp, this function is unique to PK
// @param std::map<int64_t, std::vector<int64_t>> * map
...
...
@@ -60,12 +47,20 @@ class RandomAccessOp {
// default destructor
virtual
~
RandomAccessOp
()
=
default
;
protected:
// The amount of rows in the dataset itself. This is the before-sampling value, the
// total count of rows. A sampler may choose to sample less than this amount.
int64_t
num_rows_
;
};
class
Sampler
:
public
DatasetOp
{
public:
// Constructor
// @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0
// indicates that the sampler should produce the complete set of ids.
// @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit
Sampler
(
int64_t
samples_per_buffer
=
std
::
numeric_limits
<
int64_t
>::
max
()
);
explicit
Sampler
(
int64_t
num_samples
,
int64_t
samples_per_buffer
);
// default destructor
~
Sampler
()
=
default
;
...
...
@@ -84,33 +79,36 @@ class Sampler : public DatasetOp {
// @return - The error code return
Status
Reset
()
override
=
0
;
// setter function for num_rows_
Status
SetNumRowsInDataset
(
int64_t
num_rows
);
// setter function for num_samples_
Status
SetNumSamples
(
int64_t
num_samples
);
int64_t
num_samples
()
{
return
num_samples_
;
}
// first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples
// @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds()
// first handshake between leaf source op and Sampler. This func will determine the amount of data
// in the dataset that we can sample from.
// @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is
// @return
virtual
Status
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
);
// initialize sampler and perform checks on certain vars
virtual
Status
InitSampler
()
{
return
Status
::
OK
();
}
// Not meant to be called
// setter for num samples
// @param num_samples - the number of samples to assign.
// @return status error code
Status
SetNumSamples
(
int64_t
num_samples
);
// setter for num or records in the dataset
// @param num_rows - the number of records
// @return status error code
Status
SetNumRowsInDataset
(
int64_t
num_rows
);
// Sampler is an inlined op and has no workers. Producers and consumers are computed.
// @return
int32_t
num_workers
()
const
final
{
return
0
;
}
//
Not meant to be called
//
Identify num consumers (inlined op)
// @return
int32_t
num_consumers
()
const
final
{
return
0
;
}
int32_t
num_consumers
()
const
final
;
//
Not meant to be called
//
Identify num producers (inlined op)
// @return
int32_t
num_producers
()
const
final
{
return
0
;
}
int32_t
num_producers
()
const
final
;
// Not meant to be called!
// @return - The error code return
...
...
@@ -151,10 +149,11 @@ class Sampler : public DatasetOp {
// output. Otherwise, num_rows_ is the number of rows in the dataset.
int64_t
num_rows_
;
// Number of ids this sampler will return.
// The user may want to sample less than the full amount of data. num_samples_ reduces the number
// of id's returned as request by the user. Derived classes will choose how to sample the smaller
// amount.
int64_t
num_samples_
;
// The max number of ids a DataBuffer returned by this sampler will contain.
int64_t
samples_per_buffer_
;
std
::
unique_ptr
<
ColDescriptor
>
col_desc_
;
std
::
unique_ptr
<
DataBuffer
>
child_ids_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
浏览文件 @
769ae609
...
...
@@ -20,34 +20,42 @@
namespace
mindspore
{
namespace
dataset
{
SequentialSampler
::
SequentialSampler
(
int64_t
samples_per_buffer
)
:
Sampler
(
samples_per_buffer
),
next_id_
(
0
)
{}
SequentialSampler
::
SequentialSampler
(
int64_t
num_samples
,
int64_t
start_index
,
int64_t
samples_per_buffer
)
:
Sampler
(
num_samples
,
samples_per_buffer
),
start_index_
(
start_index
),
current_id_
(
start_index
),
id_count_
(
0
)
{}
Status
SequentialSampler
::
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
if
(
next_id
_
>
num_samples_
)
{
RETURN_STATUS_UNEXPECTED
(
"Sequential
Sampler Internal Error"
);
}
else
if
(
next_id
_
==
num_samples_
)
{
if
(
id_count
_
>
num_samples_
)
{
RETURN_STATUS_UNEXPECTED
(
"SequentialSampler Internal Error"
);
}
else
if
(
id_count
_
==
num_samples_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
child_ids_
));
}
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
nex
t_id_
,
DataBuffer
::
kDeBFlagNone
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
curren
t_id_
,
DataBuffer
::
kDeBFlagNone
);
std
::
shared_ptr
<
Tensor
>
sampleIds
;
int64_t
lastId
=
(
samples_per_buffer_
+
next_id_
>
num_samples_
)
?
num_samples_
:
samples_per_buffer_
+
next_id_
;
RETURN_IF_NOT_OK
(
CreateSamplerTensor
(
&
sampleIds
,
lastId
-
next_id_
));
// Compute how many ids are left to pack, and pack this amount into a new buffer. Respect the setting for
// samples per buffer though.
int64_t
remaining_ids
=
num_samples_
-
id_count_
;
int64_t
num_elements
=
std
::
min
(
remaining_ids
,
samples_per_buffer_
);
RETURN_IF_NOT_OK
(
CreateSamplerTensor
(
&
sampleIds
,
num_elements
));
int64_t
*
idPtr
=
reinterpret_cast
<
int64_t
*>
(
sampleIds
->
GetMutableBuffer
());
while
(
next_id_
<
lastId
)
{
int64_t
sampled_id
=
nex
t_id_
;
for
(
int64_t
i
=
0
;
i
<
num_elements
;
i
++
)
{
int64_t
sampled_id
=
curren
t_id_
;
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
GetAssociatedChildId
(
&
sampled_id
,
sampled_id
));
}
*
idPtr
=
sampled_id
;
next_id_
++
;
current_id_
++
;
// Move the current id to the next one in the sequence
idPtr
++
;
}
id_count_
+=
num_elements
;
// Count the packed ids towards our overall sample count
TensorRow
row
(
1
,
sampleIds
);
(
*
out_buffer
)
->
set_tensor_table
(
std
::
make_unique
<
TensorQTable
>
(
1
,
row
));
}
...
...
@@ -55,19 +63,24 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
}
Status
SequentialSampler
::
InitSampler
()
{
num_samples_
=
(
num_samples_
<=
0
)
?
num_rows_
:
num_samples_
;
// if num_samples < 0, try if num_rows is set
if
(
HasChildSampler
())
{
num_samples_
=
std
::
min
(
num_samples_
,
num_rows_
);
CHECK_FAIL_RETURN_UNEXPECTED
(
start_index_
>=
0
,
"start_index < 0
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
start_index_
<
num_rows_
,
"start_index >= num_rows
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>=
0
,
"num_samples < 0
\n
"
);
// Adjust the num_samples count based on the range of ids we are sequencing. If num_samples is 0, we sample
// the entire set. If it's non-zero, we will implicitly cap the amount sampled based on available data.
int64_t
available_row_count
=
num_rows_
-
start_index_
;
if
(
num_samples_
==
0
||
num_samples_
>
available_row_count
)
{
num_samples_
=
available_row_count
;
}
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
&&
samples_per_buffer_
>
0
,
"Fail to init Sequential Sampler"
);
samples_per_buffer_
=
samples_per_buffer_
>
num_samples_
?
num_samples_
:
samples_per_buffer_
;
return
Status
::
OK
();
}
Status
SequentialSampler
::
Reset
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
next_id_
==
num_samples_
,
"ERROR Reset() called early/late"
);
next_id_
=
0
;
CHECK_FAIL_RETURN_UNEXPECTED
(
id_count_
==
num_samples_
,
"ERROR Reset() called early/late"
);
current_id_
=
start_index_
;
id_count_
=
0
;
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
Reset
());
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h
浏览文件 @
769ae609
...
...
@@ -26,8 +26,12 @@ namespace dataset {
class
SequentialSampler
:
public
Sampler
{
public:
// Constructor
// @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the
// full amount of ids from the dataset
// @param start_index - The starting index value
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit
SequentialSampler
(
int64_t
samples_per_buffer
=
std
::
numeric_limits
<
int64_t
>::
max
());
explicit
SequentialSampler
(
int64_t
num_samples
,
int64_t
start_index
,
int64_t
samples_per_buffer
=
std
::
numeric_limits
<
int64_t
>::
max
());
// Destructor.
~
SequentialSampler
()
=
default
;
...
...
@@ -48,7 +52,9 @@ class SequentialSampler : public Sampler {
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
private:
int64_t
next_id_
;
int64_t
current_id_
;
// The id sequencer. Each new id increments from this
int64_t
start_index_
;
// The starting id. current_id_ begins from here.
int64_t
id_count_
;
// An internal counter that tracks how many ids have been produced
};
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
浏览文件 @
769ae609
...
...
@@ -27,22 +27,28 @@
namespace
mindspore
{
namespace
dataset
{
// Constructor.
SubsetRandomSampler
::
SubsetRandomSampler
(
const
std
::
vector
<
int64_t
>
&
indices
,
int64_t
samples_per_buffer
)
:
Sampler
(
samples_per_buffer
),
indices_
(
indices
),
sample_id_
(
0
),
buffer_id_
(
0
)
{}
SubsetRandomSampler
::
SubsetRandomSampler
(
int64_t
num_samples
,
const
std
::
vector
<
int64_t
>
&
indices
,
int64_t
samples_per_buffer
)
:
Sampler
(
num_samples
,
samples_per_buffer
),
indices_
(
indices
),
sample_id_
(
0
),
buffer_id_
(
0
)
{}
// Initialized this Sampler.
Status
SubsetRandomSampler
::
InitSampler
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
,
"num_rows <= 0
\n
"
);
num_samples_
=
indices_
.
size
();
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// In this case, the id's are provided by the user. Cap the num_samples on the number of id's given.
if
(
num_samples_
==
0
||
num_samples_
>
static_cast
<
int64_t
>
(
indices_
.
size
()))
{
num_samples_
=
static_cast
<
int64_t
>
(
indices_
.
size
());
}
// Initialize random generator with seed from config manager
rand_gen_
.
seed
(
GetSeed
());
if
(
s
tatic_cast
<
size_t
>
(
samples_per_buffer_
)
>
indices_
.
size
()
)
{
samples_per_buffer_
=
static_cast
<
int64_t
>
(
indices_
.
size
())
;
if
(
s
amples_per_buffer_
>
num_samples_
)
{
samples_per_buffer_
=
num_samples_
;
}
// num_samples_ could be smaller than the total number of input id's.
// We will shuffle the full set of id's, but only select the first num_samples_ of them later.
std
::
shuffle
(
indices_
.
begin
(),
indices_
.
end
(),
rand_gen_
);
return
Status
::
OK
();
...
...
@@ -68,7 +74,7 @@ Status SubsetRandomSampler::Reset() {
// Get the sample ids.
Status
SubsetRandomSampler
::
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
// All samples have been drawn
if
(
sample_id_
==
indices_
.
size
()
)
{
if
(
sample_id_
==
num_samples_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
if
(
HasChildSampler
())
{
...
...
@@ -80,8 +86,8 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe
int64_t
last_id
=
sample_id_
+
samples_per_buffer_
;
// Handling the return all samples at once, and when last draw is not a full batch.
if
(
static_cast
<
size_t
>
(
last_id
)
>
indices_
.
size
()
)
{
last_id
=
indices_
.
size
()
;
if
(
last_id
>
num_samples_
)
{
last_id
=
num_samples_
;
}
// Allocate tensor
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h
浏览文件 @
769ae609
...
...
@@ -28,10 +28,11 @@ namespace dataset {
class
SubsetRandomSampler
:
public
Sampler
{
public:
// Constructor.
// @param num_samples The number of samples to draw. 0 for the full amount.
// @param indices List of indices from where we will randomly draw samples.
// @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer().
// When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once.
explicit
SubsetRandomSampler
(
const
std
::
vector
<
int64_t
>
&
indices
,
explicit
SubsetRandomSampler
(
int64_t
num_samples
,
const
std
::
vector
<
int64_t
>
&
indices
,
std
::
int64_t
samples_per_buffer
=
std
::
numeric_limits
<
int64_t
>::
max
());
// Destructor.
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc
已删除
100644 → 0
浏览文件 @
06ee0296
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include <memory>
#include <string>
#include "dataset/core/config_manager.h"
#include "dataset/core/global_context.h"
namespace
mindspore
{
namespace
dataset
{
// Constructor.
SubsetSampler
::
SubsetSampler
(
int64_t
start_index
,
int64_t
subset_size
)
:
Sampler
(
subset_size
),
start_index_
(
start_index
),
subset_size_
(
subset_size
),
current_id_
(
0
)
{}
Status
SubsetSampler
::
InitSampler
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
subset_size_
>
0
,
"subset_size <= 0
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
start_index_
>=
0
,
"start_index < 0
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
start_index_
<
num_rows_
,
"start_index >= num_rows
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
start_index_
+
subset_size_
-
1
<
num_rows_
,
"Final index out of bounds.
\n
"
);
num_samples_
=
subset_size_
;
return
Status
::
OK
();
}
Status
SubsetSampler
::
Reset
()
{
current_id_
=
0
;
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
Reset
());
}
return
Status
::
OK
();
}
Status
SubsetSampler
::
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
if
(
current_id_
>
subset_size_
)
{
RETURN_STATUS_UNEXPECTED
(
"SubsetSampler Internal Error"
);
}
else
if
(
current_id_
==
subset_size_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
child_ids_
));
}
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagNone
);
std
::
shared_ptr
<
Tensor
>
sampled_ids
;
RETURN_IF_NOT_OK
(
CreateSamplerTensor
(
&
sampled_ids
,
subset_size_
));
int64_t
*
sampled_ids_start_addr
=
reinterpret_cast
<
int64_t
*>
(
sampled_ids
->
GetMutableBuffer
());
while
(
current_id_
<
subset_size_
)
{
int64_t
sampled_id
=
start_index_
+
current_id_
;
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
GetAssociatedChildId
(
&
sampled_id
,
sampled_id
));
}
*
(
sampled_ids_start_addr
+
current_id_
)
=
sampled_id
;
current_id_
++
;
}
TensorRow
sampled_ids_row
(
1
,
sampled_ids
);
(
*
out_buffer
)
->
set_tensor_table
(
std
::
make_unique
<
TensorQTable
>
(
1
,
sampled_ids_row
));
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h
已删除
100644 → 0
浏览文件 @
06ee0296
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
#include <memory>
#include <vector>
#include "dataset/engine/datasetops/source/sampler/sampler.h"
namespace
mindspore
{
namespace
dataset
{
class
SubsetSampler
:
public
Sampler
{
public:
// Constructor.
// @param start_index The index we start sampling from.
explicit
SubsetSampler
(
int64_t
start_index
,
int64_t
subset_size
);
// Destructor.
~
SubsetSampler
()
=
default
;
// Initialize the sampler.
// @return Status
Status
InitSampler
()
override
;
// Reset the internal variable to the initial state and reshuffle the indices.
// @return Status
Status
Reset
()
override
;
// Get the sample ids.
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
// @note the sample ids (int64_t) will be placed in one Tensor.
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
private:
int64_t
start_index_
;
int64_t
subset_size_
;
int64_t
current_id_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
浏览文件 @
769ae609
...
...
@@ -27,25 +27,28 @@
namespace
mindspore
{
namespace
dataset
{
// Constructor.
WeightedRandomSampler
::
WeightedRandomSampler
(
const
std
::
vector
<
double
>
&
weights
,
int64_t
num_sample
s
,
bool
replacement
,
WeightedRandomSampler
::
WeightedRandomSampler
(
int64_t
num_samples
,
const
std
::
vector
<
double
>
&
weight
s
,
bool
replacement
,
int64_t
samples_per_buffer
)
:
Sampler
(
samples_per_buffer
),
:
Sampler
(
num_samples
,
samples_per_buffer
),
weights_
(
weights
),
replacement_
(
replacement
),
sample_id_
(
0
),
buffer_id_
(
0
),
user_num_samples_
(
num_samples
)
{}
buffer_id_
(
0
)
{}
// Initialized this Sampler.
Status
WeightedRandomSampler
::
InitSampler
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
&&
user_num_samples_
,
"num_samples & num_rows need to be positive"
);
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if
(
num_samples_
==
0
||
num_samples_
>
num_rows_
)
{
num_samples_
=
num_rows_
;
}
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
&&
num_samples_
,
"num_samples & num_rows need to be positive"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
samples_per_buffer_
>
0
,
"samples_per_buffer<=0
\n
"
);
num_samples_
=
user_num_samples_
;
// Initialize random generator with seed from config manager
rand_gen_
.
seed
(
GetSeed
());
samples_per_buffer_
=
(
samples_per_buffer_
>
user_num_samples_
)
?
user_
num_samples_
:
samples_per_buffer_
;
samples_per_buffer_
=
(
samples_per_buffer_
>
num_samples_
)
?
num_samples_
:
samples_per_buffer_
;
if
(
!
replacement_
)
{
exp_dist_
=
std
::
make_unique
<
std
::
exponential_distribution
<>>
(
1
);
...
...
@@ -67,8 +70,8 @@ void WeightedRandomSampler::InitOnePassSampling() {
}
// Partial sort the first `numSamples` elements.
std
::
partial_sort
(
val_idx
.
begin
(),
val_idx
.
begin
()
+
user_
num_samples_
,
val_idx
.
end
());
for
(
int64_t
i
=
0
;
i
<
user_
num_samples_
;
i
++
)
{
std
::
partial_sort
(
val_idx
.
begin
(),
val_idx
.
begin
()
+
num_samples_
,
val_idx
.
end
());
for
(
int64_t
i
=
0
;
i
<
num_samples_
;
i
++
)
{
onepass_ids_
.
push_back
(
val_idx
[
i
].
second
);
}
}
...
...
@@ -98,11 +101,11 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors"
);
}
if
(
!
replacement_
&&
(
weights_
.
size
()
<
static_cast
<
size_t
>
(
user_
num_samples_
)))
{
if
(
!
replacement_
&&
(
weights_
.
size
()
<
static_cast
<
size_t
>
(
num_samples_
)))
{
RETURN_STATUS_UNEXPECTED
(
"Without replacement, sample weights less than numSamples"
);
}
if
(
sample_id_
==
user_
num_samples_
)
{
if
(
sample_id_
==
num_samples_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
if
(
HasChildSampler
())
{
...
...
@@ -114,8 +117,8 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
int64_t
last_id
=
sample_id_
+
samples_per_buffer_
;
// Handling the return all samples at once, and when last draw is not a full batch.
if
(
last_id
>
user_
num_samples_
)
{
last_id
=
user_
num_samples_
;
if
(
last_id
>
num_samples_
)
{
last_id
=
num_samples_
;
}
// Allocate tensor.
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h
浏览文件 @
769ae609
...
...
@@ -29,12 +29,12 @@ namespace dataset {
class
WeightedRandomSampler
:
public
Sampler
{
public:
// Constructor.
// @param weights A lift of sample weights.
// @param num_samples Number of samples to be drawn.
// @param weights A lift of sample weights.
// @param replacement Determine if samples are drawn with/without replacement.
// @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer().
// When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once.
WeightedRandomSampler
(
const
std
::
vector
<
double
>
&
weights
,
int64_t
num_samples
,
bool
replacement
=
true
,
WeightedRandomSampler
(
int64_t
num_samples
,
const
std
::
vector
<
double
>
&
weights
,
bool
replacement
,
int64_t
samples_per_buffer
=
std
::
numeric_limits
<
int64_t
>::
max
());
// Destructor.
...
...
@@ -69,9 +69,6 @@ class WeightedRandomSampler : public Sampler {
// Random engine and device
std
::
mt19937
rand_gen_
;
// num_samples from user
int64_t
user_num_samples_
;
// Discrete distribution for generating weighted random numbers with replacement.
std
::
unique_ptr
<
std
::
discrete_distribution
<
int64_t
>>
discrete_dist_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
浏览文件 @
769ae609
...
...
@@ -33,7 +33,7 @@
namespace
mindspore
{
namespace
dataset
{
TextFileOp
::
Builder
::
Builder
()
:
builder_device_id_
(
0
),
builder_num_devices_
(
1
),
builder_
num_sample
s_
(
0
),
builder_shuffle_files_
(
false
)
{
:
builder_device_id_
(
0
),
builder_num_devices_
(
1
),
builder_
total_row
s_
(
0
),
builder_shuffle_files_
(
false
)
{
std
::
shared_ptr
<
ConfigManager
>
config_manager
=
GlobalContext
::
config_manager
();
builder_num_workers_
=
config_manager
->
num_parallel_workers
();
builder_op_connector_size_
=
config_manager
->
op_connector_size
();
...
...
@@ -62,7 +62,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
builder_schema_
->
AddColumn
(
ColDescriptor
(
"text"
,
DataType
(
DataType
::
DE_UINT8
),
TensorImpl
::
kFlexible
,
1
)));
std
::
shared_ptr
<
TextFileOp
>
text_file_op
=
std
::
make_shared
<
TextFileOp
>
(
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_
num_sample
s_
,
builder_worker_connector_size_
,
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_
total_row
s_
,
builder_worker_connector_size_
,
std
::
move
(
builder_schema_
),
builder_text_files_list_
,
builder_op_connector_size_
,
builder_shuffle_files_
,
builder_num_devices_
,
builder_device_id_
);
RETURN_IF_NOT_OK
(
text_file_op
->
Init
());
...
...
@@ -71,14 +71,14 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
return
Status
::
OK
();
}
TextFileOp
::
TextFileOp
(
int32_t
num_workers
,
int64_t
rows_per_buffer
,
int64_t
num_sample
s
,
int32_t
worker_connector_size
,
TextFileOp
::
TextFileOp
(
int32_t
num_workers
,
int64_t
rows_per_buffer
,
int64_t
total_row
s
,
int32_t
worker_connector_size
,
std
::
unique_ptr
<
DataSchema
>
schema
,
std
::
vector
<
std
::
string
>
text_files_list
,
int32_t
op_connector_size
,
bool
shuffle_files
,
int32_t
num_device
,
int32_t
device_id
)
:
ParallelOp
(
num_workers
,
op_connector_size
),
device_id_
(
device_id
),
num_devices_
(
num_device
),
rows_per_buffer_
(
rows_per_buffer
),
num_samples_
(
num_sample
s
),
total_rows_
(
total_row
s
),
text_files_list_
(
std
::
move
(
text_files_list
)),
shuffle_files_
(
shuffle_files
),
data_schema_
(
std
::
move
(
schema
)),
...
...
@@ -104,9 +104,9 @@ void TextFileOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info
ParallelOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal stuff
out
<<
"
\n
Rows per buffer: "
<<
rows_per_buffer_
<<
"
\n
Sample count: "
<<
num_samples
_
<<
"
\n
Device id: "
<<
device_id_
<<
"
\n
Number of devices: "
<<
num_devices_
<<
"
\n
Shuffle files: "
<<
((
shuffle_files_
)
?
"yes"
:
"no"
)
<<
"
\n
Text files list:
\n
"
;
out
<<
"
\n
Rows per buffer: "
<<
rows_per_buffer_
<<
"
\n
Row count: "
<<
total_rows_
<<
"
\n
Device id: "
<<
device_id
_
<<
"
\n
Number of devices: "
<<
num_devices_
<<
"
\n
Shuffle files: "
<<
((
shuffle_files_
)
?
"yes"
:
"no"
)
<<
"
\n
Text files list:
\n
"
;
for
(
int
i
=
0
;
i
<
text_files_list_
.
size
();
++
i
)
{
out
<<
" "
<<
text_files_list_
[
i
];
}
...
...
@@ -404,9 +404,9 @@ Status TextFileOp::operator()() {
RETURN_IF_NOT_OK
(
jagged_buffer_connector_
->
Pop
(
0
,
&
buffer
));
if
(
buffer
->
eoe
())
{
workers_done
++
;
}
else
if
(
num_samples_
==
0
||
rows_read
<
num_sample
s_
)
{
if
((
num_samples_
>
0
)
&&
(
rows_read
+
buffer
->
NumRows
()
>
num_sample
s_
))
{
int64_t
rowsToRemove
=
buffer
->
NumRows
()
-
(
num_sample
s_
-
rows_read
);
}
else
if
(
total_rows_
==
0
||
rows_read
<
total_row
s_
)
{
if
((
total_rows_
>
0
)
&&
(
rows_read
+
buffer
->
NumRows
()
>
total_row
s_
))
{
int64_t
rowsToRemove
=
buffer
->
NumRows
()
-
(
total_row
s_
-
rows_read
);
RETURN_IF_NOT_OK
(
buffer
->
SliceOff
(
rowsToRemove
));
}
rows_read
+=
buffer
->
NumRows
();
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h
浏览文件 @
769ae609
...
...
@@ -107,8 +107,8 @@ class TextFileOp : public ParallelOp {
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
Set
NumSamples
(
int64_t
num_sample
s
)
{
builder_
num_samples_
=
num_sample
s
;
Builder
&
Set
TotalRows
(
int64_t
total_row
s
)
{
builder_
total_rows_
=
total_row
s
;
return
*
this
;
}
...
...
@@ -118,7 +118,7 @@ class TextFileOp : public ParallelOp {
int32_t
builder_num_workers_
;
int32_t
builder_op_connector_size_
;
int64_t
builder_rows_per_buffer_
;
int64_t
builder_
num_sample
s_
;
int64_t
builder_
total_row
s_
;
int32_t
builder_worker_connector_size_
;
std
::
vector
<
std
::
string
>
builder_text_files_list_
;
bool
builder_shuffle_files_
;
...
...
@@ -136,7 +136,7 @@ class TextFileOp : public ParallelOp {
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
TextFileOp
(
int32_t
num_workers
,
int64_t
rows_per_buffer
,
int64_t
num_sample
s
,
int32_t
worker_connector_size
,
TextFileOp
(
int32_t
num_workers
,
int64_t
rows_per_buffer
,
int64_t
total_row
s
,
int32_t
worker_connector_size
,
std
::
unique_ptr
<
DataSchema
>
,
std
::
vector
<
std
::
string
>
text_files_list
,
int32_t
op_connector_size
,
bool
shuffle_files
,
int32_t
num_devices
,
int32_t
device_id
);
...
...
@@ -246,7 +246,7 @@ class TextFileOp : public ParallelOp {
int32_t
device_id_
;
int32_t
num_devices_
;
int64_t
rows_per_buffer_
;
int64_t
num_sample
s_
;
int64_t
total_row
s_
;
std
::
vector
<
std
::
string
>
text_files_list_
;
bool
shuffle_files_
;
std
::
unique_ptr
<
DataSchema
>
data_schema_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
浏览文件 @
769ae609
...
...
@@ -44,7 +44,7 @@ const char kSegmentationExtension[] = ".png";
const
char
kAnnotationExtension
[]
=
".xml"
;
const
char
kImageSetsExtension
[]
=
".txt"
;
VOCOp
::
Builder
::
Builder
()
:
builder_decode_
(
false
),
builder_
num_samples_
(
0
),
builder_
sampler_
(
nullptr
)
{
VOCOp
::
Builder
::
Builder
()
:
builder_decode_
(
false
),
builder_sampler_
(
nullptr
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
builder_num_workers_
=
cfg
->
num_parallel_workers
();
builder_rows_per_buffer_
=
cfg
->
rows_per_buffer
();
...
...
@@ -55,7 +55,9 @@ VOCOp::Builder::Builder() : builder_decode_(false), builder_num_samples_(0), bui
Status
VOCOp
::
Builder
::
Build
(
std
::
shared_ptr
<
VOCOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
if
(
builder_sampler_
==
nullptr
)
{
builder_sampler_
=
std
::
make_shared
<
SequentialSampler
>
();
int64_t
num_samples
=
0
;
int64_t
start_index
=
0
;
builder_sampler_
=
std
::
make_shared
<
SequentialSampler
>
(
start_index
,
num_samples
);
}
builder_schema_
=
std
::
make_unique
<
DataSchema
>
();
if
(
builder_task_type_
==
TaskType
::
Segmentation
)
{
...
...
@@ -71,8 +73,7 @@ Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) {
}
*
ptr
=
std
::
make_shared
<
VOCOp
>
(
builder_task_type_
,
builder_task_mode_
,
builder_dir_
,
builder_labels_to_read_
,
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_op_connector_size_
,
builder_num_samples_
,
builder_decode_
,
std
::
move
(
builder_schema_
),
std
::
move
(
builder_sampler_
));
builder_decode_
,
std
::
move
(
builder_schema_
),
std
::
move
(
builder_sampler_
));
return
Status
::
OK
();
}
...
...
@@ -81,20 +82,16 @@ Status VOCOp::Builder::SanityCheck() {
std
::
string
err_msg
;
err_msg
+=
dir
.
IsDirectory
()
==
false
?
"VOC path is invalid or not set
\n
"
:
""
;
err_msg
+=
builder_num_workers_
<=
0
?
"Num of parallel workers is set to 0 or negative
\n
"
:
""
;
err_msg
+=
builder_num_samples_
<
0
?
"num_samples is negative
\n
"
:
""
;
return
err_msg
.
empty
()
?
Status
::
OK
()
:
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
err_msg
);
}
VOCOp
::
VOCOp
(
const
TaskType
&
task_type
,
const
std
::
string
&
task_mode
,
const
std
::
string
&
folder_path
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
class_index
,
int32_t
num_workers
,
int32_t
rows_per_buffer
,
int32_t
queue_size
,
int64_t
num_samples
,
bool
decode
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
int32_t
queue_size
,
bool
decode
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_workers
,
queue_size
),
decode_
(
decode
),
row_cnt_
(
0
),
buf_cnt_
(
0
),
num_rows_
(
0
),
num_samples_
(
num_samples
),
task_type_
(
task_type
),
task_mode_
(
task_mode
),
folder_path_
(
folder_path
),
...
...
@@ -112,7 +109,6 @@ VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std:
Status
VOCOp
::
TraverseSampleIds
(
const
std
::
shared_ptr
<
Tensor
>
&
sample_ids
,
std
::
vector
<
int64_t
>
*
keys
)
{
for
(
auto
itr
=
sample_ids
->
begin
<
int64_t
>
();
itr
!=
sample_ids
->
end
<
int64_t
>
();
++
itr
)
{
if
((
*
itr
)
>
num_rows_
)
continue
;
if
(
row_cnt_
==
num_samples_
)
break
;
keys
->
push_back
(
*
itr
);
row_cnt_
++
;
if
(
row_cnt_
%
rows_per_buffer_
==
0
)
{
...
...
@@ -187,16 +183,6 @@ Status VOCOp::Reset() {
return
Status
::
OK
();
}
Status
VOCOp
::
GetNumSamples
(
int64_t
*
num
)
const
{
if
(
num
==
nullptr
||
num_rows_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"There is no valid data matching the dataset API VOCDataset.Please check file path or dataset API "
"validation first."
);
}
(
*
num
)
=
num_samples_
;
return
Status
::
OK
();
}
Status
VOCOp
::
LoadTensorRow
(
const
std
::
string
&
image_id
,
TensorRow
*
trow
)
{
if
(
task_type_
==
TaskType
::
Segmentation
)
{
std
::
shared_ptr
<
Tensor
>
image
,
target
;
...
...
@@ -280,7 +266,6 @@ Status VOCOp::ParseImageIds() {
in_file
.
close
();
image_ids_
.
shrink_to_fit
();
num_rows_
=
image_ids_
.
size
();
num_samples_
=
(
num_samples_
==
0
||
num_samples_
>
num_rows_
)
?
num_rows_
:
num_samples_
;
return
Status
::
OK
();
}
...
...
@@ -305,7 +290,6 @@ Status VOCOp::ParseAnnotationIds() {
}
num_rows_
=
image_ids_
.
size
();
num_samples_
=
(
num_samples_
==
0
||
num_samples_
>
num_rows_
)
?
num_rows_
:
num_samples_
;
return
Status
::
OK
();
}
...
...
@@ -432,19 +416,8 @@ Status VOCOp::ReadAnnotationToTensor(const std::string &path, const ColDescripto
return
Status
::
OK
();
}
// Derived from RandomAccessOp
Status
VOCOp
::
GetNumRowsInDataset
(
int64_t
*
num
)
const
{
if
(
num
==
nullptr
||
num_rows_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"There is no valid data matching the dataset API VOCDataset.Please check file path or dataset API "
"validation first."
);
}
(
*
num
)
=
num_rows_
;
return
Status
::
OK
();
}
Status
VOCOp
::
CountTotalRows
(
const
std
::
string
&
dir
,
const
std
::
string
&
task_type
,
const
std
::
string
&
task_mode
,
const
py
::
dict
&
dict
,
int64_t
numSamples
,
int64_t
*
count
)
{
const
py
::
dict
&
dict
,
int64_t
*
count
)
{
if
(
task_type
==
"Detection"
)
{
std
::
map
<
std
::
string
,
int32_t
>
input_class_indexing
;
for
(
auto
p
:
dict
)
{
...
...
@@ -464,14 +437,12 @@ Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_typ
RETURN_IF_NOT_OK
(
op
->
ParseImageIds
());
*
count
=
static_cast
<
int64_t
>
(
op
->
image_ids_
.
size
());
}
*
count
=
(
numSamples
==
0
||
*
count
<
numSamples
)
?
*
count
:
numSamples
;
return
Status
::
OK
();
}
Status
VOCOp
::
GetClassIndexing
(
const
std
::
string
&
dir
,
const
std
::
string
&
task_type
,
const
std
::
string
&
task_mode
,
const
py
::
dict
&
dict
,
int64_t
numSamples
,
std
::
map
<
std
::
string
,
int32_t
>
*
output_class_indexing
)
{
const
py
::
dict
&
dict
,
std
::
map
<
std
::
string
,
int32_t
>
*
output_class_indexing
)
{
std
::
map
<
std
::
string
,
int32_t
>
input_class_indexing
;
for
(
auto
p
:
dict
)
{
(
void
)
input_class_indexing
.
insert
(
std
::
pair
<
std
::
string
,
int32_t
>
(
py
::
reinterpret_borrow
<
py
::
str
>
(
p
.
first
),
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
浏览文件 @
769ae609
...
...
@@ -116,14 +116,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
return
*
this
;
}
// Setter method.
// @param int64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder
&
SetNumSamples
(
int64_t
num_samples
)
{
builder_num_samples_
=
num_samples
;
return
*
this
;
}
// Setter method.
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
...
...
@@ -157,7 +149,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
int32_t
builder_num_workers_
;
int32_t
builder_op_connector_size_
;
int32_t
builder_rows_per_buffer_
;
int64_t
builder_num_samples_
;
std
::
shared_ptr
<
Sampler
>
builder_sampler_
;
std
::
unique_ptr
<
DataSchema
>
builder_schema_
;
std
::
map
<
std
::
string
,
int32_t
>
builder_labels_to_read_
;
...
...
@@ -171,14 +162,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @param int32_t num_workers - number of workers reading images in parallel
// @param int32_t rows_per_buffer - number of images (rows) in each buffer
// @param int32_t queue_size - connector queue size
// @param int64_t num_samples - number of samples to read
// @param bool decode - whether to decode images
// @param std::unique_ptr<DataSchema> data_schema - the schema of the VOC dataset
// @param std::shared_ptr<Sampler> sampler - sampler tells VOCOp what to read
VOCOp
(
const
TaskType
&
task_type
,
const
std
::
string
&
task_mode
,
const
std
::
string
&
folder_path
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
class_index
,
int32_t
num_workers
,
int32_t
rows_per_buffer
,
int32_t
queue_size
,
int64_t
num_samples
,
bool
decode
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
);
int32_t
queue_size
,
bool
decode
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
);
// Destructor
~
VOCOp
()
=
default
;
...
...
@@ -194,15 +183,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @return Status - The error code return
Status
operator
()()
override
;
// Method derived from RandomAccessOp, enable Sampler to get numRows
// @param uint64_t num - to return numRows
// return Status - The error code return
Status
GetNumSamples
(
int64_t
*
num
)
const
override
;
// Method derived from RandomAccessOp, enable Sampler to get total number of rows in dataset
// @param uint64_t num - to return numRows
Status
GetNumRowsInDataset
(
int64_t
*
num
)
const
override
;
// A print method typically used for debugging
// @param out
// @param show_all
...
...
@@ -212,10 +192,9 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @param const std::string &task_type - task type of reading voc job
// @param const std::string &task_mode - task mode of reading voc job
// @param const py::dict &dict - input dict of class index
// @param int64_t numSamples - samples number of VOCDataset
// @param int64_t *count - output rows number of VOCDataset
static
Status
CountTotalRows
(
const
std
::
string
&
dir
,
const
std
::
string
&
task_type
,
const
std
::
string
&
task_mode
,
const
py
::
dict
&
dict
,
int64_t
numSamples
,
int64_t
*
count
);
const
py
::
dict
&
dict
,
int64_t
*
count
);
// @param const std::string &dir - VOC dir path
// @param const std::string &task_type - task type of reading voc job
...
...
@@ -224,8 +203,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @param int64_t numSamples - samples number of VOCDataset
// @param std::map<std::string, int32_t> *output_class_indexing - output class index of VOCDataset
static
Status
GetClassIndexing
(
const
std
::
string
&
dir
,
const
std
::
string
&
task_type
,
const
std
::
string
&
task_mode
,
const
py
::
dict
&
dict
,
int64_t
numSamples
,
std
::
map
<
std
::
string
,
int32_t
>
*
output_class_indexing
);
const
py
::
dict
&
dict
,
std
::
map
<
std
::
string
,
int32_t
>
*
output_class_indexing
);
private:
// Initialize Sampler, calls sampler->Init() within
...
...
@@ -283,8 +261,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
bool
decode_
;
int64_t
row_cnt_
;
int64_t
buf_cnt_
;
int64_t
num_rows_
;
int64_t
num_samples_
;
std
::
string
folder_path_
;
TaskType
task_type_
;
std
::
string
task_mode_
;
...
...
mindspore/dataset/__init__.py
浏览文件 @
769ae609
...
...
@@ -23,7 +23,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset
GeneratorDataset
,
ManifestDataset
,
Cifar10Dataset
,
Cifar100Dataset
,
VOCDataset
,
CelebADataset
,
TextFileDataset
,
\
Schema
,
Shuffle
,
zip
,
RandomDataset
from
.engine.samplers
import
DistributedSampler
,
PKSampler
,
RandomSampler
,
SequentialSampler
,
SubsetRandomSampler
,
\
WeightedRandomSampler
,
S
ubsetSampler
,
S
ampler
WeightedRandomSampler
,
Sampler
from
.engine.serializer_deserializer
import
serialize
,
deserialize
,
show
from
.engine.graphdata
import
GraphData
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
769ae609
...
...
@@ -1261,8 +1261,8 @@ class MappableDataset(SourceDataset):
def
_get_sampler_dataset_size
(
self
):
if
self
.
sampler
is
not
None
:
if
hasattr
(
self
.
sampler
,
'get_
dataset_size
'
):
return
self
.
sampler
.
get_
dataset_size
()
if
hasattr
(
self
.
sampler
,
'get_
num_samples
'
):
return
self
.
sampler
.
get_
num_samples
()
if
hasattr
(
self
.
sampler
,
'__len__'
):
return
len
(
self
.
sampler
)
...
...
@@ -1355,7 +1355,7 @@ class MappableDataset(SourceDataset):
random_sampler
.
reshuffle_each_epoch
=
False
ds
.
add_sampler
(
random_sampler
)
subset_sampler
=
samplers
.
S
ubset
Sampler
(
current_split_start_index
,
size
)
subset_sampler
=
samplers
.
S
equential
Sampler
(
current_split_start_index
,
size
)
ds
.
add_sampler
(
subset_sampler
)
# add sequential sampler, so that if user calls use_sampler, we will
...
...
@@ -2226,31 +2226,45 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
num_shards (int): Number of shard for sharding.
shard_id (int): Shard ID.
"""
if
input_sampler
is
not
None
:
# If the user provided a sampler, then it doesn't matter what the other args are because
# we are being asked specifically to use the given sampler.
# That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
# be None. Consider this example:
# sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
# data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
# In this case, the user has given different sample-related arguments that contradict each other.
# To prevent this, only allow the user to manually specify the sampler if those arguments are all None
if
(
isinstance
(
input_sampler
,
(
samplers
.
SequentialSampler
,
samplers
.
DistributedSampler
,
samplers
.
RandomSampler
,
samplers
.
SubsetRandomSampler
,
samplers
.
WeightedRandomSampler
,
samplers
.
Sampler
))
and
(
num_shards
is
not
None
or
shard_id
is
not
None
or
shuffle
is
not
None
or
num_samples
is
not
None
)):
raise
ValueError
(
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
' shard_id: {}, shuffle: {})'
.
format
(
num_samples
,
num_shards
,
shard_id
,
shuffle
))
return
input_sampler
if
shuffle
is
None
:
if
input_sampler
is
not
None
:
# If shuffle is not specified, user provided sampler, use user's sampler
return
input_sampler
if
num_shards
is
not
None
:
# If shuffle is not specified, sharding enabled, use distributed random sampler
shuffle
=
True
return
samplers
.
DistributedSampler
(
num_shards
,
shard_id
,
shuffle
=
shuffle
)
return
samplers
.
DistributedSampler
(
num_shards
,
shard_id
,
shuffle
=
shuffle
,
num_samples
=
num_samples
)
# If shuffle is not specified, sharding disabled, use random sampler
if
num_samples
is
not
None
:
return
samplers
.
RandomSampler
(
replacement
=
True
,
num_samples
=
num_samples
)
return
samplers
.
RandomSampler
()
return
samplers
.
RandomSampler
(
num_samples
=
num_samples
)
if
shuffle
is
True
:
if
num_shards
is
not
None
:
# If shuffle enabled, sharding enabled, use distributed random sampler
return
samplers
.
DistributedSampler
(
num_shards
,
shard_id
,
shuffle
=
shuffle
)
return
samplers
.
DistributedSampler
(
num_shards
,
shard_id
,
shuffle
=
shuffle
,
num_samples
=
num_samples
)
# If shuffle enabled, sharding disabled, use random sampler
if
num_samples
is
not
None
:
return
samplers
.
RandomSampler
(
replacement
=
True
,
num_samples
=
num_samples
)
return
samplers
.
RandomSampler
()
return
samplers
.
RandomSampler
(
num_samples
=
num_samples
)
if
num_shards
is
not
None
:
# If shuffle disabled, sharding enabled, use distributed sequential sampler
return
samplers
.
DistributedSampler
(
num_shards
,
shard_id
,
shuffle
=
shuffle
)
return
samplers
.
DistributedSampler
(
num_shards
,
shard_id
,
shuffle
=
shuffle
,
num_samples
=
num_samples
)
# If shuffle disabled, sharding disabled, use sequential sampler
return
samplers
.
SequentialSampler
()
return
samplers
.
SequentialSampler
(
num_samples
=
num_samples
)
class
ImageFolderDatasetV2
(
MappableDataset
):
...
...
@@ -2370,11 +2384,7 @@ class ImageFolderDatasetV2(MappableDataset):
Return:
Number, number of batches.
"""
if
self
.
num_samples
is
None
:
num_samples
=
0
else
:
num_samples
=
self
.
num_samples
num_rows
=
ImageFolderOp
.
get_num_rows_and_classes
(
self
.
dataset_dir
,
num_samples
)[
0
]
num_rows
=
ImageFolderOp
.
get_num_rows_and_classes
(
self
.
dataset_dir
)[
0
]
rows_per_shard
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
...
...
@@ -2390,11 +2400,7 @@ class ImageFolderDatasetV2(MappableDataset):
Return:
Number, number of classes.
"""
if
self
.
num_samples
is
None
:
num_samples
=
0
else
:
num_samples
=
self
.
num_samples
return
ImageFolderOp
.
get_num_rows_and_classes
(
self
.
dataset_dir
,
num_samples
)[
1
]
return
ImageFolderOp
.
get_num_rows_and_classes
(
self
.
dataset_dir
)[
1
]
def
is_shuffled
(
self
):
if
self
.
shuffle_level
is
None
:
...
...
@@ -2503,12 +2509,7 @@ class MnistDataset(MappableDataset):
Return:
Number, number of batches.
"""
if
self
.
num_samples
is
None
:
num_samples
=
0
else
:
num_samples
=
self
.
num_samples
num_rows
=
MnistOp
.
get_num_rows
(
self
.
dataset_dir
,
num_samples
)
num_rows
=
MnistOp
.
get_num_rows
(
self
.
dataset_dir
)
rows_per_shard
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
...
...
@@ -2956,11 +2957,8 @@ class GeneratorDataset(MappableDataset):
if
isinstance
(
self
.
sampler
,
(
samplers
.
SequentialSampler
,
samplers
.
DistributedSampler
,
samplers
.
RandomSampler
,
samplers
.
SubsetRandomSampler
,
samplers
.
WeightedRandomSampler
,
samplers
.
Sampler
)):
if
num_samples
is
None
:
num_samples
=
len
(
source
)
sampler_instance
=
self
.
sampler
.
create
()
sampler_instance
.
set_num_rows
(
len
(
source
))
sampler_instance
.
set_num_samples
(
num_samples
)
sampler_instance
.
initialize
()
if
num_parallel_workers
>
1
:
self
.
source
=
(
lambda
:
_cpp_sampler_fn_mp
(
sampler_instance
,
source
,
num_parallel_workers
))
...
...
@@ -3304,17 +3302,12 @@ class ManifestDataset(MappableDataset):
Return:
Number, number of batches.
"""
if
self
.
num_samples
is
None
:
num_samples
=
0
else
:
num_samples
=
self
.
num_samples
if
self
.
class_indexing
is
None
:
class_indexing
=
dict
()
else
:
class_indexing
=
self
.
class_indexing
num_rows
=
ManifestOp
.
get_num_rows_and_classes
(
self
.
dataset_file
,
num_samples
,
class_indexing
,
self
.
usage
)[
0
]
num_rows
=
ManifestOp
.
get_num_rows_and_classes
(
self
.
dataset_file
,
class_indexing
,
self
.
usage
)[
0
]
rows_per_shard
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
...
...
@@ -3330,17 +3323,12 @@ class ManifestDataset(MappableDataset):
Return:
Number, number of classes.
"""
if
self
.
num_samples
is
None
:
num_samples
=
0
else
:
num_samples
=
self
.
num_samples
if
self
.
class_indexing
is
None
:
class_indexing
=
dict
()
else
:
class_indexing
=
self
.
class_indexing
return
ManifestOp
.
get_num_rows_and_classes
(
self
.
dataset_file
,
num_samples
,
class_indexing
,
self
.
usage
)[
1
]
return
ManifestOp
.
get_num_rows_and_classes
(
self
.
dataset_file
,
class_indexing
,
self
.
usage
)[
1
]
def
get_class_indexing
(
self
):
"""
...
...
@@ -3349,17 +3337,12 @@ class ManifestDataset(MappableDataset):
Return:
Dict, A str-to-int mapping from label name to index.
"""
if
self
.
num_samples
is
None
:
num_samples
=
0
else
:
num_samples
=
self
.
num_samples
if
self
.
class_indexing
is
None
:
class_indexing
=
dict
()
else
:
class_indexing
=
self
.
class_indexing
return
ManifestOp
.
get_class_indexing
(
self
.
dataset_file
,
num_samples
,
class_indexing
,
self
.
usage
)
return
ManifestOp
.
get_class_indexing
(
self
.
dataset_file
,
class_indexing
,
self
.
usage
)
def
is_shuffled
(
self
):
if
self
.
shuffle_level
is
None
:
...
...
@@ -3473,12 +3456,8 @@ class Cifar10Dataset(MappableDataset):
Return:
Number, number of batches.
"""
if
self
.
num_samples
is
None
:
num_samples
=
0
else
:
num_samples
=
self
.
num_samples
num_rows
=
CifarOp
.
get_num_rows
(
self
.
dataset_dir
,
num_samples
,
True
)
num_rows
=
CifarOp
.
get_num_rows
(
self
.
dataset_dir
,
True
)
rows_per_shard
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
...
...
@@ -3597,12 +3576,8 @@ class Cifar100Dataset(MappableDataset):
Return:
Number, number of batches.
"""
if
self
.
num_samples
is
None
:
num_samples
=
0
else
:
num_samples
=
self
.
num_samples
num_rows
=
CifarOp
.
get_num_rows
(
self
.
dataset_dir
,
num_samples
,
False
)
num_rows
=
CifarOp
.
get_num_rows
(
self
.
dataset_dir
,
False
)
rows_per_shard
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
...
...
@@ -3631,7 +3606,7 @@ class RandomDataset(SourceDataset):
Args:
num_samples (int): number of samples to generate.
schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
If the schema is not provided, the
meta data from the TFRecord file is considered the
schema.
If the schema is not provided, the
random dataset generates a random
schema.
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
...
...
@@ -3644,9 +3619,12 @@ class RandomDataset(SourceDataset):
schema_obj
=
Schema
(
schema
)
# read the schema file and convert to schema object to validate it
self
.
schema
=
schema
self
.
columns_list
=
columns_list
self
.
num_samples
=
num_samples
if
schema_obj
is
not
None
and
num_samples
is
None
:
self
.
num_samples
=
schema_obj
.
num_rows
elif
num_samples
is
None
:
self
.
num_samples
=
0
else
:
self
.
num_samples
=
num_samples
def
get_args
(
self
):
args
=
super
().
get_args
()
...
...
@@ -4015,17 +3993,12 @@ class VOCDataset(MappableDataset):
if
self
.
task
!=
"Detection"
:
raise
NotImplementedError
()
if
self
.
num_samples
is
None
:
num_samples
=
0
else
:
num_samples
=
self
.
num_samples
if
self
.
class_indexing
is
None
:
class_indexing
=
dict
()
else
:
class_indexing
=
self
.
class_indexing
return
VOCOp
.
get_class_indexing
(
self
.
dataset_dir
,
self
.
task
,
self
.
mode
,
class_indexing
,
num_samples
)
return
VOCOp
.
get_class_indexing
(
self
.
dataset_dir
,
self
.
task
,
self
.
mode
,
class_indexing
)
def
is_shuffled
(
self
):
if
self
.
shuffle_level
is
None
:
...
...
@@ -4205,9 +4178,11 @@ class TextFileDataset(SourceDataset):
if
self
.
_dataset_size
is
None
:
num_rows
=
TextFileOp
.
get_num_rows
(
self
.
dataset_files
)
num_rows
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
if
self
.
num_samples
is
None
:
return
num_rows
return
min
(
self
.
num_samples
,
num_rows
)
# If the user gave a num samples in the dataset, then the sampler will limit the rows returned
# to that amount. Account for that here in the row count
if
self
.
num_samples
is
not
None
and
self
.
num_samples
>
0
and
num_rows
>
self
.
num_samples
:
num_rows
=
self
.
num_samples
return
num_rows
return
self
.
_dataset_size
def
is_shuffled
(
self
):
...
...
mindspore/dataset/engine/samplers.py
浏览文件 @
769ae609
...
...
@@ -22,7 +22,6 @@ User can also define custom sampler by extending from Sampler class.
import
numpy
as
np
import
mindspore._c_dataengine
as
cde
class
Sampler
:
"""
Base class for user defined sampler.
...
...
@@ -44,10 +43,10 @@ class Sampler:
>>> ds = ds.ImageFolderDatasetV2(path, sampler=ReverseSampler())
"""
def
__init__
(
self
):
def
__init__
(
self
,
num_samples
=
None
):
self
.
dataset_size
=
0
self
.
num_samples
=
0
self
.
child_sampler
=
None
self
.
num_samples
=
num_samples
def
__iter__
(
self
):
"""
...
...
@@ -84,7 +83,8 @@ class Sampler:
# Instance fetcher
# Do not override this method!
def
create
(
self
):
c_sampler
=
cde
.
PythonSampler
(
self
)
num_samples
=
self
.
num_samples
if
self
.
num_samples
is
not
None
else
0
c_sampler
=
cde
.
PythonSampler
(
num_samples
,
self
)
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
...
...
@@ -114,7 +114,7 @@ class Sampler:
return
self
.
child_sampler
.
is_sharded
()
def
get_
dataset_size
(
self
):
def
get_
num_samples
(
self
):
return
self
.
_get_indices
().
size
...
...
@@ -124,8 +124,9 @@ class BuiltinSampler:
User should not extend this class.
"""
def
__init__
(
self
):
def
__init__
(
self
,
num_samples
=
None
):
self
.
child_sampler
=
None
self
.
num_samples
=
num_samples
def
create
(
self
):
pass
...
...
@@ -149,11 +150,37 @@ class BuiltinSampler:
def
is_sharded
(
self
):
raise
NotImplementedError
(
"Sampler must implement is_sharded."
)
def
get_dataset_size
(
self
):
def
get_num_samples
(
self
):
"""
All samplers can contain a numeric num_samples value (or it could be set to None).
Child sampler can exist or be None.
if child sampler exists, then the child sampler count can be a numeric value or None.
Given these conditions, we need to output what the sampler count is for this sampler.
The following table shows the possible results from calling this function.
child sampler num_samples child_samples result
------------- ----------- ------------- --------
T x y min(x, y)
T x None x
T None y y
T None None None
None x n/a x
None None n/a None
Returns:
int, The number of samples, or None
"""
if
self
.
child_sampler
is
not
None
:
return
self
.
child_sampler
.
get_dataset_size
()
child_samples
=
self
.
child_sampler
.
get_num_samples
()
if
self
.
num_samples
is
not
None
:
if
child_samples
is
not
None
:
return
min
(
self
.
num_samples
,
child_samples
)
return
self
.
num_samples
return
None
return
child_samples
return
self
.
num_samples
class
DistributedSampler
(
BuiltinSampler
):
...
...
@@ -164,6 +191,7 @@ class DistributedSampler(BuiltinSampler):
num_shards (int): Number of shards to divide the dataset into.
shard_id (int): Shard ID of the current shard within num_shards.
shuffle (bool, optional): If true, the indices are shuffled (default=True).
num_samples (int, optional): The number of samples to draw (default=None, all elements).
Examples:
>>> import mindspore.dataset as ds
...
...
@@ -180,7 +208,7 @@ class DistributedSampler(BuiltinSampler):
ValueError: If shuffle is not a boolean value.
"""
def
__init__
(
self
,
num_shards
,
shard_id
,
shuffle
=
True
):
def
__init__
(
self
,
num_shards
,
shard_id
,
shuffle
=
True
,
num_samples
=
None
):
if
num_shards
<=
0
:
raise
ValueError
(
"num_shards should be a positive integer value, but got num_shards={}"
.
format
(
num_shards
))
...
...
@@ -194,12 +222,13 @@ class DistributedSampler(BuiltinSampler):
self
.
shard_id
=
shard_id
self
.
shuffle
=
shuffle
self
.
seed
=
0
super
().
__init__
()
super
().
__init__
(
num_samples
)
def
create
(
self
):
num_samples
=
self
.
num_samples
if
self
.
num_samples
is
not
None
else
0
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
self
.
seed
+=
1
c_sampler
=
cde
.
DistributedSampler
(
self
.
num_shards
,
self
.
shard_id
,
self
.
shuffle
,
self
.
seed
)
c_sampler
=
cde
.
DistributedSampler
(
num_samples
,
self
.
num_shards
,
self
.
shard_id
,
self
.
shuffle
,
self
.
seed
)
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
...
...
@@ -226,6 +255,7 @@ class PKSampler(BuiltinSampler):
num_class (int, optional): Number of classes to sample (default=None, all classes).
shuffle (bool, optional): If true, the class IDs are shuffled (default=False).
class_column (str, optional): Name of column to classify dataset(default='label'), for MindDataset.
num_samples (int, optional): The number of samples to draw (default=None, all elements).
Examples:
>>> import mindspore.dataset as ds
...
...
@@ -242,7 +272,7 @@ class PKSampler(BuiltinSampler):
ValueError: If shuffle is not boolean.
"""
def
__init__
(
self
,
num_val
,
num_class
=
None
,
shuffle
=
False
,
class_column
=
'label'
):
def
__init__
(
self
,
num_val
,
num_class
=
None
,
shuffle
=
False
,
class_column
=
'label'
,
num_samples
=
None
):
if
num_val
<=
0
:
raise
ValueError
(
"num_val should be a positive integer value, but got num_val={}"
.
format
(
num_val
))
...
...
@@ -255,10 +285,11 @@ class PKSampler(BuiltinSampler):
self
.
num_val
=
num_val
self
.
shuffle
=
shuffle
self
.
class_column
=
class_column
# work for minddataset
super
().
__init__
()
super
().
__init__
(
num_samples
)
def
create
(
self
):
c_sampler
=
cde
.
PKSampler
(
self
.
num_val
,
self
.
shuffle
)
num_samples
=
self
.
num_samples
if
self
.
num_samples
is
not
None
else
0
c_sampler
=
cde
.
PKSampler
(
num_samples
,
self
.
num_val
,
self
.
shuffle
)
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
...
...
@@ -309,23 +340,18 @@ class RandomSampler(BuiltinSampler):
raise
ValueError
(
"replacement should be a boolean value, but got replacement={}"
.
format
(
replacement
))
if
num_samples
is
not
None
:
if
num_samples
<
=
0
:
if
num_samples
<
0
:
raise
ValueError
(
"num_samples should be a positive integer "
"value, but got num_samples={}"
.
format
(
num_samples
))
self
.
deterministic
=
False
self
.
replacement
=
replacement
self
.
num_samples
=
num_samples
self
.
reshuffle_each_epoch
=
True
super
().
__init__
()
super
().
__init__
(
num_samples
)
def
create
(
self
):
c_sampler
=
None
if
self
.
num_samples
is
None
:
c_sampler
=
cde
.
RandomSampler
(
self
.
replacement
,
self
.
reshuffle_each_epoch
)
else
:
c_sampler
=
cde
.
RandomSampler
(
self
.
replacement
,
self
.
reshuffle_each_epoch
,
self
.
num_samples
)
num_samples
=
self
.
num_samples
if
self
.
num_samples
is
not
None
else
0
c_sampler
=
cde
.
RandomSampler
(
num_samples
,
self
.
replacement
,
self
.
reshuffle_each_epoch
)
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
...
...
@@ -339,84 +365,33 @@ class RandomSampler(BuiltinSampler):
return
self
.
child_sampler
.
is_sharded
()
def
get_dataset_size
(
self
):
return
self
.
num_samples
class
SequentialSampler
(
BuiltinSampler
):
"""
Samples the dataset elements sequentially, same as not having a sampler.
Examples:
>>> import mindspore.dataset as ds
>>>
>>> dataset_dir = "path/to/imagefolder_directory"
>>>
>>> # creates a SequentialSampler
>>> sampler = ds.SequentialSampler()
>>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
"""
def
create
(
self
):
c_sampler
=
cde
.
SequentialSampler
()
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
def
is_shuffled
(
self
):
if
self
.
child_sampler
is
None
:
return
False
return
self
.
child_sampler
.
is_shuffled
()
def
is_sharded
(
self
):
if
self
.
child_sampler
is
None
:
return
False
return
self
.
child_sampler
.
is_sharded
()
class
SubsetSampler
(
BuiltinSampler
):
"""
Samples a subset of elements consecutively from a given index.
Args:
start_index (int
): Index to start sampling at.
subset_size (int): How many samples to include in this subset
.
start_index (int
, optional): Index to start sampling at. (dafault=None starts at first id)
num_samples (int, optional): Number of elements to sample (default=None, all elements)
.
Examples:
>>> import mindspore.dataset as ds
>>>
>>> dataset_dir = "path/to/imagefolder_directory"
>>>
>>> # creates a S
ubsetSampler, will sample the next 5 images from the 100th image.
>>> sampler = ds.S
ubsetSampler(100, 5
)
>>> # creates a S
equentialSampler
>>> sampler = ds.S
equentialSampler(
)
>>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
Raises:
ValueError: If start_index is not a positive int.
ValueError: If subset_size is not a positive int.
"""
def
__init__
(
self
,
start_index
,
subset_size
):
if
not
isinstance
(
start_index
,
int
):
raise
ValueError
(
"start_index should be an int."
)
if
start_index
<
0
:
raise
ValueError
(
"start_index should not be negative."
)
if
not
isinstance
(
subset_size
,
int
):
raise
ValueError
(
"start_index should be an int"
)
if
subset_size
<
0
:
raise
ValueError
(
"subset_size should not be negative."
)
def
__init__
(
self
,
start_index
=
None
,
num_samples
=
None
):
self
.
start_index
=
start_index
self
.
subset_size
=
subset_size
super
().
__init__
()
super
().
__init__
(
num_samples
)
def
create
(
self
):
c_sampler
=
cde
.
SubsetSampler
(
self
.
start_index
,
self
.
subset_size
)
start_index
=
self
.
start_index
if
self
.
start_index
is
not
None
else
0
num_samples
=
self
.
num_samples
if
self
.
num_samples
is
not
None
else
0
c_sampler
=
cde
.
SequentialSampler
(
num_samples
,
start_index
)
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
...
...
@@ -433,9 +408,6 @@ class SubsetSampler(BuiltinSampler):
return
self
.
child_sampler
.
is_sharded
()
def
get_dataset_size
(
self
):
return
self
.
subset_size
class
SubsetRandomSampler
(
BuiltinSampler
):
"""
...
...
@@ -443,6 +415,7 @@ class SubsetRandomSampler(BuiltinSampler):
Args:
indices (list[int]): A sequence of indices.
num_samples (int, optional): Number of elements to sample (default=None, all elements).
Examples:
>>> import mindspore.dataset as ds
...
...
@@ -456,15 +429,16 @@ class SubsetRandomSampler(BuiltinSampler):
>>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
"""
def
__init__
(
self
,
indices
):
def
__init__
(
self
,
indices
,
num_samples
=
None
):
if
not
isinstance
(
indices
,
list
):
indices
=
[
indices
]
self
.
indices
=
indices
super
().
__init__
()
super
().
__init__
(
num_samples
)
def
create
(
self
):
c_sampler
=
cde
.
SubsetRandomSampler
(
self
.
indices
)
num_samples
=
self
.
num_samples
if
self
.
num_samples
is
not
None
else
0
c_sampler
=
cde
.
SubsetRandomSampler
(
num_samples
,
self
.
indices
)
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
...
...
@@ -481,9 +455,9 @@ class SubsetRandomSampler(BuiltinSampler):
def
_create_for_minddataset
(
self
):
return
cde
.
MindrecordSubsetRandomSampler
(
self
.
indices
)
def
get_dataset_size
(
self
):
return
len
(
self
.
indic
es
)
def
get_num_samples
(
self
):
num_samples
=
super
().
get_num_samples
()
return
min
(
len
(
self
.
indices
),
num_sampl
es
)
class
WeightedRandomSampler
(
BuiltinSampler
):
...
...
@@ -492,7 +466,7 @@ class WeightedRandomSampler(BuiltinSampler):
Args:
weights (list[float]): A sequence of weights, not necessarily summing up to 1.
num_samples (int): Number of elements to sample.
num_samples (int): Number of elements to sample
(default=None, all elements)
.
replacement (bool, optional): If True, put the sample ID back for the next draw (default=True).
Examples:
...
...
@@ -511,24 +485,25 @@ class WeightedRandomSampler(BuiltinSampler):
ValueError: If replacement is not boolean.
"""
def
__init__
(
self
,
weights
,
num_samples
,
replacement
=
True
):
def
__init__
(
self
,
weights
,
num_samples
=
None
,
replacement
=
True
):
if
not
isinstance
(
weights
,
list
):
weights
=
[
weights
]
if
num_samples
<=
0
:
raise
ValueError
(
"num_samples should be a positive integer "
"value, but got num_samples={}"
.
format
(
num_samples
))
if
num_samples
is
not
None
:
if
num_samples
<
0
:
raise
ValueError
(
"num_samples should be a positive integer "
"value, but got num_samples={}"
.
format
(
num_samples
))
if
not
isinstance
(
replacement
,
bool
):
raise
ValueError
(
"replacement should be a boolean value, but got replacement={}"
.
format
(
replacement
))
self
.
weights
=
weights
self
.
num_samples
=
num_samples
self
.
replacement
=
replacement
super
().
__init__
()
super
().
__init__
(
num_samples
)
def
create
(
self
):
c_sampler
=
cde
.
WeightedRandomSampler
(
self
.
weights
,
self
.
num_samples
,
self
.
replacement
)
num_samples
=
self
.
num_samples
if
self
.
num_samples
is
not
None
else
0
c_sampler
=
cde
.
WeightedRandomSampler
(
num_samples
,
self
.
weights
,
self
.
replacement
)
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
...
...
@@ -541,6 +516,3 @@ class WeightedRandomSampler(BuiltinSampler):
return
False
return
self
.
child_sampler
.
is_sharded
()
def
get_dataset_size
(
self
):
return
self
.
num_samples
mindspore/dataset/engine/serializer_deserializer.py
浏览文件 @
769ae609
...
...
@@ -161,6 +161,20 @@ def traverse(node):
else
:
node_repr
[
k
]
=
v
# If a sampler exists in this node, then the following 4 arguments must be set to None:
# num_samples, shard_id, num_shards, shuffle
# These arguments get moved into the sampler itself, so they are no longer needed to
# be set at the dataset level.
if
'sampler'
in
node_args
.
keys
():
if
'num_samples'
in
node_repr
.
keys
():
node_repr
[
'num_samples'
]
=
None
if
'shuffle'
in
node_repr
.
keys
():
node_repr
[
'shuffle'
]
=
None
if
'num_shards'
in
node_repr
.
keys
():
node_repr
[
'num_shards'
]
=
None
if
'shard_id'
in
node_repr
.
keys
():
node_repr
[
'shard_id'
]
=
None
# Leaf node doesn't have input attribute.
if
not
node
.
input
:
return
node_repr
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
769ae609
...
...
@@ -283,8 +283,8 @@ def check_num_parallel_workers(value):
def
check_num_samples
(
value
):
check_type
(
value
,
'num_samples'
,
int
)
if
value
<
=
0
:
raise
ValueError
(
"num_samples
must be greater
than 0!"
)
if
value
<
0
:
raise
ValueError
(
"num_samples
cannot be less
than 0!"
)
def
check_dataset_dir
(
dataset_dir
):
...
...
tests/ut/cpp/dataset/celeba_op_test.cc
浏览文件 @
769ae609
...
...
@@ -39,14 +39,13 @@ std::shared_ptr<RepeatOp> Repeat(int repeat_cnt);
std
::
shared_ptr
<
ExecutionTree
>
Build
(
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
ops
);
std
::
shared_ptr
<
CelebAOp
>
Celeba
(
int32_t
num_workers
,
int32_t
rows_per_buffer
,
int32_t
queue_size
,
const
std
::
string
&
dir
,
int64_t
num_samples
=
0
,
std
::
unique_ptr
<
Sampler
>
sampler
=
nullptr
,
bool
decode
=
false
,
const
std
::
string
&
dataset_type
=
"all"
)
{
const
std
::
string
&
dir
,
std
::
shared_ptr
<
Sampler
>
sampler
=
nullptr
,
bool
decode
=
false
,
const
std
::
string
&
dataset_type
=
"all"
)
{
std
::
shared_ptr
<
CelebAOp
>
so
;
CelebAOp
::
Builder
builder
;
Status
rc
=
builder
.
SetNumWorkers
(
num_workers
).
SetCelebADir
(
dir
).
SetRowsPerBuffer
(
rows_per_buffer
)
.
SetOpConnectorSize
(
queue_size
).
SetSampler
(
std
::
move
(
sampler
)).
SetDecode
(
decode
)
.
Set
NumSamples
(
num_samples
).
Set
DatasetType
(
dataset_type
).
Build
(
&
so
);
.
SetDatasetType
(
dataset_type
).
Build
(
&
so
);
return
so
;
}
...
...
@@ -116,11 +115,12 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) {
TEST_F
(
MindDataTestCelebaDataset
,
TestSubsetRandomSamplerCeleba
)
{
std
::
vector
<
int64_t
>
indices
({
1
});
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
SubsetRandomSampler
>
(
indices
);
int64_t
num_samples
=
0
;
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_shared
<
SubsetRandomSampler
>
(
num_samples
,
indices
);
uint32_t
expect_labels
[
1
][
40
]
=
{{
0
,
0
,
0
,
1
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
}};
std
::
string
dir
=
datasets_root_path_
+
"/testCelebAData/"
;
uint32_t
count
=
0
;
auto
tree
=
Build
({
Celeba
(
16
,
2
,
32
,
dir
,
0
,
std
::
move
(
sampler
))});
auto
tree
=
Build
({
Celeba
(
16
,
2
,
32
,
dir
,
std
::
move
(
sampler
))});
tree
->
Prepare
();
Status
rc
=
tree
->
Launch
();
if
(
rc
.
IsError
())
{
...
...
@@ -143,25 +143,3 @@ TEST_F(MindDataTestCelebaDataset, TestSubsetRandomSamplerCeleba) {
EXPECT_TRUE
(
count
==
1
);
}
}
TEST_F
(
MindDataTestCelebaDataset
,
TestCelebaNumSamples
)
{
std
::
string
dir
=
datasets_root_path_
+
"/testCelebAData/"
;
uint32_t
count
=
0
;
auto
tree
=
Build
({
Celeba
(
16
,
2
,
32
,
dir
,
1
)});
tree
->
Prepare
();
Status
rc
=
tree
->
Launch
();
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
"Return code error detected during tree launch: "
<<
rc
.
ToString
()
<<
"."
;
EXPECT_TRUE
(
false
);
}
else
{
DatasetIterator
di
(
tree
);
TensorMap
tersor_map
;
di
.
GetNextAsMap
(
&
tersor_map
);
EXPECT_TRUE
(
rc
.
IsOk
());
while
(
tersor_map
.
size
()
!=
0
)
{
count
++
;
di
.
GetNextAsMap
(
&
tersor_map
);
}
EXPECT_TRUE
(
count
==
1
);
}
}
tests/ut/cpp/dataset/cifar_op_test.cc
浏览文件 @
769ae609
...
...
@@ -45,13 +45,12 @@ std::shared_ptr<RepeatOp> Repeat(int repeatCnt);
std
::
shared_ptr
<
ExecutionTree
>
Build
(
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
ops
);
std
::
shared_ptr
<
CifarOp
>
Cifarop
(
uint64_t
num_works
,
uint64_t
rows
,
uint64_t
conns
,
std
::
string
path
,
std
::
unique_ptr
<
Sampler
>
sampler
=
nullptr
,
uint64_t
num_samples
=
0
,
bool
cifar10
=
true
)
{
std
::
shared_ptr
<
Sampler
>
sampler
=
nullptr
,
bool
cifar10
=
true
)
{
std
::
shared_ptr
<
CifarOp
>
so
;
CifarOp
::
Builder
builder
;
Status
rc
=
builder
.
SetNumWorkers
(
num_works
).
SetCifarDir
(
path
).
SetRowsPerBuffer
(
rows
)
.
SetOpConnectorSize
(
conns
).
SetSampler
(
std
::
move
(
sampler
)).
SetCifarType
(
cifar10
)
.
SetNumSamples
(
num_samples
).
Build
(
&
so
);
.
Build
(
&
so
);
return
so
;
}
...
...
@@ -66,7 +65,7 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) {
//appear in this dataset
//Example: python tests/dataset/data/prep_data.py
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
auto
tree
=
Build
({
Cifarop
(
16
,
2
,
32
,
folder_path
,
nullptr
,
100
)});
auto
tree
=
Build
({
Cifarop
(
16
,
2
,
32
,
folder_path
,
nullptr
)});
tree
->
Prepare
();
Status
rc
=
tree
->
Launch
();
if
(
rc
.
IsError
())
{
...
...
@@ -79,7 +78,8 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) {
EXPECT_TRUE
(
rc
.
IsOk
());
uint64_t
i
=
0
;
uint32_t
label
=
0
;
while
(
tensor_map
.
size
()
!=
0
)
{
// Note: only iterating first 100 rows then break out.
while
(
tensor_map
.
size
()
!=
0
&&
i
<
100
)
{
tensor_map
[
"label"
]
->
GetItemAt
<
uint32_t
>
(
&
label
,
{});
MS_LOG
(
DEBUG
)
<<
"row: "
<<
i
<<
"
\t
"
<<
tensor_map
[
"image"
]
->
shape
()
<<
"label:"
<<
label
<<
"
\n
"
;
i
++
;
...
...
@@ -92,9 +92,9 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) {
TEST_F
(
MindDataTestCifarOp
,
TestRandomSamplerCifar10
)
{
uint32_t
original_seed
=
GlobalContext
::
config_manager
()
->
seed
();
GlobalContext
::
config_manager
()
->
set_seed
(
0
);
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
RandomSampler
>
(
true
,
true
,
12
);
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
RandomSampler
>
(
12
,
true
,
true
);
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
auto
tree
=
Build
({
Cifarop
(
16
,
2
,
32
,
folder_path
,
std
::
move
(
sampler
)
,
100
)});
auto
tree
=
Build
({
Cifarop
(
16
,
2
,
32
,
folder_path
,
std
::
move
(
sampler
))});
tree
->
Prepare
();
Status
rc
=
tree
->
Launch
();
if
(
rc
.
IsError
())
{
...
...
@@ -118,34 +118,9 @@ TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) {
GlobalContext
::
config_manager
()
->
set_seed
(
original_seed
);
}
TEST_F
(
MindDataTestCifarOp
,
TestCifar10NumSample
)
{
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
auto
tree
=
Build
({
Cifarop
(
16
,
2
,
32
,
folder_path
,
nullptr
,
100
)});
tree
->
Prepare
();
Status
rc
=
tree
->
Launch
();
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
"Return code error detected during tree launch: "
<<
common
::
SafeCStr
(
rc
.
ToString
())
<<
"."
;
EXPECT_TRUE
(
false
);
}
else
{
DatasetIterator
di
(
tree
);
TensorMap
tensor_map
;
di
.
GetNextAsMap
(
&
tensor_map
);
EXPECT_TRUE
(
rc
.
IsOk
());
uint64_t
i
=
0
;
uint32_t
label
=
0
;
while
(
tensor_map
.
size
()
!=
0
)
{
tensor_map
[
"label"
]
->
GetItemAt
<
uint32_t
>
(
&
label
,
{});
MS_LOG
(
DEBUG
)
<<
"row: "
<<
i
<<
"
\t
"
<<
tensor_map
[
"image"
]
->
shape
()
<<
"label:"
<<
label
<<
"
\n
"
;
i
++
;
di
.
GetNextAsMap
(
&
tensor_map
);
}
EXPECT_TRUE
(
i
==
100
);
}
}
TEST_F
(
MindDataTestCifarOp
,
TestSequentialSamplerCifar100
)
{
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar100Data/"
;
auto
tree
=
Build
({
Cifarop
(
16
,
2
,
32
,
folder_path
,
nullptr
,
100
,
false
)});
auto
tree
=
Build
({
Cifarop
(
16
,
2
,
32
,
folder_path
,
nullptr
,
false
)});
tree
->
Prepare
();
Status
rc
=
tree
->
Launch
();
if
(
rc
.
IsError
())
{
...
...
@@ -159,7 +134,8 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar100) {
uint64_t
i
=
0
;
uint32_t
coarse
=
0
;
uint32_t
fine
=
0
;
while
(
tensor_map
.
size
()
!=
0
)
{
// only iterate to 100 then break out of loop
while
(
tensor_map
.
size
()
!=
0
&&
i
<
100
)
{
tensor_map
[
"coarse_label"
]
->
GetItemAt
<
uint32_t
>
(
&
coarse
,
{});
tensor_map
[
"fine_label"
]
->
GetItemAt
<
uint32_t
>
(
&
fine
,
{});
MS_LOG
(
DEBUG
)
<<
"row: "
<<
i
<<
"
\t
"
<<
tensor_map
[
"image"
]
->
shape
()
<<
" coarse:"
...
...
tests/ut/cpp/dataset/image_folder_op_test.cc
浏览文件 @
769ae609
...
...
@@ -50,9 +50,8 @@ std::shared_ptr<RepeatOp> Repeat(int repeat_cnt);
std
::
shared_ptr
<
ExecutionTree
>
Build
(
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
ops
);
std
::
shared_ptr
<
ImageFolderOp
>
ImageFolder
(
int64_t
num_works
,
int64_t
rows
,
int64_t
conns
,
std
::
string
path
,
bool
shuf
=
false
,
std
::
unique_ptr
<
Sampler
>
sampler
=
nullptr
,
std
::
map
<
std
::
string
,
int32_t
>
map
=
{},
int64_t
num_samples
=
0
,
bool
decode
=
false
)
{
bool
shuf
=
false
,
std
::
shared_ptr
<
Sampler
>
sampler
=
nullptr
,
std
::
map
<
std
::
string
,
int32_t
>
map
=
{},
bool
decode
=
false
)
{
std
::
shared_ptr
<
ImageFolderOp
>
so
;
ImageFolderOp
::
Builder
builder
;
Status
rc
=
builder
.
SetNumWorkers
(
num_works
)
...
...
@@ -63,7 +62,6 @@ std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int6
.
SetSampler
(
std
::
move
(
sampler
))
.
SetClassIndex
(
map
)
.
SetDecode
(
decode
)
.
SetNumSamples
(
num_samples
)
.
Build
(
&
so
);
return
so
;
}
...
...
@@ -138,7 +136,8 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomImageFolder) {
TEST_F
(
MindDataTestImageFolderSampler
,
TestRandomSamplerImageFolder
)
{
int32_t
original_seed
=
GlobalContext
::
config_manager
()
->
seed
();
GlobalContext
::
config_manager
()
->
set_seed
(
0
);
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
RandomSampler
>
(
true
,
true
,
12
);
int64_t
num_samples
=
12
;
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
RandomSampler
>
(
num_samples
,
true
,
true
);
int32_t
res
[]
=
{
2
,
2
,
2
,
3
,
2
,
3
,
2
,
3
,
1
,
2
,
2
,
1
};
// ground truth label
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data"
;
auto
tree
=
Build
({
ImageFolder
(
16
,
2
,
32
,
folder_path
,
false
,
std
::
move
(
sampler
))});
...
...
@@ -200,7 +199,8 @@ TEST_F(MindDataTestImageFolderSampler, TestSequentialImageFolderWithRepeatBatch)
TEST_F
(
MindDataTestImageFolderSampler
,
TestSubsetRandomSamplerImageFolder
)
{
// id range 0 - 10 is label 0, and id range 11 - 21 is label 1
std
::
vector
<
int64_t
>
indices
({
0
,
1
,
2
,
3
,
4
,
5
,
12
,
13
,
14
,
15
,
16
,
11
});
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
SubsetRandomSampler
>
(
indices
);
int64_t
num_samples
=
0
;
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_shared
<
SubsetRandomSampler
>
(
num_samples
,
indices
);
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data"
;
// Expect 6 samples for label 0 and 1
int
res
[
2
]
=
{
6
,
6
};
...
...
@@ -237,8 +237,8 @@ TEST_F(MindDataTestImageFolderSampler, TestWeightedRandomSamplerImageFolder) {
std
::
vector
<
double
>
weights
(
total_samples
,
std
::
rand
()
%
100
);
// create sampler with replacement = replacement
std
::
unique
_ptr
<
Sampler
>
sampler
=
std
::
make_
unique
<
WeightedRandomSampler
>
(
weights
,
num_sample
s
,
true
,
samples_per_buffer
);
std
::
shared
_ptr
<
Sampler
>
sampler
=
std
::
make_
shared
<
WeightedRandomSampler
>
(
num_samples
,
weight
s
,
true
,
samples_per_buffer
);
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data"
;
auto
tree
=
Build
({
ImageFolder
(
16
,
2
,
32
,
folder_path
,
false
,
std
::
move
(
sampler
))});
...
...
@@ -295,7 +295,8 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderClassIndex) {
}
TEST_F
(
MindDataTestImageFolderSampler
,
TestDistributedSampler
)
{
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
DistributedSampler
>
(
11
,
10
,
false
);
int64_t
num_samples
=
0
;
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_shared
<
DistributedSampler
>
(
num_samples
,
11
,
10
,
false
);
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data"
;
auto
tree
=
Build
({
ImageFolder
(
16
,
2
,
32
,
folder_path
,
false
,
std
::
move
(
sampler
)),
Repeat
(
4
)});
tree
->
Prepare
();
...
...
@@ -322,7 +323,8 @@ TEST_F(MindDataTestImageFolderSampler, TestDistributedSampler) {
}
TEST_F
(
MindDataTestImageFolderSampler
,
TestPKSamplerImageFolder
)
{
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
PKSampler
>
(
3
,
false
,
4
);
int64_t
num_samples
=
0
;
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_shared
<
PKSampler
>
(
num_samples
,
3
,
false
,
4
);
int32_t
res
[]
=
{
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
3
};
// ground truth label
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data"
;
auto
tree
=
Build
({
ImageFolder
(
16
,
2
,
32
,
folder_path
,
false
,
std
::
move
(
sampler
))});
...
...
@@ -349,39 +351,16 @@ TEST_F(MindDataTestImageFolderSampler, TestPKSamplerImageFolder) {
}
}
TEST_F
(
MindDataTestImageFolderSampler
,
TestImageFolderNumSamples
)
{
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data"
;
auto
tree
=
Build
({
ImageFolder
(
16
,
2
,
32
,
folder_path
,
false
,
nullptr
,
{},
11
),
Repeat
(
2
)});
tree
->
Prepare
();
Status
rc
=
tree
->
Launch
();
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
"Return code error detected during tree launch: "
<<
common
::
SafeCStr
(
rc
.
ToString
())
<<
"."
;
EXPECT_TRUE
(
false
);
}
else
{
DatasetIterator
di
(
tree
);
TensorMap
tensor_map
;
di
.
GetNextAsMap
(
&
tensor_map
);
EXPECT_TRUE
(
rc
.
IsOk
());
uint64_t
i
=
0
;
int32_t
label
=
0
;
while
(
tensor_map
.
size
()
!=
0
)
{
tensor_map
[
"label"
]
->
GetItemAt
<
int32_t
>
(
&
label
,
{});
EXPECT_TRUE
(
0
==
label
);
MS_LOG
(
DEBUG
)
<<
"row: "
<<
i
<<
"
\t
"
<<
tensor_map
[
"image"
]
->
shape
()
<<
"label:"
<<
label
<<
"
\n
"
;
i
++
;
di
.
GetNextAsMap
(
&
tensor_map
);
}
EXPECT_TRUE
(
i
==
22
);
}
}
TEST_F
(
MindDataTestImageFolderSampler
,
TestImageFolderDecode
)
{
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data"
;
std
::
map
<
std
::
string
,
int32_t
>
map
;
map
[
"class3"
]
=
333
;
map
[
"class1"
]
=
111
;
map
[
"wrong folder name"
]
=
1234
;
// this is skipped
auto
tree
=
Build
({
ImageFolder
(
16
,
2
,
32
,
folder_path
,
false
,
nullptr
,
map
,
20
,
true
)});
int64_t
num_samples
=
20
;
int64_t
start_index
=
0
;
auto
seq_sampler
=
std
::
make_shared
<
SequentialSampler
>
(
num_samples
,
start_index
);
auto
tree
=
Build
({
ImageFolder
(
16
,
2
,
32
,
folder_path
,
false
,
std
::
move
(
seq_sampler
),
map
,
true
)});
int64_t
res
[
2
]
=
{
111
,
333
};
tree
->
Prepare
();
Status
rc
=
tree
->
Launch
();
...
...
@@ -408,33 +387,12 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderDecode) {
}
}
TEST_F
(
MindDataTestImageFolderSampler
,
TestImageFolderDatasetSize
)
{
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data"
;
int64_t
num_rows
=
0
;
int64_t
num_classes
=
0
;
ImageFolderOp
::
CountRowsAndClasses
(
folder_path
,
15
,
{},
&
num_rows
,
&
num_classes
);
EXPECT_TRUE
(
num_rows
==
15
&&
num_classes
==
4
);
ImageFolderOp
::
CountRowsAndClasses
(
folder_path
,
44
,
{},
&
num_rows
,
&
num_classes
);
EXPECT_TRUE
(
num_rows
==
44
&&
num_classes
==
4
);
ImageFolderOp
::
CountRowsAndClasses
(
folder_path
,
0
,
{},
&
num_rows
,
&
num_classes
);
EXPECT_TRUE
(
num_rows
==
44
&&
num_classes
==
4
);
ImageFolderOp
::
CountRowsAndClasses
(
folder_path
,
55
,
{},
&
num_rows
,
&
num_classes
);
EXPECT_TRUE
(
num_rows
==
44
&&
num_classes
==
4
);
ImageFolderOp
::
CountRowsAndClasses
(
folder_path
,
44
,
{},
&
num_rows
,
&
num_classes
,
2
,
3
);
EXPECT_TRUE
(
num_rows
==
15
&&
num_classes
==
4
);
ImageFolderOp
::
CountRowsAndClasses
(
folder_path
,
33
,
{},
&
num_rows
,
&
num_classes
,
0
,
3
);
EXPECT_TRUE
(
num_rows
==
15
&&
num_classes
==
4
);
ImageFolderOp
::
CountRowsAndClasses
(
folder_path
,
13
,
{},
&
num_rows
,
&
num_classes
,
0
,
11
);
EXPECT_TRUE
(
num_rows
==
4
&&
num_classes
==
4
);
ImageFolderOp
::
CountRowsAndClasses
(
folder_path
,
3
,
{},
&
num_rows
,
&
num_classes
,
0
,
11
);
EXPECT_TRUE
(
num_rows
==
3
&&
num_classes
==
4
);
}
TEST_F
(
MindDataTestImageFolderSampler
,
TestImageFolderSharding1
)
{
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
DistributedSampler
>
(
4
,
0
,
false
);
int64_t
num_samples
=
5
;
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_shared
<
DistributedSampler
>
(
num_samples
,
4
,
0
,
false
);
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data"
;
// numWrks, rows, conns, path, shuffle, sampler, map, numSamples, decode
auto
tree
=
Build
({
ImageFolder
(
16
,
2
,
32
,
folder_path
,
false
,
std
::
move
(
sampler
),
{}
,
5
)});
auto
tree
=
Build
({
ImageFolder
(
16
,
2
,
32
,
folder_path
,
false
,
std
::
move
(
sampler
),
{})});
tree
->
Prepare
();
Status
rc
=
tree
->
Launch
();
int32_t
labels
[
5
]
=
{
0
,
0
,
0
,
1
,
1
};
...
...
@@ -460,10 +418,11 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding1) {
}
TEST_F
(
MindDataTestImageFolderSampler
,
TestImageFolderSharding2
)
{
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
DistributedSampler
>
(
4
,
3
,
false
);
int64_t
num_samples
=
12
;
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_shared
<
DistributedSampler
>
(
num_samples
,
4
,
3
,
false
);
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data"
;
// numWrks, rows, conns, path, shuffle, sampler, map, numSamples, decode
auto
tree
=
Build
({
ImageFolder
(
16
,
16
,
32
,
folder_path
,
false
,
std
::
move
(
sampler
),
{}
,
12
)});
auto
tree
=
Build
({
ImageFolder
(
16
,
16
,
32
,
folder_path
,
false
,
std
::
move
(
sampler
),
{})});
tree
->
Prepare
();
Status
rc
=
tree
->
Launch
();
uint32_t
labels
[
11
]
=
{
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
3
};
...
...
tests/ut/cpp/dataset/manifest_op_test.cc
浏览文件 @
769ae609
...
...
@@ -23,6 +23,7 @@
#include "dataset/core/client.h"
#include "dataset/core/global_context.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/util/de_error.h"
#include "dataset/util/status.h"
...
...
@@ -42,14 +43,13 @@ std::shared_ptr<RepeatOp> Repeat(int repeatCnt);
std
::
shared_ptr
<
ExecutionTree
>
Build
(
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
ops
);
std
::
shared_ptr
<
ManifestOp
>
Manifest
(
int32_t
num_works
,
int32_t
rows
,
int32_t
conns
,
const
std
::
string
&
file
,
std
::
string
usage
=
"train"
,
std
::
unique_ptr
<
Sampler
>
sampler
=
nullptr
,
std
::
map
<
std
::
string
,
int32_t
>
map
=
{},
uint64_t
num_samples
=
0
,
bool
decode
=
false
)
{
std
::
string
usage
=
"train"
,
std
::
shared_ptr
<
Sampler
>
sampler
=
nullptr
,
std
::
map
<
std
::
string
,
int32_t
>
map
=
{},
bool
decode
=
false
)
{
std
::
shared_ptr
<
ManifestOp
>
so
;
ManifestOp
::
Builder
builder
;
Status
rc
=
builder
.
SetNumWorkers
(
num_works
).
SetManifestFile
(
file
).
SetRowsPerBuffer
(
rows
).
SetOpConnectorSize
(
conns
).
SetSampler
(
std
::
move
(
sampler
)).
SetClassIndex
(
map
).
SetDecode
(
decode
)
.
Set
NumSamples
(
num_samples
).
Set
Usage
(
usage
).
Build
(
&
so
);
.
SetUsage
(
usage
).
Build
(
&
so
);
return
so
;
}
...
...
@@ -86,7 +86,8 @@ TEST_F(MindDataTestManifest, TestSequentialManifestWithRepeat) {
TEST_F
(
MindDataTestManifest
,
TestSubsetRandomSamplerManifest
)
{
std
::
vector
<
int64_t
>
indices
({
1
});
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
SubsetRandomSampler
>
(
indices
);
int64_t
num_samples
=
0
;
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_shared
<
SubsetRandomSampler
>
(
num_samples
,
indices
);
std
::
string
file
=
datasets_root_path_
+
"/testManifestData/cpp.json"
;
// Expect 6 samples for label 0 and 1
auto
tree
=
Build
({
Manifest
(
16
,
2
,
32
,
file
,
"train"
,
std
::
move
(
sampler
))});
...
...
@@ -145,7 +146,10 @@ TEST_F(MindDataTestManifest, MindDataTestManifestClassIndex) {
TEST_F
(
MindDataTestManifest
,
MindDataTestManifestNumSamples
)
{
std
::
string
file
=
datasets_root_path_
+
"/testManifestData/cpp.json"
;
auto
tree
=
Build
({
Manifest
(
16
,
2
,
32
,
file
,
"train"
,
nullptr
,
{},
1
),
Repeat
(
4
)});
int64_t
num_samples
=
1
;
int64_t
start_index
=
0
;
auto
seq_sampler
=
std
::
make_shared
<
SequentialSampler
>
(
num_samples
,
start_index
);
auto
tree
=
Build
({
Manifest
(
16
,
2
,
32
,
file
,
"train"
,
std
::
move
(
seq_sampler
),
{}),
Repeat
(
4
)});
tree
->
Prepare
();
Status
rc
=
tree
->
Launch
();
if
(
rc
.
IsError
())
{
...
...
@@ -171,7 +175,10 @@ TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) {
TEST_F
(
MindDataTestManifest
,
MindDataTestManifestEval
)
{
std
::
string
file
=
datasets_root_path_
+
"/testManifestData/cpp.json"
;
auto
tree
=
Build
({
Manifest
(
16
,
2
,
32
,
file
,
"eval"
,
nullptr
,
{},
1
)});
int64_t
num_samples
=
1
;
int64_t
start_index
=
0
;
auto
seq_sampler
=
std
::
make_shared
<
SequentialSampler
>
(
num_samples
,
start_index
);
auto
tree
=
Build
({
Manifest
(
16
,
2
,
32
,
file
,
"eval"
,
std
::
move
(
seq_sampler
),
{})});
tree
->
Prepare
();
Status
rc
=
tree
->
Launch
();
if
(
rc
.
IsError
())
{
...
...
tests/ut/cpp/dataset/map_op_test.cc
浏览文件 @
769ae609
...
...
@@ -120,9 +120,8 @@ class MindDataTestMapOp : public UT::DatasetOpTesting {
};
std
::
shared_ptr
<
ImageFolderOp
>
ImageFolder
(
int64_t
num_works
,
int64_t
rows
,
int64_t
conns
,
std
::
string
path
,
bool
shuf
=
false
,
std
::
unique_ptr
<
Sampler
>
sampler
=
nullptr
,
std
::
map
<
std
::
string
,
int32_t
>
map
=
{},
int64_t
num_samples
=
0
,
bool
decode
=
false
);
bool
shuf
=
false
,
std
::
shared_ptr
<
Sampler
>
sampler
=
nullptr
,
std
::
map
<
std
::
string
,
int32_t
>
map
=
{},
bool
decode
=
false
);
std
::
shared_ptr
<
ExecutionTree
>
Build
(
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
ops
);
...
...
tests/ut/cpp/dataset/mnist_op_test.cc
浏览文件 @
769ae609
...
...
@@ -53,13 +53,11 @@ Status Create1DTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements,
DataType
::
Type
data_type
=
DataType
::
DE_UINT32
);
std
::
shared_ptr
<
MnistOp
>
CreateMnist
(
int64_t
num_wrks
,
int64_t
rows
,
int64_t
conns
,
std
::
string
path
,
bool
shuf
=
false
,
std
::
unique_ptr
<
Sampler
>
sampler
=
nullptr
,
int64_t
num_samples
=
0
)
{
bool
shuf
=
false
,
std
::
shared_ptr
<
Sampler
>
sampler
=
nullptr
)
{
std
::
shared_ptr
<
MnistOp
>
so
;
MnistOp
::
Builder
builder
;
Status
rc
=
builder
.
SetNumWorkers
(
num_wrks
).
SetDir
(
path
).
SetRowsPerBuffer
(
rows
)
.
SetOpConnectorSize
(
conns
).
SetSampler
(
std
::
move
(
sampler
))
.
SetNumSamples
(
num_samples
).
Build
(
&
so
);
.
SetOpConnectorSize
(
conns
).
SetSampler
(
std
::
move
(
sampler
)).
Build
(
&
so
);
return
so
;
}
...
...
@@ -74,7 +72,10 @@ TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) {
// appear in this dataset
// Example: python tests/dataset/data/prep_data.py
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
auto
tree
=
Build
({
CreateMnist
(
16
,
2
,
32
,
folder_path
,
false
,
nullptr
,
10
),
Repeat
(
2
)});
int64_t
num_samples
=
10
;
int64_t
start_index
=
0
;
auto
seq_sampler
=
std
::
make_shared
<
SequentialSampler
>
(
num_samples
,
start_index
);
auto
tree
=
Build
({
CreateMnist
(
16
,
2
,
32
,
folder_path
,
false
,
std
::
move
(
seq_sampler
)),
Repeat
(
2
)});
tree
->
Prepare
();
uint32_t
res
[]
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
Status
rc
=
tree
->
Launch
();
...
...
@@ -101,7 +102,10 @@ TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) {
TEST_F
(
MindDataTestMnistSampler
,
TestSequentialImageFolderWithRepeatBatch
)
{
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
auto
tree
=
Build
({
CreateMnist
(
16
,
2
,
32
,
folder_path
,
false
,
nullptr
,
10
),
Repeat
(
2
),
Batch
(
5
)});
int64_t
num_samples
=
10
;
int64_t
start_index
=
0
;
auto
seq_sampler
=
std
::
make_shared
<
SequentialSampler
>
(
num_samples
,
start_index
);
auto
tree
=
Build
({
CreateMnist
(
16
,
2
,
32
,
folder_path
,
false
,
std
::
move
(
seq_sampler
)),
Repeat
(
2
),
Batch
(
5
)});
tree
->
Prepare
();
uint32_t
res
[
4
][
5
]
=
{
{
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
},
...
...
tests/ut/cpp/dataset/stand_alone_samplers_test.cc
浏览文件 @
769ae609
...
...
@@ -43,20 +43,11 @@ class MindDataTestStandAloneSampler : public UT::DatasetOpTesting {
protected:
class
MockStorageOp
:
public
RandomAccessOp
{
public:
MockStorageOp
(
int64_t
val
)
:
m_val_
(
val
)
{}
Status
GetNumSamples
(
int64_t
*
ptr
)
const
override
{
(
*
ptr
)
=
m_val_
;
return
Status
::
OK
();
}
Status
GetNumRowsInDataset
(
int64_t
*
ptr
)
const
override
{
(
*
ptr
)
=
m_val_
;
return
Status
::
OK
();
MockStorageOp
(
int64_t
val
){
// row count is in base class as protected member
// GetNumRowsInDataset does not need an override, the default from base class is fine.
num_rows_
=
val
;
}
private:
int64_t
m_val_
;
};
};
...
...
@@ -73,8 +64,9 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) {
MockStorageOp
mock
(
20
);
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
shared_ptr
<
Tensor
>
tensor
;
int64_t
num_samples
=
0
;
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
DistributedSampler
>
(
3
,
i
%
3
,
(
i
<
3
?
false
:
true
));
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_shared
<
DistributedSampler
>
(
num_samples
,
3
,
i
%
3
,
(
i
<
3
?
false
:
true
));
sampler
->
HandshakeRandomAccessOp
(
&
mock
);
sampler
->
GetNextBuffer
(
&
db
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
...
...
@@ -92,7 +84,9 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) {
std
::
shared_ptr
<
Tensor
>
label1
,
label2
;
CreateINT64Tensor
(
&
label1
,
3
,
reinterpret_cast
<
unsigned
char
*>
(
res
));
CreateINT64Tensor
(
&
label2
,
2
,
reinterpret_cast
<
unsigned
char
*>
(
res
+
3
));
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_shared
<
SequentialSampler
>
(
3
);
int64_t
num_samples
=
0
;
int64_t
start_index
=
0
;
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_shared
<
SequentialSampler
>
(
num_samples
,
start_index
,
3
);
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
shared_ptr
<
Tensor
>
tensor
;
sampler
->
HandshakeRandomAccessOp
(
&
mock
);
...
...
tests/ut/cpp/dataset/subset_random_sampler_test.cc
浏览文件 @
769ae609
...
...
@@ -31,26 +31,17 @@ class MindDataTestSubsetRandomSampler : public UT::Common {
public:
class
DummyRandomAccessOp
:
public
RandomAccessOp
{
public:
DummyRandomAccessOp
(
int64_t
num_rows
)
:
num_rows_
(
num_rows
)
{};
Status
GetNumSamples
(
int64_t
*
num
)
const
{
*
num
=
num_rows_
;
return
Status
::
OK
();
}
Status
GetNumRowsInDataset
(
int64_t
*
num
)
const
{
*
num
=
num_rows_
;
return
Status
::
OK
();
}
private:
int64_t
num_rows_
;
DummyRandomAccessOp
(
int64_t
num_rows
)
{
num_rows_
=
num_rows
;
// base class
};
};
};
TEST_F
(
MindDataTestSubsetRandomSampler
,
TestAllAtOnce
)
{
std
::
vector
<
int64_t
>
in
({
0
,
1
,
2
,
3
,
4
});
std
::
unordered_set
<
int64_t
>
in_set
(
in
.
begin
(),
in
.
end
());
SubsetRandomSampler
sampler
(
in
);
int64_t
num_samples
=
0
;
SubsetRandomSampler
sampler
(
num_samples
,
in
);
DummyRandomAccessOp
dummyRandomAccessOp
(
5
);
sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessOp
);
...
...
@@ -77,8 +68,9 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
TEST_F
(
MindDataTestSubsetRandomSampler
,
TestGetNextBuffer
)
{
int64_t
total_samples
=
100000
-
5
;
int64_t
samples_per_buffer
=
10
;
int64_t
num_samples
=
0
;
std
::
vector
<
int64_t
>
input
(
total_samples
,
1
);
SubsetRandomSampler
sampler
(
input
,
samples_per_buffer
);
SubsetRandomSampler
sampler
(
num_samples
,
input
,
samples_per_buffer
);
DummyRandomAccessOp
dummyRandomAccessOp
(
total_samples
);
sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessOp
);
...
...
@@ -109,7 +101,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
TEST_F
(
MindDataTestSubsetRandomSampler
,
TestReset
)
{
std
::
vector
<
int64_t
>
in
({
0
,
1
,
2
,
3
,
4
});
std
::
unordered_set
<
int64_t
>
in_set
(
in
.
begin
(),
in
.
end
());
SubsetRandomSampler
sampler
(
in
);
int64_t
num_samples
=
0
;
SubsetRandomSampler
sampler
(
num_samples
,
in
);
DummyRandomAccessOp
dummyRandomAccessOp
(
5
);
sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessOp
);
...
...
tests/ut/cpp/dataset/weighted_random_sampler_test.cc
浏览文件 @
769ae609
...
...
@@ -35,19 +35,11 @@ class MindDataTestWeightedRandomSampler : public UT::Common {
public:
class
DummyRandomAccessOp
:
public
RandomAccessOp
{
public:
DummyRandomAccessOp
(
uint64_t
num_rows
)
:
num_rows_
(
num_rows
)
{};
Status
GetNumSamples
(
int64_t
*
num
)
const
{
*
num
=
num_rows_
;
return
Status
::
OK
()
;
DummyRandomAccessOp
(
uint64_t
num_rows
)
{
// row count is in base class as protected member
// GetNumRowsInDataset does not need an override, the default from base class is fine.
num_rows_
=
num_rows
;
}
Status
GetNumRowsInDataset
(
int64_t
*
num
)
const
{
*
num
=
num_rows_
;
return
Status
::
OK
();
}
private:
uint64_t
num_rows_
;
};
};
...
...
@@ -59,7 +51,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
std
::
vector
<
uint64_t
>
freq
(
total_samples
,
0
);
// create sampler with replacement = true
WeightedRandomSampler
m_sampler
(
weights
,
num_sample
s
,
true
);
WeightedRandomSampler
m_sampler
(
num_samples
,
weight
s
,
true
);
DummyRandomAccessOp
dummyRandomAccessOp
(
total_samples
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessOp
);
...
...
@@ -89,7 +81,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
std
::
vector
<
uint64_t
>
freq
(
total_samples
,
0
);
// create sampler with replacement = replacement
WeightedRandomSampler
m_sampler
(
weights
,
num_sample
s
,
false
);
WeightedRandomSampler
m_sampler
(
num_samples
,
weight
s
,
false
);
DummyRandomAccessOp
dummyRandomAccessOp
(
total_samples
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessOp
);
...
...
@@ -125,7 +117,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
std
::
vector
<
double
>
weights
(
total_samples
,
std
::
rand
()
%
100
);
// create sampler with replacement = replacement
WeightedRandomSampler
m_sampler
(
weights
,
num_sample
s
,
true
,
samples_per_buffer
);
WeightedRandomSampler
m_sampler
(
num_samples
,
weight
s
,
true
,
samples_per_buffer
);
DummyRandomAccessOp
dummyRandomAccessOp
(
total_samples
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessOp
);
...
...
@@ -161,7 +153,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
std
::
vector
<
uint64_t
>
freq
(
total_samples
,
0
);
// create sampler with replacement = replacement
WeightedRandomSampler
m_sampler
(
weights
,
num_sample
s
,
false
,
samples_per_buffer
);
WeightedRandomSampler
m_sampler
(
num_samples
,
weight
s
,
false
,
samples_per_buffer
);
DummyRandomAccessOp
dummyRandomAccessOp
(
total_samples
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessOp
);
...
...
@@ -202,7 +194,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
std
::
vector
<
uint64_t
>
freq
(
total_samples
,
0
);
// create sampler with replacement = true
WeightedRandomSampler
m_sampler
(
weights
,
num_sample
s
,
true
);
WeightedRandomSampler
m_sampler
(
num_samples
,
weight
s
,
true
);
DummyRandomAccessOp
dummyRandomAccessOp
(
total_samples
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessOp
);
...
...
@@ -247,7 +239,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
std
::
vector
<
uint64_t
>
freq
(
total_samples
,
0
);
// create sampler with replacement = true
WeightedRandomSampler
m_sampler
(
weights
,
num_sample
s
,
false
);
WeightedRandomSampler
m_sampler
(
num_samples
,
weight
s
,
false
);
DummyRandomAccessOp
dummyRandomAccessOp
(
total_samples
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessOp
);
...
...
tests/ut/python/dataset/test_datasets_imagefolder.py
浏览文件 @
769ae609
...
...
@@ -58,7 +58,7 @@ def test_imagefolder_numsamples():
assert
num_iter
==
10
random_sampler
=
ds
.
RandomSampler
(
num_samples
=
3
,
replacement
=
True
)
data1
=
ds
.
ImageFolderDatasetV2
(
DATA_DIR
,
num_
samples
=
10
,
num_
parallel_workers
=
2
,
sampler
=
random_sampler
)
data1
=
ds
.
ImageFolderDatasetV2
(
DATA_DIR
,
num_parallel_workers
=
2
,
sampler
=
random_sampler
)
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
...
...
@@ -67,7 +67,7 @@ def test_imagefolder_numsamples():
assert
num_iter
==
3
random_sampler
=
ds
.
RandomSampler
(
num_samples
=
3
,
replacement
=
False
)
data1
=
ds
.
ImageFolderDatasetV2
(
DATA_DIR
,
num_
samples
=
10
,
num_
parallel_workers
=
2
,
sampler
=
random_sampler
)
data1
=
ds
.
ImageFolderDatasetV2
(
DATA_DIR
,
num_parallel_workers
=
2
,
sampler
=
random_sampler
)
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
...
...
tests/ut/python/dataset/test_datasets_sharding.py
浏览文件 @
769ae609
...
...
@@ -162,8 +162,8 @@ def test_voc_shardings(print_res=False):
voc_dir
=
"../data/dataset/testVOC2012"
def
sharding_config
(
num_shards
,
shard_id
,
num_samples
,
shuffle
,
repeat_cnt
=
1
):
sampler
=
ds
.
DistributedSampler
(
num_shards
,
shard_id
,
shuffle
=
shuffle
)
data1
=
ds
.
VOCDataset
(
voc_dir
,
decode
=
True
,
sampler
=
sampler
,
num_samples
=
num_samples
)
sampler
=
ds
.
DistributedSampler
(
num_shards
,
shard_id
,
shuffle
=
shuffle
,
num_samples
=
num_samples
)
data1
=
ds
.
VOCDataset
(
voc_dir
,
decode
=
True
,
sampler
=
sampler
)
data1
=
data1
.
repeat
(
repeat_cnt
)
res
=
[]
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
...
...
tests/ut/python/dataset/test_exceptions.py
浏览文件 @
769ae609
...
...
@@ -35,18 +35,13 @@ def test_exception_01():
def
test_exception_02
():
"""
Test
multiple exceptions with in
valid input
Test
exceptions with invalid input, and test
valid input
"""
logger
.
info
(
"test_exception_02"
)
num_samples
=
0
with
pytest
.
raises
(
ValueError
)
as
info
:
data
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
num_samples
=
num_samples
)
assert
"num_samples must be greater than 0"
in
str
(
info
.
value
)
num_samples
=
-
1
with
pytest
.
raises
(
ValueError
)
as
info
:
data
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
num_samples
=
num_samples
)
assert
"num_samples
must be greater
than 0"
in
str
(
info
.
value
)
assert
"num_samples
cannot be less
than 0"
in
str
(
info
.
value
)
num_samples
=
1
data
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
num_samples
=
num_samples
)
...
...
tests/ut/python/dataset/test_generator.py
浏览文件 @
769ae609
...
...
@@ -544,7 +544,7 @@ def test_distributed_sampler():
def
test_num_samples
():
source
=
[(
np
.
array
([
x
]),)
for
x
in
range
(
64
)]
num_samples
=
32
ds1
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
sampler
=
ds
.
SequentialSampler
(
),
num_samples
=
num_samples
)
ds1
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
sampler
=
ds
.
SequentialSampler
(
num_samples
=
num_samples
)
)
ds2
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
sampler
=
[
i
for
i
in
range
(
32
)],
num_samples
=
num_samples
)
ds3
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
],
num_samples
=
num_samples
)
...
...
@@ -660,4 +660,6 @@ if __name__ == "__main__":
test_sequential_sampler
()
test_distributed_sampler
()
test_random_sampler
()
test_num_samples
()
test_num_samples_underflow
()
test_schema
()
tests/ut/python/dataset/test_sampler.py
浏览文件 @
769ae609
...
...
@@ -28,8 +28,8 @@ def test_sequential_sampler(print_res=False):
map_
=
{(
172876
,
0
):
0
,
(
54214
,
0
):
1
,
(
54214
,
1
):
2
,
(
173673
,
0
):
3
,
(
64631
,
1
):
4
}
def
test_config
(
num_samples
,
num_repeats
=
None
):
sampler
=
ds
.
SequentialSampler
()
data1
=
ds
.
ManifestDataset
(
manifest_file
,
num_samples
=
num_samples
,
sampler
=
sampler
)
sampler
=
ds
.
SequentialSampler
(
num_samples
=
num_samples
)
data1
=
ds
.
ManifestDataset
(
manifest_file
,
sampler
=
sampler
)
if
num_repeats
is
not
None
:
data1
=
data1
.
repeat
(
num_repeats
)
res
=
[]
...
...
@@ -43,6 +43,7 @@ def test_sequential_sampler(print_res=False):
assert
test_config
(
num_samples
=
3
,
num_repeats
=
None
)
==
[
0
,
1
,
2
]
assert
test_config
(
num_samples
=
None
,
num_repeats
=
2
)
==
[
0
,
1
,
2
,
3
,
4
]
*
2
assert
test_config
(
num_samples
=
0
,
num_repeats
=
2
)
==
[
0
,
1
,
2
,
3
,
4
]
*
2
assert
test_config
(
num_samples
=
4
,
num_repeats
=
2
)
==
[
0
,
1
,
2
,
3
]
*
2
...
...
@@ -119,8 +120,8 @@ def test_python_sampler():
return
iter
([
i
for
i
in
range
(
self
.
dataset_size
)])
class
Sp2
(
ds
.
Sampler
):
def
__init__
(
self
):
super
(
Sp2
,
self
).
__init__
()
def
__init__
(
self
,
num_samples
=
None
):
super
(
Sp2
,
self
).
__init__
(
num_samples
)
# at this stage, self.dataset_size and self.num_samples are not yet known
self
.
cnt
=
0
...
...
@@ -130,8 +131,8 @@ def test_python_sampler():
def
reset
(
self
):
self
.
cnt
=
(
self
.
cnt
+
1
)
%
self
.
dataset_size
def
test_config
(
num_
samples
,
num_
repeats
,
sampler
):
data1
=
ds
.
ManifestDataset
(
manifest_file
,
num_samples
=
num_samples
,
sampler
=
sampler
)
def
test_config
(
num_repeats
,
sampler
):
data1
=
ds
.
ManifestDataset
(
manifest_file
,
sampler
=
sampler
)
if
num_repeats
is
not
None
:
data1
=
data1
.
repeat
(
num_repeats
)
res
=
[]
...
...
@@ -154,8 +155,8 @@ def test_python_sampler():
assert
data
[
0
]
==
(
np
.
array
(
i
),)
i
=
i
-
1
assert
test_config
(
5
,
2
,
Sp1
(
))
==
[
0
,
1
,
2
,
3
,
4
,
0
,
1
,
2
,
3
,
4
]
assert
test_config
(
2
,
6
,
Sp2
(
))
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
,
4
,
4
,
0
,
0
]
assert
test_config
(
2
,
Sp1
(
5
))
==
[
0
,
1
,
2
,
3
,
4
,
0
,
1
,
2
,
3
,
4
]
assert
test_config
(
6
,
Sp2
(
2
))
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
,
4
,
4
,
0
,
0
]
test_generator
()
sp1
=
Sp1
().
create
()
...
...
@@ -169,9 +170,8 @@ def test_subset_sampler():
manifest_file
=
"../data/dataset/testManifestData/test5trainimgs.json"
map_
=
{(
172876
,
0
):
0
,
(
54214
,
0
):
1
,
(
54214
,
1
):
2
,
(
173673
,
0
):
3
,
(
64631
,
1
):
4
}
def
test_config
(
num_samples
,
start_index
,
subset_size
):
_
=
num_samples
sampler
=
ds
.
SubsetSampler
(
start_index
,
subset_size
)
def
test_config
(
start_index
,
num_samples
):
sampler
=
ds
.
SequentialSampler
(
start_index
,
num_samples
)
d
=
ds
.
ManifestDataset
(
manifest_file
,
sampler
=
sampler
)
res
=
[]
...
...
@@ -180,19 +180,15 @@ def test_subset_sampler():
return
res
with
pytest
.
raises
(
RuntimeError
)
as
info
:
test_config
(
5
,
0
,
0
)
assert
"subset_size <= 0"
in
str
(
info
.
value
)
assert
test_config
(
5
,
0
,
1
)
==
[
0
]
assert
test_config
(
5
,
0
,
2
)
==
[
0
,
1
]
assert
test_config
(
5
,
0
,
3
)
==
[
0
,
1
,
2
]
assert
test_config
(
5
,
0
,
4
)
==
[
0
,
1
,
2
,
3
]
assert
test_config
(
5
,
0
,
5
)
==
[
0
,
1
,
2
,
3
,
4
]
assert
test_config
(
5
,
1
,
1
)
==
[
1
]
assert
test_config
(
5
,
2
,
3
)
==
[
2
,
3
,
4
]
assert
test_config
(
5
,
3
,
2
)
==
[
3
,
4
]
assert
test_config
(
5
,
4
,
1
)
==
[
4
]
assert
test_config
(
0
,
1
)
==
[
0
]
assert
test_config
(
0
,
2
)
==
[
0
,
1
]
assert
test_config
(
0
,
3
)
==
[
0
,
1
,
2
]
assert
test_config
(
0
,
4
)
==
[
0
,
1
,
2
,
3
]
assert
test_config
(
0
,
5
)
==
[
0
,
1
,
2
,
3
,
4
]
assert
test_config
(
1
,
1
)
==
[
1
]
assert
test_config
(
2
,
3
)
==
[
2
,
3
,
4
]
assert
test_config
(
3
,
2
)
==
[
3
,
4
]
assert
test_config
(
4
,
1
)
==
[
4
]
def
test_sampler_chain
():
...
...
@@ -200,11 +196,11 @@ def test_sampler_chain():
map_
=
{(
172876
,
0
):
0
,
(
54214
,
0
):
1
,
(
54214
,
1
):
2
,
(
173673
,
0
):
3
,
(
64631
,
1
):
4
}
def
test_config
(
num_shards
,
shard_id
):
sampler
=
ds
.
DistributedSampler
(
num_shards
,
shard_id
,
False
)
sampler
=
ds
.
DistributedSampler
(
num_shards
,
shard_id
,
shuffle
=
False
,
num_samples
=
5
)
child_sampler
=
ds
.
SequentialSampler
()
sampler
.
add_child
(
child_sampler
)
data1
=
ds
.
ManifestDataset
(
manifest_file
,
num_samples
=
5
,
sampler
=
sampler
)
data1
=
ds
.
ManifestDataset
(
manifest_file
,
sampler
=
sampler
)
res
=
[]
for
item
in
data1
.
create_dict_iterator
():
...
...
@@ -234,6 +230,11 @@ def test_add_sampler_invalid_input():
data1
.
use_sampler
(
"sampler"
)
assert
"not an instance of a sampler"
in
str
(
info
.
value
)
sampler
=
ds
.
SequentialSampler
()
with
pytest
.
raises
(
ValueError
)
as
info
:
data2
=
ds
.
ManifestDataset
(
manifest_file
,
sampler
=
sampler
,
num_samples
=
20
)
assert
"Conflicting arguments during sampler assignments"
in
str
(
info
.
value
)
if
__name__
==
'__main__'
:
test_sequential_sampler
(
True
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录