diff --git a/modules/image/Image_editing/super_resolution/falsr_c/README.md b/modules/image/Image_editing/super_resolution/falsr_c/README.md index 2e7d35bbea7cc2eff7ab40af558942a826412a3f..405b73970ffaad2313f62a4da9c635a9959fc782 100644 --- a/modules/image/Image_editing/super_resolution/falsr_c/README.md +++ b/modules/image/Image_editing/super_resolution/falsr_c/README.md @@ -68,12 +68,11 @@ - ### 3、API - ```python - def reconstruct(self, - images=None, - paths=None, - use_gpu=False, - visualization=False, - output_dir="falsr_c_output") + def reconstruct(images=None, + paths=None, + use_gpu=False, + visualization=False, + output_dir="falsr_c_output") ``` - 预测API,用于图像超分辨率。 @@ -93,21 +92,14 @@ * data (numpy.ndarray): 超分辨后图像。 - ```python - def save_inference_model(self, - dirname='falsr_c_save_model', - model_filename=None, - params_filename=None, - combined=False) + def save_inference_model(dirname) ``` - 将模型保存到指定路径。 - **参数** - * dirname: 存在模型的目录名称 - * model\_filename: 模型文件名称,默认为\_\_model\_\_ - * params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效) - * combined: 是否将参数保存到统一的一个文件中 + * dirname: 模型保存路径 @@ -166,3 +158,11 @@ 初始发布 + +* 1.1.0 + + 移除 fluid API + + ```shell + $ hub install falsr_c == 1.1.0 + ``` diff --git a/modules/image/Image_editing/super_resolution/falsr_c/README_en.md b/modules/image/Image_editing/super_resolution/falsr_c/README_en.md index 5e651a7ea9393c68af8e24a9bb34a741287ffd46..c7e1d8a209af538333d330713219076263359965 100644 --- a/modules/image/Image_editing/super_resolution/falsr_c/README_en.md +++ b/modules/image/Image_editing/super_resolution/falsr_c/README_en.md @@ -71,12 +71,11 @@ - ### 3、API - ```python - def reconstruct(self, - images=None, - paths=None, - use_gpu=False, - visualization=False, - output_dir="falsr_c_output") + def reconstruct(images=None, + paths=None, + use_gpu=False, + visualization=False, + output_dir="falsr_c_output") ``` - Prediction API. @@ -95,21 +94,14 @@ * data (numpy.ndarray): Result of super resolution. - ```python - def save_inference_model(self, - dirname='falsr_c_save_model', - model_filename=None, - params_filename=None, - combined=False) + def save_inference_model(dirname) ``` - Save the model to the specified path. - **Parameters** - * dirname: Save path. - * model\_filename: Model file name,defalt is \_\_model\_\_ - * params\_filename: Parameter file name,defalt is \_\_params\_\_(Only takes effect when `combined` is True) - * combined: Whether to save the parameters to a unified file. + * dirname: Model save path. @@ -170,4 +162,11 @@ First release +- 1.1.0 + Remove Fluid API + + + ```shell + $ hub install falsr_c == 1.1.0 + ``` diff --git a/modules/image/Image_editing/super_resolution/falsr_c/data_feed.py b/modules/image/Image_editing/super_resolution/falsr_c/data_feed.py index 8aa6514b04caf5a705b8c82c25f2ad69d3e2fcb0..c64ffa078a902ed6edc0825301a16b60a971fc55 100644 --- a/modules/image/Image_editing/super_resolution/falsr_c/data_feed.py +++ b/modules/image/Image_editing/super_resolution/falsr_c/data_feed.py @@ -5,7 +5,7 @@ from collections import OrderedDict import cv2 import numpy as np -from PIL import Image + __all__ = ['reader'] diff --git a/modules/image/Image_editing/super_resolution/falsr_c/module.py b/modules/image/Image_editing/super_resolution/falsr_c/module.py index 8a8f25997aef8a5c7aab65d7b62798b2595abbce..b1d8a8a355f286d8febdb69f0c6f3b95bf8f229d 100644 --- a/modules/image/Image_editing/super_resolution/falsr_c/module.py +++ b/modules/image/Image_editing/super_resolution/falsr_c/module.py @@ -18,13 +18,14 @@ import os import argparse import numpy as np -import paddle.fluid as fluid -import paddlehub as hub -from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor +import paddle +import paddle.jit +import paddle.static +from paddle.inference import Config, create_predictor from paddlehub.module.module import moduleinfo, runnable, serving -from falsr_c.data_feed import reader -from falsr_c.processor import postprocess, base64_to_cv2, cv2_to_base64, check_dir +from .data_feed import reader +from .processor import postprocess, base64_to_cv2, cv2_to_base64, check_dir @moduleinfo( @@ -33,21 +34,22 @@ from falsr_c.processor import postprocess, base64_to_cv2, cv2_to_base64, check_d author="paddlepaddle", author_email="", summary="falsr_c is a super resolution model.", - version="1.0.0") -class Falsr_C(hub.Module): - def _initialize(self): - self.default_pretrained_model_path = os.path.join(self.directory, "falsr_c_model") + version="1.1.0") +class Falsr_C: + def __init__(self): + self.default_pretrained_model_path = os.path.join(self.directory, "falsr_c_model", "model") self._set_config() def _set_config(self): """ predictor config setting """ - self.model_file_path = self.default_pretrained_model_path - cpu_config = AnalysisConfig(self.model_file_path) + model = self.default_pretrained_model_path+'.pdmodel' + params = self.default_pretrained_model_path+'.pdiparams' + cpu_config = Config(model, params) cpu_config.disable_glog_info() cpu_config.disable_gpu() - self.cpu_predictor = create_paddle_predictor(cpu_config) + self.cpu_predictor = create_predictor(cpu_config) try: _places = os.environ["CUDA_VISIBLE_DEVICES"] @@ -56,10 +58,10 @@ class Falsr_C(hub.Module): except: use_gpu = False if use_gpu: - gpu_config = AnalysisConfig(self.model_file_path) + gpu_config = Config(model, params) gpu_config.disable_glog_info() gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0) - self.gpu_predictor = create_paddle_predictor(gpu_config) + self.gpu_predictor = create_predictor(gpu_config) def reconstruct(self, images=None, paths=None, use_gpu=False, visualization=False, output_dir="falsr_c_output"): """ @@ -96,11 +98,18 @@ class Falsr_C(hub.Module): for i in range(total_num): image_y = np.array([all_data[i]['img_y']]) image_scale_pbpr = np.array([all_data[i]['img_scale_pbpr']]) - image_y = PaddleTensor(image_y.copy()) - image_scale_pbpr = PaddleTensor(image_scale_pbpr.copy()) - output = self.gpu_predictor.run([image_y, image_scale_pbpr]) if use_gpu else self.cpu_predictor.run( - [image_y, image_scale_pbpr]) - output = np.expand_dims(output[0].as_ndarray(), axis=1) + + predictor = self.gpu_predictor if use_gpu else self.cpu_predictor + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + input_handle.copy_from_cpu(image_y.copy()) + input_handle = predictor.get_input_handle(input_names[1]) + input_handle.copy_from_cpu(image_scale_pbpr.copy()) + + predictor.run() + output_names = predictor.get_output_names() + output_handle = predictor.get_output_handle(output_names[0]) + output = np.expand_dims(output_handle.copy_to_cpu(), axis=1) out = postprocess( data_out=output, org_im=all_data[i]['org_im'], @@ -111,29 +120,6 @@ class Falsr_C(hub.Module): res.append(out) return res - def save_inference_model(self, - dirname='falsr_c_save_model', - model_filename=None, - params_filename=None, - combined=False): - if combined: - model_filename = "__model__" if not model_filename else model_filename - params_filename = "__params__" if not params_filename else params_filename - place = fluid.CPUPlace() - exe = fluid.Executor(place) - - program, feeded_var_names, target_vars = fluid.io.load_inference_model( - dirname=self.default_pretrained_model_path, executor=exe) - - fluid.io.save_inference_model( - dirname=dirname, - main_program=program, - executor=exe, - feeded_var_names=feeded_var_names, - target_vars=target_vars, - model_filename=model_filename, - params_filename=params_filename) - @serving def serving_method(self, images, **kwargs): """ diff --git a/modules/image/Image_editing/super_resolution/falsr_c/processor.py b/modules/image/Image_editing/super_resolution/falsr_c/processor.py index fe451116a20d18e5d0c091033daa82ee64de0464..805ada4d613c0efd1f09f165db064ecad79401c8 100644 --- a/modules/image/Image_editing/super_resolution/falsr_c/processor.py +++ b/modules/image/Image_editing/super_resolution/falsr_c/processor.py @@ -52,7 +52,6 @@ def postprocess(data_out, org_im, org_im_shape, org_im_path, output_dir, visuali result['data'] = sr else: result['data'] = sr - print("result['data'] shape", result['data'].shape) return result diff --git a/modules/image/Image_editing/super_resolution/falsr_c/test.py b/modules/image/Image_editing/super_resolution/falsr_c/test.py new file mode 100644 index 0000000000000000000000000000000000000000..ec2ef673423b6ba095a34407269fc21957d810c2 --- /dev/null +++ b/modules/image/Image_editing/super_resolution/falsr_c/test.py @@ -0,0 +1,86 @@ +import os +import shutil +import unittest + +import cv2 +import requests +import numpy as np +import paddlehub as hub + + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +class TestHubModule(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + img_url = 'https://unsplash.com/photos/1sLIu1XKQrY/download?ixid=MnwxMjA3fDB8MXxhbGx8MTJ8fHx8fHwyfHwxNjYyMzQxNDUx&force=true&w=640' + if not os.path.exists('tests'): + os.makedirs('tests') + response = requests.get(img_url) + assert response.status_code == 200, 'Network Error.' + with open('tests/test.jpg', 'wb') as f: + f.write(response.content) + cls.module = hub.Module(name="falsr_c") + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree('tests') + shutil.rmtree('inference') + shutil.rmtree('falsr_c_output') + + def test_reconstruct1(self): + results = self.module.reconstruct( + paths=['tests/test.jpg'], + use_gpu=False, + visualization=False + ) + self.assertIsInstance(results[0]['data'], np.ndarray) + + def test_reconstruct2(self): + results = self.module.reconstruct( + images=[cv2.imread('tests/test.jpg')], + use_gpu=False, + visualization=False + ) + self.assertIsInstance(results[0]['data'], np.ndarray) + + def test_reconstruct3(self): + results = self.module.reconstruct( + images=[cv2.imread('tests/test.jpg')], + use_gpu=False, + visualization=True + ) + self.assertIsInstance(results[0]['data'], np.ndarray) + + def test_reconstruct4(self): + results = self.module.reconstruct( + images=[cv2.imread('tests/test.jpg')], + use_gpu=True, + visualization=False + ) + self.assertIsInstance(results[0]['data'], np.ndarray) + + def test_reconstruct5(self): + self.assertRaises( + AssertionError, + self.module.reconstruct, + paths=['no.jpg'] + ) + + def test_reconstruct6(self): + self.assertRaises( + AttributeError, + self.module.reconstruct, + images=['test.jpg'] + ) + + def test_save_inference_model(self): + self.module.save_inference_model('./inference/model') + + self.assertTrue(os.path.exists('./inference/model.pdmodel')) + self.assertTrue(os.path.exists('./inference/model.pdiparams')) + + +if __name__ == "__main__": + unittest.main()