提交 de72c8f6 编写于 作者: littletomatodonkey's avatar littletomatodonkey

support system prediction

上级 0a82f9bb
from . import utils
Global:
infer_imgs: "images/coco_000000570688.jpg"
# infer_imgs: "../docs/images/whl/demo.jpg"
det_inference_model_dir: "./ppyolov2_r50vd_dcn_365e_mainbody_infer/"
rec_inference_model_dir: "./MobileNetV1_infer/"
batch_size: 1
image_shape: [3, 640, 640]
threshold: 0.5
max_det_results: 1
labe_list:
- foreground
# inference engine config
use_gpu: False
enable_mkldnn: True
cpu_num_threads: 100
enable_benchmark: True
use_fp16: False
ir_optim: True
use_tensorrt: False
gpu_mem: 8000
enable_profile: False
DetPreProcess:
transform_ops:
- DetResize:
interp: 2
keep_ratio: false
target_size: [640, 640]
- DetNormalizeImage:
is_scale: true
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- DetPermute: {}
DetPostProcess: {}
RecPreProcess:
transform_ops:
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
RecPostProcess: null
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
def decode_image(im_file, im_info):
"""read rgb image
Args:
im_file (str|np.ndarray): input can be image path or np.ndarray
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if isinstance(im_file, str):
with open(im_file, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
else:
im = im_file
im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
return im, im_info
class DetResize(object):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(
self,
target_size,
keep_ratio=True,
interp=cv2.INTER_LINEAR, ):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
# set image_shape
im_info['input_shape'][1] = int(im_scale_y * im.shape[0])
im_info['input_shape'][2] = int(im_scale_x * im.shape[1])
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array(
[im_scale_y, im_scale_x]).astype('float32')
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class DetNormalizeImage(object):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
is_channel_first (bool): if True: image shape is CHW, else: HWC
"""
def __init__(self, mean, std, is_scale=True):
self.mean = mean
self.std = std
self.is_scale = is_scale
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.is_scale:
im = im / 255.0
im -= mean
im /= std
return im, im_info
class DetPermute(object):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(self, ):
super().__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class DetPadStride(object):
""" padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride <= 0:
return im, im_info
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
return padding_im, im_info
def det_preprocess(im, im_info, preprocess_ops):
for operator in preprocess_ops:
im, im_info = operator(im, im_info)
return im, im_info
......@@ -21,6 +21,8 @@ import paddle.nn.functional as F
def build_postprocess(config):
if config is None:
return None
config = copy.deepcopy(config)
model_name = config.pop("name")
mod = importlib.import_module(__name__)
......
......@@ -28,14 +28,33 @@ from preprocess import create_operators
from postprocess import build_postprocess
class ClsPredictor(object):
class ClsPredictor(Predictor):
def __init__(self, config):
super().__init__()
self.predictor = Predictor(config["Global"])
super().__init__(config["Global"])
self.preprocess_ops = create_operators(config["PreProcess"][
"transform_ops"])
self.postprocess = build_postprocess(config["PostProcess"])
def predict(self, images):
input_names = self.paddle_predictor.get_input_names()
input_tensor = self.paddle_predictor.get_input_handle(input_names[0])
output_names = self.paddle_predictor.get_output_names()
output_tensor = self.paddle_predictor.get_output_handle(output_names[
0])
if not isinstance(images, (list, )):
images = [images]
for idx in range(len(images)):
for ops in self.preprocess_ops:
images[idx] = ops(images[idx])
image = np.array(images)
input_tensor.copy_from_cpu(image)
self.paddle_predictor.run()
batch_output = output_tensor.copy_to_cpu()
return batch_output
def main(config):
cls_predictor = ClsPredictor(config)
......@@ -43,12 +62,8 @@ def main(config):
assert config["Global"]["batch_size"] == 1
for idx, image_file in enumerate(image_list):
batch_input = []
img = cv2.imread(image_file)[:, :, ::-1]
for ops in cls_predictor.preprocess_ops:
img = ops(img)
batch_input.append(img)
output = cls_predictor.predictor.predict(np.array(batch_input))
output = cls_predictor.predict(img)
output = cls_predictor.postprocess(output)
print(output)
return
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
from utils import logger
from utils import config
from utils.predictor import Predictor
from utils.get_image_list import get_image_list
from det_preprocess import det_preprocess
from preprocess import create_operators
import os
import argparse
import time
import yaml
import ast
from functools import reduce
import cv2
import numpy as np
import paddle
class DetPredictor(Predictor):
def __init__(self, config):
super().__init__(config["Global"],
config["Global"]["det_inference_model_dir"])
self.preprocess_ops = create_operators(config["DetPreProcess"][
"transform_ops"])
self.config = config
def preprocess(self, img):
im_info = {
'scale_factor': np.array(
[1., 1.], dtype=np.float32),
'im_shape': np.array(
img.shape[:2], dtype=np.float32),
'input_shape': self.config["Global"]["image_shape"],
"scale_factor": np.array(
[1., 1.], dtype=np.float32)
}
im, im_info = det_preprocess(img, im_info, self.preprocess_ops)
inputs = self.create_inputs(im, im_info)
return inputs
def create_inputs(self, im, im_info):
"""generate input for different model type
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
model_arch (str): model type
Returns:
inputs (dict): input of model
"""
inputs = {}
inputs['image'] = np.array((im, )).astype('float32')
inputs['im_shape'] = np.array(
(im_info['im_shape'], )).astype('float32')
inputs['scale_factor'] = np.array(
(im_info['scale_factor'], )).astype('float32')
return inputs
def parse_det_results(self, pred, threshold, label_list):
max_det_results = self.config["Global"]["max_det_results"]
keep_indexes = pred[:, 1].argsort()[::-1][:max_det_results]
results = []
for idx in keep_indexes:
single_res = pred[idx]
class_id = int(single_res[0])
score = single_res[1]
bbox = single_res[2:]
if score < threshold:
continue
label_name = label_list[class_id]
results.append({
"class_id": class_id,
"score": score,
"bbox": bbox,
"label_name": label_name,
})
return results
def predict(self, image, threshold=0.5, run_benchmark=False):
'''
Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2
threshold (float): threshold of predicted box' score
Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape: [N, im_h, im_w]
'''
inputs = self.preprocess(image)
np_boxes = None
input_names = self.paddle_predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.paddle_predictor.get_input_handle(input_names[
i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
t1 = time.time()
self.paddle_predictor.run()
output_names = self.paddle_predictor.get_output_names()
boxes_tensor = self.paddle_predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
t2 = time.time()
print("Inference: {} ms per batch image".format((t2 - t1) * 1000.0))
# do not perform postprocess in benchmark mode
results = []
if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
print('[WARNNING] No object detected.')
results = np.array([])
else:
results = np_boxes
results = self.parse_det_results(results,
self.config["Global"]["threshold"],
self.config["Global"]["labe_list"])
return results
def main(config):
det_predictor = DetPredictor(config)
image_list = get_image_list(config["Global"]["infer_imgs"])
assert config["Global"]["batch_size"] == 1
for idx, image_file in enumerate(image_list):
img = cv2.imread(image_file)[:, :, ::-1]
output = det_predictor.predict(img)
print(output)
return
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True)
main(config)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
import cv2
import numpy as np
from utils import logger
from utils import config
from utils.predictor import Predictor
from utils.get_image_list import get_image_list
from preprocess import create_operators
from postprocess import build_postprocess
class RecPredictor(Predictor):
def __init__(self, config):
super().__init__(config["Global"],
config["Global"]["rec_inference_model_dir"])
self.preprocess_ops = create_operators(config["RecPreProcess"][
"transform_ops"])
self.postprocess = build_postprocess(config["RecPostProcess"])
def predict(self, images):
input_names = self.paddle_predictor.get_input_names()
input_tensor = self.paddle_predictor.get_input_handle(input_names[0])
output_names = self.paddle_predictor.get_output_names()
output_tensor = self.paddle_predictor.get_output_handle(output_names[
0])
if not isinstance(images, (list, )):
images = [images]
for idx in range(len(images)):
for ops in self.preprocess_ops:
images[idx] = ops(images[idx])
image = np.array(images)
input_tensor.copy_from_cpu(image)
self.paddle_predictor.run()
batch_output = output_tensor.copy_to_cpu()
return batch_output
def main(config):
rec_predictor = RecPredictor(config)
image_list = get_image_list(config["Global"]["infer_imgs"])
assert config["Global"]["batch_size"] == 1
for idx, image_file in enumerate(image_list):
batch_input = []
img = cv2.imread(image_file)[:, :, ::-1]
output = rec_predictor.predict(img)
if rec_predictor.postprocess is not None:
output = rec_predictor.postprocess(output)
print(output.shape)
return
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True)
main(config)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
import copy
import cv2
import numpy as np
from python.predict_rec import RecPredictor
from python.predict_det import DetPredictor
from utils import logger
from utils import config
from utils.get_image_list import get_image_list
class SystemPredictor(object):
def __init__(self, config):
self.rec_predictor = RecPredictor(config)
self.det_predictor = DetPredictor(config)
def predict(self, img):
output = []
results = self.det_predictor.predict(img)
for result in results:
print(result)
xmin, xmax, ymin, ymax = result["bbox"].astype("int")
crop_img = img[xmin:xmax, ymin:ymax, :].copy()
rec_results = self.rec_predictor.predict(crop_img)
result["featrue"] = rec_results
output.append(result)
return output
def main(config):
system_predictor = SystemPredictor(config)
image_list = get_image_list(config["Global"]["infer_imgs"])
assert config["Global"]["batch_size"] == 1
for idx, image_file in enumerate(image_list):
img = cv2.imread(image_file)[:, :, ::-1]
output = system_predictor.predict(img)
print(output)
return
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True)
main(config)
......@@ -26,6 +26,8 @@ import cv2
import numpy as np
import importlib
from det_preprocess import DetNormalizeImage, DetPadStride, DetPermute, DetResize
def create_operators(params):
"""
......
# classification
python3.7 python/predict_cls.py -c configs/inference_cls.yaml
# feature extractor
# python3.7 python/predict_rec.py -c configs/inference_rec.yaml
# detection
# python3.7 python/predict_det.py -c configs/inference_rec.yaml
# mainbody detection + feature extractor
# python3.7 python/predict_system.py -c configs/inference_rec.yaml
\ No newline at end of file
......@@ -23,32 +23,22 @@ from paddle.inference import create_predictor
class Predictor(object):
def __init__(self, args):
def __init__(self, args, inference_model_dir=None):
# HALF precission predict only work when using tensorrt
if args.use_fp16 is True:
assert args.use_tensorrt is True
self.args = args
self.paddle_predictor = self.create_paddle_predictor(
args, inference_model_dir)
self.paddle_predictor = self.create_paddle_predictor(args)
input_names = self.paddle_predictor.get_input_names()
self.input_tensor = self.paddle_predictor.get_input_handle(input_names[
0])
def predict(self, image):
raise NotImplementedError
output_names = self.paddle_predictor.get_output_names()
self.output_tensor = self.paddle_predictor.get_output_handle(
output_names[0])
def predict(self, batch_input):
self.input_tensor.copy_from_cpu(batch_input)
self.paddle_predictor.run()
batch_output = self.output_tensor.copy_to_cpu()
return batch_output
def create_paddle_predictor(self, args):
params_file = os.path.join(args.inference_model_dir,
"inference.pdiparams")
model_file = os.path.join(args.inference_model_dir,
"inference.pdmodel")
def create_paddle_predictor(self, args, inference_model_dir=None):
if inference_model_dir is None:
inference_model_dir = args.inference_model_dir
params_file = os.path.join(inference_model_dir, "inference.pdiparams")
model_file = os.path.join(inference_model_dir, "inference.pdmodel")
config = Config(model_file, params_file)
if args.use_gpu:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册