提交 722c2d58 编写于 作者: H haoyuying

revise conflict

...@@ -2,9 +2,8 @@ import paddle ...@@ -2,9 +2,8 @@ import paddle
import paddlehub as hub import paddlehub as hub
import paddle.nn as nn import paddle.nn as nn
if __name__ == '__main__': if __name__ == '__main__':
paddle.disable_static() paddle.disable_static()
model = hub.Module(directory='user_guided_colorization') model = hub.Module(name='user_guided_colorization')
model.eval() model.eval()
result = model.predict(images='sea.jpg') result = model.predict(images='house.png')
\ No newline at end of file
...@@ -3,18 +3,24 @@ import paddlehub as hub ...@@ -3,18 +3,24 @@ import paddlehub as hub
import paddle.nn as nn import paddle.nn as nn
from paddlehub.finetune.trainer import Trainer from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.colorizedataset import Colorizedataset from paddlehub.datasets.colorizedataset import Colorizedataset
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess
if __name__ == '__main__': if __name__ == '__main__':
is_train = True is_train = True
paddle.disable_static() paddle.disable_static()
model = hub.Module(directory='user_guided_colorization') model = hub.Module(name='user_guided_colorization')
transform = Compose([Resize((256,256),interp="RANDOM"),RandomPaddingCrop(crop_size=176), ConvertColorSpace(mode='RGB2LAB'), ColorizePreprocess(ab_thresh=0, p=1)], stay_rgb=True) transform = Compose([
color_set = Colorizedataset(transform=transform, mode=is_train) Resize((256, 256), interp='NEAREST'),
RandomPaddingCrop(crop_size=176),
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train),
],
stay_rgb=True,
is_permute=False)
color_set = Colorizedataset(transform=transform, mode='train')
if is_train: if is_train:
model.train() model.train()
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters()) optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls') trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer.train(color_set, epochs=3, batch_size=1, eval_dataset=color_set, save_interval=1) trainer.train(color_set, epochs=101, batch_size=5, eval_dataset=color_set, log_interval=10, save_interval=10)
import paddle.fluid as fluid import paddle
import paddlehub as hub import paddlehub as hub
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.distributed import ParallelEnv
from paddlehub.finetune.trainer import Trainer from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.flowers import Flowers from paddlehub.datasets.flowers import Flowers
from paddlehub.process.transforms import Compose, Resize, Normalize from paddlehub.process.transforms import Compose, Resize, Normalize
from paddlehub.module.cv_module import ImageClassifierModule from paddlehub.module.cv_module import ImageClassifierModule
if __name__ == '__main__': if __name__ == '__main__':
with fluid.dygraph.guard(fluid.CUDAPlace(ParallelEnv().dev_id)): paddle.disable_static(paddle.CUDAPlace(ParallelEnv().dev_id))
transforms = Compose([Resize((224, 224)), Normalize()]) transforms = Compose([Resize((224, 224)), Normalize()])
flowers = Flowers(transforms) flowers = Flowers(transforms)
flowers_validate = Flowers(transforms, mode='val') flowers_validate = Flowers(transforms, mode='val')
model = hub.Module(directory='mobilenet_v2_animals', class_dim=flowers.num_classes) model = hub.Module(name='mobilenet_v2_imagenet', class_dim=flowers.num_classes)
# model = hub.Module(name='mobilenet_v2_animals', class_dim=flowers.num_classes)
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=0.001, parameter_list=model.parameters()) optimizer = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls') trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer.train(flowers, epochs=100, batch_size=32, eval_dataset=flowers_validate, save_interval=1) trainer.train(flowers, epochs=100, batch_size=32, eval_dataset=flowers_validate, save_interval=1)
import paddle
import paddlehub as hub
if __name__ == '__main__':
place = paddle.CUDAPlace(0)
paddle.disable_static()
model = hub.Module(name='msgnet')
model.eval()
result = model.predict("venice-boat.jpg", "candy.jpg")
import paddle
import paddlehub as hub
from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.styletransfer import StyleTransferData
from paddlehub.process.transforms import Compose, Resize, CenterCrop, SetType
if __name__ == "__main__":
place = paddle.CUDAPlace(0)
paddle.disable_static()
model = hub.Module(name='msgnet')
transform = Compose([Resize((256, 256), interp='LINEAR'), CenterCrop(crop_size=256)], SetType(datatype='float32'))
styledata = StyleTransferData(transform)
model.train()
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer.train(styledata, epochs=5, batch_size=1, eval_dataset=styledata, log_interval=1, save_interval=1)
...@@ -12,12 +12,13 @@ ...@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import paddle import paddle
import numpy import numpy
import paddle.nn as nn import paddle.nn as nn
from paddle.nn import Conv2d, ConvTranspose2d from paddle.nn import Conv2d, ConvTranspose2d
from paddlehub.module.module import moduleinfo from paddlehub.module.module import moduleinfo
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess
from paddlehub.module.cv_module import ImageColorizeModule from paddlehub.module.cv_module import ImageColorizeModule
...@@ -178,24 +179,31 @@ class UserGuidedColorization(nn.Layer): ...@@ -178,24 +179,31 @@ class UserGuidedColorization(nn.Layer):
if load_checkpoint is not None: if load_checkpoint is not None:
model_dict = paddle.load(load_checkpoint)[0] model_dict = paddle.load(load_checkpoint)[0]
self.set_dict(model_dict) self.set_dict(model_dict)
print("load pretrained model success") print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'user_guided.pdparams')
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained checkpoint success")
def transforms(self, images: str, is_train: bool = True) -> callable: def transforms(self, images: str, is_train: bool = True) -> callable:
if is_train: if is_train:
transform = Compose([ transform = Compose([
Resize((256, 256), interp="RANDOM"), Resize((256, 256), interp='NEAREST'),
RandomPaddingCrop(crop_size=176), RandomPaddingCrop(crop_size=176),
ConvertColorSpace(mode='RGB2LAB'), ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train) ColorizePreprocess(ab_thresh=0, is_train=is_train)
], ],
stay_rgb=True) stay_rgb=True,
is_permute=False)
else: else:
transform = Compose([ transform = Compose([
Resize((256, 256), interp="RANDOM"), Resize((256, 256), interp='NEAREST'),
ConvertColorSpace(mode='RGB2LAB'), ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train) ColorizePreprocess(ab_thresh=0, is_train=is_train)
], ],
stay_rgb=True) stay_rgb=True,
is_permute=False)
return transform(images) return transform(images)
def forward(self, def forward(self,
......
import os
import paddle
import paddle.nn as nn
import numpy as np
import paddle.nn.functional as F
from paddlehub.module.module import moduleinfo
from paddlehub.process.transforms import Compose, Resize, CenterCrop, SetType
from paddlehub.module.cv_module import StyleTransferModule
class GramMatrix(nn.Layer):
"""Calculate gram matrix"""
def forward(self, y):
(b, ch, h, w) = y.size()
features = y.reshape((b, ch, w * h))
features_t = features.transpose((0, 2, 1))
gram = features.bmm(features_t) / (ch * h * w)
return gram
class ConvLayer(nn.Layer):
"""Basic conv layer with reflection padding layer"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int):
super(ConvLayer, self).__init__()
pad = int(np.floor(kernel_size / 2))
self.reflection_pad = nn.ReflectionPad2d([pad, pad, pad, pad])
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x: paddle.Tensor):
out = self.reflection_pad(x)
out = self.conv2d(out)
return out
class UpsampleConvLayer(nn.Layer):
"""
Upsamples the input and then does a convolution. This method gives better results compared to ConvTranspose2d.
ref: http://distill.pub/2016/deconv-checkerboard/
Args:
in_channels(int): Number of input channels.
out_channels(int): Number of output channels.
kernel_size(int): Number of kernel size.
stride(int): Number of stride.
upsample(int): Scale factor for upsample layer, default is None.
Return:
img(paddle.Tensor): UpsampleConvLayer output.
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, upsample=None):
super(UpsampleConvLayer, self).__init__()
self.upsample = upsample
if upsample:
self.upsample_layer = nn.UpSample(scale_factor=upsample)
self.pad = int(np.floor(kernel_size / 2))
if self.pad != 0:
self.reflection_pad = nn.ReflectionPad2d([self.pad, self.pad, self.pad, self.pad])
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x):
if self.upsample:
x = self.upsample_layer(x)
if self.pad != 0:
x = self.reflection_pad(x)
out = self.conv2d(x)
return out
class Bottleneck(nn.Layer):
""" Pre-activation residual block
Identity Mapping in Deep Residual Networks
ref https://arxiv.org/abs/1603.05027
Args:
inplanes(int): Number of input channels.
planes(int): Number of output channels.
stride(int): Number of stride.
downsample(int): Scale factor for downsample layer, default is None.
norm_layer(nn.Layer): Batch norm layer, default is nn.BatchNorm2d.
Return:
img(paddle.Tensor): Bottleneck output.
"""
def __init__(self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: int = None,
norm_layer: nn.Layer = nn.BatchNorm2d):
super(Bottleneck, self).__init__()
self.expansion = 4
self.downsample = downsample
if self.downsample is not None:
self.residual_layer = nn.Conv2d(inplanes, planes * self.expansion, kernel_size=1, stride=stride)
conv_block = (norm_layer(inplanes), nn.ReLU(), nn.Conv2d(inplanes, planes, kernel_size=1, stride=1),
norm_layer(planes), nn.ReLU(), ConvLayer(planes, planes, kernel_size=3, stride=stride),
norm_layer(planes), nn.ReLU(), nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
stride=1))
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x: paddle.Tensor):
if self.downsample is not None:
residual = self.residual_layer(x)
else:
residual = x
m = self.conv_block(x)
return residual + self.conv_block(x)
class UpBottleneck(nn.Layer):
""" Up-sample residual block (from MSG-Net paper)
Enables passing identity all the way through the generator
ref https://arxiv.org/abs/1703.06953
Args:
inplanes(int): Number of input channels.
planes(int): Number of output channels.
stride(int): Number of stride, default is 2.
norm_layer(nn.Layer): Batch norm layer, default is nn.BatchNorm2d.
Return:
img(paddle.Tensor): UpBottleneck output.
"""
def __init__(self, inplanes: int, planes: int, stride: int = 2, norm_layer: nn.Layer = nn.BatchNorm2d):
super(UpBottleneck, self).__init__()
self.expansion = 4
self.residual_layer = UpsampleConvLayer(inplanes,
planes * self.expansion,
kernel_size=1,
stride=1,
upsample=stride)
conv_block = []
conv_block += [norm_layer(inplanes), nn.ReLU(), nn.Conv2d(inplanes, planes, kernel_size=1, stride=1)]
conv_block += [
norm_layer(planes),
nn.ReLU(),
UpsampleConvLayer(planes, planes, kernel_size=3, stride=1, upsample=stride)
]
conv_block += [
norm_layer(planes),
nn.ReLU(),
nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1)
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x: paddle.Tensor):
return self.residual_layer(x) + self.conv_block(x)
class Inspiration(nn.Layer):
""" Inspiration Layer (from MSG-Net paper)
tuning the featuremap with target Gram Matrix
ref https://arxiv.org/abs/1703.06953
Args:
C(int): Number of input channels.
B(int): B is equal to 1 or input mini_batch, default is 1.
Return:
img(paddle.Tensor): UpBottleneck output.
"""
def __init__(self, C: int, B: int = 1):
super(Inspiration, self).__init__()
self.weight = self.weight = paddle.create_parameter(shape=[1, C, C], dtype='float32')
# non-parameter buffer
self.G = paddle.to_tensor(np.random.rand(B, C, C))
self.C = C
def setTarget(self, target: paddle.Tensor):
self.G = target
def forward(self, X: paddle.Tensor):
# input X is a 3D feature map
self.P = paddle.bmm(self.weight.expand_as(self.G), self.G)
x = paddle.bmm(
self.P.transpose((0, 2, 1)).expand((X.shape[0], self.C, self.C)), X.reshape(
(X.shape[0], X.shape[1], -1))).reshape(X.shape)
return x
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'N x ' + str(self.C) + ')'
class Vgg16(nn.Layer):
""" First four layers from Vgg16."""
def __init__(self):
super(Vgg16, self).__init__()
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
checkpoint = os.path.join(self.directory, 'vgg16.pdparams')
if not os.path.exists(checkpoint):
os.system('wget https://bj.bcebos.com/paddlehub/model/image/image_editing/vgg_paddle.pdparams -O ' +
checkpoint)
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained vgg16 checkpoint success")
def forward(self, X):
h = F.relu(self.conv1_1(X))
h = F.relu(self.conv1_2(h))
relu1_2 = h
h = F.max_pool2d(h, kernel_size=2, stride=2)
h = F.relu(self.conv2_1(h))
h = F.relu(self.conv2_2(h))
relu2_2 = h
h = F.max_pool2d(h, kernel_size=2, stride=2)
h = F.relu(self.conv3_1(h))
h = F.relu(self.conv3_2(h))
h = F.relu(self.conv3_3(h))
relu3_3 = h
h = F.max_pool2d(h, kernel_size=2, stride=2)
h = F.relu(self.conv4_1(h))
h = F.relu(self.conv4_2(h))
h = F.relu(self.conv4_3(h))
relu4_3 = h
return [relu1_2, relu2_2, relu3_3, relu4_3]
@moduleinfo(
name="msgnet",
type="CV/image_editing",
author="paddlepaddle",
author_email="",
summary="Msgnet is a image colorization style transfer model, this module is trained with COCO2014 dataset.",
version="1.0.0",
meta=StyleTransferModule)
class MSGNet(nn.Layer):
""" MSGNet (from MSG-Net paper)
Enables passing identity all the way through the generator
ref https://arxiv.org/abs/1703.06953
Args:
input_nc(int): Number of input channels, default is 3.
output_nc(int): Number of output channels, default is 3.
ngf(int): Number of input channel for middle layer, default is 128.
n_blocks(int): Block number, default is 6.
norm_layer(nn.Layer): Batch norm layer, default is nn.InstanceNorm2d.
load_checkpoint(str): Pretrained checkpoint path, default is None.
Return:
img(paddle.Tensor): MSGNet output.
"""
def __init__(self,
input_nc=3,
output_nc=3,
ngf=128,
n_blocks=6,
norm_layer=nn.InstanceNorm2d,
load_checkpoint=None):
super(MSGNet, self).__init__()
self.gram = GramMatrix()
block = Bottleneck
upblock = UpBottleneck
expansion = 4
model1 = [
ConvLayer(input_nc, 64, kernel_size=7, stride=1),
norm_layer(64),
nn.ReLU(),
block(64, 32, 2, 1, norm_layer),
block(32 * expansion, ngf, 2, 1, norm_layer)
]
self.model1 = nn.Sequential(*tuple(model1))
model = []
model += model1
self.ins = Inspiration(ngf * expansion)
model.append(self.ins)
for i in range(n_blocks):
model += [block(ngf * expansion, ngf, 1, None, norm_layer)]
model += [
upblock(ngf * expansion, 32, 2, norm_layer),
upblock(32 * expansion, 16, 2, norm_layer),
norm_layer(16 * expansion),
nn.ReLU(),
ConvLayer(16 * expansion, output_nc, kernel_size=7, stride=1)
]
model = tuple(model)
self.model = nn.Sequential(*model)
if load_checkpoint is not None:
model_dict = paddle.load(load_checkpoint)[0]
self.set_dict(model_dict)
print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'style_paddle.pdparams')
if not os.path.exists(checkpoint):
os.system('wget https://bj.bcebos.com/paddlehub/model/image/image_editing/style_paddle.pdparams -O ' +
checkpoint)
model_dict = paddle.load(checkpoint)[0]
model_dict_clone = model_dict.copy()
for key, value in model_dict_clone.items():
if key.endswith(("scale")):
name = key.rsplit('.', 1)[0] + '.bias'
model_dict[name] = paddle.zeros(shape=model_dict[name].shape, dtype='float32')
model_dict[key] = paddle.ones(shape=model_dict[key].shape, dtype='float32')
self.set_dict(model_dict)
print("load pretrained checkpoint success")
self._vgg = None
def transform(self, path: str):
transform = Compose([Resize(
(256, 256), interp='LINEAR'), CenterCrop(crop_size=256)], SetType(datatype='float32'))
return transform(path)
def setTarget(self, Xs: paddle.Tensor):
"""Calculate feature gram matrix"""
F = self.model1(Xs)
G = self.gram(F)
self.ins.setTarget(G)
def getFeature(self, input: paddle.Tensor):
if not self._vgg:
self._vgg = Vgg16()
return self._vgg(input)
def forward(self, input: paddle.Tensor):
return self.model(input)
...@@ -36,5 +36,8 @@ class InstallCommand: ...@@ -36,5 +36,8 @@ class InstallCommand:
elif os.path.exists(_arg) and xarfile.is_xarfile(_arg): elif os.path.exists(_arg) and xarfile.is_xarfile(_arg):
manager.install(archive=_arg) manager.install(archive=_arg)
else: else:
manager.install(name=_arg) _arg = _arg.split('==')
name = _arg[0]
version = None if len(_arg) == 1 else _arg[1]
manager.install(name=name, version=version)
return True return True
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
from typing import Any, List
from paddlehub.compat.module.module_v1 import ModuleV1
from paddlehub.commands import register
from paddlehub.module.manager import LocalModuleManager
from paddlehub.module.module import Module, InvalidHubModule
@register(name='hub.run', description='Run the specific module.')
class RunCommand:
def execute(self, argv: List) -> bool:
if not argv:
print('ERROR: You must give one module to run.')
return False
module_name = argv[0]
if os.path.exists(module_name) and os.path.isdir(module_name):
try:
module = Module.load(module_name)
except InvalidHubModule:
print('{} is not a valid HubModule'.format(module_name))
return False
except:
print('Some exception occurred while loading the {}'.format(module_name))
return False
else:
module = Module(name=module_name)
if not module.is_runnable:
print('ERROR! Module {} is not executable.'.format(module_name))
return False
if isinstance(module, ModuleV1):
result = self.run_module_v1(module, argv[1:])
else:
result = module._run_func(argv[1:])
print(result)
return True
def run_module_v1(self, module, argv: List) -> Any:
parser = argparse.ArgumentParser(prog='hub run {}'.format(module.name), add_help=False)
arg_input_group = parser.add_argument_group(title='Input options', description='Data feed into the module.')
arg_config_group = parser.add_argument_group(
title='Config options', description='Run configuration for controlling module behavior, optional.')
arg_config_group.add_argument(
'--use_gpu', type=ast.literal_eval, default=False, help='whether use GPU for prediction')
arg_config_group.add_argument('--batch_size', type=int, default=1, help='batch size for prediction')
module_type = module.type.lower()
if module_type.startswith('cv'):
arg_input_group.add_argument(
'--input_path', type=str, default=None, help='path of image/video to predict', required=True)
else:
arg_input_group.add_argument('--input_text', type=str, default=None, help='text to predict', required=True)
args = parser.parse_args(argv)
except_data_format = module.processor.data_format(module.default_signature)
key = list(except_data_format.keys())[0]
input_data = {key: [args.input_path] if module_type.startswith('cv') else [args.input_text]}
return module(
sign_name=module.default_signature, data=input_data, use_gpu=args.use_gpu, batch_size=args.batch_size)
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from typing import List
from paddlehub.commands import register
from paddlehub.module.manager import LocalModuleManager
from paddlehub.server.server import module_server
from paddlehub.utils import log, platform
@register(name='hub.search', description='Search PaddleHub pretrained model through model keywords.')
class SearchCommand:
def execute(self, argv: List) -> bool:
argv = '.*' if not argv else argv[0]
widths = [20, 8, 30] if platform.is_windows() else [30, 8, 40]
table = log.Table(widths=widths)
table.append(*['ModuleName', 'Version', 'Summary'], aligns=['^', '^', '^'], colors=["blue", "blue", "blue"])
results = module_server.search_module(name=argv)
for result in results:
table.append(result['name'], result['version'], result['summary'])
print(table)
return True
...@@ -47,7 +47,7 @@ class ShowCommand: ...@@ -47,7 +47,7 @@ class ShowCommand:
widths = [15, 40] if platform.is_windows else [15, 50] widths = [15, 40] if platform.is_windows else [15, 50]
aligns = ['^', '<'] aligns = ['^', '<']
colors = ['yellow', ''] colors = ['cyan', '']
table = log.Table(widths=widths, colors=colors, aligns=aligns) table = log.Table(widths=widths, colors=colors, aligns=aligns)
table.append('ModuleName', module.name) table.append('ModuleName', module.name)
......
...@@ -37,6 +37,7 @@ class ModuleV1(object): ...@@ -37,6 +37,7 @@ class ModuleV1(object):
self.desc = module_v1_utils.convert_module_desc(desc_file) self.desc = module_v1_utils.convert_module_desc(desc_file)
self.helper = self self.helper = self
self.signatures = self.desc.signatures self.signatures = self.desc.signatures
self.default_signature = self.desc.default_signature
self.directory = directory self.directory = directory
self._load_model() self._load_model()
...@@ -185,6 +186,7 @@ class ModuleV1(object): ...@@ -185,6 +186,7 @@ class ModuleV1(object):
cls.type = module_info.type cls.type = module_info.type
cls.summary = module_info.summary cls.summary = module_info.summary
cls.version = utils.Version(module_info.version) cls.version = utils.Version(module_info.version)
cls.directory = directory
return cls return cls
@classmethod @classmethod
...@@ -195,3 +197,7 @@ class ModuleV1(object): ...@@ -195,3 +197,7 @@ class ModuleV1(object):
def assets_path(self): def assets_path(self):
return os.path.join(self.directory, 'assets') return os.path.join(self.directory, 'assets')
@property
def is_runnable(self):
return self.default_signature != None
...@@ -22,6 +22,7 @@ from paddlehub.process.functional import get_img_file ...@@ -22,6 +22,7 @@ from paddlehub.process.functional import get_img_file
from paddlehub.env import DATA_HOME from paddlehub.env import DATA_HOME
from typing import Callable from typing import Callable
class Colorizedataset(paddle.io.Dataset): class Colorizedataset(paddle.io.Dataset):
""" """
Dataset for colorization. Dataset for colorization.
...@@ -34,14 +35,12 @@ class Colorizedataset(paddle.io.Dataset): ...@@ -34,14 +35,12 @@ class Colorizedataset(paddle.io.Dataset):
def __init__(self, transform: Callable, mode: str = 'train'): def __init__(self, transform: Callable, mode: str = 'train'):
self.mode = mode self.mode = mode
self.transform = transform self.transform = transform
if self.mode == 'train': if self.mode == 'train':
self.file = 'train' self.file = 'train'
elif self.mode == 'test': elif self.mode == 'test':
self.file = 'test' self.file = 'test'
else:
self.file = 'validation'
self.file = os.path.join(DATA_HOME, 'canvas', self.file) self.file = os.path.join(DATA_HOME, 'canvas', self.file)
self.data = get_img_file(self.file) self.data = get_img_file(self.file)
...@@ -51,4 +50,4 @@ class Colorizedataset(paddle.io.Dataset): ...@@ -51,4 +50,4 @@ class Colorizedataset(paddle.io.Dataset):
return im['A'], im['hint_B'], im['mask_B'], im['B'], im['real_B_enc'] return im['A'], im['hint_B'], im['mask_B'], im['B'], im['real_B_enc']
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
\ No newline at end of file
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable
import paddle
from paddlehub.process.functional import get_img_file
from paddlehub.env import DATA_HOME
class StyleTransferData(paddle.io.Dataset):
"""
Dataset for Style transfer.
Args:
transform(callmethod) : The method of preprocess images.
mode(str): The mode for preparing dataset.
Returns:
DataSet: An iterable object for data iterating
"""
def __init__(self, transform: Callable, mode: str = 'train'):
self.mode = mode
self.transform = transform
if self.mode == 'train':
self.file = 'train'
elif self.mode == 'test':
self.file = 'test'
self.file = os.path.join(DATA_HOME, 'minicoco', self.file)
self.style_file = os.path.join(DATA_HOME, 'minicoco', '21styles')
self.data = get_img_file(self.file)
self.style = get_img_file(self.style_file)
def __getitem__(self, idx: int):
img_path = self.data[idx]
im = self.transform(img_path)
style_idx = idx % len(self.style)
style_path = self.style[style_idx]
style = self.transform(style_path)
return im, style
def __len__(self):
return len(self.data)
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
from typing import List from typing import List
from collections import OrderedDict from collections import OrderedDict
import cv2
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -137,6 +138,7 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -137,6 +138,7 @@ class ImageColorizeModule(RunModule, ImageServing):
psnrs = [] psnrs = []
lab2rgb = T.ConvertColorSpace(mode='LAB2RGB') lab2rgb = T.ConvertColorSpace(mode='LAB2RGB')
process = T.ColorPostprocess() process = T.ColorPostprocess()
for i in range(batch[0].numpy().shape[0]): for i in range(batch[0].numpy().shape[0]):
real = lab2rgb(np.concatenate((batch[0].numpy(), batch[3].numpy()), axis=1))[i] real = lab2rgb(np.concatenate((batch[0].numpy(), batch[3].numpy()), axis=1))[i]
visual_ret['real'] = process(real) visual_ret['real'] = process(real)
...@@ -146,6 +148,7 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -146,6 +148,7 @@ class ImageColorizeModule(RunModule, ImageServing):
psnr_value = 20 * np.log10(255. / np.sqrt(mse)) psnr_value = 20 * np.log10(255. / np.sqrt(mse))
psnrs.append(psnr_value) psnrs.append(psnr_value)
psnr = paddle.to_variable(np.array(psnrs)) psnr = paddle.to_variable(np.array(psnrs))
return {'loss': loss, 'metrics': {'psnr': psnr}} return {'loss': loss, 'metrics': {'psnr': psnr}}
def predict(self, images: str, visualization: bool = True, save_path: str = 'result'): def predict(self, images: str, visualization: bool = True, save_path: str = 'result'):
...@@ -309,3 +312,87 @@ class Yolov3Module(RunModule, ImageServing): ...@@ -309,3 +312,87 @@ class Yolov3Module(RunModule, ImageServing):
Func.draw_boxes_on_image(imgpath, boxes, scores, labels, label_names, 0.5) Func.draw_boxes_on_image(imgpath, boxes, scores, labels, label_names, 0.5)
return boxes, scores, labels return boxes, scores, labels
class StyleTransferModule(RunModule, ImageServing):
def training_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as loss and metrics.
'''
return self.validation_step(batch, batch_idx)
def validation_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
'''
mse_loss = nn.MSELoss()
N, C, H, W = batch[0].shape
batch[1] = batch[1][0].unsqueeze(0)
self.setTarget(batch[1])
y = self(batch[0])
xc = paddle.to_tensor(batch[0].numpy().copy())
y = Func.subtract_imagenet_mean_batch(y)
xc = Func.subtract_imagenet_mean_batch(xc)
features_y = self.getFeature(y)
features_xc = self.getFeature(xc)
f_xc_c = paddle.to_tensor(features_xc[1].numpy(), stop_gradient=True)
content_loss = mse_loss(features_y[1], f_xc_c)
batch[1] = Func.subtract_imagenet_mean_batch(batch[1])
features_style = self.getFeature(batch[1])
gram_style = [Func.gram_matrix(y) for y in features_style]
style_loss = 0.
for m in range(len(features_y)):
gram_y = Func.gram_matrix(features_y[m])
gram_s = paddle.to_tensor(np.tile(gram_style[m].numpy(), (N, 1, 1, 1)))
style_loss += mse_loss(gram_y, gram_s[:N, :, :])
loss = content_loss + style_loss
return {'loss': loss, 'metrics': {'content gap': content_loss, 'style gap': style_loss}}
def predict(self, origin_path: str, style_path: str, visualization: bool = True, save_path: str = 'result'):
'''
Colorize images
Args:
origin_path(str): Content image path .
style_path(str): Style image path.
visualization(bool): Whether to save colorized images.
save_path(str) : Path to save colorized images.
Returns:
output(np.ndarray) : The style transformed images with bgr mode.
'''
content = paddle.to_tensor(self.transform(origin_path))
style = paddle.to_tensor(self.transform(style_path))
content = content.unsqueeze(0)
style = style.unsqueeze(0)
self.setTarget(style)
output = self(content)
output = paddle.clip(output[0].transpose((1, 2, 0)), 0, 255).numpy()
if visualization:
output = output.astype(np.uint8)
style_name = "style_" + str(time.time()) + ".png"
if not os.path.exists(save_path):
os.mkdir(save_path)
path = os.path.join(save_path, style_name)
cv2.imwrite(path, output)
return output
...@@ -58,7 +58,7 @@ class HubModuleNotFoundError(Exception): ...@@ -58,7 +58,7 @@ class HubModuleNotFoundError(Exception):
hub_version = 'Any' if not info['hub_version'] else ', '.join(info['hub_version']) hub_version = 'Any' if not info['hub_version'] else ', '.join(info['hub_version'])
table.append(self.name, _ver, paddle_version, hub_version, aligns=['^', '^', '^', '^']) table.append(self.name, _ver, paddle_version, hub_version, aligns=['^', '^', '^', '^'])
tips += ', \n{}'.format(table) tips += ':\n{}'.format(table)
return tips return tips
...@@ -104,7 +104,7 @@ class EnvironmentMismatchError(Exception): ...@@ -104,7 +104,7 @@ class EnvironmentMismatchError(Exception):
table.append(self.name, _ver, paddle_version, hub_version, aligns=['^', '^', '^', '^']) table.append(self.name, _ver, paddle_version, hub_version, aligns=['^', '^', '^', '^'])
tips += ', \n{}'.format(table) tips += ':\n{}'.format(table)
return tips return tips
...@@ -238,28 +238,29 @@ class LocalModuleManager(object): ...@@ -238,28 +238,29 @@ class LocalModuleManager(object):
return self._local_modules[name] return self._local_modules[name]
result = module_server.search_module(name=name, version=version, source=source) result = module_server.search_module(name=name, version=version, source=source)
if not result: for item in result:
module_infos = module_server.get_module_info(name=name, source=source) if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
# The HubModule with the specified name cannot be found if source or 'source' in item:
if not module_infos: return self._install_from_source(result)
raise HubModuleNotFoundError(name=name, version=version, source=source) return self._install_from_url(item['url'])
valid_infos = {} module_infos = module_server.get_module_info(name=name, source=source)
if version: # The HubModule with the specified name cannot be found
for _ver, _info in module_infos.items(): if not module_infos:
if utils.Version(_ver).match(version): raise HubModuleNotFoundError(name=name, version=version, source=source)
valid_infos[_ver] = _info
else: valid_infos = {}
valid_infos = list(module_infos.keys()) if version:
for _ver, _info in module_infos.items():
# Cannot find a HubModule that meets the version if utils.Version(_ver).match(version):
if valid_infos: valid_infos[_ver] = _info
raise EnvironmentMismatchError(name=name, info=valid_infos, version=version) else:
raise HubModuleNotFoundError(name=name, info=module_infos, version=version, source=source) valid_infos = module_infos.copy()
if source or 'source' in result: # Cannot find a HubModule that meets the version
return self._install_from_source(result) if valid_infos:
return self._install_from_url(result['url']) raise EnvironmentMismatchError(name=name, info=valid_infos, version=version)
raise HubModuleNotFoundError(name=name, info=module_infos, version=version, source=source)
def _install_from_source(self, source: str) -> HubModule: def _install_from_source(self, source: str) -> HubModule:
'''Install a HubModule from Git Repo''' '''Install a HubModule from Git Repo'''
......
...@@ -148,6 +148,10 @@ class Module(object): ...@@ -148,6 +148,10 @@ class Module(object):
user_module = user_module_cls(directory=directory) user_module = user_module_cls(directory=directory)
user_module._initialize(**kwargs) user_module._initialize(**kwargs)
return user_module return user_module
if user_module_cls == ModuleV1:
return user_module_cls(directory=directory, **kwargs)
user_module_cls.directory = directory user_module_cls.directory = directory
return user_module_cls(**kwargs) return user_module_cls(**kwargs)
...@@ -166,6 +170,10 @@ class Module(object): ...@@ -166,6 +170,10 @@ class Module(object):
user_module = user_module_cls(directory=directory) user_module = user_module_cls(directory=directory)
user_module._initialize(**kwargs) user_module._initialize(**kwargs)
return user_module return user_module
if user_module_cls == ModuleV1:
return user_module_cls(directory=directory, **kwargs)
user_module_cls.directory = directory user_module_cls.directory = directory
return user_module_cls(**kwargs) return user_module_cls(**kwargs)
......
...@@ -119,7 +119,6 @@ def get_img_file(dir_name: str) -> list: ...@@ -119,7 +119,6 @@ def get_img_file(dir_name: str) -> list:
if not is_image_file(filename): if not is_image_file(filename):
continue continue
img_path = os.path.join(parent, filename) img_path = os.path.join(parent, filename)
print(img_path)
images.append(img_path) images.append(img_path)
images.sort() images.sort()
return images return images
...@@ -246,3 +245,22 @@ def get_label_infos(file_list: str): ...@@ -246,3 +245,22 @@ def get_label_infos(file_list: str):
for category in categories: for category in categories:
label_names.append(category['name']) label_names.append(category['name'])
return label_names return label_names
def subtract_imagenet_mean_batch(batch: paddle.Tensor) -> paddle.Tensor:
"""Subtract ImageNet mean pixel-wise from a BGR image."""
mean = np.zeros(shape=batch.shape, dtype='float32')
mean[:, 0, :, :] = 103.939
mean[:, 1, :, :] = 116.779
mean[:, 2, :, :] = 123.680
mean = paddle.to_tensor(mean)
return batch - mean
def gram_matrix(data: paddle.Tensor) -> paddle.Tensor:
"""Get gram matrix"""
b, ch, h, w = data.shape
features = data.reshape((b, ch, w * h))
features_t = features.transpose((0, 2, 1))
gram = features.bmm(features_t) / (ch * h * w)
return gram
...@@ -26,15 +26,16 @@ from paddlehub.process.functional import * ...@@ -26,15 +26,16 @@ from paddlehub.process.functional import *
class Compose: class Compose:
def __init__(self, transforms, to_rgb=True, stay_rgb=False): def __init__(self, transforms, to_rgb=True, stay_rgb=False, is_permute=True):
if not isinstance(transforms, list): if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!') raise TypeError('The transforms must be a list!')
if len(transforms) < 1: if len(transforms) < 1:
raise ValueError('The length of transforms ' + \ raise ValueError('The length of transforms ' + \
'must be equal or larger than 1!') 'must be equal or larger than 1!')
self.transforms = transforms self.transforms = transforms
self.to_rgb = to_rgb self.to_rgb = to_rgb
self.stay_rgb = stay_rgb self.stay_rgb = stay_rgb
self.is_permute = is_permute
def __call__(self, im): def __call__(self, im):
if isinstance(im, str): if isinstance(im, str):
...@@ -49,6 +50,9 @@ class Compose: ...@@ -49,6 +50,9 @@ class Compose:
im = op(im) im = op(im)
if not self.stay_rgb: if not self.stay_rgb:
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
if self.is_permute:
im = permute(im) im = permute(im)
return im return im
...@@ -570,17 +574,7 @@ class ColorizeHint: ...@@ -570,17 +574,7 @@ class ColorizeHint:
self.use_avg = use_avg self.use_avg = use_avg
def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray): def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray):
sample_Ps = [ sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9]
1,
2,
3,
4,
5,
6,
7,
8,
9,
]
self.data = data self.data = data
self.hint = hint self.hint = hint
self.mask = mask self.mask = mask
...@@ -591,7 +585,7 @@ class ColorizeHint: ...@@ -591,7 +585,7 @@ class ColorizeHint:
while cont_cond: while cont_cond:
if self.num_points is None: # draw from geometric if self.num_points is None: # draw from geometric
# embed() # embed()
cont_cond = np.random.rand() < (1 - self.percent) cont_cond = np.random.rand() > (1 - self.percent)
else: # add certain number of points else: # add certain number of points
cont_cond = pp < self.num_points cont_cond = pp < self.num_points
if not cont_cond: # skip out of loop if condition not met if not cont_cond: # skip out of loop if condition not met
...@@ -659,7 +653,7 @@ class ColorizePreprocess: ...@@ -659,7 +653,7 @@ class ColorizePreprocess:
""" """
def __init__(self, def __init__(self,
ab_thresh: float = 0., ab_thresh: float = 0.,
p: float = .125, p: float = 0.,
num_points: int = None, num_points: int = None,
samp: str = 'normal', samp: str = 'normal',
use_avg: bool = True, use_avg: bool = True,
...@@ -733,3 +727,41 @@ class ColorPostprocess: ...@@ -733,3 +727,41 @@ class ColorPostprocess:
img = np.clip(img, 0, 1) * 255 img = np.clip(img, 0, 1) * 255
img = img.astype(self.type) img = img.astype(self.type)
return img return img
class CenterCrop:
"""
Crop the middle part of the image to the specified size.
Args:
crop_size(int): Crop size.
Return:
img(np.ndarray): Croped image.
"""
def __init__(self, crop_size: int):
self.crop_size = crop_size
def __call__(self, img: np.ndarray):
img_width, img_height, chanel = img.shape
crop_top = int((img_height - self.crop_size) / 2.)
crop_left = int((img_width - self.crop_size) / 2.)
return img[crop_left:crop_left + self.crop_size, crop_top:crop_top + self.crop_size, :]
class SetType:
"""
Set image type.
Args:
type(type): Type of Image value.
Return:
img(np.ndarray): Transformed image.
"""
def __init__(self, datatype: type = 'float32'):
self.type = datatype
def __call__(self, img: np.ndarray):
img = img.astype(self.type)
return img
...@@ -18,6 +18,7 @@ import importlib ...@@ -18,6 +18,7 @@ import importlib
import os import os
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from typing import List
from urllib.parse import urlparse from urllib.parse import urlparse
import git import git
...@@ -74,7 +75,7 @@ class GitSource(object): ...@@ -74,7 +75,7 @@ class GitSource(object):
log.logger.warning('An error occurred while loading {}'.format(self.path)) log.logger.warning('An error occurred while loading {}'.format(self.path))
sys.path.remove(self.path) sys.path.remove(self.path)
def search_module(self, name: str, version: str = None) -> dict: def search_module(self, name: str, version: str = None) -> List[dict]:
''' '''
Search PaddleHub module Search PaddleHub module
...@@ -84,7 +85,7 @@ class GitSource(object): ...@@ -84,7 +85,7 @@ class GitSource(object):
''' '''
return self.search_resource(type='module', name=name, version=version) return self.search_resource(type='module', name=name, version=version)
def search_resource(self, type: str, name: str, version: str = None) -> dict: def search_resource(self, type: str, name: str, version: str = None) -> List[dict]:
''' '''
Search PaddleHub Resource Search PaddleHub Resource
...@@ -95,13 +96,13 @@ class GitSource(object): ...@@ -95,13 +96,13 @@ class GitSource(object):
''' '''
module = self.hub_modules.get(name, None) module = self.hub_modules.get(name, None)
if module and module.version.match(version): if module and module.version.match(version):
return { return [{
'version': module.version, 'version': module.version,
'name': module.name, 'name': module.name,
'path': self.path, 'path': self.path,
'class': module.__name__, 'class': module.__name__,
'source': self.url 'source': self.url
} }]
return None return None
@classmethod @classmethod
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
from typing import List
from paddlehub.server import ServerSource, GitSource from paddlehub.server import ServerSource, GitSource
...@@ -44,7 +45,7 @@ class HubServer(object): ...@@ -44,7 +45,7 @@ class HubServer(object):
'''Remove a module source''' '''Remove a module source'''
self.sources.pop(key) self.sources.pop(key)
def search_module(self, name: str, version: str = None, source: str = None) -> dict: def search_module(self, name: str, version: str = None, source: str = None) -> List[dict]:
''' '''
Search PaddleHub module Search PaddleHub module
...@@ -54,7 +55,7 @@ class HubServer(object): ...@@ -54,7 +55,7 @@ class HubServer(object):
''' '''
return self.search_resource(type='module', name=name, version=version, source=source) return self.search_resource(type='module', name=name, version=version, source=source)
def search_resource(self, type: str, name: str, version: str = None, source: str = None) -> dict: def search_resource(self, type: str, name: str, version: str = None, source: str = None) -> List[dict]:
''' '''
Search PaddleHub Resource Search PaddleHub Resource
...@@ -68,7 +69,7 @@ class HubServer(object): ...@@ -68,7 +69,7 @@ class HubServer(object):
result = source.search_resource(name=name, type=type, version=version) result = source.search_resource(name=name, type=type, version=version)
if result: if result:
return result return result
return {} return []
def get_module_info(self, name: str, source: str = None) -> dict: def get_module_info(self, name: str, source: str = None) -> dict:
''' '''
......
...@@ -43,7 +43,7 @@ class ServerSource(object): ...@@ -43,7 +43,7 @@ class ServerSource(object):
self._url = url self._url = url
self._timeout = timeout self._timeout = timeout
def search_module(self, name: str, version: str = None) -> dict: def search_module(self, name: str, version: str = None) -> List[dict]:
''' '''
Search PaddleHub module Search PaddleHub module
...@@ -53,7 +53,7 @@ class ServerSource(object): ...@@ -53,7 +53,7 @@ class ServerSource(object):
''' '''
return self.search_resource(type='module', name=name, version=version) return self.search_resource(type='module', name=name, version=version)
def search_resource(self, type: str, name: str, version: str = None) -> dict: def search_resource(self, type: str, name: str, version: str = None) -> List[dict]:
''' '''
Search PaddleHub Resource Search PaddleHub Resource
...@@ -76,9 +76,7 @@ class ServerSource(object): ...@@ -76,9 +76,7 @@ class ServerSource(object):
result = self.request(path='search', params=params) result = self.request(path='search', params=params)
if result['status'] == 0 and len(result['data']) > 0: if result['status'] == 0 and len(result['data']) > 0:
for item in result['data']: return result['data']
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
return item
return None return None
def get_module_info(self, name: str) -> dict: def get_module_info(self, name: str) -> dict:
......
...@@ -41,11 +41,17 @@ def download(name: str, save_path: str, version: str = None): ...@@ -41,11 +41,17 @@ def download(name: str, save_path: str, version: str = None):
if os.path.exists(file): if os.path.exists(file):
return return
resource = module_server.search_resouce(name=name, version=version, type='Model') resources = module_server.search_resouce(name=name, version=version, type='Model')
if not resource: if not resources:
raise ResourceNotFoundError(name, version)
for item in resources:
if item['name'] == name and utils.Version(item['version']).match(version):
url = item['url']
break
else:
raise ResourceNotFoundError(name, version) raise ResourceNotFoundError(name, version)
url = resource['url']
with utils.generate_tempdir() as _dir: with utils.generate_tempdir() as _dir:
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.makedirs(save_path) os.makedirs(save_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册