未验证 提交 ca09b195 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update deoldify (#1992)

* update deoldify

* add clean func

* update README

* update format
上级 ae6fcc6a
# deoldify # deoldify
|模型名称|deoldify| |模型名称|deoldify|
| :--- | :---: | | :--- | :---: |
|类别|图像-图像编辑| |类别|图像-图像编辑|
|网络|NoGAN| |网络|NoGAN|
|数据集|ILSVRC 2012| |数据集|ILSVRC 2012|
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
## 一、模型基本信息 ## 一、模型基本信息
- ### 应用效果展示 - ### 应用效果展示
- 样例结果示例(左为原图,右为效果图): - 样例结果示例(左为原图,右为效果图):
<p align="center"> <p align="center">
<img src="https://user-images.githubusercontent.com/35907364/130886749-668dfa38-42ed-4a09-8d4a-b18af0475375.jpg" width = "450" height = "300" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/130886685-76221736-839a-46a2-8415-e5e0dd3b345e.png" width = "450" height = "300" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/130886749-668dfa38-42ed-4a09-8d4a-b18af0475375.jpg" width = "450" height = "300" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/130886685-76221736-839a-46a2-8415-e5e0dd3b345e.png" width = "450" height = "300" hspace='10'/>
...@@ -45,7 +45,7 @@ ...@@ -45,7 +45,7 @@
- ```shell - ```shell
$ hub install deoldify $ hub install deoldify
``` ```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
...@@ -59,7 +59,9 @@ ...@@ -59,7 +59,9 @@
import paddlehub as hub import paddlehub as hub
model = hub.Module(name='deoldify') model = hub.Module(name='deoldify')
model.predict('/PATH/TO/IMAGE/OR/VIDEO') model.predict('/PATH/TO/IMAGE')
# model.predict('/PATH/TO/VIDEO')
``` ```
- ### 2、API - ### 2、API
...@@ -170,3 +172,11 @@ ...@@ -170,3 +172,11 @@
* 1.0.1 * 1.0.1
适配paddlehub2.0版本 适配paddlehub2.0版本
* 1.1.0
移除 Fluid API
```shell
$ hub install deoldify == 1.1.0
```
# deoldify # deoldify
| Module Name |deoldify| | Module Name |deoldify|
| :--- | :---: | | :--- | :---: |
|Category|Image editing| |Category|Image editing|
|Network |NoGAN| |Network |NoGAN|
|Dataset|ILSVRC 2012| |Dataset|ILSVRC 2012|
...@@ -11,10 +11,10 @@ ...@@ -11,10 +11,10 @@
|Latest update date |2021-04-13| |Latest update date |2021-04-13|
## I. Basic Information ## I. Basic Information
- ### Application Effect Display - ### Application Effect Display
- Sample results: - Sample results:
<p align="center"> <p align="center">
<img src="https://user-images.githubusercontent.com/35907364/130886749-668dfa38-42ed-4a09-8d4a-b18af0475375.jpg" width = "450" height = "300" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/130886685-76221736-839a-46a2-8415-e5e0dd3b345e.png" width = "450" height = "300" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/130886749-668dfa38-42ed-4a09-8d4a-b18af0475375.jpg" width = "450" height = "300" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/130886685-76221736-839a-46a2-8415-e5e0dd3b345e.png" width = "450" height = "300" hspace='10'/>
...@@ -45,7 +45,7 @@ ...@@ -45,7 +45,7 @@
- ```shell - ```shell
$ hub install deoldify $ hub install deoldify
``` ```
- In case of any problems during installation, please refer to:[Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md) - In case of any problems during installation, please refer to:[Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md)
| [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md) | [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md)
...@@ -58,7 +58,9 @@ ...@@ -58,7 +58,9 @@
import paddlehub as hub import paddlehub as hub
model = hub.Module(name='deoldify') model = hub.Module(name='deoldify')
model.predict('/PATH/TO/IMAGE/OR/VIDEO') model.predict('/PATH/TO/IMAGE')
# model.predict('/PATH/TO/VIDEO')
``` ```
- ### 2、API - ### 2、API
...@@ -169,3 +171,11 @@ ...@@ -169,3 +171,11 @@
- 1.0.1 - 1.0.1
Adapt to paddlehub2.0 Adapt to paddlehub2.0
* 1.1.0
Remove Fluid API
```shell
$ hub install deoldify == 1.1.0
```
import paddle
import numpy as np import numpy as np
import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.vision.models import resnet101 from paddle.vision.models import resnet101
import deoldify.utils as U from . import utils as U
class SequentialEx(nn.Layer): class SequentialEx(nn.Layer):
...@@ -39,6 +39,7 @@ class SequentialEx(nn.Layer): ...@@ -39,6 +39,7 @@ class SequentialEx(nn.Layer):
class Deoldify(SequentialEx): class Deoldify(SequentialEx):
def __init__(self, def __init__(self,
encoder, encoder,
n_classes, n_classes,
...@@ -76,17 +77,16 @@ class Deoldify(SequentialEx): ...@@ -76,17 +77,16 @@ class Deoldify(SequentialEx):
n_out = nf if not_final else nf // 2 n_out = nf if not_final else nf // 2
unet_block = UnetBlockWide( unet_block = UnetBlockWide(up_in_c,
up_in_c, x_in_c,
x_in_c, n_out,
n_out, self.sfs[i],
self.sfs[i], final_div=not_final,
final_div=not_final, blur=blur,
blur=blur, self_attention=sa,
self_attention=sa, norm_type=norm_type,
norm_type=norm_type, extra_bn=extra_bn,
extra_bn=extra_bn, **kwargs)
**kwargs)
unet_block.eval() unet_block.eval()
layers.append(unet_block) layers.append(unet_block)
x = unet_block(x) x = unet_block(x)
...@@ -288,7 +288,7 @@ class CustomPixelShuffle_ICNR(nn.Layer): ...@@ -288,7 +288,7 @@ class CustomPixelShuffle_ICNR(nn.Layer):
self.shuf = PixelShuffle(scale) self.shuf = PixelShuffle(scale)
self.pad = ReplicationPad2d([1, 0, 1, 0]) self.pad = ReplicationPad2d([1, 0, 1, 0])
self.blur = paddle.nn.AvgPool2D(2, stride=1) self.blur = nn.AvgPool2D(2, stride=1)
self.relu = nn.LeakyReLU(leaky) if leaky is not None else nn.ReLU() # relu(True, leaky=leaky) self.relu = nn.LeakyReLU(leaky) if leaky is not None else nn.ReLU() # relu(True, leaky=leaky)
def forward(self, x): def forward(self, x):
...@@ -315,9 +315,8 @@ def res_block(nf, dense: bool = False, norm_type='Batch', bottle: bool = False, ...@@ -315,9 +315,8 @@ def res_block(nf, dense: bool = False, norm_type='Batch', bottle: bool = False,
norm2 = norm_type norm2 = norm_type
if not dense and (norm_type == 'Batch'): norm2 = 'BatchZero' if not dense and (norm_type == 'Batch'): norm2 = 'BatchZero'
nf_inner = nf // 2 if bottle else nf nf_inner = nf // 2 if bottle else nf
return SequentialEx( return SequentialEx(conv_layer(nf, nf_inner, norm_type=norm_type, **conv_kwargs),
conv_layer(nf, nf_inner, norm_type=norm_type, **conv_kwargs), conv_layer(nf_inner, nf, norm_type=norm2, **conv_kwargs), MergeLayer(dense))
conv_layer(nf_inner, nf, norm_type=norm2, **conv_kwargs), MergeLayer(dense))
class SigmoidRange(nn.Layer): class SigmoidRange(nn.Layer):
...@@ -337,6 +336,7 @@ def sigmoid_range(x, low, high): ...@@ -337,6 +336,7 @@ def sigmoid_range(x, low, high):
class PixelShuffle(nn.Layer): class PixelShuffle(nn.Layer):
def __init__(self, upscale_factor): def __init__(self, upscale_factor):
super(PixelShuffle, self).__init__() super(PixelShuffle, self).__init__()
self.upscale_factor = upscale_factor self.upscale_factor = upscale_factor
...@@ -346,6 +346,7 @@ class PixelShuffle(nn.Layer): ...@@ -346,6 +346,7 @@ class PixelShuffle(nn.Layer):
class ReplicationPad2d(nn.Layer): class ReplicationPad2d(nn.Layer):
def __init__(self, size): def __init__(self, size):
super(ReplicationPad2d, self).__init__() super(ReplicationPad2d, self).__init__()
self.size = size self.size = size
......
...@@ -12,32 +12,32 @@ ...@@ -12,32 +12,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import glob import glob
import os
import cv2 import cv2
import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
import deoldify.utils as U from . import utils as U
from paddlehub.module.module import moduleinfo, serving, Module from .base_module import build_model
from deoldify.base_module import build_model from paddlehub.module.module import moduleinfo
from paddlehub.module.module import serving
@moduleinfo(
name="deoldify", @moduleinfo(name="deoldify",
type="CV/image_editing", type="CV/image_editing",
author="paddlepaddle", author="paddlepaddle",
author_email="", author_email="",
summary="Deoldify is a colorizaton model", summary="Deoldify is a colorizaton model",
version="1.0.0") version="1.1.0")
class DeOldifyPredictor(Module): class DeOldifyPredictor(nn.Layer):
def _initialize(self, render_factor: int = 32, output_path: int = 'result', load_checkpoint: str = None):
#super(DeOldifyPredictor, self).__init__() def __init__(self, render_factor: int = 32, output_path: int = 'output', load_checkpoint: str = None):
super(DeOldifyPredictor, self).__init__()
self.model = build_model() self.model = build_model()
self.render_factor = render_factor self.render_factor = render_factor
self.output = os.path.join(output_path, 'DeOldify') self.output = os.path.join(output_path, 'DeOldify')
...@@ -50,6 +50,8 @@ class DeOldifyPredictor(Module): ...@@ -50,6 +50,8 @@ class DeOldifyPredictor(Module):
else: else:
checkpoint = os.path.join(self.directory, 'DeOldify_stable.pdparams') checkpoint = os.path.join(self.directory, 'DeOldify_stable.pdparams')
if not os.path.exists(checkpoint):
os.system('wget https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams -O ' + checkpoint)
state_dict = paddle.load(checkpoint) state_dict = paddle.load(checkpoint)
self.model.load_dict(state_dict) self.model.load_dict(state_dict)
print("load pretrained checkpoint success") print("load pretrained checkpoint success")
...@@ -140,8 +142,6 @@ class DeOldifyPredictor(Module): ...@@ -140,8 +142,6 @@ class DeOldifyPredictor(Module):
return frame_pattern_combined, vid_out_path return frame_pattern_combined, vid_out_path
def predict(self, input): def predict(self, input):
if not os.path.exists(self.output):
os.makedirs(self.output)
if not U.is_image(input): if not U.is_image(input):
return self.run_video(input) return self.run_video(input)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import math
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.container import Sequential
from paddle.utils.download import get_weights_path_from_url
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
model_urls = {
'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams', '0ba53eea9bc970962d0ef96f7b94057e'),
'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams', '46bc9f7c3dd2e55b7866285bee91eff3'),
'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams', '5ce890a9ad386df17cf7fe2313dca0a1'),
'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams', 'fb07a451df331e4b0bb861ed97c3a9b9'),
'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams', 'f9c700f26d3644bb76ad2226ed5f5713'),
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters, filter_size, stride=1, groups=1, act=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
self._batch_norm = BatchNorm(num_filters, act=act)
def forward(self, inputs):
x = self._conv(inputs)
x = self._batch_norm(x)
return x
class BasicBlock(fluid.dygraph.Layer):
"""residual block of resnet18 and resnet34
"""
expansion = 1
def __init__(self, num_channels, num_filters, stride, shortcut=True):
super(BasicBlock, self).__init__()
self.conv0 = ConvBNLayer(num_channels=num_channels, num_filters=num_filters, filter_size=3, act='relu')
self.conv1 = ConvBNLayer(
num_channels=num_filters, num_filters=num_filters, filter_size=3, stride=stride, act='relu')
if not shortcut:
self.short = ConvBNLayer(num_channels=num_channels, num_filters=num_filters, filter_size=1, stride=stride)
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = short + conv1
return fluid.layers.relu(y)
class BottleneckBlock(fluid.dygraph.Layer):
"""residual block of resnet50, resnet101 amd resnet152
"""
expansion = 4
def __init__(self, num_channels, num_filters, stride, shortcut=True):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(num_channels=num_channels, num_filters=num_filters, filter_size=1, act='relu')
self.conv1 = ConvBNLayer(
num_channels=num_filters, num_filters=num_filters, filter_size=3, stride=stride, act='relu')
self.conv2 = ConvBNLayer(
num_channels=num_filters, num_filters=num_filters * self.expansion, filter_size=1, act=None)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels, num_filters=num_filters * self.expansion, filter_size=1, stride=stride)
self.shortcut = shortcut
self._num_channels_out = num_filters * self.expansion
def forward(self, inputs):
x = self.conv0(inputs)
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
x = fluid.layers.elementwise_add(x=short, y=conv2)
return fluid.layers.relu(x)
class ResNet(fluid.dygraph.Layer):
"""ResNet model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Block (BasicBlock|BottleneckBlock): block module of model.
depth (int): layers of resnet, default: 50.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from paddle.vision.models import ResNet
from paddle.vision.models.resnet import BottleneckBlock, BasicBlock
resnet50 = ResNet(BottleneckBlock, 50)
resnet18 = ResNet(BasicBlock, 18)
"""
def __init__(self, Block, depth=50, num_classes=1000, with_pool=True, classifier_activation='softmax'):
super(ResNet, self).__init__()
self.num_classes = num_classes
self.with_pool = with_pool
layer_config = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
}
assert depth in layer_config.keys(), \
"supported depth are {} but input layer is {}".format(
layer_config.keys(), depth)
layers = layer_config[depth]
in_channels = 64
out_channels = [64, 128, 256, 512]
self.conv = ConvBNLayer(num_channels=3, num_filters=64, filter_size=7, stride=2, act='relu')
self.pool = Pool2D(pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
self.layers = []
for idx, num_blocks in enumerate(layers):
blocks = []
shortcut = False
for b in range(num_blocks):
if b == 1:
in_channels = out_channels[idx] * Block.expansion
block = Block(
num_channels=in_channels,
num_filters=out_channels[idx],
stride=2 if b == 0 and idx != 0 else 1,
shortcut=shortcut)
blocks.append(block)
shortcut = True
layer = self.add_sublayer("layer_{}".format(idx), Sequential(*blocks))
self.layers.append(layer)
if with_pool:
self.global_pool = Pool2D(pool_size=7, pool_type='avg', global_pooling=True)
if num_classes > 0:
stdv = 1.0 / math.sqrt(out_channels[-1] * Block.expansion * 1.0)
self.fc_input_dim = out_channels[-1] * Block.expansion * 1 * 1
self.fc = Linear(
self.fc_input_dim,
num_classes,
act=classifier_activation,
param_attr=fluid.param_attr.ParamAttr(initializer=fluid.initializer.Uniform(-stdv, stdv)))
def forward(self, inputs):
x = self.conv(inputs)
x = self.pool(x)
for layer in self.layers:
x = layer(x)
if self.with_pool:
x = self.global_pool(x)
if self.num_classes > -1:
x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim])
x = self.fc(x)
return x
def _resnet(arch, Block, depth, pretrained, **kwargs):
model = ResNet(Block, depth, **kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0], model_urls[arch][1])
assert weight_path.endswith('.pdparams'), "suffix of weight must be .pdparams"
param, _ = fluid.load_dygraph(weight_path)
model.set_dict(param)
return model
def resnet18(pretrained=False, **kwargs):
"""ResNet 18-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet18
# build model
model = resnet18()
# build model and load imagenet pretrained weight
# model = resnet18(pretrained=True)
"""
return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs)
def resnet34(pretrained=False, **kwargs):
"""ResNet 34-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet34
# build model
model = resnet34()
# build model and load imagenet pretrained weight
# model = resnet34(pretrained=True)
"""
return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs)
def resnet50(pretrained=False, **kwargs):
"""ResNet 50-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet50
# build model
model = resnet50()
# build model and load imagenet pretrained weight
# model = resnet50(pretrained=True)
"""
return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)
def resnet101(pretrained=False, **kwargs):
"""ResNet 101-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet101
# build model
model = resnet101()
# build model and load imagenet pretrained weight
# model = resnet101(pretrained=True)
"""
return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs)
def resnet152(pretrained=False, **kwargs):
"""ResNet 152-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet152
# build model
model = resnet152()
# build model and load imagenet pretrained weight
# model = resnet152(pretrained=True)
"""
return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs)
import os
import shutil
import unittest
import cv2
import numpy as np
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/1sLIu1XKQrY/download?ixid=MnwxMjA3fDB8MXxhbGx8MTJ8fHx8fHwyfHwxNjYyMzQxNDUx&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="deoldify")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('output')
def test_run_image1(self):
results = self.module.run_image(img='tests/test.jpg')
self.assertIsInstance(results, np.ndarray)
def test_run_image2(self):
results = self.module.run_image(img=cv2.imread('tests/test.jpg'))
self.assertIsInstance(results, np.ndarray)
def test_run_image3(self):
self.assertRaises(FileNotFoundError, self.module.run_image, img='no.jpg')
def test_predict1(self):
pred_img, out_path = self.module.predict(input='tests/test.jpg')
self.assertIsInstance(pred_img, np.ndarray)
self.assertIsInstance(out_path, str)
def test_predict2(self):
self.assertRaises(RuntimeError, self.module.predict, input='no.jpg')
if __name__ == "__main__":
unittest.main()
import base64
import os import os
import sys import sys
import base64
import cv2 import cv2
import numpy as np import numpy as np
...@@ -91,7 +91,7 @@ class Hooks(): ...@@ -91,7 +91,7 @@ class Hooks():
def _hook_inner(m, i, o): def _hook_inner(m, i, o):
return o if isinstance(o, paddle.fluid.framework.Variable) else o if is_listy(o) else list(o) return o if isinstance(o, paddle.Tensor) else o if is_listy(o) else list(o)
def hook_output(module, detach=True, grad=False): def hook_output(module, detach=True, grad=False):
...@@ -124,6 +124,7 @@ def dummy_batch(size=(64, 64), ch_in=3): ...@@ -124,6 +124,7 @@ def dummy_batch(size=(64, 64), ch_in=3):
class _SpectralNorm(nn.SpectralNorm): class _SpectralNorm(nn.SpectralNorm):
def __init__(self, weight_shape, dim=0, power_iters=1, eps=1e-12, dtype='float32'): def __init__(self, weight_shape, dim=0, power_iters=1, eps=1e-12, dtype='float32'):
super(_SpectralNorm, self).__init__(weight_shape, dim, power_iters, eps, dtype) super(_SpectralNorm, self).__init__(weight_shape, dim, power_iters, eps, dtype)
...@@ -131,22 +132,22 @@ class _SpectralNorm(nn.SpectralNorm): ...@@ -131,22 +132,22 @@ class _SpectralNorm(nn.SpectralNorm):
inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v} inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v}
out = self._helper.create_variable_for_type_inference(self._dtype) out = self._helper.create_variable_for_type_inference(self._dtype)
_power_iters = self._power_iters if self.training else 0 _power_iters = self._power_iters if self.training else 0
self._helper.append_op( self._helper.append_op(type="spectral_norm",
type="spectral_norm", inputs=inputs,
inputs=inputs, outputs={
outputs={ "Out": out,
"Out": out, },
}, attrs={
attrs={ "dim": self._dim,
"dim": self._dim, "power_iters": _power_iters,
"power_iters": _power_iters, "eps": self._eps,
"eps": self._eps, })
})
return out return out
class Spectralnorm(paddle.nn.Layer): class Spectralnorm(paddle.nn.Layer):
def __init__(self, layer, dim=0, power_iters=1, eps=1e-12, dtype='float32'): def __init__(self, layer, dim=0, power_iters=1, eps=1e-12, dtype='float32'):
super(Spectralnorm, self).__init__() super(Spectralnorm, self).__init__()
self.spectral_norm = _SpectralNorm(layer.weight.shape, dim, power_iters, eps, dtype) self.spectral_norm = _SpectralNorm(layer.weight.shape, dim, power_iters, eps, dtype)
...@@ -167,6 +168,7 @@ class Spectralnorm(paddle.nn.Layer): ...@@ -167,6 +168,7 @@ class Spectralnorm(paddle.nn.Layer):
def video2frames(video_path, outpath, **kargs): def video2frames(video_path, outpath, **kargs):
def _dict2str(kargs): def _dict2str(kargs):
cmd_str = '' cmd_str = ''
for k, v in kargs.items(): for k, v in kargs.items():
...@@ -196,12 +198,8 @@ def video2frames(video_path, outpath, **kargs): ...@@ -196,12 +198,8 @@ def video2frames(video_path, outpath, **kargs):
def frames2video(frame_path, video_path, r): def frames2video(frame_path, video_path, r):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [ cmd = ffmpeg + [' -r ', r, ' -f ', ' image2 ', ' -i ', frame_path, ' -pix_fmt ', ' yuv420p ', video_path]
' -r ', r, ' -f ', ' image2 ', ' -i ', frame_path, ' -vcodec ', ' libx264 ', ' -pix_fmt ', ' yuv420p ',
' -crf ', ' 16 ', video_path
]
cmd = ''.join(cmd) cmd = ''.join(cmd)
if os.system(cmd) != 0: if os.system(cmd) != 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册