提交 e102e5dd 编写于 作者: M minqiyang

Fix cifar10 decompress problem

上级 3ec6d60c
......@@ -32,7 +32,7 @@ import itertools
import numpy
import paddle.dataset.common
import tarfile
from six.moves import zip
import six
from six.moves import cPickle as pickle
__all__ = ['train100', 'test100', 'train10', 'test10', 'convert']
......@@ -46,25 +46,22 @@ CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
def reader_creator(filename, sub_name, cycle=False):
def read_batch(batch):
data = batch['data']
labels = batch.get('labels', batch.get('fine_labels', None))
data = batch[six.b('data')]
labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None
for sample, label in zip(data, labels):
for sample, label in six.moves.zip(data, labels):
yield (sample / 255.0).astype(numpy.float32), int(label)
def reader():
with tarfile.open(filename, mode='r') as f:
names = (each_item.name for each_item in f
if sub_name in each_item.name)
names = [each_item.name for each_item in f if sub_name in each_item.name]
while True:
for name in names:
import sys
print(name)
sys.stdout.flush()
print(f.extractfile(name))
sys.stdout.flush()
batch = pickle.load(f.extractfile(name))
if six.PY2:
batch = pickle.load(f.extractfile(name))
else:
batch = pickle.load(f.extractfile(name), encoding='bytes')
for item in read_batch(batch):
yield item
if not cycle:
......
......@@ -87,20 +87,14 @@ def download(url, module_name, md5sum, save_name=None):
if total_length is None:
with open(filename, 'wb') as f:
import sys
print("write follow block")
sys.stdout.flush()
shutil.copyfileobj(cpt.to_bytes(r.raw), f)
shutil.copyfileobj(r.raw, f)
else:
with open(filename, 'wb') as f:
import sys
print("write follow length")
sys.stdout.flush()
dl = 0
total_length = int(total_length)
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(cpt.to_bytes(data))
f.write(data)
done = int(50 * dl / total_length)
sys.stdout.write("\r[%s%s]" % ('=' * done,
' ' * (50 - done)))
......
......@@ -27,7 +27,7 @@ def to_literal_str(obj):
def _to_literal_str(obj):
if isinstance(obj, six.binary_type):
return obj.decode('latin-1')
return obj.decode('utf-8')
elif isinstance(obj, six.text_type):
return obj
else:
......@@ -45,7 +45,7 @@ def to_bytes(obj):
def _to_bytes(obj):
if isinstance(obj, six.text_type):
return obj.encode('latin-1')
return obj.encode('utf-8')
elif isinstance(obj, six.binary_type):
return obj
else:
......
......@@ -32,8 +32,8 @@ import itertools
import numpy
import paddle.dataset.common
import tarfile
import six
from six.moves import cPickle as pickle
from six.moves import zip
__all__ = ['train10']
......@@ -44,20 +44,23 @@ CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
def reader_creator(filename, sub_name, batch_size=None):
def read_batch(batch):
data = batch['data']
labels = batch.get('labels', batch.get('fine_labels', None))
data = batch[six.b('data')]
labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None
for sample, label in zip(data, labels):
for sample, label in six.moves.zip(data, labels):
yield (sample / 255.0).astype(numpy.float32), int(label)
def reader():
with tarfile.open(filename, mode='r') as f:
names = (each_item.name for each_item in f
if sub_name in each_item.name)
names = [each_item.name for each_item in f
if sub_name in each_item.name]
batch_count = 0
for name in names:
batch = pickle.load(f.extractfile(name))
if six.PY2:
batch = pickle.load(f.extractfile(name))
else:
batch = pickle.load(f.extractfile(name), encoding='bytes')
for item in read_batch(batch):
if isinstance(batch_size, int) and batch_count > batch_size:
break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册