From 61dcfbe92fad3b15c652bf27896bf82ca6a9a254 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Wed, 4 Nov 2020 15:24:42 +0800 Subject: [PATCH] Add automatic download function to the flowers dataset --- paddlehub/datasets/canvas.py | 4 ++-- paddlehub/datasets/flowers.py | 14 +++++++++----- paddlehub/datasets/minicoco.py | 6 +++--- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/paddlehub/datasets/canvas.py b/paddlehub/datasets/canvas.py index e32376d5..22e1aefe 100644 --- a/paddlehub/datasets/canvas.py +++ b/paddlehub/datasets/canvas.py @@ -19,8 +19,8 @@ from typing import Callable import paddle import numpy as np +import paddlehub.env as hubenv from paddlehub.vision.utils import get_img_file -from paddlehub.env import DATA_HOME from paddlehub.utils.download import download_data @@ -47,7 +47,7 @@ class Canvas(paddle.io.Dataset): elif self.mode == 'test': self.file = 'test' - self.file = os.path.join(DATA_HOME, 'canvas', self.file) + self.file = os.path.join(hubenv.DATA_HOME, 'canvas', self.file) self.data = get_img_file(self.file) def __getitem__(self, idx: int) -> np.ndarray: diff --git a/paddlehub/datasets/flowers.py b/paddlehub/datasets/flowers.py index 5935f83d..91452b95 100644 --- a/paddlehub/datasets/flowers.py +++ b/paddlehub/datasets/flowers.py @@ -14,14 +14,18 @@ # limitations under the License. import os +from typing import Callable, Tuple import paddle +import numpy as np -from paddlehub.env import DATA_HOME +import paddlehub.env as hubenv +from paddlehub.utils.download import download_data +@download_data(url='https://bj.bcebos.com/paddlehub-dataset/flower_photos.tar.gz') class Flowers(paddle.io.Dataset): - def __init__(self, transforms=None, mode='train'): + def __init__(self, transforms: Callable, mode: str = 'train'): self.mode = mode self.transforms = transforms self.num_classes = 5 @@ -32,14 +36,14 @@ class Flowers(paddle.io.Dataset): self.file = 'test_list.txt' else: self.file = 'validate_list.txt' - self.file = os.path.join(DATA_HOME, 'flower_photos', self.file) + self.file = os.path.join(hubenv.DATA_HOME, 'flower_photos', self.file) with open(self.file, 'r') as file: self.data = file.read().split('\n') - def __getitem__(self, idx): + def __getitem__(self, idx) -> Tuple[np.ndarray, int]: img_path, grt = self.data[idx].split(' ') - img_path = os.path.join(DATA_HOME, 'flower_photos', img_path) + img_path = os.path.join(hubenv.DATA_HOME, 'flower_photos', img_path) im = self.transforms(img_path) return im, int(grt) diff --git a/paddlehub/datasets/minicoco.py b/paddlehub/datasets/minicoco.py index 552996c7..01d3d00d 100644 --- a/paddlehub/datasets/minicoco.py +++ b/paddlehub/datasets/minicoco.py @@ -19,8 +19,8 @@ from typing import Callable import paddle import numpy as np +import paddlehub.env as hubenv from paddlehub.vision.utils import get_img_file -from paddlehub.env import DATA_HOME from paddlehub.utils.download import download_data @@ -46,8 +46,8 @@ class MiniCOCO(paddle.io.Dataset): self.file = 'train' elif self.mode == 'test': self.file = 'test' - self.file = os.path.join(DATA_HOME, 'minicoco', self.file) - self.style_file = os.path.join(DATA_HOME, 'minicoco', '21styles') + self.file = os.path.join(hubenv.DATA_HOME, 'minicoco', self.file) + self.style_file = os.path.join(hubenv.DATA_HOME, 'minicoco', '21styles') self.data = get_img_file(self.file) self.style = get_img_file(self.style_file) -- GitLab