Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4e8e82f2
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4e8e82f2
编写于
5月 27, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 27, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1457 fix 3 bug reports for split
Merge pull request !1457 from Peilin/splitOp-after-testing
上级
3aeb91ee
d4c93575
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
381 addition
and
38 deletion
+381
-38
example/lstm_aclImdb/train.py
example/lstm_aclImdb/train.py
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
...ataset/engine/datasetops/source/sampler/random_sampler.cc
+8
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
...ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
+1
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc
...ataset/engine/datasetops/source/sampler/subset_sampler.cc
+3
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h
...dataset/engine/datasetops/source/sampler/subset_sampler.h
+1
-1
mindspore/dataset/__init__.py
mindspore/dataset/__init__.py
+1
-1
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+64
-16
mindspore/dataset/engine/samplers.py
mindspore/dataset/engine/samplers.py
+22
-0
tests/ut/python/dataset/test_sampler.py
tests/ut/python/dataset/test_sampler.py
+46
-0
tests/ut/python/dataset/test_split.py
tests/ut/python/dataset/test_split.py
+234
-10
未找到文件。
example/lstm_aclImdb/train.py
浏览文件 @
4e8e82f2
...
...
@@ -71,7 +71,7 @@ if __name__ == '__main__':
model
=
Model
(
network
,
loss
,
opt
,
{
'acc'
:
Accuracy
()})
print
(
"============== Starting Training =============="
)
ds_train
=
create_dataset
(
args
.
preprocess_path
,
cfg
.
batch_size
,
repeat_num
=
cfg
.
num_epochs
)
ds_train
=
create_dataset
(
args
.
preprocess_path
,
cfg
.
batch_size
,
cfg
.
num_epochs
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"lstm"
,
directory
=
args
.
ckpt_path
,
config
=
config_ck
)
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
浏览文件 @
4e8e82f2
...
...
@@ -70,21 +70,26 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
}
Status
RandomSampler
::
InitSampler
()
{
num_samples_
=
(
user_num_samples_
<
num_samples_
)
?
user_num_samples_
:
num_samples_
;
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
&&
num_rows_
>
0
,
"both num_samples & num_rows need to be positive"
);
samples_per_buffer_
=
samples_per_buffer_
>
num_samples_
?
num_samples_
:
samples_per_buffer_
;
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
,
"num_rows needs to be positive."
);
rnd_
.
seed
(
seed_
);
if
(
replacement_
==
false
)
{
num_samples_
=
std
::
min
(
num_samples_
,
num_rows_
);
shuffled_ids_
.
reserve
(
num_rows_
);
for
(
int64_t
i
=
0
;
i
<
num_rows_
;
i
++
)
{
shuffled_ids_
.
push_back
(
i
);
}
std
::
shuffle
(
shuffled_ids_
.
begin
(),
shuffled_ids_
.
end
(),
rnd_
);
}
else
{
num_samples_
=
std
::
min
(
num_samples_
,
user_num_samples_
);
dist
=
std
::
make_unique
<
std
::
uniform_int_distribution
<
int64_t
>>
(
0
,
num_rows_
-
1
);
}
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
,
"num_samples needs to be positive."
);
samples_per_buffer_
=
samples_per_buffer_
>
num_samples_
?
num_samples_
:
samples_per_buffer_
;
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
浏览文件 @
4e8e82f2
...
...
@@ -32,10 +32,8 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
}
// Handshake and init child first.
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_sampler
->
HandshakeRandomAccessOp
(
op
));
}
}
CHECK_FAIL_RETURN_UNEXPECTED
(
op
!=
nullptr
,
"RandomAccessOp is nullptr
\n
"
);
RETURN_IF_NOT_OK
(
op
->
GetNumSamples
(
&
num_samples_
));
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc
浏览文件 @
4e8e82f2
/**
* Copyright 20
19
Huawei Technologies Co., Ltd
* Copyright 20
20
Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
...
...
@@ -28,9 +28,9 @@ SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size)
:
Sampler
(
subset_size
),
start_index_
(
start_index
),
subset_size_
(
subset_size
),
current_id_
(
0
)
{}
Status
SubsetSampler
::
InitSampler
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
subset_size_
>
0
,
"subset_size
_
<= 0
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
subset_size_
>
0
,
"subset_size <= 0
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
start_index_
>=
0
,
"start_index < 0
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
start_index_
<
num_rows_
,
"start_index >= num_rows
_
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
start_index_
<
num_rows_
,
"start_index >= num_rows
\n
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
start_index_
+
subset_size_
-
1
<
num_rows_
,
"Final index out of bounds.
\n
"
);
num_samples_
=
subset_size_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h
浏览文件 @
4e8e82f2
/**
* Copyright 20
19
Huawei Technologies Co., Ltd
* Copyright 20
20
Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
...
...
mindspore/dataset/__init__.py
浏览文件 @
4e8e82f2
...
...
@@ -23,7 +23,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset
GeneratorDataset
,
ManifestDataset
,
Cifar10Dataset
,
Cifar100Dataset
,
VOCDataset
,
CelebADataset
,
TextFileDataset
,
\
Schema
,
Shuffle
,
zip
,
RandomDataset
from
.engine.samplers
import
DistributedSampler
,
PKSampler
,
RandomSampler
,
SequentialSampler
,
SubsetRandomSampler
,
\
WeightedRandomSampler
,
Sampler
WeightedRandomSampler
,
S
ubsetSampler
,
S
ampler
from
.engine.serializer_deserializer
import
serialize
,
deserialize
,
show
from
.engine.graphdata
import
GraphData
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
4e8e82f2
...
...
@@ -633,9 +633,9 @@ class Dataset:
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
of the original dataset. If after rounding, any size equals 0, an error will occur.
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
randomize (bool
): determines whether or not to split the data randomly. If true, the data
will be randomly split. Otherwise, each split will be created with consecutive rows
from the dataset.
randomize (bool
, optional): determines whether or not to split the data randomly (default=True).
If true, the data will be randomly split. Otherwise, each split will be created with
consecutive rows
from the dataset.
Note:
1. Dataset cannot be sharded if split is going to be called.
...
...
@@ -678,7 +678,8 @@ class Dataset:
ds
=
copy
.
deepcopy
(
self
)
if
randomize
:
# want to shuffle the same way every epoch before split
ds
=
ds
.
shuffle
()
# in alter_tree, shuffle buffer is minimum 10000, so use 10000 here
ds
=
ds
.
shuffle
(
10000
)
ds
.
reshuffle_each_epoch
=
False
if
rows_to_skip
>
0
:
...
...
@@ -1209,6 +1210,9 @@ class MappableDataset(SourceDataset):
>>> new_sampler = ds.DistributedSampler(10, 2)
>>> data.use_sampler(new_sampler)
"""
if
new_sampler
is
not
None
and
not
isinstance
(
new_sampler
,
(
samplers
.
BuiltinSampler
,
samplers
.
Sampler
)):
raise
TypeError
(
"new_sampler is not an instance of a sampler."
)
self
.
sampler
=
self
.
sampler
.
child_sampler
self
.
add_sampler
(
new_sampler
)
...
...
@@ -1218,6 +1222,11 @@ class MappableDataset(SourceDataset):
def
is_sharded
(
self
):
raise
NotImplementedError
(
"MappableDataset must implement is_sharded."
)
def
_get_sampler_dataset_size
(
self
):
if
self
.
sampler
is
not
None
:
return
self
.
sampler
.
get_dataset_size
()
return
None
@
check_split
def
split
(
self
,
sizes
,
randomize
=
True
):
...
...
@@ -1236,9 +1245,9 @@ class MappableDataset(SourceDataset):
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
of the original dataset. If after rounding, any size equals 0, an error will occur.
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
randomize (bool
): determines whether or not to split the data randomly. If true, the data
will be randomly split. Otherwise, each split will be created with consecutive rows
from the dataset.
randomize (bool
, optional): determines whether or not to split the data randomly (default=True).
If true, the data will be randomly split. Otherwise, each split will be created with
consecutive rows
from the dataset.
Note:
1. Dataset should not be sharded if split is going to be called. Instead, create a
...
...
@@ -2105,7 +2114,6 @@ class TransferDataset(DatasetOp):
self
.
iterator
=
TupleIterator
(
self
)
class
RangeDataset
(
MappableDataset
):
"""
A source dataset that reads and parses datasets stored on disk in a range.
...
...
@@ -2296,8 +2304,13 @@ class ImageFolderDatasetV2(MappableDataset):
else
:
num_samples
=
self
.
num_samples
num_rows
=
ImageFolderOp
.
get_num_rows_and_classes
(
self
.
dataset_dir
,
num_samples
)[
0
]
rows_per_shard
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
return
get_num_rows
(
num_rows
,
self
.
num_shards
)
if
rows_from_sampler
is
None
:
return
rows_per_shard
return
min
(
rows_from_sampler
,
rows_per_shard
)
def
num_classes
(
self
):
"""
...
...
@@ -2425,8 +2438,13 @@ class MnistDataset(MappableDataset):
num_samples
=
self
.
num_samples
num_rows
=
MnistOp
.
get_num_rows
(
self
.
dataset_dir
,
num_samples
)
rows_per_shard
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
if
rows_from_sampler
is
None
:
return
rows_per_shard
return
get_num_rows
(
num_rows
,
self
.
num_shards
)
return
min
(
rows_from_sampler
,
rows_per_shard
)
def
is_shuffled
(
self
):
if
self
.
shuffle_level
is
None
:
...
...
@@ -2926,8 +2944,13 @@ class GeneratorDataset(MappableDataset):
Return:
Number, number of batches.
"""
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
if
rows_from_sampler
is
None
:
return
self
.
_dataset_size
return
min
(
rows_from_sampler
,
self
.
_dataset_size
)
# manually set dataset_size as a temporary solution.
def
set_dataset_size
(
self
,
value
):
if
value
>=
0
:
...
...
@@ -3220,8 +3243,13 @@ class ManifestDataset(MappableDataset):
class_indexing
=
self
.
class_indexing
num_rows
=
ManifestOp
.
get_num_rows_and_classes
(
self
.
dataset_file
,
num_samples
,
class_indexing
,
self
.
usage
)[
0
]
rows_per_shard
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
return
get_num_rows
(
num_rows
,
self
.
num_shards
)
if
rows_from_sampler
is
None
:
return
rows_per_shard
return
min
(
rows_from_sampler
,
rows_per_shard
)
def
num_classes
(
self
):
"""
...
...
@@ -3379,8 +3407,13 @@ class Cifar10Dataset(MappableDataset):
num_samples
=
self
.
num_samples
num_rows
=
CifarOp
.
get_num_rows
(
self
.
dataset_dir
,
num_samples
,
True
)
rows_per_shard
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
if
rows_from_sampler
is
None
:
return
rows_per_shard
return
get_num_rows
(
num_rows
,
self
.
num_shards
)
return
min
(
rows_from_sampler
,
rows_per_shard
)
def
is_shuffled
(
self
):
if
self
.
shuffle_level
is
None
:
...
...
@@ -3498,8 +3531,13 @@ class Cifar100Dataset(MappableDataset):
num_samples
=
self
.
num_samples
num_rows
=
CifarOp
.
get_num_rows
(
self
.
dataset_dir
,
num_samples
,
False
)
rows_per_shard
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
return
get_num_rows
(
num_rows
,
self
.
num_shards
)
if
rows_from_sampler
is
None
:
return
rows_per_shard
return
min
(
rows_from_sampler
,
rows_per_shard
)
def
is_shuffled
(
self
):
if
self
.
shuffle_level
is
None
:
...
...
@@ -3562,7 +3600,12 @@ class RandomDataset(SourceDataset):
Return:
Number, number of batches.
"""
return
num_samples
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
if
rows_from_sampler
is
None
:
return
self
.
num_samples
return
min
(
rows_from_sampler
,
self
.
num_samples
)
def
is_shuffled
(
self
):
return
True
...
...
@@ -3871,8 +3914,13 @@ class VOCDataset(MappableDataset):
Return:
Number, number of batches.
"""
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
if
rows_from_sampler
is
None
:
return
self
.
num_samples
return
min
(
rows_from_sampler
,
self
.
num_samples
)
def
get_class_indexing
(
self
):
"""
Get the class index.
...
...
mindspore/dataset/engine/samplers.py
浏览文件 @
4e8e82f2
...
...
@@ -114,6 +114,9 @@ class Sampler:
return
self
.
child_sampler
.
is_sharded
()
def
get_dataset_size
(
self
):
return
self
.
_get_indices
().
size
class
BuiltinSampler
:
"""
...
...
@@ -146,6 +149,12 @@ class BuiltinSampler:
def
is_sharded
(
self
):
raise
NotImplementedError
(
"Sampler must implement is_sharded."
)
def
get_dataset_size
(
self
):
if
self
.
child_sampler
is
not
None
:
return
self
.
child_sampler
.
get_dataset_size
()
return
None
class
DistributedSampler
(
BuiltinSampler
):
"""
...
...
@@ -330,6 +339,9 @@ class RandomSampler(BuiltinSampler):
return
self
.
child_sampler
.
is_sharded
()
def
get_dataset_size
(
self
):
return
self
.
num_samples
class
SequentialSampler
(
BuiltinSampler
):
"""
...
...
@@ -421,6 +433,9 @@ class SubsetSampler(BuiltinSampler):
return
self
.
child_sampler
.
is_sharded
()
def
get_dataset_size
(
self
):
return
self
.
subset_size
class
SubsetRandomSampler
(
BuiltinSampler
):
"""
...
...
@@ -467,6 +482,10 @@ class SubsetRandomSampler(BuiltinSampler):
return
cde
.
MindrecordSubsetRandomSampler
(
self
.
indices
)
def
get_dataset_size
(
self
):
return
len
(
indices
)
class
WeightedRandomSampler
(
BuiltinSampler
):
"""
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
...
...
@@ -522,3 +541,6 @@ class WeightedRandomSampler(BuiltinSampler):
return
False
return
self
.
child_sampler
.
is_sharded
()
def
get_dataset_size
(
self
):
return
self
.
num_samples
tests/ut/python/dataset/test_sampler.py
浏览文件 @
4e8e82f2
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import
numpy
as
np
import
pytest
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
...
...
@@ -164,6 +165,35 @@ def test_python_sampler():
assert
list
(
sp1
.
get_indices
())
==
[
0
,
1
,
2
,
3
,
4
]
def
test_subset_sampler
():
manifest_file
=
"../data/dataset/testManifestData/test5trainimgs.json"
map
=
{(
172876
,
0
):
0
,
(
54214
,
0
):
1
,
(
54214
,
1
):
2
,
(
173673
,
0
):
3
,
(
64631
,
1
):
4
}
def
test_config
(
num_samples
,
start_index
,
subset_size
):
sampler
=
ds
.
SubsetSampler
(
start_index
,
subset_size
)
d
=
ds
.
ManifestDataset
(
manifest_file
,
sampler
=
sampler
)
res
=
[]
for
item
in
d
.
create_dict_iterator
():
res
.
append
(
map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
return
res
with
pytest
.
raises
(
RuntimeError
)
as
info
:
test_config
(
5
,
0
,
0
)
assert
"subset_size <= 0"
in
str
(
info
.
value
)
assert
test_config
(
5
,
0
,
1
)
==
[
0
]
assert
test_config
(
5
,
0
,
2
)
==
[
0
,
1
]
assert
test_config
(
5
,
0
,
3
)
==
[
0
,
1
,
2
]
assert
test_config
(
5
,
0
,
4
)
==
[
0
,
1
,
2
,
3
]
assert
test_config
(
5
,
0
,
5
)
==
[
0
,
1
,
2
,
3
,
4
]
assert
test_config
(
5
,
1
,
1
)
==
[
1
]
assert
test_config
(
5
,
2
,
3
)
==
[
2
,
3
,
4
]
assert
test_config
(
5
,
3
,
2
)
==
[
3
,
4
]
assert
test_config
(
5
,
4
,
1
)
==
[
4
]
def
test_sampler_chain
():
manifest_file
=
"../data/dataset/testManifestData/test5trainimgs.json"
map
=
{(
172876
,
0
):
0
,
(
54214
,
0
):
1
,
(
54214
,
1
):
2
,
(
173673
,
0
):
3
,
(
64631
,
1
):
4
}
...
...
@@ -190,10 +220,26 @@ def test_sampler_chain():
assert
test_config
(
5
,
3
)
==
[
3
]
assert
test_config
(
5
,
4
)
==
[
4
]
def
test_add_sampler_invalid_input
():
manifest_file
=
"../data/dataset/testManifestData/test5trainimgs.json"
map
=
{(
172876
,
0
):
0
,
(
54214
,
0
):
1
,
(
54214
,
1
):
2
,
(
173673
,
0
):
3
,
(
64631
,
1
):
4
}
data1
=
ds
.
ManifestDataset
(
manifest_file
)
with
pytest
.
raises
(
TypeError
)
as
info
:
data1
.
use_sampler
(
1
)
assert
"not an instance of a sampler"
in
str
(
info
.
value
)
with
pytest
.
raises
(
TypeError
)
as
info
:
data1
.
use_sampler
(
"sampler"
)
assert
"not an instance of a sampler"
in
str
(
info
.
value
)
if
__name__
==
'__main__'
:
test_sequential_sampler
(
True
)
test_random_sampler
(
True
)
test_random_sampler_multi_iter
(
True
)
test_sampler_py_api
()
test_python_sampler
()
test_subset_sampler
()
test_sampler_chain
()
test_add_sampler_invalid_input
()
tests/ut/python/dataset/test_split.py
浏览文件 @
4e8e82f2
...
...
@@ -23,6 +23,10 @@ from util import config_get_set_num_parallel_workers
manifest_file
=
"../data/dataset/testManifestData/test5trainimgs.json"
manifest_map
=
{(
172876
,
0
):
0
,
(
54214
,
0
):
1
,
(
54214
,
1
):
2
,
(
173673
,
0
):
3
,
(
64631
,
1
):
4
}
text_file_dataset_path
=
"../data/dataset/testTextFileDataset/*"
text_file_data
=
[
"This is a text file."
,
"Another file."
,
"Be happy every day."
,
"End of file."
,
"Good luck to everyone."
]
def
split_with_invalid_inputs
(
d
):
with
pytest
.
raises
(
ValueError
)
as
info
:
s1
,
s2
=
d
.
split
([])
...
...
@@ -68,8 +72,8 @@ def split_with_invalid_inputs(d):
s1
,
s2
=
d
.
split
([
0.05
,
0.95
])
assert
"percentage 0.05 is too small"
in
str
(
info
.
value
)
def
test_unmappable_invalid_input
():
text_file_dataset_path
=
"../data/dataset/testTextFileDataset/*"
d
=
ds
.
TextFileDataset
(
text_file_dataset_path
)
split_with_invalid_inputs
(
d
)
...
...
@@ -78,11 +82,10 @@ def test_unmappable_invalid_input():
s1
,
s2
=
d
.
split
([
4
,
1
])
assert
"dataset should not be sharded before split"
in
str
(
info
.
value
)
def
test_unmappable_split
():
text_file_dataset_path
=
"../data/dataset/testTextFileDataset/*"
text_file_data
=
[
"This is a text file."
,
"Another file."
,
"Be happy every day."
,
"End of file."
,
"Good luck to everyone."
]
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
4
)
d
=
ds
.
TextFileDataset
(
text_file_dataset_path
,
shuffle
=
False
)
s1
,
s2
=
d
.
split
([
4
,
1
],
randomize
=
False
)
...
...
@@ -124,6 +127,142 @@ def test_unmappable_split():
assert
s1_output
==
text_file_data
[
0
:
2
]
assert
s2_output
==
text_file_data
[
2
:]
# Restore configuration num_parallel_workers
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
def
test_unmappable_randomize_deterministic
():
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
4
)
# the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
ds
.
config
.
set_seed
(
53
)
d
=
ds
.
TextFileDataset
(
text_file_dataset_path
,
shuffle
=
False
)
s1
,
s2
=
d
.
split
([
0.8
,
0.2
])
for
_
in
range
(
10
):
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
# note no overlap
assert
s1_output
==
[
text_file_data
[
0
],
text_file_data
[
2
],
text_file_data
[
1
],
text_file_data
[
4
]]
assert
s2_output
==
[
text_file_data
[
3
]]
# Restore configuration num_parallel_workers
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
def
test_unmappable_randomize_repeatable
():
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
4
)
# the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
ds
.
config
.
set_seed
(
53
)
d
=
ds
.
TextFileDataset
(
text_file_dataset_path
,
shuffle
=
False
)
s1
,
s2
=
d
.
split
([
0.8
,
0.2
])
num_epochs
=
5
s1
=
s1
.
repeat
(
num_epochs
)
s2
=
s2
.
repeat
(
num_epochs
)
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
# note no overlap
assert
s1_output
==
[
text_file_data
[
0
],
text_file_data
[
2
],
text_file_data
[
1
],
text_file_data
[
4
]]
*
num_epochs
assert
s2_output
==
[
text_file_data
[
3
]]
*
num_epochs
# Restore configuration num_parallel_workers
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
def
test_unmappable_get_dataset_size
():
d
=
ds
.
TextFileDataset
(
text_file_dataset_path
,
shuffle
=
False
)
s1
,
s2
=
d
.
split
([
0.8
,
0.2
])
assert
d
.
get_dataset_size
()
==
5
assert
s1
.
get_dataset_size
()
==
4
assert
s2
.
get_dataset_size
()
==
1
def
test_unmappable_multi_split
():
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
4
)
# the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
ds
.
config
.
set_seed
(
53
)
d
=
ds
.
TextFileDataset
(
text_file_dataset_path
,
shuffle
=
False
)
s1
,
s2
=
d
.
split
([
4
,
1
])
s1_correct_output
=
[
text_file_data
[
0
],
text_file_data
[
2
],
text_file_data
[
1
],
text_file_data
[
4
]]
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
assert
s1_output
==
s1_correct_output
# no randomize in second split
s1s1
,
s1s2
,
s1s3
=
s1
.
split
([
1
,
2
,
1
],
randomize
=
False
)
s1s1_output
=
[]
for
item
in
s1s1
.
create_dict_iterator
():
s1s1_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
s1s2_output
=
[]
for
item
in
s1s2
.
create_dict_iterator
():
s1s2_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
s1s3_output
=
[]
for
item
in
s1s3
.
create_dict_iterator
():
s1s3_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
assert
s1s1_output
==
[
s1_correct_output
[
0
]]
assert
s1s2_output
==
[
s1_correct_output
[
1
],
s1_correct_output
[
2
]]
assert
s1s3_output
==
[
s1_correct_output
[
3
]]
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
assert
s2_output
==
[
text_file_data
[
3
]]
# randomize in second split
# the labels outputted by the ShuffleOp for seed 53 is [2, 3, 1, 0]
shuffled_ids
=
[
2
,
3
,
1
,
0
]
s1s1
,
s1s2
,
s1s3
=
s1
.
split
([
1
,
2
,
1
])
s1s1_output
=
[]
for
item
in
s1s1
.
create_dict_iterator
():
s1s1_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
s1s2_output
=
[]
for
item
in
s1s2
.
create_dict_iterator
():
s1s2_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
s1s3_output
=
[]
for
item
in
s1s3
.
create_dict_iterator
():
s1s3_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
assert
s1s1_output
==
[
s1_correct_output
[
shuffled_ids
[
0
]]]
assert
s1s2_output
==
[
s1_correct_output
[
shuffled_ids
[
1
]],
s1_correct_output
[
shuffled_ids
[
2
]]]
assert
s1s3_output
==
[
s1_correct_output
[
shuffled_ids
[
3
]]]
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
item
[
"text"
].
item
().
decode
(
"utf8"
))
assert
s2_output
==
[
text_file_data
[
3
]]
# Restore configuration num_parallel_workers
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
...
...
@@ -137,6 +276,7 @@ def test_mappable_invalid_input():
s1
,
s2
=
d
.
split
([
4
,
1
])
assert
"dataset should not be sharded before split"
in
str
(
info
.
value
)
def
test_mappable_split_general
():
d
=
ds
.
ManifestDataset
(
manifest_file
,
shuffle
=
False
)
d
=
d
.
take
(
5
)
...
...
@@ -183,6 +323,7 @@ def test_mappable_split_general():
assert
s1_output
==
[
0
,
1
]
assert
s2_output
==
[
2
,
3
,
4
]
def
test_mappable_split_optimized
():
d
=
ds
.
ManifestDataset
(
manifest_file
,
shuffle
=
False
)
...
...
@@ -228,9 +369,9 @@ def test_mappable_split_optimized():
assert
s1_output
==
[
0
,
1
]
assert
s2_output
==
[
2
,
3
,
4
]
def
test_mappable_randomize_deterministic
():
# set arbitrary seed for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
ds
.
config
.
set_seed
(
53
)
d
=
ds
.
ManifestDataset
(
manifest_file
,
shuffle
=
False
)
...
...
@@ -249,9 +390,9 @@ def test_mappable_randomize_deterministic():
assert
s1_output
==
[
0
,
1
,
3
,
4
]
assert
s2_output
==
[
2
]
def
test_mappable_randomize_repeatable
():
# set arbitrary seed for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
ds
.
config
.
set_seed
(
53
)
d
=
ds
.
ManifestDataset
(
manifest_file
,
shuffle
=
False
)
...
...
@@ -273,9 +414,10 @@ def test_mappable_randomize_repeatable():
assert
s1_output
==
[
0
,
1
,
3
,
4
]
*
num_epochs
assert
s2_output
==
[
2
]
*
num_epochs
def
test_mappable_sharding
():
# set arbitrary seed for repeatability for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4
, 2
]
ds
.
config
.
set_seed
(
53
)
num_epochs
=
5
...
...
@@ -336,12 +478,94 @@ def test_mappable_sharding():
assert
s2_output
==
[
2
]
assert
d2s2_output
==
[
2
]
def
test_mappable_get_dataset_size
():
d
=
ds
.
ManifestDataset
(
manifest_file
,
shuffle
=
False
)
s1
,
s2
=
d
.
split
([
4
,
1
])
assert
d
.
get_dataset_size
()
==
5
assert
s1
.
get_dataset_size
()
==
4
assert
s2
.
get_dataset_size
()
==
1
def
test_mappable_multi_split
():
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
ds
.
config
.
set_seed
(
53
)
d
=
ds
.
ManifestDataset
(
manifest_file
,
shuffle
=
False
)
s1
,
s2
=
d
.
split
([
4
,
1
])
s1_correct_output
=
[
0
,
1
,
3
,
4
]
s1_output
=
[]
for
item
in
s1
.
create_dict_iterator
():
s1_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
assert
s1_output
==
s1_correct_output
# no randomize in second split
s1s1
,
s1s2
,
s1s3
=
s1
.
split
([
1
,
2
,
1
],
randomize
=
False
)
s1s1_output
=
[]
for
item
in
s1s1
.
create_dict_iterator
():
s1s1_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
s1s2_output
=
[]
for
item
in
s1s2
.
create_dict_iterator
():
s1s2_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
s1s3_output
=
[]
for
item
in
s1s3
.
create_dict_iterator
():
s1s3_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
assert
s1s1_output
==
[
s1_correct_output
[
0
]]
assert
s1s2_output
==
[
s1_correct_output
[
1
],
s1_correct_output
[
2
]]
assert
s1s3_output
==
[
s1_correct_output
[
3
]]
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
assert
s2_output
==
[
2
]
# randomize in second split
# the labels outputted by the RandomSampler for seed 53 is [3, 1, 2, 0]
random_sampler_ids
=
[
3
,
1
,
2
,
0
]
s1s1
,
s1s2
,
s1s3
=
s1
.
split
([
1
,
2
,
1
])
s1s1_output
=
[]
for
item
in
s1s1
.
create_dict_iterator
():
s1s1_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
s1s2_output
=
[]
for
item
in
s1s2
.
create_dict_iterator
():
s1s2_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
s1s3_output
=
[]
for
item
in
s1s3
.
create_dict_iterator
():
s1s3_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
assert
s1s1_output
==
[
s1_correct_output
[
random_sampler_ids
[
0
]]]
assert
s1s2_output
==
[
s1_correct_output
[
random_sampler_ids
[
1
]],
s1_correct_output
[
random_sampler_ids
[
2
]]]
assert
s1s3_output
==
[
s1_correct_output
[
random_sampler_ids
[
3
]]]
s2_output
=
[]
for
item
in
s2
.
create_dict_iterator
():
s2_output
.
append
(
manifest_map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
assert
s2_output
==
[
2
]
if
__name__
==
'__main__'
:
test_unmappable_invalid_input
()
test_unmappable_split
()
test_unmappable_randomize_deterministic
()
test_unmappable_randomize_repeatable
()
test_unmappable_get_dataset_size
()
test_unmappable_multi_split
()
test_mappable_invalid_input
()
test_mappable_split_general
()
test_mappable_split_optimized
()
test_mappable_randomize_deterministic
()
test_mappable_randomize_repeatable
()
test_mappable_sharding
()
test_mappable_get_dataset_size
()
test_mappable_multi_split
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录