diff --git a/paddlehub/datasets/canvas.py b/paddlehub/datasets/canvas.py index e32376d5cd13e540e9b959cf394c6c49ee057c7a..22e1aefe520812a3d01b9b0cc47d94659af88fd1 100644 --- a/paddlehub/datasets/canvas.py +++ b/paddlehub/datasets/canvas.py @@ -19,8 +19,8 @@ from typing import Callable import paddle import numpy as np +import paddlehub.env as hubenv from paddlehub.vision.utils import get_img_file -from paddlehub.env import DATA_HOME from paddlehub.utils.download import download_data @@ -47,7 +47,7 @@ class Canvas(paddle.io.Dataset): elif self.mode == 'test': self.file = 'test' - self.file = os.path.join(DATA_HOME, 'canvas', self.file) + self.file = os.path.join(hubenv.DATA_HOME, 'canvas', self.file) self.data = get_img_file(self.file) def __getitem__(self, idx: int) -> np.ndarray: diff --git a/paddlehub/datasets/flowers.py b/paddlehub/datasets/flowers.py index 5935f83dfb715e5648c7a7f4ca5725953fa40649..91452b9545002de1f755bc80836913a8a41d94f9 100644 --- a/paddlehub/datasets/flowers.py +++ b/paddlehub/datasets/flowers.py @@ -14,14 +14,18 @@ # limitations under the License. import os +from typing import Callable, Tuple import paddle +import numpy as np -from paddlehub.env import DATA_HOME +import paddlehub.env as hubenv +from paddlehub.utils.download import download_data +@download_data(url='https://bj.bcebos.com/paddlehub-dataset/flower_photos.tar.gz') class Flowers(paddle.io.Dataset): - def __init__(self, transforms=None, mode='train'): + def __init__(self, transforms: Callable, mode: str = 'train'): self.mode = mode self.transforms = transforms self.num_classes = 5 @@ -32,14 +36,14 @@ class Flowers(paddle.io.Dataset): self.file = 'test_list.txt' else: self.file = 'validate_list.txt' - self.file = os.path.join(DATA_HOME, 'flower_photos', self.file) + self.file = os.path.join(hubenv.DATA_HOME, 'flower_photos', self.file) with open(self.file, 'r') as file: self.data = file.read().split('\n') - def __getitem__(self, idx): + def __getitem__(self, idx) -> Tuple[np.ndarray, int]: img_path, grt = self.data[idx].split(' ') - img_path = os.path.join(DATA_HOME, 'flower_photos', img_path) + img_path = os.path.join(hubenv.DATA_HOME, 'flower_photos', img_path) im = self.transforms(img_path) return im, int(grt) diff --git a/paddlehub/datasets/minicoco.py b/paddlehub/datasets/minicoco.py index 552996c7e3c76dffac4c6de4e9043a7596f8d5e4..01d3d00dc45c0cc237ef135a5434c1f99a54f4a2 100644 --- a/paddlehub/datasets/minicoco.py +++ b/paddlehub/datasets/minicoco.py @@ -19,8 +19,8 @@ from typing import Callable import paddle import numpy as np +import paddlehub.env as hubenv from paddlehub.vision.utils import get_img_file -from paddlehub.env import DATA_HOME from paddlehub.utils.download import download_data @@ -46,8 +46,8 @@ class MiniCOCO(paddle.io.Dataset): self.file = 'train' elif self.mode == 'test': self.file = 'test' - self.file = os.path.join(DATA_HOME, 'minicoco', self.file) - self.style_file = os.path.join(DATA_HOME, 'minicoco', '21styles') + self.file = os.path.join(hubenv.DATA_HOME, 'minicoco', self.file) + self.style_file = os.path.join(hubenv.DATA_HOME, 'minicoco', '21styles') self.data = get_img_file(self.file) self.style = get_img_file(self.style_file)