diff --git a/python/paddle/dataset/cifar.py b/python/paddle/dataset/cifar.py index 0e5bbfc45a8eeb40692abab456e9d961f06e5815..0d07462e684b094f35dc373fb144730fc9cc54ec 100644 --- a/python/paddle/dataset/cifar.py +++ b/python/paddle/dataset/cifar.py @@ -32,7 +32,7 @@ import itertools import numpy import paddle.dataset.common import tarfile -from six.moves import zip +import six from six.moves import cPickle as pickle __all__ = ['train100', 'test100', 'train10', 'test10', 'convert'] @@ -46,25 +46,22 @@ CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85' def reader_creator(filename, sub_name, cycle=False): def read_batch(batch): - data = batch['data'] - labels = batch.get('labels', batch.get('fine_labels', None)) + data = batch[six.b('data')] + labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'), None)) assert labels is not None - for sample, label in zip(data, labels): + for sample, label in six.moves.zip(data, labels): yield (sample / 255.0).astype(numpy.float32), int(label) def reader(): with tarfile.open(filename, mode='r') as f: - names = (each_item.name for each_item in f - if sub_name in each_item.name) + names = [each_item.name for each_item in f if sub_name in each_item.name] while True: for name in names: - import sys - print(name) - sys.stdout.flush() - print(f.extractfile(name)) - sys.stdout.flush() - batch = pickle.load(f.extractfile(name)) + if six.PY2: + batch = pickle.load(f.extractfile(name)) + else: + batch = pickle.load(f.extractfile(name), encoding='bytes') for item in read_batch(batch): yield item if not cycle: diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index 8abb4d2790da2bae4a800ffb2f75c4748136e5bf..07e6b199c003fbd89b3f6543e0a5a2d808f9fb75 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -87,20 +87,14 @@ def download(url, module_name, md5sum, save_name=None): if total_length is None: with open(filename, 'wb') as f: - import sys - print("write follow block") - sys.stdout.flush() - shutil.copyfileobj(cpt.to_bytes(r.raw), f) + shutil.copyfileobj(r.raw, f) else: with open(filename, 'wb') as f: - import sys - print("write follow length") - sys.stdout.flush() dl = 0 total_length = int(total_length) for data in r.iter_content(chunk_size=4096): dl += len(data) - f.write(cpt.to_bytes(data)) + f.write(data) done = int(50 * dl / total_length) sys.stdout.write("\r[%s%s]" % ('=' * done, ' ' * (50 - done))) diff --git a/python/paddle/fluid/compat.py b/python/paddle/fluid/compat.py index 05633583ccb60af22fdb9b991c352a06475d3771..fe23a5929ad42abb7606f3e953e1dd4c389c0485 100644 --- a/python/paddle/fluid/compat.py +++ b/python/paddle/fluid/compat.py @@ -27,7 +27,7 @@ def to_literal_str(obj): def _to_literal_str(obj): if isinstance(obj, six.binary_type): - return obj.decode('latin-1') + return obj.decode('utf-8') elif isinstance(obj, six.text_type): return obj else: @@ -45,7 +45,7 @@ def to_bytes(obj): def _to_bytes(obj): if isinstance(obj, six.text_type): - return obj.encode('latin-1') + return obj.encode('utf-8') elif isinstance(obj, six.binary_type): return obj else: diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py index e7b709f31be4df2417b0d429e43e77a68f00d63d..9afac4143e3fd8d5fdc2fe7bafbc36f58fbb470d 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/cifar10_small_test_set.py @@ -32,8 +32,8 @@ import itertools import numpy import paddle.dataset.common import tarfile +import six from six.moves import cPickle as pickle -from six.moves import zip __all__ = ['train10'] @@ -44,20 +44,23 @@ CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a' def reader_creator(filename, sub_name, batch_size=None): def read_batch(batch): - data = batch['data'] - labels = batch.get('labels', batch.get('fine_labels', None)) + data = batch[six.b('data')] + labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'), None)) assert labels is not None - for sample, label in zip(data, labels): + for sample, label in six.moves.zip(data, labels): yield (sample / 255.0).astype(numpy.float32), int(label) def reader(): with tarfile.open(filename, mode='r') as f: - names = (each_item.name for each_item in f - if sub_name in each_item.name) + names = [each_item.name for each_item in f + if sub_name in each_item.name] batch_count = 0 for name in names: - batch = pickle.load(f.extractfile(name)) + if six.PY2: + batch = pickle.load(f.extractfile(name)) + else: + batch = pickle.load(f.extractfile(name), encoding='bytes') for item in read_batch(batch): if isinstance(batch_size, int) and batch_count > batch_size: break