未验证 提交 ee2be095 编写于 作者: H haoyuying 提交者: GitHub

add download datast

上级 73e06533
...@@ -2,7 +2,7 @@ import paddle ...@@ -2,7 +2,7 @@ import paddle
import paddlehub as hub import paddlehub as hub
import paddlehub.process.transforms as T import paddlehub.process.transforms as T
from paddlehub.finetune.trainer import Trainer from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.colorizedataset import Colorizedataset from paddlehub.datasets.Canvas import Canvas
if __name__ == '__main__': if __name__ == '__main__':
...@@ -13,7 +13,7 @@ if __name__ == '__main__': ...@@ -13,7 +13,7 @@ if __name__ == '__main__':
stay_rgb=True, stay_rgb=True,
is_permute=False) is_permute=False)
color_set = Colorizedataset(transform=transform, mode='train') color_set = Canvas(transform=transform, mode='train')
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters()) optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='img_colorization_ckpt') trainer = Trainer(model, optimizer, checkpoint_dir='img_colorization_ckpt')
trainer.train(color_set, epochs=101, batch_size=2, eval_dataset=color_set, log_interval=10, save_interval=10) trainer.train(color_set, epochs=101, batch_size=2, eval_dataset=color_set, log_interval=10, save_interval=10)
...@@ -2,14 +2,14 @@ import paddle ...@@ -2,14 +2,14 @@ import paddle
import paddlehub as hub import paddlehub as hub
import paddlehub.process.transforms as T import paddlehub.process.transforms as T
from paddlehub.finetune.trainer import Trainer from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.styletransfer import StyleTransferData from paddlehub.datasets.MiniCOCO import MiniCOCO
if __name__ == "__main__": if __name__ == "__main__":
model = hub.Module(name='msgnet') model = hub.Module(name='msgnet')
transform = T.Compose([T.Resize( transform = T.Compose([T.Resize(
(256, 256), interp='LINEAR'), T.CenterCrop(crop_size=256)], T.SetType(datatype='float32')) (256, 256), interp='LINEAR'), T.CenterCrop(crop_size=256)], T.SetType(datatype='float32'))
styledata = StyleTransferData(transform) styledata = MiniCOCO(transform)
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters()) optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='img_style_transfer_ckpt') trainer = Trainer(model, optimizer, checkpoint_dir='img_style_transfer_ckpt')
trainer.train(styledata, epochs=5, batch_size=16, eval_dataset=styledata, log_interval=1, save_interval=1) trainer.train(styledata, epochs=5, batch_size=16, eval_dataset=styledata, log_interval=1, save_interval=1)
...@@ -21,11 +21,14 @@ import paddle ...@@ -21,11 +21,14 @@ import paddle
from paddlehub.process.functional import get_img_file from paddlehub.process.functional import get_img_file
from paddlehub.env import DATA_HOME from paddlehub.env import DATA_HOME
from typing import Callable from typing import Callable
from paddlehub.utils.download import download_data
class Colorizedataset(paddle.io.Dataset): @download_data(url='https://paddlehub.bj.bcebos.com/dygraph/datasets/canvas.tar.gz')
class Canvas(paddle.io.Dataset):
""" """
Dataset for colorization. Dataset for colorization. It contains 1193 and 400 pictures for Monet and Vango paintings style, respectively.
We collected data from https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/.
Args: Args:
transform(callmethod) : The method of preprocess images. transform(callmethod) : The method of preprocess images.
...@@ -34,6 +37,7 @@ class Colorizedataset(paddle.io.Dataset): ...@@ -34,6 +37,7 @@ class Colorizedataset(paddle.io.Dataset):
Returns: Returns:
DataSet: An iterable object for data iterating DataSet: An iterable object for data iterating
""" """
def __init__(self, transform: Callable, mode: str = 'train'): def __init__(self, transform: Callable, mode: str = 'train'):
self.mode = mode self.mode = mode
self.transform = transform self.transform = transform
......
...@@ -19,11 +19,14 @@ from typing import Callable ...@@ -19,11 +19,14 @@ from typing import Callable
import paddle import paddle
from paddlehub.process.functional import get_img_file from paddlehub.process.functional import get_img_file
from paddlehub.env import DATA_HOME from paddlehub.env import DATA_HOME
from paddlehub.utils.download import download_data
class StyleTransferData(paddle.io.Dataset): @download_data(url='https://paddlehub.bj.bcebos.com/dygraph/datasets/minicoco.tar.gz')
class MiniCOCO(paddle.io.Dataset):
""" """
Dataset for Style transfer. Dataset for Style transfer. The dataset contains 2001 images for training set and 200 images for testing set.
They are derived form COCO2014. Meanwhile, it contains 21 different style pictures in file "21styles".
Args: Args:
transform(callmethod) : The method of preprocess images. transform(callmethod) : The method of preprocess images.
...@@ -32,6 +35,7 @@ class StyleTransferData(paddle.io.Dataset): ...@@ -32,6 +35,7 @@ class StyleTransferData(paddle.io.Dataset):
Returns: Returns:
DataSet: An iterable object for data iterating DataSet: An iterable object for data iterating
""" """
def __init__(self, transform: Callable, mode: str = 'train'): def __init__(self, transform: Callable, mode: str = 'train'):
self.mode = mode self.mode = mode
self.transform = transform self.transform = transform
......
import os
from paddlehub.env import DATA_HOME
from paddle.utils.download import get_path_from_url
def download_data(url):
save_name = os.path.basename(url).split('.')[0]
output_path = os.path.join(DATA_HOME, save_name)
if not os.path.exists(output_path):
get_path_from_url(url, DATA_HOME)
def _wrapper(Dataset):
return Dataset
return _wrapper
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册