未验证 提交 67c2700f 编写于 作者: G GT-Zhang 提交者: GitHub

Optimize 102Flowers dataset reading speed (#31408)

* Fix slow data reading, In the old version, one epoch read time of this data set was about 5371 seconds(MacBook Pro Retina, 13-inch, Early 2015 2.7 GHz), and a batch took 211 seconds, It's too painful to use. Now decompress the data in advance (about 10 seconds). Each epoch of reading takes about 3 seconds(MacBook Pro Retina, 13-inch, Early 2015 2.7 GHz), and a batch takes 0.017 seconds more.

* Run CI, test=allcase

* fix qq group number. test=document_fix

 fix qq group number. test=document_fix

* fix qq group number. test=document_fix 

fix qq group number. test=document_fix
上级 f0b2f598
......@@ -86,7 +86,7 @@ We provide [English](https://www.paddlepaddle.org.cn/documentation/docs/en/guide
## Communication
- [Github Issues](https://github.com/PaddlePaddle/Paddle/issues): bug reports, feature requests, install issues, usage issues, etc.
- QQ discussion group: 778260830 (PaddlePaddle).
- QQ discussion group: 793866180 (PaddlePaddle).
- [Forums](https://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc.
## Copyright and License
......
......@@ -83,7 +83,7 @@ PaddlePaddle用户可领取**免费Tesla V100在线算力资源**,训练模型
## 交流与反馈
- 欢迎您通过[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)来提交问题、报告与建议
- QQ群: 778260830 (PaddlePaddle)
- QQ群: 793866180 (PaddlePaddle)
- [论坛](https://ai.baidu.com/forum/topic/list/168): 欢迎大家在PaddlePaddle论坛分享在使用PaddlePaddle中遇到的问题和经验, 营造良好的论坛氛围
## 版权和许可证
......
......@@ -93,62 +93,44 @@ class Flowers(Dataset):
.format(backend))
self.backend = backend
self.flag = MODE_FLAG_MAP[mode.lower()]
flag = MODE_FLAG_MAP[mode.lower()]
self.data_file = data_file
if self.data_file is None:
if not data_file:
assert download, "data_file is not set and downloading automatically is disabled"
self.data_file = _check_exists_and_download(
data_file = _check_exists_and_download(
data_file, DATA_URL, DATA_MD5, 'flowers', download)
self.label_file = label_file
if self.label_file is None:
if not label_file:
assert download, "label_file is not set and downloading automatically is disabled"
self.label_file = _check_exists_and_download(
label_file = _check_exists_and_download(
label_file, LABEL_URL, LABEL_MD5, 'flowers', download)
self.setid_file = setid_file
if self.setid_file is None:
if not setid_file:
assert download, "setid_file is not set and downloading automatically is disabled"
self.setid_file = _check_exists_and_download(
setid_file = _check_exists_and_download(
setid_file, SETID_URL, SETID_MD5, 'flowers', download)
self.transform = transform
# read dataset into memory
self._load_anno()
self.dtype = paddle.get_default_dtype()
def _load_anno(self):
self.name2mem = {}
self.data_tar = tarfile.open(self.data_file)
for ele in self.data_tar.getmembers():
self.name2mem[ele.name] = ele
data_tar = tarfile.open(data_file)
self.data_path = data_file.replace(".tgz", "/")
if not os.path.exists(self.data_path):
os.mkdir(self.data_path)
data_tar.extractall(self.data_path)
scio = try_import('scipy.io')
# double check data download
self.label_file = _check_exists_and_download(self.label_file, LABEL_URL,
LABEL_MD5, 'flowers', True)
self.setid_file = _check_exists_and_download(self.setid_file, SETID_URL,
SETID_MD5, 'flowers', True)
self.labels = scio.loadmat(self.label_file)['labels'][0]
self.indexes = scio.loadmat(self.setid_file)[self.flag][0]
self.labels = scio.loadmat(label_file)['labels'][0]
self.indexes = scio.loadmat(setid_file)[flag][0]
def __getitem__(self, idx):
index = self.indexes[idx]
label = np.array([self.labels[index - 1]])
img_name = "jpg/image_%05d.jpg" % index
img_ele = self.name2mem[img_name]
image = self.data_tar.extractfile(img_ele).read()
image = os.path.join(self.data_path, img_name)
if self.backend == 'pil':
image = Image.open(io.BytesIO(image))
image = Image.open(image)
elif self.backend == 'cv2':
image = np.array(Image.open(io.BytesIO(image)))
image = np.array(Image.open(image))
if self.transform is not None:
image = self.transform(image)
......@@ -156,7 +138,7 @@ class Flowers(Dataset):
if self.backend == 'pil':
return image, label.astype('int64')
return image.astype(self.dtype), label.astype('int64')
return image.astype(paddle.get_default_dtype()), label.astype('int64')
def __len__(self):
return len(self.indexes)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册