未验证 提交 ee2be095 编写于 作者: H haoyuying 提交者: GitHub

add download datast

上级 73e06533
......@@ -2,7 +2,7 @@ import paddle
import paddlehub as hub
import paddlehub.process.transforms as T
from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.colorizedataset import Colorizedataset
from paddlehub.datasets.Canvas import Canvas
if __name__ == '__main__':
......@@ -13,7 +13,7 @@ if __name__ == '__main__':
stay_rgb=True,
is_permute=False)
color_set = Colorizedataset(transform=transform, mode='train')
color_set = Canvas(transform=transform, mode='train')
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='img_colorization_ckpt')
trainer.train(color_set, epochs=101, batch_size=2, eval_dataset=color_set, log_interval=10, save_interval=10)
......@@ -2,14 +2,14 @@ import paddle
import paddlehub as hub
import paddlehub.process.transforms as T
from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.styletransfer import StyleTransferData
from paddlehub.datasets.MiniCOCO import MiniCOCO
if __name__ == "__main__":
model = hub.Module(name='msgnet')
transform = T.Compose([T.Resize(
(256, 256), interp='LINEAR'), T.CenterCrop(crop_size=256)], T.SetType(datatype='float32'))
styledata = StyleTransferData(transform)
styledata = MiniCOCO(transform)
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='img_style_transfer_ckpt')
trainer.train(styledata, epochs=5, batch_size=16, eval_dataset=styledata, log_interval=1, save_interval=1)
......@@ -21,11 +21,14 @@ import paddle
from paddlehub.process.functional import get_img_file
from paddlehub.env import DATA_HOME
from typing import Callable
from paddlehub.utils.download import download_data
class Colorizedataset(paddle.io.Dataset):
@download_data(url='https://paddlehub.bj.bcebos.com/dygraph/datasets/canvas.tar.gz')
class Canvas(paddle.io.Dataset):
"""
Dataset for colorization.
Dataset for colorization. It contains 1193 and 400 pictures for Monet and Vango paintings style, respectively.
We collected data from https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/.
Args:
transform(callmethod) : The method of preprocess images.
......@@ -34,6 +37,7 @@ class Colorizedataset(paddle.io.Dataset):
Returns:
DataSet: An iterable object for data iterating
"""
def __init__(self, transform: Callable, mode: str = 'train'):
self.mode = mode
self.transform = transform
......
......@@ -19,11 +19,14 @@ from typing import Callable
import paddle
from paddlehub.process.functional import get_img_file
from paddlehub.env import DATA_HOME
from paddlehub.utils.download import download_data
class StyleTransferData(paddle.io.Dataset):
@download_data(url='https://paddlehub.bj.bcebos.com/dygraph/datasets/minicoco.tar.gz')
class MiniCOCO(paddle.io.Dataset):
"""
Dataset for Style transfer.
Dataset for Style transfer. The dataset contains 2001 images for training set and 200 images for testing set.
They are derived form COCO2014. Meanwhile, it contains 21 different style pictures in file "21styles".
Args:
transform(callmethod) : The method of preprocess images.
......@@ -32,6 +35,7 @@ class StyleTransferData(paddle.io.Dataset):
Returns:
DataSet: An iterable object for data iterating
"""
def __init__(self, transform: Callable, mode: str = 'train'):
self.mode = mode
self.transform = transform
......
import os
from paddlehub.env import DATA_HOME
from paddle.utils.download import get_path_from_url
def download_data(url):
save_name = os.path.basename(url).split('.')[0]
output_path = os.path.join(DATA_HOME, save_name)
if not os.path.exists(output_path):
get_path_from_url(url, DATA_HOME)
def _wrapper(Dataset):
return Dataset
return _wrapper
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册