未验证 提交 fd66dddd 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update realsr (#1991)

* update realsr

* add clean func

* update README
上级 dbd005b4
...@@ -63,7 +63,9 @@ ...@@ -63,7 +63,9 @@
import paddlehub as hub import paddlehub as hub
model = hub.Module(name='realsr') model = hub.Module(name='realsr')
model.predict('/PATH/TO/IMAGE/OR/VIDEO') model.predict('/PATH/TO/IMAGE')
# model.predict('/PATH/TO/VIDEO')
``` ```
- ### 2、API - ### 2、API
...@@ -174,3 +176,10 @@ ...@@ -174,3 +176,10 @@
适配paddlehub2.0版本 适配paddlehub2.0版本
* 1.1.0
更新代码格式
```shell
$ hub install realsr == 1.1.0
```
\ No newline at end of file
...@@ -60,7 +60,9 @@ ...@@ -60,7 +60,9 @@
import paddlehub as hub import paddlehub as hub
model = hub.Module(name='realsr') model = hub.Module(name='realsr')
model.predict('/PATH/TO/IMAGE/OR/VIDEO') model.predict('/PATH/TO/IMAGE')
# model.predict('/PATH/TO/VIDEO')
``` ```
- ### 2、API - ### 2、API
...@@ -172,3 +174,10 @@ ...@@ -172,3 +174,10 @@
Support paddlehub2.0 Support paddlehub2.0
* 1.1.0
Update code format
```shell
$ hub install realsr == 1.1.0
```
\ No newline at end of file
...@@ -21,22 +21,21 @@ import numpy as np ...@@ -21,22 +21,21 @@ import numpy as np
from PIL import Image from PIL import Image
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from paddlehub.module.module import moduleinfo, serving, Module from paddlehub.module.module import moduleinfo, serving
from realsr.rrdb import RRDBNet from .rrdb import RRDBNet
import realsr.utils as U from . import utils as U
@moduleinfo( @moduleinfo(name="realsr",
name="realsr", type="CV/image_editing",
type="CV/image_editing", author="paddlepaddle",
author="paddlepaddle", author_email="",
author_email="", summary="realsr is a super resolution model",
summary="realsr is a super resolution model", version="1.1.0")
version="1.0.0") class RealSRPredictor(nn.Layer):
class RealSRPredictor(Module): def __init__(self, output='output', weight_path=None, load_checkpoint: str = None):
def _initialize(self, output='output', weight_path=None, load_checkpoint: str = None): super(RealSRPredictor, self).__init__()
#super(RealSRPredictor, self).__init__()
self.input = input self.input = input
self.output = os.path.join(output, 'RealSR') self.output = os.path.join(output, 'RealSR')
self.model = RRDBNet(3, 3, 64, 23) self.model = RRDBNet(3, 3, 64, 23)
...@@ -48,6 +47,8 @@ class RealSRPredictor(Module): ...@@ -48,6 +47,8 @@ class RealSRPredictor(Module):
else: else:
checkpoint = os.path.join(self.directory, 'DF2K_JPEG.pdparams') checkpoint = os.path.join(self.directory, 'DF2K_JPEG.pdparams')
if not os.path.exists(checkpoint):
os.system('wget https://paddlegan.bj.bcebos.com/applications/DF2K_JPEG.pdparams -O ' + checkpoint)
state_dict = paddle.load(checkpoint) state_dict = paddle.load(checkpoint)
self.model.load_dict(state_dict) self.model.load_dict(state_dict)
print("load pretrained checkpoint success") print("load pretrained checkpoint success")
...@@ -66,17 +67,17 @@ class RealSRPredictor(Module): ...@@ -66,17 +67,17 @@ class RealSRPredictor(Module):
if isinstance(img, str): if isinstance(img, str):
ori_img = Image.open(img).convert('RGB') ori_img = Image.open(img).convert('RGB')
elif isinstance(img, np.ndarray): elif isinstance(img, np.ndarray):
# ori_img = Image.fromarray(img).convert('RGB')
ori_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) ori_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
elif isinstance(img, Image.Image): elif isinstance(img, Image.Image):
ori_img = img ori_img = img
img = self.norm(ori_img) img = self.norm(ori_img)
x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x) with paddle.no_grad():
x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x)
pred_img = self.denorm(out.numpy()[0]) pred_img = self.denorm(out.numpy()[0])
# pred_img = Image.fromarray(pred_img)
pred_img = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR) pred_img = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR)
return pred_img return pred_img
...@@ -108,7 +109,8 @@ class RealSRPredictor(Module): ...@@ -108,7 +109,8 @@ class RealSRPredictor(Module):
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(output_path, '{}_realsr_out.mp4'.format(base_name)) vid_out_path = os.path.join(output_path,
'{}_realsr_out.mp4'.format(base_name))
U.frames2video(frame_pattern_combined, vid_out_path, str(int(fps))) U.frames2video(frame_pattern_combined, vid_out_path, str(int(fps)))
print("save result at {}".format(vid_out_path)) print("save result at {}".format(vid_out_path))
......
...@@ -30,8 +30,9 @@ class Registry(object): ...@@ -30,8 +30,9 @@ class Registry(object):
self._obj_map = {} self._obj_map = {}
def _do_register(self, name, obj): def _do_register(self, name, obj):
assert (name not in self._obj_map), "An object named '{}' was already registered in '{}' registry!".format( assert (
name, self._name) name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(name, self._name)
self._obj_map[name] = obj self._obj_map[name] = obj
def register(self, obj=None, name=None): def register(self, obj=None, name=None):
...@@ -84,7 +85,6 @@ class ResidualDenseBlock_5C(nn.Layer): ...@@ -84,7 +85,6 @@ class ResidualDenseBlock_5C(nn.Layer):
class RRDB(nn.Layer): class RRDB(nn.Layer):
'''Residual in Residual Dense Block''' '''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32): def __init__(self, nf, gc=32):
super(RRDB, self).__init__() super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc) self.RDB1 = ResidualDenseBlock_5C(nf, gc)
...@@ -104,10 +104,8 @@ def make_layer(block, n_layers): ...@@ -104,10 +104,8 @@ def make_layer(block, n_layers):
layers.append(block()) layers.append(block())
return nn.Sequential(*layers) return nn.Sequential(*layers)
GENERATORS = Registry("GENERATOR") GENERATORS = Registry("GENERATOR")
@GENERATORS.register() @GENERATORS.register()
class RRDBNet(nn.Layer): class RRDBNet(nn.Layer):
def __init__(self, in_nc, out_nc, nf, nb, gc=32): def __init__(self, in_nc, out_nc, nf, nb, gc=32):
...@@ -130,8 +128,10 @@ class RRDBNet(nn.Layer): ...@@ -130,8 +128,10 @@ class RRDBNet(nn.Layer):
trunk = self.trunk_conv(self.RRDB_trunk(fea)) trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk fea = fea + trunk
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) fea = self.lrelu(
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(
self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea))) out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out return out
import os
import shutil
import unittest
import cv2
import requests
import numpy as np
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/1sLIu1XKQrY/download?ixid=MnwxMjA3fDB8MXxhbGx8MTJ8fHx8fHwyfHwxNjYyMzQxNDUx&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="realsr")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('output')
def test_run_image1(self):
results = self.module.run_image(
img='tests/test.jpg'
)
self.assertIsInstance(results, np.ndarray)
def test_run_image2(self):
results = self.module.run_image(
img=cv2.imread('tests/test.jpg')
)
self.assertIsInstance(results, np.ndarray)
def test_run_image3(self):
self.assertRaises(
FileNotFoundError,
self.module.run_image,
img='no.jpg'
)
def test_predict1(self):
pred_img, out_path = self.module.predict(
input='tests/test.jpg'
)
self.assertIsInstance(pred_img, np.ndarray)
self.assertIsInstance(out_path, str)
def test_predict2(self):
self.assertRaises(
RuntimeError,
self.module.predict,
input='no.jpg'
)
if __name__ == "__main__":
unittest.main()
...@@ -38,7 +38,9 @@ def video2frames(video_path, outpath, **kargs): ...@@ -38,7 +38,9 @@ def video2frames(video_path, outpath, **kargs):
def frames2video(frame_path, video_path, r): def frames2video(frame_path, video_path, r):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [' -r ', r, ' -f ', ' image2 ', ' -i ', frame_path, ' -pix_fmt ', ' yuv420p ', video_path] cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', frame_path, ' -pix_fmt ', ' yuv420p ', video_path
]
cmd = ''.join(cmd) cmd = ''.join(cmd)
if os.system(cmd) != 0: if os.system(cmd) != 0:
...@@ -54,7 +56,7 @@ def is_image(input): ...@@ -54,7 +56,7 @@ def is_image(input):
return True return True
except: except:
return False return False
def cv2_to_base64(image): def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1] data = cv2.imencode('.jpg', image)[1]
...@@ -65,4 +67,4 @@ def base64_to_cv2(b64str): ...@@ -65,4 +67,4 @@ def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8')) data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR) data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data return data
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册