未验证 提交 0ea0f8e8 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update user_guided_colorization (#1994)

* update user_guided_colorization

* add clean func
上级 9d830b93
......@@ -201,4 +201,8 @@
初始发布
- ```shell
$ hub install user_guided_colorization==1.0.0
```
......@@ -203,3 +203,8 @@
* 1.0.0
First release
- ```shell
$ hub install user_guided_colorization==1.0.0
```
......@@ -20,7 +20,7 @@ from paddle.nn import Conv2D, Conv2DTranspose
from paddlehub.module.module import moduleinfo
import paddlehub.vision.transforms as T
from paddlehub.module.cv_module import ImageColorizeModule
from user_guided_colorization.data_feed import ColorizePreprocess
from .data_feed import ColorizePreprocess
@moduleinfo(
......
import os
import shutil
import unittest
import cv2
import requests
import numpy as np
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/1sLIu1XKQrY/download?ixid=MnwxMjA3fDB8MXxhbGx8MTJ8fHx8fHwyfHwxNjYyMzQxNDUx&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="user_guided_colorization")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('colorization')
def test_predict1(self):
results = self.module.predict(
images=['tests/test.jpg'],
visualization=False
)
gray = results[0]['gray']
hint = results[0]['hint']
real = results[0]['real']
fake_reg = results[0]['fake_reg']
self.assertIsInstance(gray, np.ndarray)
self.assertIsInstance(hint, np.ndarray)
self.assertIsInstance(real, np.ndarray)
self.assertIsInstance(fake_reg, np.ndarray)
def test_predict2(self):
results = self.module.predict(
images=[cv2.imread('tests/test.jpg')],
visualization=False
)
gray = results[0]['gray']
hint = results[0]['hint']
real = results[0]['real']
fake_reg = results[0]['fake_reg']
self.assertIsInstance(gray, np.ndarray)
self.assertIsInstance(hint, np.ndarray)
self.assertIsInstance(real, np.ndarray)
self.assertIsInstance(fake_reg, np.ndarray)
def test_predict3(self):
results = self.module.predict(
images=[cv2.imread('tests/test.jpg')],
visualization=True
)
gray = results[0]['gray']
hint = results[0]['hint']
real = results[0]['real']
fake_reg = results[0]['fake_reg']
self.assertIsInstance(gray, np.ndarray)
self.assertIsInstance(hint, np.ndarray)
self.assertIsInstance(real, np.ndarray)
self.assertIsInstance(fake_reg, np.ndarray)
def test_predict4(self):
self.assertRaises(
IndexError,
self.module.predict,
images=['no.jpg'],
visualization=False
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册