提交 722c2d58 编写于 作者: H haoyuying

revise conflict

......@@ -2,9 +2,8 @@ import paddle
import paddlehub as hub
import paddle.nn as nn
if __name__ == '__main__':
paddle.disable_static()
model = hub.Module(directory='user_guided_colorization')
model = hub.Module(name='user_guided_colorization')
model.eval()
result = model.predict(images='sea.jpg')
\ No newline at end of file
result = model.predict(images='house.png')
......@@ -6,15 +6,21 @@ from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.colorizedataset import Colorizedataset
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess
if __name__ == '__main__':
is_train = True
paddle.disable_static()
model = hub.Module(directory='user_guided_colorization')
transform = Compose([Resize((256,256),interp="RANDOM"),RandomPaddingCrop(crop_size=176), ConvertColorSpace(mode='RGB2LAB'), ColorizePreprocess(ab_thresh=0, p=1)], stay_rgb=True)
color_set = Colorizedataset(transform=transform, mode=is_train)
model = hub.Module(name='user_guided_colorization')
transform = Compose([
Resize((256, 256), interp='NEAREST'),
RandomPaddingCrop(crop_size=176),
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train),
],
stay_rgb=True,
is_permute=False)
color_set = Colorizedataset(transform=transform, mode='train')
if is_train:
model.train()
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer.train(color_set, epochs=3, batch_size=1, eval_dataset=color_set, save_interval=1)
trainer.train(color_set, epochs=101, batch_size=5, eval_dataset=color_set, log_interval=10, save_interval=10)
import paddle.fluid as fluid
import paddle
import paddlehub as hub
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import ParallelEnv
from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.flowers import Flowers
from paddlehub.process.transforms import Compose, Resize, Normalize
from paddlehub.module.cv_module import ImageClassifierModule
if __name__ == '__main__':
with fluid.dygraph.guard(fluid.CUDAPlace(ParallelEnv().dev_id)):
paddle.disable_static(paddle.CUDAPlace(ParallelEnv().dev_id))
transforms = Compose([Resize((224, 224)), Normalize()])
flowers = Flowers(transforms)
flowers_validate = Flowers(transforms, mode='val')
model = hub.Module(directory='mobilenet_v2_animals', class_dim=flowers.num_classes)
# model = hub.Module(name='mobilenet_v2_animals', class_dim=flowers.num_classes)
model = hub.Module(name='mobilenet_v2_imagenet', class_dim=flowers.num_classes)
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=0.001, parameter_list=model.parameters())
optimizer = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer.train(flowers, epochs=100, batch_size=32, eval_dataset=flowers_validate, save_interval=1)
import paddle
import paddlehub as hub
if __name__ == '__main__':
place = paddle.CUDAPlace(0)
paddle.disable_static()
model = hub.Module(name='msgnet')
model.eval()
result = model.predict("venice-boat.jpg", "candy.jpg")
import paddle
import paddlehub as hub
from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.styletransfer import StyleTransferData
from paddlehub.process.transforms import Compose, Resize, CenterCrop, SetType
if __name__ == "__main__":
place = paddle.CUDAPlace(0)
paddle.disable_static()
model = hub.Module(name='msgnet')
transform = Compose([Resize((256, 256), interp='LINEAR'), CenterCrop(crop_size=256)], SetType(datatype='float32'))
styledata = StyleTransferData(transform)
model.train()
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer.train(styledata, epochs=5, batch_size=1, eval_dataset=styledata, log_interval=1, save_interval=1)
......@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import numpy
import paddle.nn as nn
from paddle.nn import Conv2d, ConvTranspose2d
from paddlehub.module.module import moduleinfo
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess
from paddlehub.module.cv_module import ImageColorizeModule
......@@ -178,24 +179,31 @@ class UserGuidedColorization(nn.Layer):
if load_checkpoint is not None:
model_dict = paddle.load(load_checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained model success")
print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'user_guided.pdparams')
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained checkpoint success")
def transforms(self, images: str, is_train: bool = True) -> callable:
if is_train:
transform = Compose([
Resize((256, 256), interp="RANDOM"),
Resize((256, 256), interp='NEAREST'),
RandomPaddingCrop(crop_size=176),
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train)
],
stay_rgb=True)
stay_rgb=True,
is_permute=False)
else:
transform = Compose([
Resize((256, 256), interp="RANDOM"),
Resize((256, 256), interp='NEAREST'),
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train)
],
stay_rgb=True)
stay_rgb=True,
is_permute=False)
return transform(images)
def forward(self,
......
import os
import paddle
import paddle.nn as nn
import numpy as np
import paddle.nn.functional as F
from paddlehub.module.module import moduleinfo
from paddlehub.process.transforms import Compose, Resize, CenterCrop, SetType
from paddlehub.module.cv_module import StyleTransferModule
class GramMatrix(nn.Layer):
"""Calculate gram matrix"""
def forward(self, y):
(b, ch, h, w) = y.size()
features = y.reshape((b, ch, w * h))
features_t = features.transpose((0, 2, 1))
gram = features.bmm(features_t) / (ch * h * w)
return gram
class ConvLayer(nn.Layer):
"""Basic conv layer with reflection padding layer"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int):
super(ConvLayer, self).__init__()
pad = int(np.floor(kernel_size / 2))
self.reflection_pad = nn.ReflectionPad2d([pad, pad, pad, pad])
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x: paddle.Tensor):
out = self.reflection_pad(x)
out = self.conv2d(out)
return out
class UpsampleConvLayer(nn.Layer):
"""
Upsamples the input and then does a convolution. This method gives better results compared to ConvTranspose2d.
ref: http://distill.pub/2016/deconv-checkerboard/
Args:
in_channels(int): Number of input channels.
out_channels(int): Number of output channels.
kernel_size(int): Number of kernel size.
stride(int): Number of stride.
upsample(int): Scale factor for upsample layer, default is None.
Return:
img(paddle.Tensor): UpsampleConvLayer output.
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, upsample=None):
super(UpsampleConvLayer, self).__init__()
self.upsample = upsample
if upsample:
self.upsample_layer = nn.UpSample(scale_factor=upsample)
self.pad = int(np.floor(kernel_size / 2))
if self.pad != 0:
self.reflection_pad = nn.ReflectionPad2d([self.pad, self.pad, self.pad, self.pad])
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x):
if self.upsample:
x = self.upsample_layer(x)
if self.pad != 0:
x = self.reflection_pad(x)
out = self.conv2d(x)
return out
class Bottleneck(nn.Layer):
""" Pre-activation residual block
Identity Mapping in Deep Residual Networks
ref https://arxiv.org/abs/1603.05027
Args:
inplanes(int): Number of input channels.
planes(int): Number of output channels.
stride(int): Number of stride.
downsample(int): Scale factor for downsample layer, default is None.
norm_layer(nn.Layer): Batch norm layer, default is nn.BatchNorm2d.
Return:
img(paddle.Tensor): Bottleneck output.
"""
def __init__(self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: int = None,
norm_layer: nn.Layer = nn.BatchNorm2d):
super(Bottleneck, self).__init__()
self.expansion = 4
self.downsample = downsample
if self.downsample is not None:
self.residual_layer = nn.Conv2d(inplanes, planes * self.expansion, kernel_size=1, stride=stride)
conv_block = (norm_layer(inplanes), nn.ReLU(), nn.Conv2d(inplanes, planes, kernel_size=1, stride=1),
norm_layer(planes), nn.ReLU(), ConvLayer(planes, planes, kernel_size=3, stride=stride),
norm_layer(planes), nn.ReLU(), nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
stride=1))
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x: paddle.Tensor):
if self.downsample is not None:
residual = self.residual_layer(x)
else:
residual = x
m = self.conv_block(x)
return residual + self.conv_block(x)
class UpBottleneck(nn.Layer):
""" Up-sample residual block (from MSG-Net paper)
Enables passing identity all the way through the generator
ref https://arxiv.org/abs/1703.06953
Args:
inplanes(int): Number of input channels.
planes(int): Number of output channels.
stride(int): Number of stride, default is 2.
norm_layer(nn.Layer): Batch norm layer, default is nn.BatchNorm2d.
Return:
img(paddle.Tensor): UpBottleneck output.
"""
def __init__(self, inplanes: int, planes: int, stride: int = 2, norm_layer: nn.Layer = nn.BatchNorm2d):
super(UpBottleneck, self).__init__()
self.expansion = 4
self.residual_layer = UpsampleConvLayer(inplanes,
planes * self.expansion,
kernel_size=1,
stride=1,
upsample=stride)
conv_block = []
conv_block += [norm_layer(inplanes), nn.ReLU(), nn.Conv2d(inplanes, planes, kernel_size=1, stride=1)]
conv_block += [
norm_layer(planes),
nn.ReLU(),
UpsampleConvLayer(planes, planes, kernel_size=3, stride=1, upsample=stride)
]
conv_block += [
norm_layer(planes),
nn.ReLU(),
nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1)
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x: paddle.Tensor):
return self.residual_layer(x) + self.conv_block(x)
class Inspiration(nn.Layer):
""" Inspiration Layer (from MSG-Net paper)
tuning the featuremap with target Gram Matrix
ref https://arxiv.org/abs/1703.06953
Args:
C(int): Number of input channels.
B(int): B is equal to 1 or input mini_batch, default is 1.
Return:
img(paddle.Tensor): UpBottleneck output.
"""
def __init__(self, C: int, B: int = 1):
super(Inspiration, self).__init__()
self.weight = self.weight = paddle.create_parameter(shape=[1, C, C], dtype='float32')
# non-parameter buffer
self.G = paddle.to_tensor(np.random.rand(B, C, C))
self.C = C
def setTarget(self, target: paddle.Tensor):
self.G = target
def forward(self, X: paddle.Tensor):
# input X is a 3D feature map
self.P = paddle.bmm(self.weight.expand_as(self.G), self.G)
x = paddle.bmm(
self.P.transpose((0, 2, 1)).expand((X.shape[0], self.C, self.C)), X.reshape(
(X.shape[0], X.shape[1], -1))).reshape(X.shape)
return x
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'N x ' + str(self.C) + ')'
class Vgg16(nn.Layer):
""" First four layers from Vgg16."""
def __init__(self):
super(Vgg16, self).__init__()
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
checkpoint = os.path.join(self.directory, 'vgg16.pdparams')
if not os.path.exists(checkpoint):
os.system('wget https://bj.bcebos.com/paddlehub/model/image/image_editing/vgg_paddle.pdparams -O ' +
checkpoint)
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained vgg16 checkpoint success")
def forward(self, X):
h = F.relu(self.conv1_1(X))
h = F.relu(self.conv1_2(h))
relu1_2 = h
h = F.max_pool2d(h, kernel_size=2, stride=2)
h = F.relu(self.conv2_1(h))
h = F.relu(self.conv2_2(h))
relu2_2 = h
h = F.max_pool2d(h, kernel_size=2, stride=2)
h = F.relu(self.conv3_1(h))
h = F.relu(self.conv3_2(h))
h = F.relu(self.conv3_3(h))
relu3_3 = h
h = F.max_pool2d(h, kernel_size=2, stride=2)
h = F.relu(self.conv4_1(h))
h = F.relu(self.conv4_2(h))
h = F.relu(self.conv4_3(h))
relu4_3 = h
return [relu1_2, relu2_2, relu3_3, relu4_3]
@moduleinfo(
name="msgnet",
type="CV/image_editing",
author="paddlepaddle",
author_email="",
summary="Msgnet is a image colorization style transfer model, this module is trained with COCO2014 dataset.",
version="1.0.0",
meta=StyleTransferModule)
class MSGNet(nn.Layer):
""" MSGNet (from MSG-Net paper)
Enables passing identity all the way through the generator
ref https://arxiv.org/abs/1703.06953
Args:
input_nc(int): Number of input channels, default is 3.
output_nc(int): Number of output channels, default is 3.
ngf(int): Number of input channel for middle layer, default is 128.
n_blocks(int): Block number, default is 6.
norm_layer(nn.Layer): Batch norm layer, default is nn.InstanceNorm2d.
load_checkpoint(str): Pretrained checkpoint path, default is None.
Return:
img(paddle.Tensor): MSGNet output.
"""
def __init__(self,
input_nc=3,
output_nc=3,
ngf=128,
n_blocks=6,
norm_layer=nn.InstanceNorm2d,
load_checkpoint=None):
super(MSGNet, self).__init__()
self.gram = GramMatrix()
block = Bottleneck
upblock = UpBottleneck
expansion = 4
model1 = [
ConvLayer(input_nc, 64, kernel_size=7, stride=1),
norm_layer(64),
nn.ReLU(),
block(64, 32, 2, 1, norm_layer),
block(32 * expansion, ngf, 2, 1, norm_layer)
]
self.model1 = nn.Sequential(*tuple(model1))
model = []
model += model1
self.ins = Inspiration(ngf * expansion)
model.append(self.ins)
for i in range(n_blocks):
model += [block(ngf * expansion, ngf, 1, None, norm_layer)]
model += [
upblock(ngf * expansion, 32, 2, norm_layer),
upblock(32 * expansion, 16, 2, norm_layer),
norm_layer(16 * expansion),
nn.ReLU(),
ConvLayer(16 * expansion, output_nc, kernel_size=7, stride=1)
]
model = tuple(model)
self.model = nn.Sequential(*model)
if load_checkpoint is not None:
model_dict = paddle.load(load_checkpoint)[0]
self.set_dict(model_dict)
print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'style_paddle.pdparams')
if not os.path.exists(checkpoint):
os.system('wget https://bj.bcebos.com/paddlehub/model/image/image_editing/style_paddle.pdparams -O ' +
checkpoint)
model_dict = paddle.load(checkpoint)[0]
model_dict_clone = model_dict.copy()
for key, value in model_dict_clone.items():
if key.endswith(("scale")):
name = key.rsplit('.', 1)[0] + '.bias'
model_dict[name] = paddle.zeros(shape=model_dict[name].shape, dtype='float32')
model_dict[key] = paddle.ones(shape=model_dict[key].shape, dtype='float32')
self.set_dict(model_dict)
print("load pretrained checkpoint success")
self._vgg = None
def transform(self, path: str):
transform = Compose([Resize(
(256, 256), interp='LINEAR'), CenterCrop(crop_size=256)], SetType(datatype='float32'))
return transform(path)
def setTarget(self, Xs: paddle.Tensor):
"""Calculate feature gram matrix"""
F = self.model1(Xs)
G = self.gram(F)
self.ins.setTarget(G)
def getFeature(self, input: paddle.Tensor):
if not self._vgg:
self._vgg = Vgg16()
return self._vgg(input)
def forward(self, input: paddle.Tensor):
return self.model(input)
......@@ -36,5 +36,8 @@ class InstallCommand:
elif os.path.exists(_arg) and xarfile.is_xarfile(_arg):
manager.install(archive=_arg)
else:
manager.install(name=_arg)
_arg = _arg.split('==')
name = _arg[0]
version = None if len(_arg) == 1 else _arg[1]
manager.install(name=name, version=version)
return True
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
from typing import Any, List
from paddlehub.compat.module.module_v1 import ModuleV1
from paddlehub.commands import register
from paddlehub.module.manager import LocalModuleManager
from paddlehub.module.module import Module, InvalidHubModule
@register(name='hub.run', description='Run the specific module.')
class RunCommand:
def execute(self, argv: List) -> bool:
if not argv:
print('ERROR: You must give one module to run.')
return False
module_name = argv[0]
if os.path.exists(module_name) and os.path.isdir(module_name):
try:
module = Module.load(module_name)
except InvalidHubModule:
print('{} is not a valid HubModule'.format(module_name))
return False
except:
print('Some exception occurred while loading the {}'.format(module_name))
return False
else:
module = Module(name=module_name)
if not module.is_runnable:
print('ERROR! Module {} is not executable.'.format(module_name))
return False
if isinstance(module, ModuleV1):
result = self.run_module_v1(module, argv[1:])
else:
result = module._run_func(argv[1:])
print(result)
return True
def run_module_v1(self, module, argv: List) -> Any:
parser = argparse.ArgumentParser(prog='hub run {}'.format(module.name), add_help=False)
arg_input_group = parser.add_argument_group(title='Input options', description='Data feed into the module.')
arg_config_group = parser.add_argument_group(
title='Config options', description='Run configuration for controlling module behavior, optional.')
arg_config_group.add_argument(
'--use_gpu', type=ast.literal_eval, default=False, help='whether use GPU for prediction')
arg_config_group.add_argument('--batch_size', type=int, default=1, help='batch size for prediction')
module_type = module.type.lower()
if module_type.startswith('cv'):
arg_input_group.add_argument(
'--input_path', type=str, default=None, help='path of image/video to predict', required=True)
else:
arg_input_group.add_argument('--input_text', type=str, default=None, help='text to predict', required=True)
args = parser.parse_args(argv)
except_data_format = module.processor.data_format(module.default_signature)
key = list(except_data_format.keys())[0]
input_data = {key: [args.input_path] if module_type.startswith('cv') else [args.input_text]}
return module(
sign_name=module.default_signature, data=input_data, use_gpu=args.use_gpu, batch_size=args.batch_size)
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from typing import List
from paddlehub.commands import register
from paddlehub.module.manager import LocalModuleManager
from paddlehub.server.server import module_server
from paddlehub.utils import log, platform
@register(name='hub.search', description='Search PaddleHub pretrained model through model keywords.')
class SearchCommand:
def execute(self, argv: List) -> bool:
argv = '.*' if not argv else argv[0]
widths = [20, 8, 30] if platform.is_windows() else [30, 8, 40]
table = log.Table(widths=widths)
table.append(*['ModuleName', 'Version', 'Summary'], aligns=['^', '^', '^'], colors=["blue", "blue", "blue"])
results = module_server.search_module(name=argv)
for result in results:
table.append(result['name'], result['version'], result['summary'])
print(table)
return True
......@@ -47,7 +47,7 @@ class ShowCommand:
widths = [15, 40] if platform.is_windows else [15, 50]
aligns = ['^', '<']
colors = ['yellow', '']
colors = ['cyan', '']
table = log.Table(widths=widths, colors=colors, aligns=aligns)
table.append('ModuleName', module.name)
......
......@@ -37,6 +37,7 @@ class ModuleV1(object):
self.desc = module_v1_utils.convert_module_desc(desc_file)
self.helper = self
self.signatures = self.desc.signatures
self.default_signature = self.desc.default_signature
self.directory = directory
self._load_model()
......@@ -185,6 +186,7 @@ class ModuleV1(object):
cls.type = module_info.type
cls.summary = module_info.summary
cls.version = utils.Version(module_info.version)
cls.directory = directory
return cls
@classmethod
......@@ -195,3 +197,7 @@ class ModuleV1(object):
def assets_path(self):
return os.path.join(self.directory, 'assets')
@property
def is_runnable(self):
return self.default_signature != None
......@@ -22,6 +22,7 @@ from paddlehub.process.functional import get_img_file
from paddlehub.env import DATA_HOME
from typing import Callable
class Colorizedataset(paddle.io.Dataset):
"""
Dataset for colorization.
......@@ -39,8 +40,6 @@ class Colorizedataset(paddle.io.Dataset):
self.file = 'train'
elif self.mode == 'test':
self.file = 'test'
else:
self.file = 'validation'
self.file = os.path.join(DATA_HOME, 'canvas', self.file)
self.data = get_img_file(self.file)
......
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable
import paddle
from paddlehub.process.functional import get_img_file
from paddlehub.env import DATA_HOME
class StyleTransferData(paddle.io.Dataset):
"""
Dataset for Style transfer.
Args:
transform(callmethod) : The method of preprocess images.
mode(str): The mode for preparing dataset.
Returns:
DataSet: An iterable object for data iterating
"""
def __init__(self, transform: Callable, mode: str = 'train'):
self.mode = mode
self.transform = transform
if self.mode == 'train':
self.file = 'train'
elif self.mode == 'test':
self.file = 'test'
self.file = os.path.join(DATA_HOME, 'minicoco', self.file)
self.style_file = os.path.join(DATA_HOME, 'minicoco', '21styles')
self.data = get_img_file(self.file)
self.style = get_img_file(self.style_file)
def __getitem__(self, idx: int):
img_path = self.data[idx]
im = self.transform(img_path)
style_idx = idx % len(self.style)
style_path = self.style[style_idx]
style = self.transform(style_path)
return im, style
def __len__(self):
return len(self.data)
......@@ -18,6 +18,7 @@ import os
from typing import List
from collections import OrderedDict
import cv2
import numpy as np
import paddle
import paddle.nn as nn
......@@ -137,6 +138,7 @@ class ImageColorizeModule(RunModule, ImageServing):
psnrs = []
lab2rgb = T.ConvertColorSpace(mode='LAB2RGB')
process = T.ColorPostprocess()
for i in range(batch[0].numpy().shape[0]):
real = lab2rgb(np.concatenate((batch[0].numpy(), batch[3].numpy()), axis=1))[i]
visual_ret['real'] = process(real)
......@@ -146,6 +148,7 @@ class ImageColorizeModule(RunModule, ImageServing):
psnr_value = 20 * np.log10(255. / np.sqrt(mse))
psnrs.append(psnr_value)
psnr = paddle.to_variable(np.array(psnrs))
return {'loss': loss, 'metrics': {'psnr': psnr}}
def predict(self, images: str, visualization: bool = True, save_path: str = 'result'):
......@@ -309,3 +312,87 @@ class Yolov3Module(RunModule, ImageServing):
Func.draw_boxes_on_image(imgpath, boxes, scores, labels, label_names, 0.5)
return boxes, scores, labels
class StyleTransferModule(RunModule, ImageServing):
def training_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as loss and metrics.
'''
return self.validation_step(batch, batch_idx)
def validation_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
'''
mse_loss = nn.MSELoss()
N, C, H, W = batch[0].shape
batch[1] = batch[1][0].unsqueeze(0)
self.setTarget(batch[1])
y = self(batch[0])
xc = paddle.to_tensor(batch[0].numpy().copy())
y = Func.subtract_imagenet_mean_batch(y)
xc = Func.subtract_imagenet_mean_batch(xc)
features_y = self.getFeature(y)
features_xc = self.getFeature(xc)
f_xc_c = paddle.to_tensor(features_xc[1].numpy(), stop_gradient=True)
content_loss = mse_loss(features_y[1], f_xc_c)
batch[1] = Func.subtract_imagenet_mean_batch(batch[1])
features_style = self.getFeature(batch[1])
gram_style = [Func.gram_matrix(y) for y in features_style]
style_loss = 0.
for m in range(len(features_y)):
gram_y = Func.gram_matrix(features_y[m])
gram_s = paddle.to_tensor(np.tile(gram_style[m].numpy(), (N, 1, 1, 1)))
style_loss += mse_loss(gram_y, gram_s[:N, :, :])
loss = content_loss + style_loss
return {'loss': loss, 'metrics': {'content gap': content_loss, 'style gap': style_loss}}
def predict(self, origin_path: str, style_path: str, visualization: bool = True, save_path: str = 'result'):
'''
Colorize images
Args:
origin_path(str): Content image path .
style_path(str): Style image path.
visualization(bool): Whether to save colorized images.
save_path(str) : Path to save colorized images.
Returns:
output(np.ndarray) : The style transformed images with bgr mode.
'''
content = paddle.to_tensor(self.transform(origin_path))
style = paddle.to_tensor(self.transform(style_path))
content = content.unsqueeze(0)
style = style.unsqueeze(0)
self.setTarget(style)
output = self(content)
output = paddle.clip(output[0].transpose((1, 2, 0)), 0, 255).numpy()
if visualization:
output = output.astype(np.uint8)
style_name = "style_" + str(time.time()) + ".png"
if not os.path.exists(save_path):
os.mkdir(save_path)
path = os.path.join(save_path, style_name)
cv2.imwrite(path, output)
return output
......@@ -58,7 +58,7 @@ class HubModuleNotFoundError(Exception):
hub_version = 'Any' if not info['hub_version'] else ', '.join(info['hub_version'])
table.append(self.name, _ver, paddle_version, hub_version, aligns=['^', '^', '^', '^'])
tips += ', \n{}'.format(table)
tips += ':\n{}'.format(table)
return tips
......@@ -104,7 +104,7 @@ class EnvironmentMismatchError(Exception):
table.append(self.name, _ver, paddle_version, hub_version, aligns=['^', '^', '^', '^'])
tips += ', \n{}'.format(table)
tips += ':\n{}'.format(table)
return tips
......@@ -238,7 +238,12 @@ class LocalModuleManager(object):
return self._local_modules[name]
result = module_server.search_module(name=name, version=version, source=source)
if not result:
for item in result:
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
if source or 'source' in item:
return self._install_from_source(result)
return self._install_from_url(item['url'])
module_infos = module_server.get_module_info(name=name, source=source)
# The HubModule with the specified name cannot be found
if not module_infos:
......@@ -250,17 +255,13 @@ class LocalModuleManager(object):
if utils.Version(_ver).match(version):
valid_infos[_ver] = _info
else:
valid_infos = list(module_infos.keys())
valid_infos = module_infos.copy()
# Cannot find a HubModule that meets the version
if valid_infos:
raise EnvironmentMismatchError(name=name, info=valid_infos, version=version)
raise HubModuleNotFoundError(name=name, info=module_infos, version=version, source=source)
if source or 'source' in result:
return self._install_from_source(result)
return self._install_from_url(result['url'])
def _install_from_source(self, source: str) -> HubModule:
'''Install a HubModule from Git Repo'''
name = source['name']
......
......@@ -148,6 +148,10 @@ class Module(object):
user_module = user_module_cls(directory=directory)
user_module._initialize(**kwargs)
return user_module
if user_module_cls == ModuleV1:
return user_module_cls(directory=directory, **kwargs)
user_module_cls.directory = directory
return user_module_cls(**kwargs)
......@@ -166,6 +170,10 @@ class Module(object):
user_module = user_module_cls(directory=directory)
user_module._initialize(**kwargs)
return user_module
if user_module_cls == ModuleV1:
return user_module_cls(directory=directory, **kwargs)
user_module_cls.directory = directory
return user_module_cls(**kwargs)
......
......@@ -119,7 +119,6 @@ def get_img_file(dir_name: str) -> list:
if not is_image_file(filename):
continue
img_path = os.path.join(parent, filename)
print(img_path)
images.append(img_path)
images.sort()
return images
......@@ -246,3 +245,22 @@ def get_label_infos(file_list: str):
for category in categories:
label_names.append(category['name'])
return label_names
def subtract_imagenet_mean_batch(batch: paddle.Tensor) -> paddle.Tensor:
"""Subtract ImageNet mean pixel-wise from a BGR image."""
mean = np.zeros(shape=batch.shape, dtype='float32')
mean[:, 0, :, :] = 103.939
mean[:, 1, :, :] = 116.779
mean[:, 2, :, :] = 123.680
mean = paddle.to_tensor(mean)
return batch - mean
def gram_matrix(data: paddle.Tensor) -> paddle.Tensor:
"""Get gram matrix"""
b, ch, h, w = data.shape
features = data.reshape((b, ch, w * h))
features_t = features.transpose((0, 2, 1))
gram = features.bmm(features_t) / (ch * h * w)
return gram
......@@ -26,7 +26,7 @@ from paddlehub.process.functional import *
class Compose:
def __init__(self, transforms, to_rgb=True, stay_rgb=False):
def __init__(self, transforms, to_rgb=True, stay_rgb=False, is_permute=True):
if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!')
if len(transforms) < 1:
......@@ -35,6 +35,7 @@ class Compose:
self.transforms = transforms
self.to_rgb = to_rgb
self.stay_rgb = stay_rgb
self.is_permute = is_permute
def __call__(self, im):
if isinstance(im, str):
......@@ -49,6 +50,9 @@ class Compose:
im = op(im)
if not self.stay_rgb:
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
if self.is_permute:
im = permute(im)
return im
......@@ -570,17 +574,7 @@ class ColorizeHint:
self.use_avg = use_avg
def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray):
sample_Ps = [
1,
2,
3,
4,
5,
6,
7,
8,
9,
]
sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9]
self.data = data
self.hint = hint
self.mask = mask
......@@ -591,7 +585,7 @@ class ColorizeHint:
while cont_cond:
if self.num_points is None: # draw from geometric
# embed()
cont_cond = np.random.rand() < (1 - self.percent)
cont_cond = np.random.rand() > (1 - self.percent)
else: # add certain number of points
cont_cond = pp < self.num_points
if not cont_cond: # skip out of loop if condition not met
......@@ -659,7 +653,7 @@ class ColorizePreprocess:
"""
def __init__(self,
ab_thresh: float = 0.,
p: float = .125,
p: float = 0.,
num_points: int = None,
samp: str = 'normal',
use_avg: bool = True,
......@@ -733,3 +727,41 @@ class ColorPostprocess:
img = np.clip(img, 0, 1) * 255
img = img.astype(self.type)
return img
class CenterCrop:
"""
Crop the middle part of the image to the specified size.
Args:
crop_size(int): Crop size.
Return:
img(np.ndarray): Croped image.
"""
def __init__(self, crop_size: int):
self.crop_size = crop_size
def __call__(self, img: np.ndarray):
img_width, img_height, chanel = img.shape
crop_top = int((img_height - self.crop_size) / 2.)
crop_left = int((img_width - self.crop_size) / 2.)
return img[crop_left:crop_left + self.crop_size, crop_top:crop_top + self.crop_size, :]
class SetType:
"""
Set image type.
Args:
type(type): Type of Image value.
Return:
img(np.ndarray): Transformed image.
"""
def __init__(self, datatype: type = 'float32'):
self.type = datatype
def __call__(self, img: np.ndarray):
img = img.astype(self.type)
return img
......@@ -18,6 +18,7 @@ import importlib
import os
import sys
from collections import OrderedDict
from typing import List
from urllib.parse import urlparse
import git
......@@ -74,7 +75,7 @@ class GitSource(object):
log.logger.warning('An error occurred while loading {}'.format(self.path))
sys.path.remove(self.path)
def search_module(self, name: str, version: str = None) -> dict:
def search_module(self, name: str, version: str = None) -> List[dict]:
'''
Search PaddleHub module
......@@ -84,7 +85,7 @@ class GitSource(object):
'''
return self.search_resource(type='module', name=name, version=version)
def search_resource(self, type: str, name: str, version: str = None) -> dict:
def search_resource(self, type: str, name: str, version: str = None) -> List[dict]:
'''
Search PaddleHub Resource
......@@ -95,13 +96,13 @@ class GitSource(object):
'''
module = self.hub_modules.get(name, None)
if module and module.version.match(version):
return {
return [{
'version': module.version,
'name': module.name,
'path': self.path,
'class': module.__name__,
'source': self.url
}
}]
return None
@classmethod
......
......@@ -14,6 +14,7 @@
# limitations under the License.
from collections import OrderedDict
from typing import List
from paddlehub.server import ServerSource, GitSource
......@@ -44,7 +45,7 @@ class HubServer(object):
'''Remove a module source'''
self.sources.pop(key)
def search_module(self, name: str, version: str = None, source: str = None) -> dict:
def search_module(self, name: str, version: str = None, source: str = None) -> List[dict]:
'''
Search PaddleHub module
......@@ -54,7 +55,7 @@ class HubServer(object):
'''
return self.search_resource(type='module', name=name, version=version, source=source)
def search_resource(self, type: str, name: str, version: str = None, source: str = None) -> dict:
def search_resource(self, type: str, name: str, version: str = None, source: str = None) -> List[dict]:
'''
Search PaddleHub Resource
......@@ -68,7 +69,7 @@ class HubServer(object):
result = source.search_resource(name=name, type=type, version=version)
if result:
return result
return {}
return []
def get_module_info(self, name: str, source: str = None) -> dict:
'''
......
......@@ -43,7 +43,7 @@ class ServerSource(object):
self._url = url
self._timeout = timeout
def search_module(self, name: str, version: str = None) -> dict:
def search_module(self, name: str, version: str = None) -> List[dict]:
'''
Search PaddleHub module
......@@ -53,7 +53,7 @@ class ServerSource(object):
'''
return self.search_resource(type='module', name=name, version=version)
def search_resource(self, type: str, name: str, version: str = None) -> dict:
def search_resource(self, type: str, name: str, version: str = None) -> List[dict]:
'''
Search PaddleHub Resource
......@@ -76,9 +76,7 @@ class ServerSource(object):
result = self.request(path='search', params=params)
if result['status'] == 0 and len(result['data']) > 0:
for item in result['data']:
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
return item
return result['data']
return None
def get_module_info(self, name: str) -> dict:
......
......@@ -41,11 +41,17 @@ def download(name: str, save_path: str, version: str = None):
if os.path.exists(file):
return
resource = module_server.search_resouce(name=name, version=version, type='Model')
if not resource:
resources = module_server.search_resouce(name=name, version=version, type='Model')
if not resources:
raise ResourceNotFoundError(name, version)
for item in resources:
if item['name'] == name and utils.Version(item['version']).match(version):
url = item['url']
break
else:
raise ResourceNotFoundError(name, version)
url = resource['url']
with utils.generate_tempdir() as _dir:
if not os.path.exists(save_path):
os.makedirs(save_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册