提交 948ff63a 编写于 作者: Y yanzhenxiang2020 提交者: 高东海

fix mindrecord ut long time

上级 d75745bc
......@@ -77,20 +77,20 @@ class MnistToMR:
self.mnist_schema_json = {"label": {"type": "int64"}, "data": {"type": "bytes"}}
def _extract_images(self, filename, num_images):
def _extract_images(self, filename):
"""Extract the images into a 4D tensor [image index, y, x, channels]."""
with gzip.open(filename) as bytestream:
bytestream.read(16)
buf = bytestream.read(self.image_size * self.image_size * num_images * self.num_channels)
buf = bytestream.read()
data = np.frombuffer(buf, dtype=np.uint8)
data = data.reshape(num_images, self.image_size, self.image_size, self.num_channels)
data = data.reshape(-1, self.image_size, self.image_size, self.num_channels)
return data
def _extract_labels(self, filename, num_images):
def _extract_labels(self, filename):
"""Extract the labels into a vector of int64 label IDs."""
with gzip.open(filename) as bytestream:
bytestream.read(8)
buf = bytestream.read(1 * num_images)
buf = bytestream.read()
labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
return labels
......@@ -101,8 +101,8 @@ class MnistToMR:
Yields:
data (dict of list): mnist data list which contains dict.
"""
train_data = self._extract_images(self.train_data_filename_, 60000)
train_labels = self._extract_labels(self.train_labels_filename_, 60000)
train_data = self._extract_images(self.train_data_filename_)
train_labels = self._extract_labels(self.train_labels_filename_)
for data, label in zip(train_data, train_labels):
_, img = cv2.imencode(".jpeg", data)
yield {"label": int(label), "data": img.tobytes()}
......@@ -114,8 +114,8 @@ class MnistToMR:
Yields:
data (dict of list): mnist data list which contains dict.
"""
test_data = self._extract_images(self.test_data_filename_, 10000)
test_labels = self._extract_labels(self.test_labels_filename_, 10000)
test_data = self._extract_images(self.test_data_filename_)
test_labels = self._extract_labels(self.test_labels_filename_)
for data, label in zip(test_data, test_labels):
_, img = cv2.imencode(".jpeg", data)
yield {"label": int(label), "data": img.tobytes()}
......
......@@ -203,9 +203,9 @@ def test_nlp_page_reader_tutorial():
os.remove("{}".format(x))
os.remove("{}.db".format(x))
def test_cv_file_writer_shard_num_1000():
"""test file writer when shard num equals 1000."""
writer = FileWriter(CV_FILE_NAME, 1000)
def test_cv_file_writer_shard_num_10():
"""test file writer when shard num equals 10."""
writer = FileWriter(CV_FILE_NAME, 10)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
......@@ -214,8 +214,8 @@ def test_cv_file_writer_shard_num_1000():
writer.write_raw_data(data)
writer.commit()
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(3, '0'))
for x in range(1000)]
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(10)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
......
......@@ -37,7 +37,7 @@ def read(train_name, test_name):
count = count + 1
if count == 1:
logger.info("data: {}".format(x))
assert count == 60000
assert count == 20
reader.close()
count = 0
......@@ -47,7 +47,7 @@ def read(train_name, test_name):
count = count + 1
if count == 1:
logger.info("data: {}".format(x))
assert count == 10000
assert count == 10
reader.close()
......@@ -102,10 +102,10 @@ def test_mnist_to_mindrecord_compare_data():
't10k-images-idx3-ubyte.gz')
test_labels_filename_ = os.path.join(MNIST_DIR,
't10k-labels-idx1-ubyte.gz')
train_data = _extract_images(train_data_filename_, 60000)
train_labels = _extract_labels(train_labels_filename_, 60000)
test_data = _extract_images(test_data_filename_, 10000)
test_labels = _extract_labels(test_labels_filename_, 10000)
train_data = _extract_images(train_data_filename_, 20)
train_labels = _extract_labels(train_labels_filename_, 20)
test_data = _extract_images(test_data_filename_, 10)
test_labels = _extract_labels(test_labels_filename_, 10)
reader = FileReader(train_name)
for x, data, label in zip(reader.get_next(), train_data, train_labels):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册