提交 772e6c14 编写于 作者: C Cathy Wong

Cleanup dataset UT: test_batch, save_and_check support

上级 651e9081
...@@ -37,6 +37,7 @@ def test_batch_01(): ...@@ -37,6 +37,7 @@ def test_batch_01():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, drop_remainder) data1 = data1.batch(batch_size, drop_remainder)
assert sum([1 for _ in data1]) == 6
filename = "batch_01_result.npz" filename = "batch_01_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
...@@ -56,6 +57,7 @@ def test_batch_02(): ...@@ -56,6 +57,7 @@ def test_batch_02():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, drop_remainder=drop_remainder) data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
assert sum([1 for _ in data1]) == 2
filename = "batch_02_result.npz" filename = "batch_02_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
...@@ -75,6 +77,7 @@ def test_batch_03(): ...@@ -75,6 +77,7 @@ def test_batch_03():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size=batch_size, drop_remainder=drop_remainder) data1 = data1.batch(batch_size=batch_size, drop_remainder=drop_remainder)
assert sum([1 for _ in data1]) == 4
filename = "batch_03_result.npz" filename = "batch_03_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
...@@ -94,6 +97,7 @@ def test_batch_04(): ...@@ -94,6 +97,7 @@ def test_batch_04():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, drop_remainder) data1 = data1.batch(batch_size, drop_remainder)
assert sum([1 for _ in data1]) == 2
filename = "batch_04_result.npz" filename = "batch_04_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
...@@ -111,6 +115,7 @@ def test_batch_05(): ...@@ -111,6 +115,7 @@ def test_batch_05():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size) data1 = data1.batch(batch_size)
assert sum([1 for _ in data1]) == 12
filename = "batch_05_result.npz" filename = "batch_05_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
...@@ -130,6 +135,7 @@ def test_batch_06(): ...@@ -130,6 +135,7 @@ def test_batch_06():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(drop_remainder=drop_remainder, batch_size=batch_size) data1 = data1.batch(drop_remainder=drop_remainder, batch_size=batch_size)
assert sum([1 for _ in data1]) == 1
filename = "batch_06_result.npz" filename = "batch_06_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
...@@ -152,6 +158,7 @@ def test_batch_07(): ...@@ -152,6 +158,7 @@ def test_batch_07():
data1 = data1.batch(num_parallel_workers=num_parallel_workers, drop_remainder=drop_remainder, data1 = data1.batch(num_parallel_workers=num_parallel_workers, drop_remainder=drop_remainder,
batch_size=batch_size) batch_size=batch_size)
assert sum([1 for _ in data1]) == 3
filename = "batch_07_result.npz" filename = "batch_07_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
...@@ -171,6 +178,7 @@ def test_batch_08(): ...@@ -171,6 +178,7 @@ def test_batch_08():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, num_parallel_workers=num_parallel_workers) data1 = data1.batch(batch_size, num_parallel_workers=num_parallel_workers)
assert sum([1 for _ in data1]) == 2
filename = "batch_08_result.npz" filename = "batch_08_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
...@@ -190,6 +198,7 @@ def test_batch_09(): ...@@ -190,6 +198,7 @@ def test_batch_09():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, drop_remainder=drop_remainder) data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
assert sum([1 for _ in data1]) == 1
filename = "batch_09_result.npz" filename = "batch_09_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
...@@ -209,6 +218,7 @@ def test_batch_10(): ...@@ -209,6 +218,7 @@ def test_batch_10():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, drop_remainder=drop_remainder) data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
assert sum([1 for _ in data1]) == 0
filename = "batch_10_result.npz" filename = "batch_10_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
...@@ -228,10 +238,30 @@ def test_batch_11(): ...@@ -228,10 +238,30 @@ def test_batch_11():
data1 = ds.TFRecordDataset(DATA_DIR, schema_file) data1 = ds.TFRecordDataset(DATA_DIR, schema_file)
data1 = data1.batch(batch_size) data1 = data1.batch(batch_size)
assert sum([1 for _ in data1]) == 1
filename = "batch_11_result.npz" filename = "batch_11_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_12():
"""
Test batch: batch_size boolean value True, treated as valid value 1
"""
logger.info("test_batch_12")
# define parameters
batch_size = True
parameters = {"params": {'batch_size': batch_size}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size=batch_size)
assert sum([1 for _ in data1]) == 12
filename = "batch_12_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_exception_01(): def test_batch_exception_01():
""" """
Test batch exception: num_parallel_workers=0 Test batch exception: num_parallel_workers=0
...@@ -302,7 +332,7 @@ def test_batch_exception_04(): ...@@ -302,7 +332,7 @@ def test_batch_exception_04():
def test_batch_exception_05(): def test_batch_exception_05():
""" """
Test batch exception: batch_size wrong type, boolean value False Test batch exception: batch_size boolean value False, treated as invalid value 0
""" """
logger.info("test_batch_exception_05") logger.info("test_batch_exception_05")
...@@ -317,23 +347,6 @@ def test_batch_exception_05(): ...@@ -317,23 +347,6 @@ def test_batch_exception_05():
assert "batch_size" in str(e) assert "batch_size" in str(e)
def skip_test_batch_exception_06():
"""
Test batch exception: batch_size wrong type, boolean value True
"""
logger.info("test_batch_exception_06")
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
try:
data1 = data1.batch(batch_size=True)
sum([1 for _ in data1])
except BaseException as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "batch_size" in str(e)
def test_batch_exception_07(): def test_batch_exception_07():
""" """
Test batch exception: drop_remainder wrong type Test batch exception: drop_remainder wrong type
...@@ -473,12 +486,12 @@ if __name__ == '__main__': ...@@ -473,12 +486,12 @@ if __name__ == '__main__':
test_batch_09() test_batch_09()
test_batch_10() test_batch_10()
test_batch_11() test_batch_11()
test_batch_12()
test_batch_exception_01() test_batch_exception_01()
test_batch_exception_02() test_batch_exception_02()
test_batch_exception_03() test_batch_exception_03()
test_batch_exception_04() test_batch_exception_04()
test_batch_exception_05() test_batch_exception_05()
skip_test_batch_exception_06()
test_batch_exception_07() test_batch_exception_07()
test_batch_exception_08() test_batch_exception_08()
test_batch_exception_09() test_batch_exception_09()
......
...@@ -28,7 +28,7 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" ...@@ -28,7 +28,7 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_center_crop_op(height=375, width=375, plot=False): def test_center_crop_op(height=375, width=375, plot=False):
""" """
Test random_vertical Test CenterCrop
""" """
logger.info("Test CenterCrop") logger.info("Test CenterCrop")
...@@ -55,7 +55,7 @@ def test_center_crop_op(height=375, width=375, plot=False): ...@@ -55,7 +55,7 @@ def test_center_crop_op(height=375, width=375, plot=False):
def test_center_crop_md5(height=375, width=375): def test_center_crop_md5(height=375, width=375):
""" """
Test random_vertical Test CenterCrop
""" """
logger.info("Test CenterCrop") logger.info("Test CenterCrop")
...@@ -69,13 +69,12 @@ def test_center_crop_md5(height=375, width=375): ...@@ -69,13 +69,12 @@ def test_center_crop_md5(height=375, width=375):
# expected md5 from images # expected md5 from images
filename = "test_center_crop_01_result.npz" filename = "test_center_crop_01_result.npz"
parameters = {"params": {}} save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_md5(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_center_crop_comp(height=375, width=375, plot=False): def test_center_crop_comp(height=375, width=375, plot=False):
""" """
Test random_vertical between python and c image augmentation Test CenterCrop between python and c image augmentation
""" """
logger.info("Test CenterCrop") logger.info("Test CenterCrop")
...@@ -114,4 +113,4 @@ if __name__ == "__main__": ...@@ -114,4 +113,4 @@ if __name__ == "__main__":
test_center_crop_op(300, 600) test_center_crop_op(300, 600)
test_center_crop_op(600, 300) test_center_crop_op(600, 300)
test_center_crop_md5(600, 600) test_center_crop_md5(600, 600)
test_center_crop_comp() test_center_crop_comp()
\ No newline at end of file
...@@ -18,9 +18,7 @@ Testing Decode op in DE ...@@ -18,9 +18,7 @@ Testing Decode op in DE
import cv2 import cv2
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
import numpy as np import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
""" """
Testing RandomRotation op in DE Testing RandomColorAdjust op in DE
""" """
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import mindspore.dataset.transforms.vision.c_transforms as c_vision import mindspore.dataset.transforms.vision.c_transforms as c_vision
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
""" """
Testing RandomRotation op in DE Testing RandomErasing op in DE
""" """
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
......
...@@ -17,9 +17,7 @@ Testing the resize op in DE ...@@ -17,9 +17,7 @@ Testing the resize op in DE
""" """
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
""" """
Testing RandomRotation op in DE Testing TypeCast op in DE
""" """
import mindspore.dataset.transforms.vision.c_transforms as c_vision import mindspore.dataset.transforms.vision.c_transforms as c_vision
import mindspore.dataset.transforms.vision.py_transforms as py_vision import mindspore.dataset.transforms.vision.py_transforms as py_vision
...@@ -31,9 +31,9 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" ...@@ -31,9 +31,9 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_type_cast(): def test_type_cast():
""" """
Test type_cast_op Test TypeCast op
""" """
logger.info("test_type_cast_op") logger.info("test_type_cast")
# First dataset # First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
...@@ -71,9 +71,9 @@ def test_type_cast(): ...@@ -71,9 +71,9 @@ def test_type_cast():
def test_type_cast_string(): def test_type_cast_string():
""" """
Test type_cast_op Test TypeCast op
""" """
logger.info("test_type_cast_op") logger.info("test_type_cast_string")
# First dataset # First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
......
...@@ -44,8 +44,7 @@ def test_zip_01(): ...@@ -44,8 +44,7 @@ def test_zip_01():
dataz = ds.zip((data1, data2)) dataz = ds.zip((data1, data2))
# Note: zipped dataset has 5 rows and 7 columns # Note: zipped dataset has 5 rows and 7 columns
filename = "zip_01_result.npz" filename = "zip_01_result.npz"
parameters = {"params": {}} save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_zip_02(): def test_zip_02():
...@@ -59,8 +58,7 @@ def test_zip_02(): ...@@ -59,8 +58,7 @@ def test_zip_02():
dataz = ds.zip((data1, data2)) dataz = ds.zip((data1, data2))
# Note: zipped dataset has 3 rows and 4 columns # Note: zipped dataset has 3 rows and 4 columns
filename = "zip_02_result.npz" filename = "zip_02_result.npz"
parameters = {"params": {}} save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_zip_03(): def test_zip_03():
...@@ -74,8 +72,7 @@ def test_zip_03(): ...@@ -74,8 +72,7 @@ def test_zip_03():
dataz = ds.zip((data1, data2)) dataz = ds.zip((data1, data2))
# Note: zipped dataset has 3 rows and 7 columns # Note: zipped dataset has 3 rows and 7 columns
filename = "zip_03_result.npz" filename = "zip_03_result.npz"
parameters = {"params": {}} save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_zip_04(): def test_zip_04():
...@@ -90,8 +87,7 @@ def test_zip_04(): ...@@ -90,8 +87,7 @@ def test_zip_04():
dataz = ds.zip((data1, data2, data3)) dataz = ds.zip((data1, data2, data3))
# Note: zipped dataset has 3 rows and 9 columns # Note: zipped dataset has 3 rows and 9 columns
filename = "zip_04_result.npz" filename = "zip_04_result.npz"
parameters = {"params": {}} save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_zip_05(): def test_zip_05():
...@@ -109,8 +105,7 @@ def test_zip_05(): ...@@ -109,8 +105,7 @@ def test_zip_05():
dataz = ds.zip((data1, data2)) dataz = ds.zip((data1, data2))
# Note: zipped dataset has 5 rows and 9 columns # Note: zipped dataset has 5 rows and 9 columns
filename = "zip_05_result.npz" filename = "zip_05_result.npz"
parameters = {"params": {}} save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_zip_06(): def test_zip_06():
...@@ -129,8 +124,7 @@ def test_zip_06(): ...@@ -129,8 +124,7 @@ def test_zip_06():
dataz = dataz.repeat(2) dataz = dataz.repeat(2)
# Note: resultant dataset has 10 rows and 9 columns # Note: resultant dataset has 10 rows and 9 columns
filename = "zip_06_result.npz" filename = "zip_06_result.npz"
parameters = {"params": {}} save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_zip_exception_01(): def test_zip_exception_01():
......
...@@ -15,15 +15,15 @@ ...@@ -15,15 +15,15 @@
import json import json
import os import os
import hashlib
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import hashlib
#import jsbeautifier #import jsbeautifier
from mindspore import log as logger from mindspore import log as logger
COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
"col_sint16", "col_sint32", "col_sint64"] "col_sint16", "col_sint32", "col_sint64"]
SAVE_JSON = False
def save_golden(cur_dir, golden_ref_dir, result_dict): def save_golden(cur_dir, golden_ref_dir, result_dict):
...@@ -44,15 +44,6 @@ def save_golden_dict(cur_dir, golden_ref_dir, result_dict): ...@@ -44,15 +44,6 @@ def save_golden_dict(cur_dir, golden_ref_dir, result_dict):
np.savez(golden_ref_dir, np.array(list(result_dict.items()))) np.savez(golden_ref_dir, np.array(list(result_dict.items())))
def save_golden_md5(cur_dir, golden_ref_dir, result_dict):
"""
Save the dictionary (both keys and values) as the golden result in .npz file
"""
logger.info("cur_dir is {}".format(cur_dir))
logger.info("golden_ref_dir is {}".format(golden_ref_dir))
np.savez(golden_ref_dir, np.array(list(result_dict.items())))
def compare_to_golden(golden_ref_dir, result_dict): def compare_to_golden(golden_ref_dir, result_dict):
""" """
Compare as numpy arrays the test result to the golden result Compare as numpy arrays the test result to the golden result
...@@ -67,7 +58,7 @@ def compare_to_golden_dict(golden_ref_dir, result_dict): ...@@ -67,7 +58,7 @@ def compare_to_golden_dict(golden_ref_dir, result_dict):
Compare as dictionaries the test result to the golden result Compare as dictionaries the test result to the golden result
""" """
golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0'] golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
np.testing.assert_equal (result_dict, dict(golden_array)) np.testing.assert_equal(result_dict, dict(golden_array))
# assert result_dict == dict(golden_array) # assert result_dict == dict(golden_array)
...@@ -83,7 +74,6 @@ def save_json(filename, parameters, result_dict): ...@@ -83,7 +74,6 @@ def save_json(filename, parameters, result_dict):
fout.write(jsbeautifier.beautify(json.dumps(out_dict), options)) fout.write(jsbeautifier.beautify(json.dumps(out_dict), options))
def save_and_check(data, parameters, filename, generate_golden=False): def save_and_check(data, parameters, filename, generate_golden=False):
""" """
Save the dataset dictionary and compare (as numpy array) with golden file. Save the dataset dictionary and compare (as numpy array) with golden file.
...@@ -111,11 +101,12 @@ def save_and_check(data, parameters, filename, generate_golden=False): ...@@ -111,11 +101,12 @@ def save_and_check(data, parameters, filename, generate_golden=False):
compare_to_golden(golden_ref_dir, result_dict) compare_to_golden(golden_ref_dir, result_dict)
# Save to a json file for inspection if SAVE_JSON:
# save_json(filename, parameters, result_dict) # Save result to a json file for inspection
save_json(filename, parameters, result_dict)
def save_and_check_dict(data, parameters, filename, generate_golden=False): def save_and_check_dict(data, filename, generate_golden=False):
""" """
Save the dataset dictionary and compare (as dictionary) with golden file. Save the dataset dictionary and compare (as dictionary) with golden file.
Use create_dict_iterator to access the dataset. Use create_dict_iterator to access the dataset.
...@@ -140,11 +131,13 @@ def save_and_check_dict(data, parameters, filename, generate_golden=False): ...@@ -140,11 +131,13 @@ def save_and_check_dict(data, parameters, filename, generate_golden=False):
compare_to_golden_dict(golden_ref_dir, result_dict) compare_to_golden_dict(golden_ref_dir, result_dict)
# Save to a json file for inspection if SAVE_JSON:
# save_json(filename, parameters, result_dict) # Save result to a json file for inspection
parameters = {"params": {}}
save_json(filename, parameters, result_dict)
def save_and_check_md5(data, parameters, filename, generate_golden=False): def save_and_check_md5(data, filename, generate_golden=False):
""" """
Save the dataset dictionary and compare (as dictionary) with golden file (md5). Save the dataset dictionary and compare (as dictionary) with golden file (md5).
Use create_dict_iterator to access the dataset. Use create_dict_iterator to access the dataset.
...@@ -197,8 +190,9 @@ def ordered_save_and_check(data, parameters, filename, generate_golden=False): ...@@ -197,8 +190,9 @@ def ordered_save_and_check(data, parameters, filename, generate_golden=False):
compare_to_golden(golden_ref_dir, result_dict) compare_to_golden(golden_ref_dir, result_dict)
# Save to a json file for inspection if SAVE_JSON:
# save_json(filename, parameters, result_dict) # Save result to a json file for inspection
save_json(filename, parameters, result_dict)
def diff_mse(in1, in2): def diff_mse(in1, in2):
...@@ -211,24 +205,18 @@ def diff_me(in1, in2): ...@@ -211,24 +205,18 @@ def diff_me(in1, in2):
return mse / 255 * 100 return mse / 255 * 100
def diff_ssim(in1, in2):
from skimage.measure import compare_ssim as ssim
val = ssim(in1, in2, multichannel=True)
return (1 - val) * 100
def visualize(image_original, image_transformed): def visualize(image_original, image_transformed):
""" """
visualizes the image using DE op and Numpy op visualizes the image using DE op and Numpy op
""" """
num = len(image_cropped) num = len(image_transformed)
for i in range(num): for i in range(num):
plt.subplot(2, num, i + 1) plt.subplot(2, num, i + 1)
plt.imshow(image_original[i]) plt.imshow(image_original[i])
plt.title("Original image") plt.title("Original image")
plt.subplot(2, num, i + num + 1) plt.subplot(2, num, i + num + 1)
plt.imshow(image_cropped[i]) plt.imshow(image_transformed[i])
plt.title("Transformed image") plt.title("Transformed image")
plt.show() plt.show()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册