提交 e102e5dd 编写于 作者: M minqiyang

Fix cifar10 decompress problem

上级 3ec6d60c
...@@ -32,7 +32,7 @@ import itertools ...@@ -32,7 +32,7 @@ import itertools
import numpy import numpy
import paddle.dataset.common import paddle.dataset.common
import tarfile import tarfile
from six.moves import zip import six
from six.moves import cPickle as pickle from six.moves import cPickle as pickle
__all__ = ['train100', 'test100', 'train10', 'test10', 'convert'] __all__ = ['train100', 'test100', 'train10', 'test10', 'convert']
...@@ -46,25 +46,22 @@ CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85' ...@@ -46,25 +46,22 @@ CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
def reader_creator(filename, sub_name, cycle=False): def reader_creator(filename, sub_name, cycle=False):
def read_batch(batch): def read_batch(batch):
data = batch['data'] data = batch[six.b('data')]
labels = batch.get('labels', batch.get('fine_labels', None)) labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None assert labels is not None
for sample, label in zip(data, labels): for sample, label in six.moves.zip(data, labels):
yield (sample / 255.0).astype(numpy.float32), int(label) yield (sample / 255.0).astype(numpy.float32), int(label)
def reader(): def reader():
with tarfile.open(filename, mode='r') as f: with tarfile.open(filename, mode='r') as f:
names = (each_item.name for each_item in f names = [each_item.name for each_item in f if sub_name in each_item.name]
if sub_name in each_item.name)
while True: while True:
for name in names: for name in names:
import sys if six.PY2:
print(name)
sys.stdout.flush()
print(f.extractfile(name))
sys.stdout.flush()
batch = pickle.load(f.extractfile(name)) batch = pickle.load(f.extractfile(name))
else:
batch = pickle.load(f.extractfile(name), encoding='bytes')
for item in read_batch(batch): for item in read_batch(batch):
yield item yield item
if not cycle: if not cycle:
......
...@@ -87,20 +87,14 @@ def download(url, module_name, md5sum, save_name=None): ...@@ -87,20 +87,14 @@ def download(url, module_name, md5sum, save_name=None):
if total_length is None: if total_length is None:
with open(filename, 'wb') as f: with open(filename, 'wb') as f:
import sys shutil.copyfileobj(r.raw, f)
print("write follow block")
sys.stdout.flush()
shutil.copyfileobj(cpt.to_bytes(r.raw), f)
else: else:
with open(filename, 'wb') as f: with open(filename, 'wb') as f:
import sys
print("write follow length")
sys.stdout.flush()
dl = 0 dl = 0
total_length = int(total_length) total_length = int(total_length)
for data in r.iter_content(chunk_size=4096): for data in r.iter_content(chunk_size=4096):
dl += len(data) dl += len(data)
f.write(cpt.to_bytes(data)) f.write(data)
done = int(50 * dl / total_length) done = int(50 * dl / total_length)
sys.stdout.write("\r[%s%s]" % ('=' * done, sys.stdout.write("\r[%s%s]" % ('=' * done,
' ' * (50 - done))) ' ' * (50 - done)))
......
...@@ -27,7 +27,7 @@ def to_literal_str(obj): ...@@ -27,7 +27,7 @@ def to_literal_str(obj):
def _to_literal_str(obj): def _to_literal_str(obj):
if isinstance(obj, six.binary_type): if isinstance(obj, six.binary_type):
return obj.decode('latin-1') return obj.decode('utf-8')
elif isinstance(obj, six.text_type): elif isinstance(obj, six.text_type):
return obj return obj
else: else:
...@@ -45,7 +45,7 @@ def to_bytes(obj): ...@@ -45,7 +45,7 @@ def to_bytes(obj):
def _to_bytes(obj): def _to_bytes(obj):
if isinstance(obj, six.text_type): if isinstance(obj, six.text_type):
return obj.encode('latin-1') return obj.encode('utf-8')
elif isinstance(obj, six.binary_type): elif isinstance(obj, six.binary_type):
return obj return obj
else: else:
......
...@@ -32,8 +32,8 @@ import itertools ...@@ -32,8 +32,8 @@ import itertools
import numpy import numpy
import paddle.dataset.common import paddle.dataset.common
import tarfile import tarfile
import six
from six.moves import cPickle as pickle from six.moves import cPickle as pickle
from six.moves import zip
__all__ = ['train10'] __all__ = ['train10']
...@@ -44,20 +44,23 @@ CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a' ...@@ -44,20 +44,23 @@ CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
def reader_creator(filename, sub_name, batch_size=None): def reader_creator(filename, sub_name, batch_size=None):
def read_batch(batch): def read_batch(batch):
data = batch['data'] data = batch[six.b('data')]
labels = batch.get('labels', batch.get('fine_labels', None)) labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None assert labels is not None
for sample, label in zip(data, labels): for sample, label in six.moves.zip(data, labels):
yield (sample / 255.0).astype(numpy.float32), int(label) yield (sample / 255.0).astype(numpy.float32), int(label)
def reader(): def reader():
with tarfile.open(filename, mode='r') as f: with tarfile.open(filename, mode='r') as f:
names = (each_item.name for each_item in f names = [each_item.name for each_item in f
if sub_name in each_item.name) if sub_name in each_item.name]
batch_count = 0 batch_count = 0
for name in names: for name in names:
if six.PY2:
batch = pickle.load(f.extractfile(name)) batch = pickle.load(f.extractfile(name))
else:
batch = pickle.load(f.extractfile(name), encoding='bytes')
for item in read_batch(batch): for item in read_batch(batch):
if isinstance(batch_size, int) and batch_count > batch_size: if isinstance(batch_size, int) and batch_count > batch_size:
break break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册