提交 49ef53f1 编写于 作者: C Cathy Wong

Cleanup dataset UT: util.py internals

上级 2af6ee24
...@@ -69,8 +69,8 @@ def test_HWC2CHW_md5(): ...@@ -69,8 +69,8 @@ def test_HWC2CHW_md5():
data1 = data1.map(input_columns=["image"], operations=decode_op) data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=hwc2chw_op) data1 = data1.map(input_columns=["image"], operations=hwc2chw_op)
# expected md5 from images # Compare with expected md5 from images
filename = "test_HWC2CHW_01_result.npz" filename = "HWC2CHW_01_result.npz"
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
...@@ -103,9 +103,9 @@ def test_HWC2CHW_comp(plot=False): ...@@ -103,9 +103,9 @@ def test_HWC2CHW_comp(plot=False):
c_image = item1["image"] c_image = item1["image"]
py_image = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) py_image = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
# compare images between that applying c_transform and py_transform # Compare images between that applying c_transform and py_transform
mse = diff_mse(py_image, c_image) mse = diff_mse(py_image, c_image)
# the images aren't exactly the same due to rounding error # Note: The images aren't exactly the same due to rounding error
assert mse < 0.001 assert mse < 0.001
image_c_transposed.append(item1["image"].copy()) image_c_transposed.append(item1["image"].copy())
......
...@@ -12,10 +12,9 @@ ...@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import numpy as np
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
import mindspore.dataset.transforms.vision.py_transforms as py_vision import mindspore.dataset.transforms.vision.py_transforms as py_vision
import numpy as np
import matplotlib.pyplot as plt
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
from util import diff_mse, visualize, save_and_check_md5 from util import diff_mse, visualize, save_and_check_md5
...@@ -60,15 +59,14 @@ def test_center_crop_md5(height=375, width=375): ...@@ -60,15 +59,14 @@ def test_center_crop_md5(height=375, width=375):
logger.info("Test CenterCrop") logger.info("Test CenterCrop")
# First dataset # First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle =False) data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
decode_op = vision.Decode() decode_op = vision.Decode()
# 3 images [375, 500] [600, 500] [512, 512] # 3 images [375, 500] [600, 500] [512, 512]
center_crop_op = vision.CenterCrop([height, width]) center_crop_op = vision.CenterCrop([height, width])
data1 = data1.map(input_columns=["image"], operations=decode_op) data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=center_crop_op) data1 = data1.map(input_columns=["image"], operations=center_crop_op)
# expected md5 from images # Compare with expected md5 from images
filename = "center_crop_01_result.npz"
filename = "test_center_crop_01_result.npz"
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
...@@ -89,7 +87,7 @@ def test_center_crop_comp(height=375, width=375, plot=False): ...@@ -89,7 +87,7 @@ def test_center_crop_comp(height=375, width=375, plot=False):
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
transforms = [ transforms = [
py_vision.Decode(), py_vision.Decode(),
py_vision.CenterCrop([height, width]), py_vision.CenterCrop([height, width]),
py_vision.ToTensor() py_vision.ToTensor()
] ]
transform = py_vision.ComposeOp(transforms) transform = py_vision.ComposeOp(transforms)
...@@ -100,27 +98,28 @@ def test_center_crop_comp(height=375, width=375, plot=False): ...@@ -100,27 +98,28 @@ def test_center_crop_comp(height=375, width=375, plot=False):
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
c_image = item1["image"] c_image = item1["image"]
py_image = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) py_image = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
# the images aren't exactly the same due to rouding error # Note: The images aren't exactly the same due to rounding error
assert (diff_mse(py_image, c_image) < 0.001) assert diff_mse(py_image, c_image) < 0.001
image_cropped.append(item1["image"].copy()) image_cropped.append(item1["image"].copy())
image.append(item2["image"].copy()) image.append(item2["image"].copy())
if plot: if plot:
visualize(image, image_cropped) visualize(image, image_cropped)
def test_crop_grayscale(height=375, width=375): def test_crop_grayscale(height=375, width=375):
""" """
Test that centercrop works with pad and grayscale images Test that centercrop works with pad and grayscale images
""" """
def channel_swap(image):
def channel_swap(image):
""" """
Py func hack for our pytransforms to work with c transforms Py func hack for our pytransforms to work with c transforms
""" """
return (image.transpose(1, 2, 0) * 255).astype(np.uint8) return (image.transpose(1, 2, 0) * 255).astype(np.uint8)
transforms = [ transforms = [
py_vision.Decode(), py_vision.Decode(),
py_vision.Grayscale(1), py_vision.Grayscale(1),
py_vision.ToTensor(), py_vision.ToTensor(),
(lambda image: channel_swap(image)) (lambda image: channel_swap(image))
] ]
...@@ -129,16 +128,16 @@ def test_crop_grayscale(height=375, width=375): ...@@ -129,16 +128,16 @@ def test_crop_grayscale(height=375, width=375):
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
data1 = data1.map(input_columns=["image"], operations=transform()) data1 = data1.map(input_columns=["image"], operations=transform())
# if input is grayscale, the output dimensions should be single channel # If input is grayscale, the output dimensions should be single channel
crop_gray = vision.CenterCrop([height, width]) crop_gray = vision.CenterCrop([height, width])
data1 = data1.map(input_columns=["image"], operations=crop_gray) data1 = data1.map(input_columns=["image"], operations=crop_gray)
for item1 in data1.create_dict_iterator(): for item1 in data1.create_dict_iterator():
c_image = item1["image"] c_image = item1["image"]
# check that the image is grayscale # Check that the image is grayscale
assert (len(c_image.shape) == 3 and c_image.shape[2] == 1) assert (c_image.ndim == 3 and c_image.shape[2] == 1)
if __name__ == "__main__": if __name__ == "__main__":
test_center_crop_op(600, 600) test_center_crop_op(600, 600)
......
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from util import save_and_check
import mindspore.dataset as ds
from mindspore import log as logger
DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
"col_sint16", "col_sint32", "col_sint64"]
GENERATE_GOLDEN = False
def test_case_columns_list():
"""
a simple repeat operation.
"""
logger.info("Test Simple Repeat")
# define parameters
repeat_count = 2
parameters = {"params": {'repeat_count': repeat_count}}
columns_list = ["col_sint64", "col_sint32"]
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=columns_list, shuffle=False)
data1 = data1.repeat(repeat_count)
filename = "columns_list_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
...@@ -12,12 +12,11 @@ ...@@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from util import ordered_save_and_check from util import save_and_check_tuple
import mindspore.dataset as ds
DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"] DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"]
SCHEMA_DIR_TF = "../data/dataset/testTFTestAllTypes/datasetSchema.json" SCHEMA_DIR_TF = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
...@@ -32,7 +31,7 @@ def test_case_project_single_column(): ...@@ -32,7 +31,7 @@ def test_case_project_single_column():
data1 = data1.project(columns=columns) data1 = data1.project(columns=columns)
filename = "project_single_column_result.npz" filename = "project_single_column_result.npz"
ordered_save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_case_project_multiple_columns_in_order(): def test_case_project_multiple_columns_in_order():
...@@ -43,7 +42,7 @@ def test_case_project_multiple_columns_in_order(): ...@@ -43,7 +42,7 @@ def test_case_project_multiple_columns_in_order():
data1 = data1.project(columns=columns) data1 = data1.project(columns=columns)
filename = "project_multiple_columns_in_order_result.npz" filename = "project_multiple_columns_in_order_result.npz"
ordered_save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_case_project_multiple_columns_out_of_order(): def test_case_project_multiple_columns_out_of_order():
...@@ -54,7 +53,7 @@ def test_case_project_multiple_columns_out_of_order(): ...@@ -54,7 +53,7 @@ def test_case_project_multiple_columns_out_of_order():
data1 = data1.project(columns=columns) data1 = data1.project(columns=columns)
filename = "project_multiple_columns_out_of_order_result.npz" filename = "project_multiple_columns_out_of_order_result.npz"
ordered_save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_case_project_map(): def test_case_project_map():
...@@ -68,7 +67,7 @@ def test_case_project_map(): ...@@ -68,7 +67,7 @@ def test_case_project_map():
data1 = data1.map(input_columns=["col_3d"], operations=type_cast_op) data1 = data1.map(input_columns=["col_3d"], operations=type_cast_op)
filename = "project_map_after_result.npz" filename = "project_map_after_result.npz"
ordered_save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_case_map_project(): def test_case_map_project():
...@@ -83,7 +82,7 @@ def test_case_map_project(): ...@@ -83,7 +82,7 @@ def test_case_map_project():
data1 = data1.project(columns=columns) data1 = data1.project(columns=columns)
filename = "project_map_before_result.npz" filename = "project_map_before_result.npz"
ordered_save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_case_project_between_maps(): def test_case_project_between_maps():
...@@ -107,7 +106,7 @@ def test_case_project_between_maps(): ...@@ -107,7 +106,7 @@ def test_case_project_between_maps():
data1 = data1.map(input_columns=["col_3d"], operations=type_cast_op) data1 = data1.map(input_columns=["col_3d"], operations=type_cast_op)
filename = "project_between_maps_result.npz" filename = "project_between_maps_result.npz"
ordered_save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_case_project_repeat(): def test_case_project_repeat():
...@@ -121,7 +120,7 @@ def test_case_project_repeat(): ...@@ -121,7 +120,7 @@ def test_case_project_repeat():
data1 = data1.repeat(repeat_count) data1 = data1.repeat(repeat_count)
filename = "project_before_repeat_result.npz" filename = "project_before_repeat_result.npz"
ordered_save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_case_repeat_project(): def test_case_repeat_project():
...@@ -136,7 +135,7 @@ def test_case_repeat_project(): ...@@ -136,7 +135,7 @@ def test_case_repeat_project():
data1 = data1.project(columns=columns) data1 = data1.project(columns=columns)
filename = "project_after_repeat_result.npz" filename = "project_after_repeat_result.npz"
ordered_save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_case_map_project_map_project(): def test_case_map_project_map_project():
...@@ -155,4 +154,4 @@ def test_case_map_project_map_project(): ...@@ -155,4 +154,4 @@ def test_case_map_project_map_project():
data1 = data1.project(columns=columns) data1 = data1.project(columns=columns)
filename = "project_alternate_parallel_inline_result.npz" filename = "project_alternate_parallel_inline_result.npz"
ordered_save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
...@@ -13,11 +13,10 @@ ...@@ -13,11 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
from util import save_and_check
import mindspore.dataset as ds import mindspore.dataset as ds
import numpy as np import numpy as np
from mindspore import log as logger from mindspore import log as logger
from util import save_and_check
DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"] DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"]
SCHEMA_DIR_TF = "../data/dataset/testTFTestAllTypes/datasetSchema.json" SCHEMA_DIR_TF = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
...@@ -25,13 +24,6 @@ COLUMNS_TF = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", ...@@ -25,13 +24,6 @@ COLUMNS_TF = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
"col_sint16", "col_sint32", "col_sint64"] "col_sint16", "col_sint32", "col_sint64"]
GENERATE_GOLDEN = False GENERATE_GOLDEN = False
# Data for CIFAR and MNIST are not part of build tree
# They need to be downloaded directly
# prep_data.py can be exuted or code below
# import sys
# sys.path.insert(0,"../../data")
# import prep_data
# prep_data.download_all_for_test("../../data")
IMG_DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] IMG_DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
IMG_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" IMG_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
...@@ -41,7 +33,7 @@ SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" ...@@ -41,7 +33,7 @@ SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_tf_repeat_01(): def test_tf_repeat_01():
""" """
a simple repeat operation. Test a simple repeat operation.
""" """
logger.info("Test Simple Repeat") logger.info("Test Simple Repeat")
# define parameters # define parameters
...@@ -58,7 +50,7 @@ def test_tf_repeat_01(): ...@@ -58,7 +50,7 @@ def test_tf_repeat_01():
def test_tf_repeat_02(): def test_tf_repeat_02():
""" """
a simple repeat operation to tes infinite Test Infinite Repeat.
""" """
logger.info("Test Infinite Repeat") logger.info("Test Infinite Repeat")
# define parameters # define parameters
...@@ -77,7 +69,10 @@ def test_tf_repeat_02(): ...@@ -77,7 +69,10 @@ def test_tf_repeat_02():
def test_tf_repeat_03(): def test_tf_repeat_03():
'''repeat and batch ''' """
Test Repeat then Batch.
"""
logger.info("Test Repeat then Batch")
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
batch_size = 32 batch_size = 32
...@@ -90,15 +85,32 @@ def test_tf_repeat_03(): ...@@ -90,15 +85,32 @@ def test_tf_repeat_03():
data1 = data1.batch(batch_size, drop_remainder=True) data1 = data1.batch(batch_size, drop_remainder=True)
num_iter = 0 num_iter = 0
for item in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
num_iter += 1 num_iter += 1
logger.info("Number of tf data in data1: {}".format(num_iter)) logger.info("Number of tf data in data1: {}".format(num_iter))
assert num_iter == 2 assert num_iter == 2
def test_tf_repeat_04():
"""
Test a simple repeat operation with column list.
"""
logger.info("Test Simple Repeat Column List")
# define parameters
repeat_count = 2
parameters = {"params": {'repeat_count': repeat_count}}
columns_list = ["col_sint64", "col_sint32"]
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, columns_list=columns_list, shuffle=False)
data1 = data1.repeat(repeat_count)
filename = "repeat_list_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def generator(): def generator():
for i in range(3): for i in range(3):
yield np.array([i]), (yield np.array([i]),)
def test_nested_repeat1(): def test_nested_repeat1():
...@@ -151,7 +163,7 @@ def test_nested_repeat5(): ...@@ -151,7 +163,7 @@ def test_nested_repeat5():
data = data.repeat(2) data = data.repeat(2)
data = data.repeat(3) data = data.repeat(3)
for i, d in enumerate(data): for _, d in enumerate(data):
assert np.array_equal(d[0], np.asarray([[0], [1], [2]])) assert np.array_equal(d[0], np.asarray([[0], [1], [2]]))
assert sum([1 for _ in data]) == 6 assert sum([1 for _ in data]) == 6
...@@ -163,7 +175,7 @@ def test_nested_repeat6(): ...@@ -163,7 +175,7 @@ def test_nested_repeat6():
data = data.batch(3) data = data.batch(3)
data = data.repeat(3) data = data.repeat(3)
for i, d in enumerate(data): for _, d in enumerate(data):
assert np.array_equal(d[0], np.asarray([[0], [1], [2]])) assert np.array_equal(d[0], np.asarray([[0], [1], [2]]))
assert sum([1 for _ in data]) == 6 assert sum([1 for _ in data]) == 6
...@@ -175,7 +187,7 @@ def test_nested_repeat7(): ...@@ -175,7 +187,7 @@ def test_nested_repeat7():
data = data.repeat(3) data = data.repeat(3)
data = data.batch(3) data = data.batch(3)
for i, d in enumerate(data): for _, d in enumerate(data):
assert np.array_equal(d[0], np.asarray([[0], [1], [2]])) assert np.array_equal(d[0], np.asarray([[0], [1], [2]]))
assert sum([1 for _ in data]) == 6 assert sum([1 for _ in data]) == 6
...@@ -232,11 +244,18 @@ def test_nested_repeat11(): ...@@ -232,11 +244,18 @@ def test_nested_repeat11():
if __name__ == "__main__": if __name__ == "__main__":
logger.info("--------test tf repeat 01---------") test_tf_repeat_01()
# test_repeat_01() test_tf_repeat_02()
logger.info("--------test tf repeat 02---------")
# test_repeat_02()
logger.info("--------test tf repeat 03---------")
test_tf_repeat_03() test_tf_repeat_03()
test_tf_repeat_04()
test_nested_repeat1()
test_nested_repeat2()
test_nested_repeat3()
test_nested_repeat4()
test_nested_repeat5()
test_nested_repeat6()
test_nested_repeat7()
test_nested_repeat8()
test_nested_repeat9()
test_nested_repeat10()
test_nested_repeat11()
...@@ -21,12 +21,13 @@ import matplotlib.pyplot as plt ...@@ -21,12 +21,13 @@ import matplotlib.pyplot as plt
#import jsbeautifier #import jsbeautifier
from mindspore import log as logger from mindspore import log as logger
# These are the column names defined in the testTFTestAllTypes dataset
COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
"col_sint16", "col_sint32", "col_sint64"] "col_sint16", "col_sint32", "col_sint64"]
SAVE_JSON = False SAVE_JSON = False
def save_golden(cur_dir, golden_ref_dir, result_dict): def _save_golden(cur_dir, golden_ref_dir, result_dict):
""" """
Save the dictionary values as the golden result in .npz file Save the dictionary values as the golden result in .npz file
""" """
...@@ -35,7 +36,7 @@ def save_golden(cur_dir, golden_ref_dir, result_dict): ...@@ -35,7 +36,7 @@ def save_golden(cur_dir, golden_ref_dir, result_dict):
np.savez(golden_ref_dir, np.array(list(result_dict.values()))) np.savez(golden_ref_dir, np.array(list(result_dict.values())))
def save_golden_dict(cur_dir, golden_ref_dir, result_dict): def _save_golden_dict(cur_dir, golden_ref_dir, result_dict):
""" """
Save the dictionary (both keys and values) as the golden result in .npz file Save the dictionary (both keys and values) as the golden result in .npz file
""" """
...@@ -44,7 +45,7 @@ def save_golden_dict(cur_dir, golden_ref_dir, result_dict): ...@@ -44,7 +45,7 @@ def save_golden_dict(cur_dir, golden_ref_dir, result_dict):
np.savez(golden_ref_dir, np.array(list(result_dict.items()))) np.savez(golden_ref_dir, np.array(list(result_dict.items())))
def compare_to_golden(golden_ref_dir, result_dict): def _compare_to_golden(golden_ref_dir, result_dict):
""" """
Compare as numpy arrays the test result to the golden result Compare as numpy arrays the test result to the golden result
""" """
...@@ -53,16 +54,15 @@ def compare_to_golden(golden_ref_dir, result_dict): ...@@ -53,16 +54,15 @@ def compare_to_golden(golden_ref_dir, result_dict):
assert np.array_equal(test_array, golden_array) assert np.array_equal(test_array, golden_array)
def compare_to_golden_dict(golden_ref_dir, result_dict): def _compare_to_golden_dict(golden_ref_dir, result_dict):
""" """
Compare as dictionaries the test result to the golden result Compare as dictionaries the test result to the golden result
""" """
golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0'] golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
np.testing.assert_equal(result_dict, dict(golden_array)) np.testing.assert_equal(result_dict, dict(golden_array))
# assert result_dict == dict(golden_array)
def save_json(filename, parameters, result_dict): def _save_json(filename, parameters, result_dict):
""" """
Save the result dictionary in json file Save the result dictionary in json file
""" """
...@@ -78,6 +78,7 @@ def save_and_check(data, parameters, filename, generate_golden=False): ...@@ -78,6 +78,7 @@ def save_and_check(data, parameters, filename, generate_golden=False):
""" """
Save the dataset dictionary and compare (as numpy array) with golden file. Save the dataset dictionary and compare (as numpy array) with golden file.
Use create_dict_iterator to access the dataset. Use create_dict_iterator to access the dataset.
Note: save_and_check() is deprecated; use save_and_check_dict().
""" """
num_iter = 0 num_iter = 0
result_dict = {} result_dict = {}
...@@ -97,13 +98,13 @@ def save_and_check(data, parameters, filename, generate_golden=False): ...@@ -97,13 +98,13 @@ def save_and_check(data, parameters, filename, generate_golden=False):
golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename) golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
if generate_golden: if generate_golden:
# Save as the golden result # Save as the golden result
save_golden(cur_dir, golden_ref_dir, result_dict) _save_golden(cur_dir, golden_ref_dir, result_dict)
compare_to_golden(golden_ref_dir, result_dict) _compare_to_golden(golden_ref_dir, result_dict)
if SAVE_JSON: if SAVE_JSON:
# Save result to a json file for inspection # Save result to a json file for inspection
save_json(filename, parameters, result_dict) _save_json(filename, parameters, result_dict)
def save_and_check_dict(data, filename, generate_golden=False): def save_and_check_dict(data, filename, generate_golden=False):
...@@ -127,14 +128,14 @@ def save_and_check_dict(data, filename, generate_golden=False): ...@@ -127,14 +128,14 @@ def save_and_check_dict(data, filename, generate_golden=False):
golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename) golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
if generate_golden: if generate_golden:
# Save as the golden result # Save as the golden result
save_golden_dict(cur_dir, golden_ref_dir, result_dict) _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
compare_to_golden_dict(golden_ref_dir, result_dict) _compare_to_golden_dict(golden_ref_dir, result_dict)
if SAVE_JSON: if SAVE_JSON:
# Save result to a json file for inspection # Save result to a json file for inspection
parameters = {"params": {}} parameters = {"params": {}}
save_json(filename, parameters, result_dict) _save_json(filename, parameters, result_dict)
def save_and_check_md5(data, filename, generate_golden=False): def save_and_check_md5(data, filename, generate_golden=False):
...@@ -159,22 +160,21 @@ def save_and_check_md5(data, filename, generate_golden=False): ...@@ -159,22 +160,21 @@ def save_and_check_md5(data, filename, generate_golden=False):
golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename) golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
if generate_golden: if generate_golden:
# Save as the golden result # Save as the golden result
save_golden_dict(cur_dir, golden_ref_dir, result_dict) _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
compare_to_golden_dict(golden_ref_dir, result_dict) _compare_to_golden_dict(golden_ref_dir, result_dict)
def ordered_save_and_check(data, parameters, filename, generate_golden=False): def save_and_check_tuple(data, parameters, filename, generate_golden=False):
""" """
Save the dataset dictionary and compare (as numpy array) with golden file. Save the dataset dictionary and compare (as numpy array) with golden file.
Use create_tuple_iterator to access the dataset. Use create_tuple_iterator to access the dataset.
""" """
num_iter = 0 num_iter = 0
result_dict = {} result_dict = {}
for item in data.create_tuple_iterator(): # each data is a dictionary for item in data.create_tuple_iterator(): # each data is a dictionary
for data_key in range(0, len(item)): for data_key, _ in enumerate(item):
if data_key not in result_dict: if data_key not in result_dict:
result_dict[data_key] = [] result_dict[data_key] = []
result_dict[data_key].append(item[data_key].tolist()) result_dict[data_key].append(item[data_key].tolist())
...@@ -186,13 +186,13 @@ def ordered_save_and_check(data, parameters, filename, generate_golden=False): ...@@ -186,13 +186,13 @@ def ordered_save_and_check(data, parameters, filename, generate_golden=False):
golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename) golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
if generate_golden: if generate_golden:
# Save as the golden result # Save as the golden result
save_golden(cur_dir, golden_ref_dir, result_dict) _save_golden(cur_dir, golden_ref_dir, result_dict)
compare_to_golden(golden_ref_dir, result_dict) _compare_to_golden(golden_ref_dir, result_dict)
if SAVE_JSON: if SAVE_JSON:
# Save result to a json file for inspection # Save result to a json file for inspection
save_json(filename, parameters, result_dict) _save_json(filename, parameters, result_dict)
def diff_mse(in1, in2): def diff_mse(in1, in2):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册