提交 827270b7 编写于 作者: C chenjian

reduce

上级 4b4d8cf8
import os
from typing import Union
import numpy as np
from onnxruntime import InferenceSession
def get_relative_path(root, *args):
return os.path.join(os.path.dirname(root), *args)
class EnlightenOnnxModel:
def __init__(self, model: Union[bytes, str, None] = None):
self.graph = InferenceSession(model or get_relative_path(__file__, 'enlighten.onnx'))
def __repr__(self):
return '<EnlightenGAN OnnxModel> {}'.format(id(self))
def _pad(self, img):
h, w, _ = img.shape
block_size = 16
min_height = (h // block_size + 1) * block_size
min_width = (w // block_size + 1) * block_size
img = np.pad(img, ((0, min_height - h), (0, min_width - w), (0, 0)), mode='constant', constant_values=0)
return img, (h, w)
def _preprocess(self, img):
if len(img.shape) != 3:
raise ValueError('Incorrect shape: expected 3, got {}'.format(img.shape))
return np.expand_dims(np.transpose(img, (2, 0, 1)).astype(np.float32) / 255., 0)
def predict(self, img):
padded, (h, w) = self._pad(img)
image_numpy, = self.graph.run(['output'], {'input': self._preprocess(padded)})
image_numpy = (np.transpose(image_numpy[0], (1, 2, 0)) + 1) / 2.0 * 255.0
image_numpy = np.clip(image_numpy, 0, 255)
return image_numpy.astype('uint8')[:h, :w, :]
import paddle import paddle
import math import math
from x2paddle.op_mapper.onnx2paddle import onnx_custom_layer as x2paddle_nn
class ONNXModel(paddle.nn.Layer): class ONNXModel(paddle.nn.Layer):
......
...@@ -21,7 +21,6 @@ from paddlehub.module.module import moduleinfo, runnable, serving ...@@ -21,7 +21,6 @@ from paddlehub.module.module import moduleinfo, runnable, serving
import numpy as np import numpy as np
import cv2 import cv2
from .enlighten_inference import EnlightenOnnxModel
from .enlighten_inference.pd_model.x2paddle_code import ONNXModel from .enlighten_inference.pd_model.x2paddle_code import ONNXModel
from .util import base64_to_cv2 from .util import base64_to_cv2
...@@ -36,19 +35,19 @@ class EnlightenGAN: ...@@ -36,19 +35,19 @@ class EnlightenGAN:
self.model.set_dict(params, use_structured_name=True) self.model.set_dict(params, use_structured_name=True)
def enlightening(self, def enlightening(self,
images=None, images:list=None,
paths=None, paths:list=None,
output_dir='./enlightening_result/', output_dir:str='./enlightening_result/',
use_gpu=False, use_gpu:bool=False,
visualization=True): visualization:bool=True):
''' '''
enlighten images in the low-light scene. enlighten images in the low-light scene.
images (list[numpy.ndarray]): data of images, shape of each is [H, W, C], color space must be BGR(read by cv2). images (list[numpy.ndarray]): data of images, shape of each is [H, W, C], color space must be BGR(read by cv2).
paths (list[str]): paths to images paths (list[str]): paths to images
output_dir: the dir to save the results output_dir (str): the dir to save the results
use_gpu: if True, use gpu to perform the computation, otherwise cpu. use_gpu (bool): if True, use gpu to perform the computation, otherwise cpu.
visualization: if True, save results in output_dir. visualization (bool): if True, save results in output_dir.
''' '''
results = [] results = []
paddle.disable_static() paddle.disable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册