diff --git a/applications/tools/video-enhance.py b/applications/tools/video-enhance.py index 23f6fc728e2399f47f34789e7b99db4d40b4256a..05fdacbf58a85868b96a9e0542fd4877f079f868 100644 --- a/applications/tools/video-enhance.py +++ b/applications/tools/video-enhance.py @@ -67,6 +67,10 @@ parser.add_argument('--mindim', default=360, help='Length of minimum image edges') # DeOldify args +parser.add_argument('--artistic', + action='store_true', + default=False, + help='whether to use artistic DeOldify Model') parser.add_argument('--render_factor', type=int, default=32, @@ -107,6 +111,7 @@ if __name__ == "__main__": elif order == 'DeOldify': predictor = DeOldifyPredictor(args.output, weight_path=args.DeOldify_weight, + artistic=args.artistic, render_factor=args.render_factor) frames_path, temp_video_path = predictor.run(temp_video_path) elif order == 'RealSR': diff --git a/docs/en_US/tutorials/video_restore.md b/docs/en_US/tutorials/video_restore.md new file mode 120000 index 0000000000000000000000000000000000000000..6043d42afd6af0b019f09b636318cbddc9d2a913 --- /dev/null +++ b/docs/en_US/tutorials/video_restore.md @@ -0,0 +1 @@ +../../zh_CN/tutorials/video_restore.md \ No newline at end of file diff --git a/docs/zh_CN/apis/apps.md b/docs/zh_CN/apis/apps.md index be2f119cc9a65b14d6fc6e7866927ba3097756d3..c59ad969fa603e3259c788dd623f289cef497964 100644 --- a/docs/zh_CN/apis/apps.md +++ b/docs/zh_CN/apis/apps.md @@ -39,6 +39,7 @@ ppgan.apps.DeOldifyPredictor(output='output', weight_path=None, render_factor=32 > > > - output (str): 设置输出图片的保存路径,默认是output。注意,保存路径为设置output/DeOldify。 > > - weight_path (str): 指定模型路径,默认是None,则会自动下载内置的已经训练好的模型。 +> > - artistic (bool): 是否使用偏"艺术性"的模型。"艺术性"的模型有可能产生一些有趣的颜色,但是毛刺比较多。 > > - render_factor (int): 图片渲染上色时的缩放因子,图片会缩放到边长为16xrender_factor的正方形, 再上色,例如render_factor默认值为32,输入图片先缩放到(16x32=512) 512x512大小的图片。通常来说,render_factor越小,计算速度越快,颜色看起来也更鲜活。较旧和较低质量的图像通常会因降低渲染因子而受益。渲染因子越高,图像质量越好,但颜色可能会稍微褪色。 ### run diff --git a/docs/zh_CN/tutorials/video_restore.md b/docs/zh_CN/tutorials/video_restore.md index a1553db12a5d4e59afb79358e47f66caa51a69e7..06dcd1e93c431004cfc9495167651b16b3243821 100644 --- a/docs/zh_CN/tutorials/video_restore.md +++ b/docs/zh_CN/tutorials/video_restore.md @@ -63,6 +63,7 @@ ppgan.apps.DeOldifyPredictor(output='output', weight_path=None, render_factor=32 - `output (str,可选的)`: 输出的文件夹路径,默认值:`output`. - `weight_path (None,可选的)`: 载入的权重路径,如果没有设置,则从云端下载默认的权重到本地。默认值:`None`。 +- `artistic (bool)`: 是否使用偏"艺术性"的模型。"艺术性"的模型有可能产生一些有趣的颜色,但是毛刺比较多。 - `render_factor (int)`: 会将该参数乘以16后作为输入帧的resize的值,如果该值设置为32, 则输入帧会resize到(32 * 16, 32 * 16)的尺寸再输入到网络中。 diff --git a/ppgan/apps/base_predictor.py b/ppgan/apps/base_predictor.py index d92573b3e68ad2037c374fb7759a71172d541607..388f841a170aa318480f688cde8a507859c03e89 100644 --- a/ppgan/apps/base_predictor.py +++ b/ppgan/apps/base_predictor.py @@ -14,6 +14,7 @@ import os import cv2 +import numpy as np from PIL import Image import paddle @@ -64,9 +65,16 @@ class BasePredictor(object): def is_image(self, input): try: - img = Image.open(input) - _ = img.size - return True + if isinstance(input, (np.ndarray, Image.Image)): + return True + elif isinstance(input, str): + if not os.path.isfile(input): + raise ValueError('input must be a file') + img = Image.open(input) + _ = img.size + return True + else: + return False except: return False diff --git a/ppgan/apps/deoldify_predictor.py b/ppgan/apps/deoldify_predictor.py index c1bab4068630cc6d61a927ef431911c4f447c3d2..6379d4d3e6de099f0631b18a33e442f269ebba51 100644 --- a/ppgan/apps/deoldify_predictor.py +++ b/ppgan/apps/deoldify_predictor.py @@ -26,17 +26,27 @@ from ppgan.models.generators.deoldify import build_model from .base_predictor import BasePredictor -DEOLDIFY_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams' +DEOLDIFY_STABLE_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams' +DEOLDIFY_ART_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_art.pdparams' class DeOldifyPredictor(BasePredictor): - def __init__(self, output='output', weight_path=None, render_factor=32): - # self.input = input + def __init__(self, + output='output', + weight_path=None, + artistic=False, + render_factor=32): self.output = os.path.join(output, 'DeOldify') + if not os.path.exists(self.output): + os.makedirs(self.output) self.render_factor = render_factor - self.model = build_model() + self.model = build_model( + model_type='artistic' if artistic else 'stable') if weight_path is None: - weight_path = get_path_from_url(DEOLDIFY_WEIGHT_URL) + if artistic: + weight_path = get_path_from_url(DEOLDIFY_ART_WEIGHT_URL) + else: + weight_path = get_path_from_url(DEOLDIFY_STABLE_WEIGHT_URL) state_dict = paddle.load(weight_path) self.model.load_dict(state_dict) @@ -134,7 +144,10 @@ class DeOldifyPredictor(BasePredictor): out_path = None if self.output: - base_name = os.path.splitext(os.path.basename(input))[0] + try: + base_name = os.path.splitext(os.path.basename(input))[0] + except: + base_name = 'result' out_path = os.path.join(self.output, base_name + '.png') pred_img.save(out_path) diff --git a/ppgan/apps/realsr_predictor.py b/ppgan/apps/realsr_predictor.py index 58f3b61537cff92c994fe3bfb13affef0e911915..3f471d44a7cac06e1949c52eff96a0f38ef31ffb 100644 --- a/ppgan/apps/realsr_predictor.py +++ b/ppgan/apps/realsr_predictor.py @@ -107,7 +107,10 @@ class RealSRPredictor(BasePredictor): out_path = None if self.output: - base_name = os.path.splitext(os.path.basename(input))[0] + try: + base_name = os.path.splitext(os.path.basename(input))[0] + except: + base_name = 'result' out_path = os.path.join(self.output, base_name + '.png') pred_img.save(out_path) diff --git a/ppgan/models/generators/deoldify.py b/ppgan/models/generators/deoldify.py index 909ce38612c547a89829795a00d28091ea155535..b04f39df7e3bbfc22b8baa85e0ff2b5d8ce4b4df 100644 --- a/ppgan/models/generators/deoldify.py +++ b/ppgan/models/generators/deoldify.py @@ -16,7 +16,7 @@ import numpy as np import paddle import paddle.nn as nn import paddle.nn.functional as F -from paddle.vision.models import resnet101 +from paddle.vision.models import resnet34, resnet101 from .hook import hook_outputs, model_sizes, dummy_eval from ...modules.nn import Spectralnorm @@ -57,6 +57,7 @@ class Deoldify(SequentialEx): def __init__(self, encoder, n_classes, + model_type='stable', blur=False, blur_final=True, self_attention=False, @@ -95,18 +96,34 @@ class Deoldify(SequentialEx): do_blur = blur and (not_final or blur_final) sa = self_attention and (i == len(sfs_idxs) - 3) - n_out = nf if not_final else nf // 2 - - unet_block = UnetBlockWide(up_in_c, - x_in_c, - n_out, - self.sfs[i], - final_div=not_final, - blur=blur, - self_attention=sa, - norm_type=norm_type, - extra_bn=extra_bn, - **kwargs) + if model_type == 'stable': + n_out = nf if not_final else nf // 2 + unet_block = UnetBlockWide(up_in_c, + x_in_c, + n_out, + self.sfs[i], + final_div=not_final, + blur=blur, + self_attention=sa, + norm_type=norm_type, + extra_bn=extra_bn, + **kwargs) + elif model_type == 'artistic': + unet_block = UnetBlockDeep(up_in_c, + x_in_c, + self.sfs[i], + final_div=not_final, + blur=blur, + self_attention=sa, + norm_type=norm_type, + extra_bn=extra_bn, + nf_factor=nf_factor, + **kwargs) + else: + raise ValueError( + 'Expected model_type in [stable, artistic], but got {}'. + format(model_type)) + unet_block.eval() layers.append(unet_block) x = unet_block(x) @@ -151,7 +168,7 @@ def custom_conv_layer(ni: int, bn = norm_type in ('Batch', 'Batchzero') or extra_bn == True if bias is None: bias = not bn - conv_func = nn.Conv2DTranspose if transpose else nn.Conv1d if is_1d else nn.Conv2D + conv_func = nn.Conv2DTranspose if transpose else nn.Conv1D if is_1d else nn.Conv2D conv = conv_func(ni, nf, @@ -222,19 +239,18 @@ class UnetBlockWide(nn.Layer): class UnetBlockDeep(nn.Layer): "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`." - def __init__( - self, - up_in_c: int, - x_in_c: int, - # hook: Hook, - final_div: bool = True, - blur: bool = False, - leaky: float = None, - self_attention: bool = False, - nf_factor: float = 1.0, - **kwargs): + def __init__(self, + up_in_c: int, + x_in_c: int, + hook, + final_div: bool = True, + blur: bool = False, + leaky: float = None, + self_attention: bool = False, + nf_factor: float = 1.0, + **kwargs): super().__init__() - + self.hook = hook self.shuf = CustomPixelShuffle_ICNR(up_in_c, up_in_c // 2, blur=blur, @@ -312,7 +328,7 @@ def conv_layer(ni: int, if padding is None: padding = (ks - 1) // 2 if not transpose else 0 bn = norm_type in ('Batch', 'BatchZero') if bias is None: bias = not bn - conv_func = nn.Conv2DTranspose if transpose else nn.Conv1d if is_1d else nn.Conv2D + conv_func = nn.Conv2DTranspose if transpose else nn.Conv1D if is_1d else nn.Conv2D conv = conv_func(ni, nf, @@ -472,16 +488,27 @@ def _get_sfs_idxs(sizes): return sfs_idxs -def build_model(): - backbone = resnet101() +def build_model(model_type='stable'): + if model_type == 'stable': + backbone = resnet101() + nf_factor = 2 + elif model_type == 'artistic': + backbone = resnet34() + nf_factor = 1.5 + else: + raise ValueError( + 'Expected model_type in [stable, artistic], but got {}'.format( + model_type)) + cut = -2 encoder = nn.Sequential(*list(backbone.children())[:cut]) model = Deoldify(encoder, 3, + model_type=model_type, blur=True, y_range=(-3, 3), norm_type='Spectral', self_attention=True, - nf_factor=2) + nf_factor=nf_factor) return model diff --git a/ppgan/models/pix2pix_model.py b/ppgan/models/pix2pix_model.py index 80b31578871417008355d0c50d5e6d9820eff30c..99253fbec7ac65798949c87bfe13553118f2c4dc 100644 --- a/ppgan/models/pix2pix_model.py +++ b/ppgan/models/pix2pix_model.py @@ -80,11 +80,8 @@ class Pix2PixModel(BaseModel): AtoB = self.cfg.dataset.train.direction == 'AtoB' - # TODO: replace to_varialbe with to_tensor - self.real_A = paddle.fluid.dygraph.to_variable( - input['A' if AtoB else 'B']) - self.real_B = paddle.fluid.dygraph.to_variable( - input['B' if AtoB else 'A']) + self.real_A = paddle.to_tensor(input['A' if AtoB else 'B']) + self.real_B = paddle.to_tensor(input['B' if AtoB else 'A']) self.image_paths = input['A_paths' if AtoB else 'B_paths']