diff --git a/hub_module/modules/image/colorization/deoldify/README.md b/hub_module/modules/image/colorization/deoldify/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ac847cb7b297f02473b75585cc74695a79db51f6 --- /dev/null +++ b/hub_module/modules/image/colorization/deoldify/README.md @@ -0,0 +1,121 @@ + +## 模型概述 +deoldify是用于图像和视频的着色渲染模型,该模型能够实现给黑白照片和视频恢复原彩。 + +## API 说明 + +```python +def predict(self, input): +``` + +着色变换API,得到着色后的图片或者视频。 + + +**参数** + +* input(str): 图片或者视频的路径; + +**返回** + +若输入是图片,返回值为: +* pred_img(np.ndarray): BGR图片数据; +* out_path(str): 保存图片路径。 + +若输入是视频,返回值为: +* frame_pattern_combined(str): 视频着色后单帧数据保存路径; +* vid_out_path(str): 视频保存路径。 + +```python +def run_image(self, img): +``` +图像着色API, 得到着色后的图片。 + +**参数** + +* img (str|np.ndarray): 图片路径或则BGR格式图片。 + +**返回** + +* pred_img(np.ndarray): BGR图片数据; + +```python +def run_video(self, video): +``` +视频着色API, 得到着色后的视频。 + +**参数** + +* video (str): 待处理视频路径。 + +**返回** + +* frame_pattern_combined(str): 视频着色后单帧数据保存路径; +* vid_out_path(str): 视频保存路径。 + +## 预测代码示例 + +```python +import paddlehub as hub + +model = hub.Module('deoldify') +model.predict('/PATH/TO/IMAGE/OR/VIDEO') +``` + +## 服务部署 + +PaddleHub Serving可以部署一个在线照片着色服务。 + +## 第一步:启动PaddleHub Serving + +运行启动命令: +```shell +$ hub serving start -m deoldify +``` + +这样就完成了一个图像着色的在线服务API的部署,默认端口号为8866。 + +**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。 + +## 第二步:发送预测请求 + +配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果 + +```python +import requests +import json +import base64 + +import cv2 +import numpy as np + +def cv2_to_base64(image): + data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(data.tostring()).decode('utf8') +def base64_to_cv2(b64str): + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data + +# 发送HTTP请求 +org_im = cv2.imread('/PATH/TO/ORIGIN/IMAGE') +data = {'images':cv2_to_base64(org_im)} +headers = {"Content-type": "application/json"} +url = "http://127.0.0.1:8866/predict/deoldify" +r = requests.post(url=url, headers=headers, data=json.dumps(data)) +img = base64_to_cv2(r.json()["results"]) +cv2.imwrite('/PATH/TO/SAVE/IMAGE', img) +``` + + +## 模型相关信息 + +### 模型代码 + +https://github.com/jantic/DeOldify + +### 依赖 + +paddlepaddle >= 2.0.0rc + +paddlehub >= 1.8.3 \ No newline at end of file diff --git a/hub_module/modules/image/colorization/deoldify/base_module.py b/hub_module/modules/image/colorization/deoldify/base_module.py new file mode 100644 index 0000000000000000000000000000000000000000..da4b1c86ce737f4dcc7f6e15a6e75a998eb87108 --- /dev/null +++ b/hub_module/modules/image/colorization/deoldify/base_module.py @@ -0,0 +1,474 @@ +import paddle +import numpy as np +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.vision.models import resnet101 + +import deoldify.utils as U + + +class SequentialEx(nn.Layer): + "Like `nn.Sequential`, but with ModuleList semantics, and can access module input" + + def __init__(self, *layers): + super().__init__() + self.layers = nn.LayerList(layers) + + def forward(self, x): + res = x + for l in self.layers: + if isinstance(l, MergeLayer): + l.orig = x + nres = l(res) + # We have to remove res.orig to avoid hanging refs and therefore memory leaks + # l.orig = None + res = nres + return res + + def __getitem__(self, i): + return self.layers[i] + + def append(self, l): + return self.layers.append(l) + + def extend(self, l): + return self.layers.extend(l) + + def insert(self, i, l): + return self.layers.insert(i, l) + + +class Deoldify(SequentialEx): + def __init__(self, + encoder, + n_classes, + blur=False, + blur_final=True, + self_attention=False, + y_range=None, + last_cross=True, + bottle=False, + norm_type='Batch', + nf_factor=1, + **kwargs): + + imsize = (256, 256) + sfs_szs = U.model_sizes(encoder, size=imsize) + sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs))) + self.sfs = U.hook_outputs([encoder[i] for i in sfs_idxs], detach=False) + x = U.dummy_eval(encoder, imsize).detach() + + nf = 512 * nf_factor + extra_bn = norm_type == 'Spectral' + ni = sfs_szs[-1][1] + middle_conv = nn.Sequential( + custom_conv_layer(ni, + ni * 2, + norm_type=norm_type, + extra_bn=extra_bn), + custom_conv_layer(ni * 2, + ni, + norm_type=norm_type, + extra_bn=extra_bn), + ) + + layers = [encoder, nn.BatchNorm(ni), nn.ReLU(), middle_conv] + + for i, idx in enumerate(sfs_idxs): + not_final = i != len(sfs_idxs) - 1 + up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1]) + do_blur = blur and (not_final or blur_final) + sa = self_attention and (i == len(sfs_idxs) - 3) + + n_out = nf if not_final else nf // 2 + + unet_block = UnetBlockWide(up_in_c, + x_in_c, + n_out, + self.sfs[i], + final_div=not_final, + blur=blur, + self_attention=sa, + norm_type=norm_type, + extra_bn=extra_bn, + **kwargs) + unet_block.eval() + layers.append(unet_block) + x = unet_block(x) + + ni = x.shape[1] + if imsize != sfs_szs[0][-2:]: + layers.append(PixelShuffle_ICNR(ni, **kwargs)) + if last_cross: + layers.append(MergeLayer(dense=True)) + ni += 3 + layers.append( + res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs)) + layers += [ + custom_conv_layer(ni, + n_classes, + ks=1, + use_activ=False, + norm_type=norm_type) + ] + if y_range is not None: + layers.append(SigmoidRange(*y_range)) + super().__init__(*layers) + + +def custom_conv_layer(ni: int, + nf: int, + ks: int = 3, + stride: int = 1, + padding: int = None, + bias: bool = None, + is_1d: bool = False, + norm_type='Batch', + use_activ: bool = True, + leaky: float = None, + transpose: bool = False, + self_attention: bool = False, + extra_bn: bool = False, + **kwargs): + "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers." + if padding is None: + padding = (ks - 1) // 2 if not transpose else 0 + bn = norm_type in ('Batch', 'Batchzero') or extra_bn == True + if bias is None: + bias = not bn + conv_func = nn.Conv2DTranspose if transpose else nn.Conv1d if is_1d else nn.Conv2D + + conv = conv_func(ni, + nf, + kernel_size=ks, + bias_attr=bias, + stride=stride, + padding=padding) + if norm_type == 'Weight': + conv = nn.utils.weight_norm(conv) + elif norm_type == 'Spectral': + conv = U.Spectralnorm(conv) + layers = [conv] + if use_activ: + layers.append(relu(True, leaky=leaky)) + if bn: + layers.append((nn.BatchNorm if is_1d else nn.BatchNorm)(nf)) + if self_attention: + layers.append(SelfAttention(nf)) + + return nn.Sequential(*layers) + + +def relu(inplace: bool = False, leaky: float = None): + "Return a relu activation, maybe `leaky` and `inplace`." + return nn.LeakyReLU(leaky) if leaky is not None else nn.ReLU() + + +class UnetBlockWide(nn.Layer): + "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`." + + def __init__(self, + up_in_c: int, + x_in_c: int, + n_out: int, + hook, + final_div: bool = True, + blur: bool = False, + leaky: float = None, + self_attention: bool = False, + **kwargs): + super().__init__() + self.hook = hook + up_out = x_out = n_out // 2 + self.shuf = CustomPixelShuffle_ICNR(up_in_c, + up_out, + blur=blur, + leaky=leaky, + **kwargs) + self.bn = nn.BatchNorm(x_in_c) + ni = up_out + x_in_c + self.conv = custom_conv_layer(ni, + x_out, + leaky=leaky, + self_attention=self_attention, + **kwargs) + self.relu = relu(leaky=leaky) + + def forward(self, up_in): + s = self.hook.stored + up_out = self.shuf(up_in) + ssh = s.shape[-2:] + if ssh != up_out.shape[-2:]: + up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest') + cat_x = self.relu(paddle.concat([up_out, self.bn(s)], axis=1)) + return self.conv(cat_x) + + +class UnetBlockDeep(nn.Layer): + "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`." + + def __init__( + self, + up_in_c: int, + x_in_c: int, + # hook: Hook, + final_div: bool = True, + blur: bool = False, + leaky: float = None, + self_attention: bool = False, + nf_factor: float = 1.0, + **kwargs): + super().__init__() + + self.shuf = CustomPixelShuffle_ICNR(up_in_c, + up_in_c // 2, + blur=blur, + leaky=leaky, + **kwargs) + self.bn = nn.BatchNorm(x_in_c) + ni = up_in_c // 2 + x_in_c + nf = int((ni if final_div else ni // 2) * nf_factor) + self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs) + self.conv2 = custom_conv_layer(nf, + nf, + leaky=leaky, + self_attention=self_attention, + **kwargs) + self.relu = relu(leaky=leaky) + + def forward(self, up_in): + s = self.hook.stored + up_out = self.shuf(up_in) + ssh = s.shape[-2:] + if ssh != up_out.shape[-2:]: + up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest') + cat_x = self.relu(paddle.concat([up_out, self.bn(s)], axis=1)) + return self.conv2(self.conv1(cat_x)) + + +def ifnone(a, b): + "`a` if `a` is not None, otherwise `b`." + return b if a is None else a + + +class PixelShuffle_ICNR(nn.Layer): + "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, \ + `icnr` init, and `weight_norm`." + + def __init__(self, + ni: int, + nf: int = None, + scale: int = 2, + blur: bool = False, + norm_type='Weight', + leaky: float = None): + super().__init__() + nf = ifnone(nf, ni) + self.conv = conv_layer(ni, + nf * (scale ** 2), + ks=1, + norm_type=norm_type, + use_activ=False) + + self.shuf = PixelShuffle(scale) + + self.pad = ReplicationPad2d([1, 0, 1, 0]) + self.blur = nn.AvgPool2D(2, stride=1) + self.relu = relu(True, leaky=leaky) + + def forward(self, x): + x = self.shuf(self.relu(self.conv(x))) + return self.blur(self.pad(x)) if self.blur else x + + +def conv_layer(ni: int, + nf: int, + ks: int = 3, + stride: int = 1, + padding: int = None, + bias: bool = None, + is_1d: bool = False, + norm_type='Batch', + use_activ: bool = True, + leaky: float = None, + transpose: bool = False, + init=None, + self_attention: bool = False): + "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers." + if padding is None: padding = (ks - 1) // 2 if not transpose else 0 + bn = norm_type in ('Batch', 'BatchZero') + if bias is None: bias = not bn + conv_func = nn.Conv2DTranspose if transpose else nn.Conv1d if is_1d else nn.Conv2D + + conv = conv_func(ni, + nf, + kernel_size=ks, + bias_attr=bias, + stride=stride, + padding=padding) + if norm_type == 'Weight': + conv = nn.utils.weight_norm(conv) + elif norm_type == 'Spectral': + conv = U.Spectralnorm(conv) + + layers = [conv] + if use_activ: layers.append(relu(True, leaky=leaky)) + if bn: layers.append((nn.BatchNorm if is_1d else nn.BatchNorm)(nf)) + if self_attention: layers.append(SelfAttention(nf)) + return nn.Sequential(*layers) + + +class CustomPixelShuffle_ICNR(nn.Layer): + "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, \ + and `weight_norm`." + + def __init__(self, + ni: int, + nf: int = None, + scale: int = 2, + blur: bool = False, + leaky: float = None, + **kwargs): + super().__init__() + nf = ifnone(nf, ni) + self.conv = custom_conv_layer(ni, + nf * (scale ** 2), + ks=1, + use_activ=False, + **kwargs) + + self.shuf = PixelShuffle(scale) + + self.pad = ReplicationPad2d([1, 0, 1, 0]) + self.blur = paddle.nn.AvgPool2D(2, stride=1) + self.relu = nn.LeakyReLU( + leaky) if leaky is not None else nn.ReLU() # relu(True, leaky=leaky) + + def forward(self, x): + x = self.shuf(self.relu(self.conv(x))) + return self.blur(self.pad(x)) if self.blur else x + + +class MergeLayer(nn.Layer): + "Merge a shortcut with the result of the module by adding them or concatenating thme if `dense=True`." + + def __init__(self, dense: bool = False): + super().__init__() + self.dense = dense + self.orig = None + + def forward(self, x): + out = paddle.concat([x, self.orig], + axis=1) if self.dense else (x + self.orig) + self.orig = None + return out + + +def res_block(nf, + dense: bool = False, + norm_type='Batch', + bottle: bool = False, + **conv_kwargs): + "Resnet block of `nf` features. `conv_kwargs` are passed to `conv_layer`." + norm2 = norm_type + if not dense and (norm_type == 'Batch'): norm2 = 'BatchZero' + nf_inner = nf // 2 if bottle else nf + return SequentialEx( + conv_layer(nf, nf_inner, norm_type=norm_type, **conv_kwargs), + conv_layer(nf_inner, nf, norm_type=norm2, **conv_kwargs), + MergeLayer(dense)) + + +class SigmoidRange(nn.Layer): + "Sigmoid module with range `(low,x_max)`" + + def __init__(self, low, high): + super().__init__() + self.low, self.high = low, high + + def forward(self, x): + return sigmoid_range(x, self.low, self.high) + + +def sigmoid_range(x, low, high): + "Sigmoid function with range `(low, high)`" + return F.sigmoid(x) * (high - low) + low + + +class PixelShuffle(nn.Layer): + def __init__(self, upscale_factor): + super(PixelShuffle, self).__init__() + self.upscale_factor = upscale_factor + + def forward(self, x): + return F.pixel_shuffle(x, self.upscale_factor) + + +class ReplicationPad2d(nn.Layer): + def __init__(self, size): + super(ReplicationPad2d, self).__init__() + self.size = size + + def forward(self, x): + return F.pad(x, self.size, mode="replicate") + + +def conv1d(ni: int, + no: int, + ks: int = 1, + stride: int = 1, + padding: int = 0, + bias: bool = False): + "Create and initialize a `nn.Conv1d` layer with spectral normalization." + conv = nn.Conv1D(ni, no, ks, stride=stride, padding=padding, bias_attr=bias) + return U.Spectralnorm(conv) + + +class SelfAttention(nn.Layer): + "Self attention layer for nd." + + def __init__(self, n_channels): + super().__init__() + self.query = conv1d(n_channels, n_channels // 8) + self.key = conv1d(n_channels, n_channels // 8) + self.value = conv1d(n_channels, n_channels) + self.gamma = self.create_parameter( + shape=[1], default_initializer=paddle.nn.initializer.Constant( + 0.0)) # nn.Parameter(tensor([0.])) + + def forward(self, x): + # Notation from https://arxiv.org/pdf/1805.08318.pdf + size = x.shape + x = paddle.reshape(x, list(size[:2]) + [-1]) + f, g, h = self.query(x), self.key(x), self.value(x) + + beta = paddle.nn.functional.softmax(paddle.bmm( + paddle.transpose(f, [0, 2, 1]), g), + axis=1) + o = self.gamma * paddle.bmm(h, beta) + x + return paddle.reshape(o, size) + + +def _get_sfs_idxs(sizes): + "Get the indexes of the layers where the size of the activation changes." + feature_szs = [size[-1] for size in sizes] + sfs_idxs = list( + np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]) + if feature_szs[0] != feature_szs[1]: + sfs_idxs = [0] + sfs_idxs + return sfs_idxs + + +def build_model(): + backbone = resnet101() + cut = -2 + encoder = nn.Sequential(*list(backbone.children())[:cut]) + + model = Deoldify(encoder, + 3, + blur=True, + y_range=(-3, 3), + norm_type='Spectral', + self_attention=True, + nf_factor=2) + return model diff --git a/hub_module/modules/image/colorization/deoldify/module.py b/hub_module/modules/image/colorization/deoldify/module.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5dd5b3ba2369d6542727d02976d45ec6bde2cd --- /dev/null +++ b/hub_module/modules/image/colorization/deoldify/module.py @@ -0,0 +1,166 @@ +# coding:utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import glob + +import cv2 +import paddle +import paddle.nn as nn +import numpy as np +from PIL import Image +from tqdm import tqdm + +import deoldify.utils as U +from paddlehub.module.module import moduleinfo, serving, Module +from deoldify.base_module import build_model + + +@moduleinfo(name="deoldify", + type="CV/image_editing", + author="paddlepaddle", + author_email="", + summary="Deoldify is a colorizaton model", + version="1.0.0") +class DeOldifyPredictor(Module): + def _initialize(self, render_factor: int = 32, output_path: int = 'result', load_checkpoint: str = None): + #super(DeOldifyPredictor, self).__init__() + self.model = build_model() + self.render_factor = render_factor + self.output = os.path.join(output_path, 'DeOldify') + if not os.path.exists(self.output): + os.makedirs(self.output) + if load_checkpoint is not None: + state_dict = paddle.load(load_checkpoint) + self.model.load_dict(state_dict) + print("load custom checkpoint success") + + else: + checkpoint = os.path.join(self.directory, 'DeOldify_stable.pdparams') + state_dict = paddle.load(checkpoint) + self.model.load_dict(state_dict) + print("load pretrained checkpoint success") + + def norm(self, img, render_factor=32, render_base=16): + target_size = render_factor * render_base + img = img.resize((target_size, target_size), resample=Image.BILINEAR) + + img = np.array(img).transpose([2, 0, 1]).astype('float32') / 255.0 + + img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) + img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + + img -= img_mean + img /= img_std + return img.astype('float32') + + def denorm(self, img): + img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) + img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + + img *= img_std + img += img_mean + img = img.transpose((1, 2, 0)) + + return (img * 255).clip(0, 255).astype('uint8') + + + def post_process(self, raw_color, orig): + color_np = np.asarray(raw_color) + orig_np = np.asarray(orig) + color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV) + orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV) + hires = np.copy(orig_yuv) + hires[:, :, 1:3] = color_yuv[:, :, 1:3] + final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR) + return final + + def run_image(self, img): + if isinstance(img, str): + ori_img = Image.open(img).convert('LA').convert('RGB') + elif isinstance(img, np.ndarray): + ori_img = Image.fromarray(img).convert('LA').convert('RGB') + elif isinstance(img, Image.Image): + ori_img = img + + img = self.norm(ori_img, self.render_factor) + x = paddle.to_tensor(img[np.newaxis, ...]) + out = self.model(x) + + pred_img = self.denorm(out.numpy()[0]) + pred_img = Image.fromarray(pred_img) + pred_img = pred_img.resize(ori_img.size, resample=Image.BILINEAR) + pred_img = self.post_process(pred_img, ori_img) + pred_img =cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR) + return pred_img + + def run_video(self, video): + base_name = os.path.basename(video).split('.')[0] + output_path = os.path.join(self.output, base_name) + pred_frame_path = os.path.join(output_path, 'frames_pred') + + if not os.path.exists(output_path): + os.makedirs(output_path) + + if not os.path.exists(pred_frame_path): + os.makedirs(pred_frame_path) + + cap = cv2.VideoCapture(video) + fps = cap.get(cv2.CAP_PROP_FPS) + + out_path = U.video2frames(video, output_path) + + frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) + + for frame in tqdm(frames): + pred_img = self.run_image(frame) + pred_img = cv2.cvtColor(pred_img, cv2.COLOR_BGR2RGB) + pred_img = Image.fromarray(pred_img) + frame_name = os.path.basename(frame) + pred_img.save(os.path.join(pred_frame_path, frame_name)) + + frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') + + vid_out_path = os.path.join(output_path, + '{}_deoldify_out.mp4'.format(base_name)) + U.frames2video(frame_pattern_combined, vid_out_path, str(int(fps))) + print('Save video result at {}.'.format(vid_out_path)) + + return frame_pattern_combined, vid_out_path + + def predict(self, input): + if not os.path.exists(self.output): + os.makedirs(self.output) + + if not U.is_image(input): + return self.run_video(input) + else: + pred_img = self.run_image(input) + + if self.output: + base_name = os.path.splitext(os.path.basename(input))[0] + out_path = os.path.join(self.output, base_name + '.png') + cv2.imwrite(out_path, pred_img) + return pred_img, out_path + + @serving + def serving_method(self, images, **kwargs): + """ + Run as a service. + """ + images_decode = U.base64_to_cv2(images) + results = self.run_image(img=images_decode) + results = U.cv2_to_base64(results) + return results \ No newline at end of file diff --git a/hub_module/modules/image/colorization/deoldify/resnet.py b/hub_module/modules/image/colorization/deoldify/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..806a62346edab77be13d274472371d0a771e81fc --- /dev/null +++ b/hub_module/modules/image/colorization/deoldify/resnet.py @@ -0,0 +1,383 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import math +import paddle.fluid as fluid + +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear +from paddle.fluid.dygraph.container import Sequential + +from paddle.utils.download import get_weights_path_from_url + +__all__ = [ + 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152' +] + +model_urls = { + 'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams', + '0ba53eea9bc970962d0ef96f7b94057e'), + 'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams', + '46bc9f7c3dd2e55b7866285bee91eff3'), + 'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams', + '5ce890a9ad386df17cf7fe2313dca0a1'), + 'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams', + 'fb07a451df331e4b0bb861ed97c3a9b9'), + 'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams', + 'f9c700f26d3644bb76ad2226ed5f5713'), +} + + +class ConvBNLayer(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + act=None): + super(ConvBNLayer, self).__init__() + + self._conv = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + bias_attr=False) + + self._batch_norm = BatchNorm(num_filters, act=act) + + def forward(self, inputs): + x = self._conv(inputs) + x = self._batch_norm(x) + + return x + + +class BasicBlock(fluid.dygraph.Layer): + """residual block of resnet18 and resnet34 + """ + expansion = 1 + + def __init__(self, num_channels, num_filters, stride, shortcut=True): + super(BasicBlock, self).__init__() + + self.conv0 = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters, + filter_size=3, + act='relu') + self.conv1 = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu') + + if not shortcut: + self.short = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters, + filter_size=1, + stride=stride) + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + + y = short + conv1 + + return fluid.layers.relu(y) + + +class BottleneckBlock(fluid.dygraph.Layer): + """residual block of resnet50, resnet101 amd resnet152 + """ + + expansion = 4 + + def __init__(self, num_channels, num_filters, stride, shortcut=True): + super(BottleneckBlock, self).__init__() + + self.conv0 = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters, + filter_size=1, + act='relu') + self.conv1 = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu') + self.conv2 = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters * self.expansion, + filter_size=1, + act=None) + + if not shortcut: + self.short = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters * self.expansion, + filter_size=1, + stride=stride) + + self.shortcut = shortcut + + self._num_channels_out = num_filters * self.expansion + + def forward(self, inputs): + x = self.conv0(inputs) + conv1 = self.conv1(x) + conv2 = self.conv2(conv1) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + + x = fluid.layers.elementwise_add(x=short, y=conv2) + + return fluid.layers.relu(x) + + +class ResNet(fluid.dygraph.Layer): + """ResNet model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + Block (BasicBlock|BottleneckBlock): block module of model. + depth (int): layers of resnet, default: 50. + num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer + will not be defined. Default: 1000. + with_pool (bool): use pool before the last fc layer or not. Default: True. + classifier_activation (str): activation for the last fc layer. Default: 'softmax'. + + Examples: + .. code-block:: python + + from paddle.vision.models import ResNet + from paddle.vision.models.resnet import BottleneckBlock, BasicBlock + + resnet50 = ResNet(BottleneckBlock, 50) + + resnet18 = ResNet(BasicBlock, 18) + + """ + + def __init__(self, + Block, + depth=50, + num_classes=1000, + with_pool=True, + classifier_activation='softmax'): + super(ResNet, self).__init__() + + self.num_classes = num_classes + self.with_pool = with_pool + + layer_config = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + } + assert depth in layer_config.keys(), \ + "supported depth are {} but input layer is {}".format( + layer_config.keys(), depth) + + layers = layer_config[depth] + + in_channels = 64 + out_channels = [64, 128, 256, 512] + + self.conv = ConvBNLayer( + num_channels=3, num_filters=64, filter_size=7, stride=2, act='relu') + self.pool = Pool2D( + pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') + + self.layers = [] + for idx, num_blocks in enumerate(layers): + blocks = [] + shortcut = False + for b in range(num_blocks): + if b == 1: + in_channels = out_channels[idx] * Block.expansion + block = Block( + num_channels=in_channels, + num_filters=out_channels[idx], + stride=2 if b == 0 and idx != 0 else 1, + shortcut=shortcut) + blocks.append(block) + shortcut = True + layer = self.add_sublayer("layer_{}".format(idx), + Sequential(*blocks)) + self.layers.append(layer) + + if with_pool: + self.global_pool = Pool2D( + pool_size=7, pool_type='avg', global_pooling=True) + + if num_classes > 0: + stdv = 1.0 / math.sqrt(out_channels[-1] * Block.expansion * 1.0) + self.fc_input_dim = out_channels[-1] * Block.expansion * 1 * 1 + self.fc = Linear( + self.fc_input_dim, + num_classes, + act=classifier_activation, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv))) + + def forward(self, inputs): + x = self.conv(inputs) + x = self.pool(x) + for layer in self.layers: + x = layer(x) + + if self.with_pool: + x = self.global_pool(x) + + if self.num_classes > -1: + x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim]) + x = self.fc(x) + return x + + +def _resnet(arch, Block, depth, pretrained, **kwargs): + model = ResNet(Block, depth, **kwargs) + if pretrained: + assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( + arch) + weight_path = get_weights_path_from_url(model_urls[arch][0], + model_urls[arch][1]) + assert weight_path.endswith( + '.pdparams'), "suffix of weight must be .pdparams" + param, _ = fluid.load_dygraph(weight_path) + model.set_dict(param) + + return model + + +def resnet18(pretrained=False, **kwargs): + """ResNet 18-layer model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import resnet18 + + # build model + model = resnet18() + + # build model and load imagenet pretrained weight + # model = resnet18(pretrained=True) + """ + return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs) + + +def resnet34(pretrained=False, **kwargs): + """ResNet 34-layer model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import resnet34 + + # build model + model = resnet34() + + # build model and load imagenet pretrained weight + # model = resnet34(pretrained=True) + """ + return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs) + + +def resnet50(pretrained=False, **kwargs): + """ResNet 50-layer model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import resnet50 + + # build model + model = resnet50() + + # build model and load imagenet pretrained weight + # model = resnet50(pretrained=True) + """ + return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs) + + +def resnet101(pretrained=False, **kwargs): + """ResNet 101-layer model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import resnet101 + + # build model + model = resnet101() + + # build model and load imagenet pretrained weight + # model = resnet101(pretrained=True) + """ + return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs) + + +def resnet152(pretrained=False, **kwargs): + """ResNet 152-layer model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import resnet152 + + # build model + model = resnet152() + + # build model and load imagenet pretrained weight + # model = resnet152(pretrained=True) + """ + return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs) diff --git a/hub_module/modules/image/colorization/deoldify/utils.py b/hub_module/modules/image/colorization/deoldify/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a619a280d129df2e705367711694aaa7166d4899 --- /dev/null +++ b/hub_module/modules/image/colorization/deoldify/utils.py @@ -0,0 +1,231 @@ +import os +import sys +import base64 + +import cv2 +import numpy as np +import paddle +import paddle.nn as nn +from PIL import Image + + +def cv2_to_base64(image): + data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(data.tostring()).decode('utf8') + + +def base64_to_cv2(b64str): + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data + + +def is_listy(x): + return isinstance(x, (tuple, list)) + + +class Hook(): + "Create a hook on `m` with `hook_func`." + + def __init__(self, m, hook_func, is_forward=True, detach=True): + self.hook_func, self.detach, self.stored = hook_func, detach, None + f = m.register_forward_post_hook if is_forward else m.register_backward_hook + self.hook = f(self.hook_fn) + self.removed = False + + def hook_fn(self, module, input, output): + "Applies `hook_func` to `module`, `input`, `output`." + if self.detach: + input = (o.detach() + for o in input) if is_listy(input) else input.detach() + output = (o.detach() + for o in output) if is_listy(output) else output.detach() + self.stored = self.hook_func(module, input, output) + + def remove(self): + "Remove the hook from the model." + if not self.removed: + self.hook.remove() + self.removed = True + + def __enter__(self, *args): + return self + + def __exit__(self, *args): + self.remove() + + +class Hooks(): + "Create several hooks on the modules in `ms` with `hook_func`." + + def __init__(self, ms, hook_func, is_forward=True, detach=True): + self.hooks = [] + try: + for m in ms: + self.hooks.append(Hook(m, hook_func, is_forward, detach)) + except Exception as e: + pass + + def __getitem__(self, i: int) -> Hook: + return self.hooks[i] + + def __len__(self) -> int: + return len(self.hooks) + + def __iter__(self): + return iter(self.hooks) + + @property + def stored(self): + return [o.stored for o in self] + + def remove(self): + "Remove the hooks from the model." + for h in self.hooks: + h.remove() + + def __enter__(self, *args): + return self + + def __exit__(self, *args): + self.remove() + + +def _hook_inner(m, i, o): + return o if isinstance( + o, paddle.fluid.framework.Variable) else o if is_listy(o) else list(o) + + +def hook_output(module, detach=True, grad=False): + "Return a `Hook` that stores activations of `module` in `self.stored`" + return Hook(module, _hook_inner, detach=detach, is_forward=not grad) + + +def hook_outputs(modules, detach=True, grad=False): + "Return `Hooks` that store activations of all `modules` in `self.stored`" + return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad) + + +def model_sizes(m, size=(64, 64)): + "Pass a dummy input through the model `m` to get the various sizes of activations." + with hook_outputs(m) as hooks: + x = dummy_eval(m, size) + return [o.stored.shape for o in hooks] + + +def dummy_eval(m, size=(64, 64)): + "Pass a `dummy_batch` in evaluation mode in `m` with `size`." + m.eval() + return m(dummy_batch(size)) + + +def dummy_batch(size=(64, 64), ch_in=3): + "Create a dummy batch to go through `m` with `size`." + arr = np.random.rand(1, ch_in, *size).astype('float32') * 2 - 1 + return paddle.to_tensor(arr) + + + +class _SpectralNorm(nn.SpectralNorm): + def __init__(self, + weight_shape, + dim=0, + power_iters=1, + eps=1e-12, + dtype='float32'): + super(_SpectralNorm, self).__init__(weight_shape, dim, power_iters, eps, + dtype) + + def forward(self, weight): + inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v} + out = self._helper.create_variable_for_type_inference(self._dtype) + _power_iters = self._power_iters if self.training else 0 + self._helper.append_op(type="spectral_norm", + inputs=inputs, + outputs={ + "Out": out, + }, + attrs={ + "dim": self._dim, + "power_iters": _power_iters, + "eps": self._eps, + }) + + return out + + +class Spectralnorm(paddle.nn.Layer): + def __init__(self, layer, dim=0, power_iters=1, eps=1e-12, dtype='float32'): + super(Spectralnorm, self).__init__() + self.spectral_norm = _SpectralNorm(layer.weight.shape, dim, power_iters, + eps, dtype) + self.dim = dim + self.power_iters = power_iters + self.eps = eps + self.layer = layer + weight = layer._parameters['weight'] + del layer._parameters['weight'] + self.weight_orig = self.create_parameter(weight.shape, + dtype=weight.dtype) + self.weight_orig.set_value(weight) + + def forward(self, x): + weight = self.spectral_norm(self.weight_orig) + self.layer.weight = weight + out = self.layer(x) + return out + + +def video2frames(video_path, outpath, **kargs): + def _dict2str(kargs): + cmd_str = '' + for k, v in kargs.items(): + cmd_str += (' ' + str(k) + ' ' + str(v)) + return cmd_str + + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] + vid_name = video_path.split('/')[-1].split('.')[0] + out_full_path = os.path.join(outpath, vid_name) + + if not os.path.exists(out_full_path): + os.makedirs(out_full_path) + + # video file name + outformat = out_full_path + '/%08d.png' + + cmd = ffmpeg + cmd = ffmpeg + [' -i ', video_path, ' -start_number ', ' 0 ', outformat] + + cmd = ''.join(cmd) + _dict2str(kargs) + + if os.system(cmd) != 0: + raise RuntimeError('ffmpeg process video: {} error'.format(vid_name)) + + sys.stdout.flush() + return out_full_path + + +def frames2video(frame_path, video_path, r): + + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] + cmd = ffmpeg + [ + ' -r ', r, ' -f ', ' image2 ', ' -i ', frame_path, ' -vcodec ', + ' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', video_path + ] + cmd = ''.join(cmd) + + if os.system(cmd) != 0: + raise RuntimeError('ffmpeg process video: {} error'.format(video_path)) + + sys.stdout.flush() + + +def is_image(input): + try: + img = Image.open(input) + _ = img.size + + return True + except: + return False \ No newline at end of file diff --git a/hub_module/modules/image/colorization/photo_restoration/README.md b/hub_module/modules/image/colorization/photo_restoration/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4ba39c57b7dc20f88fbb13a20bc3953154dd0420 --- /dev/null +++ b/hub_module/modules/image/colorization/photo_restoration/README.md @@ -0,0 +1,98 @@ +## 模型概述 + +photo_restoration 是针对老照片修复的模型。它主要由两个部分组成:着色和超分。着色模型基于deoldify +,超分模型基于realsr. 用户可以根据自己的需求选择对图像进行着色或超分操作。因此在使用该模型时,请预先安装deoldify和realsr两个模型。 + + +## API + +```python +def run_image(self, + input, + model_select= ['Colorization', 'SuperResolution'], + save_path = 'photo_restoration'): +``` + +预测API,用于图片修复。 + +**参数** + +* input (numpy.ndarray|str): 图片数据,numpy.ndarray 或者 str形式。ndarray.shape 为 \[H, W, C\],BGR格式; str为图片的路径。 + +* model_select (list\[str\]): 选择对图片对操作,\['Colorization'\]对图像只进行着色操作, \['SuperResolution'\]对图像只进行超分操作; +默认值为\['Colorization', 'SuperResolution'\]。 + +* save_path (str): 保存图片的路径, 默认为'photo_restoration'。 + +**返回** + +* output (numpy.ndarray): 照片修复结果,ndarray.shape 为 \[H, W, C\],BGR格式。 + + + +## 代码示例 + +图片修复代码示例: + +```python +import cv2 +import paddlehub as hub + +model = hub.Module('photo_restoration', visualization=True) +im = cv2.imread('/PATH/TO/IMAGE') +res = model.run_image(im) + +``` + +## 服务部署 + +PaddleHub Serving可以部署一个照片修复的在线服务。 + +## 第一步:启动PaddleHub Serving + +运行启动命令: + +```shell +$ hub serving start -m photo_restoration +``` + +这样就完成了一个照片修复的服务化API的部署,默认端口号为8866。 + +**NOTE:** 如使用GPU预测,则需要在启动服务之前,设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。 + +## 第二步:发送预测请求 + +配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果 + +```python +import requests +import json +import base64 + +import cv2 +import numpy as np + +def cv2_to_base64(image): + data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(data.tostring()).decode('utf8') +def base64_to_cv2(b64str): + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data + +# 发送HTTP请求 +org_im = cv2.imread('PATH/TO/IMAGE') +data = {'images':cv2_to_base64(org_im), 'model_select': ['Colorization', 'SuperResolution']} +headers = {"Content-type": "application/json"} +url = "http://127.0.0.1:8866/predict/photo_restoration" +r = requests.post(url=url, headers=headers, data=json.dumps(data)) +img = base64_to_cv2(r.json()["results"]) +cv2.imwrite('PATH/TO/SAVE/IMAGE', img) +``` + +### 依赖 + +paddlepaddle >= 2.0.0rc + +paddlehub >= 1.8.2 diff --git a/hub_module/modules/image/colorization/photo_restoration/module.py b/hub_module/modules/image/colorization/photo_restoration/module.py new file mode 100644 index 0000000000000000000000000000000000000000..916f900de5131e2320447cdf45643eef95cff48b --- /dev/null +++ b/hub_module/modules/image/colorization/photo_restoration/module.py @@ -0,0 +1,81 @@ +# coding:utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time + +import cv2 +import paddle.nn as nn +import paddlehub as hub +from paddlehub.module.module import moduleinfo, serving, Module + +import photo_restoration.utils as U + + +@moduleinfo(name="photo_restoration", + type="CV/image_editing", + author="paddlepaddle", + author_email="", + summary="photo_restoration is a photo restoration model based on deoldify and realsr.", + version="1.0.0") +class PhotoRestoreModel(Module): + """ + PhotoRestoreModel + + Args: + load_checkpoint(str): Checkpoint save path, default is None. + visualization (bool): Whether to save the estimation result. Default is True. + """ + def _initialize(self, visualization: bool = False): + #super(PhotoRestoreModel, self).__init__() + self.deoldify = hub.Module(name='deoldify') + self.realsr = hub.Module(name='realsr') + self.visualization = visualization + + + def run_image(self, input, model_select: list = ['Colorization', 'SuperResolution'], save_path: str = 'photo_restoration'): + self.models = [] + for model in model_select: + print('\n {} model proccess start..'.format(model)) + if model == 'Colorization': + self.deoldify.eval() + self.models.append(self.deoldify) + if model == 'SuperResolution': + self.realsr.eval() + self.models.append(self.realsr) + + for model in self.models: + output = model.run_image(input) + input = output + if self.visualization: + if not os.path.exists(save_path): + os.mkdir(save_path) + img_name = str(time.time()) + '.png' + save_img = os.path.join(save_path, img_name) + cv2.imwrite(save_img, output) + print("save result at: ", save_img) + + return output + + @serving + def serving_method(self, images, model_select): + """ + Run as a service. + """ + print(model_select) + images_decode = U.base64_to_cv2(images) + results = self.run_image(input=images_decode, model_select=model_select) + results = U.cv2_to_base64(results) + return results + diff --git a/hub_module/modules/image/colorization/photo_restoration/utils.py b/hub_module/modules/image/colorization/photo_restoration/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c3756451b9689b48ca6ce2093d1117e94e34373f --- /dev/null +++ b/hub_module/modules/image/colorization/photo_restoration/utils.py @@ -0,0 +1,15 @@ +import base64 +import cv2 +import numpy as np + + +def cv2_to_base64(image): + data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(data.tostring()).decode('utf8') + + +def base64_to_cv2(b64str): + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data diff --git a/hub_module/modules/image/super_resolution/realsr/README.md b/hub_module/modules/image/super_resolution/realsr/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9e854f421e783563233e11e5ea33b6cc85080b52 --- /dev/null +++ b/hub_module/modules/image/super_resolution/realsr/README.md @@ -0,0 +1,121 @@ + +## 模型概述 +realsr是用于图像和视频超分模型,该模型基于Toward Real-World Single Image Super-Resolution: A New Benchmark and A New Mode,它能够将输入的图片和视频超分四倍。 + +## API 说明 + +```python +def predict(self, input): +``` + +超分API,得到超分后的图片或者视频。 + + +**参数** + +* input (str): 图片或者视频的路径; + +**返回** + +若输入是图片,返回值为: +* pred_img(np.ndarray): BGR图片数据; +* out_path(str): 保存图片路径。 + +若输入是视频,返回值为: +* frame_pattern_combined(str): 视频超分后单帧数据保存路径; +* vid_out_path(str): 视频保存路径。 + +```python +def run_image(self, img): +``` +图像超分API, 得到超分后的图片。 + +**参数** + +* img (str|np.ndarray): 图片路径或则BGR格式图片。 + +**返回** + +* pred_img(np.ndarray): BGR图片数据; + +```python +def run_video(self, video): +``` +视频超分API, 得到超分后的视频。 + +**参数** + +* video(str): 待处理视频路径。 + +**返回** + +* frame_pattern_combined(str): 视频超分后单帧数据保存路径; +* vid_out_path(str): 视频保存路径。 + +## 预测代码示例 + +```python +import paddlehub as hub + +model = hub.Module('realsr') +model.predict('/PATH/TO/IMAGE/OR/VIDEO') +``` + +## 服务部署 + +PaddleHub Serving可以部署一个在线照片超分服务。 + +## 第一步:启动PaddleHub Serving + +运行启动命令: +```shell +$ hub serving start -m realsr +``` + +这样就完成了一个图像超分的在线服务API的部署,默认端口号为8866。 + +**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。 + +## 第二步:发送预测请求 + +配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果 + +```python +import requests +import json +import base64 + +import cv2 +import numpy as np + +def cv2_to_base64(image): + data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(data.tostring()).decode('utf8') +def base64_to_cv2(b64str): + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data + +# 发送HTTP请求 +org_im = cv2.imread('/PATH/TO/IMAGE') +data = {'images':cv2_to_base64(org_im)} +headers = {"Content-type": "application/json"} +url = "http://127.0.0.1:8866/predict/realsr" +r = requests.post(url=url, headers=headers, data=json.dumps(data)) +img = base64_to_cv2(r.json()["results"]) +cv2.imwrite('/PATH/TO/SAVE/IMAGE', img) + +``` + +## 模型相关信息 + +### 模型代码 + +https://github.com/csjcai/RealSR + +### 依赖 + +paddlepaddle >= 2.0.0rc + +paddlehub >= 1.8.3 \ No newline at end of file diff --git a/hub_module/modules/image/super_resolution/realsr/module.py b/hub_module/modules/image/super_resolution/realsr/module.py new file mode 100644 index 0000000000000000000000000000000000000000..16098edd88e40cae8e7ff793378eaeed96d1798a --- /dev/null +++ b/hub_module/modules/image/super_resolution/realsr/module.py @@ -0,0 +1,145 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import glob + +from tqdm import tqdm +import numpy as np +from PIL import Image +import paddle +import paddle.nn as nn +from paddlehub.module.module import moduleinfo, serving, Module + +from realsr.rrdb import RRDBNet +import realsr.utils as U + + +@moduleinfo(name="realsr", + type="CV/image_editing", + author="paddlepaddle", + author_email="", + summary="realsr is a super resolution model", + version="1.0.0") +class RealSRPredictor(Module): + def _initialize(self, output='output', weight_path=None, load_checkpoint: str = None): + #super(RealSRPredictor, self).__init__() + self.input = input + self.output = os.path.join(output, 'RealSR') + self.model = RRDBNet(3, 3, 64, 23) + + if load_checkpoint is not None: + state_dict = paddle.load(load_checkpoint) + self.model.load_dict(state_dict) + print("load custom checkpoint success") + + else: + checkpoint = os.path.join(self.directory, 'DF2K_JPEG.pdparams') + state_dict = paddle.load(checkpoint) + self.model.load_dict(state_dict) + print("load pretrained checkpoint success") + + self.model.eval() + + def norm(self, img): + img = np.array(img).transpose([2, 0, 1]).astype('float32') / 255.0 + return img.astype('float32') + + def denorm(self, img): + img = img.transpose((1, 2, 0)) + return (img * 255).clip(0, 255).astype('uint8') + + def run_image(self, img): + if isinstance(img, str): + ori_img = Image.open(img).convert('RGB') + elif isinstance(img, np.ndarray): + # ori_img = Image.fromarray(img).convert('RGB') + ori_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + elif isinstance(img, Image.Image): + ori_img = img + + img = self.norm(ori_img) + x = paddle.to_tensor(img[np.newaxis, ...]) + out = self.model(x) + + pred_img = self.denorm(out.numpy()[0]) + # pred_img = Image.fromarray(pred_img) + pred_img = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR) + + return pred_img + + def run_video(self, video): + base_name = os.path.basename(video).split('.')[0] + output_path = os.path.join(self.output, base_name) + pred_frame_path = os.path.join(output_path, 'frames_pred') + + if not os.path.exists(output_path): + os.makedirs(output_path) + + if not os.path.exists(pred_frame_path): + os.makedirs(pred_frame_path) + + cap = cv2.VideoCapture(video) + fps = cap.get(cv2.CAP_PROP_FPS) + + out_path = U.video2frames(video, output_path) + + frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) + + for frame in tqdm(frames): + pred_img = self.run_image(frame) + pred_img = cv2.cvtColor(pred_img, cv2.COLOR_BGR2RGB) + pred_img = Image.fromarray(pred_img) + frame_name = os.path.basename(frame) + pred_img.save(os.path.join(pred_frame_path, frame_name)) + + frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') + + vid_out_path = os.path.join(output_path, + '{}_realsr_out.mp4'.format(base_name)) + U.frames2video(frame_pattern_combined, vid_out_path, str(int(fps))) + print("save result at {}".format(vid_out_path)) + + return frame_pattern_combined, vid_out_path + + def predict(self, input): + if not os.path.exists(self.output): + os.makedirs(self.output) + + if not U.is_image(input): + return self.run_video(input) + else: + pred_img = self.run_image(input) + + out_path = None + if self.output: + final = cv2.cvtColor(pred_img, cv2.COLOR_BGR2RGB) + final = Image.fromarray(final) + base_name = os.path.splitext(os.path.basename(input))[0] + out_path = os.path.join(self.output, base_name + '.png') + final.save(out_path) + print('save result at {}'.format(out_path)) + + return pred_img, out_path + + @serving + def serving_method(self, images, **kwargs): + """ + Run as a service. + """ + images_decode = U.base64_to_cv2(images) + results = self.run_image(img=images_decode) + results = U.cv2_to_base64(results) + return results diff --git a/hub_module/modules/image/super_resolution/realsr/rrdb.py b/hub_module/modules/image/super_resolution/realsr/rrdb.py new file mode 100644 index 0000000000000000000000000000000000000000..64f5fff9bf3623fadda11acc8b031e43f916c837 --- /dev/null +++ b/hub_module/modules/image/super_resolution/realsr/rrdb.py @@ -0,0 +1,137 @@ +import functools +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class Registry(object): + """ + The registry that provides name -> object mapping, to support third-party users' custom modules. + To create a registry (inside segmentron): + .. code-block:: python + BACKBONE_REGISTRY = Registry('BACKBONE') + To register an object: + .. code-block:: python + @BACKBONE_REGISTRY.register() + class MyBackbone(): + ... + Or: + .. code-block:: python + BACKBONE_REGISTRY.register(MyBackbone) + """ + + def __init__(self, name): + """ + Args: + name (str): the name of this registry + """ + self._name = name + + self._obj_map = {} + + def _do_register(self, name, obj): + assert ( + name not in self._obj_map + ), "An object named '{}' was already registered in '{}' registry!".format(name, self._name) + self._obj_map[name] = obj + + def register(self, obj=None, name=None): + """ + Register the given object under the the name `obj.__name__`. + Can be used as either a decorator or not. See docstring of this class for usage. + """ + if obj is None: + # used as a decorator + def deco(func_or_class, name=name): + if name is None: + name = func_or_class.__name__ + self._do_register(name, func_or_class) + return func_or_class + + return deco + + # used as a function call + if name is None: + name = obj.__name__ + self._do_register(name, obj) + + def get(self, name): + ret = self._obj_map.get(name) + if ret is None: + raise KeyError("No object named '{}' found in '{}' registry!".format(name, self._name)) + + return ret + + +class ResidualDenseBlock_5C(nn.Layer): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2D(nf, gc, 3, 1, 1, bias_attr=bias) + self.conv2 = nn.Conv2D(nf + gc, gc, 3, 1, 1, bias_attr=bias) + self.conv3 = nn.Conv2D(nf + 2 * gc, gc, 3, 1, 1, bias_attr=bias) + self.conv4 = nn.Conv2D(nf + 3 * gc, gc, 3, 1, 1, bias_attr=bias) + self.conv5 = nn.Conv2D(nf + 4 * gc, nf, 3, 1, 1, bias_attr=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(paddle.concat((x, x1), 1))) + x3 = self.lrelu(self.conv3(paddle.concat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(paddle.concat((x, x1, x2, x3), 1))) + x5 = self.conv5(paddle.concat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(nn.Layer): + '''Residual in Residual Dense Block''' + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + +GENERATORS = Registry("GENERATOR") + +@GENERATORS.register() +class RRDBNet(nn.Layer): + def __init__(self, in_nc, out_nc, nf, nb, gc=32): + super(RRDBNet, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + + self.conv_first = nn.Conv2D(in_nc, nf, 3, 1, 1, bias_attr=True) + self.RRDB_trunk = make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True) + #### upsampling + self.upconv1 = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True) + self.upconv2 = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True) + self.HRconv = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True) + self.conv_last = nn.Conv2D(nf, out_nc, 3, 1, 1, bias_attr=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2) + + def forward(self, x): + fea = self.conv_first(x) + trunk = self.trunk_conv(self.RRDB_trunk(fea)) + fea = fea + trunk + + fea = self.lrelu( + self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) + fea = self.lrelu( + self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + return out diff --git a/hub_module/modules/image/super_resolution/realsr/utils.py b/hub_module/modules/image/super_resolution/realsr/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a345bb8ddc4f5f03b6c3f738503ac72daacb1e8a --- /dev/null +++ b/hub_module/modules/image/super_resolution/realsr/utils.py @@ -0,0 +1,69 @@ +import os +import sys +import base64 + +import cv2 +from PIL import Image +import numpy as np + +def video2frames(video_path, outpath, **kargs): + def _dict2str(kargs): + cmd_str = '' + for k, v in kargs.items(): + cmd_str += (' ' + str(k) + ' ' + str(v)) + return cmd_str + + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] + vid_name = video_path.split('/')[-1].split('.')[0] + out_full_path = os.path.join(outpath, vid_name) + + if not os.path.exists(out_full_path): + os.makedirs(out_full_path) + + # video file name + outformat = out_full_path + '/%08d.png' + + cmd = ffmpeg + cmd = ffmpeg + [' -i ', video_path, ' -start_number ', ' 0 ', outformat] + + cmd = ''.join(cmd) + _dict2str(kargs) + + if os.system(cmd) != 0: + raise RuntimeError('ffmpeg process video: {} error'.format(vid_name)) + + sys.stdout.flush() + return out_full_path + + +def frames2video(frame_path, video_path, r): + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] + cmd = ffmpeg + [ + ' -r ', r, ' -f ', ' image2 ', ' -i ', frame_path, ' -pix_fmt ', ' yuv420p ', video_path + ] + cmd = ''.join(cmd) + + if os.system(cmd) != 0: + raise RuntimeError('ffmpeg process video: {} error'.format(video_path)) + + sys.stdout.flush() + + +def is_image(input): + try: + img = Image.open(input) + _ = img.size + return True + except: + return False + + +def cv2_to_base64(image): + data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(data.tostring()).decode('utf8') + + +def base64_to_cv2(b64str): + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data \ No newline at end of file diff --git a/hub_module/modules/video/video_restore/dain/README.md b/hub_module/modules/video/video_restore/dain/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0ce279125850fb77c5e273171fae72183fdedbb5 --- /dev/null +++ b/hub_module/modules/video/video_restore/dain/README.md @@ -0,0 +1,40 @@ +## 模型概述 +dain是视频插帧的模型,该模型基于Depth-Aware Video Frame Interpolation,可以用于老旧视频补帧从而提升视频效果。 + +## API 说明 + +```python +def predict(self, video_path): +``` + +补帧API,得到补帧后的视频。 + +**参数** + +* video_path (str): 原始视频的路径; + +**返回** + +* frame_pattern_combined(str): 视频补帧后单帧数据保存路径; +* vid_out_path(str): 视频保存路径。 + + +## 预测代码示例 + +```python +import paddlehub as hub + +model = hub.Module('dain') +model.predict('/PATH/TO/VIDEO') +``` + +## 模型相关信息 + +### 模型代码 +https://github.com/baowenbo/DAIN + +### 依赖 + +paddlepaddle >= 2.0.0rc + +paddlehub >= 1.8.3 \ No newline at end of file diff --git a/hub_module/modules/video/video_restore/dain/module.py b/hub_module/modules/video/video_restore/dain/module.py new file mode 100644 index 0000000000000000000000000000000000000000..098ee40e5d79108acce29dd9c82e83c9ab0ffc94 --- /dev/null +++ b/hub_module/modules/video/video_restore/dain/module.py @@ -0,0 +1,300 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import glob +import shutil +import numpy as np +from tqdm import tqdm +from imageio import imread, imsave + +import paddle +import paddle.fluid as fluid +from paddlehub.module.module import moduleinfo, serving, Module + +import dain.utils as U + + +@moduleinfo(name="dain", + type="CV/image_editing", + author="paddlepaddle", + author_email="", + summary="Dain is a model for video frame interpolation", + version="1.0.0") +class DAINPredictor(Module): + def _initialize(self, + output_path='output', + weight_path=None, + time_step=0.5, + use_gpu=True, + key_frame_thread=0., + remove_duplicates=True): + paddle.enable_static() + self.output_path = os.path.join(output_path, 'DAIN') + + if weight_path is None: + cur_path = os.path.abspath(os.path.dirname(__file__)) + self.weight_path = os.path.join(cur_path, 'DAIN_weight') + + self.time_step = time_step + self.key_frame_thread = key_frame_thread + self.remove_duplicates = remove_duplicates + + self.build_inference_model() + + def build_inference_model(self): + if paddle.in_dynamic_mode(): + # todo self.model = build_model(self.cfg) + pass + else: + place = paddle.fluid.framework._current_expected_place() + self.exe = paddle.fluid.Executor(place) + file_names = os.listdir(self.weight_path) + for file_name in file_names: + if file_name.find('model') > -1: + model_file = file_name + elif file_name.find('param') > -1: + param_file = file_name + + self.program, self.feed_names, self.fetch_targets = paddle.static.load_inference_model( + dirname=self.weight_path, + executor=self.exe, + model_filename=model_file, + params_filename=param_file) + + def base_forward(self, inputs): + if paddle.in_dynamic_mode(): + out = self.model(inputs) + else: + feed_dict = {} + if isinstance(inputs, dict): + feed_dict = inputs + elif isinstance(inputs, (list, tuple)): + for i, feed_name in enumerate(self.feed_names): + feed_dict[feed_name] = inputs[i] + else: + feed_dict[self.feed_names[0]] = inputs + + out = self.exe.run(self.program, + fetch_list=self.fetch_targets, + feed=feed_dict) + + return out + + def predict(self, video_path): + frame_path_input = os.path.join(self.output_path, 'frames-input') + frame_path_interpolated = os.path.join(self.output_path, + 'frames-interpolated') + frame_path_combined = os.path.join(self.output_path, 'frames-combined') + video_path_output = os.path.join(self.output_path, 'videos-output') + + if not os.path.exists(self.output_path): + os.makedirs(self.output_path) + if not os.path.exists(frame_path_input): + os.makedirs(frame_path_input) + if not os.path.exists(frame_path_interpolated): + os.makedirs(frame_path_interpolated) + if not os.path.exists(frame_path_combined): + os.makedirs(frame_path_combined) + if not os.path.exists(video_path_output): + os.makedirs(video_path_output) + + timestep = self.time_step + num_frames = int(1.0 / timestep) - 1 + + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) + print("Old fps (frame rate): ", fps) + + times_interp = int(1.0 / timestep) + r2 = str(int(fps) * times_interp) + print("New fps (frame rate): ", r2) + + out_path = U.video2frames(video_path, frame_path_input) + + vidname = video_path.split('/')[-1].split('.')[0] + + frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) + orig_frames = len(frames) + need_frames = orig_frames * times_interp + + if self.remove_duplicates: + frames = self.remove_duplicate_frames(out_path) + left_frames = len(frames) + timestep = left_frames / need_frames + num_frames = int(1.0 / timestep) - 1 + + img = imread(frames[0]) + + int_width = img.shape[1] + int_height = img.shape[0] + channel = img.shape[2] + if not channel == 3: + return + + if int_width != ((int_width >> 7) << 7): + int_width_pad = (((int_width >> 7) + 1) << 7) # more than necessary + padding_left = int((int_width_pad - int_width) / 2) + padding_right = int_width_pad - int_width - padding_left + else: + int_width_pad = int_width + padding_left = 32 + padding_right = 32 + + if int_height != ((int_height >> 7) << 7): + int_height_pad = ( + ((int_height >> 7) + 1) << 7) # more than necessary + padding_top = int((int_height_pad - int_height) / 2) + padding_bottom = int_height_pad - int_height - padding_top + else: + int_height_pad = int_height + padding_top = 32 + padding_bottom = 32 + + frame_num = len(frames) + + if not os.path.exists(os.path.join(frame_path_interpolated, vidname)): + os.makedirs(os.path.join(frame_path_interpolated, vidname)) + if not os.path.exists(os.path.join(frame_path_combined, vidname)): + os.makedirs(os.path.join(frame_path_combined, vidname)) + + for i in tqdm(range(frame_num - 1)): + first = frames[i] + second = frames[i + 1] + + img_first = imread(first) + img_second = imread(second) + '''--------------Frame change test------------------------''' + img_first_gray = np.dot(img_first[..., :3], [0.299, 0.587, 0.114]) + img_second_gray = np.dot(img_second[..., :3], [0.299, 0.587, 0.114]) + + img_first_gray = img_first_gray.flatten(order='C') + img_second_gray = img_second_gray.flatten(order='C') + corr = np.corrcoef(img_first_gray, img_second_gray)[0, 1] + key_frame = False + if corr < self.key_frame_thread: + key_frame = True + '''-------------------------------------------------------''' + + X0 = img_first.astype('float32').transpose((2, 0, 1)) / 255 + X1 = img_second.astype('float32').transpose((2, 0, 1)) / 255 + + assert (X0.shape[1] == X1.shape[1]) + assert (X0.shape[2] == X1.shape[2]) + + X0 = np.pad(X0, ((0, 0), (padding_top, padding_bottom), \ + (padding_left, padding_right)), mode='edge') + X1 = np.pad(X1, ((0, 0), (padding_top, padding_bottom), \ + (padding_left, padding_right)), mode='edge') + + X0 = np.expand_dims(X0, axis=0) + X1 = np.expand_dims(X1, axis=0) + + X0 = np.expand_dims(X0, axis=0) + X1 = np.expand_dims(X1, axis=0) + + X = np.concatenate((X0, X1), axis=0) + + o = self.base_forward(X) + + y_ = o[0] + + y_ = [ + np.transpose( + 255.0 * item.clip( + 0, 1.0)[0, :, padding_top:padding_top + int_height, + padding_left:padding_left + int_width], + (1, 2, 0)) for item in y_ + ] + time_offsets = [kk * timestep for kk in range(1, 1 + num_frames, 1)] + + count = 1 + for item, time_offset in zip(y_, time_offsets): + out_dir = os.path.join(frame_path_interpolated, vidname, + "{:0>6d}_{:0>4d}.png".format(i, count)) + count = count + 1 + imsave(out_dir, np.round(item).astype(np.uint8)) + + num_frames = int(1.0 / timestep) - 1 + + input_dir = os.path.join(frame_path_input, vidname) + interpolated_dir = os.path.join(frame_path_interpolated, vidname) + combined_dir = os.path.join(frame_path_combined, vidname) + self.combine_frames(input_dir, interpolated_dir, combined_dir, + num_frames) + + frame_pattern_combined = os.path.join(frame_path_combined, vidname, + '%08d.png') + video_pattern_output = os.path.join(video_path_output, vidname + '.mp4') + if os.path.exists(video_pattern_output): + os.remove(video_pattern_output) + U.frames2video(frame_pattern_combined, video_pattern_output, r2) + print('Save result at {}.'.format(video_pattern_output)) + + return frame_pattern_combined, video_pattern_output + + def combine_frames(self, input, interpolated, combined, num_frames): + frames1 = sorted(glob.glob(os.path.join(input, '*.png'))) + frames2 = sorted(glob.glob(os.path.join(interpolated, '*.png'))) + num1 = len(frames1) + num2 = len(frames2) + + for i in range(num1): + src = frames1[i] + imgname = int(src.split('/')[-1].split('.')[-2]) + assert i == imgname + dst = os.path.join(combined, + '{:08d}.png'.format(i * (num_frames + 1))) + shutil.copy2(src, dst) + if i < num1 - 1: + try: + for k in range(num_frames): + src = frames2[i * num_frames + k] + dst = os.path.join( + combined, + '{:08d}.png'.format(i * (num_frames + 1) + k + 1)) + shutil.copy2(src, dst) + except Exception as e: + print(e) + + def remove_duplicate_frames(self, paths): + def dhash(image, hash_size=8): + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + resized = cv2.resize(gray, (hash_size + 1, hash_size)) + diff = resized[:, 1:] > resized[:, :-1] + return sum([2 ** i for (i, v) in enumerate(diff.flatten()) if v]) + + hashes = {} + image_paths = sorted(glob.glob(os.path.join(paths, '*.png'))) + for image_path in image_paths: + image = cv2.imread(image_path) + h = dhash(image) + p = hashes.get(h, []) + p.append(image_path) + hashes[h] = p + + for (h, hashed_paths) in hashes.items(): + if len(hashed_paths) > 1: + for p in hashed_paths[1:]: + os.remove(p) + + frames = sorted(glob.glob(os.path.join(paths, '*.png'))) + for fid, frame in enumerate(frames): + new_name = '{:08d}'.format(fid) + '.png' + new_name = os.path.join(paths, new_name) + os.rename(frame, new_name) + + frames = sorted(glob.glob(os.path.join(paths, '*.png'))) + return frames diff --git a/hub_module/modules/video/video_restore/dain/utils.py b/hub_module/modules/video/video_restore/dain/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c2c6b042938a870f2329ae9e57a13c43ee5bc319 --- /dev/null +++ b/hub_module/modules/video/video_restore/dain/utils.py @@ -0,0 +1,59 @@ +import os +import sys +import base64 + +import cv2 + +def video2frames(video_path, outpath, **kargs): + def _dict2str(kargs): + cmd_str = '' + for k, v in kargs.items(): + cmd_str += (' ' + str(k) + ' ' + str(v)) + return cmd_str + + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] + vid_name = video_path.split('/')[-1].split('.')[0] + out_full_path = os.path.join(outpath, vid_name) + + if not os.path.exists(out_full_path): + os.makedirs(out_full_path) + + # video file name + outformat = out_full_path + '/%08d.png' + + cmd = ffmpeg + cmd = ffmpeg + [' -i ', video_path, ' -start_number ', ' 0 ', outformat] + + cmd = ''.join(cmd) + _dict2str(kargs) + + if os.system(cmd) != 0: + raise RuntimeError('ffmpeg process video: {} error'.format(vid_name)) + + sys.stdout.flush() + return out_full_path + + +def frames2video(frame_path, video_path, r): + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] + cmd = ffmpeg + [ + ' -r ', r, ' -f ', ' image2 ', ' -i ', frame_path, ' -vcodec ', + ' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', video_path + ] + cmd = ''.join(cmd) + + if os.system(cmd) != 0: + raise RuntimeError('ffmpeg process video: {} error'.format(video_path)) + + sys.stdout.flush() + + +def cv2_to_base64(image): + data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(data.tostring()).decode('utf8') + + +def base64_to_cv2(b64str): + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data \ No newline at end of file diff --git a/hub_module/modules/video/video_restore/edvr/README.md b/hub_module/modules/video/video_restore/edvr/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4f298b0af0fd966fb2b7bf70c5943e1757b1db43 --- /dev/null +++ b/hub_module/modules/video/video_restore/edvr/README.md @@ -0,0 +1,39 @@ +## 模型概述 +edvr是视频超分的模型,该模型基于Video Restoration with Enhanced Deformable Convolutional Networks,可以用于提升老旧视频的分辨率从而提升视频效果。 +## API 说明 + +```python +def predict(self, video_path): +``` + +补帧API,得到超分后的视频。 + +**参数** + +* video_path (str): 原始视频的路径; + +**返回** + +* frame_pattern_combined(str): 视频超分后单帧数据保存路径; +* vid_out_path(str): 视频保存路径。 + + +## 预测代码示例 + +```python +import paddlehub as hub + +model = hub.Module('edvr') +model.predict('/PATH/TO//VIDEO') +``` + +## 模型相关信息 + +### 模型代码 +https://github.com/xinntao/EDVR + +### 依赖 + +paddlepaddle >= 2.0.0rc + +paddlehub >= 1.8.3 \ No newline at end of file diff --git a/hub_module/modules/video/video_restore/edvr/module.py b/hub_module/modules/video/video_restore/edvr/module.py new file mode 100644 index 0000000000000000000000000000000000000000..7ed578d48473b444dcd44f55ec9dbaaca9b3984b --- /dev/null +++ b/hub_module/modules/video/video_restore/edvr/module.py @@ -0,0 +1,166 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import glob +import time + +from tqdm import tqdm +import cv2 +import numpy as np +import paddle +from PIL import Image +from paddlehub.module.module import moduleinfo, serving, Module + +import edvr.utils as U + + +class EDVRDataset: + def __init__(self, frame_paths): + self.frames = frame_paths + + def __getitem__(self, index): + indexs = U.get_test_neighbor_frames(index, 5, len(self.frames)) + frame_list = [] + for i in indexs: + img = U.read_img(self.frames[i]) + frame_list.append(img) + + img_LQs = np.stack(frame_list, axis=0) + # BGR to RGB, HWC to CHW, numpy to tensor + img_LQs = img_LQs[:, :, :, [2, 1, 0]] + img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32') + + return img_LQs, self.frames[index] + + def __len__(self): + return len(self.frames) + + +@moduleinfo(name="edvr", + type="CV/image_editing", + author="paddlepaddle", + author_email="", + summary="EDVR is a super resolution model", + version="1.0.0") +class EDVRPredictor(Module): + def _initialize(self, output_path='output', weight_path=None): + paddle.enable_static() + self.input = input + self.output = os.path.join(output_path, 'EDVR') + + if weight_path is None: + cur_path = os.path.abspath(os.path.dirname(__file__)) + weight_path = os.path.join(cur_path, 'edvr_infer_model') + + self.weight_path = weight_path + + self.build_inference_model() + + def build_inference_model(self): + if paddle.in_dynamic_mode(): + # todo self.model = build_model(self.cfg) + pass + else: + place = paddle.fluid.framework._current_expected_place() + self.exe = paddle.fluid.Executor(place) + file_names = os.listdir(self.weight_path) + for file_name in file_names: + if file_name.find('model') > -1: + model_file = file_name + elif file_name.find('param') > -1: + param_file = file_name + + self.program, self.feed_names, self.fetch_targets = paddle.static.load_inference_model( + dirname=self.weight_path, + executor=self.exe, + model_filename=model_file, + params_filename=param_file) + + def base_forward(self, inputs): + if paddle.in_dynamic_mode(): + out = self.model(inputs) + else: + feed_dict = {} + if isinstance(inputs, dict): + feed_dict = inputs + elif isinstance(inputs, (list, tuple)): + for i, feed_name in enumerate(self.feed_names): + feed_dict[feed_name] = inputs[i] + else: + feed_dict[self.feed_names[0]] = inputs + + out = self.exe.run(self.program, + fetch_list=self.fetch_targets, + feed=feed_dict) + + return out + + def is_image(self, input): + try: + img = Image.open(input) + _ = img.size + return True + except: + return False + + def predict(self, video_path): + vid = video_path + base_name = os.path.basename(vid).split('.')[0] + output_path = os.path.join(self.output, base_name) + pred_frame_path = os.path.join(output_path, 'frames_pred') + + if not os.path.exists(output_path): + os.makedirs(output_path) + + if not os.path.exists(pred_frame_path): + os.makedirs(pred_frame_path) + + cap = cv2.VideoCapture(vid) + fps = cap.get(cv2.CAP_PROP_FPS) + + out_path = U.video2frames(vid, output_path) + + frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) + + dataset = EDVRDataset(frames) + + periods = [] + cur_time = time.time() + for infer_iter, data in enumerate(tqdm(dataset)): + data_feed_in = [data[0]] + + outs = self.base_forward(np.array(data_feed_in)) + + infer_result_list = [item for item in outs] + + frame_path = data[1] + + img_i = U.get_img(infer_result_list[0]) + U.save_img( + img_i, + os.path.join(pred_frame_path, os.path.basename(frame_path))) + + prev_time = cur_time + cur_time = time.time() + period = cur_time - prev_time + periods.append(period) + + frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') + vid_out_path = os.path.join(self.output, + '{}_edvr_out.mp4'.format(base_name)) + U.frames2video(frame_pattern_combined, vid_out_path, str(int(fps))) + print('save video result at ', vid_out_path) + + return frame_pattern_combined, vid_out_path \ No newline at end of file diff --git a/hub_module/modules/video/video_restore/edvr/utils.py b/hub_module/modules/video/video_restore/edvr/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..78e58a5a3f7ecbef0eebbc06ecb50460d1071074 --- /dev/null +++ b/hub_module/modules/video/video_restore/edvr/utils.py @@ -0,0 +1,144 @@ +import os +import sys +import base64 + +import cv2 +import numpy as np + + +def video2frames(video_path, outpath, **kargs): + def _dict2str(kargs): + cmd_str = '' + for k, v in kargs.items(): + cmd_str += (' ' + str(k) + ' ' + str(v)) + return cmd_str + + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] + vid_name = video_path.split('/')[-1].split('.')[0] + out_full_path = os.path.join(outpath, vid_name) + + if not os.path.exists(out_full_path): + os.makedirs(out_full_path) + + # video file name + outformat = out_full_path + '/%08d.png' + + cmd = ffmpeg + cmd = ffmpeg + [' -i ', video_path, ' -start_number ', ' 0 ', outformat] + + cmd = ''.join(cmd) + _dict2str(kargs) + + if os.system(cmd) != 0: + raise RuntimeError('ffmpeg process video: {} error'.format(vid_name)) + + sys.stdout.flush() + return out_full_path + + +def frames2video(frame_path, video_path, r): + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] + cmd = ffmpeg + [ + ' -r ', r, ' -f ', ' image2 ', ' -i ', frame_path, ' -vcodec ', + ' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', video_path + ] + cmd = ''.join(cmd) + + if os.system(cmd) != 0: + raise RuntimeError('ffmpeg process video: {} error'.format(video_path)) + + sys.stdout.flush() + + +def cv2_to_base64(image): + data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(data.tostring()).decode('utf8') + + +def base64_to_cv2(b64str): + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data + + +def get_img(pred): + pred = pred.squeeze() + pred = np.clip(pred, a_min=0., a_max=1.0) + pred = pred * 255 + pred = pred.round() + pred = pred.astype('uint8') + pred = np.transpose(pred, (1, 2, 0)) # chw -> hwc + pred = pred[:, :, ::-1] # rgb -> bgr + return pred + + +def save_img(img, framename): + dirname = os.path.dirname(framename) + if not os.path.exists(dirname): + os.makedirs(dirname) + + cv2.imwrite(framename, img) + + +def read_img(path, size=None, is_gt=False): + """read image by cv2 + return: Numpy float32, HWC, BGR, [0,1]""" + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +def get_test_neighbor_frames(crt_i, N, max_n, padding='new_info'): + """Generate an index list for reading N frames from a sequence of images + Args: + crt_i (int): current center index + max_n (int): max number of the sequence of images (calculated from 1) + N (int): reading N frames + padding (str): padding mode, one of replicate | reflection | new_info | circle + Example: crt_i = 0, N = 5 + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + new_info: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + Returns: + return_l (list [int]): a list of indexes + """ + max_n = max_n - 1 + n_pad = N // 2 + return_l = [] + + for i in range(crt_i - n_pad, crt_i + n_pad + 1): + if i < 0: + if padding == 'replicate': + add_idx = 0 + elif padding == 'reflection': + add_idx = -i + elif padding == 'new_info': + add_idx = (crt_i + n_pad) + (-i) + elif padding == 'circle': + add_idx = N + i + else: + raise ValueError('Wrong padding mode') + elif i > max_n: + if padding == 'replicate': + add_idx = max_n + elif padding == 'reflection': + add_idx = max_n * 2 - i + elif padding == 'new_info': + add_idx = (crt_i - n_pad) - (i - max_n) + elif padding == 'circle': + add_idx = i - N + else: + raise ValueError('Wrong padding mode') + else: + add_idx = i + return_l.append(add_idx) + + return return_l \ No newline at end of file diff --git a/hub_module/modules/video/video_restore/video_restoration/README.md b/hub_module/modules/video/video_restore/video_restoration/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f8b585919aeafcec6734a8cbf20ea1ea866b2e66 --- /dev/null +++ b/hub_module/modules/video/video_restore/video_restoration/README.md @@ -0,0 +1,45 @@ +## 模型概述 + +video_restoration 是针对老旧视频修复的模型。它主要由三个个部分组成:插帧,着色和超分。插帧模型基于dain模型,着色模型基于deoldify模型,超分模型基于edvr模型. 用户可以根据自己的需求选择对图像进行插帧,着色或超分操作。在使用该模型前请预先安装dain, deoldify以及edvr. + + +## API + +```python +def predict(self, + input_video_path, + model_select=['Interpolation', 'Colorization', 'SuperResolution']): +``` + +预测API,用于视频修复。 + +**参数** + +* input_video_path (str): 视频的路径。 + +* model_select (list\[str\]): 选择对图片对操作,\['Interpolation'\]对视频只进行插帧操作,\['Colorization'\]对视频只进行着色操作, \['SuperResolution'\]对视频只进行超分操作, +默认值为\['Interpolation', 'Colorization', 'SuperResolution'\]。 + +**返回** + +* temp_video_path (str): 处理后视频保存的位置。 + + + +## 代码示例 + +视频修复代码示例: + +```python +import paddlehub as hub + +model = hub.Module('video_restoration') +model.predict('/PATH/TO/VIDEO') + +``` + +### 依赖 + +paddlepaddle >= 2.0.0rc + +paddlehub >= 1.8.3 diff --git a/hub_module/modules/video/video_restore/video_restoration/module.py b/hub_module/modules/video/video_restore/video_restoration/module.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b103ad36f2a9f7f54bc028af556ecef19aa15f --- /dev/null +++ b/hub_module/modules/video/video_restore/video_restoration/module.py @@ -0,0 +1,75 @@ +# coding:utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time + +import cv2 +import paddle +import paddle.nn as nn +import paddlehub as hub +from paddlehub.module.module import moduleinfo, serving, Module + + +@moduleinfo(name="video_restoration", + type="CV/image_editing", + author="paddlepaddle", + author_email="", + summary="video_restoration is a video restoration model based on dain, deoldify and edvr.", + version="1.0.0") +class PhotoRestoreModel(Module): + """ + PhotoRestoreModel + + Args: + output_path(str): Path to save results. + + """ + def _initialize(self, output_path='output'): + self.output_path = output_path + paddle.enable_static() + self.dain = hub.Module(name='dain', output_path=self.output_path) + self.edvr = hub.Module(name='edvr', output_path=self.output_path) + paddle.disable_static() + self.deoldify = hub.Module(name='deoldify', output_path=self.output_path) + + def predict(self, + input_video_path, + model_select=['Interpolation', 'Colorization', 'SuperResolution']): + temp_video_path = None + for model in model_select: + print('\n {} model proccess start..'.format(model)) + if model == 'Interpolation': + paddle.enable_static() + print('dain input:',input_video_path) + frames_path, temp_video_path = self.dain.predict(input_video_path) + input_video_path = temp_video_path + paddle.disable_static() + + if model == 'Colorization': + self.deoldify.eval() + print('deoldify input:',input_video_path) + frames_path, temp_video_path = self.deoldify.predict(input_video_path) + input_video_path = temp_video_path + + if model == 'SuperResolution': + paddle.enable_static() + print('edvr input:',input_video_path) + frames_path, temp_video_path = self.edvr.predict(input_video_path) + input_video_path = temp_video_path + paddle.disable_static() + + return temp_video_path + +