提交 819b102e 编写于 作者: J jonyguo

add performance test for mindrecord

上级 420ef2a3
......@@ -118,5 +118,8 @@ def mindrecord_dict_data(task_id):
image_file = open(file_name, "rb")
image_bytes = image_file.read()
image_file.close()
if not image_bytes:
print("The image file: {} is invalid.".format(file_name))
continue
data["data"] = image_bytes
yield data
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""use ImageNetToMR tool generate mindrecord"""
import os
from mindspore.mindrecord import ImageNetToMR
IMAGENET_MAP_FILE = "../../../ut/data/mindrecord/testImageNetDataWhole/labels_map.txt"
IMAGENET_IMAGE_DIR = "../../../ut/data/mindrecord/testImageNetDataWhole/images"
MINDRECORD_FILE = "./imagenet.mindrecord"
PARTITION_NUMBER = 16
def imagenet_to_mindrecord():
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE,
IMAGENET_IMAGE_DIR,
MINDRECORD_FILE,
PARTITION_NUMBER)
imagenet_transformer.transform()
if __name__ == '__main__':
imagenet_to_mindrecord()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""generate tfrecord"""
import collections
import os
import tensorflow as tf
IMAGENET_MAP_FILE = "../../../ut/data/mindrecord/testImageNetDataWhole/labels_map.txt"
IMAGENET_IMAGE_DIR = "../../../ut/data/mindrecord/testImageNetDataWhole/images"
TFRECORD_FILE = "./imagenet.tfrecord"
PARTITION_NUMBER = 16
def get_imagenet_filename_label_pic(map_file, image_dir):
"""
Get data from imagenet.
Yields:
filename, label, image_bytes
"""
if not os.path.exists(map_file):
raise IOError("map file {} not exists".format(map_file))
label_dict = {}
with open(map_file) as fp:
line = fp.readline()
while line:
labels = line.split(" ")
label_dict[labels[1]] = labels[0]
line = fp.readline()
# get all the dir which are n02087046, n02094114, n02109525
dir_paths = {}
for item in label_dict:
real_path = os.path.join(image_dir, label_dict[item])
if not os.path.isdir(real_path):
print("{} dir is not exist".format(real_path))
continue
dir_paths[item] = real_path
if not dir_paths:
raise PathNotExistsError("not valid image dir in {}".format(image_dir))
# get the filename, label and image binary as a dict
for label in dir_paths:
for item in os.listdir(dir_paths[label]):
file_name = os.path.join(dir_paths[label], item)
if not item.endswith("JPEG") and not item.endswith("jpg"):
print("{} file is not suffix with JPEG/jpg, skip it.".format(file_name))
continue
# get the image data
image_file = open(file_name, "rb")
image_bytes = image_file.read()
image_file.close()
if not image_bytes:
print("The image file: {} is invalid.".format(file_name))
continue
yield str(file_name), int(label), image_bytes
def create_int_feature(values):
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[values]))
return feature
def create_string_feature(values):
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')]))
return feature
def create_bytes_feature(values):
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
return feature
def imagenet_to_tfrecord():
writers = []
for i in range(PARTITION_NUMBER):
output_file = TFRECORD_FILE + str(i).rjust(2, '0')
writers.append(tf.io.TFRecordWriter(output_file))
writer_index = 0
total_written = 0
for file_name, label, image_bytes in get_imagenet_filename_label_pic(IMAGENET_MAP_FILE,
IMAGENET_IMAGE_DIR):
features = collections.OrderedDict()
features["file_name"] = create_string_feature(file_name)
features["label"] = create_int_feature(label)
features["data"] = create_bytes_feature(image_bytes)
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writers[writer_index].write(tf_example.SerializeToString())
writer_index = (writer_index + 1) % len(writers)
total_written += 1
for writer in writers:
writer.close()
print("Write {} total examples".format(total_written))
if __name__ == '__main__':
imagenet_to_tfrecord()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test dataset performance about mindspore.MindDataset, mindspore.TFRecordDataset, tf.data.TFRecordDataset"""
import time
import mindspore.dataset as ds
from mindspore.mindrecord import FileReader
import tensorflow as tf
print_step = 5000
def print_log(count):
if count % print_step == 0:
print("Read {} rows ...".format(count))
def use_filereader(mindrecord):
start = time.time()
columns_list = ["data", "label"]
reader = FileReader(file_name=mindrecord,
num_consumer=4,
columns=columns_list)
num_iter = 0
for index, item in enumerate(reader.get_next()):
num_iter += 1
print_log(num_iter)
end = time.time()
print("Read by FileReader - total rows: {}, cost time: {}s".format(num_iter, end - start))
def use_minddataset(mindrecord):
start = time.time()
columns_list = ["data", "label"]
data_set = ds.MindDataset(dataset_file=mindrecord,
columns_list=columns_list,
num_parallel_workers=4)
num_iter = 0
for item in data_set.create_dict_iterator():
num_iter += 1
print_log(num_iter)
end = time.time()
print("Read by MindDataset - total rows: {}, cost time: {}s".format(num_iter, end - start))
def use_tfrecorddataset(tfrecord):
start = time.time()
columns_list = ["data", "label"]
data_set = ds.TFRecordDataset(dataset_files=tfrecord,
columns_list=columns_list,
num_parallel_workers=4,
shuffle=ds.Shuffle.GLOBAL)
data_set = data_set.shuffle(10000)
num_iter = 0
for item in data_set.create_dict_iterator():
num_iter += 1
print_log(num_iter)
end = time.time()
print("Read by TFRecordDataset - total rows: {}, cost time: {}s".format(num_iter, end - start))
def use_tensorflow_tfrecorddataset(tfrecord):
start = time.time()
def _parse_record(example_photo):
features = {
'file_name': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([1], tf.int64),
'data': tf.io.FixedLenFeature([], tf.string)}
parsed_features = tf.io.parse_single_example(example_photo, features=features)
return parsed_features
data_set = tf.data.TFRecordDataset(filenames=tfrecord,
buffer_size=100000,
num_parallel_reads=4)
data_set = data_set.map(_parse_record, num_parallel_calls=4)
num_iter = 0
for item in data_set.__iter__():
num_iter += 1
print_log(num_iter)
end = time.time()
print("Read by TensorFlow TFRecordDataset - total rows: {}, cost time: {}s".format(num_iter, end - start))
if __name__ == '__main__':
# use MindDataset
mindrecord = './imagenet.mindrecord00'
use_minddataset(mindrecord)
# use TFRecordDataset
tfrecord = ['imagenet.tfrecord00', 'imagenet.tfrecord01', 'imagenet.tfrecord02', 'imagenet.tfrecord03',
'imagenet.tfrecord04', 'imagenet.tfrecord05', 'imagenet.tfrecord06', 'imagenet.tfrecord07',
'imagenet.tfrecord08', 'imagenet.tfrecord09', 'imagenet.tfrecord10', 'imagenet.tfrecord11',
'imagenet.tfrecord12', 'imagenet.tfrecord13', 'imagenet.tfrecord14', 'imagenet.tfrecord15']
use_tfrecorddataset(tfrecord)
# use TensorFlow TFRecordDataset
use_tensorflow_tfrecorddataset(tfrecord)
# use FileReader
# use_filereader(mindrecord)
{
"datasetType": "TF",
"numRows": 930059,
"columns": {
"file_name": {
"type": "uint8",
"rank": 0
},
"label": {
"type": "int64",
"rank": 0
},
"data": {
"type": "uint8",
"rank": 0
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册