提交 da6b6b42 编写于 作者: C chenjian

modify according to review

上级 245cc67b
...@@ -13,6 +13,19 @@ ...@@ -13,6 +13,19 @@
## 一、模型基本信息 ## 一、模型基本信息
- ### 应用效果展示
- 样例结果示例:
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/142962370-a957d7b3-8050-4f5a-8462-3d6e49facb33.png" width = "450" height = "300" hspace='10'/>
<br />
输入图像
<br />
<img src="https://user-images.githubusercontent.com/22424850/142962460-4a1b31ef-0eec-423b-ab3d-8622f3e8261a.png" width = "450" height = "300" hspace='10'/>
<br />
输出图像
<br />
</p>
- ### 模型介绍 - ### 模型介绍
- 通过大量暗光条件下短曝光和长曝光组成的图像对,以RAW图像为输入,RGB图像为参照进行训练,该模型实现端到端直接将暗光下的RAW图像处理得到可见的RGB图像。 - 通过大量暗光条件下短曝光和长曝光组成的图像对,以RAW图像为输入,RGB图像为参照进行训练,该模型实现端到端直接将暗光下的RAW图像处理得到可见的RGB图像。
...@@ -25,7 +38,6 @@ ...@@ -25,7 +38,6 @@
- ### 1、环境依赖 - ### 1、环境依赖
- rawpy - rawpy
- pillow
- ### 2、安装 - ### 2、安装
...@@ -53,26 +65,64 @@ ...@@ -53,26 +65,64 @@
denoiser = hub.Module(name="seeinthedark") denoiser = hub.Module(name="seeinthedark")
input_path = "/PATH/TO/IMAGE" input_path = "/PATH/TO/IMAGE"
# Read from a raw file # Read from a raw file
denoiser.denoising(input_path, output_path='./denoising_result.png', use_gpu=True) denoiser.denoising(paths=[input_path], output_path='./denoising_result.png', use_gpu=True)
``` ```
- ### 3、API - ### 3、API
- ```python - ```python
def denoising(input_path, output_path='./denoising_result.png', use_gpu=False) def denoising(images=None, paths=None, output_dir='./denoising_result/', use_gpu=False, visualization=True)
``` ```
- 暗光增强API,完成对暗光RAW图像的降噪并处理生成RGB图像。 - 暗光增强API,完成对暗光RAW图像的降噪并处理生成RGB图像。
- **参数** - **参数**
- images (list\[numpy.ndarray\]): 输入的图像,单通道的马赛克图像; <br/>
- input\_path (str): 暗光图像文件的路径,Sony的RAW格式; <br/> - paths (list\[str\]): 暗光图像文件的路径,Sony的RAW格式;<br/>
- output\_path (str): 结果保存的路径, 需要指定输出文件名; <br/> - output\_dir (str): 结果保存的路径; <br/>
- use\_gpu (bool): 是否使用 GPU;<br/> - use\_gpu (bool): 是否使用 GPU;<br/>
- visualization(bool): 是否保存结果到本地文件夹
## 四、服务部署
- PaddleHub Serving可以部署一个在线图像风格转换服务。
- ### 第一步:启动PaddleHub Serving
- 运行启动命令:
- ```shell
$ hub serving start -m seeinthedark
```
- 这样就完成了一个图像风格转换的在线服务API的部署,默认端口号为8866。
- **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
- ### 第二步:发送预测请求
- 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
- ```python
import requests
import json
import rawpy
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# 发送HTTP请求
data = {'images':[cv2_to_base64(rawpy.imread("/PATH/TO/IMAGE").raw_image_visible)]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/seeinthedark/"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
## 、更新历史 ## 、更新历史
* 1.0.0 * 1.0.0
......
...@@ -17,15 +17,19 @@ import argparse ...@@ -17,15 +17,19 @@ import argparse
import paddle import paddle
import paddlehub as hub import paddlehub as hub
from paddlehub.module.module import moduleinfo, runnable from paddlehub.module.module import moduleinfo, runnable, serving
import numpy as np import numpy as np
import rawpy import rawpy
from PIL import Image import cv2
from .util import base64_to_cv2
def pack_raw(raw): def pack_raw(raw):
# pack Bayer image to 4 channels # pack Bayer image to 4 channels
im = raw.raw_image_visible.astype(np.float32) im = raw
if not isinstance(raw, np.ndarray):
im = raw.raw_image_visible.astype(np.float32)
im = np.maximum(im - 512, 0) / (16383 - 512) # subtract the black level im = np.maximum(im - 512, 0) / (16383 - 512) # subtract the black level
im = np.expand_dims(im, axis=2) im = np.expand_dims(im, axis=2)
...@@ -42,47 +46,98 @@ def pack_raw(raw): ...@@ -42,47 +46,98 @@ def pack_raw(raw):
class LearningToSeeInDark: class LearningToSeeInDark:
def __init__(self): def __init__(self):
self.pretrained_model = os.path.join(self.directory, "pd_model/inference_model") self.pretrained_model = os.path.join(self.directory, "pd_model/inference_model")
self.cpu_have_loaded = False
self.gpu_have_loaded = False
def set_device(self, use_gpu=False):
if use_gpu == False:
if not self.cpu_have_loaded:
exe = paddle.static.Executor(paddle.CPUPlace())
[prog, inputs, outputs] = paddle.static.load_inference_model(
path_prefix=self.pretrained_model,
executor=exe,
model_filename="model.pdmodel",
params_filename="model.pdiparams")
self.cpuexec, self.cpuprog, self.cpuinputs, self.cpuoutputs = exe, prog, inputs, outputs
self.cpu_have_loaded = True
return self.cpuexec, self.cpuprog, self.cpuinputs, self.cpuoutputs
def denoising(self, input_path, output_path='./denoising_result.png', use_gpu=False): else:
if not self.gpu_have_loaded:
exe = paddle.static.Executor(paddle.CUDAPlace(0))
[prog, inputs, outputs] = paddle.static.load_inference_model(
path_prefix=self.pretrained_model,
executor=exe,
model_filename="model.pdmodel",
params_filename="model.pdiparams")
self.gpuexec, self.gpuprog, self.gpuinputs, self.gpuoutputs = exe, prog, inputs, outputs
self.gpu_have_loaded = True
return self.gpuexec, self.gpuprog, self.gpuinputs, self.gpuoutputs
def denoising(self, images=None, paths=None, output_dir='./enlightening_result/', use_gpu=False,
visualization=True):
''' '''
Denoise a raw image in the low-light scene. Denoise a raw image in the low-light scene.
input_path: the raw image path images (list[numpy.ndarray]): data of images, shape of each is [H, W], must be sing-channel image captured by camera.
output_path: the path to save the results paths (list[str]): paths to images
output_dir: the dir to save the results
use_gpu: if True, use gpu to perform the computation, otherwise cpu. use_gpu: if True, use gpu to perform the computation, otherwise cpu.
visualization: if True, save results in output_dir.
''' '''
results = []
paddle.enable_static() paddle.enable_static()
if use_gpu == False: exe, prog, inputs, outputs = self.set_device(use_gpu)
exe = paddle.static.Executor(paddle.CPUPlace())
else: if images != None:
exe = paddle.static.Executor(paddle.CUDAPlace(0)) for raw in images:
[prog, inputs, outputs] = paddle.static.load_inference_model( input_full = np.expand_dims(pack_raw(raw), axis=0) * 300
path_prefix=self.pretrained_model, px = input_full.shape[1] // 512
executor=exe, py = input_full.shape[2] // 512
model_filename="model.pdmodel", rx, ry = px * 512, py * 512
params_filename="model.pdiparams") input_full = input_full[:, :rx, :ry, :]
raw = rawpy.imread(input_path) output = np.random.randn(rx * 2, ry * 2, 3)
input_full = np.expand_dims(pack_raw(raw), axis=0) * 300 input_full = np.minimum(input_full, 1.0)
px = input_full.shape[1] // 512 for i in range(px):
py = input_full.shape[2] // 512 for j in range(py):
rx, ry = px * 512, py * 512 input_patch = input_full[:, i * 512:i * 512 + 512, j * 512:j * 512 + 512, :]
input_full = input_full[:, :rx, :ry, :] result = exe.run(prog, feed={inputs[0]: input_patch}, fetch_list=outputs)
output = np.random.randn(rx * 2, ry * 2, 3) output[i * 512 * 2:i * 512 * 2 + 512 * 2, j * 512 * 2:j * 512 * 2 + 512 * 2, :] = result[0][0]
input_full = np.minimum(input_full, 1.0) output = np.minimum(np.maximum(output, 0), 1)
for i in range(px): output = output * 255
for j in range(py): output = np.clip(output, 0, 255)
input_patch = input_full[:, i * 512:i * 512 + 512, j * 512:j * 512 + 512, :] output = output.astype('uint8')
result = exe.run(prog, feed={inputs[0]: input_patch}, fetch_list=outputs) results.append(output)
output[i * 512 * 2:i * 512 * 2 + 512 * 2, j * 512 * 2:j * 512 * 2 + 512 * 2, :] = result[0][0] if paths != None:
output = np.minimum(np.maximum(output, 0), 1) for path in paths:
raw = rawpy.imread(path)
print('Denoising Over.') input_full = np.expand_dims(pack_raw(raw), axis=0) * 300
try: px = input_full.shape[1] // 512
Image.fromarray(np.uint8(output * 255)).save(os.path.join(output_path)) py = input_full.shape[2] // 512
print('Image saved in {}'.format(output_path)) rx, ry = px * 512, py * 512
except: input_full = input_full[:, :rx, :ry, :]
print('Save image failed. Please check the output_path, should\ output = np.random.randn(rx * 2, ry * 2, 3)
be image format ext, e.g. png. current output path {}'.format(output_path)) input_full = np.minimum(input_full, 1.0)
for i in range(px):
for j in range(py):
input_patch = input_full[:, i * 512:i * 512 + 512, j * 512:j * 512 + 512, :]
result = exe.run(prog, feed={inputs[0]: input_patch}, fetch_list=outputs)
output[i * 512 * 2:i * 512 * 2 + 512 * 2, j * 512 * 2:j * 512 * 2 + 512 * 2, :] = result[0][0]
output = np.minimum(np.maximum(output, 0), 1)
output = output * 255
output = np.clip(output, 0, 255)
output = output.astype('uint8')
results.append(output)
if visualization == True:
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
for i, out in enumerate(results):
cv2.imwrite(os.path.join(output_dir, 'output_{}.png'.format(i)), out[:, :, ::-1])
return results
@runnable @runnable
def run_cmd(self, argvs: list): def run_cmd(self, argvs: list):
...@@ -101,7 +156,21 @@ class LearningToSeeInDark: ...@@ -101,7 +156,21 @@ class LearningToSeeInDark:
self.add_module_config_arg() self.add_module_config_arg()
self.add_module_input_arg() self.add_module_input_arg()
self.args = self.parser.parse_args(argvs) self.args = self.parser.parse_args(argvs)
self.denoising(input_path=self.args.input_path, output_path=self.args.output_path, use_gpu=self.args.use_gpu) self.denoising(
paths=[self.args.input_path],
output_dir=self.args.output_dir,
use_gpu=self.args.use_gpu,
visualization=self.args.visualization)
@serving
def serving_method(self, images, **kwargs):
"""
Run as a service.
"""
images_decode = [base64_to_cv2(image) for image in images]
results = self.denoising(images=images_decode, **kwargs)
tolist = [result.tolist() for result in results]
return tolist
def add_module_config_arg(self): def add_module_config_arg(self):
""" """
...@@ -110,7 +179,8 @@ class LearningToSeeInDark: ...@@ -110,7 +179,8 @@ class LearningToSeeInDark:
self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not") self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not")
self.arg_config_group.add_argument( self.arg_config_group.add_argument(
'--output_path', type=str, default='denoising_result.png', help='output path for saving result.') '--output_dir', type=str, default='denoising_result', help='output directory for saving result.')
self.arg_config_group.add_argument('--visualization', type=bool, default=False, help='save results or not.')
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册