未验证 提交 600eb492 编写于 作者: H houj04 提交者: GitHub

add xpu and npu support for pyramidbox_lite series. (#1618)

上级 41bee5c9
......@@ -9,7 +9,10 @@ import os
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddle.inference import Config
from paddle.inference import create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from pyramidbox_lite_mobile.data_feed import reader
......@@ -29,26 +32,53 @@ class PyramidBoxLiteMobile(hub.Module):
self._set_config()
self.processor = self
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self):
"""
predictor config setting
"""
cpu_config = AnalysisConfig(self.default_pretrained_model_path)
# create default cpu predictor
cpu_config = Config(self.default_pretrained_model_path)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path)
self.cpu_predictor = create_predictor(cpu_config)
# create predictors using various types of devices
# npu
npu_id = self._get_device_id("FLAGS_selected_npus")
if npu_id != -1:
# use npu
npu_config = Config(self.default_pretrained_model_path)
npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config)
# gpu
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
# use gpu
gpu_config = Config(self.default_pretrained_model_path)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id)
self.gpu_predictor = create_predictor(gpu_config)
# xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
xpu_config = Config(self.default_pretrained_model_path)
xpu_config.disable_glog_info()
xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config)
def face_detection(self,
images=None,
......@@ -58,7 +88,8 @@ class PyramidBoxLiteMobile(hub.Module):
output_dir='detection_result',
visualization=False,
shrink=0.5,
confs_threshold=0.6):
confs_threshold=0.6,
use_device=None):
"""
API for face detection.
......@@ -70,18 +101,29 @@ class PyramidBoxLiteMobile(hub.Module):
visualization (bool): Whether to save image or not.
shrink (float): parameter to control the resize scale in preprocess.
confs_threshold (float): confidence threshold.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
res (list[dict]): The result of face detection and save path of images.
"""
# real predictor to use
if use_device is not None:
if use_device == "cpu":
predictor = self.cpu_predictor
elif use_device == "xpu":
predictor = self.xpu_predictor
elif use_device == "npu":
predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id."
)
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
# compatibility with older versions
if data:
......@@ -97,11 +139,19 @@ class PyramidBoxLiteMobile(hub.Module):
res = list()
# process one by one
for element in reader(images, paths, shrink):
image = np.expand_dims(element['image'], axis=0).astype('float32')
image_tensor = PaddleTensor(image.copy())
data_out = self.gpu_predictor.run([image_tensor]) if use_gpu else self.cpu_predictor.run([image_tensor])
batch_image = np.expand_dims(element['image'], axis=0).astype('float32')
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_handle(input_names[0])
input_tensor.reshape(batch_image.shape)
input_tensor.copy_from_cpu(batch_image.copy())
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
predictor_output = output_handle.copy_to_cpu()
out = postprocess(
data_out=data_out[0].as_ndarray(),
data_out=predictor_output,
org_im=element['org_im'],
org_im_path=element['org_im_path'],
image_width=element['image_width'],
......@@ -166,7 +216,8 @@ class PyramidBoxLiteMobile(hub.Module):
output_dir=args.output_dir,
visualization=args.visualization,
shrink=args.shrink,
confs_threshold=args.confs_threshold)
confs_threshold=args.confs_threshold,
use_device=args.use_device)
return results
def add_module_config_arg(self):
......@@ -179,6 +230,10 @@ class PyramidBoxLiteMobile(hub.Module):
'--output_dir', type=str, default='detection_result', help="The directory to save output images.")
self.arg_config_group.add_argument(
'--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.")
self.arg_config_group.add_argument(
'--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self):
"""
......
......@@ -101,7 +101,7 @@ def process_image(org_im, face):
return image_in
def reader(face_detector, shrink, confs_threshold, images, paths, use_gpu, use_multi_scale):
def reader(face_detector, shrink, confs_threshold, images, paths, use_gpu, use_multi_scale, use_device=None):
"""
Preprocess to yield image.
......@@ -113,6 +113,7 @@ def reader(face_detector, shrink, confs_threshold, images, paths, use_gpu, use_m
paths (list[str]): paths to images.
use_gpu (bool): whether to use gpu in face_detector.
use_multi_scale (bool): whether to enable multi-scale face detection.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Yield:
element (collections.OrderedDict): info of original image, preprocessed image, contains 3 keys:
org_im (numpy.ndarray) : original image.
......@@ -149,7 +150,8 @@ def reader(face_detector, shrink, confs_threshold, images, paths, use_gpu, use_m
use_gpu=use_gpu,
visualization=False,
shrink=scale,
confs_threshold=confs_threshold)
confs_threshold=confs_threshold,
use_device=use_device)
_s = list()
for _face in _detect_res[0]['data']:
......@@ -172,7 +174,8 @@ def reader(face_detector, shrink, confs_threshold, images, paths, use_gpu, use_m
use_gpu=use_gpu,
visualization=False,
shrink=shrink,
confs_threshold=confs_threshold)
confs_threshold=confs_threshold,
use_device=use_device)
detect_faces = _detect_res[0]['data']
element['preprocessed'] = list()
......
......@@ -9,7 +9,10 @@ import os
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddle.inference import Config
from paddle.inference import create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from pyramidbox_lite_mobile_mask.data_feed import reader
......@@ -38,26 +41,53 @@ class PyramidBoxLiteMobileMask(hub.Module):
self._set_config()
self.processor = self
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self):
"""
predictor config setting
"""
cpu_config = AnalysisConfig(self.default_pretrained_model_path)
# create default cpu predictor
cpu_config = Config(self.default_pretrained_model_path)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config)
self.cpu_predictor = create_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path)
# create predictors using various types of devices
# npu
npu_id = self._get_device_id("FLAGS_selected_npus")
if npu_id != -1:
# use npu
npu_config = Config(self.default_pretrained_model_path)
npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config)
# gpu
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
# use gpu
gpu_config = Config(self.default_pretrained_model_path)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id)
self.gpu_predictor = create_predictor(gpu_config)
# xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
xpu_config = Config(self.default_pretrained_model_path)
xpu_config.disable_glog_info()
xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config)
def set_face_detector_module(self, face_detector_module):
"""
......@@ -80,7 +110,8 @@ class PyramidBoxLiteMobileMask(hub.Module):
output_dir='detection_result',
use_multi_scale=False,
shrink=0.5,
confs_threshold=0.6):
confs_threshold=0.6,
use_device=None):
"""
API for face detection.
......@@ -96,18 +127,29 @@ class PyramidBoxLiteMobileMask(hub.Module):
it reduce the prediction speed for the increase model calculation.
shrink (float): parameter to control the resize scale in preprocess.
confs_threshold (float): confidence threshold.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
res (list[dict]): The result of face detection and save path of images.
"""
# real predictor to use
if use_device is not None:
if use_device == "cpu":
predictor = self.cpu_predictor
elif use_device == "xpu":
predictor = self.xpu_predictor
elif use_device == "npu":
predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id."
)
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
# compatibility with older versions
if data:
......@@ -122,7 +164,8 @@ class PyramidBoxLiteMobileMask(hub.Module):
# get all data
all_element = list()
for yield_data in reader(self.face_detector, shrink, confs_threshold, images, paths, use_gpu, use_multi_scale):
for yield_data in reader(self.face_detector, shrink, confs_threshold, images, paths, use_gpu, use_multi_scale,
use_device):
all_element.append(yield_data)
image_list = list()
......@@ -145,13 +188,18 @@ class PyramidBoxLiteMobileMask(hub.Module):
except:
pass
image_arr = np.squeeze(np.array(batch_data), axis=1)
image_tensor = PaddleTensor(image_arr.copy())
data_out = self.gpu_predictor.run([image_tensor]) if use_gpu else self.cpu_predictor.run([image_tensor])
# len(data_out) == 1
# data_out[0].as_ndarray().shape == (-1, 2)
data_out = data_out[0].as_ndarray()
predict_out = np.concatenate((predict_out, data_out))
batch_image = np.squeeze(np.array(batch_data), axis=1)
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_handle(input_names[0])
input_tensor.reshape(batch_image.shape)
input_tensor.copy_from_cpu(batch_image.copy())
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
predictor_output = output_handle.copy_to_cpu()
predict_out = np.concatenate((predict_out, predictor_output))
predict_out = predict_out[1:]
# postprocess one by one
......@@ -229,7 +277,8 @@ class PyramidBoxLiteMobileMask(hub.Module):
output_dir=args.output_dir,
visualization=args.visualization,
shrink=args.shrink,
confs_threshold=args.confs_threshold)
confs_threshold=args.confs_threshold,
use_device=args.use_device)
return results
def add_module_config_arg(self):
......@@ -242,6 +291,10 @@ class PyramidBoxLiteMobileMask(hub.Module):
'--output_dir', type=str, default='detection_result', help="The directory to save output images.")
self.arg_config_group.add_argument(
'--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.")
self.arg_config_group.add_argument(
'--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self):
"""
......
......@@ -9,7 +9,10 @@ import os
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddle.inference import Config
from paddle.inference import create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from pyramidbox_lite_server.data_feed import reader
......@@ -29,26 +32,53 @@ class PyramidBoxLiteServer(hub.Module):
self._set_config()
self.processor = self
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self):
"""
predictor config setting
"""
cpu_config = AnalysisConfig(self.default_pretrained_model_path)
# create default cpu predictor
cpu_config = Config(self.default_pretrained_model_path)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path)
self.cpu_predictor = create_predictor(cpu_config)
# create predictors using various types of devices
# npu
npu_id = self._get_device_id("FLAGS_selected_npus")
if npu_id != -1:
# use npu
npu_config = Config(self.default_pretrained_model_path)
npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config)
# gpu
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
# use gpu
gpu_config = Config(self.default_pretrained_model_path)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id)
self.gpu_predictor = create_predictor(gpu_config)
# xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
xpu_config = Config(self.default_pretrained_model_path)
xpu_config.disable_glog_info()
xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config)
def face_detection(self,
images=None,
......@@ -58,7 +88,8 @@ class PyramidBoxLiteServer(hub.Module):
output_dir='detection_result',
visualization=False,
shrink=0.5,
confs_threshold=0.6):
confs_threshold=0.6,
use_device=None):
"""
API for face detection.
......@@ -70,18 +101,29 @@ class PyramidBoxLiteServer(hub.Module):
visualization (bool): Whether to save image or not.
shrink (float): parameter to control the resize scale in preprocess.
confs_threshold (float): confidence threshold.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
res (list[dict]): The result of face detection and save path of images.
"""
# real predictor to use
if use_device is not None:
if use_device == "cpu":
predictor = self.cpu_predictor
elif use_device == "xpu":
predictor = self.xpu_predictor
elif use_device == "npu":
predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id."
)
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
# compatibility with older versions
if data:
......@@ -97,11 +139,19 @@ class PyramidBoxLiteServer(hub.Module):
res = list()
# process one by one
for element in reader(images, paths, shrink):
image = np.expand_dims(element['image'], axis=0).astype('float32')
image_tensor = PaddleTensor(image.copy())
data_out = self.gpu_predictor.run([image_tensor]) if use_gpu else self.cpu_predictor.run([image_tensor])
batch_image = np.expand_dims(element['image'], axis=0).astype('float32')
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_handle(input_names[0])
input_tensor.reshape(batch_image.shape)
input_tensor.copy_from_cpu(batch_image.copy())
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
predictor_output = output_handle.copy_to_cpu()
out = postprocess(
data_out=data_out[0].as_ndarray(),
data_out=predictor_output,
org_im=element['org_im'],
org_im_path=element['org_im_path'],
image_width=element['image_width'],
......@@ -163,7 +213,8 @@ class PyramidBoxLiteServer(hub.Module):
output_dir=args.output_dir,
visualization=args.visualization,
shrink=args.shrink,
confs_threshold=args.confs_threshold)
confs_threshold=args.confs_threshold,
use_device=args.use_device)
return results
def add_module_config_arg(self):
......@@ -176,6 +227,10 @@ class PyramidBoxLiteServer(hub.Module):
'--output_dir', type=str, default='detection_result', help="The directory to save output images.")
self.arg_config_group.add_argument(
'--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.")
self.arg_config_group.add_argument(
'--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self):
"""
......
......@@ -43,8 +43,7 @@ def bbox_vote(det):
det_accu[:, 0:4] = det_accu[:, 0:4] * np.tile(det_accu[:, -1:], (1, 4))
max_score = np.max(det_accu[:, 4])
det_accu_sum = np.zeros((1, 5))
det_accu_sum[:, 0:4] = np.sum(
det_accu[:, 0:4], axis=0) / np.sum(det_accu[:, -1:])
det_accu_sum[:, 0:4] = np.sum(det_accu[:, 0:4], axis=0) / np.sum(det_accu[:, -1:])
det_accu_sum[:, 4] = max_score
try:
dets = np.row_stack((dets, det_accu_sum))
......@@ -54,38 +53,26 @@ def bbox_vote(det):
return dets
def crop(image,
pts,
shift=0,
scale=1.5,
rotate=0,
res_width=128,
res_height=128):
def crop(image, pts, shift=0, scale=1.5, rotate=0, res_width=128, res_height=128):
res = (res_width, res_height)
idx1 = 0
idx2 = 1
# angle
alpha = 0
if pts[idx2, 0] != -1 and pts[idx2, 1] != -1 and pts[idx1, 0] != -1 and pts[
idx1, 1] != -1:
alpha = math.atan2(pts[idx2, 1] - pts[idx1, 1],
pts[idx2, 0] - pts[idx1, 0]) * 180 / math.pi
if pts[idx2, 0] != -1 and pts[idx2, 1] != -1 and pts[idx1, 0] != -1 and pts[idx1, 1] != -1:
alpha = math.atan2(pts[idx2, 1] - pts[idx1, 1], pts[idx2, 0] - pts[idx1, 0]) * 180 / math.pi
pts[pts == -1] = np.inf
coord_min = np.min(pts, 0)
pts[pts == np.inf] = -1
coord_max = np.max(pts, 0)
# coordinates of center point
c = np.array([
coord_max[0] - (coord_max[0] - coord_min[0]) / 2,
coord_max[1] - (coord_max[1] - coord_min[1]) / 2
]) # center
max_wh = max((coord_max[0] - coord_min[0]) / 2,
(coord_max[1] - coord_min[1]) / 2)
c = np.array([coord_max[0] - (coord_max[0] - coord_min[0]) / 2,
coord_max[1] - (coord_max[1] - coord_min[1]) / 2]) # center
max_wh = max((coord_max[0] - coord_min[0]) / 2, (coord_max[1] - coord_min[1]) / 2)
# Shift the center point, rot add eyes angle
c = c + shift * max_wh
rotate = rotate + alpha
M = cv2.getRotationMatrix2D((c[0], c[1]), rotate,
res[0] / (2 * max_wh * scale))
M = cv2.getRotationMatrix2D((c[0], c[1]), rotate, res[0] / (2 * max_wh * scale))
M[0, 2] = M[0, 2] - (c[0] - res[0] / 2.0)
M[1, 2] = M[1, 2] - (c[1] - res[0] / 2.0)
image_out = cv2.warpAffine(image, M, res)
......@@ -97,27 +84,24 @@ def color_normalize(image, mean, std=None):
image = np.repeat(image, axis=2)
h, w, c = image.shape
image = np.transpose(image, (2, 0, 1))
image = np.subtract(image.reshape(c, -1), mean[:, np.newaxis]).reshape(
-1, h, w)
image = np.subtract(image.reshape(c, -1), mean[:, np.newaxis]).reshape(-1, h, w)
image = np.transpose(image, (1, 2, 0))
return image
def process_image(org_im, face):
pts = np.array([
face['left'], face['top'], face['right'], face['top'], face['left'],
face['bottom'], face['right'], face['bottom']
face['left'], face['top'], face['right'], face['top'], face['left'], face['bottom'], face['right'],
face['bottom']
]).reshape(4, 2).astype(np.float32)
image_in, M = crop(org_im, pts)
image_in = image_in / 256.0
image_in = color_normalize(image_in, mean=np.array([0.5, 0.5, 0.5]))
image_in = image_in.astype(np.float32).transpose([2, 0, 1]).reshape(
-1, 3, 128, 128)
image_in = image_in.astype(np.float32).transpose([2, 0, 1]).reshape(-1, 3, 128, 128)
return image_in
def reader(face_detector, shrink, confs_threshold, images, paths, use_gpu,
use_multi_scale):
def reader(face_detector, shrink, confs_threshold, images, paths, use_gpu, use_multi_scale, use_device=None):
"""
Preprocess to yield image.
......@@ -129,6 +113,7 @@ def reader(face_detector, shrink, confs_threshold, images, paths, use_gpu,
paths (list[str]): paths to images.
use_gpu (bool): whether to use gpu in face_detector.
use_multi_scale (bool): whether to enable multi-scale face detection.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Yield:
element (collections.OrderedDict): info of original image, preprocessed image, contains 3 keys:
org_im (numpy.ndarray) : original image.
......@@ -142,8 +127,7 @@ def reader(face_detector, shrink, confs_threshold, images, paths, use_gpu,
assert type(paths) is list, "paths should be a list."
for im_path in paths:
each = OrderedDict()
assert os.path.isfile(
im_path), "The {} isn't a valid file path.".format(im_path)
assert os.path.isfile(im_path), "The {} isn't a valid file path.".format(im_path)
im = cv2.imread(im_path)
each['org_im'] = im
each['org_im_path'] = im_path
......@@ -153,8 +137,7 @@ def reader(face_detector, shrink, confs_threshold, images, paths, use_gpu,
for im in images:
each = OrderedDict()
each['org_im'] = im
each['org_im_path'] = 'ndarray_time={}'.format(
round(time.time(), 6) * 1e6)
each['org_im_path'] = 'ndarray_time={}'.format(round(time.time(), 6) * 1e6)
component.append(each)
for element in component:
......@@ -167,31 +150,24 @@ def reader(face_detector, shrink, confs_threshold, images, paths, use_gpu,
use_gpu=use_gpu,
visualization=False,
shrink=scale,
confs_threshold=confs_threshold)
confs_threshold=confs_threshold,
use_device=use_device)
_s = list()
for _face in _detect_res[0]['data']:
_face_list = [
_face['left'], _face['top'], _face['right'],
_face['bottom'], _face['confidence']
]
_face_list = [_face['left'], _face['top'], _face['right'], _face['bottom'], _face['confidence']]
_s.append(_face_list)
if _s:
scale_res.append(np.array(_s))
if scale_res:
scale_res = np.row_stack(scale_res)
scale_res = bbox_vote(scale_res)
keep_index = np.where(scale_res[:, 4] >= confs_threshold)[0]
scale_res = scale_res[keep_index, :]
for data in scale_res:
face = {
'left': data[0],
'top': data[1],
'right': data[2],
'bottom': data[3],
'confidence': data[4]
}
face = {'left': data[0], 'top': data[1], 'right': data[2], 'bottom': data[3], 'confidence': data[4]}
detect_faces.append(face)
else:
detect_faces = []
......@@ -201,7 +177,8 @@ def reader(face_detector, shrink, confs_threshold, images, paths, use_gpu,
use_gpu=use_gpu,
visualization=False,
shrink=shrink,
confs_threshold=confs_threshold)
confs_threshold=confs_threshold,
use_device=use_device)
detect_faces = _detect_res[0]['data']
element['preprocessed'] = list()
......
......@@ -9,7 +9,10 @@ import os
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddle.inference import Config
from paddle.inference import create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from pyramidbox_lite_server_mask.data_feed import reader
......@@ -30,8 +33,7 @@ class PyramidBoxLiteServerMask(hub.Module):
Args:
face_detector_module (class): module to detect face.
"""
self.default_pretrained_model_path = os.path.join(
self.directory, "pyramidbox_lite_server_mask_model")
self.default_pretrained_model_path = os.path.join(self.directory, "pyramidbox_lite_server_mask_model")
if face_detector_module is None:
self.face_detector = hub.Module(name='pyramidbox_lite_server')
else:
......@@ -39,27 +41,53 @@ class PyramidBoxLiteServerMask(hub.Module):
self._set_config()
self.processor = self
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self):
"""
predictor config setting
"""
cpu_config = AnalysisConfig(self.default_pretrained_model_path)
# create default cpu predictor
cpu_config = Config(self.default_pretrained_model_path)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config)
self.cpu_predictor = create_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path)
# create predictors using various types of devices
# npu
npu_id = self._get_device_id("FLAGS_selected_npus")
if npu_id != -1:
# use npu
npu_config = Config(self.default_pretrained_model_path)
npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config)
# gpu
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
# use gpu
gpu_config = Config(self.default_pretrained_model_path)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(
memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id)
self.gpu_predictor = create_predictor(gpu_config)
# xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
xpu_config = Config(self.default_pretrained_model_path)
xpu_config.disable_glog_info()
xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config)
def set_face_detector_module(self, face_detector_module):
"""
......@@ -82,7 +110,8 @@ class PyramidBoxLiteServerMask(hub.Module):
output_dir='detection_result',
use_multi_scale=False,
shrink=0.5,
confs_threshold=0.6):
confs_threshold=0.6,
use_device=None):
"""
API for face detection.
......@@ -97,18 +126,29 @@ class PyramidBoxLiteServerMask(hub.Module):
it reduce the prediction speed for the increase model calculation.
shrink (float): parameter to control the resize scale in preprocess.
confs_threshold (float): confidence threshold.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
res (list[dict]): The result of face detection and save path of images.
"""
# real predictor to use
if use_device is not None:
if use_device == "cpu":
predictor = self.cpu_predictor
elif use_device == "xpu":
predictor = self.xpu_predictor
elif use_device == "npu":
predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError(
"Attempt to use GPU for prediction, but environment variable CUDA_VISIBLE_DEVICES was not set correctly."
)
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
# compatibility with older versions
if data:
......@@ -123,16 +163,14 @@ class PyramidBoxLiteServerMask(hub.Module):
# get all data
all_element = list()
for yield_data in reader(self.face_detector, shrink, confs_threshold,
images, paths, use_gpu, use_multi_scale):
for yield_data in reader(self.face_detector, shrink, confs_threshold, images, paths, use_gpu, use_multi_scale,
use_device):
all_element.append(yield_data)
image_list = list()
element_image_num = list()
for i in range(len(all_element)):
element_image = [
handled['image'] for handled in all_element[i]['preprocessed']
]
element_image = [handled['image'] for handled in all_element[i]['preprocessed']]
element_image_num.append(len(element_image))
image_list.extend(element_image)
......@@ -149,23 +187,24 @@ class PyramidBoxLiteServerMask(hub.Module):
except:
pass
image_arr = np.squeeze(np.array(batch_data), axis=1)
image_tensor = PaddleTensor(image_arr.copy())
data_out = self.gpu_predictor.run([
image_tensor
]) if use_gpu else self.cpu_predictor.run([image_tensor])
# len(data_out) == 1
# data_out[0].as_ndarray().shape == (-1, 2)
data_out = data_out[0].as_ndarray()
predict_out = np.concatenate((predict_out, data_out))
batch_image = np.squeeze(np.array(batch_data), axis=1)
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_handle(input_names[0])
input_tensor.reshape(batch_image.shape)
input_tensor.copy_from_cpu(batch_image.copy())
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
predictor_output = output_handle.copy_to_cpu()
predict_out = np.concatenate((predict_out, predictor_output))
predict_out = predict_out[1:]
# postprocess one by one
res = list()
for i in range(len(all_element)):
detect_faces_list = [
handled['face'] for handled in all_element[i]['preprocessed']
]
detect_faces_list = [handled['face'] for handled in all_element[i]['preprocessed']]
interval_left = sum(element_image_num[0:i])
interval_right = interval_left + element_image_num[i]
out = postprocess(
......@@ -178,31 +217,16 @@ class PyramidBoxLiteServerMask(hub.Module):
res.append(out)
return res
def save_inference_model(self,
dirname,
model_filename=None,
params_filename=None,
combined=True):
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
classifier_dir = os.path.join(dirname, 'mask_detector')
detector_dir = os.path.join(dirname, 'pyramidbox_lite')
self._save_classifier_model(classifier_dir, model_filename,
params_filename, combined)
self._save_detector_model(detector_dir, model_filename, params_filename,
combined)
def _save_detector_model(self,
dirname,
model_filename=None,
params_filename=None,
combined=True):
self.face_detector.save_inference_model(dirname, model_filename,
params_filename, combined)
def _save_classifier_model(self,
dirname,
model_filename=None,
params_filename=None,
combined=True):
self._save_classifier_model(classifier_dir, model_filename, params_filename, combined)
self._save_detector_model(detector_dir, model_filename, params_filename, combined)
def _save_detector_model(self, dirname, model_filename=None, params_filename=None, combined=True):
self.face_detector.save_inference_model(dirname, model_filename, params_filename, combined)
def _save_classifier_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
......@@ -240,12 +264,9 @@ class PyramidBoxLiteServerMask(hub.Module):
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(
title="Input options", description="Input data. Required")
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options",
description=
"Run configuration for controlling module behavior, not required.")
title="Config options", description="Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
......@@ -255,7 +276,8 @@ class PyramidBoxLiteServerMask(hub.Module):
output_dir=args.output_dir,
visualization=args.visualization,
shrink=args.shrink,
confs_threshold=args.confs_threshold)
confs_threshold=args.confs_threshold,
use_device=args.use_device)
return results
def add_module_config_arg(self):
......@@ -263,36 +285,25 @@ class PyramidBoxLiteServerMask(hub.Module):
Add the command config options.
"""
self.arg_config_group.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU or not")
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not")
self.arg_config_group.add_argument(
'--output_dir',
type=str,
default='detection_result',
help="The directory to save output images.")
'--output_dir', type=str, default='detection_result', help="The directory to save output images.")
self.arg_config_group.add_argument(
'--visualization',
type=ast.literal_eval,
default=False,
help="whether to save output as images.")
'--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.")
self.arg_config_group.add_argument(
'--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self):
"""
Add the command input options.
"""
self.arg_input_group.add_argument(
'--input_path', type=str, help="path to image.")
self.arg_input_group.add_argument('--input_path', type=str, help="path to image.")
self.arg_input_group.add_argument(
'--shrink',
type=ast.literal_eval,
default=0.5,
help=
"resize the image to `shrink * original_shape` before feeding into network."
)
help="resize the image to `shrink * original_shape` before feeding into network.")
self.arg_input_group.add_argument(
'--confs_threshold',
type=ast.literal_eval,
default=0.6,
help="confidence threshold.")
'--confs_threshold', type=ast.literal_eval, default=0.6, help="confidence threshold.")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册