未验证 提交 e6891a50 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update falsr_c (#1989)

* update falsr_c

* update version

* add clean func

* update falsr_c
上级 7a54b07f
......@@ -68,8 +68,7 @@
- ### 3、API
- ```python
def reconstruct(self,
images=None,
def reconstruct(images=None,
paths=None,
use_gpu=False,
visualization=False,
......@@ -93,21 +92,14 @@
* data (numpy.ndarray): 超分辨后图像。
- ```python
def save_inference_model(self,
dirname='falsr_c_save_model',
model_filename=None,
params_filename=None,
combined=False)
def save_inference_model(dirname)
```
- 将模型保存到指定路径。
- **参数**
* dirname: 存在模型的目录名称
* model\_filename: 模型文件名称,默认为\_\_model\_\_
* params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效)
* combined: 是否将参数保存到统一的一个文件中
* dirname: 模型保存路径
......@@ -166,3 +158,11 @@
初始发布
* 1.1.0
移除 fluid API
```shell
$ hub install falsr_c == 1.1.0
```
......@@ -71,8 +71,7 @@
- ### 3、API
- ```python
def reconstruct(self,
images=None,
def reconstruct(images=None,
paths=None,
use_gpu=False,
visualization=False,
......@@ -95,21 +94,14 @@
* data (numpy.ndarray): Result of super resolution.
- ```python
def save_inference_model(self,
dirname='falsr_c_save_model',
model_filename=None,
params_filename=None,
combined=False)
def save_inference_model(dirname)
```
- Save the model to the specified path.
- **Parameters**
* dirname: Save path.
* model\_filename: Model file name,defalt is \_\_model\_\_
* params\_filename: Parameter file name,defalt is \_\_params\_\_(Only takes effect when `combined` is True)
* combined: Whether to save the parameters to a unified file.
* dirname: Model save path.
......@@ -170,4 +162,11 @@
First release
- 1.1.0
Remove Fluid API
```shell
$ hub install falsr_c == 1.1.0
```
......@@ -5,7 +5,7 @@ from collections import OrderedDict
import cv2
import numpy as np
from PIL import Image
__all__ = ['reader']
......
......@@ -18,13 +18,14 @@ import os
import argparse
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
import paddle
import paddle.jit
import paddle.static
from paddle.inference import Config, create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from falsr_c.data_feed import reader
from falsr_c.processor import postprocess, base64_to_cv2, cv2_to_base64, check_dir
from .data_feed import reader
from .processor import postprocess, base64_to_cv2, cv2_to_base64, check_dir
@moduleinfo(
......@@ -33,21 +34,22 @@ from falsr_c.processor import postprocess, base64_to_cv2, cv2_to_base64, check_d
author="paddlepaddle",
author_email="",
summary="falsr_c is a super resolution model.",
version="1.0.0")
class Falsr_C(hub.Module):
def _initialize(self):
self.default_pretrained_model_path = os.path.join(self.directory, "falsr_c_model")
version="1.1.0")
class Falsr_C:
def __init__(self):
self.default_pretrained_model_path = os.path.join(self.directory, "falsr_c_model", "model")
self._set_config()
def _set_config(self):
"""
predictor config setting
"""
self.model_file_path = self.default_pretrained_model_path
cpu_config = AnalysisConfig(self.model_file_path)
model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config)
self.cpu_predictor = create_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
......@@ -56,10 +58,10 @@ class Falsr_C(hub.Module):
except:
use_gpu = False
if use_gpu:
gpu_config = AnalysisConfig(self.model_file_path)
gpu_config = Config(model, params)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
self.gpu_predictor = create_predictor(gpu_config)
def reconstruct(self, images=None, paths=None, use_gpu=False, visualization=False, output_dir="falsr_c_output"):
"""
......@@ -96,11 +98,18 @@ class Falsr_C(hub.Module):
for i in range(total_num):
image_y = np.array([all_data[i]['img_y']])
image_scale_pbpr = np.array([all_data[i]['img_scale_pbpr']])
image_y = PaddleTensor(image_y.copy())
image_scale_pbpr = PaddleTensor(image_scale_pbpr.copy())
output = self.gpu_predictor.run([image_y, image_scale_pbpr]) if use_gpu else self.cpu_predictor.run(
[image_y, image_scale_pbpr])
output = np.expand_dims(output[0].as_ndarray(), axis=1)
predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(image_y.copy())
input_handle = predictor.get_input_handle(input_names[1])
input_handle.copy_from_cpu(image_scale_pbpr.copy())
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output = np.expand_dims(output_handle.copy_to_cpu(), axis=1)
out = postprocess(
data_out=output,
org_im=all_data[i]['org_im'],
......@@ -111,29 +120,6 @@ class Falsr_C(hub.Module):
res.append(out)
return res
def save_inference_model(self,
dirname='falsr_c_save_model',
model_filename=None,
params_filename=None,
combined=False):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving
def serving_method(self, images, **kwargs):
"""
......
......@@ -52,7 +52,6 @@ def postprocess(data_out, org_im, org_im_shape, org_im_path, output_dir, visuali
result['data'] = sr
else:
result['data'] = sr
print("result['data'] shape", result['data'].shape)
return result
......
import os
import shutil
import unittest
import cv2
import requests
import numpy as np
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/1sLIu1XKQrY/download?ixid=MnwxMjA3fDB8MXxhbGx8MTJ8fHx8fHwyfHwxNjYyMzQxNDUx&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="falsr_c")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('falsr_c_output')
def test_reconstruct1(self):
results = self.module.reconstruct(
paths=['tests/test.jpg'],
use_gpu=False,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_reconstruct2(self):
results = self.module.reconstruct(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_reconstruct3(self):
results = self.module.reconstruct(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_reconstruct4(self):
results = self.module.reconstruct(
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_reconstruct5(self):
self.assertRaises(
AssertionError,
self.module.reconstruct,
paths=['no.jpg']
)
def test_reconstruct6(self):
self.assertRaises(
AttributeError,
self.module.reconstruct,
images=['test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册