未验证 提交 71f3755a 编写于 作者: L LielinJiang 提交者: GitHub

Add art model of deoldify (#71)

* add art model of deoldify
上级 f02d52a7
...@@ -67,6 +67,10 @@ parser.add_argument('--mindim', ...@@ -67,6 +67,10 @@ parser.add_argument('--mindim',
default=360, default=360,
help='Length of minimum image edges') help='Length of minimum image edges')
# DeOldify args # DeOldify args
parser.add_argument('--artistic',
action='store_true',
default=False,
help='whether to use artistic DeOldify Model')
parser.add_argument('--render_factor', parser.add_argument('--render_factor',
type=int, type=int,
default=32, default=32,
...@@ -107,6 +111,7 @@ if __name__ == "__main__": ...@@ -107,6 +111,7 @@ if __name__ == "__main__":
elif order == 'DeOldify': elif order == 'DeOldify':
predictor = DeOldifyPredictor(args.output, predictor = DeOldifyPredictor(args.output,
weight_path=args.DeOldify_weight, weight_path=args.DeOldify_weight,
artistic=args.artistic,
render_factor=args.render_factor) render_factor=args.render_factor)
frames_path, temp_video_path = predictor.run(temp_video_path) frames_path, temp_video_path = predictor.run(temp_video_path)
elif order == 'RealSR': elif order == 'RealSR':
......
../../zh_CN/tutorials/video_restore.md
\ No newline at end of file
...@@ -39,6 +39,7 @@ ppgan.apps.DeOldifyPredictor(output='output', weight_path=None, render_factor=32 ...@@ -39,6 +39,7 @@ ppgan.apps.DeOldifyPredictor(output='output', weight_path=None, render_factor=32
> >
> > - output (str): 设置输出图片的保存路径,默认是output。注意,保存路径为设置output/DeOldify。 > > - output (str): 设置输出图片的保存路径,默认是output。注意,保存路径为设置output/DeOldify。
> > - weight_path (str): 指定模型路径,默认是None,则会自动下载内置的已经训练好的模型。 > > - weight_path (str): 指定模型路径,默认是None,则会自动下载内置的已经训练好的模型。
> > - artistic (bool): 是否使用偏"艺术性"的模型。"艺术性"的模型有可能产生一些有趣的颜色,但是毛刺比较多。
> > - render_factor (int): 图片渲染上色时的缩放因子,图片会缩放到边长为16xrender_factor的正方形, 再上色,例如render_factor默认值为32,输入图片先缩放到(16x32=512) 512x512大小的图片。通常来说,render_factor越小,计算速度越快,颜色看起来也更鲜活。较旧和较低质量的图像通常会因降低渲染因子而受益。渲染因子越高,图像质量越好,但颜色可能会稍微褪色。 > > - render_factor (int): 图片渲染上色时的缩放因子,图片会缩放到边长为16xrender_factor的正方形, 再上色,例如render_factor默认值为32,输入图片先缩放到(16x32=512) 512x512大小的图片。通常来说,render_factor越小,计算速度越快,颜色看起来也更鲜活。较旧和较低质量的图像通常会因降低渲染因子而受益。渲染因子越高,图像质量越好,但颜色可能会稍微褪色。
### run ### run
......
...@@ -63,6 +63,7 @@ ppgan.apps.DeOldifyPredictor(output='output', weight_path=None, render_factor=32 ...@@ -63,6 +63,7 @@ ppgan.apps.DeOldifyPredictor(output='output', weight_path=None, render_factor=32
- `output (str,可选的)`: 输出的文件夹路径,默认值:`output`. - `output (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `weight_path (None,可选的)`: 载入的权重路径,如果没有设置,则从云端下载默认的权重到本地。默认值:`None` - `weight_path (None,可选的)`: 载入的权重路径,如果没有设置,则从云端下载默认的权重到本地。默认值:`None`
- `artistic (bool)`: 是否使用偏"艺术性"的模型。"艺术性"的模型有可能产生一些有趣的颜色,但是毛刺比较多。
- `render_factor (int)`: 会将该参数乘以16后作为输入帧的resize的值,如果该值设置为32, - `render_factor (int)`: 会将该参数乘以16后作为输入帧的resize的值,如果该值设置为32,
则输入帧会resize到(32 * 16, 32 * 16)的尺寸再输入到网络中。 则输入帧会resize到(32 * 16, 32 * 16)的尺寸再输入到网络中。
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import os import os
import cv2 import cv2
import numpy as np
from PIL import Image from PIL import Image
import paddle import paddle
...@@ -64,9 +65,16 @@ class BasePredictor(object): ...@@ -64,9 +65,16 @@ class BasePredictor(object):
def is_image(self, input): def is_image(self, input):
try: try:
img = Image.open(input) if isinstance(input, (np.ndarray, Image.Image)):
_ = img.size return True
return True elif isinstance(input, str):
if not os.path.isfile(input):
raise ValueError('input must be a file')
img = Image.open(input)
_ = img.size
return True
else:
return False
except: except:
return False return False
......
...@@ -26,17 +26,27 @@ from ppgan.models.generators.deoldify import build_model ...@@ -26,17 +26,27 @@ from ppgan.models.generators.deoldify import build_model
from .base_predictor import BasePredictor from .base_predictor import BasePredictor
DEOLDIFY_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams' DEOLDIFY_STABLE_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
DEOLDIFY_ART_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_art.pdparams'
class DeOldifyPredictor(BasePredictor): class DeOldifyPredictor(BasePredictor):
def __init__(self, output='output', weight_path=None, render_factor=32): def __init__(self,
# self.input = input output='output',
weight_path=None,
artistic=False,
render_factor=32):
self.output = os.path.join(output, 'DeOldify') self.output = os.path.join(output, 'DeOldify')
if not os.path.exists(self.output):
os.makedirs(self.output)
self.render_factor = render_factor self.render_factor = render_factor
self.model = build_model() self.model = build_model(
model_type='artistic' if artistic else 'stable')
if weight_path is None: if weight_path is None:
weight_path = get_path_from_url(DEOLDIFY_WEIGHT_URL) if artistic:
weight_path = get_path_from_url(DEOLDIFY_ART_WEIGHT_URL)
else:
weight_path = get_path_from_url(DEOLDIFY_STABLE_WEIGHT_URL)
state_dict = paddle.load(weight_path) state_dict = paddle.load(weight_path)
self.model.load_dict(state_dict) self.model.load_dict(state_dict)
...@@ -134,7 +144,10 @@ class DeOldifyPredictor(BasePredictor): ...@@ -134,7 +144,10 @@ class DeOldifyPredictor(BasePredictor):
out_path = None out_path = None
if self.output: if self.output:
base_name = os.path.splitext(os.path.basename(input))[0] try:
base_name = os.path.splitext(os.path.basename(input))[0]
except:
base_name = 'result'
out_path = os.path.join(self.output, base_name + '.png') out_path = os.path.join(self.output, base_name + '.png')
pred_img.save(out_path) pred_img.save(out_path)
......
...@@ -107,7 +107,10 @@ class RealSRPredictor(BasePredictor): ...@@ -107,7 +107,10 @@ class RealSRPredictor(BasePredictor):
out_path = None out_path = None
if self.output: if self.output:
base_name = os.path.splitext(os.path.basename(input))[0] try:
base_name = os.path.splitext(os.path.basename(input))[0]
except:
base_name = 'result'
out_path = os.path.join(self.output, base_name + '.png') out_path = os.path.join(self.output, base_name + '.png')
pred_img.save(out_path) pred_img.save(out_path)
......
...@@ -16,7 +16,7 @@ import numpy as np ...@@ -16,7 +16,7 @@ import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.vision.models import resnet101 from paddle.vision.models import resnet34, resnet101
from .hook import hook_outputs, model_sizes, dummy_eval from .hook import hook_outputs, model_sizes, dummy_eval
from ...modules.nn import Spectralnorm from ...modules.nn import Spectralnorm
...@@ -57,6 +57,7 @@ class Deoldify(SequentialEx): ...@@ -57,6 +57,7 @@ class Deoldify(SequentialEx):
def __init__(self, def __init__(self,
encoder, encoder,
n_classes, n_classes,
model_type='stable',
blur=False, blur=False,
blur_final=True, blur_final=True,
self_attention=False, self_attention=False,
...@@ -95,18 +96,34 @@ class Deoldify(SequentialEx): ...@@ -95,18 +96,34 @@ class Deoldify(SequentialEx):
do_blur = blur and (not_final or blur_final) do_blur = blur and (not_final or blur_final)
sa = self_attention and (i == len(sfs_idxs) - 3) sa = self_attention and (i == len(sfs_idxs) - 3)
n_out = nf if not_final else nf // 2 if model_type == 'stable':
n_out = nf if not_final else nf // 2
unet_block = UnetBlockWide(up_in_c, unet_block = UnetBlockWide(up_in_c,
x_in_c, x_in_c,
n_out, n_out,
self.sfs[i], self.sfs[i],
final_div=not_final, final_div=not_final,
blur=blur, blur=blur,
self_attention=sa, self_attention=sa,
norm_type=norm_type, norm_type=norm_type,
extra_bn=extra_bn, extra_bn=extra_bn,
**kwargs) **kwargs)
elif model_type == 'artistic':
unet_block = UnetBlockDeep(up_in_c,
x_in_c,
self.sfs[i],
final_div=not_final,
blur=blur,
self_attention=sa,
norm_type=norm_type,
extra_bn=extra_bn,
nf_factor=nf_factor,
**kwargs)
else:
raise ValueError(
'Expected model_type in [stable, artistic], but got {}'.
format(model_type))
unet_block.eval() unet_block.eval()
layers.append(unet_block) layers.append(unet_block)
x = unet_block(x) x = unet_block(x)
...@@ -151,7 +168,7 @@ def custom_conv_layer(ni: int, ...@@ -151,7 +168,7 @@ def custom_conv_layer(ni: int,
bn = norm_type in ('Batch', 'Batchzero') or extra_bn == True bn = norm_type in ('Batch', 'Batchzero') or extra_bn == True
if bias is None: if bias is None:
bias = not bn bias = not bn
conv_func = nn.Conv2DTranspose if transpose else nn.Conv1d if is_1d else nn.Conv2D conv_func = nn.Conv2DTranspose if transpose else nn.Conv1D if is_1d else nn.Conv2D
conv = conv_func(ni, conv = conv_func(ni,
nf, nf,
...@@ -222,19 +239,18 @@ class UnetBlockWide(nn.Layer): ...@@ -222,19 +239,18 @@ class UnetBlockWide(nn.Layer):
class UnetBlockDeep(nn.Layer): class UnetBlockDeep(nn.Layer):
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`." "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
def __init__( def __init__(self,
self, up_in_c: int,
up_in_c: int, x_in_c: int,
x_in_c: int, hook,
# hook: Hook, final_div: bool = True,
final_div: bool = True, blur: bool = False,
blur: bool = False, leaky: float = None,
leaky: float = None, self_attention: bool = False,
self_attention: bool = False, nf_factor: float = 1.0,
nf_factor: float = 1.0, **kwargs):
**kwargs):
super().__init__() super().__init__()
self.hook = hook
self.shuf = CustomPixelShuffle_ICNR(up_in_c, self.shuf = CustomPixelShuffle_ICNR(up_in_c,
up_in_c // 2, up_in_c // 2,
blur=blur, blur=blur,
...@@ -312,7 +328,7 @@ def conv_layer(ni: int, ...@@ -312,7 +328,7 @@ def conv_layer(ni: int,
if padding is None: padding = (ks - 1) // 2 if not transpose else 0 if padding is None: padding = (ks - 1) // 2 if not transpose else 0
bn = norm_type in ('Batch', 'BatchZero') bn = norm_type in ('Batch', 'BatchZero')
if bias is None: bias = not bn if bias is None: bias = not bn
conv_func = nn.Conv2DTranspose if transpose else nn.Conv1d if is_1d else nn.Conv2D conv_func = nn.Conv2DTranspose if transpose else nn.Conv1D if is_1d else nn.Conv2D
conv = conv_func(ni, conv = conv_func(ni,
nf, nf,
...@@ -472,16 +488,27 @@ def _get_sfs_idxs(sizes): ...@@ -472,16 +488,27 @@ def _get_sfs_idxs(sizes):
return sfs_idxs return sfs_idxs
def build_model(): def build_model(model_type='stable'):
backbone = resnet101() if model_type == 'stable':
backbone = resnet101()
nf_factor = 2
elif model_type == 'artistic':
backbone = resnet34()
nf_factor = 1.5
else:
raise ValueError(
'Expected model_type in [stable, artistic], but got {}'.format(
model_type))
cut = -2 cut = -2
encoder = nn.Sequential(*list(backbone.children())[:cut]) encoder = nn.Sequential(*list(backbone.children())[:cut])
model = Deoldify(encoder, model = Deoldify(encoder,
3, 3,
model_type=model_type,
blur=True, blur=True,
y_range=(-3, 3), y_range=(-3, 3),
norm_type='Spectral', norm_type='Spectral',
self_attention=True, self_attention=True,
nf_factor=2) nf_factor=nf_factor)
return model return model
...@@ -80,11 +80,8 @@ class Pix2PixModel(BaseModel): ...@@ -80,11 +80,8 @@ class Pix2PixModel(BaseModel):
AtoB = self.cfg.dataset.train.direction == 'AtoB' AtoB = self.cfg.dataset.train.direction == 'AtoB'
# TODO: replace to_varialbe with to_tensor self.real_A = paddle.to_tensor(input['A' if AtoB else 'B'])
self.real_A = paddle.fluid.dygraph.to_variable( self.real_B = paddle.to_tensor(input['B' if AtoB else 'A'])
input['A' if AtoB else 'B'])
self.real_B = paddle.fluid.dygraph.to_variable(
input['B' if AtoB else 'A'])
self.image_paths = input['A_paths' if AtoB else 'B_paths'] self.image_paths = input['A_paths' if AtoB else 'B_paths']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册