Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
49ef53f1
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
49ef53f1
编写于
5月 11, 2020
作者:
C
Cathy Wong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cleanup dataset UT: util.py internals
上级
2af6ee24
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
96 addition
and
120 deletion
+96
-120
tests/ut/data/dataset/golden/HWC2CHW_01_result.npz
tests/ut/data/dataset/golden/HWC2CHW_01_result.npz
+0
-0
tests/ut/data/dataset/golden/center_crop_01_result.npz
tests/ut/data/dataset/golden/center_crop_01_result.npz
+0
-0
tests/ut/data/dataset/golden/repeat_list_result.npz
tests/ut/data/dataset/golden/repeat_list_result.npz
+0
-0
tests/ut/python/dataset/test_HWC2CHW.py
tests/ut/python/dataset/test_HWC2CHW.py
+4
-4
tests/ut/python/dataset/test_center_crop.py
tests/ut/python/dataset/test_center_crop.py
+18
-19
tests/ut/python/dataset/test_general.py
tests/ut/python/dataset/test_general.py
+0
-41
tests/ut/python/dataset/test_project.py
tests/ut/python/dataset/test_project.py
+11
-12
tests/ut/python/dataset/test_repeat.py
tests/ut/python/dataset/test_repeat.py
+43
-24
tests/ut/python/dataset/util.py
tests/ut/python/dataset/util.py
+20
-20
未找到文件。
tests/ut/data/dataset/golden/
test_
HWC2CHW_01_result.npz
→
tests/ut/data/dataset/golden/HWC2CHW_01_result.npz
浏览文件 @
49ef53f1
文件已移动
tests/ut/data/dataset/golden/
test_
center_crop_01_result.npz
→
tests/ut/data/dataset/golden/center_crop_01_result.npz
浏览文件 @
49ef53f1
文件已移动
tests/ut/data/dataset/golden/
columns
_list_result.npz
→
tests/ut/data/dataset/golden/
repeat
_list_result.npz
浏览文件 @
49ef53f1
文件已移动
tests/ut/python/dataset/test_HWC2CHW.py
浏览文件 @
49ef53f1
...
@@ -69,8 +69,8 @@ def test_HWC2CHW_md5():
...
@@ -69,8 +69,8 @@ def test_HWC2CHW_md5():
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
hwc2chw_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
hwc2chw_op
)
# expected md5 from images
#
Compare with
expected md5 from images
filename
=
"
test_
HWC2CHW_01_result.npz"
filename
=
"HWC2CHW_01_result.npz"
save_and_check_md5
(
data1
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_md5
(
data1
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
...
@@ -103,9 +103,9 @@ def test_HWC2CHW_comp(plot=False):
...
@@ -103,9 +103,9 @@ def test_HWC2CHW_comp(plot=False):
c_image
=
item1
[
"image"
]
c_image
=
item1
[
"image"
]
py_image
=
(
item2
[
"image"
].
transpose
(
1
,
2
,
0
)
*
255
).
astype
(
np
.
uint8
)
py_image
=
(
item2
[
"image"
].
transpose
(
1
,
2
,
0
)
*
255
).
astype
(
np
.
uint8
)
#
c
ompare images between that applying c_transform and py_transform
#
C
ompare images between that applying c_transform and py_transform
mse
=
diff_mse
(
py_image
,
c_image
)
mse
=
diff_mse
(
py_image
,
c_image
)
#
t
he images aren't exactly the same due to rounding error
#
Note: T
he images aren't exactly the same due to rounding error
assert
mse
<
0.001
assert
mse
<
0.001
image_c_transposed
.
append
(
item1
[
"image"
].
copy
())
image_c_transposed
.
append
(
item1
[
"image"
].
copy
())
...
...
tests/ut/python/dataset/test_center_crop.py
浏览文件 @
49ef53f1
...
@@ -12,10 +12,9 @@
...
@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
numpy
as
np
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
mindspore.dataset.transforms.vision.py_transforms
as
py_vision
import
mindspore.dataset.transforms.vision.py_transforms
as
py_vision
import
numpy
as
np
import
matplotlib.pyplot
as
plt
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
util
import
diff_mse
,
visualize
,
save_and_check_md5
from
util
import
diff_mse
,
visualize
,
save_and_check_md5
...
@@ -60,15 +59,14 @@ def test_center_crop_md5(height=375, width=375):
...
@@ -60,15 +59,14 @@ def test_center_crop_md5(height=375, width=375):
logger
.
info
(
"Test CenterCrop"
)
logger
.
info
(
"Test CenterCrop"
)
# First dataset
# First dataset
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
decode_op
=
vision
.
Decode
()
decode_op
=
vision
.
Decode
()
# 3 images [375, 500] [600, 500] [512, 512]
# 3 images [375, 500] [600, 500] [512, 512]
center_crop_op
=
vision
.
CenterCrop
([
height
,
width
])
center_crop_op
=
vision
.
CenterCrop
([
height
,
width
])
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
center_crop_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
center_crop_op
)
# expected md5 from images
# Compare with expected md5 from images
filename
=
"center_crop_01_result.npz"
filename
=
"test_center_crop_01_result.npz"
save_and_check_md5
(
data1
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_md5
(
data1
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
...
@@ -100,8 +98,8 @@ def test_center_crop_comp(height=375, width=375, plot=False):
...
@@ -100,8 +98,8 @@ def test_center_crop_comp(height=375, width=375, plot=False):
for
item1
,
item2
in
zip
(
data1
.
create_dict_iterator
(),
data2
.
create_dict_iterator
()):
for
item1
,
item2
in
zip
(
data1
.
create_dict_iterator
(),
data2
.
create_dict_iterator
()):
c_image
=
item1
[
"image"
]
c_image
=
item1
[
"image"
]
py_image
=
(
item2
[
"image"
].
transpose
(
1
,
2
,
0
)
*
255
).
astype
(
np
.
uint8
)
py_image
=
(
item2
[
"image"
].
transpose
(
1
,
2
,
0
)
*
255
).
astype
(
np
.
uint8
)
#
the images aren't exactly the same due to rouding error
#
Note: The images aren't exactly the same due to rounding error
assert
(
diff_mse
(
py_image
,
c_image
)
<
0.001
)
assert
diff_mse
(
py_image
,
c_image
)
<
0.001
image_cropped
.
append
(
item1
[
"image"
].
copy
())
image_cropped
.
append
(
item1
[
"image"
].
copy
())
image
.
append
(
item2
[
"image"
].
copy
())
image
.
append
(
item2
[
"image"
].
copy
())
if
plot
:
if
plot
:
...
@@ -112,6 +110,7 @@ def test_crop_grayscale(height=375, width=375):
...
@@ -112,6 +110,7 @@ def test_crop_grayscale(height=375, width=375):
"""
"""
Test that centercrop works with pad and grayscale images
Test that centercrop works with pad and grayscale images
"""
"""
def
channel_swap
(
image
):
def
channel_swap
(
image
):
"""
"""
Py func hack for our pytransforms to work with c transforms
Py func hack for our pytransforms to work with c transforms
...
@@ -129,15 +128,15 @@ def test_crop_grayscale(height=375, width=375):
...
@@ -129,15 +128,15 @@ def test_crop_grayscale(height=375, width=375):
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
transform
())
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
transform
())
#
if input is grayscale, the output dimensions should be single channel
#
If input is grayscale, the output dimensions should be single channel
crop_gray
=
vision
.
CenterCrop
([
height
,
width
])
crop_gray
=
vision
.
CenterCrop
([
height
,
width
])
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
crop_gray
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
crop_gray
)
for
item1
in
data1
.
create_dict_iterator
():
for
item1
in
data1
.
create_dict_iterator
():
c_image
=
item1
[
"image"
]
c_image
=
item1
[
"image"
]
#
c
heck that the image is grayscale
#
C
heck that the image is grayscale
assert
(
len
(
c_image
.
shape
)
==
3
and
c_image
.
shape
[
2
]
==
1
)
assert
(
c_image
.
ndim
==
3
and
c_image
.
shape
[
2
]
==
1
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/ut/python/dataset/test_general.py
已删除
100644 → 0
浏览文件 @
2af6ee24
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
util
import
save_and_check
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
DATA_DIR
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
SCHEMA_DIR
=
"../data/dataset/testTFTestAllTypes/datasetSchema.json"
COLUMNS
=
[
"col_1d"
,
"col_2d"
,
"col_3d"
,
"col_binary"
,
"col_float"
,
"col_sint16"
,
"col_sint32"
,
"col_sint64"
]
GENERATE_GOLDEN
=
False
def
test_case_columns_list
():
"""
a simple repeat operation.
"""
logger
.
info
(
"Test Simple Repeat"
)
# define parameters
repeat_count
=
2
parameters
=
{
"params"
:
{
'repeat_count'
:
repeat_count
}}
columns_list
=
[
"col_sint64"
,
"col_sint32"
]
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
columns_list
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
repeat_count
)
filename
=
"columns_list_result.npz"
save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
tests/ut/python/dataset/test_project.py
浏览文件 @
49ef53f1
...
@@ -12,12 +12,11 @@
...
@@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
mindspore.dataset
.transforms.vision.c_transforms
as
vision
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.c_transforms
as
C
import
mindspore.dataset.transforms.c_transforms
as
C
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
from
util
import
ordered_save_and_check
from
util
import
save_and_check_tuple
import
mindspore.dataset
as
ds
DATA_DIR_TF
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
DATA_DIR_TF
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
SCHEMA_DIR_TF
=
"../data/dataset/testTFTestAllTypes/datasetSchema.json"
SCHEMA_DIR_TF
=
"../data/dataset/testTFTestAllTypes/datasetSchema.json"
...
@@ -32,7 +31,7 @@ def test_case_project_single_column():
...
@@ -32,7 +31,7 @@ def test_case_project_single_column():
data1
=
data1
.
project
(
columns
=
columns
)
data1
=
data1
.
project
(
columns
=
columns
)
filename
=
"project_single_column_result.npz"
filename
=
"project_single_column_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_project_multiple_columns_in_order
():
def
test_case_project_multiple_columns_in_order
():
...
@@ -43,7 +42,7 @@ def test_case_project_multiple_columns_in_order():
...
@@ -43,7 +42,7 @@ def test_case_project_multiple_columns_in_order():
data1
=
data1
.
project
(
columns
=
columns
)
data1
=
data1
.
project
(
columns
=
columns
)
filename
=
"project_multiple_columns_in_order_result.npz"
filename
=
"project_multiple_columns_in_order_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_project_multiple_columns_out_of_order
():
def
test_case_project_multiple_columns_out_of_order
():
...
@@ -54,7 +53,7 @@ def test_case_project_multiple_columns_out_of_order():
...
@@ -54,7 +53,7 @@ def test_case_project_multiple_columns_out_of_order():
data1
=
data1
.
project
(
columns
=
columns
)
data1
=
data1
.
project
(
columns
=
columns
)
filename
=
"project_multiple_columns_out_of_order_result.npz"
filename
=
"project_multiple_columns_out_of_order_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_project_map
():
def
test_case_project_map
():
...
@@ -68,7 +67,7 @@ def test_case_project_map():
...
@@ -68,7 +67,7 @@ def test_case_project_map():
data1
=
data1
.
map
(
input_columns
=
[
"col_3d"
],
operations
=
type_cast_op
)
data1
=
data1
.
map
(
input_columns
=
[
"col_3d"
],
operations
=
type_cast_op
)
filename
=
"project_map_after_result.npz"
filename
=
"project_map_after_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_map_project
():
def
test_case_map_project
():
...
@@ -83,7 +82,7 @@ def test_case_map_project():
...
@@ -83,7 +82,7 @@ def test_case_map_project():
data1
=
data1
.
project
(
columns
=
columns
)
data1
=
data1
.
project
(
columns
=
columns
)
filename
=
"project_map_before_result.npz"
filename
=
"project_map_before_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_project_between_maps
():
def
test_case_project_between_maps
():
...
@@ -107,7 +106,7 @@ def test_case_project_between_maps():
...
@@ -107,7 +106,7 @@ def test_case_project_between_maps():
data1
=
data1
.
map
(
input_columns
=
[
"col_3d"
],
operations
=
type_cast_op
)
data1
=
data1
.
map
(
input_columns
=
[
"col_3d"
],
operations
=
type_cast_op
)
filename
=
"project_between_maps_result.npz"
filename
=
"project_between_maps_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_project_repeat
():
def
test_case_project_repeat
():
...
@@ -121,7 +120,7 @@ def test_case_project_repeat():
...
@@ -121,7 +120,7 @@ def test_case_project_repeat():
data1
=
data1
.
repeat
(
repeat_count
)
data1
=
data1
.
repeat
(
repeat_count
)
filename
=
"project_before_repeat_result.npz"
filename
=
"project_before_repeat_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_repeat_project
():
def
test_case_repeat_project
():
...
@@ -136,7 +135,7 @@ def test_case_repeat_project():
...
@@ -136,7 +135,7 @@ def test_case_repeat_project():
data1
=
data1
.
project
(
columns
=
columns
)
data1
=
data1
.
project
(
columns
=
columns
)
filename
=
"project_after_repeat_result.npz"
filename
=
"project_after_repeat_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_map_project_map_project
():
def
test_case_map_project_map_project
():
...
@@ -155,4 +154,4 @@ def test_case_map_project_map_project():
...
@@ -155,4 +154,4 @@ def test_case_map_project_map_project():
data1
=
data1
.
project
(
columns
=
columns
)
data1
=
data1
.
project
(
columns
=
columns
)
filename
=
"project_alternate_parallel_inline_result.npz"
filename
=
"project_alternate_parallel_inline_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
tests/ut/python/dataset/test_repeat.py
浏览文件 @
49ef53f1
...
@@ -13,11 +13,10 @@
...
@@ -13,11 +13,10 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
util
import
save_and_check
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
import
numpy
as
np
import
numpy
as
np
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
util
import
save_and_check
DATA_DIR_TF
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
DATA_DIR_TF
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
SCHEMA_DIR_TF
=
"../data/dataset/testTFTestAllTypes/datasetSchema.json"
SCHEMA_DIR_TF
=
"../data/dataset/testTFTestAllTypes/datasetSchema.json"
...
@@ -25,13 +24,6 @@ COLUMNS_TF = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
...
@@ -25,13 +24,6 @@ COLUMNS_TF = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
"col_sint16"
,
"col_sint32"
,
"col_sint64"
]
"col_sint16"
,
"col_sint32"
,
"col_sint64"
]
GENERATE_GOLDEN
=
False
GENERATE_GOLDEN
=
False
# Data for CIFAR and MNIST are not part of build tree
# They need to be downloaded directly
# prep_data.py can be exuted or code below
# import sys
# sys.path.insert(0,"../../data")
# import prep_data
# prep_data.download_all_for_test("../../data")
IMG_DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
IMG_DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
IMG_SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
IMG_SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
...
@@ -41,7 +33,7 @@ SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
...
@@ -41,7 +33,7 @@ SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def
test_tf_repeat_01
():
def
test_tf_repeat_01
():
"""
"""
a simple repeat operation.
Test
a simple repeat operation.
"""
"""
logger
.
info
(
"Test Simple Repeat"
)
logger
.
info
(
"Test Simple Repeat"
)
# define parameters
# define parameters
...
@@ -58,7 +50,7 @@ def test_tf_repeat_01():
...
@@ -58,7 +50,7 @@ def test_tf_repeat_01():
def
test_tf_repeat_02
():
def
test_tf_repeat_02
():
"""
"""
a simple repeat operation to tes infinite
Test Infinite Repeat.
"""
"""
logger
.
info
(
"Test Infinite Repeat"
)
logger
.
info
(
"Test Infinite Repeat"
)
# define parameters
# define parameters
...
@@ -77,7 +69,10 @@ def test_tf_repeat_02():
...
@@ -77,7 +69,10 @@ def test_tf_repeat_02():
def
test_tf_repeat_03
():
def
test_tf_repeat_03
():
'''repeat and batch '''
"""
Test Repeat then Batch.
"""
logger
.
info
(
"Test Repeat then Batch"
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF2
,
SCHEMA_DIR_TF2
,
shuffle
=
False
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF2
,
SCHEMA_DIR_TF2
,
shuffle
=
False
)
batch_size
=
32
batch_size
=
32
...
@@ -90,15 +85,32 @@ def test_tf_repeat_03():
...
@@ -90,15 +85,32 @@ def test_tf_repeat_03():
data1
=
data1
.
batch
(
batch_size
,
drop_remainder
=
True
)
data1
=
data1
.
batch
(
batch_size
,
drop_remainder
=
True
)
num_iter
=
0
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
for
_
in
data1
.
create_dict_iterator
():
num_iter
+=
1
num_iter
+=
1
logger
.
info
(
"Number of tf data in data1: {}"
.
format
(
num_iter
))
logger
.
info
(
"Number of tf data in data1: {}"
.
format
(
num_iter
))
assert
num_iter
==
2
assert
num_iter
==
2
def
test_tf_repeat_04
():
"""
Test a simple repeat operation with column list.
"""
logger
.
info
(
"Test Simple Repeat Column List"
)
# define parameters
repeat_count
=
2
parameters
=
{
"params"
:
{
'repeat_count'
:
repeat_count
}}
columns_list
=
[
"col_sint64"
,
"col_sint32"
]
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF
,
SCHEMA_DIR_TF
,
columns_list
=
columns_list
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
repeat_count
)
filename
=
"repeat_list_result.npz"
save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
generator
():
def
generator
():
for
i
in
range
(
3
):
for
i
in
range
(
3
):
yield
np
.
array
([
i
]),
(
yield
np
.
array
([
i
]),)
def
test_nested_repeat1
():
def
test_nested_repeat1
():
...
@@ -151,7 +163,7 @@ def test_nested_repeat5():
...
@@ -151,7 +163,7 @@ def test_nested_repeat5():
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
3
)
data
=
data
.
repeat
(
3
)
for
i
,
d
in
enumerate
(
data
):
for
_
,
d
in
enumerate
(
data
):
assert
np
.
array_equal
(
d
[
0
],
np
.
asarray
([[
0
],
[
1
],
[
2
]]))
assert
np
.
array_equal
(
d
[
0
],
np
.
asarray
([[
0
],
[
1
],
[
2
]]))
assert
sum
([
1
for
_
in
data
])
==
6
assert
sum
([
1
for
_
in
data
])
==
6
...
@@ -163,7 +175,7 @@ def test_nested_repeat6():
...
@@ -163,7 +175,7 @@ def test_nested_repeat6():
data
=
data
.
batch
(
3
)
data
=
data
.
batch
(
3
)
data
=
data
.
repeat
(
3
)
data
=
data
.
repeat
(
3
)
for
i
,
d
in
enumerate
(
data
):
for
_
,
d
in
enumerate
(
data
):
assert
np
.
array_equal
(
d
[
0
],
np
.
asarray
([[
0
],
[
1
],
[
2
]]))
assert
np
.
array_equal
(
d
[
0
],
np
.
asarray
([[
0
],
[
1
],
[
2
]]))
assert
sum
([
1
for
_
in
data
])
==
6
assert
sum
([
1
for
_
in
data
])
==
6
...
@@ -175,7 +187,7 @@ def test_nested_repeat7():
...
@@ -175,7 +187,7 @@ def test_nested_repeat7():
data
=
data
.
repeat
(
3
)
data
=
data
.
repeat
(
3
)
data
=
data
.
batch
(
3
)
data
=
data
.
batch
(
3
)
for
i
,
d
in
enumerate
(
data
):
for
_
,
d
in
enumerate
(
data
):
assert
np
.
array_equal
(
d
[
0
],
np
.
asarray
([[
0
],
[
1
],
[
2
]]))
assert
np
.
array_equal
(
d
[
0
],
np
.
asarray
([[
0
],
[
1
],
[
2
]]))
assert
sum
([
1
for
_
in
data
])
==
6
assert
sum
([
1
for
_
in
data
])
==
6
...
@@ -232,11 +244,18 @@ def test_nested_repeat11():
...
@@ -232,11 +244,18 @@ def test_nested_repeat11():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
logger
.
info
(
"--------test tf repeat 01---------"
)
test_tf_repeat_01
()
# test_repeat_01()
test_tf_repeat_02
()
logger
.
info
(
"--------test tf repeat 02---------"
)
# test_repeat_02()
logger
.
info
(
"--------test tf repeat 03---------"
)
test_tf_repeat_03
()
test_tf_repeat_03
()
test_tf_repeat_04
()
test_nested_repeat1
()
test_nested_repeat2
()
test_nested_repeat3
()
test_nested_repeat4
()
test_nested_repeat5
()
test_nested_repeat6
()
test_nested_repeat7
()
test_nested_repeat8
()
test_nested_repeat9
()
test_nested_repeat10
()
test_nested_repeat11
()
tests/ut/python/dataset/util.py
浏览文件 @
49ef53f1
...
@@ -21,12 +21,13 @@ import matplotlib.pyplot as plt
...
@@ -21,12 +21,13 @@ import matplotlib.pyplot as plt
#import jsbeautifier
#import jsbeautifier
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
# These are the column names defined in the testTFTestAllTypes dataset
COLUMNS
=
[
"col_1d"
,
"col_2d"
,
"col_3d"
,
"col_binary"
,
"col_float"
,
COLUMNS
=
[
"col_1d"
,
"col_2d"
,
"col_3d"
,
"col_binary"
,
"col_float"
,
"col_sint16"
,
"col_sint32"
,
"col_sint64"
]
"col_sint16"
,
"col_sint32"
,
"col_sint64"
]
SAVE_JSON
=
False
SAVE_JSON
=
False
def
save_golden
(
cur_dir
,
golden_ref_dir
,
result_dict
):
def
_
save_golden
(
cur_dir
,
golden_ref_dir
,
result_dict
):
"""
"""
Save the dictionary values as the golden result in .npz file
Save the dictionary values as the golden result in .npz file
"""
"""
...
@@ -35,7 +36,7 @@ def save_golden(cur_dir, golden_ref_dir, result_dict):
...
@@ -35,7 +36,7 @@ def save_golden(cur_dir, golden_ref_dir, result_dict):
np
.
savez
(
golden_ref_dir
,
np
.
array
(
list
(
result_dict
.
values
())))
np
.
savez
(
golden_ref_dir
,
np
.
array
(
list
(
result_dict
.
values
())))
def
save_golden_dict
(
cur_dir
,
golden_ref_dir
,
result_dict
):
def
_
save_golden_dict
(
cur_dir
,
golden_ref_dir
,
result_dict
):
"""
"""
Save the dictionary (both keys and values) as the golden result in .npz file
Save the dictionary (both keys and values) as the golden result in .npz file
"""
"""
...
@@ -44,7 +45,7 @@ def save_golden_dict(cur_dir, golden_ref_dir, result_dict):
...
@@ -44,7 +45,7 @@ def save_golden_dict(cur_dir, golden_ref_dir, result_dict):
np
.
savez
(
golden_ref_dir
,
np
.
array
(
list
(
result_dict
.
items
())))
np
.
savez
(
golden_ref_dir
,
np
.
array
(
list
(
result_dict
.
items
())))
def
compare_to_golden
(
golden_ref_dir
,
result_dict
):
def
_
compare_to_golden
(
golden_ref_dir
,
result_dict
):
"""
"""
Compare as numpy arrays the test result to the golden result
Compare as numpy arrays the test result to the golden result
"""
"""
...
@@ -53,16 +54,15 @@ def compare_to_golden(golden_ref_dir, result_dict):
...
@@ -53,16 +54,15 @@ def compare_to_golden(golden_ref_dir, result_dict):
assert
np
.
array_equal
(
test_array
,
golden_array
)
assert
np
.
array_equal
(
test_array
,
golden_array
)
def
compare_to_golden_dict
(
golden_ref_dir
,
result_dict
):
def
_
compare_to_golden_dict
(
golden_ref_dir
,
result_dict
):
"""
"""
Compare as dictionaries the test result to the golden result
Compare as dictionaries the test result to the golden result
"""
"""
golden_array
=
np
.
load
(
golden_ref_dir
,
allow_pickle
=
True
)[
'arr_0'
]
golden_array
=
np
.
load
(
golden_ref_dir
,
allow_pickle
=
True
)[
'arr_0'
]
np
.
testing
.
assert_equal
(
result_dict
,
dict
(
golden_array
))
np
.
testing
.
assert_equal
(
result_dict
,
dict
(
golden_array
))
# assert result_dict == dict(golden_array)
def
save_json
(
filename
,
parameters
,
result_dict
):
def
_
save_json
(
filename
,
parameters
,
result_dict
):
"""
"""
Save the result dictionary in json file
Save the result dictionary in json file
"""
"""
...
@@ -78,6 +78,7 @@ def save_and_check(data, parameters, filename, generate_golden=False):
...
@@ -78,6 +78,7 @@ def save_and_check(data, parameters, filename, generate_golden=False):
"""
"""
Save the dataset dictionary and compare (as numpy array) with golden file.
Save the dataset dictionary and compare (as numpy array) with golden file.
Use create_dict_iterator to access the dataset.
Use create_dict_iterator to access the dataset.
Note: save_and_check() is deprecated; use save_and_check_dict().
"""
"""
num_iter
=
0
num_iter
=
0
result_dict
=
{}
result_dict
=
{}
...
@@ -97,13 +98,13 @@ def save_and_check(data, parameters, filename, generate_golden=False):
...
@@ -97,13 +98,13 @@ def save_and_check(data, parameters, filename, generate_golden=False):
golden_ref_dir
=
os
.
path
.
join
(
cur_dir
,
"../../data/dataset"
,
'golden'
,
filename
)
golden_ref_dir
=
os
.
path
.
join
(
cur_dir
,
"../../data/dataset"
,
'golden'
,
filename
)
if
generate_golden
:
if
generate_golden
:
# Save as the golden result
# Save as the golden result
save_golden
(
cur_dir
,
golden_ref_dir
,
result_dict
)
_
save_golden
(
cur_dir
,
golden_ref_dir
,
result_dict
)
compare_to_golden
(
golden_ref_dir
,
result_dict
)
_
compare_to_golden
(
golden_ref_dir
,
result_dict
)
if
SAVE_JSON
:
if
SAVE_JSON
:
# Save result to a json file for inspection
# Save result to a json file for inspection
save_json
(
filename
,
parameters
,
result_dict
)
_
save_json
(
filename
,
parameters
,
result_dict
)
def
save_and_check_dict
(
data
,
filename
,
generate_golden
=
False
):
def
save_and_check_dict
(
data
,
filename
,
generate_golden
=
False
):
...
@@ -127,14 +128,14 @@ def save_and_check_dict(data, filename, generate_golden=False):
...
@@ -127,14 +128,14 @@ def save_and_check_dict(data, filename, generate_golden=False):
golden_ref_dir
=
os
.
path
.
join
(
cur_dir
,
"../../data/dataset"
,
'golden'
,
filename
)
golden_ref_dir
=
os
.
path
.
join
(
cur_dir
,
"../../data/dataset"
,
'golden'
,
filename
)
if
generate_golden
:
if
generate_golden
:
# Save as the golden result
# Save as the golden result
save_golden_dict
(
cur_dir
,
golden_ref_dir
,
result_dict
)
_
save_golden_dict
(
cur_dir
,
golden_ref_dir
,
result_dict
)
compare_to_golden_dict
(
golden_ref_dir
,
result_dict
)
_
compare_to_golden_dict
(
golden_ref_dir
,
result_dict
)
if
SAVE_JSON
:
if
SAVE_JSON
:
# Save result to a json file for inspection
# Save result to a json file for inspection
parameters
=
{
"params"
:
{}}
parameters
=
{
"params"
:
{}}
save_json
(
filename
,
parameters
,
result_dict
)
_
save_json
(
filename
,
parameters
,
result_dict
)
def
save_and_check_md5
(
data
,
filename
,
generate_golden
=
False
):
def
save_and_check_md5
(
data
,
filename
,
generate_golden
=
False
):
...
@@ -159,22 +160,21 @@ def save_and_check_md5(data, filename, generate_golden=False):
...
@@ -159,22 +160,21 @@ def save_and_check_md5(data, filename, generate_golden=False):
golden_ref_dir
=
os
.
path
.
join
(
cur_dir
,
"../../data/dataset"
,
'golden'
,
filename
)
golden_ref_dir
=
os
.
path
.
join
(
cur_dir
,
"../../data/dataset"
,
'golden'
,
filename
)
if
generate_golden
:
if
generate_golden
:
# Save as the golden result
# Save as the golden result
save_golden_dict
(
cur_dir
,
golden_ref_dir
,
result_dict
)
_
save_golden_dict
(
cur_dir
,
golden_ref_dir
,
result_dict
)
compare_to_golden_dict
(
golden_ref_dir
,
result_dict
)
_
compare_to_golden_dict
(
golden_ref_dir
,
result_dict
)
def
ordered_save_and_check
(
data
,
parameters
,
filename
,
generate_golden
=
False
):
def
save_and_check_tuple
(
data
,
parameters
,
filename
,
generate_golden
=
False
):
"""
"""
Save the dataset dictionary and compare (as numpy array) with golden file.
Save the dataset dictionary and compare (as numpy array) with golden file.
Use create_tuple_iterator to access the dataset.
Use create_tuple_iterator to access the dataset.
"""
"""
num_iter
=
0
num_iter
=
0
result_dict
=
{}
result_dict
=
{}
for
item
in
data
.
create_tuple_iterator
():
# each data is a dictionary
for
item
in
data
.
create_tuple_iterator
():
# each data is a dictionary
for
data_key
in
range
(
0
,
len
(
item
)
):
for
data_key
,
_
in
enumerate
(
item
):
if
data_key
not
in
result_dict
:
if
data_key
not
in
result_dict
:
result_dict
[
data_key
]
=
[]
result_dict
[
data_key
]
=
[]
result_dict
[
data_key
].
append
(
item
[
data_key
].
tolist
())
result_dict
[
data_key
].
append
(
item
[
data_key
].
tolist
())
...
@@ -186,13 +186,13 @@ def ordered_save_and_check(data, parameters, filename, generate_golden=False):
...
@@ -186,13 +186,13 @@ def ordered_save_and_check(data, parameters, filename, generate_golden=False):
golden_ref_dir
=
os
.
path
.
join
(
cur_dir
,
"../../data/dataset"
,
'golden'
,
filename
)
golden_ref_dir
=
os
.
path
.
join
(
cur_dir
,
"../../data/dataset"
,
'golden'
,
filename
)
if
generate_golden
:
if
generate_golden
:
# Save as the golden result
# Save as the golden result
save_golden
(
cur_dir
,
golden_ref_dir
,
result_dict
)
_
save_golden
(
cur_dir
,
golden_ref_dir
,
result_dict
)
compare_to_golden
(
golden_ref_dir
,
result_dict
)
_
compare_to_golden
(
golden_ref_dir
,
result_dict
)
if
SAVE_JSON
:
if
SAVE_JSON
:
# Save result to a json file for inspection
# Save result to a json file for inspection
save_json
(
filename
,
parameters
,
result_dict
)
_
save_json
(
filename
,
parameters
,
result_dict
)
def
diff_mse
(
in1
,
in2
):
def
diff_mse
(
in1
,
in2
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录