From db7b3d1c969a386ab3fb242119224d92f4e21614 Mon Sep 17 00:00:00 2001 From: zhouzj <41366441+zzjjay@users.noreply.github.com> Date: Thu, 24 Jun 2021 18:44:16 +0800 Subject: [PATCH] update_dataset_flowers (#33738) --- python/paddle/dataset/flowers.py | 45 +++++++++++++------------------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/python/paddle/dataset/flowers.py b/python/paddle/dataset/flowers.py index 45a4c36f42e..8ca948b49bc 100644 --- a/python/paddle/dataset/flowers.py +++ b/python/paddle/dataset/flowers.py @@ -114,35 +114,26 @@ def reader_creator(data_file, :return: data reader :rtype: callable ''' - scio = try_import('scipy.io') - - labels = scio.loadmat(label_file)['labels'][0] - indexes = scio.loadmat(setid_file)[dataset_name][0] - - img2label = {} - for i in indexes: - img = "jpg/image_%05d.jpg" % i - img2label[img] = labels[i - 1] - file_list = batch_images_from_tar(data_file, dataset_name, img2label) def reader(): - while True: - with open(file_list, 'r') as f_list: - for file in f_list: - file = file.strip() - batch = None - with open(file, 'rb') as f: - batch = pickle.load(f, encoding='bytes') - - if six.PY3: - batch = cpt.to_text(batch) - data_batch = batch['data'] - labels_batch = batch['label'] - for sample, label in six.moves.zip(data_batch, - labels_batch): - yield sample, int(label) - 1 - if not cycle: - break + scio = try_import('scipy.io') + + labels = scio.loadmat(label_file)['labels'][0] + indexes = scio.loadmat(setid_file)[dataset_name][0] + + img2label = {} + for i in indexes: + img = "jpg/image_%05d.jpg" % i + img2label[img] = labels[i - 1] + + tf = tarfile.open(data_file) + mems = tf.getmembers() + file_id = 0 + for mem in mems: + if mem.name in img2label: + image = tf.extractfile(mem).read() + label = img2label[mem.name] + yield image, int(label) - 1 if use_xmap: return xmap_readers(mapper, reader, min(4, cpu_count()), buffered_size) -- GitLab