未验证 提交 d0355914 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update falsr_b (#2185)

上级 41791d9b
......@@ -2,7 +2,7 @@
|模型名称|falsr_b|
| :--- | :---: |
| :--- | :---: |
|类别|图像-图像编辑|
|网络|falsr_b|
|数据集|DIV2k|
......@@ -15,7 +15,7 @@
## 一、模型基本信息
- ### 应用效果展示
- 样例结果示例(左为原图,右为效果图):
<p align="center">
<img src="https://user-images.githubusercontent.com/35907364/133558583-0b7049db-ed1f-4a16-8676-f2141fcb3dee.png" width = "450" height = "300" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/130899031-a6f8c58a-5cb7-4105-b990-8cca5ae15368.png" width = "450" height = "300" hspace='10'/>
......@@ -149,20 +149,23 @@
print("save image as falsr_b_X2.png")
```
- ### Gradio APP 支持
从 PaddleHub 2.3.1 开始支持使用链接 http://127.0.0.1:8866/gradio/falsr_b 在浏览器中访问 falsr_b 的 Gradio APP。
## 五、更新历史
* 1.0.0
初始发布
* 1.1.0
移除 fluid API
* 1.2.0
添加 Gradio APP 支持
```shell
$ hub install falsr_b == 1.1.0
$ hub install falsr_b == 1.2.0
```
# falsr_b
|Module Name|falsr_b|
| :--- | :---: |
| :--- | :---: |
|Category |Image editing|
|Network |falsr_b|
|Dataset|DIV2k|
......@@ -11,10 +11,10 @@
|Latest update date|2021-02-26|
## I. Basic Information
## I. Basic Information
- ### Application Effect Display
- Sample results:
<p align="center">
<img src="https://user-images.githubusercontent.com/35907364/133558583-0b7049db-ed1f-4a16-8676-f2141fcb3dee.png" width = "450" height = "300" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/130899031-a6f8c58a-5cb7-4105-b990-8cca5ae15368.png" width = "450" height = "300" hspace='10'/>
......@@ -88,7 +88,7 @@
* output\_dir (str): Save path of images, "dcscn_output" by default.
- **Return**
* res (list\[dict\]): The list of model results, where each element is dict and each field is:
* res (list\[dict\]): The list of model results, where each element is dict and each field is:
* save\_path (str, optional): Save path of the result, save_path is '' if no image is saved.
* data (numpy.ndarray): Result of super resolution.
......@@ -153,20 +153,23 @@
print("save image as falsr_b_X2.png")
```
- ### Gradio APP support
Starting with PaddleHub 2.3.1, the Gradio APP for falsr_b is supported to be accessed in the browser using the link http://127.0.0.1:8866/gradio/falsr_b.
## V. Release Note
- 1.0.0
* 1.0.0
First release
- 1.1.0
* 1.1.0
Remove Fluid API
* 1.2.0
Add Gradio APP support.
```shell
$ hub install falsr_b == 1.1.0
$ hub install falsr_b == 1.2.0
```
# -*- coding:utf-8 -*-
import os
import time
from collections import OrderedDict
......@@ -6,7 +5,6 @@ from collections import OrderedDict
import cv2
import numpy as np
__all__ = ['reader']
......
# -*- coding:utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
......@@ -12,30 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import argparse
import numpy as np
import paddle
import paddle.jit
import paddle.static
from paddle.inference import Config, create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from paddle.inference import Config
from paddle.inference import create_predictor
from .data_feed import reader
from .processor import postprocess, base64_to_cv2, cv2_to_base64, check_dir
@moduleinfo(
name="falsr_b",
type="CV/image_editing",
author="paddlepaddle",
author_email="",
summary="falsr_b is a super resolution model.",
version="1.1.0")
from .processor import base64_to_cv2
from .processor import check_dir
from .processor import cv2_to_base64
from .processor import postprocess
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
@moduleinfo(name="falsr_b",
type="CV/image_editing",
author="paddlepaddle",
author_email="",
summary="falsr_b is a super resolution model.",
version="1.2.0")
class Falsr_B:
def __init__(self):
self.default_pretrained_model_path = os.path.join(self.directory, "falsr_b_model", "model")
self._set_config()
......@@ -44,8 +45,8 @@ class Falsr_B:
"""
predictor config setting
"""
model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
model = self.default_pretrained_model_path + '.pdmodel'
params = self.default_pretrained_model_path + '.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
......@@ -110,13 +111,12 @@ class Falsr_B:
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output = np.expand_dims(output_handle.copy_to_cpu(), axis=1)
out = postprocess(
data_out=output,
org_im=all_data[i]['org_im'],
org_im_shape=all_data[i]['org_im_shape'],
org_im_path=all_data[i]['org_im_path'],
output_dir=output_dir,
visualization=visualization)
out = postprocess(data_out=output,
org_im=all_data[i]['org_im'],
org_im_shape=all_data[i]['org_im_shape'],
org_im_path=all_data[i]['org_im_path'],
output_dir=output_dir,
visualization=visualization)
res.append(out)
return res
......@@ -135,11 +135,10 @@ class Falsr_B:
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(
description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
......@@ -147,8 +146,10 @@ class Falsr_B:
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
results = self.reconstruct(
paths=[args.input_path], use_gpu=args.use_gpu, output_dir=args.output_dir, visualization=args.visualization)
results = self.reconstruct(paths=[args.input_path],
use_gpu=args.use_gpu,
output_dir=args.output_dir,
visualization=args.visualization)
if args.save_dir is not None:
check_dir(args.save_dir)
self.save_inference_model(args.save_dir)
......@@ -159,14 +160,22 @@ class Falsr_B:
"""
Add the command config options.
"""
self.arg_config_group.add_argument(
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not")
self.arg_config_group.add_argument(
'--output_dir', type=str, default='falsr_b_output', help="The directory to save output images.")
self.arg_config_group.add_argument(
'--save_dir', type=str, default='falsr_b_save_model', help="The directory to save model.")
self.arg_config_group.add_argument(
'--visualization', type=ast.literal_eval, default=True, help="whether to save output as images.")
self.arg_config_group.add_argument('--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU or not")
self.arg_config_group.add_argument('--output_dir',
type=str,
default='falsr_b_output',
help="The directory to save output images.")
self.arg_config_group.add_argument('--save_dir',
type=str,
default='falsr_b_save_model',
help="The directory to save model.")
self.arg_config_group.add_argument('--visualization',
type=ast.literal_eval,
default=True,
help="whether to save output as images.")
def add_module_input_arg(self):
"""
......@@ -174,8 +183,20 @@ class Falsr_B:
"""
self.arg_input_group.add_argument('--input_path', type=str, help="path to image.")
if __name__ == "__main__":
module = Falsr_B()
module.reconstruct(paths=["BSD100_001.png", "BSD100_002.png", "Set5_003.png"])
module.save_inference_model()
def create_gradio_app(self):
import gradio as gr
import tempfile
import os
from PIL import Image
def inference(image, use_gpu=False):
with tempfile.TemporaryDirectory() as temp_dir:
self.reconstruct(paths=[image], use_gpu=use_gpu, visualization=True, output_dir=temp_dir)
return Image.open(os.path.join(temp_dir, os.listdir(temp_dir)[0]))
interface = gr.Interface(
inference,
[gr.inputs.Image(type="filepath"), gr.Checkbox(label='use_gpu')],
gr.outputs.Image(type="ndarray"),
title='falsr_b')
return interface
# -*- coding:utf-8 -*-
import base64
import os
import time
import base64
import cv2
import numpy as np
......
......@@ -3,18 +3,19 @@ import shutil
import unittest
import cv2
import requests
import numpy as np
import paddlehub as hub
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/1sLIu1XKQrY/download?ixid=MnwxMjA3fDB8MXxhbGx8MTJ8fHx8fHwyfHwxNjYyMzQxNDUx&force=true&w=640'
img_url = 'https://unsplash.com/photos/1sLIu1XKQrY/download?ixid=MnwxMjA3fDB8MXxhbGx8MTJ8fHx8fHwyfHwxNjYyMzQxNDUx&force=true&w=120'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
......@@ -30,50 +31,26 @@ class TestHubModule(unittest.TestCase):
shutil.rmtree('falsr_b_output')
def test_reconstruct1(self):
results = self.module.reconstruct(
paths=['tests/test.jpg'],
use_gpu=False,
visualization=False
)
results = self.module.reconstruct(paths=['tests/test.jpg'], use_gpu=False, visualization=False)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_reconstruct2(self):
results = self.module.reconstruct(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False
)
results = self.module.reconstruct(images=[cv2.imread('tests/test.jpg')], use_gpu=False, visualization=False)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_reconstruct3(self):
results = self.module.reconstruct(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True
)
results = self.module.reconstruct(images=[cv2.imread('tests/test.jpg')], use_gpu=False, visualization=True)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_reconstruct4(self):
results = self.module.reconstruct(
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False
)
results = self.module.reconstruct(images=[cv2.imread('tests/test.jpg')], use_gpu=True, visualization=False)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_reconstruct5(self):
self.assertRaises(
AssertionError,
self.module.reconstruct,
paths=['no.jpg']
)
self.assertRaises(AssertionError, self.module.reconstruct, paths=['no.jpg'])
def test_reconstruct6(self):
self.assertRaises(
AttributeError,
self.module.reconstruct,
images=['test.jpg']
)
self.assertRaises(AttributeError, self.module.reconstruct, images=['test.jpg'])
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册