diff --git a/python/paddle/dataset/flowers.py b/python/paddle/dataset/flowers.py index 45a4c36f42ecd5929354b3a23933e08f5f80168b..8ca948b49bc4a74885e8cf3496d7b6c7c50b5865 100644 --- a/python/paddle/dataset/flowers.py +++ b/python/paddle/dataset/flowers.py @@ -114,35 +114,26 @@ def reader_creator(data_file, :return: data reader :rtype: callable ''' - scio = try_import('scipy.io') - - labels = scio.loadmat(label_file)['labels'][0] - indexes = scio.loadmat(setid_file)[dataset_name][0] - - img2label = {} - for i in indexes: - img = "jpg/image_%05d.jpg" % i - img2label[img] = labels[i - 1] - file_list = batch_images_from_tar(data_file, dataset_name, img2label) def reader(): - while True: - with open(file_list, 'r') as f_list: - for file in f_list: - file = file.strip() - batch = None - with open(file, 'rb') as f: - batch = pickle.load(f, encoding='bytes') - - if six.PY3: - batch = cpt.to_text(batch) - data_batch = batch['data'] - labels_batch = batch['label'] - for sample, label in six.moves.zip(data_batch, - labels_batch): - yield sample, int(label) - 1 - if not cycle: - break + scio = try_import('scipy.io') + + labels = scio.loadmat(label_file)['labels'][0] + indexes = scio.loadmat(setid_file)[dataset_name][0] + + img2label = {} + for i in indexes: + img = "jpg/image_%05d.jpg" % i + img2label[img] = labels[i - 1] + + tf = tarfile.open(data_file) + mems = tf.getmembers() + file_id = 0 + for mem in mems: + if mem.name in img2label: + image = tf.extractfile(mem).read() + label = img2label[mem.name] + yield image, int(label) - 1 if use_xmap: return xmap_readers(mapper, reader, min(4, cpu_count()), buffered_size)