Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
49ef53f1
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
49ef53f1
编写于
5月 11, 2020
作者:
C
Cathy Wong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cleanup dataset UT: util.py internals
上级
2af6ee24
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
96 addition
and
120 deletion
+96
-120
tests/ut/data/dataset/golden/HWC2CHW_01_result.npz
tests/ut/data/dataset/golden/HWC2CHW_01_result.npz
+0
-0
tests/ut/data/dataset/golden/center_crop_01_result.npz
tests/ut/data/dataset/golden/center_crop_01_result.npz
+0
-0
tests/ut/data/dataset/golden/repeat_list_result.npz
tests/ut/data/dataset/golden/repeat_list_result.npz
+0
-0
tests/ut/python/dataset/test_HWC2CHW.py
tests/ut/python/dataset/test_HWC2CHW.py
+4
-4
tests/ut/python/dataset/test_center_crop.py
tests/ut/python/dataset/test_center_crop.py
+18
-19
tests/ut/python/dataset/test_general.py
tests/ut/python/dataset/test_general.py
+0
-41
tests/ut/python/dataset/test_project.py
tests/ut/python/dataset/test_project.py
+11
-12
tests/ut/python/dataset/test_repeat.py
tests/ut/python/dataset/test_repeat.py
+43
-24
tests/ut/python/dataset/util.py
tests/ut/python/dataset/util.py
+20
-20
未找到文件。
tests/ut/data/dataset/golden/
test_
HWC2CHW_01_result.npz
→
tests/ut/data/dataset/golden/HWC2CHW_01_result.npz
浏览文件 @
49ef53f1
文件已移动
tests/ut/data/dataset/golden/
test_
center_crop_01_result.npz
→
tests/ut/data/dataset/golden/center_crop_01_result.npz
浏览文件 @
49ef53f1
文件已移动
tests/ut/data/dataset/golden/
columns
_list_result.npz
→
tests/ut/data/dataset/golden/
repeat
_list_result.npz
浏览文件 @
49ef53f1
文件已移动
tests/ut/python/dataset/test_HWC2CHW.py
浏览文件 @
49ef53f1
...
...
@@ -69,8 +69,8 @@ def test_HWC2CHW_md5():
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
hwc2chw_op
)
# expected md5 from images
filename
=
"
test_
HWC2CHW_01_result.npz"
#
Compare with
expected md5 from images
filename
=
"HWC2CHW_01_result.npz"
save_and_check_md5
(
data1
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
...
...
@@ -103,9 +103,9 @@ def test_HWC2CHW_comp(plot=False):
c_image
=
item1
[
"image"
]
py_image
=
(
item2
[
"image"
].
transpose
(
1
,
2
,
0
)
*
255
).
astype
(
np
.
uint8
)
#
c
ompare images between that applying c_transform and py_transform
#
C
ompare images between that applying c_transform and py_transform
mse
=
diff_mse
(
py_image
,
c_image
)
#
t
he images aren't exactly the same due to rounding error
#
Note: T
he images aren't exactly the same due to rounding error
assert
mse
<
0.001
image_c_transposed
.
append
(
item1
[
"image"
].
copy
())
...
...
tests/ut/python/dataset/test_center_crop.py
浏览文件 @
49ef53f1
...
...
@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
numpy
as
np
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
mindspore.dataset.transforms.vision.py_transforms
as
py_vision
import
numpy
as
np
import
matplotlib.pyplot
as
plt
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
from
util
import
diff_mse
,
visualize
,
save_and_check_md5
...
...
@@ -60,15 +59,14 @@ def test_center_crop_md5(height=375, width=375):
logger
.
info
(
"Test CenterCrop"
)
# First dataset
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
decode_op
=
vision
.
Decode
()
# 3 images [375, 500] [600, 500] [512, 512]
center_crop_op
=
vision
.
CenterCrop
([
height
,
width
])
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
center_crop_op
)
# expected md5 from images
filename
=
"test_center_crop_01_result.npz"
# Compare with expected md5 from images
filename
=
"center_crop_01_result.npz"
save_and_check_md5
(
data1
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
...
...
@@ -100,8 +98,8 @@ def test_center_crop_comp(height=375, width=375, plot=False):
for
item1
,
item2
in
zip
(
data1
.
create_dict_iterator
(),
data2
.
create_dict_iterator
()):
c_image
=
item1
[
"image"
]
py_image
=
(
item2
[
"image"
].
transpose
(
1
,
2
,
0
)
*
255
).
astype
(
np
.
uint8
)
#
the images aren't exactly the same due to rouding error
assert
(
diff_mse
(
py_image
,
c_image
)
<
0.001
)
#
Note: The images aren't exactly the same due to rounding error
assert
diff_mse
(
py_image
,
c_image
)
<
0.001
image_cropped
.
append
(
item1
[
"image"
].
copy
())
image
.
append
(
item2
[
"image"
].
copy
())
if
plot
:
...
...
@@ -112,6 +110,7 @@ def test_crop_grayscale(height=375, width=375):
"""
Test that centercrop works with pad and grayscale images
"""
def
channel_swap
(
image
):
"""
Py func hack for our pytransforms to work with c transforms
...
...
@@ -129,15 +128,15 @@ def test_crop_grayscale(height=375, width=375):
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
transform
())
#
if input is grayscale, the output dimensions should be single channel
#
If input is grayscale, the output dimensions should be single channel
crop_gray
=
vision
.
CenterCrop
([
height
,
width
])
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
crop_gray
)
for
item1
in
data1
.
create_dict_iterator
():
c_image
=
item1
[
"image"
]
#
c
heck that the image is grayscale
assert
(
len
(
c_image
.
shape
)
==
3
and
c_image
.
shape
[
2
]
==
1
)
#
C
heck that the image is grayscale
assert
(
c_image
.
ndim
==
3
and
c_image
.
shape
[
2
]
==
1
)
if
__name__
==
"__main__"
:
...
...
tests/ut/python/dataset/test_general.py
已删除
100644 → 0
浏览文件 @
2af6ee24
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
util
import
save_and_check
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
DATA_DIR
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
SCHEMA_DIR
=
"../data/dataset/testTFTestAllTypes/datasetSchema.json"
COLUMNS
=
[
"col_1d"
,
"col_2d"
,
"col_3d"
,
"col_binary"
,
"col_float"
,
"col_sint16"
,
"col_sint32"
,
"col_sint64"
]
GENERATE_GOLDEN
=
False
def
test_case_columns_list
():
"""
a simple repeat operation.
"""
logger
.
info
(
"Test Simple Repeat"
)
# define parameters
repeat_count
=
2
parameters
=
{
"params"
:
{
'repeat_count'
:
repeat_count
}}
columns_list
=
[
"col_sint64"
,
"col_sint32"
]
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
columns_list
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
repeat_count
)
filename
=
"columns_list_result.npz"
save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
tests/ut/python/dataset/test_project.py
浏览文件 @
49ef53f1
...
...
@@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
mindspore.dataset
.transforms.vision.c_transforms
as
vision
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.c_transforms
as
C
from
mindspore.common
import
dtype
as
mstype
from
util
import
ordered_save_and_check
from
util
import
save_and_check_tuple
import
mindspore.dataset
as
ds
DATA_DIR_TF
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
SCHEMA_DIR_TF
=
"../data/dataset/testTFTestAllTypes/datasetSchema.json"
...
...
@@ -32,7 +31,7 @@ def test_case_project_single_column():
data1
=
data1
.
project
(
columns
=
columns
)
filename
=
"project_single_column_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_project_multiple_columns_in_order
():
...
...
@@ -43,7 +42,7 @@ def test_case_project_multiple_columns_in_order():
data1
=
data1
.
project
(
columns
=
columns
)
filename
=
"project_multiple_columns_in_order_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_project_multiple_columns_out_of_order
():
...
...
@@ -54,7 +53,7 @@ def test_case_project_multiple_columns_out_of_order():
data1
=
data1
.
project
(
columns
=
columns
)
filename
=
"project_multiple_columns_out_of_order_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_project_map
():
...
...
@@ -68,7 +67,7 @@ def test_case_project_map():
data1
=
data1
.
map
(
input_columns
=
[
"col_3d"
],
operations
=
type_cast_op
)
filename
=
"project_map_after_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_map_project
():
...
...
@@ -83,7 +82,7 @@ def test_case_map_project():
data1
=
data1
.
project
(
columns
=
columns
)
filename
=
"project_map_before_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_project_between_maps
():
...
...
@@ -107,7 +106,7 @@ def test_case_project_between_maps():
data1
=
data1
.
map
(
input_columns
=
[
"col_3d"
],
operations
=
type_cast_op
)
filename
=
"project_between_maps_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_project_repeat
():
...
...
@@ -121,7 +120,7 @@ def test_case_project_repeat():
data1
=
data1
.
repeat
(
repeat_count
)
filename
=
"project_before_repeat_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_repeat_project
():
...
...
@@ -136,7 +135,7 @@ def test_case_repeat_project():
data1
=
data1
.
project
(
columns
=
columns
)
filename
=
"project_after_repeat_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_map_project_map_project
():
...
...
@@ -155,4 +154,4 @@ def test_case_map_project_map_project():
data1
=
data1
.
project
(
columns
=
columns
)
filename
=
"project_alternate_parallel_inline_result.npz"
ordered_save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check_tuple
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
tests/ut/python/dataset/test_repeat.py
浏览文件 @
49ef53f1
...
...
@@ -13,11 +13,10 @@
# limitations under the License.
# ==============================================================================
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
util
import
save_and_check
import
mindspore.dataset
as
ds
import
numpy
as
np
from
mindspore
import
log
as
logger
from
util
import
save_and_check
DATA_DIR_TF
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
SCHEMA_DIR_TF
=
"../data/dataset/testTFTestAllTypes/datasetSchema.json"
...
...
@@ -25,13 +24,6 @@ COLUMNS_TF = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
"col_sint16"
,
"col_sint32"
,
"col_sint64"
]
GENERATE_GOLDEN
=
False
# Data for CIFAR and MNIST are not part of build tree
# They need to be downloaded directly
# prep_data.py can be exuted or code below
# import sys
# sys.path.insert(0,"../../data")
# import prep_data
# prep_data.download_all_for_test("../../data")
IMG_DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
IMG_SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
...
...
@@ -41,7 +33,7 @@ SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def
test_tf_repeat_01
():
"""
a simple repeat operation.
Test
a simple repeat operation.
"""
logger
.
info
(
"Test Simple Repeat"
)
# define parameters
...
...
@@ -58,7 +50,7 @@ def test_tf_repeat_01():
def
test_tf_repeat_02
():
"""
a simple repeat operation to tes infinite
Test Infinite Repeat.
"""
logger
.
info
(
"Test Infinite Repeat"
)
# define parameters
...
...
@@ -77,7 +69,10 @@ def test_tf_repeat_02():
def
test_tf_repeat_03
():
'''repeat and batch '''
"""
Test Repeat then Batch.
"""
logger
.
info
(
"Test Repeat then Batch"
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF2
,
SCHEMA_DIR_TF2
,
shuffle
=
False
)
batch_size
=
32
...
...
@@ -90,15 +85,32 @@ def test_tf_repeat_03():
data1
=
data1
.
batch
(
batch_size
,
drop_remainder
=
True
)
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
for
_
in
data1
.
create_dict_iterator
():
num_iter
+=
1
logger
.
info
(
"Number of tf data in data1: {}"
.
format
(
num_iter
))
assert
num_iter
==
2
def
test_tf_repeat_04
():
"""
Test a simple repeat operation with column list.
"""
logger
.
info
(
"Test Simple Repeat Column List"
)
# define parameters
repeat_count
=
2
parameters
=
{
"params"
:
{
'repeat_count'
:
repeat_count
}}
columns_list
=
[
"col_sint64"
,
"col_sint32"
]
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF
,
SCHEMA_DIR_TF
,
columns_list
=
columns_list
,
shuffle
=
False
)
data1
=
data1
.
repeat
(
repeat_count
)
filename
=
"repeat_list_result.npz"
save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
generator
():
for
i
in
range
(
3
):
yield
np
.
array
([
i
]),
(
yield
np
.
array
([
i
]),)
def
test_nested_repeat1
():
...
...
@@ -151,7 +163,7 @@ def test_nested_repeat5():
data
=
data
.
repeat
(
2
)
data
=
data
.
repeat
(
3
)
for
i
,
d
in
enumerate
(
data
):
for
_
,
d
in
enumerate
(
data
):
assert
np
.
array_equal
(
d
[
0
],
np
.
asarray
([[
0
],
[
1
],
[
2
]]))
assert
sum
([
1
for
_
in
data
])
==
6
...
...
@@ -163,7 +175,7 @@ def test_nested_repeat6():
data
=
data
.
batch
(
3
)
data
=
data
.
repeat
(
3
)
for
i
,
d
in
enumerate
(
data
):
for
_
,
d
in
enumerate
(
data
):
assert
np
.
array_equal
(
d
[
0
],
np
.
asarray
([[
0
],
[
1
],
[
2
]]))
assert
sum
([
1
for
_
in
data
])
==
6
...
...
@@ -175,7 +187,7 @@ def test_nested_repeat7():
data
=
data
.
repeat
(
3
)
data
=
data
.
batch
(
3
)
for
i
,
d
in
enumerate
(
data
):
for
_
,
d
in
enumerate
(
data
):
assert
np
.
array_equal
(
d
[
0
],
np
.
asarray
([[
0
],
[
1
],
[
2
]]))
assert
sum
([
1
for
_
in
data
])
==
6
...
...
@@ -232,11 +244,18 @@ def test_nested_repeat11():
if
__name__
==
"__main__"
:
logger
.
info
(
"--------test tf repeat 01---------"
)
# test_repeat_01()
logger
.
info
(
"--------test tf repeat 02---------"
)
# test_repeat_02()
logger
.
info
(
"--------test tf repeat 03---------"
)
test_tf_repeat_01
()
test_tf_repeat_02
()
test_tf_repeat_03
()
test_tf_repeat_04
()
test_nested_repeat1
()
test_nested_repeat2
()
test_nested_repeat3
()
test_nested_repeat4
()
test_nested_repeat5
()
test_nested_repeat6
()
test_nested_repeat7
()
test_nested_repeat8
()
test_nested_repeat9
()
test_nested_repeat10
()
test_nested_repeat11
()
tests/ut/python/dataset/util.py
浏览文件 @
49ef53f1
...
...
@@ -21,12 +21,13 @@ import matplotlib.pyplot as plt
#import jsbeautifier
from
mindspore
import
log
as
logger
# These are the column names defined in the testTFTestAllTypes dataset
COLUMNS
=
[
"col_1d"
,
"col_2d"
,
"col_3d"
,
"col_binary"
,
"col_float"
,
"col_sint16"
,
"col_sint32"
,
"col_sint64"
]
SAVE_JSON
=
False
def
save_golden
(
cur_dir
,
golden_ref_dir
,
result_dict
):
def
_
save_golden
(
cur_dir
,
golden_ref_dir
,
result_dict
):
"""
Save the dictionary values as the golden result in .npz file
"""
...
...
@@ -35,7 +36,7 @@ def save_golden(cur_dir, golden_ref_dir, result_dict):
np
.
savez
(
golden_ref_dir
,
np
.
array
(
list
(
result_dict
.
values
())))
def
save_golden_dict
(
cur_dir
,
golden_ref_dir
,
result_dict
):
def
_
save_golden_dict
(
cur_dir
,
golden_ref_dir
,
result_dict
):
"""
Save the dictionary (both keys and values) as the golden result in .npz file
"""
...
...
@@ -44,7 +45,7 @@ def save_golden_dict(cur_dir, golden_ref_dir, result_dict):
np
.
savez
(
golden_ref_dir
,
np
.
array
(
list
(
result_dict
.
items
())))
def
compare_to_golden
(
golden_ref_dir
,
result_dict
):
def
_
compare_to_golden
(
golden_ref_dir
,
result_dict
):
"""
Compare as numpy arrays the test result to the golden result
"""
...
...
@@ -53,16 +54,15 @@ def compare_to_golden(golden_ref_dir, result_dict):
assert
np
.
array_equal
(
test_array
,
golden_array
)
def
compare_to_golden_dict
(
golden_ref_dir
,
result_dict
):
def
_
compare_to_golden_dict
(
golden_ref_dir
,
result_dict
):
"""
Compare as dictionaries the test result to the golden result
"""
golden_array
=
np
.
load
(
golden_ref_dir
,
allow_pickle
=
True
)[
'arr_0'
]
np
.
testing
.
assert_equal
(
result_dict
,
dict
(
golden_array
))
# assert result_dict == dict(golden_array)
def
save_json
(
filename
,
parameters
,
result_dict
):
def
_
save_json
(
filename
,
parameters
,
result_dict
):
"""
Save the result dictionary in json file
"""
...
...
@@ -78,6 +78,7 @@ def save_and_check(data, parameters, filename, generate_golden=False):
"""
Save the dataset dictionary and compare (as numpy array) with golden file.
Use create_dict_iterator to access the dataset.
Note: save_and_check() is deprecated; use save_and_check_dict().
"""
num_iter
=
0
result_dict
=
{}
...
...
@@ -97,13 +98,13 @@ def save_and_check(data, parameters, filename, generate_golden=False):
golden_ref_dir
=
os
.
path
.
join
(
cur_dir
,
"../../data/dataset"
,
'golden'
,
filename
)
if
generate_golden
:
# Save as the golden result
save_golden
(
cur_dir
,
golden_ref_dir
,
result_dict
)
_
save_golden
(
cur_dir
,
golden_ref_dir
,
result_dict
)
compare_to_golden
(
golden_ref_dir
,
result_dict
)
_
compare_to_golden
(
golden_ref_dir
,
result_dict
)
if
SAVE_JSON
:
# Save result to a json file for inspection
save_json
(
filename
,
parameters
,
result_dict
)
_
save_json
(
filename
,
parameters
,
result_dict
)
def
save_and_check_dict
(
data
,
filename
,
generate_golden
=
False
):
...
...
@@ -127,14 +128,14 @@ def save_and_check_dict(data, filename, generate_golden=False):
golden_ref_dir
=
os
.
path
.
join
(
cur_dir
,
"../../data/dataset"
,
'golden'
,
filename
)
if
generate_golden
:
# Save as the golden result
save_golden_dict
(
cur_dir
,
golden_ref_dir
,
result_dict
)
_
save_golden_dict
(
cur_dir
,
golden_ref_dir
,
result_dict
)
compare_to_golden_dict
(
golden_ref_dir
,
result_dict
)
_
compare_to_golden_dict
(
golden_ref_dir
,
result_dict
)
if
SAVE_JSON
:
# Save result to a json file for inspection
parameters
=
{
"params"
:
{}}
save_json
(
filename
,
parameters
,
result_dict
)
_
save_json
(
filename
,
parameters
,
result_dict
)
def
save_and_check_md5
(
data
,
filename
,
generate_golden
=
False
):
...
...
@@ -159,22 +160,21 @@ def save_and_check_md5(data, filename, generate_golden=False):
golden_ref_dir
=
os
.
path
.
join
(
cur_dir
,
"../../data/dataset"
,
'golden'
,
filename
)
if
generate_golden
:
# Save as the golden result
save_golden_dict
(
cur_dir
,
golden_ref_dir
,
result_dict
)
_
save_golden_dict
(
cur_dir
,
golden_ref_dir
,
result_dict
)
compare_to_golden_dict
(
golden_ref_dir
,
result_dict
)
_
compare_to_golden_dict
(
golden_ref_dir
,
result_dict
)
def
ordered_save_and_check
(
data
,
parameters
,
filename
,
generate_golden
=
False
):
def
save_and_check_tuple
(
data
,
parameters
,
filename
,
generate_golden
=
False
):
"""
Save the dataset dictionary and compare (as numpy array) with golden file.
Use create_tuple_iterator to access the dataset.
"""
num_iter
=
0
result_dict
=
{}
for
item
in
data
.
create_tuple_iterator
():
# each data is a dictionary
for
data_key
in
range
(
0
,
len
(
item
)
):
for
data_key
,
_
in
enumerate
(
item
):
if
data_key
not
in
result_dict
:
result_dict
[
data_key
]
=
[]
result_dict
[
data_key
].
append
(
item
[
data_key
].
tolist
())
...
...
@@ -186,13 +186,13 @@ def ordered_save_and_check(data, parameters, filename, generate_golden=False):
golden_ref_dir
=
os
.
path
.
join
(
cur_dir
,
"../../data/dataset"
,
'golden'
,
filename
)
if
generate_golden
:
# Save as the golden result
save_golden
(
cur_dir
,
golden_ref_dir
,
result_dict
)
_
save_golden
(
cur_dir
,
golden_ref_dir
,
result_dict
)
compare_to_golden
(
golden_ref_dir
,
result_dict
)
_
compare_to_golden
(
golden_ref_dir
,
result_dict
)
if
SAVE_JSON
:
# Save result to a json file for inspection
save_json
(
filename
,
parameters
,
result_dict
)
_
save_json
(
filename
,
parameters
,
result_dict
)
def
diff_mse
(
in1
,
in2
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录