From 67c2700f606ae8aacbeb563e3577e258e6fcd0fe Mon Sep 17 00:00:00 2001 From: GT-Zhang <46156734+GT-ZhangAcer@users.noreply.github.com> Date: Wed, 19 May 2021 15:38:33 +0800 Subject: [PATCH] Optimize 102Flowers dataset reading speed (#31408) * Fix slow data reading, In the old version, one epoch read time of this data set was about 5371 seconds(MacBook Pro Retina, 13-inch, Early 2015 2.7 GHz), and a batch took 211 seconds, It's too painful to use. Now decompress the data in advance (about 10 seconds). Each epoch of reading takes about 3 seconds(MacBook Pro Retina, 13-inch, Early 2015 2.7 GHz), and a batch takes 0.017 seconds more. * Run CI, test=allcase * fix qq group number. test=document_fix fix qq group number. test=document_fix * fix qq group number. test=document_fix fix qq group number. test=document_fix --- README.md | 2 +- README_cn.md | 2 +- python/paddle/vision/datasets/flowers.py | 54 ++++++++---------------- 3 files changed, 20 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 8b437e4115a..d0a35332d47 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ We provide [English](https://www.paddlepaddle.org.cn/documentation/docs/en/guide ## Communication - [Github Issues](https://github.com/PaddlePaddle/Paddle/issues): bug reports, feature requests, install issues, usage issues, etc. -- QQ discussion group: 778260830 (PaddlePaddle). +- QQ discussion group: 793866180 (PaddlePaddle). - [Forums](https://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc. ## Copyright and License diff --git a/README_cn.md b/README_cn.md index 7a10cba2845..2be8be3df6e 100644 --- a/README_cn.md +++ b/README_cn.md @@ -83,7 +83,7 @@ PaddlePaddle用户可领取**免费Tesla V100在线算力资源**,训练模型 ## 交流与反馈 - 欢迎您通过[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)来提交问题、报告与建议 -- QQ群: 778260830 (PaddlePaddle) +- QQ群: 793866180 (PaddlePaddle) - [论坛](https://ai.baidu.com/forum/topic/list/168): 欢迎大家在PaddlePaddle论坛分享在使用PaddlePaddle中遇到的问题和经验, 营造良好的论坛氛围 ## 版权和许可证 diff --git a/python/paddle/vision/datasets/flowers.py b/python/paddle/vision/datasets/flowers.py index 448d6efb52b..65c0b604efd 100644 --- a/python/paddle/vision/datasets/flowers.py +++ b/python/paddle/vision/datasets/flowers.py @@ -93,62 +93,44 @@ class Flowers(Dataset): .format(backend)) self.backend = backend - self.flag = MODE_FLAG_MAP[mode.lower()] + flag = MODE_FLAG_MAP[mode.lower()] - self.data_file = data_file - if self.data_file is None: + if not data_file: assert download, "data_file is not set and downloading automatically is disabled" - self.data_file = _check_exists_and_download( + data_file = _check_exists_and_download( data_file, DATA_URL, DATA_MD5, 'flowers', download) - self.label_file = label_file - if self.label_file is None: + if not label_file: assert download, "label_file is not set and downloading automatically is disabled" - self.label_file = _check_exists_and_download( + label_file = _check_exists_and_download( label_file, LABEL_URL, LABEL_MD5, 'flowers', download) - self.setid_file = setid_file - if self.setid_file is None: + if not setid_file: assert download, "setid_file is not set and downloading automatically is disabled" - self.setid_file = _check_exists_and_download( + setid_file = _check_exists_and_download( setid_file, SETID_URL, SETID_MD5, 'flowers', download) self.transform = transform - # read dataset into memory - self._load_anno() - - self.dtype = paddle.get_default_dtype() - - def _load_anno(self): - self.name2mem = {} - self.data_tar = tarfile.open(self.data_file) - for ele in self.data_tar.getmembers(): - self.name2mem[ele.name] = ele + data_tar = tarfile.open(data_file) + self.data_path = data_file.replace(".tgz", "/") + if not os.path.exists(self.data_path): + os.mkdir(self.data_path) + data_tar.extractall(self.data_path) scio = try_import('scipy.io') - - # double check data download - self.label_file = _check_exists_and_download(self.label_file, LABEL_URL, - LABEL_MD5, 'flowers', True) - - self.setid_file = _check_exists_and_download(self.setid_file, SETID_URL, - SETID_MD5, 'flowers', True) - - self.labels = scio.loadmat(self.label_file)['labels'][0] - self.indexes = scio.loadmat(self.setid_file)[self.flag][0] + self.labels = scio.loadmat(label_file)['labels'][0] + self.indexes = scio.loadmat(setid_file)[flag][0] def __getitem__(self, idx): index = self.indexes[idx] label = np.array([self.labels[index - 1]]) img_name = "jpg/image_%05d.jpg" % index - img_ele = self.name2mem[img_name] - image = self.data_tar.extractfile(img_ele).read() - + image = os.path.join(self.data_path, img_name) if self.backend == 'pil': - image = Image.open(io.BytesIO(image)) + image = Image.open(image) elif self.backend == 'cv2': - image = np.array(Image.open(io.BytesIO(image))) + image = np.array(Image.open(image)) if self.transform is not None: image = self.transform(image) @@ -156,7 +138,7 @@ class Flowers(Dataset): if self.backend == 'pil': return image, label.astype('int64') - return image.astype(self.dtype), label.astype('int64') + return image.astype(paddle.get_default_dtype()), label.astype('int64') def __len__(self): return len(self.indexes) -- GitLab