未验证 提交 fd66dddd 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update realsr (#1991)

* update realsr

* add clean func

* update README
上级 dbd005b4
......@@ -63,7 +63,9 @@
import paddlehub as hub
model = hub.Module(name='realsr')
model.predict('/PATH/TO/IMAGE/OR/VIDEO')
model.predict('/PATH/TO/IMAGE')
# model.predict('/PATH/TO/VIDEO')
```
- ### 2、API
......@@ -174,3 +176,10 @@
适配paddlehub2.0版本
* 1.1.0
更新代码格式
```shell
$ hub install realsr == 1.1.0
```
\ No newline at end of file
......@@ -60,7 +60,9 @@
import paddlehub as hub
model = hub.Module(name='realsr')
model.predict('/PATH/TO/IMAGE/OR/VIDEO')
model.predict('/PATH/TO/IMAGE')
# model.predict('/PATH/TO/VIDEO')
```
- ### 2、API
......@@ -172,3 +174,10 @@
Support paddlehub2.0
* 1.1.0
Update code format
```shell
$ hub install realsr == 1.1.0
```
\ No newline at end of file
......@@ -21,22 +21,21 @@ import numpy as np
from PIL import Image
import paddle
import paddle.nn as nn
from paddlehub.module.module import moduleinfo, serving, Module
from paddlehub.module.module import moduleinfo, serving
from realsr.rrdb import RRDBNet
import realsr.utils as U
from .rrdb import RRDBNet
from . import utils as U
@moduleinfo(
name="realsr",
@moduleinfo(name="realsr",
type="CV/image_editing",
author="paddlepaddle",
author_email="",
summary="realsr is a super resolution model",
version="1.0.0")
class RealSRPredictor(Module):
def _initialize(self, output='output', weight_path=None, load_checkpoint: str = None):
#super(RealSRPredictor, self).__init__()
version="1.1.0")
class RealSRPredictor(nn.Layer):
def __init__(self, output='output', weight_path=None, load_checkpoint: str = None):
super(RealSRPredictor, self).__init__()
self.input = input
self.output = os.path.join(output, 'RealSR')
self.model = RRDBNet(3, 3, 64, 23)
......@@ -48,6 +47,8 @@ class RealSRPredictor(Module):
else:
checkpoint = os.path.join(self.directory, 'DF2K_JPEG.pdparams')
if not os.path.exists(checkpoint):
os.system('wget https://paddlegan.bj.bcebos.com/applications/DF2K_JPEG.pdparams -O ' + checkpoint)
state_dict = paddle.load(checkpoint)
self.model.load_dict(state_dict)
print("load pretrained checkpoint success")
......@@ -66,17 +67,17 @@ class RealSRPredictor(Module):
if isinstance(img, str):
ori_img = Image.open(img).convert('RGB')
elif isinstance(img, np.ndarray):
# ori_img = Image.fromarray(img).convert('RGB')
ori_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
elif isinstance(img, Image.Image):
ori_img = img
img = self.norm(ori_img)
with paddle.no_grad():
x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x)
pred_img = self.denorm(out.numpy()[0])
# pred_img = Image.fromarray(pred_img)
pred_img = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR)
return pred_img
......@@ -108,7 +109,8 @@ class RealSRPredictor(Module):
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(output_path, '{}_realsr_out.mp4'.format(base_name))
vid_out_path = os.path.join(output_path,
'{}_realsr_out.mp4'.format(base_name))
U.frames2video(frame_pattern_combined, vid_out_path, str(int(fps)))
print("save result at {}".format(vid_out_path))
......
......@@ -30,8 +30,9 @@ class Registry(object):
self._obj_map = {}
def _do_register(self, name, obj):
assert (name not in self._obj_map), "An object named '{}' was already registered in '{}' registry!".format(
name, self._name)
assert (
name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(name, self._name)
self._obj_map[name] = obj
def register(self, obj=None, name=None):
......@@ -84,7 +85,6 @@ class ResidualDenseBlock_5C(nn.Layer):
class RRDB(nn.Layer):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
......@@ -104,10 +104,8 @@ def make_layer(block, n_layers):
layers.append(block())
return nn.Sequential(*layers)
GENERATORS = Registry("GENERATOR")
@GENERATORS.register()
class RRDBNet(nn.Layer):
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
......@@ -130,8 +128,10 @@ class RRDBNet(nn.Layer):
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(
self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(
self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out
import os
import shutil
import unittest
import cv2
import requests
import numpy as np
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/1sLIu1XKQrY/download?ixid=MnwxMjA3fDB8MXxhbGx8MTJ8fHx8fHwyfHwxNjYyMzQxNDUx&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="realsr")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('output')
def test_run_image1(self):
results = self.module.run_image(
img='tests/test.jpg'
)
self.assertIsInstance(results, np.ndarray)
def test_run_image2(self):
results = self.module.run_image(
img=cv2.imread('tests/test.jpg')
)
self.assertIsInstance(results, np.ndarray)
def test_run_image3(self):
self.assertRaises(
FileNotFoundError,
self.module.run_image,
img='no.jpg'
)
def test_predict1(self):
pred_img, out_path = self.module.predict(
input='tests/test.jpg'
)
self.assertIsInstance(pred_img, np.ndarray)
self.assertIsInstance(out_path, str)
def test_predict2(self):
self.assertRaises(
RuntimeError,
self.module.predict,
input='no.jpg'
)
if __name__ == "__main__":
unittest.main()
......@@ -38,7 +38,9 @@ def video2frames(video_path, outpath, **kargs):
def frames2video(frame_path, video_path, r):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [' -r ', r, ' -f ', ' image2 ', ' -i ', frame_path, ' -pix_fmt ', ' yuv420p ', video_path]
cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', frame_path, ' -pix_fmt ', ' yuv420p ', video_path
]
cmd = ''.join(cmd)
if os.system(cmd) != 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册