提交 61dcfbe9 编写于 作者: W wuzewu

Add automatic download function to the flowers dataset

上级 a3dcc0cb
......@@ -19,8 +19,8 @@ from typing import Callable
import paddle
import numpy as np
import paddlehub.env as hubenv
from paddlehub.vision.utils import get_img_file
from paddlehub.env import DATA_HOME
from paddlehub.utils.download import download_data
......@@ -47,7 +47,7 @@ class Canvas(paddle.io.Dataset):
elif self.mode == 'test':
self.file = 'test'
self.file = os.path.join(DATA_HOME, 'canvas', self.file)
self.file = os.path.join(hubenv.DATA_HOME, 'canvas', self.file)
self.data = get_img_file(self.file)
def __getitem__(self, idx: int) -> np.ndarray:
......
......@@ -14,14 +14,18 @@
# limitations under the License.
import os
from typing import Callable, Tuple
import paddle
import numpy as np
from paddlehub.env import DATA_HOME
import paddlehub.env as hubenv
from paddlehub.utils.download import download_data
@download_data(url='https://bj.bcebos.com/paddlehub-dataset/flower_photos.tar.gz')
class Flowers(paddle.io.Dataset):
def __init__(self, transforms=None, mode='train'):
def __init__(self, transforms: Callable, mode: str = 'train'):
self.mode = mode
self.transforms = transforms
self.num_classes = 5
......@@ -32,14 +36,14 @@ class Flowers(paddle.io.Dataset):
self.file = 'test_list.txt'
else:
self.file = 'validate_list.txt'
self.file = os.path.join(DATA_HOME, 'flower_photos', self.file)
self.file = os.path.join(hubenv.DATA_HOME, 'flower_photos', self.file)
with open(self.file, 'r') as file:
self.data = file.read().split('\n')
def __getitem__(self, idx):
def __getitem__(self, idx) -> Tuple[np.ndarray, int]:
img_path, grt = self.data[idx].split(' ')
img_path = os.path.join(DATA_HOME, 'flower_photos', img_path)
img_path = os.path.join(hubenv.DATA_HOME, 'flower_photos', img_path)
im = self.transforms(img_path)
return im, int(grt)
......
......@@ -19,8 +19,8 @@ from typing import Callable
import paddle
import numpy as np
import paddlehub.env as hubenv
from paddlehub.vision.utils import get_img_file
from paddlehub.env import DATA_HOME
from paddlehub.utils.download import download_data
......@@ -46,8 +46,8 @@ class MiniCOCO(paddle.io.Dataset):
self.file = 'train'
elif self.mode == 'test':
self.file = 'test'
self.file = os.path.join(DATA_HOME, 'minicoco', self.file)
self.style_file = os.path.join(DATA_HOME, 'minicoco', '21styles')
self.file = os.path.join(hubenv.DATA_HOME, 'minicoco', self.file)
self.style_file = os.path.join(hubenv.DATA_HOME, 'minicoco', '21styles')
self.data = get_img_file(self.file)
self.style = get_img_file(self.style_file)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册