Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9739d3b0
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9739d3b0
编写于
3月 29, 2020
作者:
J
Junhan Hu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add CPP sampler support for GeneratorDataset
上级
9a781025
变更
31
隐藏空白更改
内联
并排
Showing
31 changed file
with
432 addition
and
127 deletion
+432
-127
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+1
-1
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+10
-1
mindspore/ccsrc/dataset/core/tensor.cc
mindspore/ccsrc/dataset/core/tensor.cc
+2
-0
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
...spore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
...ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
...ore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
...t/engine/datasetops/source/sampler/distributed_sampler.cc
+3
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h
...et/engine/datasetops/source/sampler/distributed_sampler.h
+2
-4
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
...rc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
+9
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h
...src/dataset/engine/datasetops/source/sampler/pk_sampler.h
+4
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
...ataset/engine/datasetops/source/sampler/random_sampler.cc
+2
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h
...dataset/engine/datasetops/source/sampler/random_sampler.h
+2
-4
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
...ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
+47
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
.../ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
+15
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
...et/engine/datasetops/source/sampler/sequential_sampler.cc
+1
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h
...set/engine/datasetops/source/sampler/sequential_sampler.h
+2
-4
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
...engine/datasetops/source/sampler/subset_random_sampler.cc
+2
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h
.../engine/datasetops/source/sampler/subset_random_sampler.h
+1
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
...gine/datasetops/source/sampler/weighted_random_sampler.cc
+16
-16
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h
...ngine/datasetops/source/sampler/weighted_random_sampler.h
+4
-1
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
+1
-1
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+159
-35
mindspore/dataset/engine/samplers.py
mindspore/dataset/engine/samplers.py
+0
-1
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+32
-12
tests/ut/cpp/dataset/stand_alone_samplers_test.cc
tests/ut/cpp/dataset/stand_alone_samplers_test.cc
+2
-2
tests/ut/cpp/dataset/subset_random_sampler_test.cc
tests/ut/cpp/dataset/subset_random_sampler_test.cc
+6
-6
tests/ut/cpp/dataset/weighted_random_sampler_test.cc
tests/ut/cpp/dataset/weighted_random_sampler_test.cc
+12
-12
tests/ut/python/dataset/test_generator.py
tests/ut/python/dataset/test_generator.py
+71
-0
tests/ut/python/dataset/test_sampler.py
tests/ut/python/dataset/test_sampler.py
+21
-0
未找到文件。
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
9739d3b0
...
@@ -517,7 +517,7 @@ Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<Datase
...
@@ -517,7 +517,7 @@ Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<Datase
std
::
string
key
=
py
::
str
(
arg
.
first
);
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
!
value
.
is_none
())
{
if
(
key
==
"
generator_function
"
)
{
if
(
key
==
"
source
"
)
{
py
::
object
obj
=
py
::
cast
(
&
value
);
py
::
object
obj
=
py
::
cast
(
&
value
);
if
(
!
py
::
isinstance
<
py
::
function
>
(
obj
))
{
if
(
!
py
::
isinstance
<
py
::
function
>
(
obj
))
{
std
::
string
err_msg
=
"Error: generator is invalid or not set."
;
std
::
string
err_msg
=
"Error: generator is invalid or not set."
;
...
...
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
9739d3b0
...
@@ -384,7 +384,16 @@ void bindTensorOps4(py::module *m) {
...
@@ -384,7 +384,16 @@ void bindTensorOps4(py::module *m) {
}
}
void
bindSamplerOps
(
py
::
module
*
m
)
{
void
bindSamplerOps
(
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
Sampler
,
std
::
shared_ptr
<
Sampler
>>
(
*
m
,
"Sampler"
);
(
void
)
py
::
class_
<
Sampler
,
std
::
shared_ptr
<
Sampler
>>
(
*
m
,
"Sampler"
)
.
def
(
"set_num_rows"
,
[](
Sampler
&
self
,
int64_t
rows
)
{
THROW_IF_ERROR
(
self
.
SetNumRowsInDataset
(
rows
));
})
.
def
(
"set_num_samples"
,
[](
Sampler
&
self
,
int64_t
samples
)
{
THROW_IF_ERROR
(
self
.
SetNumSamples
(
samples
));
})
.
def
(
"initialize"
,
[](
Sampler
&
self
)
{
THROW_IF_ERROR
(
self
.
InitSampler
());
})
.
def
(
"get_indices"
,
[](
Sampler
&
self
)
{
py
::
array
ret
;
THROW_IF_ERROR
(
self
.
GetAllIdsThenReset
(
&
ret
));
return
ret
;
});
(
void
)
py
::
class_
<
mindrecord
::
ShardOperator
,
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>>
(
*
m
,
"ShardOperator"
);
(
void
)
py
::
class_
<
mindrecord
::
ShardOperator
,
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>>
(
*
m
,
"ShardOperator"
);
(
void
)
py
::
class_
<
DistributedSampler
,
Sampler
,
std
::
shared_ptr
<
DistributedSampler
>>
(
*
m
,
"DistributedSampler"
)
(
void
)
py
::
class_
<
DistributedSampler
,
Sampler
,
std
::
shared_ptr
<
DistributedSampler
>>
(
*
m
,
"DistributedSampler"
)
...
...
mindspore/ccsrc/dataset/core/tensor.cc
浏览文件 @
9739d3b0
...
@@ -491,6 +491,8 @@ Status Tensor::GetItemAt(T *o, const std::vector<dsize_t> &index) const {
...
@@ -491,6 +491,8 @@ Status Tensor::GetItemAt(T *o, const std::vector<dsize_t> &index) const {
// return data as numpy, should return status
// return data as numpy, should return status
Status
Tensor
::
GetDataAsNumpy
(
py
::
array
*
data
)
{
Status
Tensor
::
GetDataAsNumpy
(
py
::
array
*
data
)
{
RETURN_UNEXPECTED_IF_NULL
(
data_
);
RETURN_UNEXPECTED_IF_NULL
(
data
);
if
(
type_
==
DataType
::
DE_BOOL
)
{
if
(
type_
==
DataType
::
DE_BOOL
)
{
*
data
=
py
::
array_t
<
bool
>
(
shape_
.
AsVector
(),
reinterpret_cast
<
bool
*>
(
data_
));
*
data
=
py
::
array_t
<
bool
>
(
shape_
.
AsVector
(),
reinterpret_cast
<
bool
*>
(
data_
));
}
else
if
(
type_
==
DataType
::
DE_INT8
)
{
}
else
if
(
type_
==
DataType
::
DE_INT8
)
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
浏览文件 @
9739d3b0
...
@@ -100,7 +100,7 @@ Status CelebAOp::LaunchThreadsAndInitOp() {
...
@@ -100,7 +100,7 @@ Status CelebAOp::LaunchThreadsAndInitOp() {
RETURN_IF_NOT_OK
(
tree_
->
LaunchWorkers
(
num_workers_
,
std
::
bind
(
&
CelebAOp
::
WorkerEntry
,
this
,
std
::
placeholders
::
_1
)));
RETURN_IF_NOT_OK
(
tree_
->
LaunchWorkers
(
num_workers_
,
std
::
bind
(
&
CelebAOp
::
WorkerEntry
,
this
,
std
::
placeholders
::
_1
)));
TaskManager
::
FindMe
()
->
Post
();
TaskManager
::
FindMe
()
->
Post
();
RETURN_IF_NOT_OK
(
ParseImageAttrInfo
());
RETURN_IF_NOT_OK
(
ParseImageAttrInfo
());
RETURN_IF_NOT_OK
(
sampler_
->
Init
(
this
));
RETURN_IF_NOT_OK
(
sampler_
->
HandshakeRandomAccessOp
(
this
));
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
浏览文件 @
9739d3b0
...
@@ -240,7 +240,7 @@ Status CifarOp::Reset() {
...
@@ -240,7 +240,7 @@ Status CifarOp::Reset() {
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status
CifarOp
::
InitSampler
()
{
Status
CifarOp
::
InitSampler
()
{
RETURN_IF_NOT_OK
(
sampler_
->
Init
(
this
));
RETURN_IF_NOT_OK
(
sampler_
->
HandshakeRandomAccessOp
(
this
));
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
浏览文件 @
9739d3b0
...
@@ -258,7 +258,7 @@ Status ImageFolderOp::Reset() {
...
@@ -258,7 +258,7 @@ Status ImageFolderOp::Reset() {
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status
ImageFolderOp
::
InitSampler
()
{
Status
ImageFolderOp
::
InitSampler
()
{
RETURN_IF_NOT_OK
(
sampler_
->
Init
(
this
));
RETURN_IF_NOT_OK
(
sampler_
->
HandshakeRandomAccessOp
(
this
));
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
浏览文件 @
9739d3b0
...
@@ -254,7 +254,7 @@ Status ManifestOp::Reset() {
...
@@ -254,7 +254,7 @@ Status ManifestOp::Reset() {
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status
ManifestOp
::
InitSampler
()
{
Status
ManifestOp
::
InitSampler
()
{
RETURN_IF_NOT_OK
(
sampler_
->
Init
(
this
));
RETURN_IF_NOT_OK
(
sampler_
->
HandshakeRandomAccessOp
(
this
));
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
浏览文件 @
9739d3b0
...
@@ -205,7 +205,7 @@ Status MnistOp::Reset() {
...
@@ -205,7 +205,7 @@ Status MnistOp::Reset() {
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status
MnistOp
::
InitSampler
()
{
Status
MnistOp
::
InitSampler
()
{
RETURN_IF_NOT_OK
(
sampler_
->
Init
(
this
));
RETURN_IF_NOT_OK
(
sampler_
->
HandshakeRandomAccessOp
(
this
));
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
浏览文件 @
9739d3b0
...
@@ -31,8 +31,9 @@ DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shu
...
@@ -31,8 +31,9 @@ DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shu
num_devices_
(
num_dev
),
num_devices_
(
num_dev
),
shuffle_
(
shuffle
)
{}
shuffle_
(
shuffle
)
{}
Status
DistributedSampler
::
Init
(
const
RandomAccessOp
*
op
)
{
Status
DistributedSampler
::
InitSampler
()
{
RETURN_IF_NOT_OK
(
Sampler
::
Init
(
op
));
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
,
"num_samples <= 0
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
,
"num_rows <= 0
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
device_id_
<
num_devices_
&&
device_id_
>=
0
&&
num_rows_
>
0
&&
num_samples_
>
0
,
CHECK_FAIL_RETURN_UNEXPECTED
(
device_id_
<
num_devices_
&&
device_id_
>=
0
&&
num_rows_
>
0
&&
num_samples_
>
0
,
"fail to init DistributedSampler"
);
"fail to init DistributedSampler"
);
rnd_
.
seed
(
seed_
++
);
rnd_
.
seed
(
seed_
++
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h
浏览文件 @
9739d3b0
...
@@ -41,10 +41,8 @@ class DistributedSampler : public Sampler {
...
@@ -41,10 +41,8 @@ class DistributedSampler : public Sampler {
// @return - The error code return
// @return - The error code return
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
// first handshake between StorageOp and Sampler
// Init sampler, called by base class or python
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
Status
InitSampler
()
override
;
// @return
Status
Init
(
const
RandomAccessOp
*
)
override
;
// for next epoch of sampleIds
// for next epoch of sampleIds
// @return - The error code return
// @return - The error code return
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
浏览文件 @
9739d3b0
...
@@ -28,9 +28,7 @@ PKSampler::PKSampler(int64_t val, bool shuffle, int64_t samples_per_buffer)
...
@@ -28,9 +28,7 @@ PKSampler::PKSampler(int64_t val, bool shuffle, int64_t samples_per_buffer)
num_pk_samples_
(
0
),
num_pk_samples_
(
0
),
samples_per_class_
(
val
)
{}
samples_per_class_
(
val
)
{}
Status
PKSampler
::
Init
(
const
RandomAccessOp
*
op
)
{
Status
PKSampler
::
InitSampler
()
{
RETURN_UNEXPECTED_IF_NULL
(
op
);
RETURN_IF_NOT_OK
(
op
->
GetClassIds
(
&
label_to_ids_
));
labels_
.
reserve
(
label_to_ids_
.
size
());
labels_
.
reserve
(
label_to_ids_
.
size
());
for
(
const
auto
&
pair
:
label_to_ids_
)
{
for
(
const
auto
&
pair
:
label_to_ids_
)
{
if
(
pair
.
second
.
empty
()
==
false
)
{
if
(
pair
.
second
.
empty
()
==
false
)
{
...
@@ -79,5 +77,13 @@ Status PKSampler::Reset() {
...
@@ -79,5 +77,13 @@ Status PKSampler::Reset() {
rnd_
.
seed
(
seed_
++
);
rnd_
.
seed
(
seed_
++
);
return
Status
::
OK
();
return
Status
::
OK
();
}
}
Status
PKSampler
::
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
)
{
RETURN_UNEXPECTED_IF_NULL
(
op
);
RETURN_IF_NOT_OK
(
op
->
GetClassIds
(
&
label_to_ids_
));
RETURN_IF_NOT_OK
(
InitSampler
());
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace dataset
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h
浏览文件 @
9739d3b0
...
@@ -45,7 +45,10 @@ class PKSampler : public Sampler { // NOT YET FINISHED
...
@@ -45,7 +45,10 @@ class PKSampler : public Sampler { // NOT YET FINISHED
// first handshake between StorageOp and Sampler
// first handshake between StorageOp and Sampler
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
// @return
// @return
Status
Init
(
const
RandomAccessOp
*
op
)
override
;
Status
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
)
override
;
// init sampler, to be called by python or Handshake
Status
InitSampler
()
override
;
// for next epoch of sampleIds
// for next epoch of sampleIds
// @return - The error code return
// @return - The error code return
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
浏览文件 @
9739d3b0
...
@@ -49,10 +49,9 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
...
@@ -49,10 +49,9 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
return
Status
::
OK
();
return
Status
::
OK
();
}
}
Status
RandomSampler
::
Init
(
const
RandomAccessOp
*
op
)
{
Status
RandomSampler
::
InitSampler
()
{
RETURN_IF_NOT_OK
(
Sampler
::
Init
(
op
));
num_samples_
=
(
user_num_samples_
<
num_samples_
)
?
user_num_samples_
:
num_samples_
;
num_samples_
=
(
user_num_samples_
<
num_samples_
)
?
user_num_samples_
:
num_samples_
;
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
&&
num_rows_
>
0
,
"
Fail to init RandomSampler
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
&&
num_rows_
>
0
,
"
both num_samples & num_rows need to be positive
"
);
samples_per_buffer_
=
samples_per_buffer_
>
num_samples_
?
num_samples_
:
samples_per_buffer_
;
samples_per_buffer_
=
samples_per_buffer_
>
num_samples_
?
num_samples_
:
samples_per_buffer_
;
if
(
replacement_
==
false
)
{
if
(
replacement_
==
false
)
{
shuffled_ids_
.
reserve
(
num_rows_
);
shuffled_ids_
.
reserve
(
num_rows_
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h
浏览文件 @
9739d3b0
...
@@ -42,10 +42,8 @@ class RandomSampler : public Sampler {
...
@@ -42,10 +42,8 @@ class RandomSampler : public Sampler {
// @return - The error code return
// @return - The error code return
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
// first handshake between StorageOp and Sampler
// meant to be called by base class or python
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
Status
InitSampler
()
override
;
// @return
Status
Init
(
const
RandomAccessOp
*
op
)
override
;
// for next epoch of sampleIds
// for next epoch of sampleIds
// @return - The error code return
// @return - The error code return
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
浏览文件 @
9739d3b0
...
@@ -20,12 +20,13 @@ namespace dataset {
...
@@ -20,12 +20,13 @@ namespace dataset {
Sampler
::
Sampler
(
int64_t
samples_per_buffer
)
Sampler
::
Sampler
(
int64_t
samples_per_buffer
)
:
DatasetOp
(
0
),
num_rows_
(
0
),
num_samples_
(
0
),
samples_per_buffer_
(
samples_per_buffer
),
col_desc_
(
nullptr
)
{}
:
DatasetOp
(
0
),
num_rows_
(
0
),
num_samples_
(
0
),
samples_per_buffer_
(
samples_per_buffer
),
col_desc_
(
nullptr
)
{}
Status
Sampler
::
Init
(
const
RandomAccessOp
*
op
)
{
Status
Sampler
::
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
op
!=
nullptr
&&
samples_per_buffer_
>
0
,
"Fail to init Sampler()
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
op
!=
nullptr
,
"RandomAccessOp is nullptr
\n
"
);
RETURN_IF_NOT_OK
(
op
->
GetNumSamples
(
&
num_samples_
));
RETURN_IF_NOT_OK
(
op
->
GetNumSamples
(
&
num_samples_
));
RETURN_IF_NOT_OK
(
op
->
GetNumRowsInDataset
(
&
num_rows_
));
RETURN_IF_NOT_OK
(
op
->
GetNumRowsInDataset
(
&
num_rows_
));
// It's up to the derived class to check the validity of the two args
// It's up to the derived class to check the validity of the two args
// Because some sampler only needs one of the arg (weighted_random_sampler)
// Because some sampler only needs one of the arg (weighted_random_sampler)
RETURN_IF_NOT_OK
(
InitSampler
());
// init sampler after callback
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
@@ -42,5 +43,49 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
...
@@ -42,5 +43,49 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
(
void
)(
*
sample_ids
)
->
StartAddr
();
// allocate memory in case user forgets!
(
void
)(
*
sample_ids
)
->
StartAddr
();
// allocate memory in case user forgets!
return
Status
::
OK
();
return
Status
::
OK
();
}
}
Status
Sampler
::
GetAllIdsThenReset
(
py
::
array
*
data
)
{
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
shared_ptr
<
Tensor
>
sample_ids
;
// check samples_per_buffer is properly set and doesn't overflow
CHECK_FAIL_RETURN_UNEXPECTED
(
samples_per_buffer_
+
1
>
1
,
"samples_per_buffer invalid"
);
// A call to derived class to get sample ids wrapped inside a buffer
RETURN_IF_NOT_OK
(
GetNextBuffer
(
&
db
));
// Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch
RETURN_IF_NOT_OK
(
db
->
GetTensor
(
&
sample_ids
,
0
,
0
));
// check this buffer is not a ctrl buffer
CHECK_FAIL_RETURN_UNEXPECTED
(
db
->
buffer_flags
()
==
DataBuffer
::
kDeBFlagNone
,
"ERROR ctrl buffer received"
);
{
py
::
gil_scoped_acquire
gil_acquire
;
if
(
Py_IsInitialized
()
==
0
)
{
return
Status
(
StatusCode
::
kPythonInterpreterFailure
,
"Python Interpreter is finalized"
);
}
try
{
RETURN_IF_NOT_OK
(
sample_ids
->
GetDataAsNumpy
(
data
));
}
catch
(
const
std
::
runtime_error
&
e
)
{
return
Status
(
StatusCode
::
kPyFuncException
,
e
.
what
());
}
}
// perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch
RETURN_IF_NOT_OK
(
GetNextBuffer
(
&
db
));
CHECK_FAIL_RETURN_UNEXPECTED
(
db
->
eoe
(),
"ERROR Non EOE received"
);
// Reset Sampler since this is the end of the epoch
RETURN_IF_NOT_OK
(
Reset
());
return
Status
::
OK
();
}
Status
Sampler
::
SetNumSamples
(
int64_t
num_samples
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples
>
0
,
"num_samples is negative or 0"
);
num_samples_
=
num_samples
;
return
Status
::
OK
();
}
Status
Sampler
::
SetNumRowsInDataset
(
int64_t
num_rows
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows
>
0
,
"num_rows is negative or 0"
);
num_rows_
=
num_rows
;
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace dataset
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
浏览文件 @
9739d3b0
...
@@ -78,14 +78,26 @@ class Sampler : public DatasetOp {
...
@@ -78,14 +78,26 @@ class Sampler : public DatasetOp {
// @return - The error code return
// @return - The error code return
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
=
0
;
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
=
0
;
// return all ids in one epoch as a numpy array, then call reset
Status
GetAllIdsThenReset
(
py
::
array
*
data
);
// for next epoch of sampleIds
// for next epoch of sampleIds
// @return - The error code return
// @return - The error code return
Status
Reset
()
override
=
0
;
Status
Reset
()
override
=
0
;
// first handshake between StorageOp and Sampler. Base class init will call both GetNumRows and GetNumSamples
// setter function for num_rows_
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
Status
SetNumRowsInDataset
(
int64_t
num_rows
);
// setter function for num_samples_
Status
SetNumSamples
(
int64_t
num_samples
);
// first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples
// @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds()
// @return
// @return
virtual
Status
Init
(
const
RandomAccessOp
*
op
);
virtual
Status
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
);
// initialize sampler and perform checks on certain vars
virtual
Status
InitSampler
()
{
return
Status
::
OK
();
}
// Not meant to be called
// Not meant to be called
// @return
// @return
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
浏览文件 @
9739d3b0
...
@@ -41,9 +41,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
...
@@ -41,9 +41,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
return
Status
::
OK
();
return
Status
::
OK
();
}
}
Status
SequentialSampler
::
Init
(
const
RandomAccessOp
*
op
)
{
Status
SequentialSampler
::
InitSampler
()
{
RETURN_UNEXPECTED_IF_NULL
(
op
);
RETURN_IF_NOT_OK
(
op
->
GetNumSamples
(
&
num_samples_
));
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
&&
samples_per_buffer_
>
0
,
"Fail to init Sequential Sampler"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
&&
samples_per_buffer_
>
0
,
"Fail to init Sequential Sampler"
);
samples_per_buffer_
=
samples_per_buffer_
>
num_samples_
?
num_samples_
:
samples_per_buffer_
;
samples_per_buffer_
=
samples_per_buffer_
>
num_samples_
?
num_samples_
:
samples_per_buffer_
;
return
Status
::
OK
();
return
Status
::
OK
();
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h
浏览文件 @
9739d3b0
...
@@ -32,10 +32,8 @@ class SequentialSampler : public Sampler {
...
@@ -32,10 +32,8 @@ class SequentialSampler : public Sampler {
// Destructor.
// Destructor.
~
SequentialSampler
()
=
default
;
~
SequentialSampler
()
=
default
;
// Initialize the sampler.
// init sampler, called by python
// @param op
Status
InitSampler
()
override
;
// @return Status
Status
Init
(
const
RandomAccessOp
*
op
)
override
;
// for next epoch of sampleIds
// for next epoch of sampleIds
// @return - The error code return
// @return - The error code return
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
浏览文件 @
9739d3b0
...
@@ -31,9 +31,8 @@ SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, in
...
@@ -31,9 +31,8 @@ SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, in
:
Sampler
(
samples_per_buffer
),
indices_
(
indices
),
sample_id_
(
0
),
buffer_id_
(
0
)
{}
:
Sampler
(
samples_per_buffer
),
indices_
(
indices
),
sample_id_
(
0
),
buffer_id_
(
0
)
{}
// Initialized this Sampler.
// Initialized this Sampler.
Status
SubsetRandomSampler
::
Init
(
const
RandomAccessOp
*
op
)
{
Status
SubsetRandomSampler
::
InitSampler
()
{
// Calling base class init.
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
,
"num_rows <= 0
\n
"
);
RETURN_IF_NOT_OK
(
Sampler
::
Init
(
op
));
// Initialize random generator with seed from config manager
// Initialize random generator with seed from config manager
rand_gen_
.
seed
(
GetSeed
());
rand_gen_
.
seed
(
GetSeed
());
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h
浏览文件 @
9739d3b0
...
@@ -38,9 +38,8 @@ class SubsetRandomSampler : public Sampler {
...
@@ -38,9 +38,8 @@ class SubsetRandomSampler : public Sampler {
~
SubsetRandomSampler
()
=
default
;
~
SubsetRandomSampler
()
=
default
;
// Initialize the sampler.
// Initialize the sampler.
// @param op (Not used in this sampler)
// @return Status
// @return Status
Status
Init
(
const
RandomAccessOp
*
op
)
override
;
Status
Init
Sampler
(
)
override
;
// Reset the internal variable to the initial state and reshuffle the indices.
// Reset the internal variable to the initial state and reshuffle the indices.
// @return Status
// @return Status
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
浏览文件 @
9739d3b0
...
@@ -29,21 +29,21 @@ namespace dataset {
...
@@ -29,21 +29,21 @@ namespace dataset {
// Constructor.
// Constructor.
WeightedRandomSampler
::
WeightedRandomSampler
(
const
std
::
vector
<
double
>
&
weights
,
int64_t
num_samples
,
bool
replacement
,
WeightedRandomSampler
::
WeightedRandomSampler
(
const
std
::
vector
<
double
>
&
weights
,
int64_t
num_samples
,
bool
replacement
,
int64_t
samples_per_buffer
)
int64_t
samples_per_buffer
)
:
Sampler
(
samples_per_buffer
),
weights_
(
weights
),
replacement_
(
replacement
),
sample_id_
(
0
),
buffer_id_
(
0
)
{
:
Sampler
(
samples_per_buffer
),
num_samples_
=
num_samples
;
// this variable is defined in base class sampler
weights_
(
weights
),
}
replacement_
(
replacement
),
sample_id_
(
0
),
buffer_id_
(
0
),
user_num_samples_
(
num_samples
)
{}
// Initialized this Sampler.
// Initialized this Sampler.
Status
WeightedRandomSampler
::
Init
(
const
RandomAccessOp
*
op
)
{
Status
WeightedRandomSampler
::
InitSampler
()
{
RETURN_UNEXPECTED_IF_NULL
(
op
);
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
&&
user_num_samples_
,
"num_samples & num_rows need to be positive"
);
RETURN_IF_NOT_OK
(
op
->
GetNumRowsInDataset
(
&
num_rows_
));
CHECK_FAIL_RETURN_UNEXPECTED
(
samples_per_buffer_
>
0
,
"samples_per_buffer<=0
\n
"
);
// Initialize random generator with seed from config manager
// Initialize random generator with seed from config manager
rand_gen_
.
seed
(
GetSeed
());
rand_gen_
.
seed
(
GetSeed
());
samples_per_buffer_
=
(
samples_per_buffer_
>
num_samples_
)
?
num_samples_
:
samples_per_buffer_
;
samples_per_buffer_
=
(
samples_per_buffer_
>
user_num_samples_
)
?
user_num_samples_
:
samples_per_buffer_
;
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
&&
samples_per_buffer_
>
0
,
"Fail to init WeightedRandomSampler"
);
if
(
!
replacement_
)
{
if
(
!
replacement_
)
{
exp_dist_
=
std
::
make_unique
<
std
::
exponential_distribution
<>>
(
1
);
exp_dist_
=
std
::
make_unique
<
std
::
exponential_distribution
<>>
(
1
);
...
@@ -65,8 +65,8 @@ void WeightedRandomSampler::InitOnePassSampling() {
...
@@ -65,8 +65,8 @@ void WeightedRandomSampler::InitOnePassSampling() {
}
}
// Partial sort the first `numSamples` elements.
// Partial sort the first `numSamples` elements.
std
::
partial_sort
(
val_idx
.
begin
(),
val_idx
.
begin
()
+
num_samples_
,
val_idx
.
end
());
std
::
partial_sort
(
val_idx
.
begin
(),
val_idx
.
begin
()
+
user_
num_samples_
,
val_idx
.
end
());
for
(
int64_t
i
=
0
;
i
<
num_samples_
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
user_
num_samples_
;
i
++
)
{
onepass_ids_
.
push_back
(
val_idx
[
i
].
second
);
onepass_ids_
.
push_back
(
val_idx
[
i
].
second
);
}
}
}
}
...
@@ -91,11 +91,11 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
...
@@ -91,11 +91,11 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors"
);
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors"
);
}
}
if
(
!
replacement_
&&
(
weights_
.
size
()
<
static_cast
<
size_t
>
(
num_samples_
)))
{
if
(
!
replacement_
&&
(
weights_
.
size
()
<
static_cast
<
size_t
>
(
user_
num_samples_
)))
{
RETURN_STATUS_UNEXPECTED
(
"Without replacement, sample weights less than numSamples"
);
RETURN_STATUS_UNEXPECTED
(
"Without replacement, sample weights less than numSamples"
);
}
}
if
(
sample_id_
==
num_samples_
)
{
if
(
sample_id_
==
user_
num_samples_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagEOE
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
}
else
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagNone
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagNone
);
...
@@ -103,8 +103,8 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
...
@@ -103,8 +103,8 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
int64_t
last_id
=
sample_id_
+
samples_per_buffer_
;
int64_t
last_id
=
sample_id_
+
samples_per_buffer_
;
// Handling the return all samples at once, and when last draw is not a full batch.
// Handling the return all samples at once, and when last draw is not a full batch.
if
(
last_id
>
num_samples_
)
{
if
(
last_id
>
user_
num_samples_
)
{
last_id
=
num_samples_
;
last_id
=
user_
num_samples_
;
}
}
// Allocate tensor.
// Allocate tensor.
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h
浏览文件 @
9739d3b0
...
@@ -43,7 +43,7 @@ class WeightedRandomSampler : public Sampler {
...
@@ -43,7 +43,7 @@ class WeightedRandomSampler : public Sampler {
// Initialize the sampler.
// Initialize the sampler.
// @param op (Not used in this sampler)
// @param op (Not used in this sampler)
// @return Status
// @return Status
Status
Init
(
const
RandomAccessOp
*
op
)
override
;
Status
Init
Sampler
(
)
override
;
// Reset the internal variable to the initial state and reshuffle the indices.
// Reset the internal variable to the initial state and reshuffle the indices.
Status
Reset
()
override
;
Status
Reset
()
override
;
...
@@ -69,6 +69,9 @@ class WeightedRandomSampler : public Sampler {
...
@@ -69,6 +69,9 @@ class WeightedRandomSampler : public Sampler {
// Random engine and device
// Random engine and device
std
::
mt19937
rand_gen_
;
std
::
mt19937
rand_gen_
;
// num_samples from user
int64_t
user_num_samples_
;
// Discrete distribution for generating weighted random numbers with replacement.
// Discrete distribution for generating weighted random numbers with replacement.
std
::
unique_ptr
<
std
::
discrete_distribution
<
int64_t
>>
discrete_dist_
;
std
::
unique_ptr
<
std
::
discrete_distribution
<
int64_t
>>
discrete_dist_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
浏览文件 @
9739d3b0
...
@@ -220,7 +220,7 @@ Status VOCOp::ParseImageIds() {
...
@@ -220,7 +220,7 @@ Status VOCOp::ParseImageIds() {
}
}
Status
VOCOp
::
InitSampler
()
{
Status
VOCOp
::
InitSampler
()
{
RETURN_IF_NOT_OK
(
sampler_
->
Init
(
this
));
RETURN_IF_NOT_OK
(
sampler_
->
HandshakeRandomAccessOp
(
this
));
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
9739d3b0
...
@@ -1748,14 +1748,70 @@ class MindDataset(SourceDataset):
...
@@ -1748,14 +1748,70 @@ class MindDataset(SourceDataset):
return
num_rows
return
num_rows
def
ds_fn
(
dataset
):
def
_iter_fn
(
dataset
,
num_samples
):
for
val
in
dataset
:
"""
# convert output tensors to ndarrays
Generator function wrapper for iterable dataset
yield
tuple
([
np
.
array
(
x
)
for
x
in
val
])
"""
if
num_samples
is
not
None
:
ds_iter
=
iter
(
dataset
)
for
_
in
range
(
num_samples
):
try
:
val
=
next
(
ds_iter
)
except
StopIteration
:
return
# convert output tensors to ndarrays
yield
tuple
([
np
.
array
(
x
)
for
x
in
val
])
else
:
for
val
in
dataset
:
# convert output tensors to ndarrays
yield
tuple
([
np
.
array
(
x
)
for
x
in
val
])
def
_generator_fn
(
generator
,
num_samples
):
"""
Generator function wrapper for generator function dataset
"""
if
num_samples
is
not
None
:
gen_iter
=
generator
()
for
_
in
range
(
num_samples
):
try
:
val
=
next
(
gen_iter
)
except
StopIteration
:
return
yield
val
else
:
gen_iter
=
generator
()
for
val
in
gen_iter
:
yield
val
def
sampler_fn
(
sampler
,
dataset
):
def
_py_sampler_fn
(
sampler
,
num_samples
,
dataset
):
for
i
in
sampler
:
"""
Generator function wrapper for mappable dataset with python sampler
"""
if
num_samples
is
not
None
:
sampler_iter
=
iter
(
sampler
)
for
_
in
range
(
num_samples
):
try
:
idx
=
next
(
sampler_iter
)
except
StopIteration
:
return
val
=
dataset
[
idx
]
# convert output tensors to ndarrays
yield
tuple
([
np
.
array
(
x
)
for
x
in
val
])
else
:
for
i
in
sampler
:
val
=
dataset
[
i
]
# convert output tensors to ndarrays
yield
tuple
([
np
.
array
(
x
)
for
x
in
val
])
def
_cpp_sampler_fn
(
sampler
,
dataset
):
"""
Generator function wrapper for mappable dataset with cpp sampler
"""
indices
=
sampler
.
get_indices
()
for
i
in
indices
:
val
=
dataset
[
i
]
val
=
dataset
[
i
]
# convert output tensors to ndarrays
# convert output tensors to ndarrays
yield
tuple
([
np
.
array
(
x
)
for
x
in
val
])
yield
tuple
([
np
.
array
(
x
)
for
x
in
val
])
...
@@ -1763,49 +1819,122 @@ def sampler_fn(sampler, dataset):
...
@@ -1763,49 +1819,122 @@ def sampler_fn(sampler, dataset):
class
GeneratorDataset
(
SourceDataset
):
class
GeneratorDataset
(
SourceDataset
):
"""
"""
A source dataset that generate data from calling generator function each epoch.
A source dataset that generate data from python by invoking python data source each epoch.
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
below shows what input args are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
:widths: 25 25 50
:header-rows: 1
* - Parameter 'sampler'
- Parameter 'shuffle'
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Args:
Args:
generator_function (callable):
source (Callable/Iterable/Random Accessible):
A callable object that returns an Generator object that supports the iter() protocol.
A generator callable object, an iterable python object or a random accessible python object.
Generator object is required to return a tuple of numpy array as a row of the dataset on next().
Callable source is required to return a tuple of numpy array as a row of the dataset on source().next().
Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next().
Random accessible source is required to return a tuple of numpy array as a row of the dataset on
source[idx].
column_names (list[str]): List of column names of the dataset.
column_names (list[str]): List of column names of the dataset.
column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None).
column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None).
If provided, sanity check will be performed on generator output.
If provided, sanity check will be performed on generator output.
prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None).
schema (Schema/String, optional): Path to the json schema file or schema object (default=None).
sampler (Sampler, optional): Object used to choose samples from the dataset (default=None).
If the schema is not provided, the meta data from column_names and column_types is considered the schema.
num_samples (int, optional): The number of samples to be included in the dataset
(default=None, all images).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
(default=None, expected order behavior shown in the table).
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
required.
(default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
This argument should be specified only when 'num_samples' is "None". Random accessible input is required.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
when num_shards is also specified. Random accessible input is required.
Examples:
Examples:
>>> import mindspore.data
set as ds
>>> import mindspore.data
engine as de
>>> # 1)
generator function that generates multi-dimensional data
>>> # 1)
Multidimensional generator function as callable input
>>> def generator_md():
>>> def generator_md():
>>> for i in range(64):
>>> for i in range(64):
>>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
>>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
>>> # create multi_dimension_generator_dataset with GeneratorMD
()
and column name "multi_dimensional_data"
>>> # create multi_dimension_generator_dataset with GeneratorMD and column name "multi_dimensional_data"
>>> multi_dimension_generator_dataset = d
s
.GeneratorDataset(generator_md, ["multi_dimensional_data"])
>>> multi_dimension_generator_dataset = d
e
.GeneratorDataset(generator_md, ["multi_dimensional_data"])
>>> # 2)
generator function that generates multi-columns data
>>> # 2)
Multi-column generator function as callable input
>>> def generator_mc(maxid = 64):
>>> def generator_mc(maxid = 64):
>>> for i in range(maxid):
>>> for i in range(maxid):
>>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
>>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
>>> # create multi_column_generator_dataset with GeneratorMC() and column names "col1" and "col2"
>>> # create multi_column_generator_dataset with GeneratorMC and column names "col1" and "col2"
>>> multi_column_generator_dataset = ds.GeneratorDataset(generator_mc, ["col1, col2"])
>>> multi_column_generator_dataset = de.GeneratorDataset(generator_mc, ["col1, col2"])
>>> # 3) Iterable dataset as iterable input
>>> class MyIterable():
>>> def __iter__(self):
>>> return # User implementation
>>> # create iterable_generator_dataset with MyIterable object
>>> iterable_generator_dataset = de.GeneratorDataset(MyIterable(), ["col1"])
>>> # 4) Random accessible dataset as Random accessible input
>>> class MyRA():
>>> def __getitem__(self, index):
>>> return # User implementation
>>> # create ra_generator_dataset with MyRA object
>>> ra_generator_dataset = de.GeneratorDataset(MyRA(), ["col1"])
>>> # List/Dict/Tuple is also random accessible
>>> list_generator = de.GeneratorDataset([(np.array(0),), (np.array(1)), (np.array(2))], ["col1"])
>>> # 5) Built-in Sampler
>>> my_generator = de.GeneratorDataset(my_ds, ["img", "label"], sampler=samplers.RandomSampler())
>>>
"""
"""
@
check_generatordataset
@
check_generatordataset
def
__init__
(
self
,
generator_function
,
column_names
,
column_types
=
None
,
prefetch_size
=
None
,
sampler
=
None
):
def
__init__
(
self
,
source
,
column_names
,
column_types
=
None
,
schema
=
None
,
num_samples
=
None
,
num_parallel_workers
=
1
,
super
().
__init__
(
1
)
shuffle
=
None
,
sampler
=
None
,
num_shards
=
None
,
shard_id
=
None
):
if
sampler
is
not
None
:
super
().
__init__
(
num_parallel_workers
)
self
.
generator_function
=
(
lambda
:
sampler_fn
(
sampler
,
generator_function
))
self
.
sampler
=
_select_sampler
(
num_samples
,
sampler
,
shuffle
,
num_shards
,
shard_id
)
if
self
.
sampler
is
not
None
and
hasattr
(
source
,
"__getitem__"
):
if
isinstance
(
self
.
sampler
,
(
samplers
.
SequentialSampler
,
samplers
.
DistributedSampler
,
samplers
.
RandomSampler
,
samplers
.
SubsetRandomSampler
,
samplers
.
WeightedRandomSampler
)):
if
num_samples
is
None
:
num_samples
=
len
(
source
)
sampler_instance
=
self
.
sampler
.
create
()
sampler_instance
.
set_num_rows
(
len
(
source
))
sampler_instance
.
set_num_samples
(
num_samples
)
sampler_instance
.
initialize
()
self
.
source
=
(
lambda
:
_cpp_sampler_fn
(
sampler_instance
,
source
))
else
:
self
.
source
=
(
lambda
:
_py_sampler_fn
(
self
.
sampler
,
num_samples
,
source
))
else
:
else
:
try
:
try
:
# test to see if generator_function is iterable
iter
(
source
)
iter
(
generator_function
)
except
TypeError
:
except
TypeError
:
#
generator_function was not iterable, assume it is a function
#
Use generator function if input callable
self
.
generator_function
=
generator_function
self
.
source
=
(
lambda
:
_generator_fn
(
source
,
num_samples
))
else
:
else
:
# generator_function was iterable, build a function around it
# Use iterator function if input is iterable
self
.
generator_function
=
(
lambda
:
ds_fn
(
generator_function
))
# Random accessible input is also iterable
self
.
source
=
(
lambda
:
_iter_fn
(
source
,
num_samples
))
self
.
column_names
=
column_names
self
.
column_names
=
column_names
...
@@ -1813,17 +1942,12 @@ class GeneratorDataset(SourceDataset):
...
@@ -1813,17 +1942,12 @@ class GeneratorDataset(SourceDataset):
self
.
column_types
=
mstypelist_to_detypelist
(
column_types
)
self
.
column_types
=
mstypelist_to_detypelist
(
column_types
)
else
:
else
:
self
.
column_types
=
column_types
self
.
column_types
=
column_types
self
.
distribution
=
""
self
.
prefetch_size
=
prefetch_size
self
.
sampler
=
sampler
def
get_args
(
self
):
def
get_args
(
self
):
args
=
super
().
get_args
()
args
=
super
().
get_args
()
args
[
"
generator_function"
]
=
self
.
generator_function
args
[
"
source"
]
=
self
.
source
args
[
"column_names"
]
=
self
.
column_names
args
[
"column_names"
]
=
self
.
column_names
args
[
"column_types"
]
=
self
.
column_types
args
[
"column_types"
]
=
self
.
column_types
args
[
"prefetch_size"
]
=
self
.
prefetch_size
args
[
"sampler"
]
=
self
.
sampler
return
args
return
args
def
get_dataset_size
(
self
):
def
get_dataset_size
(
self
):
...
...
mindspore/dataset/engine/samplers.py
浏览文件 @
9739d3b0
...
@@ -20,7 +20,6 @@ SequentialSampler, SubsetRandomSampler, WeightedRandomSampler.
...
@@ -20,7 +20,6 @@ SequentialSampler, SubsetRandomSampler, WeightedRandomSampler.
import
mindspore._c_dataengine
as
cde
import
mindspore._c_dataengine
as
cde
class
DistributedSampler
():
class
DistributedSampler
():
"""
"""
Sampler that access a shard of the dataset.
Sampler that access a shard of the dataset.
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
9739d3b0
...
@@ -543,28 +543,48 @@ def check_generatordataset(method):
...
@@ -543,28 +543,48 @@ def check_generatordataset(method):
def
new_method
(
*
args
,
**
kwargs
):
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
nreq_param_int
=
[
'prefetch_size'
]
nreq_param_list
=
[
'column_names'
,
'column_types'
]
# check generator_function; required argument
# check generator_function; required argument
generator_function
=
param_dict
.
get
(
'generator_function'
)
source
=
param_dict
.
get
(
'source'
)
if
generator_function
is
None
:
if
source
is
None
:
raise
ValueError
(
"generator_function is not provided."
)
raise
ValueError
(
"source is not provided."
)
if
not
callable
(
source
):
try
:
iter
(
source
)
except
TypeError
:
raise
TypeError
(
"source should be callable, iterable or random accessible"
)
# check column_names; required argument
# check column_names; required argument
column_names
=
param_dict
.
get
(
'column_names'
)
column_names
=
param_dict
.
get
(
'column_names'
)
if
column_names
is
None
:
if
column_names
is
None
:
raise
ValueError
(
"column_names is not provided."
)
raise
ValueError
(
"column_names is not provided."
)
# check prefetch_size range
# check optional argument
prefetch_size
=
param_dict
.
get
(
'prefetch_size'
)
nreq_param_int
=
[
"num_samples"
,
"num_parallel_workers"
,
"num_shards"
,
"shard_id"
]
if
prefetch_size
is
not
None
and
(
prefetch_size
<=
0
or
prefetch_size
>
1024
):
raise
ValueError
(
"prefetch_size exceeds the boundary."
)
check_param_type
(
nreq_param_int
,
param_dict
,
int
)
check_param_type
(
nreq_param_int
,
param_dict
,
int
)
nreq_param_list
=
[
"column_types"
]
check_param_type
(
nreq_param_list
,
param_dict
,
list
)
check_param_type
(
nreq_param_list
,
param_dict
,
list
)
num_shards
=
param_dict
.
get
(
"num_shards"
)
shard_id
=
param_dict
.
get
(
"shard_id"
)
if
(
num_shards
is
None
)
!=
(
shard_id
is
None
):
# These two parameters appear together.
raise
ValueError
(
"num_shards and shard_id need to be passed in together"
)
if
num_shards
is
not
None
:
if
shard_id
>=
num_shards
:
raise
ValueError
(
"shard_id should be less than num_shards"
)
sampler
=
param_dict
.
get
(
"sampler"
)
if
sampler
is
not
None
:
if
isinstance
(
sampler
,
samplers
.
PKSampler
):
raise
ValueError
(
"PKSampler is not supported by GeneratorDataset"
)
if
not
isinstance
(
sampler
,
(
samplers
.
SequentialSampler
,
samplers
.
DistributedSampler
,
samplers
.
RandomSampler
,
samplers
.
SubsetRandomSampler
,
samplers
.
WeightedRandomSampler
)):
try
:
iter
(
sampler
)
except
TypeError
:
raise
TypeError
(
"sampler should be either iterable or from dataset.samplers.py"
)
return
method
(
*
args
,
**
kwargs
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
return
new_method
...
...
tests/ut/cpp/dataset/stand_alone_samplers_test.cc
浏览文件 @
9739d3b0
...
@@ -75,7 +75,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) {
...
@@ -75,7 +75,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) {
std
::
shared_ptr
<
Tensor
>
tensor
;
std
::
shared_ptr
<
Tensor
>
tensor
;
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
DistributedSampler
>
(
3
,
i
%
3
,
(
i
<
3
?
false
:
true
));
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
DistributedSampler
>
(
3
,
i
%
3
,
(
i
<
3
?
false
:
true
));
sampler
->
Init
(
&
mock
);
sampler
->
HandshakeRandomAccessOp
(
&
mock
);
sampler
->
GetNextBuffer
(
&
db
);
sampler
->
GetNextBuffer
(
&
db
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
MS_LOG
(
DEBUG
)
<<
(
*
tensor
);
MS_LOG
(
DEBUG
)
<<
(
*
tensor
);
...
@@ -95,7 +95,7 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) {
...
@@ -95,7 +95,7 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) {
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_shared
<
SequentialSampler
>
(
3
);
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_shared
<
SequentialSampler
>
(
3
);
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
shared_ptr
<
Tensor
>
tensor
;
std
::
shared_ptr
<
Tensor
>
tensor
;
sampler
->
Init
(
&
mock
);
sampler
->
HandshakeRandomAccessOp
(
&
mock
);
sampler
->
GetNextBuffer
(
&
db
);
sampler
->
GetNextBuffer
(
&
db
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
EXPECT_TRUE
((
*
tensor
)
==
(
*
label1
));
EXPECT_TRUE
((
*
tensor
)
==
(
*
label1
));
...
...
tests/ut/cpp/dataset/subset_random_sampler_test.cc
浏览文件 @
9739d3b0
...
@@ -52,8 +52,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
...
@@ -52,8 +52,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
std
::
unordered_set
<
int64_t
>
in_set
(
in
.
begin
(),
in
.
end
());
std
::
unordered_set
<
int64_t
>
in_set
(
in
.
begin
(),
in
.
end
());
SubsetRandomSampler
sampler
(
in
);
SubsetRandomSampler
sampler
(
in
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
5
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
5
);
sampler
.
Init
(
&
dummy_random_access_o
p
);
sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
...
@@ -80,8 +80,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
...
@@ -80,8 +80,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
std
::
vector
<
int64_t
>
input
(
total_samples
,
1
);
std
::
vector
<
int64_t
>
input
(
total_samples
,
1
);
SubsetRandomSampler
sampler
(
input
,
samples_per_buffer
);
SubsetRandomSampler
sampler
(
input
,
samples_per_buffer
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
total_samples
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
total_samples
);
sampler
.
Init
(
&
dummy_random_access_o
p
);
sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
...
@@ -111,8 +111,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
...
@@ -111,8 +111,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
std
::
unordered_set
<
int64_t
>
in_set
(
in
.
begin
(),
in
.
end
());
std
::
unordered_set
<
int64_t
>
in_set
(
in
.
begin
(),
in
.
end
());
SubsetRandomSampler
sampler
(
in
);
SubsetRandomSampler
sampler
(
in
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
5
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
5
);
sampler
.
Init
(
&
dummy_random_access_o
p
);
sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
...
...
tests/ut/cpp/dataset/weighted_random_sampler_test.cc
浏览文件 @
9739d3b0
...
@@ -60,8 +60,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
...
@@ -60,8 +60,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
// create sampler with replacement = true
// create sampler with replacement = true
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
true
);
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
true
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
total_samples
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
total_samples
);
m_sampler
.
Init
(
&
dummy_random_access_o
p
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
...
@@ -90,8 +90,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
...
@@ -90,8 +90,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
// create sampler with replacement = replacement
// create sampler with replacement = replacement
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
false
);
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
false
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
total_samples
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
total_samples
);
m_sampler
.
Init
(
&
dummy_random_access_o
p
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
...
@@ -126,8 +126,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
...
@@ -126,8 +126,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
// create sampler with replacement = replacement
// create sampler with replacement = replacement
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
true
,
samples_per_buffer
);
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
true
,
samples_per_buffer
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
total_samples
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
total_samples
);
m_sampler
.
Init
(
&
dummy_random_access_o
p
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
...
@@ -162,8 +162,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
...
@@ -162,8 +162,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
// create sampler with replacement = replacement
// create sampler with replacement = replacement
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
false
,
samples_per_buffer
);
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
false
,
samples_per_buffer
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
total_samples
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
total_samples
);
m_sampler
.
Init
(
&
dummy_random_access_o
p
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
...
@@ -203,8 +203,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
...
@@ -203,8 +203,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
// create sampler with replacement = true
// create sampler with replacement = true
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
true
);
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
true
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
total_samples
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
total_samples
);
m_sampler
.
Init
(
&
dummy_random_access_o
p
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
...
@@ -248,8 +248,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
...
@@ -248,8 +248,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
// create sampler with replacement = true
// create sampler with replacement = true
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
false
);
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
false
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
total_samples
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
total_samples
);
m_sampler
.
Init
(
&
dummy_random_access_o
p
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
...
...
tests/ut/python/dataset/test_generator.py
浏览文件 @
9739d3b0
...
@@ -439,6 +439,74 @@ def test_case_error_4():
...
@@ -439,6 +439,74 @@ def test_case_error_4():
assert
"Unexpected error. Result of a tensorOp doesn't match output column names"
in
str
(
info
.
value
)
assert
"Unexpected error. Result of a tensorOp doesn't match output column names"
in
str
(
info
.
value
)
def
test_sequential_sampler
():
source
=
[(
np
.
array
([
x
]),)
for
x
in
range
(
64
)]
ds1
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
sampler
=
ds
.
SequentialSampler
())
i
=
0
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
data
[
"data"
],
golden
)
i
=
i
+
1
def
test_random_sampler
():
source
=
[(
np
.
array
([
x
]),)
for
x
in
range
(
64
)]
ds1
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
shuffle
=
True
)
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
pass
def
test_distributed_sampler
():
source
=
[(
np
.
array
([
x
]),)
for
x
in
range
(
64
)]
for
sid
in
range
(
8
):
ds1
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
shuffle
=
False
,
num_shards
=
8
,
shard_id
=
sid
)
i
=
sid
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
data
[
"data"
],
golden
)
i
=
i
+
8
def
test_num_samples
():
source
=
[(
np
.
array
([
x
]),)
for
x
in
range
(
64
)]
num_samples
=
32
ds1
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
sampler
=
ds
.
SequentialSampler
(),
num_samples
=
num_samples
)
ds2
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
sampler
=
[
i
for
i
in
range
(
32
)],
num_samples
=
num_samples
)
ds3
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
],
num_samples
=
num_samples
)
count
=
0
for
_
in
ds1
.
create_dict_iterator
():
count
=
count
+
1
assert
count
==
num_samples
count
=
0
for
_
in
ds2
.
create_dict_iterator
():
count
=
count
+
1
assert
count
==
num_samples
count
=
0
for
_
in
ds3
.
create_dict_iterator
():
count
=
count
+
1
assert
count
==
num_samples
def
test_num_samples_underflow
():
source
=
[(
np
.
array
([
x
]),)
for
x
in
range
(
64
)]
num_samples
=
256
ds2
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
sampler
=
[
i
for
i
in
range
(
64
)],
num_samples
=
num_samples
)
ds3
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
],
num_samples
=
num_samples
)
count
=
0
for
_
in
ds2
.
create_dict_iterator
():
count
=
count
+
1
assert
count
==
64
count
=
0
for
_
in
ds3
.
create_dict_iterator
():
count
=
count
+
1
assert
count
==
64
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_case_0
()
test_case_0
()
test_case_1
()
test_case_1
()
...
@@ -458,3 +526,6 @@ if __name__ == "__main__":
...
@@ -458,3 +526,6 @@ if __name__ == "__main__":
test_case_error_2
()
test_case_error_2
()
test_case_error_3
()
test_case_error_3
()
test_case_error_4
()
test_case_error_4
()
test_sequential_sampler
()
test_distributed_sampler
()
test_random_sampler
()
tests/ut/python/dataset/test_sampler.py
浏览文件 @
9739d3b0
...
@@ -87,7 +87,28 @@ def test_random_sampler_multi_iter(print_res=False):
...
@@ -87,7 +87,28 @@ def test_random_sampler_multi_iter(print_res=False):
test_config
(
replacement
=
True
,
num_samples
=
5
,
num_repeats
=
5
,
validate
=
[
0
,
1
,
2
,
3
,
4
,
5
])
test_config
(
replacement
=
True
,
num_samples
=
5
,
num_repeats
=
5
,
validate
=
[
0
,
1
,
2
,
3
,
4
,
5
])
def
test_sampler_py_api
():
sampler
=
ds
.
SequentialSampler
().
create
()
sampler
.
set_num_rows
(
128
)
sampler
.
set_num_samples
(
64
)
sampler
.
initialize
()
sampler
.
get_indices
()
sampler
=
ds
.
RandomSampler
().
create
()
sampler
.
set_num_rows
(
128
)
sampler
.
set_num_samples
(
64
)
sampler
.
initialize
()
sampler
.
get_indices
()
sampler
=
ds
.
DistributedSampler
(
8
,
4
).
create
()
sampler
.
set_num_rows
(
128
)
sampler
.
set_num_samples
(
64
)
sampler
.
initialize
()
sampler
.
get_indices
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_sequential_sampler
(
True
)
test_sequential_sampler
(
True
)
test_random_sampler
(
True
)
test_random_sampler
(
True
)
test_random_sampler_multi_iter
(
True
)
test_random_sampler_multi_iter
(
True
)
test_sampler_py_api
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录