Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2e3d55ed
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2e3d55ed
编写于
5月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1281 Implementation of SplitOp
Merge pull request !1281 from Peilin/splitOp
上级
39b9aedf
71e8bb19
变更
24
展开全部
隐藏空白更改
内联
并排
Showing
24 changed file
with
1507 addition
and
46 deletion
+1507
-46
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+12
-0
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+15
-7
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt
...c/dataset/engine/datasetops/source/sampler/CMakeLists.txt
+1
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
...t/engine/datasetops/source/sampler/distributed_sampler.cc
+35
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h
...et/engine/datasetops/source/sampler/distributed_sampler.h
+2
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
...rc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
+19
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc
...ataset/engine/datasetops/source/sampler/python_sampler.cc
+17
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
...ataset/engine/datasetops/source/sampler/random_sampler.cc
+50
-6
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h
...dataset/engine/datasetops/source/sampler/random_sampler.h
+5
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
...ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
+73
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
.../ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
+34
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
...et/engine/datasetops/source/sampler/sequential_sampler.cc
+26
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h
...set/engine/datasetops/source/sampler/sequential_sampler.h
+2
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
...engine/datasetops/source/sampler/subset_random_sampler.cc
+18
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc
...ataset/engine/datasetops/source/sampler/subset_sampler.cc
+85
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h
...dataset/engine/datasetops/source/sampler/subset_sampler.h
+58
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
...gine/datasetops/source/sampler/weighted_random_sampler.cc
+17
-1
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+425
-11
mindspore/dataset/engine/samplers.py
mindspore/dataset/engine/samplers.py
+201
-10
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+41
-0
tests/ut/cpp/dataset/cifar_op_test.cc
tests/ut/cpp/dataset/cifar_op_test.cc
+1
-1
tests/ut/cpp/dataset/image_folder_op_test.cc
tests/ut/cpp/dataset/image_folder_op_test.cc
+1
-1
tests/ut/python/dataset/test_sampler.py
tests/ut/python/dataset/test_sampler.py
+27
-0
tests/ut/python/dataset/test_split.py
tests/ut/python/dataset/test_split.py
+342
-0
未找到文件。
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
2e3d55ed
...
...
@@ -364,6 +364,18 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO
std
::
string
err_msg
=
"Error: Shuffle buffer size is missing"
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
// Optional arguments
for
(
auto
arg
:
args
)
{
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"reshuffle_each_epoch"
)
{
(
void
)
builder
->
SetReshuffleEachEpoch
(
ToBool
(
args
[
"reshuffle_each_epoch"
]));
}
}
}
std
::
shared_ptr
<
ShuffleOp
>
op
;
RETURN_IF_NOT_OK
(
builder
->
Build
(
&
op
));
*
ptr
=
op
;
...
...
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
2e3d55ed
...
...
@@ -51,6 +51,7 @@
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
...
...
@@ -425,11 +426,14 @@ void bindSamplerOps(py::module *m) {
.
def
(
"set_num_rows"
,
[](
Sampler
&
self
,
int64_t
rows
)
{
THROW_IF_ERROR
(
self
.
SetNumRowsInDataset
(
rows
));
})
.
def
(
"set_num_samples"
,
[](
Sampler
&
self
,
int64_t
samples
)
{
THROW_IF_ERROR
(
self
.
SetNumSamples
(
samples
));
})
.
def
(
"initialize"
,
[](
Sampler
&
self
)
{
THROW_IF_ERROR
(
self
.
InitSampler
());
})
.
def
(
"get_indices"
,
[](
Sampler
&
self
)
{
py
::
array
ret
;
THROW_IF_ERROR
(
self
.
GetAllIdsThenReset
(
&
ret
));
return
ret
;
});
.
def
(
"get_indices"
,
[](
Sampler
&
self
)
{
py
::
array
ret
;
THROW_IF_ERROR
(
self
.
GetAllIdsThenReset
(
&
ret
));
return
ret
;
})
.
def
(
"add_child"
,
[](
std
::
shared_ptr
<
Sampler
>
self
,
std
::
shared_ptr
<
Sampler
>
child
)
{
THROW_IF_ERROR
(
self
->
AddChild
(
child
));
});
(
void
)
py
::
class_
<
mindrecord
::
ShardOperator
,
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>>
(
*
m
,
"ShardOperator"
);
...
...
@@ -441,12 +445,16 @@ void bindSamplerOps(py::module *m) {
.
def
(
py
::
init
<
int64_t
,
bool
>
(),
py
::
arg
(
"kVal"
),
py
::
arg
(
"shuffle"
));
(
void
)
py
::
class_
<
RandomSampler
,
Sampler
,
std
::
shared_ptr
<
RandomSampler
>>
(
*
m
,
"RandomSampler"
)
.
def
(
py
::
init
<
bool
,
int64_t
>
(),
py
::
arg
(
"replacement"
),
py
::
arg
(
"numSamples"
))
.
def
(
py
::
init
<
bool
>
(),
py
::
arg
(
"replacement"
));
.
def
(
py
::
init
<
bool
,
bool
,
int64_t
>
(),
py
::
arg
(
"replacement"
),
py
::
arg
(
"reshuffle_each_epoch"
),
py
::
arg
(
"num_samples"
))
.
def
(
py
::
init
<
bool
,
bool
>
(),
py
::
arg
(
"replacement"
),
py
::
arg
(
"reshuffle_each_epoch"
));
(
void
)
py
::
class_
<
SequentialSampler
,
Sampler
,
std
::
shared_ptr
<
SequentialSampler
>>
(
*
m
,
"SequentialSampler"
)
.
def
(
py
::
init
<>
());
(
void
)
py
::
class_
<
SubsetSampler
,
Sampler
,
std
::
shared_ptr
<
SubsetSampler
>>
(
*
m
,
"SubsetSampler"
)
.
def
(
py
::
init
<
int64_t
,
int64_t
>
(),
py
::
arg
(
"start_index"
),
py
::
arg
(
"subset_size"
));
(
void
)
py
::
class_
<
SubsetRandomSampler
,
Sampler
,
std
::
shared_ptr
<
SubsetRandomSampler
>>
(
*
m
,
"SubsetRandomSampler"
)
.
def
(
py
::
init
<
std
::
vector
<
int64_t
>>
(),
py
::
arg
(
"indices"
));
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt
浏览文件 @
2e3d55ed
...
...
@@ -8,5 +8,6 @@ add_library(engine-datasetops-source-sampler OBJECT
sampler.cc
sequential_sampler.cc
subset_random_sampler.cc
subset_sampler.cc
weighted_random_sampler.cc
)
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
浏览文件 @
2e3d55ed
...
...
@@ -55,13 +55,27 @@ Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer
}
else
if
(
cnt_
==
samples_per_buffer_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
child_ids_
));
}
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
cnt_
,
DataBuffer
::
kDeBFlagNone
);
std
::
shared_ptr
<
Tensor
>
sample_ids
;
RETURN_IF_NOT_OK
(
CreateSamplerTensor
(
&
sample_ids
,
samples_per_buffer_
));
int64_t
*
id_ptr
=
reinterpret_cast
<
int64_t
*>
(
sample_ids
->
GetMutableBuffer
());
while
(
cnt_
<
samples_per_buffer_
)
{
int64_t
next_id
=
(
num_devices_
*
(
cnt_
++
)
+
device_id_
)
%
num_rows_
;
*
(
id_ptr
++
)
=
shuffle_
?
shuffle_vec_
[
static_cast
<
size_t
>
(
next_id
)]
:
next_id
;
int64_t
sampled_id
=
(
num_devices_
*
cnt_
+
device_id_
)
%
num_rows_
;
if
(
shuffle_
)
{
sampled_id
=
shuffle_vec_
[
static_cast
<
size_t
>
(
sampled_id
)];
}
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
GetAssociatedChildId
(
&
sampled_id
,
sampled_id
));
}
*
id_ptr
=
sampled_id
;
id_ptr
++
;
cnt_
++
;
}
TensorRow
row
(
1
,
sample_ids
);
(
*
out_buffer
)
->
set_tensor_table
(
std
::
make_unique
<
TensorQTable
>
(
1
,
row
));
...
...
@@ -72,11 +86,29 @@ Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer
Status
DistributedSampler
::
Reset
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
cnt_
==
samples_per_buffer_
,
"ERROR Reset() called early/late"
);
cnt_
=
0
;
rnd_
.
seed
(
seed_
++
);
if
(
shuffle_
==
true
)
{
rnd_
.
seed
(
seed_
);
seed_
++
;
std
::
shuffle
(
shuffle_vec_
.
begin
(),
shuffle_vec_
.
end
(),
rnd_
);
}
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
Reset
());
}
return
Status
::
OK
();
}
void
DistributedSampler
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
out
<<
"(sampler): DistributedSampler
\n
"
;
if
(
show_all
)
{
out
<<
"seed_: "
<<
seed_
<<
'\n'
;
out
<<
"device_id_: "
<<
device_id_
<<
'\n'
;
out
<<
"num_devices_: "
<<
num_devices_
<<
'\n'
;
out
<<
"shuffle_: "
<<
shuffle_
<<
'\n'
;
}
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h
浏览文件 @
2e3d55ed
...
...
@@ -48,6 +48,8 @@ class DistributedSampler : public Sampler {
// @return - The error code return
Status
Reset
()
override
;
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
private:
int64_t
cnt_
;
// number of samples that have already been filled in to buffer
uint32_t
seed_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
浏览文件 @
2e3d55ed
...
...
@@ -38,6 +38,7 @@ Status PKSampler::InitSampler() {
rnd_
.
seed
(
seed_
++
);
num_pk_samples_
=
samples_per_class_
*
static_cast
<
int64_t
>
(
labels_
.
size
());
samples_per_buffer_
=
(
samples_per_buffer_
>
num_pk_samples_
)
?
num_pk_samples_
:
samples_per_buffer_
;
num_samples_
=
num_pk_samples_
;
if
(
shuffle_
==
true
)
{
std
::
shuffle
(
labels_
.
begin
(),
labels_
.
end
(),
rnd_
);
}
else
{
...
...
@@ -53,6 +54,10 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
}
else
if
(
next_id_
==
num_pk_samples_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
child_ids_
));
}
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
next_id_
,
DataBuffer
::
kDeBFlagNone
);
std
::
shared_ptr
<
Tensor
>
sample_ids
;
int64_t
last_id
=
...
...
@@ -63,8 +68,16 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
int64_t
cls_id
=
next_id_
++
/
samples_per_class_
;
const
std
::
vector
<
int64_t
>
&
samples
=
label_to_ids_
[
labels_
[
cls_id
]];
int64_t
rnd_ind
=
std
::
uniform_int_distribution
<
int64_t
>
(
0
,
samples
.
size
()
-
1
)(
rnd_
);
*
(
id_ptr
++
)
=
samples
[
rnd_ind
];
int64_t
sampled_id
=
samples
[
rnd_ind
];
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
GetAssociatedChildId
(
&
sampled_id
,
sampled_id
));
}
*
id_ptr
=
sampled_id
;
id_ptr
++
;
}
TensorRow
row
(
1
,
sample_ids
);
(
*
out_buffer
)
->
set_tensor_table
(
std
::
make_unique
<
TensorQTable
>
(
1
,
row
));
}
...
...
@@ -75,6 +88,11 @@ Status PKSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED
(
next_id_
==
num_pk_samples_
,
"ERROR Reset() called early/late"
);
next_id_
=
0
;
rnd_
.
seed
(
seed_
++
);
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
Reset
());
}
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc
浏览文件 @
2e3d55ed
...
...
@@ -27,6 +27,10 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if
(
need_to_reset_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
child_ids_
));
}
std
::
shared_ptr
<
Tensor
>
sample_ids
;
{
py
::
gil_scoped_acquire
gil_acquire
;
...
...
@@ -38,6 +42,14 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
py
::
object
py_ret
=
py_sampler_instance
.
attr
(
"_get_indices"
)();
py
::
array
np_sample_ids
=
py_ret
.
cast
<
py
::
array
>
();
Tensor
::
CreateTensor
(
&
sample_ids
,
np_sample_ids
);
// copy numpy to tensor
if
(
HasChildSampler
())
{
for
(
auto
it
=
sample_ids
->
begin
<
int64_t
>
();
it
!=
sample_ids
->
end
<
int64_t
>
();
++
it
)
{
int64_t
associated_child_id
=
0
;
RETURN_IF_NOT_OK
(
GetAssociatedChildId
(
&
associated_child_id
,
associated_child_id
));
*
it
=
associated_child_id
;
}
}
}
catch
(
const
py
::
error_already_set
&
e
)
{
return
Status
(
StatusCode
::
kPyFuncException
,
e
.
what
());
}
catch
(
const
py
::
cast_error
&
e
)
{
...
...
@@ -79,6 +91,11 @@ Status PythonSampler::Reset() {
}
catch
(
const
py
::
error_already_set
&
e
)
{
return
Status
(
StatusCode
::
kPyFuncException
,
e
.
what
());
}
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
Reset
());
}
return
Status
::
OK
();
}
}
// namespace dataset
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
浏览文件 @
2e3d55ed
...
...
@@ -14,18 +14,22 @@
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include <algorithm>
#include <limits>
#include <memory>
#include "dataset/util/random.h"
namespace
mindspore
{
namespace
dataset
{
RandomSampler
::
RandomSampler
(
bool
replacement
,
int64_t
num_samples
,
int64_t
samples_per_buffer
)
RandomSampler
::
RandomSampler
(
bool
replacement
,
bool
reshuffle_each_epoch
,
int64_t
num_samples
,
int64_t
samples_per_buffer
)
:
Sampler
(
samples_per_buffer
),
seed_
(
GetSeed
()),
replacement_
(
replacement
),
user_num_samples_
(
num_samples
),
next_id_
(
0
),
reshuffle_each_epoch_
(
reshuffle_each_epoch
),
dist
(
nullptr
)
{}
Status
RandomSampler
::
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
...
...
@@ -34,13 +38,29 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
}
else
if
(
next_id_
==
num_samples_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
child_ids_
));
}
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
next_id_
,
DataBuffer
::
kDeBFlagNone
);
std
::
shared_ptr
<
Tensor
>
sampleIds
;
int64_t
last_id
=
s
amples_per_buffer_
+
next_id_
>
num_samples_
?
num_samples_
:
samples_per_buffer_
+
next_id_
;
int64_t
last_id
=
s
td
::
min
(
samples_per_buffer_
+
next_id_
,
num_samples_
)
;
RETURN_IF_NOT_OK
(
CreateSamplerTensor
(
&
sampleIds
,
last_id
-
next_id_
));
int64_t
*
id_ptr
=
reinterpret_cast
<
int64_t
*>
(
sampleIds
->
GetMutableBuffer
());
for
(
int64_t
i
=
0
;
i
<
(
last_id
-
next_id_
);
i
++
)
{
*
(
id_ptr
+
i
)
=
replacement_
?
(
*
dist
)(
rnd_
)
:
shuffled_ids_
[
static_cast
<
size_t
>
(
i
+
next_id_
)];
int64_t
sampled_id
=
0
;
if
(
replacement_
)
{
sampled_id
=
(
*
dist
)(
rnd_
);
}
else
{
sampled_id
=
shuffled_ids_
[
static_cast
<
size_t
>
(
i
+
next_id_
)];
}
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
GetAssociatedChildId
(
&
sampled_id
,
sampled_id
));
}
*
(
id_ptr
+
i
)
=
sampled_id
;
}
next_id_
=
last_id
;
TensorRow
row
(
1
,
sampleIds
);
...
...
@@ -53,7 +73,9 @@ Status RandomSampler::InitSampler() {
num_samples_
=
(
user_num_samples_
<
num_samples_
)
?
user_num_samples_
:
num_samples_
;
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
&&
num_rows_
>
0
,
"both num_samples & num_rows need to be positive"
);
samples_per_buffer_
=
samples_per_buffer_
>
num_samples_
?
num_samples_
:
samples_per_buffer_
;
rnd_
.
seed
(
seed_
++
);
rnd_
.
seed
(
seed_
);
if
(
replacement_
==
false
)
{
shuffled_ids_
.
reserve
(
num_rows_
);
for
(
int64_t
i
=
0
;
i
<
num_rows_
;
i
++
)
{
...
...
@@ -69,11 +91,33 @@ Status RandomSampler::InitSampler() {
Status
RandomSampler
::
Reset
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
next_id_
==
num_samples_
,
"ERROR Reset() called early/late"
);
next_id_
=
0
;
rnd_
.
seed
(
seed_
++
);
if
(
replacement_
==
false
)
{
if
(
reshuffle_each_epoch_
)
{
seed_
++
;
}
rnd_
.
seed
(
seed_
);
if
(
replacement_
==
false
&&
reshuffle_each_epoch_
)
{
std
::
shuffle
(
shuffled_ids_
.
begin
(),
shuffled_ids_
.
end
(),
rnd_
);
}
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
Reset
());
}
return
Status
::
OK
();
}
void
RandomSampler
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
out
<<
"(sampler): RandomSampler
\n
"
;
if
(
show_all
)
{
out
<<
"user_num_samples_: "
<<
user_num_samples_
<<
'\n'
;
out
<<
"num_samples_: "
<<
num_samples_
<<
'\n'
;
out
<<
"next_id_: "
<<
next_id_
<<
'\n'
;
}
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h
浏览文件 @
2e3d55ed
...
...
@@ -30,7 +30,8 @@ class RandomSampler : public Sampler {
// @param bool replacement - put he id back / or not after a sample
// @param int64_t numSamples - number samples to draw
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit
RandomSampler
(
bool
replacement
=
false
,
int64_t
num_samples
=
std
::
numeric_limits
<
int64_t
>::
max
(),
explicit
RandomSampler
(
bool
replacement
=
false
,
bool
reshuffle_each_epoch
=
true
,
int64_t
num_samples
=
std
::
numeric_limits
<
int64_t
>::
max
(),
int64_t
samples_per_buffer
=
std
::
numeric_limits
<
int64_t
>::
max
());
// Destructor.
...
...
@@ -49,6 +50,8 @@ class RandomSampler : public Sampler {
// @return - The error code return
Status
Reset
()
override
;
virtual
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
;
private:
uint32_t
seed_
;
bool
replacement_
;
...
...
@@ -57,6 +60,7 @@ class RandomSampler : public Sampler {
int64_t
next_id_
;
std
::
mt19937
rnd_
;
std
::
unique_ptr
<
std
::
uniform_int_distribution
<
int64_t
>>
dist
;
bool
reshuffle_each_epoch_
;
};
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
浏览文件 @
2e3d55ed
...
...
@@ -15,18 +15,41 @@
*/
#include "dataset/engine/datasetops/source/sampler/sampler.h"
#include <string>
namespace
mindspore
{
namespace
dataset
{
Sampler
::
Sampler
(
int64_t
samples_per_buffer
)
:
DatasetOp
(
0
),
num_rows_
(
0
),
num_samples_
(
0
),
samples_per_buffer_
(
samples_per_buffer
),
col_desc_
(
nullptr
)
{}
Status
Sampler
::
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
)
{
std
::
shared_ptr
<
Sampler
>
child_sampler
;
if
(
HasChildSampler
())
{
child_sampler
=
std
::
dynamic_pointer_cast
<
Sampler
>
(
child_
[
0
]);
if
(
!
child_sampler
)
{
std
::
string
err_msg
(
"Cannot handshake, child is not a sampler object."
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
// Handshake and init child first.
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_sampler
->
HandshakeRandomAccessOp
(
op
));
}
}
CHECK_FAIL_RETURN_UNEXPECTED
(
op
!=
nullptr
,
"RandomAccessOp is nullptr
\n
"
);
RETURN_IF_NOT_OK
(
op
->
GetNumSamples
(
&
num_samples_
));
RETURN_IF_NOT_OK
(
op
->
GetNumRowsInDataset
(
&
num_rows_
));
if
(
HasChildSampler
())
{
int64_t
child_num_samples
=
child_sampler
->
num_samples
();
num_rows_
=
child_num_samples
;
}
else
{
RETURN_IF_NOT_OK
(
op
->
GetNumRowsInDataset
(
&
num_rows_
));
}
// It's up to the derived class to check the validity of the two args
// Because some sampler only needs one of the arg (weighted_random_sampler)
RETURN_IF_NOT_OK
(
InitSampler
());
// init sampler after callback
return
Status
::
OK
();
}
...
...
@@ -44,6 +67,15 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
return
Status
::
OK
();
}
void
Sampler
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
out
<<
"(sampler): base
\n
"
;
if
(
show_all
)
{
out
<<
"num_rows_: "
<<
num_rows_
<<
'\n'
;
out
<<
"num_samples_: "
<<
num_samples_
<<
'\n'
;
}
}
Status
Sampler
::
GetAllIdsThenReset
(
py
::
array
*
data
)
{
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
shared_ptr
<
Tensor
>
sample_ids
;
...
...
@@ -84,5 +116,45 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
num_rows_
=
num_rows
;
return
Status
::
OK
();
}
Status
Sampler
::
AddChild
(
std
::
shared_ptr
<
DatasetOp
>
child
)
{
if
(
child
==
nullptr
)
{
return
Status
::
OK
();
}
// Only samplers can be added, not any other DatasetOp.
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
dynamic_pointer_cast
<
Sampler
>
(
child
);
if
(
!
sampler
)
{
std
::
string
err_msg
(
"Cannot add child, child is not a sampler object."
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
// Samplers can have at most 1 child.
if
(
!
child_
.
empty
())
{
std
::
string
err_msg
(
"Cannot add child sampler, this sampler already has a child."
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
child_
.
push_back
(
child
);
// doesn't work, protected?
// child->AddParent(this);
return
Status
::
OK
();
}
bool
Sampler
::
HasChildSampler
()
{
return
!
child_
.
empty
();
}
Status
Sampler
::
GetAssociatedChildId
(
int64_t
*
out_associated_id
,
int64_t
id
)
{
if
(
child_ids_
==
nullptr
)
{
RETURN_STATUS_UNEXPECTED
(
"Trying to get associated child id, but there are no child ids!"
);
}
TensorRow
sample_row
;
RETURN_IF_NOT_OK
(
child_ids_
->
GetRow
(
0
,
&
sample_row
));
std
::
shared_ptr
<
Tensor
>
sample_ids
=
sample_row
[
0
];
RETURN_IF_NOT_OK
(
sample_ids
->
GetItemAt
<
int64_t
>
(
out_associated_id
,
{
id
}));
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
浏览文件 @
2e3d55ed
...
...
@@ -90,6 +90,8 @@ class Sampler : public DatasetOp {
// setter function for num_samples_
Status
SetNumSamples
(
int64_t
num_samples
);
int64_t
num_samples
()
{
return
num_samples_
;
}
// first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples
// @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds()
// @return
...
...
@@ -114,17 +116,48 @@ class Sampler : public DatasetOp {
// @return - The error code return
Status
operator
()()
final
{
RETURN_STATUS_UNEXPECTED
(
"Functor not supported in Sampler"
);
}
// Adds a sampler to become our child.
// @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
// @return - The error code returned.
Status
AddChild
(
std
::
shared_ptr
<
DatasetOp
>
child
);
// A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
// @param std::shared_ptr<Tensor>* sampleIds
// @param int64_t numElements - must be a non 0 number
// @return
// @return
- The error code returned.
Status
CreateSamplerTensor
(
std
::
shared_ptr
<
Tensor
>
*
sample_ids
,
int64_t
num_elements
);
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
Sampler
&
sampler
)
{
sampler
.
Print
(
out
,
false
);
return
out
;
}
// Checks if this sampler has a child sampler.
// @return - tre if there is a child sampler, false otherwise.
bool
HasChildSampler
();
// Uses id as an index for the list of ids generated by the child sampler, and gets the
// associated id.
// @param int64_t* out_associated_id - Out parameter, contains the associated id.
// @param int64_t id - The id used as an index to get the associated child id.
// @return - The error code returned.
Status
GetAssociatedChildId
(
int64_t
*
out_associated_id
,
int64_t
id
);
protected:
// Number of rows of data from the place this sampler is sampling from. If this sampler
// has a child sampler, num_rows_ is the number of ids the child sampler will
// output. Otherwise, num_rows_ is the number of rows in the dataset.
int64_t
num_rows_
;
// Number of ids this sampler will return.
int64_t
num_samples_
;
// The max number of ids a DataBuffer returned by this sampler will contain.
int64_t
samples_per_buffer_
;
std
::
unique_ptr
<
ColDescriptor
>
col_desc_
;
std
::
unique_ptr
<
DataBuffer
>
child_ids_
;
};
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
浏览文件 @
2e3d55ed
...
...
@@ -15,6 +15,7 @@
*/
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include <algorithm>
#include <memory>
namespace
mindspore
{
...
...
@@ -27,14 +28,26 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
}
else
if
(
next_id_
==
num_samples_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
child_ids_
));
}
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
next_id_
,
DataBuffer
::
kDeBFlagNone
);
std
::
shared_ptr
<
Tensor
>
sampleIds
;
int64_t
lastId
=
(
samples_per_buffer_
+
next_id_
>
num_samples_
)
?
num_samples_
:
samples_per_buffer_
+
next_id_
;
RETURN_IF_NOT_OK
(
CreateSamplerTensor
(
&
sampleIds
,
lastId
-
next_id_
));
int64_t
*
idPtr
=
reinterpret_cast
<
int64_t
*>
(
sampleIds
->
GetMutableBuffer
());
while
(
next_id_
<
lastId
)
{
*
(
idPtr
++
)
=
next_id_
++
;
int64_t
sampled_id
=
next_id_
;
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
GetAssociatedChildId
(
&
sampled_id
,
sampled_id
));
}
*
idPtr
=
sampled_id
;
next_id_
++
;
idPtr
++
;
}
TensorRow
row
(
1
,
sampleIds
);
(
*
out_buffer
)
->
set_tensor_table
(
std
::
make_unique
<
TensorQTable
>
(
1
,
row
));
}
...
...
@@ -43,6 +56,10 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
Status
SequentialSampler
::
InitSampler
()
{
num_samples_
=
(
num_samples_
<=
0
)
?
num_rows_
:
num_samples_
;
// if num_samples < 0, try if num_rows is set
if
(
HasChildSampler
())
{
num_samples_
=
std
::
min
(
num_samples_
,
num_rows_
);
}
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
&&
samples_per_buffer_
>
0
,
"Fail to init Sequential Sampler"
);
samples_per_buffer_
=
samples_per_buffer_
>
num_samples_
?
num_samples_
:
samples_per_buffer_
;
return
Status
::
OK
();
...
...
@@ -51,7 +68,15 @@ Status SequentialSampler::InitSampler() {
Status
SequentialSampler
::
Reset
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
next_id_
==
num_samples_
,
"ERROR Reset() called early/late"
);
next_id_
=
0
;
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
Reset
());
}
return
Status
::
OK
();
}
void
SequentialSampler
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
out
<<
"(sampler): SequentialSampler
\n
"
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h
浏览文件 @
2e3d55ed
...
...
@@ -45,6 +45,8 @@ class SequentialSampler : public Sampler {
// @return - The error code return
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
private:
int64_t
next_id_
;
};
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
浏览文件 @
2e3d55ed
...
...
@@ -34,6 +34,8 @@ SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, in
Status
SubsetRandomSampler
::
InitSampler
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
,
"num_rows <= 0
\n
"
);
num_samples_
=
indices_
.
size
();
// Initialize random generator with seed from config manager
rand_gen_
.
seed
(
GetSeed
());
...
...
@@ -56,6 +58,10 @@ Status SubsetRandomSampler::Reset() {
rand_gen_
.
seed
(
GetSeed
());
std
::
shuffle
(
indices_
.
begin
(),
indices_
.
end
(),
rand_gen_
);
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
Reset
());
}
return
Status
::
OK
();
}
...
...
@@ -65,6 +71,10 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe
if
(
sample_id_
==
indices_
.
size
())
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
child_ids_
));
}
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagNone
);
std
::
shared_ptr
<
Tensor
>
outputIds
;
...
...
@@ -87,7 +97,14 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
*
(
id_ptr
++
)
=
indices_
[
sample_id_
++
];
int64_t
sampled_id
=
indices_
[
sample_id_
];
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
GetAssociatedChildId
(
&
sampled_id
,
sampled_id
));
}
*
id_ptr
=
sampled_id
;
id_ptr
++
;
sample_id_
++
;
}
// Create a TensorTable from that single tensor and push into DataBuffer
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc
0 → 100644
浏览文件 @
2e3d55ed
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include <memory>
#include <string>
#include "dataset/core/config_manager.h"
#include "dataset/core/global_context.h"
namespace
mindspore
{
namespace
dataset
{
// Constructor.
SubsetSampler
::
SubsetSampler
(
int64_t
start_index
,
int64_t
subset_size
)
:
Sampler
(
subset_size
),
start_index_
(
start_index
),
subset_size_
(
subset_size
),
current_id_
(
0
)
{}
Status
SubsetSampler
::
InitSampler
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
subset_size_
>
0
,
"subset_size_ <= 0
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
start_index_
>=
0
,
"start_index < 0
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
start_index_
<
num_rows_
,
"start_index >= num_rows_
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
start_index_
+
subset_size_
-
1
<
num_rows_
,
"Final index out of bounds.
\n
"
);
num_samples_
=
subset_size_
;
return
Status
::
OK
();
}
Status
SubsetSampler
::
Reset
()
{
current_id_
=
0
;
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
Reset
());
}
return
Status
::
OK
();
}
Status
SubsetSampler
::
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
if
(
current_id_
>
subset_size_
)
{
RETURN_STATUS_UNEXPECTED
(
"SubsetSampler Internal Error"
);
}
else
if
(
current_id_
==
subset_size_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
child_ids_
));
}
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagNone
);
std
::
shared_ptr
<
Tensor
>
sampled_ids
;
RETURN_IF_NOT_OK
(
CreateSamplerTensor
(
&
sampled_ids
,
subset_size_
));
int64_t
*
sampled_ids_start_addr
=
reinterpret_cast
<
int64_t
*>
(
sampled_ids
->
GetMutableBuffer
());
while
(
current_id_
<
subset_size_
)
{
int64_t
sampled_id
=
start_index_
+
current_id_
;
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
GetAssociatedChildId
(
&
sampled_id
,
sampled_id
));
}
*
(
sampled_ids_start_addr
+
current_id_
)
=
sampled_id
;
current_id_
++
;
}
TensorRow
sampled_ids_row
(
1
,
sampled_ids
);
(
*
out_buffer
)
->
set_tensor_table
(
std
::
make_unique
<
TensorQTable
>
(
1
,
sampled_ids_row
));
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h
0 → 100644
浏览文件 @
2e3d55ed
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
#include <memory>
#include <vector>
#include "dataset/engine/datasetops/source/sampler/sampler.h"
namespace
mindspore
{
namespace
dataset
{
class
SubsetSampler
:
public
Sampler
{
public:
// Constructor.
// @param start_index The index we start sampling from.
explicit
SubsetSampler
(
int64_t
start_index
,
int64_t
subset_size
);
// Destructor.
~
SubsetSampler
()
=
default
;
// Initialize the sampler.
// @return Status
Status
InitSampler
()
override
;
// Reset the internal variable to the initial state and reshuffle the indices.
// @return Status
Status
Reset
()
override
;
// Get the sample ids.
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
// @note the sample ids (int64_t) will be placed in one Tensor.
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
private:
int64_t
start_index_
;
int64_t
subset_size_
;
int64_t
current_id_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
浏览文件 @
2e3d55ed
...
...
@@ -40,6 +40,8 @@ WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights,
Status
WeightedRandomSampler
::
InitSampler
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
&&
user_num_samples_
,
"num_samples & num_rows need to be positive"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
samples_per_buffer_
>
0
,
"samples_per_buffer<=0
\n
"
);
num_samples_
=
user_num_samples_
;
// Initialize random generator with seed from config manager
rand_gen_
.
seed
(
GetSeed
());
...
...
@@ -81,6 +83,11 @@ Status WeightedRandomSampler::Reset() {
}
else
{
discrete_dist_
->
reset
();
}
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
Reset
());
}
return
Status
::
OK
();
}
...
...
@@ -98,6 +105,10 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
if
(
sample_id_
==
user_num_samples_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
child_ids_
));
}
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagNone
);
std
::
shared_ptr
<
Tensor
>
outputIds
;
...
...
@@ -127,7 +138,12 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
RETURN_STATUS_UNEXPECTED
(
"generated id is bigger than numRows (out of bound)."
);
}
*
(
id_ptr
++
)
=
genId
;
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
GetAssociatedChildId
(
&
genId
,
genId
));
}
*
id_ptr
=
genId
;
id_ptr
++
;
sample_id_
++
;
}
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
2e3d55ed
此差异已折叠。
点击以展开。
mindspore/dataset/engine/samplers.py
浏览文件 @
2e3d55ed
...
...
@@ -47,6 +47,7 @@ class Sampler:
def
__init__
(
self
):
self
.
dataset_size
=
0
self
.
num_samples
=
0
self
.
child_sampler
=
None
def
__iter__
(
self
):
"""
...
...
@@ -83,7 +84,35 @@ class Sampler:
# Instance fetcher
# Do not override this method!
def
create
(
self
):
return
cde
.
PythonSampler
(
self
)
c_sampler
=
cde
.
PythonSampler
(
self
)
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
def
add_child
(
self
,
sampler
):
self
.
child_sampler
=
sampler
def
get_child
(
self
):
return
self
.
child_sampler
def
create_child
(
self
):
c_child_sampler
=
None
if
self
.
child_sampler
is
not
None
:
c_child_sampler
=
self
.
child_sampler
.
create
()
return
c_child_sampler
def
is_shuffled
(
self
):
if
self
.
child_sampler
is
None
:
return
False
return
self
.
child_sampler
.
is_shuffled
()
def
is_sharded
(
self
):
if
self
.
child_sampler
is
None
:
return
False
return
self
.
child_sampler
.
is_sharded
()
class
BuiltinSampler
:
...
...
@@ -93,11 +122,30 @@ class BuiltinSampler:
User should not extend this class.
"""
def
__init__
(
self
):
pass
self
.
child_sampler
=
None
def
create
(
self
):
pass
def
add_child
(
self
,
sampler
):
self
.
child_sampler
=
sampler
def
get_child
(
self
):
return
self
.
child_sampler
def
create_child
(
self
):
c_child_sampler
=
None
if
self
.
child_sampler
is
not
None
:
c_child_sampler
=
self
.
child_sampler
.
create
()
return
c_child_sampler
def
is_shuffled
(
self
):
raise
NotImplementedError
(
"Sampler must implement is_shuffled."
)
def
is_sharded
(
self
):
raise
NotImplementedError
(
"Sampler must implement is_sharded."
)
class
DistributedSampler
(
BuiltinSampler
):
"""
...
...
@@ -142,7 +190,22 @@ class DistributedSampler(BuiltinSampler):
def
create
(
self
):
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
self
.
seed
+=
1
return
cde
.
DistributedSampler
(
self
.
num_shards
,
self
.
shard_id
,
self
.
shuffle
,
self
.
seed
)
c_sampler
=
cde
.
DistributedSampler
(
self
.
num_shards
,
self
.
shard_id
,
self
.
shuffle
,
self
.
seed
)
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
def
is_shuffled
(
self
):
if
self
.
child_sampler
is
None
:
return
self
.
shuffle
return
self
.
child_sampler
.
is_shuffled
()
def
is_sharded
(
self
):
if
self
.
child_sampler
is
None
:
return
self
.
num_shards
>
1
return
self
.
child_sampler
.
is_sharded
()
class
PKSampler
(
BuiltinSampler
):
...
...
@@ -186,7 +249,22 @@ class PKSampler(BuiltinSampler):
super
().
__init__
()
def
create
(
self
):
return
cde
.
PKSampler
(
self
.
num_val
,
self
.
shuffle
)
c_sampler
=
cde
.
PKSampler
(
self
.
num_val
,
self
.
shuffle
)
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
def
is_shuffled
(
self
):
if
self
.
child_sampler
is
None
:
return
self
.
shuffle
return
self
.
child_sampler
.
is_shuffled
()
def
is_sharded
(
self
):
if
self
.
child_sampler
is
None
:
return
False
return
self
.
child_sampler
.
is_sharded
()
def
_create_for_minddataset
(
self
):
if
not
self
.
class_column
or
not
isinstance
(
self
.
class_column
,
str
):
...
...
@@ -226,15 +304,31 @@ class RandomSampler(BuiltinSampler):
raise
ValueError
(
"num_samples should be a positive integer "
"value, but got num_samples={}"
.
format
(
num_samples
))
self
.
deterministic
=
False
self
.
replacement
=
replacement
self
.
num_samples
=
num_samples
self
.
reshuffle_each_epoch
=
True
super
().
__init__
()
def
create
(
self
):
# If num_samples is not specified, then call constructor #2
c_sampler
=
None
if
self
.
num_samples
is
None
:
return
cde
.
RandomSampler
(
self
.
replacement
)
return
cde
.
RandomSampler
(
self
.
replacement
,
self
.
num_samples
)
c_sampler
=
cde
.
RandomSampler
(
self
.
replacement
,
self
.
reshuffle_each_epoch
)
else
:
c_sampler
=
cde
.
RandomSampler
(
self
.
replacement
,
self
.
reshuffle_each_epoch
,
self
.
num_samples
)
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
def
is_shuffled
(
self
):
return
True
def
is_sharded
(
self
):
if
self
.
child_sampler
is
None
:
return
False
return
self
.
child_sampler
.
is_sharded
()
class
SequentialSampler
(
BuiltinSampler
):
...
...
@@ -252,7 +346,80 @@ class SequentialSampler(BuiltinSampler):
"""
def
create
(
self
):
return
cde
.
SequentialSampler
()
c_sampler
=
cde
.
SequentialSampler
()
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
def
is_shuffled
(
self
):
if
self
.
child_sampler
is
None
:
return
False
return
self
.
child_sampler
.
is_shuffled
()
def
is_sharded
(
self
):
if
self
.
child_sampler
is
None
:
return
False
return
self
.
child_sampler
.
is_sharded
()
class
SubsetSampler
(
BuiltinSampler
):
"""
Samples a subset of elements consecutively from a given index.
Args:
start_index (int): Index to start sampling at.
subset_size (int): How many samples to include in this subset.
Examples:
>>> import mindspore.dataset as ds
>>>
>>> dataset_dir = "path/to/imagefolder_directory"
>>>
>>> # creates a SubsetSampler, will sample the next 5 images from the 100th image.
>>> sampler = ds.SubsetSampler(100, 5)
>>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
Raises:
ValueError: If start_index is not a positive int.
ValueError: If subset_size is not a positive int.
"""
def
__init__
(
self
,
start_index
,
subset_size
):
if
not
isinstance
(
start_index
,
int
):
raise
ValueError
(
"start_index should be an int."
)
if
start_index
<
0
:
raise
ValueError
(
"start_index should not be negative."
)
if
not
isinstance
(
subset_size
,
int
):
raise
ValueError
(
"start_index should be an int"
)
if
subset_size
<
0
:
raise
ValueError
(
"subset_size should not be negative."
)
self
.
start_index
=
start_index
self
.
subset_size
=
subset_size
super
().
__init__
()
def
create
(
self
):
c_sampler
=
cde
.
SubsetSampler
(
self
.
start_index
,
self
.
subset_size
)
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
def
is_shuffled
(
self
):
if
self
.
child_sampler
is
None
:
return
False
return
self
.
child_sampler
.
is_shuffled
()
def
is_sharded
(
self
):
if
self
.
child_sampler
is
None
:
return
False
return
self
.
child_sampler
.
is_sharded
()
class
SubsetRandomSampler
(
BuiltinSampler
):
...
...
@@ -282,7 +449,19 @@ class SubsetRandomSampler(BuiltinSampler):
super
().
__init__
()
def
create
(
self
):
return
cde
.
SubsetRandomSampler
(
self
.
indices
)
c_sampler
=
cde
.
SubsetRandomSampler
(
self
.
indices
)
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
def
is_shuffled
(
self
):
return
True
def
is_sharded
(
self
):
if
self
.
child_sampler
is
None
:
return
False
return
self
.
child_sampler
.
is_sharded
()
def
_create_for_minddataset
(
self
):
return
cde
.
MindrecordSubsetRandomSampler
(
self
.
indices
)
...
...
@@ -330,4 +509,16 @@ class WeightedRandomSampler(BuiltinSampler):
super
().
__init__
()
def
create
(
self
):
return
cde
.
WeightedRandomSampler
(
self
.
weights
,
self
.
num_samples
,
self
.
replacement
)
c_sampler
=
cde
.
WeightedRandomSampler
(
self
.
weights
,
self
.
num_samples
,
self
.
replacement
)
c_child_sampler
=
self
.
create_child
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
def
is_shuffled
(
self
):
return
True
def
is_sharded
(
self
):
if
self
.
child_sampler
is
None
:
return
False
return
self
.
child_sampler
.
is_sharded
()
mindspore/dataset/engine/validators.py
浏览文件 @
2e3d55ed
...
...
@@ -1031,3 +1031,44 @@ def check_textfiledataset(method):
return
method
(
*
args
,
**
kwargs
)
return
new_method
def
check_split
(
method
):
"""check the input arguments of split."""
@
wraps
(
method
)
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
nreq_param_list
=
[
'sizes'
]
nreq_param_bool
=
[
'randomize'
]
check_param_type
(
nreq_param_list
,
param_dict
,
list
)
check_param_type
(
nreq_param_bool
,
param_dict
,
bool
)
# check sizes: must be list of float or list of int
sizes
=
param_dict
.
get
(
'sizes'
)
if
not
sizes
:
raise
ValueError
(
"sizes cannot be empty."
)
all_int
=
all
(
isinstance
(
item
,
int
)
for
item
in
sizes
)
all_float
=
all
(
isinstance
(
item
,
float
)
for
item
in
sizes
)
if
not
(
all_int
or
all_float
):
raise
ValueError
(
"sizes should be list of int or list of float."
)
if
all_int
:
all_positive
=
all
(
item
>
0
for
item
in
sizes
)
if
not
all_positive
:
raise
ValueError
(
"sizes is a list of int, but there should be no negative numbers."
)
if
all_float
:
all_valid_percentages
=
all
(
0
<
item
<=
1
for
item
in
sizes
)
if
not
all_valid_percentages
:
raise
ValueError
(
"sizes is a list of float, but there should be no numbers outside the range [0, 1]."
)
epsilon
=
0.00001
if
not
abs
(
sum
(
sizes
)
-
1
)
<
epsilon
:
raise
ValueError
(
"sizes is a list of float, but the percentages do not sum up to 1."
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
tests/ut/cpp/dataset/cifar_op_test.cc
浏览文件 @
2e3d55ed
...
...
@@ -92,7 +92,7 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) {
TEST_F
(
MindDataTestCifarOp
,
TestRandomSamplerCifar10
)
{
uint32_t
original_seed
=
GlobalContext
::
config_manager
()
->
seed
();
GlobalContext
::
config_manager
()
->
set_seed
(
0
);
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
RandomSampler
>
(
true
,
12
);
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
RandomSampler
>
(
true
,
true
,
12
);
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
auto
tree
=
Build
({
Cifarop
(
16
,
2
,
32
,
folder_path
,
std
::
move
(
sampler
),
100
)});
tree
->
Prepare
();
...
...
tests/ut/cpp/dataset/image_folder_op_test.cc
浏览文件 @
2e3d55ed
...
...
@@ -138,7 +138,7 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomImageFolder) {
TEST_F
(
MindDataTestImageFolderSampler
,
TestRandomSamplerImageFolder
)
{
int32_t
original_seed
=
GlobalContext
::
config_manager
()
->
seed
();
GlobalContext
::
config_manager
()
->
set_seed
(
0
);
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
RandomSampler
>
(
true
,
12
);
std
::
unique_ptr
<
Sampler
>
sampler
=
std
::
make_unique
<
RandomSampler
>
(
true
,
true
,
12
);
int32_t
res
[]
=
{
2
,
2
,
2
,
3
,
2
,
3
,
2
,
3
,
1
,
2
,
2
,
1
};
// ground truth label
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data"
;
auto
tree
=
Build
({
ImageFolder
(
16
,
2
,
32
,
folder_path
,
false
,
std
::
move
(
sampler
))});
...
...
tests/ut/python/dataset/test_sampler.py
浏览文件 @
2e3d55ed
...
...
@@ -164,9 +164,36 @@ def test_python_sampler():
assert
list
(
sp1
.
get_indices
())
==
[
0
,
1
,
2
,
3
,
4
]
def
test_sampler_chain
():
manifest_file
=
"../data/dataset/testManifestData/test5trainimgs.json"
map
=
{(
172876
,
0
):
0
,
(
54214
,
0
):
1
,
(
54214
,
1
):
2
,
(
173673
,
0
):
3
,
(
64631
,
1
):
4
}
def
test_config
(
num_shards
,
shard_id
):
sampler
=
ds
.
DistributedSampler
(
num_shards
,
shard_id
,
False
)
child_sampler
=
ds
.
SequentialSampler
()
sampler
.
add_child
(
child_sampler
)
data1
=
ds
.
ManifestDataset
(
manifest_file
,
num_samples
=
5
,
sampler
=
sampler
)
res
=
[]
for
item
in
data1
.
create_dict_iterator
():
logger
.
info
(
"item[image].shape[0]: {}, item[label].item(): {}"
.
format
(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
()))
res
.
append
(
map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
return
res
assert
test_config
(
2
,
0
)
==
[
0
,
2
,
4
]
assert
test_config
(
2
,
1
)
==
[
1
,
3
,
0
]
assert
test_config
(
5
,
0
)
==
[
0
]
assert
test_config
(
5
,
1
)
==
[
1
]
assert
test_config
(
5
,
2
)
==
[
2
]
assert
test_config
(
5
,
3
)
==
[
3
]
assert
test_config
(
5
,
4
)
==
[
4
]
if
__name__
==
'__main__'
:
test_sequential_sampler
(
True
)
test_random_sampler
(
True
)
test_random_sampler_multi_iter
(
True
)
test_sampler_py_api
()
test_python_sampler
()
test_sampler_chain
()
tests/ut/python/dataset/test_split.py
0 → 100644
浏览文件 @
2e3d55ed
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
pytest
import
mindspore.dataset
as
ds
# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
# the label of each image is [0,0,0,1,1] each image can be uniquely identified
# via the following lookup table (dict){(83554, 0): 0, (54214, 0): 1, (54214, 1): 2, (65512, 0): 3, (64631, 1): 4}
manifest_file
=
"../data/dataset/testManifestData/test5trainimgs.json"
manifest_map
=
{(
172876
,
0
):
0
,
(
54214
,
0
):
1
,
(
54214
,
1
):
2
,
(
173673
,
0
):
3
,
(
64631
,
1
):
4
}
def
split_with_invalid_inputs
(
d
):
with
pytest
.
raises
(
ValueError
)
as
info
:
s1
,
s2
=
d
.
split
([])
assert
"sizes cannot be empty"
in
str
(
info
.
value
)
with
pytest
.
raises
(
ValueError
)
as
info
:
s1
,
s2
=
d
.
split
([
5
,
0.6
])
assert
"sizes should be list of int or list of float"
in
str
(
info
.
value
)
with
pytest
.
raises
(
ValueError
)
as
info
:
s1
,
s2
=
d
.
split
([
-
1
,
6
])
assert
"there should be no negative numbers"
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
s1
,
s2
=
d
.
split
([
3
,
1
])
assert
"sum of split sizes 4 is not equal to dataset size 5"
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
s1
,
s2
=
d
.
split
([
5
,
1
])
assert
"sum of split sizes 6 is not equal to dataset size 5"
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
s1
,
s2
=
d
.
split
([
0.15
,
0.15
,
0.15
,
0.15
,
0.15
,
0.25
])
assert
"sum of calculated split sizes 6 is not equal to dataset size 5"
in
str
(
info
.
value
)
with
pytest
.
raises
(
ValueError
)
as
info
:
s1
,
s2
=
d
.
split
([
-
0.5
,
0.5
])
assert
"there should be no numbers outside the range [0, 1]"
in
str
(
info
.
value
)
with
pytest
.
raises
(
ValueError
)
as
info
:
s1
,
s2
=
d
.
split
([
1.5
,
0.5
])
assert
"there should be no numbers outside the range [0, 1]"
in
str
(
info
.
value
)
with
pytest
.
raises
(
ValueError
)
as
info
:
s1
,
s2
=
d
.
split
([
0.5
,
0.6
])
assert
"percentages do not sum up to 1"
in
str
(
info
.
value
)
with
pytest
.
raises
(
ValueError
)
as
info
:
s1
,
s2
=
d
.
split
([
0.3
,
0.6
])
assert
"percentages do not sum up to 1"
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
s1
,
s2
=
d
.
split
([
0.05
,
0.95
])
assert
"percentage 0.05 is too small"
in
str
(
info
.
value
)
def
test_unmappable_invalid_input
():
text_file_dataset_path
=
"../data/dataset/testTextFileDataset/*"
d
=
ds
.
TextFileDataset
(
text_file_dataset_path
)
split_with_invalid_inputs
(
d
)
d
=
ds
.
TextFileDataset
(
text_file_dataset_path
,
num_shards
=
2
,
shard_id
=
0
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
s1
,
s2
=
d
.
split
([
4
,
1
])
assert
"dataset should not be sharded before split"
in
str
(
info
.
value
)
def
test_unmappable_split
():
text_file_dataset_path
=
"../data/dataset/testTextFileDataset/*"
text_file_data
=
[
"This is a text file."
,
"Another file."
,
"Be happy every day."
,
"End of file."
,
"Good luck to everyone."
]
ds
.
config
.
set_num_parallel_workers
(
4
)
d
=
ds
.
TextFileDataset
(
text_file_dataset_path
,
shuffle
=
False
)
s1
,
s2
=
d
.
split
([
4
,
1
],
randomize
=
False
)
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
assert
s1_output
==
text_file_data
[
0
:
4
]
assert
s2_output
==
text_file_data
[
4
:]
# exact percentages
s1
,
s2
=
d
.
split
([
0.8
,
0.2
],
randomize
=
False
)
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
assert
s1_output
==
text_file_data
[
0
:
4
]
assert
s2_output
==
text_file_data
[
4
:]
# fuzzy percentages
s1
,
s2
=
d
.
split
([
0.33
,
0.67
],
randomize
=
False
)
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
assert
s1_output
==
text_file_data
[
0
:
2
]
assert
s2_output
==
text_file_data
[
2
:]
def
test_mappable_invalid_input
():
d
=
ds
.
ManifestDataset
(
manifest_file
)
split_with_invalid_inputs
(
d
)
d
=
ds
.
ManifestDataset
(
manifest_file
,
num_shards
=
2
,
shard_id
=
0
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
s1
,
s2
=
d
.
split
([
4
,
1
])
assert
"dataset should not be sharded before split"
in
str
(
info
.
value
)
def
test_mappable_split_general
():
d
=
ds
.
ManifestDataset
(
manifest_file
,
shuffle
=
False
)
d
=
d
.
take
(
5
)
# absolute rows
s1
,
s2
=
d
.
split
([
4
,
1
],
randomize
=
False
)
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
assert
s1_output
==
[
0
,
1
,
2
,
3
]
assert
s2_output
==
[
4
]
# exact percentages
s1
,
s2
=
d
.
split
([
0.8
,
0.2
],
randomize
=
False
)
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
assert
s1_output
==
[
0
,
1
,
2
,
3
]
assert
s2_output
==
[
4
]
# fuzzy percentages
s1
,
s2
=
d
.
split
([
0.33
,
0.67
],
randomize
=
False
)
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
assert
s1_output
==
[
0
,
1
]
assert
s2_output
==
[
2
,
3
,
4
]
def
test_mappable_split_optimized
():
d
=
ds
.
ManifestDataset
(
manifest_file
,
shuffle
=
False
)
# absolute rows
s1
,
s2
=
d
.
split
([
4
,
1
],
randomize
=
False
)
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
assert
s1_output
==
[
0
,
1
,
2
,
3
]
assert
s2_output
==
[
4
]
# exact percentages
s1
,
s2
=
d
.
split
([
0.8
,
0.2
],
randomize
=
False
)
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
assert
s1_output
==
[
0
,
1
,
2
,
3
]
assert
s2_output
==
[
4
]
# fuzzy percentages
s1
,
s2
=
d
.
split
([
0.33
,
0.67
],
randomize
=
False
)
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
assert
s1_output
==
[
0
,
1
]
assert
s2_output
==
[
2
,
3
,
4
]
def
test_mappable_randomize_deterministic
():
# set arbitrary seed for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
ds
.
config
.
set_seed
(
53
)
d
=
ds
.
ManifestDataset
(
manifest_file
,
shuffle
=
False
)
s1
,
s2
=
d
.
split
([
0.8
,
0.2
])
for
_
in
range
(
10
):
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
# note no overlap
assert
s1_output
==
[
0
,
1
,
3
,
4
]
assert
s2_output
==
[
2
]
def
test_mappable_randomize_repeatable
():
# set arbitrary seed for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
ds
.
config
.
set_seed
(
53
)
d
=
ds
.
ManifestDataset
(
manifest_file
,
shuffle
=
False
)
s1
,
s2
=
d
.
split
([
0.8
,
0.2
])
num_epochs
=
5
s1
=
s1
.
repeat
(
num_epochs
)
s2
=
s2
.
repeat
(
num_epochs
)
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
# note no overlap
assert
s1_output
==
[
0
,
1
,
3
,
4
]
*
num_epochs
assert
s2_output
==
[
2
]
*
num_epochs
def
test_mappable_sharding
():
# set arbitrary seed for repeatability for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
ds
.
config
.
set_seed
(
53
)
num_epochs
=
5
first_split_num_rows
=
4
d
=
ds
.
ManifestDataset
(
manifest_file
,
shuffle
=
False
)
s1
,
s2
=
d
.
split
([
first_split_num_rows
,
1
])
distributed_sampler
=
ds
.
DistributedSampler
(
2
,
0
)
s1
.
use_sampler
(
distributed_sampler
)
s1
=
s1
.
repeat
(
num_epochs
)
# testing sharding, second dataset to simulate another instance
d2
=
ds
.
ManifestDataset
(
manifest_file
,
shuffle
=
False
)
d2s1
,
d2s2
=
d2
.
split
([
first_split_num_rows
,
1
])
distributed_sampler
=
ds
.
DistributedSampler
(
2
,
1
)
d2s1
.
use_sampler
(
distributed_sampler
)
d2s1
=
d2s1
.
repeat
(
num_epochs
)
# shard 0
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
# shard 1
d2s1_output
=
[]
for
item
in
d2s1
.
create_dict_iterator
():
d2s1_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
rows_per_shard_per_epoch
=
2
assert
len
(
s1_output
)
==
rows_per_shard_per_epoch
*
num_epochs
assert
len
(
d2s1_output
)
==
rows_per_shard_per_epoch
*
num_epochs
# verify each epoch that
# 1. shards contain no common elements
# 2. the data was split the same way, and that the union of shards equal the split
correct_sorted_split_result
=
[
0
,
1
,
3
,
4
]
for
i
in
range
(
num_epochs
):
combined_data
=
[]
for
j
in
range
(
rows_per_shard_per_epoch
):
combined_data
.
append
(
s1_output
[
i
*
rows_per_shard_per_epoch
+
j
])
combined_data
.
append
(
d2s1_output
[
i
*
rows_per_shard_per_epoch
+
j
])
assert
sorted
(
combined_data
)
==
correct_sorted_split_result
# test other split
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
d2s2_output
=
[]
for
item
in
d2s2
.
create_dict_iterator
():
d2s2_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
assert
s2_output
==
[
2
]
assert
d2s2_output
==
[
2
]
if
__name__
==
'__main__'
:
test_unmappable_invalid_input
()
test_unmappable_split
()
test_mappable_invalid_input
()
test_mappable_split_general
()
test_mappable_split_optimized
()
test_mappable_randomize_deterministic
()
test_mappable_randomize_repeatable
()
test_mappable_sharding
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录