提交 61dcfbe9 编写于 作者: W wuzewu

Add automatic download function to the flowers dataset

上级 a3dcc0cb
...@@ -19,8 +19,8 @@ from typing import Callable ...@@ -19,8 +19,8 @@ from typing import Callable
import paddle import paddle
import numpy as np import numpy as np
import paddlehub.env as hubenv
from paddlehub.vision.utils import get_img_file from paddlehub.vision.utils import get_img_file
from paddlehub.env import DATA_HOME
from paddlehub.utils.download import download_data from paddlehub.utils.download import download_data
...@@ -47,7 +47,7 @@ class Canvas(paddle.io.Dataset): ...@@ -47,7 +47,7 @@ class Canvas(paddle.io.Dataset):
elif self.mode == 'test': elif self.mode == 'test':
self.file = 'test' self.file = 'test'
self.file = os.path.join(DATA_HOME, 'canvas', self.file) self.file = os.path.join(hubenv.DATA_HOME, 'canvas', self.file)
self.data = get_img_file(self.file) self.data = get_img_file(self.file)
def __getitem__(self, idx: int) -> np.ndarray: def __getitem__(self, idx: int) -> np.ndarray:
......
...@@ -14,14 +14,18 @@ ...@@ -14,14 +14,18 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import Callable, Tuple
import paddle import paddle
import numpy as np
from paddlehub.env import DATA_HOME import paddlehub.env as hubenv
from paddlehub.utils.download import download_data
@download_data(url='https://bj.bcebos.com/paddlehub-dataset/flower_photos.tar.gz')
class Flowers(paddle.io.Dataset): class Flowers(paddle.io.Dataset):
def __init__(self, transforms=None, mode='train'): def __init__(self, transforms: Callable, mode: str = 'train'):
self.mode = mode self.mode = mode
self.transforms = transforms self.transforms = transforms
self.num_classes = 5 self.num_classes = 5
...@@ -32,14 +36,14 @@ class Flowers(paddle.io.Dataset): ...@@ -32,14 +36,14 @@ class Flowers(paddle.io.Dataset):
self.file = 'test_list.txt' self.file = 'test_list.txt'
else: else:
self.file = 'validate_list.txt' self.file = 'validate_list.txt'
self.file = os.path.join(DATA_HOME, 'flower_photos', self.file) self.file = os.path.join(hubenv.DATA_HOME, 'flower_photos', self.file)
with open(self.file, 'r') as file: with open(self.file, 'r') as file:
self.data = file.read().split('\n') self.data = file.read().split('\n')
def __getitem__(self, idx): def __getitem__(self, idx) -> Tuple[np.ndarray, int]:
img_path, grt = self.data[idx].split(' ') img_path, grt = self.data[idx].split(' ')
img_path = os.path.join(DATA_HOME, 'flower_photos', img_path) img_path = os.path.join(hubenv.DATA_HOME, 'flower_photos', img_path)
im = self.transforms(img_path) im = self.transforms(img_path)
return im, int(grt) return im, int(grt)
......
...@@ -19,8 +19,8 @@ from typing import Callable ...@@ -19,8 +19,8 @@ from typing import Callable
import paddle import paddle
import numpy as np import numpy as np
import paddlehub.env as hubenv
from paddlehub.vision.utils import get_img_file from paddlehub.vision.utils import get_img_file
from paddlehub.env import DATA_HOME
from paddlehub.utils.download import download_data from paddlehub.utils.download import download_data
...@@ -46,8 +46,8 @@ class MiniCOCO(paddle.io.Dataset): ...@@ -46,8 +46,8 @@ class MiniCOCO(paddle.io.Dataset):
self.file = 'train' self.file = 'train'
elif self.mode == 'test': elif self.mode == 'test':
self.file = 'test' self.file = 'test'
self.file = os.path.join(DATA_HOME, 'minicoco', self.file) self.file = os.path.join(hubenv.DATA_HOME, 'minicoco', self.file)
self.style_file = os.path.join(DATA_HOME, 'minicoco', '21styles') self.style_file = os.path.join(hubenv.DATA_HOME, 'minicoco', '21styles')
self.data = get_img_file(self.file) self.data = get_img_file(self.file)
self.style = get_img_file(self.style_file) self.style = get_img_file(self.style_file)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册