Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
cf026096
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cf026096
编写于
4月 16, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
!183 Mindspore.dataset CPP sampler for GeneratorDataset
Merge pull request !183 from JunhanHu/cpp_sampler
上级
ff464bbc
9739d3b0
变更
31
隐藏空白更改
内联
并排
Showing
31 changed file
with
432 addition
and
127 deletion
+432
-127
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+1
-1
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+10
-1
mindspore/ccsrc/dataset/core/tensor.cc
mindspore/ccsrc/dataset/core/tensor.cc
+2
-0
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
...spore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
...ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
...ore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
...t/engine/datasetops/source/sampler/distributed_sampler.cc
+3
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h
...et/engine/datasetops/source/sampler/distributed_sampler.h
+2
-4
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
...rc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
+9
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h
...src/dataset/engine/datasetops/source/sampler/pk_sampler.h
+4
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
...ataset/engine/datasetops/source/sampler/random_sampler.cc
+2
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h
...dataset/engine/datasetops/source/sampler/random_sampler.h
+2
-4
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
...ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
+47
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
.../ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
+15
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
...et/engine/datasetops/source/sampler/sequential_sampler.cc
+1
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h
...set/engine/datasetops/source/sampler/sequential_sampler.h
+2
-4
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
...engine/datasetops/source/sampler/subset_random_sampler.cc
+2
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h
.../engine/datasetops/source/sampler/subset_random_sampler.h
+1
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
...gine/datasetops/source/sampler/weighted_random_sampler.cc
+16
-16
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h
...ngine/datasetops/source/sampler/weighted_random_sampler.h
+4
-1
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
+1
-1
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+159
-35
mindspore/dataset/engine/samplers.py
mindspore/dataset/engine/samplers.py
+0
-1
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+32
-12
tests/ut/cpp/dataset/stand_alone_samplers_test.cc
tests/ut/cpp/dataset/stand_alone_samplers_test.cc
+2
-2
tests/ut/cpp/dataset/subset_random_sampler_test.cc
tests/ut/cpp/dataset/subset_random_sampler_test.cc
+6
-6
tests/ut/cpp/dataset/weighted_random_sampler_test.cc
tests/ut/cpp/dataset/weighted_random_sampler_test.cc
+12
-12
tests/ut/python/dataset/test_generator.py
tests/ut/python/dataset/test_generator.py
+71
-0
tests/ut/python/dataset/test_sampler.py
tests/ut/python/dataset/test_sampler.py
+21
-0
未找到文件。
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
cf026096
...
...
@@ -517,7 +517,7 @@ Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<Datase
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"
generator_function
"
)
{
if
(
key
==
"
source
"
)
{
py
::
object
obj
=
py
::
cast
(
&
value
);
if
(
!
py
::
isinstance
<
py
::
function
>
(
obj
))
{
std
::
string
err_msg
=
"Error: generator is invalid or not set."
;
...
...
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
cf026096
...
...
@@ -388,7 +388,16 @@ void bindTensorOps4(py::module *m) {
}
void
bindSamplerOps
(
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
Sampler
,
std
::
shared_ptr
<
Sampler
>>
(
*
m
,
"Sampler"
);
(
void
)
py
::
class_
<
Sampler
,
std
::
shared_ptr
<
Sampler
>>
(
*
m
,
"Sampler"
)
.
def
(
"set_num_rows"
,
[](
Sampler
&
self
,
int64_t
rows
)
{
THROW_IF_ERROR
(
self
.
SetNumRowsInDataset
(
rows
));
})
.
def
(
"set_num_samples"
,
[](
Sampler
&
self
,
int64_t
samples
)
{
THROW_IF_ERROR
(
self
.
SetNumSamples
(
samples
));
})
.
def
(
"initialize"
,
[](
Sampler
&
self
)
{
THROW_IF_ERROR
(
self
.
InitSampler
());
})
.
def
(
"get_indices"
,
[](
Sampler
&
self
)
{
py
::
array
ret
;
THROW_IF_ERROR
(
self
.
GetAllIdsThenReset
(
&
ret
));
return
ret
;
});
(
void
)
py
::
class_
<
mindrecord
::
ShardOperator
,
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>>
(
*
m
,
"ShardOperator"
);
(
void
)
py
::
class_
<
DistributedSampler
,
Sampler
,
std
::
shared_ptr
<
DistributedSampler
>>
(
*
m
,
"DistributedSampler"
)
...
...
mindspore/ccsrc/dataset/core/tensor.cc
浏览文件 @
cf026096
...
...
@@ -491,6 +491,8 @@ Status Tensor::GetItemAt(T *o, const std::vector<dsize_t> &index) const {
// return data as numpy, should return status
Status
Tensor
::
GetDataAsNumpy
(
py
::
array
*
data
)
{
RETURN_UNEXPECTED_IF_NULL
(
data_
);
RETURN_UNEXPECTED_IF_NULL
(
data
);
if
(
type_
==
DataType
::
DE_BOOL
)
{
*
data
=
py
::
array_t
<
bool
>
(
shape_
.
AsVector
(),
reinterpret_cast
<
bool
*>
(
data_
));
}
else
if
(
type_
==
DataType
::
DE_INT8
)
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
浏览文件 @
cf026096
...
...
@@ -100,7 +100,7 @@ Status CelebAOp::LaunchThreadsAndInitOp() {
RETURN_IF_NOT_OK
(
tree_
->
LaunchWorkers
(
num_workers_
,
std
::
bind
(
&
CelebAOp
::
WorkerEntry
,
this
,
std
::
placeholders
::
_1
)));
TaskManager
::
FindMe
()
->
Post
();
RETURN_IF_NOT_OK
(
ParseImageAttrInfo
());
RETURN_IF_NOT_OK
(
sampler_
->
Init
(
this
));
RETURN_IF_NOT_OK
(
sampler_
->
HandshakeRandomAccessOp
(
this
));
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
浏览文件 @
cf026096
...
...
@@ -240,7 +240,7 @@ Status CifarOp::Reset() {
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status
CifarOp
::
InitSampler
()
{
RETURN_IF_NOT_OK
(
sampler_
->
Init
(
this
));
RETURN_IF_NOT_OK
(
sampler_
->
HandshakeRandomAccessOp
(
this
));
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
浏览文件 @
cf026096
...
...
@@ -258,7 +258,7 @@ Status ImageFolderOp::Reset() {
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status
ImageFolderOp
::
InitSampler
()
{
RETURN_IF_NOT_OK
(
sampler_
->
Init
(
this
));
RETURN_IF_NOT_OK
(
sampler_
->
HandshakeRandomAccessOp
(
this
));
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
浏览文件 @
cf026096
...
...
@@ -254,7 +254,7 @@ Status ManifestOp::Reset() {
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status
ManifestOp
::
InitSampler
()
{
RETURN_IF_NOT_OK
(
sampler_
->
Init
(
this
));
RETURN_IF_NOT_OK
(
sampler_
->
HandshakeRandomAccessOp
(
this
));
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
浏览文件 @
cf026096
...
...
@@ -205,7 +205,7 @@ Status MnistOp::Reset() {
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status
MnistOp
::
InitSampler
()
{
RETURN_IF_NOT_OK
(
sampler_
->
Init
(
this
));
RETURN_IF_NOT_OK
(
sampler_
->
HandshakeRandomAccessOp
(
this
));
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
浏览文件 @
cf026096
...
...
@@ -31,8 +31,9 @@ DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shu
num_devices_
(
num_dev
),
shuffle_
(
shuffle
)
{}
Status
DistributedSampler
::
Init
(
const
RandomAccessOp
*
op
)
{
RETURN_IF_NOT_OK
(
Sampler
::
Init
(
op
));
Status
DistributedSampler
::
InitSampler
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
,
"num_samples <= 0
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
,
"num_rows <= 0
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
device_id_
<
num_devices_
&&
device_id_
>=
0
&&
num_rows_
>
0
&&
num_samples_
>
0
,
"fail to init DistributedSampler"
);
rnd_
.
seed
(
seed_
++
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h
浏览文件 @
cf026096
...
...
@@ -41,10 +41,8 @@ class DistributedSampler : public Sampler {
// @return - The error code return
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
// first handshake between StorageOp and Sampler
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
// @return
Status
Init
(
const
RandomAccessOp
*
)
override
;
// Init sampler, called by base class or python
Status
InitSampler
()
override
;
// for next epoch of sampleIds
// @return - The error code return
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
浏览文件 @
cf026096
...
...
@@ -28,9 +28,7 @@ PKSampler::PKSampler(int64_t val, bool shuffle, int64_t samples_per_buffer)
num_pk_samples_
(
0
),
samples_per_class_
(
val
)
{}
Status
PKSampler
::
Init
(
const
RandomAccessOp
*
op
)
{
RETURN_UNEXPECTED_IF_NULL
(
op
);
RETURN_IF_NOT_OK
(
op
->
GetClassIds
(
&
label_to_ids_
));
Status
PKSampler
::
InitSampler
()
{
labels_
.
reserve
(
label_to_ids_
.
size
());
for
(
const
auto
&
pair
:
label_to_ids_
)
{
if
(
pair
.
second
.
empty
()
==
false
)
{
...
...
@@ -79,5 +77,13 @@ Status PKSampler::Reset() {
rnd_
.
seed
(
seed_
++
);
return
Status
::
OK
();
}
Status
PKSampler
::
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
)
{
RETURN_UNEXPECTED_IF_NULL
(
op
);
RETURN_IF_NOT_OK
(
op
->
GetClassIds
(
&
label_to_ids_
));
RETURN_IF_NOT_OK
(
InitSampler
());
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h
浏览文件 @
cf026096
...
...
@@ -45,7 +45,10 @@ class PKSampler : public Sampler { // NOT YET FINISHED
// first handshake between StorageOp and Sampler
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
// @return
Status
Init
(
const
RandomAccessOp
*
op
)
override
;
Status
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
)
override
;
// init sampler, to be called by python or Handshake
Status
InitSampler
()
override
;
// for next epoch of sampleIds
// @return - The error code return
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
浏览文件 @
cf026096
...
...
@@ -49,10 +49,9 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
return
Status
::
OK
();
}
Status
RandomSampler
::
Init
(
const
RandomAccessOp
*
op
)
{
RETURN_IF_NOT_OK
(
Sampler
::
Init
(
op
));
Status
RandomSampler
::
InitSampler
()
{
num_samples_
=
(
user_num_samples_
<
num_samples_
)
?
user_num_samples_
:
num_samples_
;
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
&&
num_rows_
>
0
,
"
Fail to init RandomSampler
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
&&
num_rows_
>
0
,
"
both num_samples & num_rows need to be positive
"
);
samples_per_buffer_
=
samples_per_buffer_
>
num_samples_
?
num_samples_
:
samples_per_buffer_
;
if
(
replacement_
==
false
)
{
shuffled_ids_
.
reserve
(
num_rows_
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h
浏览文件 @
cf026096
...
...
@@ -42,10 +42,8 @@ class RandomSampler : public Sampler {
// @return - The error code return
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
// first handshake between StorageOp and Sampler
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
// @return
Status
Init
(
const
RandomAccessOp
*
op
)
override
;
// meant to be called by base class or python
Status
InitSampler
()
override
;
// for next epoch of sampleIds
// @return - The error code return
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
浏览文件 @
cf026096
...
...
@@ -20,12 +20,13 @@ namespace dataset {
Sampler
::
Sampler
(
int64_t
samples_per_buffer
)
:
DatasetOp
(
0
),
num_rows_
(
0
),
num_samples_
(
0
),
samples_per_buffer_
(
samples_per_buffer
),
col_desc_
(
nullptr
)
{}
Status
Sampler
::
Init
(
const
RandomAccessOp
*
op
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
op
!=
nullptr
&&
samples_per_buffer_
>
0
,
"Fail to init Sampler()
\n
"
);
Status
Sampler
::
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
op
!=
nullptr
,
"RandomAccessOp is nullptr
\n
"
);
RETURN_IF_NOT_OK
(
op
->
GetNumSamples
(
&
num_samples_
));
RETURN_IF_NOT_OK
(
op
->
GetNumRowsInDataset
(
&
num_rows_
));
// It's up to the derived class to check the validity of the two args
// Because some sampler only needs one of the arg (weighted_random_sampler)
RETURN_IF_NOT_OK
(
InitSampler
());
// init sampler after callback
return
Status
::
OK
();
}
...
...
@@ -42,5 +43,49 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
(
void
)(
*
sample_ids
)
->
StartAddr
();
// allocate memory in case user forgets!
return
Status
::
OK
();
}
Status
Sampler
::
GetAllIdsThenReset
(
py
::
array
*
data
)
{
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
shared_ptr
<
Tensor
>
sample_ids
;
// check samples_per_buffer is properly set and doesn't overflow
CHECK_FAIL_RETURN_UNEXPECTED
(
samples_per_buffer_
+
1
>
1
,
"samples_per_buffer invalid"
);
// A call to derived class to get sample ids wrapped inside a buffer
RETURN_IF_NOT_OK
(
GetNextBuffer
(
&
db
));
// Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch
RETURN_IF_NOT_OK
(
db
->
GetTensor
(
&
sample_ids
,
0
,
0
));
// check this buffer is not a ctrl buffer
CHECK_FAIL_RETURN_UNEXPECTED
(
db
->
buffer_flags
()
==
DataBuffer
::
kDeBFlagNone
,
"ERROR ctrl buffer received"
);
{
py
::
gil_scoped_acquire
gil_acquire
;
if
(
Py_IsInitialized
()
==
0
)
{
return
Status
(
StatusCode
::
kPythonInterpreterFailure
,
"Python Interpreter is finalized"
);
}
try
{
RETURN_IF_NOT_OK
(
sample_ids
->
GetDataAsNumpy
(
data
));
}
catch
(
const
std
::
runtime_error
&
e
)
{
return
Status
(
StatusCode
::
kPyFuncException
,
e
.
what
());
}
}
// perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch
RETURN_IF_NOT_OK
(
GetNextBuffer
(
&
db
));
CHECK_FAIL_RETURN_UNEXPECTED
(
db
->
eoe
(),
"ERROR Non EOE received"
);
// Reset Sampler since this is the end of the epoch
RETURN_IF_NOT_OK
(
Reset
());
return
Status
::
OK
();
}
Status
Sampler
::
SetNumSamples
(
int64_t
num_samples
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples
>
0
,
"num_samples is negative or 0"
);
num_samples_
=
num_samples
;
return
Status
::
OK
();
}
Status
Sampler
::
SetNumRowsInDataset
(
int64_t
num_rows
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows
>
0
,
"num_rows is negative or 0"
);
num_rows_
=
num_rows
;
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
浏览文件 @
cf026096
...
...
@@ -78,14 +78,26 @@ class Sampler : public DatasetOp {
// @return - The error code return
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
=
0
;
// return all ids in one epoch as a numpy array, then call reset
Status
GetAllIdsThenReset
(
py
::
array
*
data
);
// for next epoch of sampleIds
// @return - The error code return
Status
Reset
()
override
=
0
;
// first handshake between StorageOp and Sampler. Base class init will call both GetNumRows and GetNumSamples
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
// setter function for num_rows_
Status
SetNumRowsInDataset
(
int64_t
num_rows
);
// setter function for num_samples_
Status
SetNumSamples
(
int64_t
num_samples
);
// first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples
// @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds()
// @return
virtual
Status
Init
(
const
RandomAccessOp
*
op
);
virtual
Status
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
);
// initialize sampler and perform checks on certain vars
virtual
Status
InitSampler
()
{
return
Status
::
OK
();
}
// Not meant to be called
// @return
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
浏览文件 @
cf026096
...
...
@@ -41,9 +41,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
return
Status
::
OK
();
}
Status
SequentialSampler
::
Init
(
const
RandomAccessOp
*
op
)
{
RETURN_UNEXPECTED_IF_NULL
(
op
);
RETURN_IF_NOT_OK
(
op
->
GetNumSamples
(
&
num_samples_
));
Status
SequentialSampler
::
InitSampler
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
&&
samples_per_buffer_
>
0
,
"Fail to init Sequential Sampler"
);
samples_per_buffer_
=
samples_per_buffer_
>
num_samples_
?
num_samples_
:
samples_per_buffer_
;
return
Status
::
OK
();
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h
浏览文件 @
cf026096
...
...
@@ -32,10 +32,8 @@ class SequentialSampler : public Sampler {
// Destructor.
~
SequentialSampler
()
=
default
;
// Initialize the sampler.
// @param op
// @return Status
Status
Init
(
const
RandomAccessOp
*
op
)
override
;
// init sampler, called by python
Status
InitSampler
()
override
;
// for next epoch of sampleIds
// @return - The error code return
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
浏览文件 @
cf026096
...
...
@@ -31,9 +31,8 @@ SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, in
:
Sampler
(
samples_per_buffer
),
indices_
(
indices
),
sample_id_
(
0
),
buffer_id_
(
0
)
{}
// Initialized this Sampler.
Status
SubsetRandomSampler
::
Init
(
const
RandomAccessOp
*
op
)
{
// Calling base class init.
RETURN_IF_NOT_OK
(
Sampler
::
Init
(
op
));
Status
SubsetRandomSampler
::
InitSampler
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
,
"num_rows <= 0
\n
"
);
// Initialize random generator with seed from config manager
rand_gen_
.
seed
(
GetSeed
());
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h
浏览文件 @
cf026096
...
...
@@ -38,9 +38,8 @@ class SubsetRandomSampler : public Sampler {
~
SubsetRandomSampler
()
=
default
;
// Initialize the sampler.
// @param op (Not used in this sampler)
// @return Status
Status
Init
(
const
RandomAccessOp
*
op
)
override
;
Status
Init
Sampler
(
)
override
;
// Reset the internal variable to the initial state and reshuffle the indices.
// @return Status
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
浏览文件 @
cf026096
...
...
@@ -29,21 +29,21 @@ namespace dataset {
// Constructor.
WeightedRandomSampler
::
WeightedRandomSampler
(
const
std
::
vector
<
double
>
&
weights
,
int64_t
num_samples
,
bool
replacement
,
int64_t
samples_per_buffer
)
:
Sampler
(
samples_per_buffer
),
weights_
(
weights
),
replacement_
(
replacement
),
sample_id_
(
0
),
buffer_id_
(
0
)
{
num_samples_
=
num_samples
;
// this variable is defined in base class sampler
}
:
Sampler
(
samples_per_buffer
),
weights_
(
weights
),
replacement_
(
replacement
),
sample_id_
(
0
),
buffer_id_
(
0
),
user_num_samples_
(
num_samples
)
{}
// Initialized this Sampler.
Status
WeightedRandomSampler
::
Init
(
const
RandomAccessOp
*
op
)
{
RETURN_UNEXPECTED_IF_NULL
(
op
);
RETURN_IF_NOT_OK
(
op
->
GetNumRowsInDataset
(
&
num_rows_
));
Status
WeightedRandomSampler
::
InitSampler
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
&&
user_num_samples_
,
"num_samples & num_rows need to be positive"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
samples_per_buffer_
>
0
,
"samples_per_buffer<=0
\n
"
);
// Initialize random generator with seed from config manager
rand_gen_
.
seed
(
GetSeed
());
samples_per_buffer_
=
(
samples_per_buffer_
>
num_samples_
)
?
num_samples_
:
samples_per_buffer_
;
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
&&
samples_per_buffer_
>
0
,
"Fail to init WeightedRandomSampler"
);
samples_per_buffer_
=
(
samples_per_buffer_
>
user_num_samples_
)
?
user_num_samples_
:
samples_per_buffer_
;
if
(
!
replacement_
)
{
exp_dist_
=
std
::
make_unique
<
std
::
exponential_distribution
<>>
(
1
);
...
...
@@ -65,8 +65,8 @@ void WeightedRandomSampler::InitOnePassSampling() {
}
// Partial sort the first `numSamples` elements.
std
::
partial_sort
(
val_idx
.
begin
(),
val_idx
.
begin
()
+
num_samples_
,
val_idx
.
end
());
for
(
int64_t
i
=
0
;
i
<
num_samples_
;
i
++
)
{
std
::
partial_sort
(
val_idx
.
begin
(),
val_idx
.
begin
()
+
user_
num_samples_
,
val_idx
.
end
());
for
(
int64_t
i
=
0
;
i
<
user_
num_samples_
;
i
++
)
{
onepass_ids_
.
push_back
(
val_idx
[
i
].
second
);
}
}
...
...
@@ -91,11 +91,11 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors"
);
}
if
(
!
replacement_
&&
(
weights_
.
size
()
<
static_cast
<
size_t
>
(
num_samples_
)))
{
if
(
!
replacement_
&&
(
weights_
.
size
()
<
static_cast
<
size_t
>
(
user_
num_samples_
)))
{
RETURN_STATUS_UNEXPECTED
(
"Without replacement, sample weights less than numSamples"
);
}
if
(
sample_id_
==
num_samples_
)
{
if
(
sample_id_
==
user_
num_samples_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagNone
);
...
...
@@ -103,8 +103,8 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
int64_t
last_id
=
sample_id_
+
samples_per_buffer_
;
// Handling the return all samples at once, and when last draw is not a full batch.
if
(
last_id
>
num_samples_
)
{
last_id
=
num_samples_
;
if
(
last_id
>
user_
num_samples_
)
{
last_id
=
user_
num_samples_
;
}
// Allocate tensor.
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h
浏览文件 @
cf026096
...
...
@@ -43,7 +43,7 @@ class WeightedRandomSampler : public Sampler {
// Initialize the sampler.
// @param op (Not used in this sampler)
// @return Status
Status
Init
(
const
RandomAccessOp
*
op
)
override
;
Status
Init
Sampler
(
)
override
;
// Reset the internal variable to the initial state and reshuffle the indices.
Status
Reset
()
override
;
...
...
@@ -69,6 +69,9 @@ class WeightedRandomSampler : public Sampler {
// Random engine and device
std
::
mt19937
rand_gen_
;
// num_samples from user
int64_t
user_num_samples_
;
// Discrete distribution for generating weighted random numbers with replacement.
std
::
unique_ptr
<
std
::
discrete_distribution
<
int64_t
>>
discrete_dist_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
浏览文件 @
cf026096
...
...
@@ -220,7 +220,7 @@ Status VOCOp::ParseImageIds() {
}
Status
VOCOp
::
InitSampler
()
{
RETURN_IF_NOT_OK
(
sampler_
->
Init
(
this
));
RETURN_IF_NOT_OK
(
sampler_
->
HandshakeRandomAccessOp
(
this
));
return
Status
::
OK
();
}
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
cf026096
...
...
@@ -1758,14 +1758,70 @@ class MindDataset(SourceDataset):
return
num_rows
def
ds_fn
(
dataset
):
for
val
in
dataset
:
# convert output tensors to ndarrays
yield
tuple
([
np
.
array
(
x
)
for
x
in
val
])
def
_iter_fn
(
dataset
,
num_samples
):
"""
Generator function wrapper for iterable dataset
"""
if
num_samples
is
not
None
:
ds_iter
=
iter
(
dataset
)
for
_
in
range
(
num_samples
):
try
:
val
=
next
(
ds_iter
)
except
StopIteration
:
return
# convert output tensors to ndarrays
yield
tuple
([
np
.
array
(
x
)
for
x
in
val
])
else
:
for
val
in
dataset
:
# convert output tensors to ndarrays
yield
tuple
([
np
.
array
(
x
)
for
x
in
val
])
def
_generator_fn
(
generator
,
num_samples
):
"""
Generator function wrapper for generator function dataset
"""
if
num_samples
is
not
None
:
gen_iter
=
generator
()
for
_
in
range
(
num_samples
):
try
:
val
=
next
(
gen_iter
)
except
StopIteration
:
return
yield
val
else
:
gen_iter
=
generator
()
for
val
in
gen_iter
:
yield
val
def
sampler_fn
(
sampler
,
dataset
):
for
i
in
sampler
:
def
_py_sampler_fn
(
sampler
,
num_samples
,
dataset
):
"""
Generator function wrapper for mappable dataset with python sampler
"""
if
num_samples
is
not
None
:
sampler_iter
=
iter
(
sampler
)
for
_
in
range
(
num_samples
):
try
:
idx
=
next
(
sampler_iter
)
except
StopIteration
:
return
val
=
dataset
[
idx
]
# convert output tensors to ndarrays
yield
tuple
([
np
.
array
(
x
)
for
x
in
val
])
else
:
for
i
in
sampler
:
val
=
dataset
[
i
]
# convert output tensors to ndarrays
yield
tuple
([
np
.
array
(
x
)
for
x
in
val
])
def
_cpp_sampler_fn
(
sampler
,
dataset
):
"""
Generator function wrapper for mappable dataset with cpp sampler
"""
indices
=
sampler
.
get_indices
()
for
i
in
indices
:
val
=
dataset
[
i
]
# convert output tensors to ndarrays
yield
tuple
([
np
.
array
(
x
)
for
x
in
val
])
...
...
@@ -1773,49 +1829,122 @@ def sampler_fn(sampler, dataset):
class
GeneratorDataset
(
SourceDataset
):
"""
A source dataset that generate data from calling generator function each epoch.
A source dataset that generate data from python by invoking python data source each epoch.
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
below shows what input args are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
:widths: 25 25 50
:header-rows: 1
* - Parameter 'sampler'
- Parameter 'shuffle'
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Args:
generator_function (callable):
A callable object that returns an Generator object that supports the iter() protocol.
Generator object is required to return a tuple of numpy array as a row of the dataset on next().
source (Callable/Iterable/Random Accessible):
A generator callable object, an iterable python object or a random accessible python object.
Callable source is required to return a tuple of numpy array as a row of the dataset on source().next().
Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next().
Random accessible source is required to return a tuple of numpy array as a row of the dataset on
source[idx].
column_names (list[str]): List of column names of the dataset.
column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None).
If provided, sanity check will be performed on generator output.
prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None).
sampler (Sampler, optional): Object used to choose samples from the dataset (default=None).
schema (Schema/String, optional): Path to the json schema file or schema object (default=None).
If the schema is not provided, the meta data from column_names and column_types is considered the schema.
num_samples (int, optional): The number of samples to be included in the dataset
(default=None, all images).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
(default=None, expected order behavior shown in the table).
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
required.
(default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
This argument should be specified only when 'num_samples' is "None". Random accessible input is required.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
when num_shards is also specified. Random accessible input is required.
Examples:
>>> import mindspore.data
set as ds
>>> # 1)
generator function that generates multi-dimensional data
>>> import mindspore.data
engine as de
>>> # 1)
Multidimensional generator function as callable input
>>> def generator_md():
>>> for i in range(64):
>>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
>>> # create multi_dimension_generator_dataset with GeneratorMD
()
and column name "multi_dimensional_data"
>>> multi_dimension_generator_dataset = d
s
.GeneratorDataset(generator_md, ["multi_dimensional_data"])
>>> # 2)
generator function that generates multi-columns data
>>> # create multi_dimension_generator_dataset with GeneratorMD and column name "multi_dimensional_data"
>>> multi_dimension_generator_dataset = d
e
.GeneratorDataset(generator_md, ["multi_dimensional_data"])
>>> # 2)
Multi-column generator function as callable input
>>> def generator_mc(maxid = 64):
>>> for i in range(maxid):
>>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
>>> # create multi_column_generator_dataset with GeneratorMC() and column names "col1" and "col2"
>>> multi_column_generator_dataset = ds.GeneratorDataset(generator_mc, ["col1, col2"])
>>> # create multi_column_generator_dataset with GeneratorMC and column names "col1" and "col2"
>>> multi_column_generator_dataset = de.GeneratorDataset(generator_mc, ["col1, col2"])
>>> # 3) Iterable dataset as iterable input
>>> class MyIterable():
>>> def __iter__(self):
>>> return # User implementation
>>> # create iterable_generator_dataset with MyIterable object
>>> iterable_generator_dataset = de.GeneratorDataset(MyIterable(), ["col1"])
>>> # 4) Random accessible dataset as Random accessible input
>>> class MyRA():
>>> def __getitem__(self, index):
>>> return # User implementation
>>> # create ra_generator_dataset with MyRA object
>>> ra_generator_dataset = de.GeneratorDataset(MyRA(), ["col1"])
>>> # List/Dict/Tuple is also random accessible
>>> list_generator = de.GeneratorDataset([(np.array(0),), (np.array(1)), (np.array(2))], ["col1"])
>>> # 5) Built-in Sampler
>>> my_generator = de.GeneratorDataset(my_ds, ["img", "label"], sampler=samplers.RandomSampler())
>>>
"""
@
check_generatordataset
def
__init__
(
self
,
generator_function
,
column_names
,
column_types
=
None
,
prefetch_size
=
None
,
sampler
=
None
):
super
().
__init__
(
1
)
if
sampler
is
not
None
:
self
.
generator_function
=
(
lambda
:
sampler_fn
(
sampler
,
generator_function
))
def
__init__
(
self
,
source
,
column_names
,
column_types
=
None
,
schema
=
None
,
num_samples
=
None
,
num_parallel_workers
=
1
,
shuffle
=
None
,
sampler
=
None
,
num_shards
=
None
,
shard_id
=
None
):
super
().
__init__
(
num_parallel_workers
)
self
.
sampler
=
_select_sampler
(
num_samples
,
sampler
,
shuffle
,
num_shards
,
shard_id
)
if
self
.
sampler
is
not
None
and
hasattr
(
source
,
"__getitem__"
):
if
isinstance
(
self
.
sampler
,
(
samplers
.
SequentialSampler
,
samplers
.
DistributedSampler
,
samplers
.
RandomSampler
,
samplers
.
SubsetRandomSampler
,
samplers
.
WeightedRandomSampler
)):
if
num_samples
is
None
:
num_samples
=
len
(
source
)
sampler_instance
=
self
.
sampler
.
create
()
sampler_instance
.
set_num_rows
(
len
(
source
))
sampler_instance
.
set_num_samples
(
num_samples
)
sampler_instance
.
initialize
()
self
.
source
=
(
lambda
:
_cpp_sampler_fn
(
sampler_instance
,
source
))
else
:
self
.
source
=
(
lambda
:
_py_sampler_fn
(
self
.
sampler
,
num_samples
,
source
))
else
:
try
:
# test to see if generator_function is iterable
iter
(
generator_function
)
iter
(
source
)
except
TypeError
:
#
generator_function was not iterable, assume it is a function
self
.
generator_function
=
generator_function
#
Use generator function if input callable
self
.
source
=
(
lambda
:
_generator_fn
(
source
,
num_samples
))
else
:
# generator_function was iterable, build a function around it
self
.
generator_function
=
(
lambda
:
ds_fn
(
generator_function
))
# Use iterator function if input is iterable
# Random accessible input is also iterable
self
.
source
=
(
lambda
:
_iter_fn
(
source
,
num_samples
))
self
.
column_names
=
column_names
...
...
@@ -1823,17 +1952,12 @@ class GeneratorDataset(SourceDataset):
self
.
column_types
=
mstypelist_to_detypelist
(
column_types
)
else
:
self
.
column_types
=
column_types
self
.
distribution
=
""
self
.
prefetch_size
=
prefetch_size
self
.
sampler
=
sampler
def
get_args
(
self
):
args
=
super
().
get_args
()
args
[
"
generator_function"
]
=
self
.
generator_function
args
[
"
source"
]
=
self
.
source
args
[
"column_names"
]
=
self
.
column_names
args
[
"column_types"
]
=
self
.
column_types
args
[
"prefetch_size"
]
=
self
.
prefetch_size
args
[
"sampler"
]
=
self
.
sampler
return
args
def
get_dataset_size
(
self
):
...
...
mindspore/dataset/engine/samplers.py
浏览文件 @
cf026096
...
...
@@ -20,7 +20,6 @@ SequentialSampler, SubsetRandomSampler, WeightedRandomSampler.
import
mindspore._c_dataengine
as
cde
class
DistributedSampler
():
"""
Sampler that access a shard of the dataset.
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
cf026096
...
...
@@ -543,28 +543,48 @@ def check_generatordataset(method):
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
nreq_param_int
=
[
'prefetch_size'
]
nreq_param_list
=
[
'column_names'
,
'column_types'
]
# check generator_function; required argument
generator_function
=
param_dict
.
get
(
'generator_function'
)
if
generator_function
is
None
:
raise
ValueError
(
"generator_function is not provided."
)
source
=
param_dict
.
get
(
'source'
)
if
source
is
None
:
raise
ValueError
(
"source is not provided."
)
if
not
callable
(
source
):
try
:
iter
(
source
)
except
TypeError
:
raise
TypeError
(
"source should be callable, iterable or random accessible"
)
# check column_names; required argument
column_names
=
param_dict
.
get
(
'column_names'
)
if
column_names
is
None
:
raise
ValueError
(
"column_names is not provided."
)
# check prefetch_size range
prefetch_size
=
param_dict
.
get
(
'prefetch_size'
)
if
prefetch_size
is
not
None
and
(
prefetch_size
<=
0
or
prefetch_size
>
1024
):
raise
ValueError
(
"prefetch_size exceeds the boundary."
)
# check optional argument
nreq_param_int
=
[
"num_samples"
,
"num_parallel_workers"
,
"num_shards"
,
"shard_id"
]
check_param_type
(
nreq_param_int
,
param_dict
,
int
)
nreq_param_list
=
[
"column_types"
]
check_param_type
(
nreq_param_list
,
param_dict
,
list
)
num_shards
=
param_dict
.
get
(
"num_shards"
)
shard_id
=
param_dict
.
get
(
"shard_id"
)
if
(
num_shards
is
None
)
!=
(
shard_id
is
None
):
# These two parameters appear together.
raise
ValueError
(
"num_shards and shard_id need to be passed in together"
)
if
num_shards
is
not
None
:
if
shard_id
>=
num_shards
:
raise
ValueError
(
"shard_id should be less than num_shards"
)
sampler
=
param_dict
.
get
(
"sampler"
)
if
sampler
is
not
None
:
if
isinstance
(
sampler
,
samplers
.
PKSampler
):
raise
ValueError
(
"PKSampler is not supported by GeneratorDataset"
)
if
not
isinstance
(
sampler
,
(
samplers
.
SequentialSampler
,
samplers
.
DistributedSampler
,
samplers
.
RandomSampler
,
samplers
.
SubsetRandomSampler
,
samplers
.
WeightedRandomSampler
)):
try
:
iter
(
sampler
)
except
TypeError
:
raise
TypeError
(
"sampler should be either iterable or from dataset.samplers.py"
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
...
...
tests/ut/cpp/dataset/stand_alone_samplers_test.cc
浏览文件 @
cf026096
...
...
@@ -75,7 +75,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) {
std
::
shared_ptr
<
Tensor
>
tensor
;
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
DistributedSampler
>
(
3
,
i
%
3
,
(
i
<
3
?
false
:
true
));
sampler
->
Init
(
&
mock
);
sampler
->
HandshakeRandomAccessOp
(
&
mock
);
sampler
->
GetNextBuffer
(
&
db
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
MS_LOG
(
DEBUG
)
<<
(
*
tensor
);
...
...
@@ -95,7 +95,7 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) {
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_shared
<
SequentialSampler
>
(
3
);
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
shared_ptr
<
Tensor
>
tensor
;
sampler
->
Init
(
&
mock
);
sampler
->
HandshakeRandomAccessOp
(
&
mock
);
sampler
->
GetNextBuffer
(
&
db
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
EXPECT_TRUE
((
*
tensor
)
==
(
*
label1
));
...
...
tests/ut/cpp/dataset/subset_random_sampler_test.cc
浏览文件 @
cf026096
...
...
@@ -52,8 +52,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
std
::
unordered_set
<
int64_t
>
in_set
(
in
.
begin
(),
in
.
end
());
SubsetRandomSampler
sampler
(
in
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
5
);
sampler
.
Init
(
&
dummy_random_access_o
p
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
5
);
sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
...
...
@@ -80,8 +80,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
std
::
vector
<
int64_t
>
input
(
total_samples
,
1
);
SubsetRandomSampler
sampler
(
input
,
samples_per_buffer
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
total_samples
);
sampler
.
Init
(
&
dummy_random_access_o
p
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
total_samples
);
sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
...
...
@@ -111,8 +111,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
std
::
unordered_set
<
int64_t
>
in_set
(
in
.
begin
(),
in
.
end
());
SubsetRandomSampler
sampler
(
in
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
5
);
sampler
.
Init
(
&
dummy_random_access_o
p
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
5
);
sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
...
...
tests/ut/cpp/dataset/weighted_random_sampler_test.cc
浏览文件 @
cf026096
...
...
@@ -60,8 +60,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
// create sampler with replacement = true
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
true
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
total_samples
);
m_sampler
.
Init
(
&
dummy_random_access_o
p
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
total_samples
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
...
...
@@ -90,8 +90,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
// create sampler with replacement = replacement
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
false
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
total_samples
);
m_sampler
.
Init
(
&
dummy_random_access_o
p
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
total_samples
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
...
...
@@ -126,8 +126,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
// create sampler with replacement = replacement
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
true
,
samples_per_buffer
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
total_samples
);
m_sampler
.
Init
(
&
dummy_random_access_o
p
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
total_samples
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
...
...
@@ -162,8 +162,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
// create sampler with replacement = replacement
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
false
,
samples_per_buffer
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
total_samples
);
m_sampler
.
Init
(
&
dummy_random_access_o
p
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
total_samples
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
...
...
@@ -203,8 +203,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
// create sampler with replacement = true
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
true
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
total_samples
);
m_sampler
.
Init
(
&
dummy_random_access_o
p
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
total_samples
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
...
...
@@ -248,8 +248,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
// create sampler with replacement = true
WeightedRandomSampler
m_sampler
(
weights
,
num_samples
,
false
);
DummyRandomAccessOp
dummy
_random_access_o
p
(
total_samples
);
m_sampler
.
Init
(
&
dummy_random_access_o
p
);
DummyRandomAccessOp
dummy
RandomAccessO
p
(
total_samples
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessO
p
);
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
...
...
tests/ut/python/dataset/test_generator.py
浏览文件 @
cf026096
...
...
@@ -439,6 +439,74 @@ def test_case_error_4():
assert
"Unexpected error. Result of a tensorOp doesn't match output column names"
in
str
(
info
.
value
)
def
test_sequential_sampler
():
source
=
[(
np
.
array
([
x
]),)
for
x
in
range
(
64
)]
ds1
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
sampler
=
ds
.
SequentialSampler
())
i
=
0
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
data
[
"data"
],
golden
)
i
=
i
+
1
def
test_random_sampler
():
source
=
[(
np
.
array
([
x
]),)
for
x
in
range
(
64
)]
ds1
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
shuffle
=
True
)
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
pass
def
test_distributed_sampler
():
source
=
[(
np
.
array
([
x
]),)
for
x
in
range
(
64
)]
for
sid
in
range
(
8
):
ds1
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
shuffle
=
False
,
num_shards
=
8
,
shard_id
=
sid
)
i
=
sid
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
data
[
"data"
],
golden
)
i
=
i
+
8
def
test_num_samples
():
source
=
[(
np
.
array
([
x
]),)
for
x
in
range
(
64
)]
num_samples
=
32
ds1
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
sampler
=
ds
.
SequentialSampler
(),
num_samples
=
num_samples
)
ds2
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
sampler
=
[
i
for
i
in
range
(
32
)],
num_samples
=
num_samples
)
ds3
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
],
num_samples
=
num_samples
)
count
=
0
for
_
in
ds1
.
create_dict_iterator
():
count
=
count
+
1
assert
count
==
num_samples
count
=
0
for
_
in
ds2
.
create_dict_iterator
():
count
=
count
+
1
assert
count
==
num_samples
count
=
0
for
_
in
ds3
.
create_dict_iterator
():
count
=
count
+
1
assert
count
==
num_samples
def
test_num_samples_underflow
():
source
=
[(
np
.
array
([
x
]),)
for
x
in
range
(
64
)]
num_samples
=
256
ds2
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
sampler
=
[
i
for
i
in
range
(
64
)],
num_samples
=
num_samples
)
ds3
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
],
num_samples
=
num_samples
)
count
=
0
for
_
in
ds2
.
create_dict_iterator
():
count
=
count
+
1
assert
count
==
64
count
=
0
for
_
in
ds3
.
create_dict_iterator
():
count
=
count
+
1
assert
count
==
64
if
__name__
==
"__main__"
:
test_case_0
()
test_case_1
()
...
...
@@ -458,3 +526,6 @@ if __name__ == "__main__":
test_case_error_2
()
test_case_error_3
()
test_case_error_4
()
test_sequential_sampler
()
test_distributed_sampler
()
test_random_sampler
()
tests/ut/python/dataset/test_sampler.py
浏览文件 @
cf026096
...
...
@@ -87,7 +87,28 @@ def test_random_sampler_multi_iter(print_res=False):
test_config
(
replacement
=
True
,
num_samples
=
5
,
num_repeats
=
5
,
validate
=
[
0
,
1
,
2
,
3
,
4
,
5
])
def
test_sampler_py_api
():
sampler
=
ds
.
SequentialSampler
().
create
()
sampler
.
set_num_rows
(
128
)
sampler
.
set_num_samples
(
64
)
sampler
.
initialize
()
sampler
.
get_indices
()
sampler
=
ds
.
RandomSampler
().
create
()
sampler
.
set_num_rows
(
128
)
sampler
.
set_num_samples
(
64
)
sampler
.
initialize
()
sampler
.
get_indices
()
sampler
=
ds
.
DistributedSampler
(
8
,
4
).
create
()
sampler
.
set_num_rows
(
128
)
sampler
.
set_num_samples
(
64
)
sampler
.
initialize
()
sampler
.
get_indices
()
if
__name__
==
'__main__'
:
test_sequential_sampler
(
True
)
test_random_sampler
(
True
)
test_random_sampler_multi_iter
(
True
)
test_sampler_py_api
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录