diff --git a/README.md b/README.md index 8b437e4115abe80073866f52f3d7e387e2a554d3..d0a35332d474e32691b595df8c4c2d0e780bd344 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ We provide [English](https://www.paddlepaddle.org.cn/documentation/docs/en/guide ## Communication - [Github Issues](https://github.com/PaddlePaddle/Paddle/issues): bug reports, feature requests, install issues, usage issues, etc. -- QQ discussion group: 778260830 (PaddlePaddle). +- QQ discussion group: 793866180 (PaddlePaddle). - [Forums](https://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc. ## Copyright and License diff --git a/README_cn.md b/README_cn.md index 7a10cba2845498d2299fc516f5804eb1a84e4ecc..2be8be3df6e7b2f99cda1f34c9359a45e51ae5ea 100644 --- a/README_cn.md +++ b/README_cn.md @@ -83,7 +83,7 @@ PaddlePaddle用户可领取**免费Tesla V100在线算力资源**,训练模型 ## 交流与反馈 - 欢迎您通过[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)来提交问题、报告与建议 -- QQ群: 778260830 (PaddlePaddle) +- QQ群: 793866180 (PaddlePaddle) - [论坛](https://ai.baidu.com/forum/topic/list/168): 欢迎大家在PaddlePaddle论坛分享在使用PaddlePaddle中遇到的问题和经验, 营造良好的论坛氛围 ## 版权和许可证 diff --git a/python/paddle/vision/datasets/flowers.py b/python/paddle/vision/datasets/flowers.py index 448d6efb52beca953de7981312e8f9131e6fb05d..65c0b604efd5d719cf9df313592b6e3561b5958a 100644 --- a/python/paddle/vision/datasets/flowers.py +++ b/python/paddle/vision/datasets/flowers.py @@ -93,62 +93,44 @@ class Flowers(Dataset): .format(backend)) self.backend = backend - self.flag = MODE_FLAG_MAP[mode.lower()] + flag = MODE_FLAG_MAP[mode.lower()] - self.data_file = data_file - if self.data_file is None: + if not data_file: assert download, "data_file is not set and downloading automatically is disabled" - self.data_file = _check_exists_and_download( + data_file = _check_exists_and_download( data_file, DATA_URL, DATA_MD5, 'flowers', download) - self.label_file = label_file - if self.label_file is None: + if not label_file: assert download, "label_file is not set and downloading automatically is disabled" - self.label_file = _check_exists_and_download( + label_file = _check_exists_and_download( label_file, LABEL_URL, LABEL_MD5, 'flowers', download) - self.setid_file = setid_file - if self.setid_file is None: + if not setid_file: assert download, "setid_file is not set and downloading automatically is disabled" - self.setid_file = _check_exists_and_download( + setid_file = _check_exists_and_download( setid_file, SETID_URL, SETID_MD5, 'flowers', download) self.transform = transform - # read dataset into memory - self._load_anno() - - self.dtype = paddle.get_default_dtype() - - def _load_anno(self): - self.name2mem = {} - self.data_tar = tarfile.open(self.data_file) - for ele in self.data_tar.getmembers(): - self.name2mem[ele.name] = ele + data_tar = tarfile.open(data_file) + self.data_path = data_file.replace(".tgz", "/") + if not os.path.exists(self.data_path): + os.mkdir(self.data_path) + data_tar.extractall(self.data_path) scio = try_import('scipy.io') - - # double check data download - self.label_file = _check_exists_and_download(self.label_file, LABEL_URL, - LABEL_MD5, 'flowers', True) - - self.setid_file = _check_exists_and_download(self.setid_file, SETID_URL, - SETID_MD5, 'flowers', True) - - self.labels = scio.loadmat(self.label_file)['labels'][0] - self.indexes = scio.loadmat(self.setid_file)[self.flag][0] + self.labels = scio.loadmat(label_file)['labels'][0] + self.indexes = scio.loadmat(setid_file)[flag][0] def __getitem__(self, idx): index = self.indexes[idx] label = np.array([self.labels[index - 1]]) img_name = "jpg/image_%05d.jpg" % index - img_ele = self.name2mem[img_name] - image = self.data_tar.extractfile(img_ele).read() - + image = os.path.join(self.data_path, img_name) if self.backend == 'pil': - image = Image.open(io.BytesIO(image)) + image = Image.open(image) elif self.backend == 'cv2': - image = np.array(Image.open(io.BytesIO(image))) + image = np.array(Image.open(image)) if self.transform is not None: image = self.transform(image) @@ -156,7 +138,7 @@ class Flowers(Dataset): if self.backend == 'pil': return image, label.astype('int64') - return image.astype(self.dtype), label.astype('int64') + return image.astype(paddle.get_default_dtype()), label.astype('int64') def __len__(self): return len(self.indexes)