diff --git a/modules/image/image_processing/seeinthedark/README.md b/modules/image/image_processing/seeinthedark/README.md
index 786b40493a7e5ad6639b3d2281883867f0172b2d..9822f8d879331cbac87685547afa890415069ef8 100644
--- a/modules/image/image_processing/seeinthedark/README.md
+++ b/modules/image/image_processing/seeinthedark/README.md
@@ -13,6 +13,19 @@
## 一、模型基本信息
+- ### 应用效果展示
+ - 样例结果示例:
+
+
+
+ 输入图像
+
+
+
+ 输出图像
+
+
+
- ### 模型介绍
- 通过大量暗光条件下短曝光和长曝光组成的图像对,以RAW图像为输入,RGB图像为参照进行训练,该模型实现端到端直接将暗光下的RAW图像处理得到可见的RGB图像。
@@ -25,7 +38,6 @@
- ### 1、环境依赖
- rawpy
- - pillow
- ### 2、安装
@@ -53,26 +65,64 @@
denoiser = hub.Module(name="seeinthedark")
input_path = "/PATH/TO/IMAGE"
# Read from a raw file
- denoiser.denoising(input_path, output_path='./denoising_result.png', use_gpu=True)
+ denoiser.denoising(paths=[input_path], output_path='./denoising_result.png', use_gpu=True)
```
- ### 3、API
- ```python
- def denoising(input_path, output_path='./denoising_result.png', use_gpu=False)
+ def denoising(images=None, paths=None, output_dir='./denoising_result/', use_gpu=False, visualization=True)
```
- 暗光增强API,完成对暗光RAW图像的降噪并处理生成RGB图像。
- **参数**
-
- - input\_path (str): 暗光图像文件的路径,Sony的RAW格式;
- - output\_path (str): 结果保存的路径, 需要指定输出文件名;
+ - images (list\[numpy.ndarray\]): 输入的图像,单通道的马赛克图像;
+ - paths (list\[str\]): 暗光图像文件的路径,Sony的RAW格式;
+ - output\_dir (str): 结果保存的路径;
- use\_gpu (bool): 是否使用 GPU;
+ - visualization(bool): 是否保存结果到本地文件夹
+
+## 四、服务部署
+
+- PaddleHub Serving可以部署一个在线图像风格转换服务。
+
+- ### 第一步:启动PaddleHub Serving
+
+ - 运行启动命令:
+ - ```shell
+ $ hub serving start -m seeinthedark
+ ```
+
+ - 这样就完成了一个图像风格转换的在线服务API的部署,默认端口号为8866。
+
+ - **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
+
+- ### 第二步:发送预测请求
+
+ - 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
+
+ - ```python
+ import requests
+ import json
+ import rawpy
+ import base64
+
+
+ def cv2_to_base64(image):
+ data = cv2.imencode('.jpg', image)[1]
+ return base64.b64encode(data.tostring()).decode('utf8')
+ # 发送HTTP请求
+ data = {'images':[cv2_to_base64(rawpy.imread("/PATH/TO/IMAGE").raw_image_visible)]}
+ headers = {"Content-type": "application/json"}
+ url = "http://127.0.0.1:8866/predict/seeinthedark/"
+ r = requests.post(url=url, headers=headers, data=json.dumps(data))
+ # 打印预测结果
+ print(r.json()["results"])
-## 四、更新历史
+## 五、更新历史
* 1.0.0
diff --git a/modules/image/image_processing/seeinthedark/module.py b/modules/image/image_processing/seeinthedark/module.py
index 1ba3c3821b24cafa4a27d6fe3dbee3187af3cdc7..434e40a79b0a8523e7a87eb1d5394aa4eaa0dbac 100644
--- a/modules/image/image_processing/seeinthedark/module.py
+++ b/modules/image/image_processing/seeinthedark/module.py
@@ -17,15 +17,19 @@ import argparse
import paddle
import paddlehub as hub
-from paddlehub.module.module import moduleinfo, runnable
+from paddlehub.module.module import moduleinfo, runnable, serving
import numpy as np
import rawpy
-from PIL import Image
+import cv2
+
+from .util import base64_to_cv2
def pack_raw(raw):
# pack Bayer image to 4 channels
- im = raw.raw_image_visible.astype(np.float32)
+ im = raw
+ if not isinstance(raw, np.ndarray):
+ im = raw.raw_image_visible.astype(np.float32)
im = np.maximum(im - 512, 0) / (16383 - 512) # subtract the black level
im = np.expand_dims(im, axis=2)
@@ -42,47 +46,98 @@ def pack_raw(raw):
class LearningToSeeInDark:
def __init__(self):
self.pretrained_model = os.path.join(self.directory, "pd_model/inference_model")
+ self.cpu_have_loaded = False
+ self.gpu_have_loaded = False
+
+ def set_device(self, use_gpu=False):
+ if use_gpu == False:
+ if not self.cpu_have_loaded:
+ exe = paddle.static.Executor(paddle.CPUPlace())
+ [prog, inputs, outputs] = paddle.static.load_inference_model(
+ path_prefix=self.pretrained_model,
+ executor=exe,
+ model_filename="model.pdmodel",
+ params_filename="model.pdiparams")
+ self.cpuexec, self.cpuprog, self.cpuinputs, self.cpuoutputs = exe, prog, inputs, outputs
+ self.cpu_have_loaded = True
+
+ return self.cpuexec, self.cpuprog, self.cpuinputs, self.cpuoutputs
- def denoising(self, input_path, output_path='./denoising_result.png', use_gpu=False):
+ else:
+ if not self.gpu_have_loaded:
+ exe = paddle.static.Executor(paddle.CUDAPlace(0))
+ [prog, inputs, outputs] = paddle.static.load_inference_model(
+ path_prefix=self.pretrained_model,
+ executor=exe,
+ model_filename="model.pdmodel",
+ params_filename="model.pdiparams")
+ self.gpuexec, self.gpuprog, self.gpuinputs, self.gpuoutputs = exe, prog, inputs, outputs
+ self.gpu_have_loaded = True
+
+ return self.gpuexec, self.gpuprog, self.gpuinputs, self.gpuoutputs
+
+ def denoising(self, images=None, paths=None, output_dir='./enlightening_result/', use_gpu=False,
+ visualization=True):
'''
Denoise a raw image in the low-light scene.
- input_path: the raw image path
- output_path: the path to save the results
+ images (list[numpy.ndarray]): data of images, shape of each is [H, W], must be sing-channel image captured by camera.
+ paths (list[str]): paths to images
+ output_dir: the dir to save the results
use_gpu: if True, use gpu to perform the computation, otherwise cpu.
+ visualization: if True, save results in output_dir.
'''
+ results = []
paddle.enable_static()
- if use_gpu == False:
- exe = paddle.static.Executor(paddle.CPUPlace())
- else:
- exe = paddle.static.Executor(paddle.CUDAPlace(0))
- [prog, inputs, outputs] = paddle.static.load_inference_model(
- path_prefix=self.pretrained_model,
- executor=exe,
- model_filename="model.pdmodel",
- params_filename="model.pdiparams")
- raw = rawpy.imread(input_path)
- input_full = np.expand_dims(pack_raw(raw), axis=0) * 300
- px = input_full.shape[1] // 512
- py = input_full.shape[2] // 512
- rx, ry = px * 512, py * 512
- input_full = input_full[:, :rx, :ry, :]
- output = np.random.randn(rx * 2, ry * 2, 3)
- input_full = np.minimum(input_full, 1.0)
- for i in range(px):
- for j in range(py):
- input_patch = input_full[:, i * 512:i * 512 + 512, j * 512:j * 512 + 512, :]
- result = exe.run(prog, feed={inputs[0]: input_patch}, fetch_list=outputs)
- output[i * 512 * 2:i * 512 * 2 + 512 * 2, j * 512 * 2:j * 512 * 2 + 512 * 2, :] = result[0][0]
- output = np.minimum(np.maximum(output, 0), 1)
-
- print('Denoising Over.')
- try:
- Image.fromarray(np.uint8(output * 255)).save(os.path.join(output_path))
- print('Image saved in {}'.format(output_path))
- except:
- print('Save image failed. Please check the output_path, should\
- be image format ext, e.g. png. current output path {}'.format(output_path))
+ exe, prog, inputs, outputs = self.set_device(use_gpu)
+
+ if images != None:
+ for raw in images:
+ input_full = np.expand_dims(pack_raw(raw), axis=0) * 300
+ px = input_full.shape[1] // 512
+ py = input_full.shape[2] // 512
+ rx, ry = px * 512, py * 512
+ input_full = input_full[:, :rx, :ry, :]
+ output = np.random.randn(rx * 2, ry * 2, 3)
+ input_full = np.minimum(input_full, 1.0)
+ for i in range(px):
+ for j in range(py):
+ input_patch = input_full[:, i * 512:i * 512 + 512, j * 512:j * 512 + 512, :]
+ result = exe.run(prog, feed={inputs[0]: input_patch}, fetch_list=outputs)
+ output[i * 512 * 2:i * 512 * 2 + 512 * 2, j * 512 * 2:j * 512 * 2 + 512 * 2, :] = result[0][0]
+ output = np.minimum(np.maximum(output, 0), 1)
+ output = output * 255
+ output = np.clip(output, 0, 255)
+ output = output.astype('uint8')
+ results.append(output)
+ if paths != None:
+ for path in paths:
+ raw = rawpy.imread(path)
+ input_full = np.expand_dims(pack_raw(raw), axis=0) * 300
+ px = input_full.shape[1] // 512
+ py = input_full.shape[2] // 512
+ rx, ry = px * 512, py * 512
+ input_full = input_full[:, :rx, :ry, :]
+ output = np.random.randn(rx * 2, ry * 2, 3)
+ input_full = np.minimum(input_full, 1.0)
+ for i in range(px):
+ for j in range(py):
+ input_patch = input_full[:, i * 512:i * 512 + 512, j * 512:j * 512 + 512, :]
+ result = exe.run(prog, feed={inputs[0]: input_patch}, fetch_list=outputs)
+ output[i * 512 * 2:i * 512 * 2 + 512 * 2, j * 512 * 2:j * 512 * 2 + 512 * 2, :] = result[0][0]
+ output = np.minimum(np.maximum(output, 0), 1)
+ output = output * 255
+ output = np.clip(output, 0, 255)
+ output = output.astype('uint8')
+ results.append(output)
+
+ if visualization == True:
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir, exist_ok=True)
+ for i, out in enumerate(results):
+ cv2.imwrite(os.path.join(output_dir, 'output_{}.png'.format(i)), out[:, :, ::-1])
+
+ return results
@runnable
def run_cmd(self, argvs: list):
@@ -101,7 +156,21 @@ class LearningToSeeInDark:
self.add_module_config_arg()
self.add_module_input_arg()
self.args = self.parser.parse_args(argvs)
- self.denoising(input_path=self.args.input_path, output_path=self.args.output_path, use_gpu=self.args.use_gpu)
+ self.denoising(
+ paths=[self.args.input_path],
+ output_dir=self.args.output_dir,
+ use_gpu=self.args.use_gpu,
+ visualization=self.args.visualization)
+
+ @serving
+ def serving_method(self, images, **kwargs):
+ """
+ Run as a service.
+ """
+ images_decode = [base64_to_cv2(image) for image in images]
+ results = self.denoising(images=images_decode, **kwargs)
+ tolist = [result.tolist() for result in results]
+ return tolist
def add_module_config_arg(self):
"""
@@ -110,7 +179,8 @@ class LearningToSeeInDark:
self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not")
self.arg_config_group.add_argument(
- '--output_path', type=str, default='denoising_result.png', help='output path for saving result.')
+ '--output_dir', type=str, default='denoising_result', help='output directory for saving result.')
+ self.arg_config_group.add_argument('--visualization', type=bool, default=False, help='save results or not.')
def add_module_input_arg(self):
"""
diff --git a/modules/image/image_processing/seeinthedark/requirements.txt b/modules/image/image_processing/seeinthedark/requirements.txt
index b49bea0673ed1c7e3e4d23e998eaa21f3d4494b1..32c8259e1c5451cc0f2bec980fa8476ac1708771 100644
--- a/modules/image/image_processing/seeinthedark/requirements.txt
+++ b/modules/image/image_processing/seeinthedark/requirements.txt
@@ -1,2 +1 @@
rawpy
-pillow