Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
79f087c2
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
79f087c2
编写于
5月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!909 Adding fix for set seed
Merge pull request !909 from EricZ/test_random
上级
200b3f07
26cb3e8a
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
282 addition
and
33 deletion
+282
-33
mindspore/dataset/core/configuration.py
mindspore/dataset/core/configuration.py
+8
-1
tests/ut/python/dataset/test_config.py
tests/ut/python/dataset/test_config.py
+222
-6
tests/ut/python/dataset/test_datasets_textfileop.py
tests/ut/python/dataset/test_datasets_textfileop.py
+1
-0
tests/ut/python/dataset/test_random_color_adjust.py
tests/ut/python/dataset/test_random_color_adjust.py
+21
-21
tests/ut/python/dataset/test_random_crop.py
tests/ut/python/dataset/test_random_crop.py
+4
-3
tests/ut/python/dataset/test_rename.py
tests/ut/python/dataset/test_rename.py
+3
-2
tests/ut/python/dataset/test_shuffle.py
tests/ut/python/dataset/test_shuffle.py
+23
-0
未找到文件。
mindspore/dataset/core/configuration.py
浏览文件 @
79f087c2
...
...
@@ -15,7 +15,7 @@
"""
The configuration manager.
"""
import
random
import
mindspore._c_dataengine
as
cde
INT32_MAX
=
2147483647
...
...
@@ -32,6 +32,12 @@ class ConfigurationManager:
"""
Set the seed to be used in any random generator. This is used to produce deterministic results.
Note:
This set_seed function sets the seed in the python random library function for deterministic
python augmentations using randomness. This set_seed function should be called with every
iterator created to reset the random seed. In our pipeline this does not guarantee
deterministic results with num_parallel_workers > 1.
Args:
seed(int): seed to be set
...
...
@@ -47,6 +53,7 @@ class ConfigurationManager:
if
seed
<
0
or
seed
>
UINT32_MAX
:
raise
ValueError
(
"Seed given is not within the required range"
)
self
.
config
.
set_seed
(
seed
)
random
.
seed
(
seed
)
def
get_seed
(
self
):
"""
...
...
tests/ut/python/dataset/test_config.py
浏览文件 @
79f087c2
...
...
@@ -13,14 +13,19 @@
# limitations under the License.
# ==============================================================================
"""
Testing configuration manager
Testing configuration manager
"""
import
filecmp
import
glob
import
numpy
as
np
import
os
from
mindspore
import
log
as
logger
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
mindspore.dataset.transforms.vision.py_transforms
as
py_vision
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
...
...
@@ -46,9 +51,17 @@ def test_basic():
assert
ds
.
config
.
get_prefetch_size
()
==
4
assert
ds
.
config
.
get_seed
()
==
5
def
test_get_seed
():
"""
This gets the seed value without explicitly setting a default, expect int.
"""
assert
isinstance
(
ds
.
config
.
get_seed
(),
int
)
def
test_pipeline
():
"""
Test that our configuration pipeline works when we set parameters at d
ataset interval
"""
Test that our configuration pipeline works when we set parameters at d
ifferent locations in dataset code
"""
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
ds
.
config
.
set_num_parallel_workers
(
2
)
...
...
@@ -60,12 +73,12 @@ def test_pipeline():
data2
=
data2
.
map
(
input_columns
=
[
"image"
],
operations
=
[
vision
.
Decode
(
True
)])
ds
.
serialize
(
data2
,
"testpipeline2.json"
)
# check that the generated output is different
# check that the generated output is different
assert
(
filecmp
.
cmp
(
'testpipeline.json'
,
'testpipeline2.json'
))
# this test passes currently because our num_parallel_workers don't get updated.
# this test passes currently because our num_parallel_workers don't get updated.
# remove generated jason files
# remove generated jason files
file_list
=
glob
.
glob
(
'*.json'
)
for
f
in
file_list
:
try
:
...
...
@@ -74,6 +87,209 @@ def test_pipeline():
logger
.
info
(
"Error while deleting: {}"
.
format
(
f
))
def
test_deterministic_run_fail
():
"""
Test RandomCrop with seed, expected to fail
"""
logger
.
info
(
"test_deterministic_run_fail"
)
# when we set the seed all operations within our dataset should be deterministic
ds
.
config
.
set_seed
(
0
)
ds
.
config
.
set_num_parallel_workers
(
1
)
# First dataset
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
# Assuming we get the same seed on calling constructor, if this op is re-used then result won't be
# the same in between the two datasets. For example, RandomCrop constructor takes seed (0)
# outputs a deterministic series of numbers, e,g "a" = [1, 2, 3, 4, 5, 6] <- pretend these are random
random_crop_op
=
vision
.
RandomCrop
([
512
,
512
],
[
200
,
200
,
200
,
200
])
decode_op
=
vision
.
Decode
()
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
random_crop_op
)
# Second dataset
data2
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
data2
=
data2
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
# If seed is set up on constructor
data2
=
data2
.
map
(
input_columns
=
[
"image"
],
operations
=
random_crop_op
)
try
:
for
item1
,
item2
in
zip
(
data1
.
create_dict_iterator
(),
data2
.
create_dict_iterator
()):
np
.
testing
.
assert_equal
(
item1
[
"image"
],
item2
[
"image"
])
except
BaseException
as
e
:
# two datasets split the number out of the sequence a
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
assert
"Array"
in
str
(
e
)
def
test_deterministic_run_pass
():
"""
Test deterministic run with with setting the seed
"""
logger
.
info
(
"test_deterministic_run_pass"
)
ds
.
config
.
set_seed
(
0
)
ds
.
config
.
set_num_parallel_workers
(
1
)
# First dataset
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
# We get the seed when constructor is called
random_crop_op
=
vision
.
RandomCrop
([
512
,
512
],
[
200
,
200
,
200
,
200
])
decode_op
=
vision
.
Decode
()
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
random_crop_op
)
# Second dataset
data2
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
data2
=
data2
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
# Since seed is set up on constructor, so the two ops output deterministic sequence.
# Assume the generated random sequence "a" = [1, 2, 3, 4, 5, 6] <- pretend these are random
random_crop_op2
=
vision
.
RandomCrop
([
512
,
512
],
[
200
,
200
,
200
,
200
])
data2
=
data2
.
map
(
input_columns
=
[
"image"
],
operations
=
random_crop_op2
)
try
:
for
item1
,
item2
in
zip
(
data1
.
create_dict_iterator
(),
data2
.
create_dict_iterator
()):
np
.
testing
.
assert_equal
(
item1
[
"image"
],
item2
[
"image"
])
except
BaseException
as
e
:
# two datasets both use numbers from the generated sequence "a"
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
assert
"Array"
in
str
(
e
)
def
test_seed_undeterministic
():
"""
Test seed with num parallel workers in c, this test is expected to fail some of the time
"""
logger
.
info
(
"test_seed_undeterministic"
)
ds
.
config
.
set_seed
(
0
)
# First dataset
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
# seed will be read in during constructor call
random_crop_op
=
vision
.
RandomCrop
([
512
,
512
],
[
200
,
200
,
200
,
200
])
decode_op
=
vision
.
Decode
()
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
random_crop_op
)
# Second dataset
data2
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
data2
=
data2
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
# If seed is set up on constructor, so the two ops output deterministic sequence
random_crop_op2
=
vision
.
RandomCrop
([
512
,
512
],
[
200
,
200
,
200
,
200
])
data2
=
data2
.
map
(
input_columns
=
[
"image"
],
operations
=
random_crop_op2
)
for
item1
,
item2
in
zip
(
data1
.
create_dict_iterator
(),
data2
.
create_dict_iterator
()):
np
.
testing
.
assert_equal
(
item1
[
"image"
],
item2
[
"image"
])
def
test_deterministic_run_distribution
():
"""
Test deterministic run with with setting the seed being used in a distribution
"""
logger
.
info
(
"test_deterministic_run_distribution"
)
# when we set the seed all operations within our dataset should be deterministic
ds
.
config
.
set_seed
(
0
)
ds
.
config
.
set_num_parallel_workers
(
1
)
# First dataset
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
random_crop_op
=
vision
.
RandomHorizontalFlip
(
0.1
)
decode_op
=
vision
.
Decode
()
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
random_crop_op
)
# Second dataset
data2
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
data2
=
data2
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
# If seed is set up on constructor, so the two ops output deterministic sequence
random_crop_op2
=
vision
.
RandomHorizontalFlip
(
0.1
)
data2
=
data2
.
map
(
input_columns
=
[
"image"
],
operations
=
random_crop_op2
)
for
item1
,
item2
in
zip
(
data1
.
create_dict_iterator
(),
data2
.
create_dict_iterator
()):
np
.
testing
.
assert_equal
(
item1
[
"image"
],
item2
[
"image"
])
def
test_deterministic_python_seed
():
"""
Test deterministic execution with seed in python
"""
logger
.
info
(
"deterministic_random_crop_op_python_2"
)
ds
.
config
.
set_seed
(
0
)
ds
.
config
.
set_num_parallel_workers
(
1
)
# First dataset
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
transforms
=
[
py_vision
.
Decode
(),
py_vision
.
RandomCrop
([
512
,
512
],
[
200
,
200
,
200
,
200
]),
py_vision
.
ToTensor
(),
]
transform
=
py_vision
.
ComposeOp
(
transforms
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
transform
())
data1_output
=
[]
# config.set_seed() calls random.seed()
for
data_one
in
data1
.
create_dict_iterator
():
data1_output
.
append
(
data_one
[
"image"
])
# Second dataset
data2
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
data2
=
data2
.
map
(
input_columns
=
[
"image"
],
operations
=
transform
())
# config.set_seed() calls random.seed(), resets seed for next dataset iterator
ds
.
config
.
set_seed
(
0
)
data2_output
=
[]
for
data_two
in
data2
.
create_dict_iterator
():
data2_output
.
append
(
data_two
[
"image"
])
np
.
testing
.
assert_equal
(
data1_output
,
data2_output
)
def
test_deterministic_python_seed_multi_thread
():
"""
Test deterministic execution with seed in python, this fails with multi-thread pyfunc run
"""
logger
.
info
(
"deterministic_random_crop_op_python_2"
)
ds
.
config
.
set_seed
(
0
)
# when we set the seed all operations within our dataset should be deterministic
# First dataset
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
transforms
=
[
py_vision
.
Decode
(),
py_vision
.
RandomCrop
([
512
,
512
],
[
200
,
200
,
200
,
200
]),
py_vision
.
ToTensor
(),
]
transform
=
py_vision
.
ComposeOp
(
transforms
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
transform
(),
python_multiprocessing
=
True
)
data1_output
=
[]
# config.set_seed() calls random.seed()
for
data_one
in
data1
.
create_dict_iterator
():
data1_output
.
append
(
data_one
[
"image"
])
# Second dataset
data2
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
# If seed is set up on constructor
data2
=
data2
.
map
(
input_columns
=
[
"image"
],
operations
=
transform
(),
python_multiprocessing
=
True
)
# config.set_seed() calls random.seed()
ds
.
config
.
set_seed
(
0
)
data2_output
=
[]
for
data_two
in
data2
.
create_dict_iterator
():
data2_output
.
append
(
data_two
[
"image"
])
try
:
np
.
testing
.
assert_equal
(
data1_output
,
data2_output
)
except
BaseException
as
e
:
# expect output to not match during multi-threaded excution
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
assert
"Array"
in
str
(
e
)
if
__name__
==
'__main__'
:
test_basic
()
test_pipeline
()
test_deterministic_run_pass
()
test_deterministic_run_distribution
()
test_deterministic_run_fail
()
test_deterministic_python_seed
()
test_seed_undeterministic
()
test_get_seed
()
tests/ut/python/dataset/test_datasets_textfileop.py
浏览文件 @
79f087c2
...
...
@@ -36,6 +36,7 @@ def test_textline_dataset_all_file():
assert
(
count
==
5
)
def
test_textline_dataset_totext
():
ds
.
config
.
set_num_parallel_workers
(
4
)
data
=
ds
.
TextFileDataset
(
DATA_ALL_FILE
,
shuffle
=
False
)
count
=
0
line
=
[
"This is a text file."
,
"Another file."
,
"Be happy every day."
,
"End of file."
,
"Good luck to everyone."
]
...
...
tests/ut/python/dataset/test_random_color_adjust.py
浏览文件 @
79f087c2
...
...
@@ -37,7 +37,7 @@ def visualize(first, mse, second):
plt
.
subplot
(
142
)
plt
.
imshow
(
second
)
plt
.
title
(
"py random_color_
jitter
image"
)
plt
.
title
(
"py random_color_
adjust
image"
)
plt
.
subplot
(
143
)
plt
.
imshow
(
first
-
second
)
...
...
@@ -50,20 +50,20 @@ def diff_mse(in1, in2):
return
mse
*
100
def
test_random_color_
jitter
_op_brightness
():
def
test_random_color_
adjust
_op_brightness
():
"""
Test RandomColorAdjust op
"""
logger
.
info
(
"test_random_color_
jitter
_op"
)
logger
.
info
(
"test_random_color_
adjust
_op"
)
# First dataset
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
decode_op
=
c_vision
.
Decode
()
random_
jitter
_op
=
c_vision
.
RandomColorAdjust
((
0.8
,
0.8
),
(
1
,
1
),
(
1
,
1
),
(
0
,
0
))
random_
adjust
_op
=
c_vision
.
RandomColorAdjust
((
0.8
,
0.8
),
(
1
,
1
),
(
1
,
1
),
(
0
,
0
))
ctrans
=
[
decode_op
,
random_
jitter
_op
,
random_
adjust
_op
,
]
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
ctrans
)
...
...
@@ -100,20 +100,20 @@ def test_random_color_jitter_op_brightness():
# visualize(c_image, mse, py_image)
def
test_random_color_
jitter
_op_contrast
():
def
test_random_color_
adjust
_op_contrast
():
"""
Test RandomColorAdjust op
"""
logger
.
info
(
"test_random_color_
jitter
_op"
)
logger
.
info
(
"test_random_color_
adjust
_op"
)
# First dataset
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
decode_op
=
c_vision
.
Decode
()
random_
jitter
_op
=
c_vision
.
RandomColorAdjust
((
1
,
1
),
(
0.5
,
0.5
),
(
1
,
1
),
(
0
,
0
))
random_
adjust
_op
=
c_vision
.
RandomColorAdjust
((
1
,
1
),
(
0.5
,
0.5
),
(
1
,
1
),
(
0
,
0
))
ctrans
=
[
decode_op
,
random_
jitter
_op
random_
adjust
_op
]
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
ctrans
)
...
...
@@ -156,20 +156,20 @@ def test_random_color_jitter_op_contrast():
# visualize(c_image, mse, py_image)
def
test_random_color_
jitter
_op_saturation
():
def
test_random_color_
adjust
_op_saturation
():
"""
Test RandomColorAdjust op
"""
logger
.
info
(
"test_random_color_
jitter
_op"
)
logger
.
info
(
"test_random_color_
adjust
_op"
)
# First dataset
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
decode_op
=
c_vision
.
Decode
()
random_
jitter
_op
=
c_vision
.
RandomColorAdjust
((
1
,
1
),
(
1
,
1
),
(
0.5
,
0.5
),
(
0
,
0
))
random_
adjust
_op
=
c_vision
.
RandomColorAdjust
((
1
,
1
),
(
1
,
1
),
(
0.5
,
0.5
),
(
0
,
0
))
ctrans
=
[
decode_op
,
random_
jitter
_op
random_
adjust
_op
]
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
ctrans
)
...
...
@@ -209,20 +209,20 @@ def test_random_color_jitter_op_saturation():
# visualize(c_image, mse, py_image)
def
test_random_color_
jitter
_op_hue
():
def
test_random_color_
adjust
_op_hue
():
"""
Test RandomColorAdjust op
"""
logger
.
info
(
"test_random_color_
jitter
_op"
)
logger
.
info
(
"test_random_color_
adjust
_op"
)
# First dataset
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
decode_op
=
c_vision
.
Decode
()
random_
jitter
_op
=
c_vision
.
RandomColorAdjust
((
1
,
1
),
(
1
,
1
),
(
1
,
1
),
(
0.2
,
0.2
))
random_
adjust
_op
=
c_vision
.
RandomColorAdjust
((
1
,
1
),
(
1
,
1
),
(
1
,
1
),
(
0.2
,
0.2
))
ctrans
=
[
decode_op
,
random_
jitter
_op
,
random_
adjust
_op
,
]
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
ctrans
)
...
...
@@ -264,7 +264,7 @@ def test_random_color_jitter_op_hue():
if
__name__
==
"__main__"
:
test_random_color_
jitter
_op_brightness
()
test_random_color_
jitter
_op_contrast
()
test_random_color_
jitter
_op_saturation
()
test_random_color_
jitter
_op_hue
()
test_random_color_
adjust
_op_brightness
()
test_random_color_
adjust
_op_contrast
()
test_random_color_
adjust
_op_saturation
()
test_random_color_
adjust
_op_hue
()
tests/ut/python/dataset/test_random_crop.py
浏览文件 @
79f087c2
...
...
@@ -17,8 +17,8 @@ Testing RandomCropAndResize op in DE
"""
import
matplotlib.pyplot
as
plt
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
import
mindspore.dataset
as
ds
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
...
...
@@ -45,9 +45,9 @@ def visualize(a, mse, original):
def
test_random_crop_op
():
"""
Test RandomCrop
AndResize o
p
Test RandomCrop
O
p
"""
logger
.
info
(
"test_random_crop_
and_resize_
op"
)
logger
.
info
(
"test_random_crop_op"
)
# First dataset
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
...
...
@@ -67,3 +67,4 @@ def test_random_crop_op():
if
__name__
==
"__main__"
:
test_random_crop_op
()
tests/ut/python/dataset/test_rename.py
浏览文件 @
79f087c2
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
numpy
as
np
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
...
...
@@ -34,9 +35,9 @@ def test_rename():
for
i
,
item
in
enumerate
(
data
.
create_dict_iterator
()):
logger
.
info
(
"item[mask] is {}"
.
format
(
item
[
"masks"
]))
assert
item
[
"masks"
].
all
()
==
item
[
"input_ids"
].
all
(
)
np
.
testing
.
assert_equal
(
item
[
"masks"
],
item
[
"input_ids"
]
)
logger
.
info
(
"item[seg_ids] is {}"
.
format
(
item
[
"seg_ids"
]))
assert
item
[
"segment_ids"
].
all
()
==
item
[
"seg_ids"
].
all
(
)
np
.
testing
.
assert_equal
(
item
[
"segment_ids"
],
item
[
"seg_ids"
]
)
# need to consume the data in the buffer
num_iter
+=
1
logger
.
info
(
"Number of data in data: {}"
.
format
(
num_iter
))
...
...
tests/ut/python/dataset/test_shuffle.py
浏览文件 @
79f087c2
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
numpy
as
np
from
util
import
save_and_check
import
mindspore.dataset
as
ds
...
...
@@ -117,6 +118,27 @@ def test_shuffle_05():
save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_shuffle_06
():
"""
Test shuffle: with set seed, both datasets
"""
logger
.
info
(
"test_shuffle_06"
)
# define parameters
buffer_size
=
13
seed
=
1
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
shuffle
=
ds
.
Shuffle
.
FILES
)
ds
.
config
.
set_seed
(
seed
)
data1
=
data1
.
shuffle
(
buffer_size
=
buffer_size
)
data2
=
ds
.
TFRecordDataset
(
DATA_DIR
,
shuffle
=
ds
.
Shuffle
.
FILES
)
data2
=
data2
.
shuffle
(
buffer_size
=
buffer_size
)
for
item1
,
item2
in
zip
(
data1
.
create_dict_iterator
(),
data2
.
create_dict_iterator
()):
np
.
testing
.
assert_equal
(
item1
,
item2
)
def
test_shuffle_exception_01
():
"""
Test shuffle exception: buffer_size<0
...
...
@@ -231,6 +253,7 @@ if __name__ == '__main__':
test_shuffle_03
()
test_shuffle_04
()
test_shuffle_05
()
test_shuffle_06
()
test_shuffle_exception_01
()
test_shuffle_exception_02
()
test_shuffle_exception_03
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录