diff --git a/mindspore/mindrecord/tools/tfrecord_to_mr.py b/mindspore/mindrecord/tools/tfrecord_to_mr.py index e8c52001fdddf2967e416349307468f9db19710c..c351f4307f4e2bb8b5390dfdb724a90ee8c78736 100644 --- a/mindspore/mindrecord/tools/tfrecord_to_mr.py +++ b/mindspore/mindrecord/tools/tfrecord_to_mr.py @@ -113,7 +113,7 @@ class TFRecordToMR: feature_dict = {"context": {"xxxx": tf.io.FixedLenFeature([], tf.string), \ "yyyy": tf.io.VarLenFeature(tf.int64)}, \ "sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}} - bytes_fields (list): the bytes fields which are in feature_dict. + bytes_fields (list, optional): the bytes fields which are in feature_dict and can be images bytes. Raises: ValueError: If parameter is invalid. @@ -147,7 +147,7 @@ class TFRecordToMR: self.feature_dict = feature_dict bytes_fields_list = [] - if bytes_fields: + if bytes_fields is not None: if not isinstance(bytes_fields, list): raise ValueError("Parameter bytes_fields: {} must be list(str).".format(bytes_fields)) for item in bytes_fields: @@ -161,6 +161,9 @@ class TFRecordToMR: if not isinstance(self.feature_dict[item].shape, list): raise ValueError("Parameter feature_dict[{}].shape should be a list.".format(item)) + if self.feature_dict[item].dtype != tf.string: + raise ValueError("Parameter bytes_field: {} should be tf.string in feature_dict.".format(item)) + casted_bytes_field = _cast_name(item) bytes_fields_list.append(casted_bytes_field) @@ -172,7 +175,7 @@ class TFRecordToMR: for key, val in self.feature_dict.items(): if not val.shape: self.scalar_set.add(_cast_name(key)) - if key in self.bytes_fields_list: + if _cast_name(key) in self.bytes_fields_list: mindrecord_schema[_cast_name(key)] = {"type": "bytes"} else: mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype)} @@ -182,8 +185,8 @@ class TFRecordToMR: if val.shape[0] < 1: raise ValueError("Parameter feature_dict[{}].shape[0] should > 0".format(key)) if val.dtype == tf.string: - raise ValueError("Parameter feautre_dict[{}].dtype is tf.string which shape[0] \ - is not None. It is not supported.".format(key)) + raise ValueError("Parameter feautre_dict[{}].dtype is tf.string which shape[0] " \ + "is not None. It is not supported.".format(key)) self.list_set.add(_cast_name(key)) mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype), "shape": [val.shape[0]]} self.mindrecord_schema = mindrecord_schema @@ -219,12 +222,12 @@ class TFRecordToMR: index_id = index_id + 1 for key, val in features.items(): cast_key = _cast_name(key) - if key in self.scalar_set: + if cast_key in self.scalar_set: self._get_data_when_scalar_field(ms_dict, cast_key, key, val) else: if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list): - raise ValueError("he response key: {}, value: {} from TFRecord should be a ndarray or list." - .format(key, val)) + raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or " \ + "list.".format(key, val)) # list set ms_dict[cast_key] = \ np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) diff --git a/tests/ut/python/mindrecord/test_tfrecord_to_mr.py b/tests/ut/python/mindrecord/test_tfrecord_to_mr.py index 87c002aa3b675c02715b1c61fecb47896095690b..cfd0d53a492339c9cc98784ebe9e9d20015dc8dc 100644 --- a/tests/ut/python/mindrecord/test_tfrecord_to_mr.py +++ b/tests/ut/python/mindrecord/test_tfrecord_to_mr.py @@ -15,6 +15,7 @@ import collections from importlib import import_module import os +from string import punctuation import numpy as np import pytest @@ -35,6 +36,27 @@ TFRECORD_FILE_NAME = "test.tfrecord" MINDRECORD_FILE_NAME = "test.mindrecord" PARTITION_NUM = 1 +def cast_name(key): + """ + Cast schema names which containing special characters to valid names. + + Here special characters means any characters in + '!"#$%&\'()*+,./:;<=>?@[\\]^`{|}~ + Valid names can only contain a-z, A-Z, and 0-9 and _ + + Args: + key (str): original key that might contains special characters. + + Returns: + str, casted key that replace the special characters with "_". i.e. if + key is "a b" then returns "a_b". + """ + special_symbols = set('{}{}'.format(punctuation, ' ')) + special_symbols.remove('_') + new_key = ['_' if x in special_symbols else x for x in key] + casted_key = ''.join(new_key) + return casted_key + def verify_data(transformer, reader): """Verify the data by read from mindrecord""" tf_iter = transformer.tfrecord_iterator() @@ -43,14 +65,14 @@ def verify_data(transformer, reader): count = 0 for tf_item, mr_item in zip(tf_iter, mr_iter): count = count + 1 - assert len(tf_item) == 6 - assert len(mr_item) == 6 + assert len(tf_item) == len(mr_item) for key, value in tf_item.items(): - logger.info("key: {}, tfrecord: value: {}, mindrecord: value: {}".format(key, value, mr_item[key])) + logger.info("key: {}, tfrecord: value: {}, mindrecord: value: {}".format(key, value, + mr_item[cast_name(key)])) if isinstance(value, np.ndarray): - assert (value == mr_item[key]).all() + assert (value == mr_item[cast_name(key)]).all() else: - assert value == mr_item[key] + assert value == mr_item[cast_name(key)] assert count == 10 def generate_tfrecord(): @@ -102,6 +124,39 @@ def generate_tfrecord(): writer.close() logger.info("Write {} rows in tfrecord.".format(example_count)) +def generate_tfrecord_with_special_field_name(): + def create_int_feature(values): + if isinstance(values, list): + feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) # values: [int, int, int] + else: + feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[values])) # values: int + return feature + + def create_bytes_feature(values): + if isinstance(values, bytes): + feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) # values: bytes + else: + # values: string + feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')])) + return feature + + writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + + example_count = 0 + for i in range(10): + label = i + image_bytes = bytes(str("aaaabbbbcccc" + str(i)), encoding="utf-8") + + features = collections.OrderedDict() + features["image/class/label"] = create_int_feature(label) + features["image/encoded"] = create_bytes_feature(image_bytes) + + tf_example = tf.train.Example(features=tf.train.Features(feature=features)) + writer.write(tf_example.SerializeToString()) + example_count += 1 + writer.close() + logger.info("Write {} rows in tfrecord.".format(example_count)) + def test_tfrecord_to_mindrecord(): """test transform tfrecord to mindrecord.""" if not tf or tf.__version__ < SupportedTensorFlowVersion: @@ -398,3 +453,110 @@ def test_tfrecord_to_mindrecord_scalar_bytes_with_10_exception(): os.remove(MINDRECORD_FILE_NAME + ".db") os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + +def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_string_type(): + """test transform tfrecord to mindrecord.""" + if not tf or tf.__version__ < SupportedTensorFlowVersion: + # skip the test + logger.warning("Module tensorflow is not found or version wrong, \ + please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion)) + return + + generate_tfrecord() + assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + + feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string), + "image_bytes": tf.io.FixedLenFeature([], tf.string), + "int64_scalar": tf.io.FixedLenFeature([], tf.int64), + "float_scalar": tf.io.FixedLenFeature([], tf.float32), + "int64_list": tf.io.FixedLenFeature([6], tf.int64), + "float_list": tf.io.FixedLenFeature([7], tf.float32), + } + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + with pytest.raises(ValueError): + tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), + MINDRECORD_FILE_NAME, feature_dict, ["int64_list"]) + tfrecord_transformer.transform() + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + +def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_list(): + """test transform tfrecord to mindrecord.""" + if not tf or tf.__version__ < SupportedTensorFlowVersion: + # skip the test + logger.warning("Module tensorflow is not found or version wrong, \ + please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion)) + return + + generate_tfrecord() + assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + + feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string), + "image_bytes": tf.io.FixedLenFeature([], tf.string), + "int64_scalar": tf.io.FixedLenFeature([], tf.int64), + "float_scalar": tf.io.FixedLenFeature([], tf.float32), + "int64_list": tf.io.FixedLenFeature([6], tf.int64), + "float_list": tf.io.FixedLenFeature([7], tf.float32), + } + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + with pytest.raises(ValueError): + tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), + MINDRECORD_FILE_NAME, feature_dict, "") + tfrecord_transformer.transform() + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + +def test_tfrecord_to_mindrecord_with_special_field_name(): + """test transform tfrecord to mindrecord.""" + if not tf or tf.__version__ < SupportedTensorFlowVersion: + # skip the test + logger.warning("Module tensorflow is not found or version wrong, \ + please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion)) + return + + generate_tfrecord_with_special_field_name() + assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) + + feature_dict = {"image/class/label": tf.io.FixedLenFeature([], tf.int64), + "image/encoded": tf.io.FixedLenFeature([], tf.string), + } + + if os.path.exists(MINDRECORD_FILE_NAME): + os.remove(MINDRECORD_FILE_NAME) + if os.path.exists(MINDRECORD_FILE_NAME + ".db"): + os.remove(MINDRECORD_FILE_NAME + ".db") + + tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), + MINDRECORD_FILE_NAME, feature_dict, ["image/encoded"]) + tfrecord_transformer.transform() + + assert os.path.exists(MINDRECORD_FILE_NAME) + assert os.path.exists(MINDRECORD_FILE_NAME + ".db") + + fr_mindrecord = FileReader(MINDRECORD_FILE_NAME) + verify_data(tfrecord_transformer, fr_mindrecord) + + os.remove(MINDRECORD_FILE_NAME) + os.remove(MINDRECORD_FILE_NAME + ".db") + + os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))