diff --git a/mindspore/mindrecord/__init__.py b/mindspore/mindrecord/__init__.py index 31fb801c46b59d978e0df56e6a72ed619a78af81..ba686c6c1832f5ca651fd71ff77377055978fa51 100644 --- a/mindspore/mindrecord/__init__.py +++ b/mindspore/mindrecord/__init__.py @@ -31,7 +31,8 @@ from .tools.cifar10_to_mr import Cifar10ToMR from .tools.cifar100_to_mr import Cifar100ToMR from .tools.imagenet_to_mr import ImageNetToMR from .tools.mnist_to_mr import MnistToMR +from .tools.tfrecord_to_mr import TFRecordToMR __all__ = ['FileWriter', 'FileReader', 'MindPage', - 'Cifar10ToMR', 'Cifar100ToMR', 'ImageNetToMR', 'MnistToMR', + 'Cifar10ToMR', 'Cifar100ToMR', 'ImageNetToMR', 'MnistToMR', 'TFRecordToMR', 'SUCCESS', 'FAILED'] diff --git a/mindspore/mindrecord/tools/tfrecord_to_mr.py b/mindspore/mindrecord/tools/tfrecord_to_mr.py new file mode 100644 index 0000000000000000000000000000000000000000..aba3b74729f98411c2a3d89627ba48dd727f6c15 --- /dev/null +++ b/mindspore/mindrecord/tools/tfrecord_to_mr.py @@ -0,0 +1,268 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +TFRecord convert tool for MindRecord +""" + +from importlib import import_module +from string import punctuation +import numpy as np + +from mindspore import log as logger +from ..filewriter import FileWriter +from ..shardutils import check_filename + +try: + tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord +except ModuleNotFoundError: + tf = None + +__all__ = ['TFRecordToMR'] + +SupportedTensorFlowVersion = '2.1.0' + +def _cast_type(value): + """ + Cast complex data type to basic datatype for MindRecord to recognize. + + Args: + value: the TFRecord data type + + Returns: + str, which is MindRecord field type. + """ + tf_type_to_mr_type = {tf.string: "string", + tf.int8: "int32", + tf.int16: "int32", + tf.int32: "int32", + tf.int64: "int64", + tf.uint8: "int32", + tf.uint16: "int32", + tf.uint32: "int64", + tf.uint64: "int64", + tf.float16: "float32", + tf.float32: "float32", + tf.float64: "float64", + tf.double: "float64", + tf.bool: "int32"} + unsupport_tf_type_to_mr_type = {tf.complex64: "None", + tf.complex128: "None"} + + if value in tf_type_to_mr_type: + return tf_type_to_mr_type[value] + + raise ValueError("Type " + value + " is not supported in MindRecord.") + +def _cast_string_type_to_np_type(value): + """Cast string type like: int32/int64/float32/float64 to np.int32/np.int64/np.float32/np.float64""" + string_type_to_np_type = {"int32": np.int32, + "int64": np.int64, + "float32": np.float32, + "float64": np.float64} + + if value in string_type_to_np_type: + return string_type_to_np_type[value] + + raise ValueError("Type " + value + " is not supported cast to numpy type in MindRecord.") + +def _cast_name(key): + """ + Cast schema names which containing special characters to valid names. + + Here special characters means any characters in + '!"#$%&\'()*+,./:;<=>?@[\\]^`{|}~ + Valid names can only contain a-z, A-Z, and 0-9 and _ + + Args: + key (str): original key that might contains special characters. + + Returns: + str, casted key that replace the special characters with "_". i.e. if + key is "a b" then returns "a_b". + """ + special_symbols = set('{}{}'.format(punctuation, ' ')) + special_symbols.remove('_') + new_key = ['_' if x in special_symbols else x for x in key] + casted_key = ''.join(new_key) + return casted_key + +class TFRecordToMR: + """ + Class is for tranformation from TFRecord to MindRecord. + + Args: + source (str): the TFRecord file to be transformed. + destination (str): the MindRecord file path to tranform into. + feature_dict (dict): a dictionary than states the feature type, i.e. + feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), + "yyyy": tf.io.FixedLenFeature([], tf.int64)} + ****** follow case which uses VarLenFeature not support ****** + feature_dict = {"context": {"xxxx": tf.io.FixedLenFeature([], tf.string), + "yyyy": tf.io.VarLenFeature(tf.int64)}, + "sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}} + bytes_fields (list): the bytes fields which are in feature_dict. + + Rasies: + ValueError, when: + 1) parameter TFRecord is not string. + 2) parameter MindRecord is not string. + 3) feature_dict is not FixedLenFeature. + 4) parameter bytes_field is not list(str) or not in feature_dict + Exception, when tensorflow module not found or version is not correct. + """ + def __init__(self, source, destination, feature_dict, bytes_fields=None): + if not tf: + raise Exception("Module tensorflow is not found, please use pip install it.") + + if tf.__version__ < SupportedTensorFlowVersion: + raise Exception("Module tensorflow version must be greater or equal {}.".format(SupportedTensorFlowVersion)) + + if not isinstance(source, str): + raise ValueError("Parameter source must be string.") + check_filename(source) + + if not isinstance(destination, str): + raise ValueError("Parameter destination must be string.") + check_filename(destination) + + self.source = source + self.destination = destination + + if feature_dict is None or not isinstance(feature_dict, dict): + raise ValueError("Parameter feature_dict is None or not dict.") + + for key, val in feature_dict.items(): + if not isinstance(val, tf.io.FixedLenFeature): + raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict)) + + self.feature_dict = feature_dict + + bytes_fields_list = [] + if bytes_fields: + if not isinstance(bytes_fields, list): + raise ValueError("Parameter bytes_fields: {} must be list(str).".format(bytes_fields)) + for item in bytes_fields: + if not isinstance(item, str): + raise ValueError("Parameter bytes_fields's item: {} is not str.".format(item)) + + if item not in self.feature_dict: + raise ValueError("Parameter bytes_fields's item: {} is not in feature_dict: {}." + .format(item, self.feature_dict)) + + if not isinstance(self.feature_dict[item].shape, list): + raise ValueError("Parameter feature_dict[{}].shape should be a list.".format(item)) + + casted_bytes_field = _cast_name(item) + bytes_fields_list.append(casted_bytes_field) + + self.bytes_fields_list = bytes_fields_list + self.scalar_set = set() + self.list_set = set() + + mindrecord_schema = {} + for key, val in self.feature_dict.items(): + if not val.shape: + self.scalar_set.add(_cast_name(key)) + if key in self.bytes_fields_list: + mindrecord_schema[_cast_name(key)] = {"type": "bytes"} + else: + mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype)} + else: + if len(val.shape) != 1: + raise ValueError("Parameter len(feature_dict[{}].shape) should be 1.") + if val.shape[0] < 1: + raise ValueError("Parameter feature_dict[{}].shape[0] should > 0".format(key)) + if val.dtype == tf.string: + raise ValueError("Parameter feautre_dict[{}].dtype is tf.string which shape[0] \ + is not None. It is not supported.".format(key)) + self.list_set.add(_cast_name(key)) + mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype), "shape": [val.shape[0]]} + self.mindrecord_schema = mindrecord_schema + + def _parse_record(self, example): + """Returns features for a single example""" + features = tf.io.parse_single_example(example, features=self.feature_dict) + return features + + def _get_data_when_scalar_field(self, ms_dict, cast_key, key, val): + """put data in ms_dict when field type is string""" + if isinstance(val.numpy(), (np.ndarray, list)): + raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val)) + if self.feature_dict[key].dtype == tf.string: + if cast_key in self.bytes_fields_list: + ms_dict[cast_key] = val.numpy() + else: + ms_dict[cast_key] = str(val.numpy(), encoding="utf-8") + elif _cast_type(self.feature_dict[key].dtype).startswith("int"): + ms_dict[cast_key] = int(val.numpy()) + else: + ms_dict[cast_key] = float(val.numpy()) + + def tfrecord_iterator(self): + """Yield a dict with key to be fields in schema, and value to be data.""" + dataset = tf.data.TFRecordDataset(self.source) + dataset = dataset.map(self._parse_record) + iterator = dataset.__iter__() + index_id = 0 + try: + for features in iterator: + ms_dict = {} + index_id = index_id + 1 + for key, val in features.items(): + cast_key = _cast_name(key) + if key in self.scalar_set: + self._get_data_when_scalar_field(ms_dict, cast_key, key, val) + else: + if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list): + raise ValueError("he response key: {}, value: {} from TFRecord should be a ndarray or list." + .format(key, val)) + # list set + ms_dict[cast_key] = \ + np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) + yield ms_dict + except tf.errors.InvalidArgumentError: + raise ValueError("TFRecord feature_dict parameter error.") + + def transform(self): + """ + Executes transform from TFRecord to MindRecord. + + Returns: + SUCCESS/FAILED, whether successfuly written into MindRecord. + """ + writer = FileWriter(self.destination) + logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}" + .format(self.mindrecord_schema, self.feature_dict)) + + writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord") + + tf_iter = self.tfrecord_iterator() + batch_size = 256 + + transform_count = 0 + while True: + data_list = [] + try: + for _ in range(batch_size): + data_list.append(tf_iter.__next__()) + transform_count += 1 + + writer.write_raw_data(data_list) + logger.info("Transformed {} records...".format(transform_count)) + except StopIteration: + if data_list: + writer.write_raw_data(data_list) + logger.info("Transformed {} records...".format(transform_count)) + break + return writer.commit() diff --git a/tests/ut/data/mindrecord/testTFRecordData/README.md b/tests/ut/data/mindrecord/testTFRecordData/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f358cd564b5bd96205869f67690f23f64b5e31aa --- /dev/null +++ b/tests/ut/data/mindrecord/testTFRecordData/README.md @@ -0,0 +1 @@ +## tfrecord file dir diff --git a/tests/ut/python/mindrecord/test_tfrecord_to_mr.py b/tests/ut/python/mindrecord/test_tfrecord_to_mr.py new file mode 100644 index 0000000000000000000000000000000000000000..87c002aa3b675c02715b1c61fecb47896095690b --- /dev/null +++ b/tests/ut/python/mindrecord/test_tfrecord_to_mr.py @@ -0,0 +1,400 @@ +# Copyright 2020 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""test tfrecord to mindrecord tool""" +import collections +from importlib import import_module +import os + +import numpy as np +import pytest +from mindspore import log as logger +from mindspore.mindrecord import FileReader +from mindspore.mindrecord import TFRecordToMR + +SupportedTensorFlowVersion = '2.1.0' + +try: + tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord +except ModuleNotFoundError: + logger.warning("tensorflow module not found.") + tf = None + +TFRECORD_DATA_DIR = "../data/mindrecord/testTFRecordData" +TFRECORD_FILE_NAME = "test.tfrecord" +MINDRECORD_FILE_NAME = "test.mindrecord" +PARTITION_NUM = 1 + +def verify_data(transformer, reader): + """Verify the data by read from mindrecord""" + tf_iter = transformer.tfrecord_iterator() + mr_iter = reader.get_next() + + count = 0 + for tf_item, mr_item in zip(tf_iter, mr_iter): + count = count + 1 + assert len(tf_item) == 6 + assert len(mr_item) == 6 + for key, value in tf_item.items(): + logger.info("key: {}, tfrecord: value: {}, mindrecord: value: {}".format(key, value, mr_item[key])) + if isinstance(value, np.ndarray): + assert (value == mr_item[key]).all() + else: + assert value == mr_item[key] + assert count == 10 + +def generate_tfrecord(): + def create_int_feature(values): + if isinstance(values, list): + feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) # values: [int, int, int] + else: + feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[values])) # values: int + return feature + + def create_float_feature(values): + if isinstance(values, list): + feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) # values: [float, float] + else: + feature = tf.train.Feature(float_list=tf.train.FloatList(value=[values])) # values: float + return feature + + def create_bytes_feature(values): + if isinstance(values, bytes): + feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) # values: bytes + else: + # values: string + feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')])) + return feature + + writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + + example_count = 0 + for i in range(10): + file_name = "000" + str(i) + ".jpg" + image_bytes = bytes(str("aaaabbbbcccc" + str(i)), encoding="utf-8") + int64_scalar = i + float_scalar = float(i) + int64_list = [i, i+1, i+2, i+3, i+4, i+1234567890] + float_list = [float(i), float(i+1), float(i+2.8), float(i+3.2), + float(i+4.4), float(i+123456.9), float(i+98765432.1)] + + features = collections.OrderedDict() + features["file_name"] = create_bytes_feature(file_name) + features["image_bytes"] = create_bytes_feature(image_bytes) + features["int64_scalar"] = create_int_feature(int64_scalar) + features["float_scalar"] = create_float_feature(float_scalar) + features["int64_list"] = create_int_feature(int64_list) + features["float_list"] = create_float_feature(float_list) + + tf_example = tf.train.Example(features=tf.train.Features(feature=features)) + writer.write(tf_example.SerializeToString()) + example_count += 1 + writer.close() + logger.info("Write {} rows in tfrecord.".format(example_count)) + +def test_tfrecord_to_mindrecord(): + """test transform tfrecord to mindrecord.""" + if not tf or tf.__version__ < SupportedTensorFlowVersion: + # skip the test + logger.warning("Module tensorflow is not found or version wrong, \ + please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion)) + return + + generate_tfrecord() + assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + + feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string), + "image_bytes": tf.io.FixedLenFeature([], tf.string), + "int64_scalar": tf.io.FixedLenFeature([], tf.int64), + "float_scalar": tf.io.FixedLenFeature([], tf.float32), + "int64_list": tf.io.FixedLenFeature([6], tf.int64), + "float_list": tf.io.FixedLenFeature([7], tf.float32), + } + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), + MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"]) + tfrecord_transformer.transform() + + assert os.path.exists(MINDRECORD_FILE_NAME) + assert os.path.exists(MINDRECORD_FILE_NAME + ".db") + + fr_mindrecord = FileReader(MINDRECORD_FILE_NAME) + verify_data(tfrecord_transformer, fr_mindrecord) + + os.remove(MINDRECORD_FILE_NAME) + os.remove(MINDRECORD_FILE_NAME + ".db") + + os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + +def test_tfrecord_to_mindrecord_scalar_with_1(): + """test transform tfrecord to mindrecord.""" + if not tf or tf.__version__ < SupportedTensorFlowVersion: + # skip the test + logger.warning("Module tensorflow is not found or version wrong, \ + please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion)) + return + + generate_tfrecord() + assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + + feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string), + "image_bytes": tf.io.FixedLenFeature([], tf.string), + "int64_scalar": tf.io.FixedLenFeature([1], tf.int64), + "float_scalar": tf.io.FixedLenFeature([1], tf.float32), + "int64_list": tf.io.FixedLenFeature([6], tf.int64), + "float_list": tf.io.FixedLenFeature([7], tf.float32), + } + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), + MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"]) + tfrecord_transformer.transform() + + assert os.path.exists(MINDRECORD_FILE_NAME) + assert os.path.exists(MINDRECORD_FILE_NAME + ".db") + + fr_mindrecord = FileReader(MINDRECORD_FILE_NAME) + verify_data(tfrecord_transformer, fr_mindrecord) + + os.remove(MINDRECORD_FILE_NAME) + os.remove(MINDRECORD_FILE_NAME + ".db") + + os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + +def test_tfrecord_to_mindrecord_scalar_with_1_list_small_len_exception(): + """test transform tfrecord to mindrecord.""" + if not tf or tf.__version__ < SupportedTensorFlowVersion: + # skip the test + logger.warning("Module tensorflow is not found or version wrong, \ + please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion)) + return + + generate_tfrecord() + assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + + feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string), + "image_bytes": tf.io.FixedLenFeature([], tf.string), + "int64_scalar": tf.io.FixedLenFeature([1], tf.int64), + "float_scalar": tf.io.FixedLenFeature([1], tf.float32), + "int64_list": tf.io.FixedLenFeature([6], tf.int64), + "float_list": tf.io.FixedLenFeature([2], tf.float32), + } + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + with pytest.raises(ValueError): + tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), + MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"]) + tfrecord_transformer.transform() + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + +def test_tfrecord_to_mindrecord_list_with_diff_type_exception(): + """test transform tfrecord to mindrecord.""" + if not tf or tf.__version__ < SupportedTensorFlowVersion: + # skip the test + logger.warning("Module tensorflow is not found or version wrong, \ + please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion)) + return + + generate_tfrecord() + assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + + feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string), + "image_bytes": tf.io.FixedLenFeature([], tf.string), + "int64_scalar": tf.io.FixedLenFeature([1], tf.int64), + "float_scalar": tf.io.FixedLenFeature([1], tf.float32), + "int64_list": tf.io.FixedLenFeature([6], tf.float32), + "float_list": tf.io.FixedLenFeature([7], tf.float32), + } + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + with pytest.raises(ValueError): + tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), + MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"]) + tfrecord_transformer.transform() + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + +def test_tfrecord_to_mindrecord_list_without_bytes_type(): + """test transform tfrecord to mindrecord.""" + if not tf or tf.__version__ < SupportedTensorFlowVersion: + # skip the test + logger.warning("Module tensorflow is not found or version wrong, \ + please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion)) + return + + generate_tfrecord() + assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + + feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string), + "image_bytes": tf.io.FixedLenFeature([], tf.string), + "int64_scalar": tf.io.FixedLenFeature([1], tf.int64), + "float_scalar": tf.io.FixedLenFeature([1], tf.float32), + "int64_list": tf.io.FixedLenFeature([6], tf.int64), + "float_list": tf.io.FixedLenFeature([7], tf.float32), + } + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), + MINDRECORD_FILE_NAME, feature_dict) + tfrecord_transformer.transform() + + assert os.path.exists(MINDRECORD_FILE_NAME) + assert os.path.exists(MINDRECORD_FILE_NAME + ".db") + + fr_mindrecord = FileReader(MINDRECORD_FILE_NAME) + verify_data(tfrecord_transformer, fr_mindrecord) + + os.remove(MINDRECORD_FILE_NAME) + os.remove(MINDRECORD_FILE_NAME + ".db") + + os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + +def test_tfrecord_to_mindrecord_scalar_with_2_exception(): + """test transform tfrecord to mindrecord.""" + if not tf or tf.__version__ < SupportedTensorFlowVersion: + # skip the test + logger.warning("Module tensorflow is not found or version wrong, \ + please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion)) + return + + generate_tfrecord() + assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + + feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string), + "image_bytes": tf.io.FixedLenFeature([], tf.string), + "int64_scalar": tf.io.FixedLenFeature([2], tf.int64), + "float_scalar": tf.io.FixedLenFeature([1], tf.float32), + "int64_list": tf.io.FixedLenFeature([6], tf.int64), + "float_list": tf.io.FixedLenFeature([7], tf.float32), + } + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), + MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"]) + with pytest.raises(ValueError): + tfrecord_transformer.transform() + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + +def test_tfrecord_to_mindrecord_scalar_string_with_1_exception(): + """test transform tfrecord to mindrecord.""" + if not tf or tf.__version__ < SupportedTensorFlowVersion: + # skip the test + logger.warning("Module tensorflow is not found or version wrong, \ + please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion)) + return + + generate_tfrecord() + assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + + feature_dict = {"file_name": tf.io.FixedLenFeature([1], tf.string), + "image_bytes": tf.io.FixedLenFeature([], tf.string), + "int64_scalar": tf.io.FixedLenFeature([1], tf.int64), + "float_scalar": tf.io.FixedLenFeature([1], tf.float32), + "int64_list": tf.io.FixedLenFeature([6], tf.int64), + "float_list": tf.io.FixedLenFeature([7], tf.float32), + } + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + with pytest.raises(ValueError): + tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), + MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"]) + tfrecord_transformer.transform() + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + +def test_tfrecord_to_mindrecord_scalar_bytes_with_10_exception(): + """test transform tfrecord to mindrecord.""" + if not tf or tf.__version__ < SupportedTensorFlowVersion: + # skip the test + logger.warning("Module tensorflow is not found or version wrong, \ + please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion)) + return + + generate_tfrecord() + assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + + feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string), + "image_bytes": tf.io.FixedLenFeature([10], tf.string), + "int64_scalar": tf.io.FixedLenFeature([1], tf.int64), + "float_scalar": tf.io.FixedLenFeature([1], tf.float32), + "int64_list": tf.io.FixedLenFeature([6], tf.int64), + "float_list": tf.io.FixedLenFeature([7], tf.float32), + } + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + with pytest.raises(ValueError): + tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), + MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"]) + tfrecord_transformer.transform() + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))