You need to sign in or sign up before continuing.
未验证 提交 fd66dddd 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update realsr (#1991)

* update realsr

* add clean func

* update README
上级 dbd005b4
......@@ -63,7 +63,9 @@
import paddlehub as hub
model = hub.Module(name='realsr')
model.predict('/PATH/TO/IMAGE/OR/VIDEO')
model.predict('/PATH/TO/IMAGE')
# model.predict('/PATH/TO/VIDEO')
```
- ### 2、API
......@@ -174,3 +176,10 @@
适配paddlehub2.0版本
* 1.1.0
更新代码格式
```shell
$ hub install realsr == 1.1.0
```
\ No newline at end of file
......@@ -60,7 +60,9 @@
import paddlehub as hub
model = hub.Module(name='realsr')
model.predict('/PATH/TO/IMAGE/OR/VIDEO')
model.predict('/PATH/TO/IMAGE')
# model.predict('/PATH/TO/VIDEO')
```
- ### 2、API
......@@ -172,3 +174,10 @@
Support paddlehub2.0
* 1.1.0
Update code format
```shell
$ hub install realsr == 1.1.0
```
\ No newline at end of file
......@@ -21,22 +21,21 @@ import numpy as np
from PIL import Image
import paddle
import paddle.nn as nn
from paddlehub.module.module import moduleinfo, serving, Module
from realsr.rrdb import RRDBNet
import realsr.utils as U
@moduleinfo(
name="realsr",
type="CV/image_editing",
author="paddlepaddle",
author_email="",
summary="realsr is a super resolution model",
version="1.0.0")
class RealSRPredictor(Module):
def _initialize(self, output='output', weight_path=None, load_checkpoint: str = None):
#super(RealSRPredictor, self).__init__()
from paddlehub.module.module import moduleinfo, serving
from .rrdb import RRDBNet
from . import utils as U
@moduleinfo(name="realsr",
type="CV/image_editing",
author="paddlepaddle",
author_email="",
summary="realsr is a super resolution model",
version="1.1.0")
class RealSRPredictor(nn.Layer):
def __init__(self, output='output', weight_path=None, load_checkpoint: str = None):
super(RealSRPredictor, self).__init__()
self.input = input
self.output = os.path.join(output, 'RealSR')
self.model = RRDBNet(3, 3, 64, 23)
......@@ -48,6 +47,8 @@ class RealSRPredictor(Module):
else:
checkpoint = os.path.join(self.directory, 'DF2K_JPEG.pdparams')
if not os.path.exists(checkpoint):
os.system('wget https://paddlegan.bj.bcebos.com/applications/DF2K_JPEG.pdparams -O ' + checkpoint)
state_dict = paddle.load(checkpoint)
self.model.load_dict(state_dict)
print("load pretrained checkpoint success")
......@@ -66,17 +67,17 @@ class RealSRPredictor(Module):
if isinstance(img, str):
ori_img = Image.open(img).convert('RGB')
elif isinstance(img, np.ndarray):
# ori_img = Image.fromarray(img).convert('RGB')
ori_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
elif isinstance(img, Image.Image):
ori_img = img
img = self.norm(ori_img)
x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x)
with paddle.no_grad():
x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x)
pred_img = self.denorm(out.numpy()[0])
# pred_img = Image.fromarray(pred_img)
pred_img = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR)
return pred_img
......@@ -108,7 +109,8 @@ class RealSRPredictor(Module):
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(output_path, '{}_realsr_out.mp4'.format(base_name))
vid_out_path = os.path.join(output_path,
'{}_realsr_out.mp4'.format(base_name))
U.frames2video(frame_pattern_combined, vid_out_path, str(int(fps)))
print("save result at {}".format(vid_out_path))
......
......@@ -30,8 +30,9 @@ class Registry(object):
self._obj_map = {}
def _do_register(self, name, obj):
assert (name not in self._obj_map), "An object named '{}' was already registered in '{}' registry!".format(
name, self._name)
assert (
name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(name, self._name)
self._obj_map[name] = obj
def register(self, obj=None, name=None):
......@@ -84,7 +85,6 @@ class ResidualDenseBlock_5C(nn.Layer):
class RRDB(nn.Layer):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
......@@ -104,10 +104,8 @@ def make_layer(block, n_layers):
layers.append(block())
return nn.Sequential(*layers)
GENERATORS = Registry("GENERATOR")
@GENERATORS.register()
class RRDBNet(nn.Layer):
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
......@@ -130,8 +128,10 @@ class RRDBNet(nn.Layer):
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(
self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(
self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out
import os
import shutil
import unittest
import cv2
import requests
import numpy as np
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/1sLIu1XKQrY/download?ixid=MnwxMjA3fDB8MXxhbGx8MTJ8fHx8fHwyfHwxNjYyMzQxNDUx&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="realsr")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('output')
def test_run_image1(self):
results = self.module.run_image(
img='tests/test.jpg'
)
self.assertIsInstance(results, np.ndarray)
def test_run_image2(self):
results = self.module.run_image(
img=cv2.imread('tests/test.jpg')
)
self.assertIsInstance(results, np.ndarray)
def test_run_image3(self):
self.assertRaises(
FileNotFoundError,
self.module.run_image,
img='no.jpg'
)
def test_predict1(self):
pred_img, out_path = self.module.predict(
input='tests/test.jpg'
)
self.assertIsInstance(pred_img, np.ndarray)
self.assertIsInstance(out_path, str)
def test_predict2(self):
self.assertRaises(
RuntimeError,
self.module.predict,
input='no.jpg'
)
if __name__ == "__main__":
unittest.main()
......@@ -38,7 +38,9 @@ def video2frames(video_path, outpath, **kargs):
def frames2video(frame_path, video_path, r):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [' -r ', r, ' -f ', ' image2 ', ' -i ', frame_path, ' -pix_fmt ', ' yuv420p ', video_path]
cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', frame_path, ' -pix_fmt ', ' yuv420p ', video_path
]
cmd = ''.join(cmd)
if os.system(cmd) != 0:
......@@ -54,7 +56,7 @@ def is_image(input):
return True
except:
return False
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
......@@ -65,4 +67,4 @@ def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
return data
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册