未验证 提交 e420428b 编写于 作者: H houj04 提交者: GitHub

yolov3_darknet53_vehicles support npu and xpu. (#1609)

上级 3e471d8e
...@@ -9,7 +9,8 @@ from functools import partial ...@@ -9,7 +9,8 @@ from functools import partial
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddlehub as hub import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor from paddle.inference import Config
from paddle.inference import create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
from paddlehub.common.paddle_helper import add_vars_prefix from paddlehub.common.paddle_helper import add_vars_prefix
...@@ -19,40 +20,65 @@ from yolov3_darknet53_vehicles.data_feed import reader ...@@ -19,40 +20,65 @@ from yolov3_darknet53_vehicles.data_feed import reader
from yolov3_darknet53_vehicles.yolo_head import MultiClassNMS, YOLOv3Head from yolov3_darknet53_vehicles.yolo_head import MultiClassNMS, YOLOv3Head
@moduleinfo( @moduleinfo(name="yolov3_darknet53_vehicles",
name="yolov3_darknet53_vehicles", version="1.0.1",
version="1.0.1", type="CV/object_detection",
type="CV/object_detection", summary="Baidu's YOLOv3 model for vehicles detection, with backbone DarkNet53.",
summary="Baidu's YOLOv3 model for vehicles detection, with backbone DarkNet53.", author="paddlepaddle",
author="paddlepaddle", author_email="paddle-dev@baidu.com")
author_email="paddle-dev@baidu.com")
class YOLOv3DarkNet53Vehicles(hub.Module): class YOLOv3DarkNet53Vehicles(hub.Module):
def _initialize(self): def _initialize(self):
self.default_pretrained_model_path = os.path.join(self.directory, "yolov3_darknet53_vehicles_model") self.default_pretrained_model_path = os.path.join(self.directory, "yolov3_darknet53_vehicles_model")
self.label_names = load_label_info(os.path.join(self.directory, "label_file.txt")) self.label_names = load_label_info(os.path.join(self.directory, "label_file.txt"))
self._set_config() self._set_config()
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self): def _set_config(self):
""" """
predictor config setting. predictor config setting.
""" """
cpu_config = AnalysisConfig(self.default_pretrained_model_path)
# create default cpu predictor
cpu_config = Config(self.default_pretrained_model_path)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
cpu_config.switch_ir_optim(False) self.cpu_predictor = create_predictor(cpu_config)
self.cpu_predictor = create_paddle_predictor(cpu_config)
try: # create predictors using various types of devices
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0]) # npu
use_gpu = True npu_id = self._get_device_id("FLAGS_selected_npus")
except: if npu_id != -1:
use_gpu = False # use npu
if use_gpu: npu_config = Config(self.default_pretrained_model_path)
gpu_config = AnalysisConfig(self.default_pretrained_model_path) npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config)
# gpu
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
# use gpu
gpu_config = Config(self.default_pretrained_model_path)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id)
self.gpu_predictor = create_paddle_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
# xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
xpu_config = Config(self.default_pretrained_model_path)
xpu_config.disable_glog_info()
xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config)
def context(self, trainable=True, pretrained=True, get_prediction=False): def context(self, trainable=True, pretrained=True, get_prediction=False):
""" """
...@@ -81,21 +107,19 @@ class YOLOv3DarkNet53Vehicles(hub.Module): ...@@ -81,21 +107,19 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
# im_size # im_size
im_size = fluid.layers.data(name='im_size', shape=[2], dtype='int32') im_size = fluid.layers.data(name='im_size', shape=[2], dtype='int32')
# yolo_head # yolo_head
yolo_head = YOLOv3Head( yolo_head = YOLOv3Head(anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]], anchors=[[8, 9], [10, 23], [19, 15], [23, 33], [40, 25], [54, 50], [101, 80],
anchors=[[8, 9], [10, 23], [19, 15], [23, 33], [40, 25], [54, 50], [101, 80], [139, 145], [139, 145], [253, 224]],
[253, 224]], norm_decay=0.,
norm_decay=0., num_classes=6,
num_classes=6, ignore_thresh=0.7,
ignore_thresh=0.7, label_smooth=False,
label_smooth=False, nms=MultiClassNMS(background_label=-1,
nms=MultiClassNMS( keep_top_k=100,
background_label=-1, nms_threshold=0.45,
keep_top_k=100, nms_top_k=400,
nms_threshold=0.45, normalized=False,
nms_top_k=400, score_threshold=0.005))
normalized=False,
score_threshold=0.005))
# head_features # head_features
head_features, body_features = yolo_head._get_outputs(body_feats, is_train=trainable) head_features, body_features = yolo_head._get_outputs(body_feats, is_train=trainable)
...@@ -148,7 +172,8 @@ class YOLOv3DarkNet53Vehicles(hub.Module): ...@@ -148,7 +172,8 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
use_gpu=False, use_gpu=False,
output_dir='yolov3_vehicles_detect_output', output_dir='yolov3_vehicles_detect_output',
score_thresh=0.2, score_thresh=0.2,
visualization=True): visualization=True,
use_device=None):
"""API of Object Detection. """API of Object Detection.
Args: Args:
...@@ -159,6 +184,7 @@ class YOLOv3DarkNet53Vehicles(hub.Module): ...@@ -159,6 +184,7 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
output_dir (str): The path to store output images. output_dir (str): The path to store output images.
visualization (bool): Whether to save image or not. visualization (bool): Whether to save image or not.
score_thresh (float): threshold for object detecion. score_thresh (float): threshold for object detecion.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns: Returns:
res (list[dict]): The result of vehicles detecion. keys include 'data', 'save_path', the corresponding value is: res (list[dict]): The result of vehicles detecion. keys include 'data', 'save_path', the corresponding value is:
...@@ -171,14 +197,25 @@ class YOLOv3DarkNet53Vehicles(hub.Module): ...@@ -171,14 +197,25 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
confidence (float): The confidence of detection result. confidence (float): The confidence of detection result.
save_path (str, optional): The path to save output images. save_path (str, optional): The path to save output images.
""" """
if use_gpu:
try: # real predictor to use
_places = os.environ["CUDA_VISIBLE_DEVICES"] if use_device is not None:
int(_places[0]) if use_device == "cpu":
except: predictor = self.cpu_predictor
raise RuntimeError( elif use_device == "xpu":
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id." predictor = self.xpu_predictor
) elif use_device == "npu":
predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
paths = paths if paths else list() paths = paths if paths else list()
data_reader = partial(reader, paths, images) data_reader = partial(reader, paths, images)
...@@ -186,22 +223,31 @@ class YOLOv3DarkNet53Vehicles(hub.Module): ...@@ -186,22 +223,31 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
res = [] res = []
for iter_id, feed_data in enumerate(batch_reader()): for iter_id, feed_data in enumerate(batch_reader()):
feed_data = np.array(feed_data) feed_data = np.array(feed_data)
image_tensor = PaddleTensor(np.array(list(feed_data[:, 0])))
im_size_tensor = PaddleTensor(np.array(list(feed_data[:, 1]))) input_names = predictor.get_input_names()
if use_gpu: image_data = np.array(list(feed_data[:, 0]))
data_out = self.gpu_predictor.run([image_tensor, im_size_tensor]) image_size_data = np.array(list(feed_data[:, 1]))
else:
data_out = self.cpu_predictor.run([image_tensor, im_size_tensor]) image_tensor = predictor.get_input_handle(input_names[0])
image_tensor.reshape(image_data.shape)
output = postprocess( image_tensor.copy_from_cpu(image_data.copy())
paths=paths,
images=images, image_size_tensor = predictor.get_input_handle(input_names[1])
data_out=data_out, image_size_tensor.reshape(image_size_data.shape)
score_thresh=score_thresh, image_size_tensor.copy_from_cpu(image_size_data.copy())
label_names=self.label_names,
output_dir=output_dir, predictor.run()
handle_id=iter_id * batch_size, output_names = predictor.get_output_names()
visualization=visualization) output_handle = predictor.get_output_handle(output_names[0])
output = postprocess(paths=paths,
images=images,
data_out=output_handle,
score_thresh=score_thresh,
label_names=self.label_names,
output_dir=output_dir,
handle_id=iter_id * batch_size,
visualization=visualization)
res.extend(output) res.extend(output)
return res return res
...@@ -215,14 +261,13 @@ class YOLOv3DarkNet53Vehicles(hub.Module): ...@@ -215,14 +261,13 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
program, feeded_var_names, target_vars = fluid.io.load_inference_model( program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe) dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model( fluid.io.save_inference_model(dirname=dirname,
dirname=dirname, main_program=program,
main_program=program, executor=exe,
executor=exe, feeded_var_names=feeded_var_names,
feeded_var_names=feeded_var_names, target_vars=target_vars,
target_vars=target_vars, model_filename=model_filename,
model_filename=model_filename, params_filename=params_filename)
params_filename=params_filename)
@serving @serving
def serving_method(self, images, **kwargs): def serving_method(self, images, **kwargs):
...@@ -238,39 +283,44 @@ class YOLOv3DarkNet53Vehicles(hub.Module): ...@@ -238,39 +283,44 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
""" """
Run as a command. Run as a command.
""" """
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
description="Run the {} module.".format(self.name), prog='hub run {}'.format(self.name),
prog='hub run {}'.format(self.name), usage='%(prog)s',
usage='%(prog)s', add_help=True)
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group( self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.") title="Config options", description="Run configuration for controlling module behavior, not required.")
self.add_module_config_arg() self.add_module_config_arg()
self.add_module_input_arg() self.add_module_input_arg()
args = self.parser.parse_args(argvs) args = self.parser.parse_args(argvs)
results = self.object_detection( results = self.object_detection(paths=[args.input_path],
paths=[args.input_path], batch_size=args.batch_size,
batch_size=args.batch_size, use_gpu=args.use_gpu,
use_gpu=args.use_gpu, output_dir=args.output_dir,
output_dir=args.output_dir, visualization=args.visualization,
visualization=args.visualization, score_thresh=args.score_thresh,
score_thresh=args.score_thresh) use_device=args.use_device)
return results return results
def add_module_config_arg(self): def add_module_config_arg(self):
""" """
Add the command config options. Add the command config options.
""" """
self.arg_config_group.add_argument( self.arg_config_group.add_argument('--use_gpu',
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not") type=ast.literal_eval,
self.arg_config_group.add_argument( default=False,
'--output_dir', help="whether use GPU or not")
type=str, self.arg_config_group.add_argument('--output_dir',
default='yolov3_vehicles_detect_output', type=str,
help="The directory to save output images.") default='yolov3_vehicles_detect_output',
self.arg_config_group.add_argument( help="The directory to save output images.")
'--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.") self.arg_config_group.add_argument('--visualization',
type=ast.literal_eval,
default=False,
help="whether to save output as images.")
self.arg_config_group.add_argument('--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
...@@ -278,5 +328,7 @@ class YOLOv3DarkNet53Vehicles(hub.Module): ...@@ -278,5 +328,7 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
""" """
self.arg_input_group.add_argument('--input_path', type=str, help="path to image.") self.arg_input_group.add_argument('--input_path', type=str, help="path to image.")
self.arg_input_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.") self.arg_input_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.")
self.arg_input_group.add_argument( self.arg_input_group.add_argument('--score_thresh',
'--score_thresh', type=ast.literal_eval, default=0.2, help="threshold for object detecion.") type=ast.literal_eval,
default=0.2,
help="threshold for object detecion.")
...@@ -57,8 +57,8 @@ def draw_bounding_box_on_image(image_path, data_list, save_dir): ...@@ -57,8 +57,8 @@ def draw_bounding_box_on_image(image_path, data_list, save_dir):
if image.mode == 'RGB': if image.mode == 'RGB':
text = data['label'] + ": %.2f%%" % (100 * data['confidence']) text = data['label'] + ": %.2f%%" % (100 * data['confidence'])
textsize_width, textsize_height = draw.textsize(text=text) textsize_width, textsize_height = draw.textsize(text=text)
draw.rectangle( draw.rectangle(xy=(left, top - (textsize_height + 5), left + textsize_width + 10, top),
xy=(left, top - (textsize_height + 5), left + textsize_width + 10, top), fill=(255, 255, 255)) fill=(255, 255, 255))
draw.text(xy=(left, top - 15), text=text, fill=(0, 0, 0)) draw.text(xy=(left, top - 15), text=text, fill=(0, 0, 0))
save_name = get_save_image_name(image, save_dir, image_path) save_name = get_save_image_name(image, save_dir, image_path)
...@@ -94,8 +94,6 @@ def postprocess(paths, images, data_out, score_thresh, label_names, output_dir, ...@@ -94,8 +94,6 @@ def postprocess(paths, images, data_out, score_thresh, label_names, output_dir,
paths (list[str]): The paths of images. paths (list[str]): The paths of images.
images (list(numpy.ndarray)): images data, shape of each is [H, W, C] images (list(numpy.ndarray)): images data, shape of each is [H, W, C]
data_out (lod_tensor): data output of predictor. data_out (lod_tensor): data output of predictor.
batch_size (int): batch size.
use_gpu (bool): Whether to use gpu.
output_dir (str): The path to store output images. output_dir (str): The path to store output images.
visualization (bool): Whether to save image or not. visualization (bool): Whether to save image or not.
score_thresh (float): the low limit of bounding box. score_thresh (float): the low limit of bounding box.
...@@ -113,9 +111,8 @@ def postprocess(paths, images, data_out, score_thresh, label_names, output_dir, ...@@ -113,9 +111,8 @@ def postprocess(paths, images, data_out, score_thresh, label_names, output_dir,
confidence (float): The confidence of detection result. confidence (float): The confidence of detection result.
save_path (str): The path to save output images. save_path (str): The path to save output images.
""" """
lod_tensor = data_out[0] results = data_out.copy_to_cpu()
lod = lod_tensor.lod[0] lod = data_out.lod()[0]
results = lod_tensor.as_ndarray()
check_dir(output_dir) check_dir(output_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册