diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index 68660601c161d2332b17b448fae089506238ba78..94ca21de8e4b8cfca82f220866df705999eb42c0 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -17,6 +17,7 @@ import hashlib import os import errno import shutil +import six import sys import importlib import paddle.dataset @@ -93,6 +94,8 @@ def download(url, module_name, md5sum, save_name=None): dl = 0 total_length = int(total_length) for data in r.iter_content(chunk_size=4096): + if six.PY2: + data = six.b(data) dl += len(data) f.write(data) done = int(50 * dl / total_length) diff --git a/python/paddle/dataset/flowers.py b/python/paddle/dataset/flowers.py index 527044b415533cc640e3cfc5837c08ab0f8b74b1..e50d0d0ace9b1d790aa644aa9b7e2893f0359603 100644 --- a/python/paddle/dataset/flowers.py +++ b/python/paddle/dataset/flowers.py @@ -28,11 +28,12 @@ Graphics and Image Processing (2008) http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}. """ -import cPickle import itertools import functools from common import download import tarfile +import six +from six.moves import cPickle as pickle import scipy.io as scio from paddle.dataset.image import * from paddle.reader import * @@ -41,10 +42,10 @@ import numpy as np from multiprocessing import cpu_count __all__ = ['train', 'test', 'valid'] -DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz' -LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat' -SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' -DATA_MD5 = '33bfc11892f1e405ca193ae9a9f2a118' +DATA_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/102flowers.tgz' +LABEL_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/imagelabels.mat' +SETID_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/setid.mat' +DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' # In official 'readme', tstid is the flag of test data @@ -111,8 +112,11 @@ def reader_creator(data_file, for file in open(file_list): file = file.strip() batch = None - with open(file, 'r') as f: - batch = cPickle.load(f) + with open(file, 'rb') as f: + if six.PY2: + batch = pickle.load(f) + else: + batch = pickle.load(f, encoding='bytes') data = batch['data'] labels = batch['label'] for sample, label in itertools.izip(data, batch['label']):