提交 2fa2f866 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2135 add tfrecord to mindrecord tool

Merge pull request !2135 from guozhijian/add_tools_tfrecord_to_mindrecord
......@@ -31,7 +31,8 @@ from .tools.cifar10_to_mr import Cifar10ToMR
from .tools.cifar100_to_mr import Cifar100ToMR
from .tools.imagenet_to_mr import ImageNetToMR
from .tools.mnist_to_mr import MnistToMR
from .tools.tfrecord_to_mr import TFRecordToMR
__all__ = ['FileWriter', 'FileReader', 'MindPage',
'Cifar10ToMR', 'Cifar100ToMR', 'ImageNetToMR', 'MnistToMR',
'Cifar10ToMR', 'Cifar100ToMR', 'ImageNetToMR', 'MnistToMR', 'TFRecordToMR',
'SUCCESS', 'FAILED']
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
TFRecord convert tool for MindRecord
"""
from importlib import import_module
from string import punctuation
import numpy as np
from mindspore import log as logger
from ..filewriter import FileWriter
from ..shardutils import check_filename
try:
tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
except ModuleNotFoundError:
tf = None
__all__ = ['TFRecordToMR']
SupportedTensorFlowVersion = '2.1.0'
def _cast_type(value):
"""
Cast complex data type to basic datatype for MindRecord to recognize.
Args:
value: the TFRecord data type
Returns:
str, which is MindRecord field type.
"""
tf_type_to_mr_type = {tf.string: "string",
tf.int8: "int32",
tf.int16: "int32",
tf.int32: "int32",
tf.int64: "int64",
tf.uint8: "int32",
tf.uint16: "int32",
tf.uint32: "int64",
tf.uint64: "int64",
tf.float16: "float32",
tf.float32: "float32",
tf.float64: "float64",
tf.double: "float64",
tf.bool: "int32"}
unsupport_tf_type_to_mr_type = {tf.complex64: "None",
tf.complex128: "None"}
if value in tf_type_to_mr_type:
return tf_type_to_mr_type[value]
raise ValueError("Type " + value + " is not supported in MindRecord.")
def _cast_string_type_to_np_type(value):
"""Cast string type like: int32/int64/float32/float64 to np.int32/np.int64/np.float32/np.float64"""
string_type_to_np_type = {"int32": np.int32,
"int64": np.int64,
"float32": np.float32,
"float64": np.float64}
if value in string_type_to_np_type:
return string_type_to_np_type[value]
raise ValueError("Type " + value + " is not supported cast to numpy type in MindRecord.")
def _cast_name(key):
"""
Cast schema names which containing special characters to valid names.
Here special characters means any characters in
'!"#$%&\'()*+,./:;<=>?@[\\]^`{|}~
Valid names can only contain a-z, A-Z, and 0-9 and _
Args:
key (str): original key that might contains special characters.
Returns:
str, casted key that replace the special characters with "_". i.e. if
key is "a b" then returns "a_b".
"""
special_symbols = set('{}{}'.format(punctuation, ' '))
special_symbols.remove('_')
new_key = ['_' if x in special_symbols else x for x in key]
casted_key = ''.join(new_key)
return casted_key
class TFRecordToMR:
"""
Class is for tranformation from TFRecord to MindRecord.
Args:
source (str): the TFRecord file to be transformed.
destination (str): the MindRecord file path to tranform into.
feature_dict (dict): a dictionary than states the feature type, i.e.
feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string),
"yyyy": tf.io.FixedLenFeature([], tf.int64)}
****** follow case which uses VarLenFeature not support ******
feature_dict = {"context": {"xxxx": tf.io.FixedLenFeature([], tf.string),
"yyyy": tf.io.VarLenFeature(tf.int64)},
"sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}}
bytes_fields (list): the bytes fields which are in feature_dict.
Rasies:
ValueError, when:
1) parameter TFRecord is not string.
2) parameter MindRecord is not string.
3) feature_dict is not FixedLenFeature.
4) parameter bytes_field is not list(str) or not in feature_dict
Exception, when tensorflow module not found or version is not correct.
"""
def __init__(self, source, destination, feature_dict, bytes_fields=None):
if not tf:
raise Exception("Module tensorflow is not found, please use pip install it.")
if tf.__version__ < SupportedTensorFlowVersion:
raise Exception("Module tensorflow version must be greater or equal {}.".format(SupportedTensorFlowVersion))
if not isinstance(source, str):
raise ValueError("Parameter source must be string.")
check_filename(source)
if not isinstance(destination, str):
raise ValueError("Parameter destination must be string.")
check_filename(destination)
self.source = source
self.destination = destination
if feature_dict is None or not isinstance(feature_dict, dict):
raise ValueError("Parameter feature_dict is None or not dict.")
for key, val in feature_dict.items():
if not isinstance(val, tf.io.FixedLenFeature):
raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict))
self.feature_dict = feature_dict
bytes_fields_list = []
if bytes_fields:
if not isinstance(bytes_fields, list):
raise ValueError("Parameter bytes_fields: {} must be list(str).".format(bytes_fields))
for item in bytes_fields:
if not isinstance(item, str):
raise ValueError("Parameter bytes_fields's item: {} is not str.".format(item))
if item not in self.feature_dict:
raise ValueError("Parameter bytes_fields's item: {} is not in feature_dict: {}."
.format(item, self.feature_dict))
if not isinstance(self.feature_dict[item].shape, list):
raise ValueError("Parameter feature_dict[{}].shape should be a list.".format(item))
casted_bytes_field = _cast_name(item)
bytes_fields_list.append(casted_bytes_field)
self.bytes_fields_list = bytes_fields_list
self.scalar_set = set()
self.list_set = set()
mindrecord_schema = {}
for key, val in self.feature_dict.items():
if not val.shape:
self.scalar_set.add(_cast_name(key))
if key in self.bytes_fields_list:
mindrecord_schema[_cast_name(key)] = {"type": "bytes"}
else:
mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype)}
else:
if len(val.shape) != 1:
raise ValueError("Parameter len(feature_dict[{}].shape) should be 1.")
if val.shape[0] < 1:
raise ValueError("Parameter feature_dict[{}].shape[0] should > 0".format(key))
if val.dtype == tf.string:
raise ValueError("Parameter feautre_dict[{}].dtype is tf.string which shape[0] \
is not None. It is not supported.".format(key))
self.list_set.add(_cast_name(key))
mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype), "shape": [val.shape[0]]}
self.mindrecord_schema = mindrecord_schema
def _parse_record(self, example):
"""Returns features for a single example"""
features = tf.io.parse_single_example(example, features=self.feature_dict)
return features
def _get_data_when_scalar_field(self, ms_dict, cast_key, key, val):
"""put data in ms_dict when field type is string"""
if isinstance(val.numpy(), (np.ndarray, list)):
raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val))
if self.feature_dict[key].dtype == tf.string:
if cast_key in self.bytes_fields_list:
ms_dict[cast_key] = val.numpy()
else:
ms_dict[cast_key] = str(val.numpy(), encoding="utf-8")
elif _cast_type(self.feature_dict[key].dtype).startswith("int"):
ms_dict[cast_key] = int(val.numpy())
else:
ms_dict[cast_key] = float(val.numpy())
def tfrecord_iterator(self):
"""Yield a dict with key to be fields in schema, and value to be data."""
dataset = tf.data.TFRecordDataset(self.source)
dataset = dataset.map(self._parse_record)
iterator = dataset.__iter__()
index_id = 0
try:
for features in iterator:
ms_dict = {}
index_id = index_id + 1
for key, val in features.items():
cast_key = _cast_name(key)
if key in self.scalar_set:
self._get_data_when_scalar_field(ms_dict, cast_key, key, val)
else:
if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list):
raise ValueError("he response key: {}, value: {} from TFRecord should be a ndarray or list."
.format(key, val))
# list set
ms_dict[cast_key] = \
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"]))
yield ms_dict
except tf.errors.InvalidArgumentError:
raise ValueError("TFRecord feature_dict parameter error.")
def transform(self):
"""
Executes transform from TFRecord to MindRecord.
Returns:
SUCCESS/FAILED, whether successfuly written into MindRecord.
"""
writer = FileWriter(self.destination)
logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}"
.format(self.mindrecord_schema, self.feature_dict))
writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord")
tf_iter = self.tfrecord_iterator()
batch_size = 256
transform_count = 0
while True:
data_list = []
try:
for _ in range(batch_size):
data_list.append(tf_iter.__next__())
transform_count += 1
writer.write_raw_data(data_list)
logger.info("Transformed {} records...".format(transform_count))
except StopIteration:
if data_list:
writer.write_raw_data(data_list)
logger.info("Transformed {} records...".format(transform_count))
break
return writer.commit()
# Copyright 2020 Huawei Technologies Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test tfrecord to mindrecord tool"""
import collections
from importlib import import_module
import os
import numpy as np
import pytest
from mindspore import log as logger
from mindspore.mindrecord import FileReader
from mindspore.mindrecord import TFRecordToMR
SupportedTensorFlowVersion = '2.1.0'
try:
tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
except ModuleNotFoundError:
logger.warning("tensorflow module not found.")
tf = None
TFRECORD_DATA_DIR = "../data/mindrecord/testTFRecordData"
TFRECORD_FILE_NAME = "test.tfrecord"
MINDRECORD_FILE_NAME = "test.mindrecord"
PARTITION_NUM = 1
def verify_data(transformer, reader):
"""Verify the data by read from mindrecord"""
tf_iter = transformer.tfrecord_iterator()
mr_iter = reader.get_next()
count = 0
for tf_item, mr_item in zip(tf_iter, mr_iter):
count = count + 1
assert len(tf_item) == 6
assert len(mr_item) == 6
for key, value in tf_item.items():
logger.info("key: {}, tfrecord: value: {}, mindrecord: value: {}".format(key, value, mr_item[key]))
if isinstance(value, np.ndarray):
assert (value == mr_item[key]).all()
else:
assert value == mr_item[key]
assert count == 10
def generate_tfrecord():
def create_int_feature(values):
if isinstance(values, list):
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) # values: [int, int, int]
else:
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[values])) # values: int
return feature
def create_float_feature(values):
if isinstance(values, list):
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) # values: [float, float]
else:
feature = tf.train.Feature(float_list=tf.train.FloatList(value=[values])) # values: float
return feature
def create_bytes_feature(values):
if isinstance(values, bytes):
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) # values: bytes
else:
# values: string
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')]))
return feature
writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
example_count = 0
for i in range(10):
file_name = "000" + str(i) + ".jpg"
image_bytes = bytes(str("aaaabbbbcccc" + str(i)), encoding="utf-8")
int64_scalar = i
float_scalar = float(i)
int64_list = [i, i+1, i+2, i+3, i+4, i+1234567890]
float_list = [float(i), float(i+1), float(i+2.8), float(i+3.2),
float(i+4.4), float(i+123456.9), float(i+98765432.1)]
features = collections.OrderedDict()
features["file_name"] = create_bytes_feature(file_name)
features["image_bytes"] = create_bytes_feature(image_bytes)
features["int64_scalar"] = create_int_feature(int64_scalar)
features["float_scalar"] = create_float_feature(float_scalar)
features["int64_list"] = create_int_feature(int64_list)
features["float_list"] = create_float_feature(float_list)
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
example_count += 1
writer.close()
logger.info("Write {} rows in tfrecord.".format(example_count))
def test_tfrecord_to_mindrecord():
"""test transform tfrecord to mindrecord."""
if not tf or tf.__version__ < SupportedTensorFlowVersion:
# skip the test
logger.warning("Module tensorflow is not found or version wrong, \
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
"int64_scalar": tf.io.FixedLenFeature([], tf.int64),
"float_scalar": tf.io.FixedLenFeature([], tf.float32),
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
tfrecord_transformer.transform()
assert os.path.exists(MINDRECORD_FILE_NAME)
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
verify_data(tfrecord_transformer, fr_mindrecord)
os.remove(MINDRECORD_FILE_NAME)
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
def test_tfrecord_to_mindrecord_scalar_with_1():
"""test transform tfrecord to mindrecord."""
if not tf or tf.__version__ < SupportedTensorFlowVersion:
# skip the test
logger.warning("Module tensorflow is not found or version wrong, \
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
"int64_scalar": tf.io.FixedLenFeature([1], tf.int64),
"float_scalar": tf.io.FixedLenFeature([1], tf.float32),
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
tfrecord_transformer.transform()
assert os.path.exists(MINDRECORD_FILE_NAME)
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
verify_data(tfrecord_transformer, fr_mindrecord)
os.remove(MINDRECORD_FILE_NAME)
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
def test_tfrecord_to_mindrecord_scalar_with_1_list_small_len_exception():
"""test transform tfrecord to mindrecord."""
if not tf or tf.__version__ < SupportedTensorFlowVersion:
# skip the test
logger.warning("Module tensorflow is not found or version wrong, \
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
"int64_scalar": tf.io.FixedLenFeature([1], tf.int64),
"float_scalar": tf.io.FixedLenFeature([1], tf.float32),
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
"float_list": tf.io.FixedLenFeature([2], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
with pytest.raises(ValueError):
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
tfrecord_transformer.transform()
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
def test_tfrecord_to_mindrecord_list_with_diff_type_exception():
"""test transform tfrecord to mindrecord."""
if not tf or tf.__version__ < SupportedTensorFlowVersion:
# skip the test
logger.warning("Module tensorflow is not found or version wrong, \
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
"int64_scalar": tf.io.FixedLenFeature([1], tf.int64),
"float_scalar": tf.io.FixedLenFeature([1], tf.float32),
"int64_list": tf.io.FixedLenFeature([6], tf.float32),
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
with pytest.raises(ValueError):
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
tfrecord_transformer.transform()
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
def test_tfrecord_to_mindrecord_list_without_bytes_type():
"""test transform tfrecord to mindrecord."""
if not tf or tf.__version__ < SupportedTensorFlowVersion:
# skip the test
logger.warning("Module tensorflow is not found or version wrong, \
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
"int64_scalar": tf.io.FixedLenFeature([1], tf.int64),
"float_scalar": tf.io.FixedLenFeature([1], tf.float32),
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict)
tfrecord_transformer.transform()
assert os.path.exists(MINDRECORD_FILE_NAME)
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
verify_data(tfrecord_transformer, fr_mindrecord)
os.remove(MINDRECORD_FILE_NAME)
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
def test_tfrecord_to_mindrecord_scalar_with_2_exception():
"""test transform tfrecord to mindrecord."""
if not tf or tf.__version__ < SupportedTensorFlowVersion:
# skip the test
logger.warning("Module tensorflow is not found or version wrong, \
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
"int64_scalar": tf.io.FixedLenFeature([2], tf.int64),
"float_scalar": tf.io.FixedLenFeature([1], tf.float32),
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
with pytest.raises(ValueError):
tfrecord_transformer.transform()
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
def test_tfrecord_to_mindrecord_scalar_string_with_1_exception():
"""test transform tfrecord to mindrecord."""
if not tf or tf.__version__ < SupportedTensorFlowVersion:
# skip the test
logger.warning("Module tensorflow is not found or version wrong, \
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
feature_dict = {"file_name": tf.io.FixedLenFeature([1], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
"int64_scalar": tf.io.FixedLenFeature([1], tf.int64),
"float_scalar": tf.io.FixedLenFeature([1], tf.float32),
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
with pytest.raises(ValueError):
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
tfrecord_transformer.transform()
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
def test_tfrecord_to_mindrecord_scalar_bytes_with_10_exception():
"""test transform tfrecord to mindrecord."""
if not tf or tf.__version__ < SupportedTensorFlowVersion:
# skip the test
logger.warning("Module tensorflow is not found or version wrong, \
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([10], tf.string),
"int64_scalar": tf.io.FixedLenFeature([1], tf.int64),
"float_scalar": tf.io.FixedLenFeature([1], tf.float32),
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
with pytest.raises(ValueError):
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
tfrecord_transformer.transform()
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册