Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b78894e0
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b78894e0
编写于
5月 19, 2020
作者:
C
Cathy Wong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cleanup dataset UT: unskip and enhance TFRecord sharding tests
上级
a3b9c238
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
99 addition
and
10 deletion
+99
-10
tests/ut/python/dataset/test_concat.py
tests/ut/python/dataset/test_concat.py
+4
-5
tests/ut/python/dataset/test_datasets_sharding.py
tests/ut/python/dataset/test_datasets_sharding.py
+85
-0
tests/ut/python/dataset/test_five_crop.py
tests/ut/python/dataset/test_five_crop.py
+1
-1
tests/ut/python/dataset/test_tfreader_op.py
tests/ut/python/dataset/test_tfreader_op.py
+9
-4
未找到文件。
tests/ut/python/dataset/test_concat.py
浏览文件 @
b78894e0
...
@@ -21,19 +21,18 @@ import mindspore.dataset.transforms.vision.py_transforms as F
...
@@ -21,19 +21,18 @@ import mindspore.dataset.transforms.vision.py_transforms as F
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
# In generator dataset: Number of rows is 3
, its value is
0, 1, 2
# In generator dataset: Number of rows is 3
; its values are
0, 1, 2
def
generator
():
def
generator
():
for
i
in
range
(
3
):
for
i
in
range
(
3
):
yield
np
.
array
([
i
]),
yield
np
.
array
([
i
]),
# In generator_10 dataset: Number of rows is 7
, its value is 3, 4, 5 ... 10
# In generator_10 dataset: Number of rows is 7
; its values are 3, 4, 5 ... 9
def
generator_10
():
def
generator_10
():
for
i
in
range
(
3
,
10
):
for
i
in
range
(
3
,
10
):
yield
np
.
array
([
i
]),
yield
np
.
array
([
i
]),
# In generator_20 dataset: Number of rows is 10; its values are 10, 11, 12 ... 19
# In generator_20 dataset: Number of rows is 10, its value is 10, 11, 12 ... 20
def
generator_20
():
def
generator_20
():
for
i
in
range
(
10
,
20
):
for
i
in
range
(
10
,
20
):
yield
np
.
array
([
i
]),
yield
np
.
array
([
i
]),
...
@@ -135,7 +134,7 @@ def test_concat_05():
...
@@ -135,7 +134,7 @@ def test_concat_05():
def
test_concat_06
():
def
test_concat_06
():
"""
"""
Test concat: test concat muti datasets in one time
Test concat: test concat mu
l
ti datasets in one time
"""
"""
logger
.
info
(
"test_concat_06"
)
logger
.
info
(
"test_concat_06"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"col1"
])
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"col1"
])
...
...
tests/ut/python/dataset/test_datasets_sharding.py
浏览文件 @
b78894e0
...
@@ -35,6 +35,9 @@ def test_imagefolder_shardings(print_res=False):
...
@@ -35,6 +35,9 @@ def test_imagefolder_shardings(print_res=False):
assert
(
sharding_config
(
4
,
0
,
5
,
False
,
dict
())
==
[
0
,
0
,
0
,
1
,
1
])
# 5 rows
assert
(
sharding_config
(
4
,
0
,
5
,
False
,
dict
())
==
[
0
,
0
,
0
,
1
,
1
])
# 5 rows
assert
(
sharding_config
(
4
,
0
,
12
,
False
,
dict
())
==
[
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
])
# 11 rows
assert
(
sharding_config
(
4
,
0
,
12
,
False
,
dict
())
==
[
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
])
# 11 rows
assert
(
sharding_config
(
4
,
3
,
None
,
False
,
dict
())
==
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
3
])
# 11 rows
assert
(
sharding_config
(
4
,
3
,
None
,
False
,
dict
())
==
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
3
])
# 11 rows
assert
(
sharding_config
(
1
,
0
,
55
,
False
,
dict
())
==
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
]
)
# 44 rows
assert
(
sharding_config
(
2
,
0
,
55
,
False
,
dict
())
==
[
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
]
)
# 22 rows
assert
(
sharding_config
(
2
,
1
,
55
,
False
,
dict
())
==
[
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
,
3
]
)
# 22 rows
# total 22 in dataset rows because of class indexing which takes only 2 folders
# total 22 in dataset rows because of class indexing which takes only 2 folders
assert
(
len
(
sharding_config
(
4
,
0
,
None
,
True
,
{
"class1"
:
111
,
"class2"
:
999
}))
==
6
)
assert
(
len
(
sharding_config
(
4
,
0
,
None
,
True
,
{
"class1"
:
111
,
"class2"
:
999
}))
==
6
)
assert
(
len
(
sharding_config
(
4
,
2
,
3
,
True
,
{
"class1"
:
111
,
"class2"
:
999
}))
==
3
)
assert
(
len
(
sharding_config
(
4
,
2
,
3
,
True
,
{
"class1"
:
111
,
"class2"
:
999
}))
==
3
)
...
@@ -44,6 +47,86 @@ def test_imagefolder_shardings(print_res=False):
...
@@ -44,6 +47,86 @@ def test_imagefolder_shardings(print_res=False):
assert
(
len
(
sharding_config
(
5
,
1
,
None
,
True
,
{
"class1"
:
111
,
"class2"
:
999
},
4
))
==
20
)
assert
(
len
(
sharding_config
(
5
,
1
,
None
,
True
,
{
"class1"
:
111
,
"class2"
:
999
},
4
))
==
20
)
def
test_tfrecord_shardings1
(
print_res
=
False
):
""" Test TFRecordDataset sharding with num_parallel_workers=1 """
# total 40 rows in dataset
tf_files
=
[
"../data/dataset/tf_file_dataset/test1.data"
,
"../data/dataset/tf_file_dataset/test2.data"
,
"../data/dataset/tf_file_dataset/test3.data"
,
"../data/dataset/tf_file_dataset/test4.data"
]
def
sharding_config
(
num_shards
,
shard_id
,
num_samples
,
repeat_cnt
=
1
):
data1
=
ds
.
TFRecordDataset
(
tf_files
,
num_shards
=
num_shards
,
shard_id
=
shard_id
,
num_samples
=
num_samples
,
shuffle
=
ds
.
Shuffle
.
FILES
,
num_parallel_workers
=
1
)
data1
=
data1
.
repeat
(
repeat_cnt
)
res
=
[]
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
res
.
append
(
item
[
"scalars"
][
0
])
if
print_res
:
logger
.
info
(
"scalars of dataset: {}"
.
format
(
res
))
return
res
assert
sharding_config
(
2
,
0
,
None
,
1
)
==
[
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
]
# 20 rows
assert
sharding_config
(
2
,
1
,
None
,
1
)
==
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
]
# 20 rows
assert
sharding_config
(
2
,
0
,
3
,
1
)
==
[
11
,
12
,
13
]
# 3 rows
assert
sharding_config
(
2
,
1
,
3
,
1
)
==
[
1
,
2
,
3
]
# 3 rows
assert
sharding_config
(
2
,
0
,
40
,
1
)
==
[
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
]
# 20 rows
assert
sharding_config
(
2
,
1
,
40
,
1
)
==
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
]
# 20 rows
assert
sharding_config
(
2
,
0
,
55
,
1
)
==
[
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
]
# 20 rows
assert
sharding_config
(
2
,
1
,
55
,
1
)
==
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
]
# 20 rows
assert
sharding_config
(
3
,
0
,
8
,
1
)
==
[
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
]
# 8 rows
assert
sharding_config
(
3
,
1
,
8
,
1
)
==
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
]
# 8 rows
assert
sharding_config
(
3
,
2
,
8
,
1
)
==
[
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
]
# 8 rows
assert
sharding_config
(
4
,
0
,
2
,
1
)
==
[
11
,
12
]
# 2 rows
assert
sharding_config
(
4
,
1
,
2
,
1
)
==
[
1
,
2
]
# 2 rows
assert
sharding_config
(
4
,
2
,
2
,
1
)
==
[
21
,
22
]
# 2 rows
assert
sharding_config
(
4
,
3
,
2
,
1
)
==
[
31
,
32
]
# 2 rows
assert
sharding_config
(
3
,
0
,
4
,
2
)
==
[
11
,
12
,
13
,
14
,
21
,
22
,
23
,
24
]
# 8 rows
assert
sharding_config
(
3
,
1
,
4
,
2
)
==
[
1
,
2
,
3
,
4
,
11
,
12
,
13
,
14
]
# 8 rows
assert
sharding_config
(
3
,
2
,
4
,
2
)
==
[
21
,
22
,
23
,
24
,
31
,
32
,
33
,
34
]
# 8 rows
def
test_tfrecord_shardings4
(
print_res
=
False
):
""" Test TFRecordDataset sharding with num_parallel_workers=4 """
# total 40 rows in dataset
tf_files
=
[
"../data/dataset/tf_file_dataset/test1.data"
,
"../data/dataset/tf_file_dataset/test2.data"
,
"../data/dataset/tf_file_dataset/test3.data"
,
"../data/dataset/tf_file_dataset/test4.data"
]
def
sharding_config
(
num_shards
,
shard_id
,
num_samples
,
repeat_cnt
=
1
):
data1
=
ds
.
TFRecordDataset
(
tf_files
,
num_shards
=
num_shards
,
shard_id
=
shard_id
,
num_samples
=
num_samples
,
shuffle
=
ds
.
Shuffle
.
FILES
,
num_parallel_workers
=
4
)
data1
=
data1
.
repeat
(
repeat_cnt
)
res
=
[]
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
res
.
append
(
item
[
"scalars"
][
0
])
if
print_res
:
logger
.
info
(
"scalars of dataset: {}"
.
format
(
res
))
return
res
def
check_result
(
result_list
,
expect_length
,
expect_set
):
assert
len
(
result_list
)
==
expect_length
assert
set
(
result_list
)
==
expect_set
check_result
(
sharding_config
(
2
,
0
,
None
,
1
),
20
,
{
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
})
check_result
(
sharding_config
(
2
,
1
,
None
,
1
),
20
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
})
check_result
(
sharding_config
(
2
,
0
,
3
,
1
),
3
,
{
11
,
12
,
21
})
check_result
(
sharding_config
(
2
,
1
,
3
,
1
),
3
,
{
1
,
2
,
31
})
check_result
(
sharding_config
(
2
,
0
,
40
,
1
),
20
,
{
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
})
check_result
(
sharding_config
(
2
,
1
,
40
,
1
),
20
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
})
check_result
(
sharding_config
(
2
,
0
,
55
,
1
),
20
,
{
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
})
check_result
(
sharding_config
(
2
,
1
,
55
,
1
),
20
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
})
check_result
(
sharding_config
(
3
,
0
,
8
,
1
),
8
,
{
32
,
33
,
34
,
11
,
12
,
13
,
14
,
31
})
check_result
(
sharding_config
(
3
,
1
,
8
,
1
),
8
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
})
check_result
(
sharding_config
(
3
,
2
,
8
,
1
),
8
,
{
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
})
check_result
(
sharding_config
(
4
,
0
,
2
,
1
),
2
,
{
11
,
12
})
check_result
(
sharding_config
(
4
,
1
,
2
,
1
),
2
,
{
1
,
2
})
check_result
(
sharding_config
(
4
,
2
,
2
,
1
),
2
,
{
21
,
22
})
check_result
(
sharding_config
(
4
,
3
,
2
,
1
),
2
,
{
31
,
32
})
check_result
(
sharding_config
(
3
,
0
,
4
,
2
),
8
,
{
32
,
1
,
2
,
11
,
12
,
21
,
22
,
31
})
check_result
(
sharding_config
(
3
,
1
,
4
,
2
),
8
,
{
1
,
2
,
3
,
4
,
11
,
12
,
13
,
14
})
check_result
(
sharding_config
(
3
,
2
,
4
,
2
),
8
,
{
32
,
33
,
34
,
21
,
22
,
23
,
24
,
31
})
def
test_manifest_shardings
(
print_res
=
False
):
def
test_manifest_shardings
(
print_res
=
False
):
manifest_file
=
"../data/dataset/testManifestData/test5trainimgs.json"
manifest_file
=
"../data/dataset/testManifestData/test5trainimgs.json"
...
@@ -157,6 +240,8 @@ def test_mnist_shardings(print_res=False):
...
@@ -157,6 +240,8 @@ def test_mnist_shardings(print_res=False):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_imagefolder_shardings
(
True
)
test_imagefolder_shardings
(
True
)
test_tfrecord_shardings1
(
True
)
test_tfrecord_shardings4
(
True
)
test_manifest_shardings
(
True
)
test_manifest_shardings
(
True
)
test_voc_shardings
(
True
)
test_voc_shardings
(
True
)
test_cifar10_shardings
(
True
)
test_cifar10_shardings
(
True
)
...
...
tests/ut/python/dataset/test_five_crop.py
浏览文件 @
b78894e0
...
@@ -43,7 +43,7 @@ def visualize(image_1, image_2):
...
@@ -43,7 +43,7 @@ def visualize(image_1, image_2):
plt
.
show
()
plt
.
show
()
def
skip_
test_five_crop_op
():
def
test_five_crop_op
():
"""
"""
Test FiveCrop
Test FiveCrop
"""
"""
...
...
tests/ut/python/dataset/test_tfreader_op.py
浏览文件 @
b78894e0
...
@@ -153,7 +153,7 @@ def test_tf_record_shuffle():
...
@@ -153,7 +153,7 @@ def test_tf_record_shuffle():
assert
np
.
array_equal
(
t1
,
t2
)
assert
np
.
array_equal
(
t1
,
t2
)
def
skip_
test_tf_record_shard
():
def
test_tf_record_shard
():
tf_files
=
[
"../data/dataset/tf_file_dataset/test1.data"
,
"../data/dataset/tf_file_dataset/test2.data"
,
tf_files
=
[
"../data/dataset/tf_file_dataset/test1.data"
,
"../data/dataset/tf_file_dataset/test2.data"
,
"../data/dataset/tf_file_dataset/test3.data"
,
"../data/dataset/tf_file_dataset/test4.data"
]
"../data/dataset/tf_file_dataset/test3.data"
,
"../data/dataset/tf_file_dataset/test4.data"
]
...
@@ -171,12 +171,14 @@ def skip_test_tf_record_shard():
...
@@ -171,12 +171,14 @@ def skip_test_tf_record_shard():
# 2. with enough epochs, both workers will get the entire dataset (e,g. ep1_wrkr1: f1&f3, ep2,_wrkr1 f2&f4)
# 2. with enough epochs, both workers will get the entire dataset (e,g. ep1_wrkr1: f1&f3, ep2,_wrkr1 f2&f4)
worker1_res
=
get_res
(
0
,
16
)
worker1_res
=
get_res
(
0
,
16
)
worker2_res
=
get_res
(
1
,
16
)
worker2_res
=
get_res
(
1
,
16
)
# Confirm each worker gets 3x16=48 rows
assert
len
(
worker1_res
)
==
48
assert
len
(
worker1_res
)
==
len
(
worker2_res
)
# check criteria 1
# check criteria 1
for
i
in
range
(
len
(
worker1_res
)):
for
i
in
range
(
len
(
worker1_res
)):
assert
(
worker1_res
[
i
]
!=
worker2_res
[
i
])
assert
(
worker1_res
[
i
]
!=
worker2_res
[
i
])
# check criteria 2
# check criteria 2
assert
(
set
(
worker2_res
)
==
set
(
worker1_res
))
assert
(
set
(
worker2_res
)
==
set
(
worker1_res
))
assert
(
len
(
set
(
worker2_res
))
==
12
)
def
test_tf_shard_equal_rows
():
def
test_tf_shard_equal_rows
():
...
@@ -198,7 +200,10 @@ def test_tf_shard_equal_rows():
...
@@ -198,7 +200,10 @@ def test_tf_shard_equal_rows():
for
i
in
range
(
len
(
worker1_res
)):
for
i
in
range
(
len
(
worker1_res
)):
assert
(
worker1_res
[
i
]
!=
worker2_res
[
i
])
assert
(
worker1_res
[
i
]
!=
worker2_res
[
i
])
assert
(
worker2_res
[
i
]
!=
worker3_res
[
i
])
assert
(
worker2_res
[
i
]
!=
worker3_res
[
i
])
assert
(
len
(
worker1_res
)
==
28
)
# Confirm each worker gets same number of rows
assert
len
(
worker1_res
)
==
28
assert
len
(
worker1_res
)
==
len
(
worker2_res
)
assert
len
(
worker2_res
)
==
len
(
worker3_res
)
worker4_res
=
get_res
(
1
,
0
,
1
)
worker4_res
=
get_res
(
1
,
0
,
1
)
assert
(
len
(
worker4_res
)
==
40
)
assert
(
len
(
worker4_res
)
==
40
)
...
@@ -272,7 +277,7 @@ if __name__ == '__main__':
...
@@ -272,7 +277,7 @@ if __name__ == '__main__':
test_tf_files
()
test_tf_files
()
test_tf_record_schema
()
test_tf_record_schema
()
test_tf_record_shuffle
()
test_tf_record_shuffle
()
#
test_tf_record_shard()
test_tf_record_shard
()
test_tf_shard_equal_rows
()
test_tf_shard_equal_rows
()
test_case_tf_file_no_schema_columns_list
()
test_case_tf_file_no_schema_columns_list
()
test_tf_record_schema_columns_list
()
test_tf_record_schema_columns_list
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录