未验证 提交 2f06ad8a 编写于 作者: J JYChen 提交者: GitHub

add vehicle attr model into pipeline (#6274)

* add vehicle attr inti pipeline

* fix in no-rgb in predict_video
上级 39531637
crop_thresh: 0.5
attr_thresh: 0.5
kpt_thresh: 0.2
visual: True
warmup_frame: 50
......@@ -24,12 +22,14 @@ VEHICLE_PLATE:
rec_batch_num: 6
word_dict_path: deploy/pphuman/ppvehicle/rec_word_dict.txt
basemode: "idbased"
enable: True
enable: False
ATTR:
model_dir: output_inference/strongbaseline_r50_30e/
VEHICLE_ATTR:
model_dir: output_inference/vehicle_attribute_infer/
batch_size: 8
basemode: "idbased"
color_threshold: 0.5
type_threshold: 0.5
enable: False
REID:
......
......@@ -28,7 +28,8 @@ class Result(object):
'reid': dict(),
'det_action': dict(),
'cls_action': dict(),
'vehicleplate': dict()
'vehicleplate': dict(),
'vehicle_attr': dict()
}
def update(self, res, name):
......
......@@ -156,7 +156,8 @@ class PipeTimer(Times):
'skeleton_action': Times(),
'reid': Times(),
'det_action': Times(),
'cls_action': Times()
'cls_action': Times(),
'vehicle_attr': Times()
}
self.img_num = 0
......
......@@ -27,7 +27,6 @@ from collections import Sequence
from reid import ReID
from datacollector import DataCollector, Result
from mtmct import mtmct_process
from ppvehicle.vehicle_plate import PlateRecognizer
# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
......@@ -51,6 +50,9 @@ from pptracking.python.mot_sde_infer import SDE_Detector
from pptracking.python.mot.visualize import plot_tracking_dict
from pptracking.python.mot.utils import flow_statistic
from ppvehicle.vehicle_plate import PlateRecognizer
from ppvehicle.vehicle_attr import VehicleAttr
class Pipeline(object):
"""
......@@ -224,12 +226,12 @@ class PipePredictor(object):
# general module for pphuman and ppvehicle
self.with_mot = cfg.get('MOT', False)['enable'] if cfg.get(
'MOT', False) else False
self.with_attr = cfg.get('ATTR', False)['enable'] if cfg.get(
self.with_human_attr = cfg.get('ATTR', False)['enable'] if cfg.get(
'ATTR', False) else False
if self.with_mot:
print('Multi-Object Tracking enabled')
if self.with_attr:
print('Attribute Recognition enabled')
if self.with_human_attr:
print('Human Attribute Recognition enabled')
# only for pphuman
self.with_skeleton_action = cfg.get(
......@@ -265,6 +267,12 @@ class PipePredictor(object):
if self.with_vehicleplate:
print('Vehicle Plate Recognition enabled')
self.with_vehicle_attr = cfg.get(
'VEHICLE_ATTR', False)['enable'] if cfg.get('VEHICLE_ATTR',
False) else False
if self.with_vehicle_attr:
print('Vehicle Attribute Recognition enabled')
self.modebase = {
"framebased": False,
"videobased": False,
......@@ -294,7 +302,7 @@ class PipePredictor(object):
model_dir, device, run_mode, batch_size, trt_min_shape,
trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
enable_mkldnn)
if self.with_attr:
if self.with_human_attr:
attr_cfg = self.cfg['ATTR']
model_dir = attr_cfg['model_dir']
batch_size = attr_cfg['batch_size']
......@@ -305,8 +313,21 @@ class PipePredictor(object):
trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
enable_mkldnn)
if self.with_vehicle_attr:
vehicleattr_cfg = self.cfg['VEHICLE_ATTR']
model_dir = vehicleattr_cfg['model_dir']
batch_size = vehicleattr_cfg['batch_size']
color_threshold = vehicleattr_cfg['color_threshold']
type_threshold = vehicleattr_cfg['type_threshold']
basemode = vehicleattr_cfg['basemode']
self.modebase[basemode] = True
self.vehicle_attr_predictor = VehicleAttr(
model_dir, device, run_mode, batch_size, trt_min_shape,
trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
enable_mkldnn, color_threshold, type_threshold)
else:
if self.with_attr:
if self.with_human_attr:
attr_cfg = self.cfg['ATTR']
model_dir = attr_cfg['model_dir']
batch_size = attr_cfg['batch_size']
......@@ -412,6 +433,19 @@ class PipePredictor(object):
basemode = vehicleplate_cfg['basemode']
self.modebase[basemode] = True
if self.with_vehicle_attr:
vehicleattr_cfg = self.cfg['VEHICLE_ATTR']
model_dir = vehicleattr_cfg['model_dir']
batch_size = vehicleattr_cfg['batch_size']
color_threshold = vehicleattr_cfg['color_threshold']
type_threshold = vehicleattr_cfg['type_threshold']
basemode = vehicleattr_cfg['basemode']
self.modebase[basemode] = True
self.vehicle_attr_predictor = VehicleAttr(
model_dir, device, run_mode, batch_size, trt_min_shape,
trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
enable_mkldnn, color_threshold, type_threshold)
if self.with_mot or self.modebase["idbased"] or self.modebase[
"skeletonbased"]:
mot_cfg = self.cfg['MOT']
......@@ -510,7 +544,7 @@ class PipePredictor(object):
self.pipe_timer.module_time['det'].end()
self.pipeline_res.update(det_res, 'det')
if self.with_attr:
if self.with_human_attr:
crop_inputs = crop_image_with_det(batch_input, det_res)
attr_res_list = []
......@@ -528,6 +562,24 @@ class PipePredictor(object):
attr_res = {'output': attr_res_list}
self.pipeline_res.update(attr_res, 'attr')
if self.with_vehicle_attr:
crop_inputs = crop_image_with_det(batch_input, det_res)
vehicle_attr_res_list = []
if i > self.warmup_frame:
self.pipe_timer.module_time['vehicle_attr'].start()
for crop_input in crop_inputs:
attr_res = self.vehicle_attr_predictor.predict_image(
crop_input, visual=False)
vehicle_attr_res_list.extend(attr_res['output'])
if i > self.warmup_frame:
self.pipe_timer.module_time['vehicle_attr'].end()
attr_res = {'output': vehicle_attr_res_list}
self.pipeline_res.update(attr_res, 'vehicle_attr')
self.pipe_timer.img_num += len(batch_input)
if i > self.warmup_frame:
self.pipe_timer.total_time.end()
......@@ -581,13 +633,14 @@ class PipePredictor(object):
ret, frame = capture.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if self.modebase["idbased"] or self.modebase["skeletonbased"]:
if frame_id > self.warmup_frame:
self.pipe_timer.total_time.start()
self.pipe_timer.module_time['mot'].start()
res = self.mot_predictor.predict_image(
[copy.deepcopy(frame)], visual=False)
[copy.deepcopy(frame_rgb)], visual=False)
if frame_id > self.warmup_frame:
self.pipe_timer.module_time['mot'].end()
......@@ -625,14 +678,14 @@ class PipePredictor(object):
self.pipeline_res.update(mot_res, 'mot')
crop_input, new_bboxes, ori_bboxes = crop_image_with_mot(
frame, mot_res)
frame_rgb, mot_res)
if self.with_vehicleplate:
platelicense = self.vehicleplate_detector.get_platelicense(
crop_input)
self.pipeline_res.update(platelicense, 'vehicleplate')
if self.with_attr:
if self.with_human_attr:
if frame_id > self.warmup_frame:
self.pipe_timer.module_time['attr'].start()
attr_res = self.attr_predictor.predict_image(
......@@ -641,6 +694,15 @@ class PipePredictor(object):
self.pipe_timer.module_time['attr'].end()
self.pipeline_res.update(attr_res, 'attr')
if self.with_vehicle_attr:
if frame_id > self.warmup_frame:
self.pipe_timer.module_time['vehicle_attr'].start()
attr_res = self.vehicle_attr_predictor.predict_image(
crop_input, visual=False)
if frame_id > self.warmup_frame:
self.pipe_timer.module_time['vehicle_attr'].end()
self.pipeline_res.update(attr_res, 'vehicle_attr')
if self.with_idbased_detaction:
if frame_id > self.warmup_frame:
self.pipe_timer.module_time['det_action'].start()
......@@ -708,7 +770,7 @@ class PipePredictor(object):
if self.with_mtmct and frame_id % 10 == 0:
crop_input, img_qualities, rects = self.reid_predictor.crop_image_with_mot(
frame, mot_res)
frame_rgb, mot_res)
if frame_id > self.warmup_frame:
self.pipe_timer.module_time['reid'].start()
reid_res = self.reid_predictor.predict_batch(crop_input)
......@@ -740,7 +802,7 @@ class PipePredictor(object):
# collect frames
if frame_id % sample_freq == 0:
# Scale image
scaled_img = scale(frame)
scaled_img = scale(frame_rgb)
video_action_imgs.append(scaled_img)
# the number of collected frames is enough to predict video action
......@@ -820,11 +882,18 @@ class PipePredictor(object):
records=records,
center_traj=center_traj)
attr_res = result.get('attr')
if attr_res is not None:
human_attr_res = result.get('attr')
if human_attr_res is not None:
boxes = mot_res['boxes'][:, 1:]
human_attr_res = human_attr_res['output']
image = visualize_attr(image, human_attr_res, boxes)
image = np.array(image)
vehicle_attr_res = result.get('vehicle_attr')
if vehicle_attr_res is not None:
boxes = mot_res['boxes'][:, 1:]
attr_res = attr_res['output']
image = visualize_attr(image, attr_res, boxes)
vehicle_attr_res = vehicle_attr_res['output']
image = visualize_attr(image, vehicle_attr_res, boxes)
image = np.array(image)
vehicleplate_res = result.get('vehicleplate')
......@@ -883,7 +952,9 @@ class PipePredictor(object):
def visualize_image(self, im_files, images, result):
start_idx, boxes_num_i = 0, 0
det_res = result.get('det')
attr_res = result.get('attr')
human_attr_res = result.get('attr')
vehicle_attr_res = result.get('vehicle_attr')
for i, (im_file, im) in enumerate(zip(im_files, images)):
if det_res is not None:
det_res_i = {}
......@@ -897,10 +968,15 @@ class PipePredictor(object):
threshold=self.cfg['crop_thresh'])
im = np.ascontiguousarray(np.copy(im))
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
if attr_res is not None:
attr_res_i = attr_res['output'][start_idx:start_idx +
boxes_num_i]
im = visualize_attr(im, attr_res_i, det_res_i['boxes'])
if human_attr_res is not None:
human_attr_res_i = human_attr_res['output'][start_idx:start_idx
+ boxes_num_i]
im = visualize_attr(im, human_attr_res_i, det_res_i['boxes'])
if vehicle_attr_res is not None:
vehicle_attr_res_i = vehicle_attr_res['output'][
start_idx:start_idx + boxes_num_i]
im = visualize_attr(im, vehicle_attr_res_i, det_res_i['boxes'])
img_name = os.path.split(im_file)[-1]
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import yaml
import glob
import cv2
import numpy as np
import math
import paddle
import sys
from collections import Sequence
# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 3)))
sys.path.insert(0, parent_path)
from paddle.inference import Config, create_predictor
from utils import argsparser, Timer, get_current_memory_mb
from benchmark_utils import PaddleInferBenchmark
from python.infer import Detector, print_arguments
from python.attr_infer import AttrDetector
class VehicleAttr(AttrDetector):
"""
Args:
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
type_threshold (float): The threshold of score for vehicle type recognition.
color_threshold (float): The threshold of score for vehicle color recognition.
"""
def __init__(self,
model_dir,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
output_dir='output',
color_threshold=0.5,
type_threshold=0.5):
super(VehicleAttr, self).__init__(
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
output_dir=output_dir)
self.color_threshold = color_threshold
self.type_threshold = type_threshold
self.result_history = {}
self.color_list = [
"yellow", "orange", "green", "gray", "red", "blue", "white",
"golden", "brown", "black"
]
self.type_list = [
"sedan", "suv", "van", "hatchback", "mpv", "pickup", "bus", "truck",
"estate"
]
def postprocess(self, inputs, result):
# postprocess output of predictor
im_results = result['output']
batch_res = []
for res in im_results:
res = res.tolist()
attr_res = []
color_res_str = "Color: "
type_res_str = "Type: "
color_idx = np.argmax(res[:10])
type_idx = np.argmax(res[10:])
if res[color_idx] >= self.color_threshold:
color_res_str += self.color_list[color_idx]
else:
color_res_str += "Unknown"
attr_res.append(color_res_str)
if res[type_idx + 10] >= self.type_threshold:
type_res_str += self.type_list[type_idx]
else:
type_res_str += "Unknown"
attr_res.append(type_res_str)
batch_res.append(attr_res)
result = {'output': batch_res}
return result
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
print_arguments(FLAGS)
FLAGS.device = FLAGS.device.upper()
assert FLAGS.device in ['CPU', 'GPU', 'XPU'
], "device should be CPU, GPU or XPU"
assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册