未验证 提交 1af344a9 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update photo_restoration (#1998)

* update photo_restoration

* add clean func
上级 524a7512
...@@ -143,6 +143,10 @@ ...@@ -143,6 +143,10 @@
初始发布 初始发布
* 1.0.1 * 1.1.0
适配paddlehub2.0版本 适配paddlehub2.0版本
* ```shell
$ hub install photo_restoration==1.1.0
```
...@@ -145,7 +145,11 @@ ...@@ -145,7 +145,11 @@
First release First release
- 1.0.1 - 1.1.0
Adapt to paddlehub2.0 Adapt to paddlehub2.0
* ```shell
$ hub install photo_restoration==1.1.0
```
...@@ -18,19 +18,18 @@ import time ...@@ -18,19 +18,18 @@ import time
import cv2 import cv2
import paddle.nn as nn import paddle.nn as nn
import paddlehub as hub import paddlehub as hub
from paddlehub.module.module import moduleinfo, serving, Module from paddlehub.module.module import moduleinfo, serving
import photo_restoration.utils as U from . import utils as U
@moduleinfo( @moduleinfo(name="photo_restoration",
name="photo_restoration", type="CV/image_editing",
type="CV/image_editing", author="paddlepaddle",
author="paddlepaddle", author_email="",
author_email="", summary="photo_restoration is a photo restoration model based on deoldify and realsr.",
summary="photo_restoration is a photo restoration model based on deoldify and realsr.", version="1.1.0")
version="1.0.0") class PhotoRestoreModel(nn.Layer):
class PhotoRestoreModel(Module):
""" """
PhotoRestoreModel PhotoRestoreModel
...@@ -39,8 +38,8 @@ class PhotoRestoreModel(Module): ...@@ -39,8 +38,8 @@ class PhotoRestoreModel(Module):
visualization (bool): Whether to save the estimation result. Default is True. visualization (bool): Whether to save the estimation result. Default is True.
""" """
def _initialize(self, visualization: bool = False): def __init__(self, visualization: bool = False):
#super(PhotoRestoreModel, self).__init__() super(PhotoRestoreModel, self).__init__()
self.deoldify = hub.Module(name='deoldify') self.deoldify = hub.Module(name='deoldify')
self.realsr = hub.Module(name='realsr') self.realsr = hub.Module(name='realsr')
self.visualization = visualization self.visualization = visualization
......
import os
import shutil
import unittest
import cv2
import requests
import numpy as np
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/1sLIu1XKQrY/download?ixid=MnwxMjA3fDB8MXxhbGx8MTJ8fHx8fHwyfHwxNjYyMzQxNDUx&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="photo_restoration")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('photo_restoration')
def test_run_image1(self):
results = self.module.run_image(
input='tests/test.jpg'
)
self.assertIsInstance(results, np.ndarray)
def test_run_image2(self):
results = self.module.run_image(
input=cv2.imread('tests/test.jpg')
)
self.assertIsInstance(results, np.ndarray)
def test_run_image3(self):
self.assertRaises(
FileNotFoundError,
self.module.run_image,
input='no.jpg'
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册