未验证 提交 94c19353 编写于 作者: F FlyingQianMM 提交者: GitHub

[MOT] add python implementation for flow_statistics (#4528)

* add flow_statistic for mot_jde

* raise NotImplementedError when multi-class entrance counting is chosen

* revert * 'mp4v' to *'mp4v'
上级 91f55ece
......@@ -19,14 +19,8 @@ import paddle
import numpy as np
__all__ = [
'MOTTimer',
'Detection',
'write_mot_results',
'load_det_results',
'preprocess_reid',
'get_crops',
'clip_box',
'scale_coords',
'MOTTimer', 'Detection', 'write_mot_results', 'load_det_results',
'preprocess_reid', 'get_crops', 'clip_box', 'scale_coords', 'flow_statistic'
]
......@@ -219,3 +213,81 @@ def preprocess_reid(imgs,
im_batch.append(img)
im_batch = np.concatenate(im_batch, 0)
return im_batch
def flow_statistic(result,
secs_interval,
do_entrance_counting,
video_fps,
entrance,
id_set,
interval_id_set,
in_id_list,
out_id_list,
prev_center,
records,
data_type,
num_classes=1):
# Count in and out number:
# Use horizontal center line as the entrance just for simplification.
# If a person located in the above the horizontal center line
# at the previous frame and is in the below the line at the current frame,
# the in number is increased by one.
# If a person was in the below the horizontal center line
# at the previous frame and locates in the below the line at the current frame,
# the out number is increased by one.
# TODO: if the entrance is not the horizontal center line,
# the counting method should be optimized.
if do_entrance_counting:
entrance_y = entrance[1] # xmin, ymin, xmax, ymax
frame_id, tlwhs, tscores, track_ids = result
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
if track_id < 0: continue
if data_type == 'kitti':
frame_id -= 1
x1, y1, w, h = tlwh
center_x = x1 + w / 2.
center_y = y1 + h / 2.
if track_id in prev_center:
if prev_center[track_id][1] <= entrance_y and \
center_y > entrance_y:
in_id_list.append(track_id)
if prev_center[track_id][1] >= entrance_y and \
center_y < entrance_y:
out_id_list.append(track_id)
prev_center[track_id][0] = center_x
prev_center[track_id][1] = center_y
else:
prev_center[track_id] = [center_x, center_y]
# Count totol number, number at a manual-setting interval
frame_id, tlwhs, tscores, track_ids = result
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
if track_id < 0: continue
id_set.add(track_id)
interval_id_set.add(track_id)
# Reset counting at the interval beginning
if frame_id % video_fps == 0 and frame_id / video_fps % secs_interval == 0:
curr_interval_count = len(interval_id_set)
interval_id_set.clear()
info = "Frame id: {}, Total count: {}".format(frame_id, len(id_set))
if do_entrance_counting:
info += ", In count: {}, Out count: {}".format(
len(in_id_list), len(out_id_list))
if frame_id % video_fps == 0 and frame_id / video_fps % secs_interval == 0:
info += ", Count during {} secs: {}".format(secs_interval,
curr_interval_count)
interval_id_set.clear()
print(info)
info += "\n"
records.append(info)
return {
"id_set": id_set,
"interval_id_set": interval_id_set,
"in_id_list": in_id_list,
"out_id_list": out_id_list,
"prev_center": prev_center,
"records": records
}
......@@ -29,7 +29,7 @@ from benchmark_utils import PaddleInferBenchmark
from visualize import plot_tracking_dict
from mot.tracker import JDETracker
from mot.utils import MOTTimer, write_mot_results
from mot.utils import MOTTimer, write_mot_results, flow_statistic
# Global dictionary
MOT_SUPPORT_MODELS = {
......@@ -220,6 +220,16 @@ def predict_video(detector, camera_id):
data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = detector.pred_config.labels
if num_classes == 1:
id_set = set()
interval_id_set = set()
in_id_list = list()
out_id_list = list()
prev_center = dict()
records = list()
entrance = [0, height / 2., width, height / 2.]
video_fps = fps
while (1):
ret, frame = capture.read()
if not ret:
......@@ -233,6 +243,25 @@ def predict_video(detector, camera_id):
results[cls_id].append((frame_id + 1, online_tlwhs[cls_id],
online_scores[cls_id], online_ids[cls_id]))
# NOTE: just implement flow statistic for one class
if num_classes == 1:
result = (frame_id + 1, online_tlwhs[0], online_scores[0],
online_ids[0])
statistic = flow_statistic(
result, FLAGS.secs_interval, FLAGS.do_entrance_counting,
video_fps, entrance, id_set, interval_id_set, in_id_list,
out_id_list, prev_center, records, data_type, num_classes)
id_set = statistic['id_set']
interval_id_set = statistic['interval_id_set']
in_id_list = statistic['in_id_list']
out_id_list = statistic['out_id_list']
prev_center = statistic['prev_center']
records = statistic['records']
elif num_classes > 1 and do_entrance_counting:
raise NotImplementedError(
'Multi-class flow counting is not implemented now!')
fps = 1. / timer.average_time
im = plot_tracking_dict(
frame,
......@@ -264,6 +293,16 @@ def predict_video(detector, camera_id):
write_mot_results(result_filename, results, data_type, num_classes)
if num_classes == 1:
result_filename = os.path.join(
FLAGS.output_dir,
video_name.split('.')[-2] + '_flow_statistic.txt')
f = open(result_filename, 'w')
for line in records:
f.write(line)
print('Flow statistic save in {}'.format(result_filename))
f.close()
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(save_dir,
......
......@@ -120,6 +120,17 @@ def argsparser():
type=int,
default=50,
help="max batch_size for reid model inference.")
parser.add_argument(
"--do_entrance_counting",
action='store_true',
help="Whether counting the numbers of identifiers entering "
"or getting out from the entrance. Note that only support one-class"
"counting, multi-class counting is coming soon.")
parser.add_argument(
"--secs_interval",
type=int,
default=10,
help="The seconds interval to count after tracking")
return parser
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册