未验证 提交 057ef8bd 编写于 作者: F Feng Ni 提交者: GitHub

add tqdm eval and infer (#5587)

上级 ae54352f
...@@ -141,10 +141,6 @@ class LogPrinter(Callback): ...@@ -141,10 +141,6 @@ class LogPrinter(Callback):
dtime=str(data_time), dtime=str(data_time),
ips=ips) ips=ips)
logger.info(fmt) logger.info(fmt)
if mode == 'eval':
step_id = status['step_id']
if step_id % 100 == 0:
logger.info("Eval iter: {}".format(step_id))
def on_epoch_end(self, status): def on_epoch_end(self, status):
if dist.get_world_size() < 2 or dist.get_rank() == 0: if dist.get_world_size() < 2 or dist.get_rank() == 0:
......
...@@ -17,12 +17,11 @@ from __future__ import division ...@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import cv2
import glob import glob
import re import re
import paddle import paddle
import numpy as np import numpy as np
import os.path as osp from tqdm import tqdm
from collections import defaultdict from collections import defaultdict
from ppdet.core.workspace import create from ppdet.core.workspace import create
...@@ -31,8 +30,7 @@ from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_bo ...@@ -31,8 +30,7 @@ from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_bo
from ppdet.modeling.mot.utils import MOTTimer, load_det_results, write_mot_results, save_vis_results from ppdet.modeling.mot.utils import MOTTimer, load_det_results, write_mot_results, save_vis_results
from ppdet.modeling.mot.tracker import JDETracker, DeepSORTTracker from ppdet.modeling.mot.tracker import JDETracker, DeepSORTTracker
from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric, MCMOTMetric
from ppdet.metrics import MCMOTMetric
import ppdet.utils.stats as stats import ppdet.utils.stats as stats
from .callbacks import Callback, ComposeCallback from .callbacks import Callback, ComposeCallback
...@@ -142,11 +140,8 @@ class Tracker(object): ...@@ -142,11 +140,8 @@ class Tracker(object):
self.model.eval() self.model.eval()
results = defaultdict(list) # support single class and multi classes results = defaultdict(list) # support single class and multi classes
for step_id, data in enumerate(dataloader): for step_id, data in enumerate(tqdm(dataloader)):
self.status['step_id'] = step_id self.status['step_id'] = step_id
if frame_id % 40 == 0:
logger.info('Processing frame {} ({:.2f} fps)'.format(
frame_id, 1. / max(1e-5, timer.average_time)))
# forward # forward
timer.tic() timer.tic()
pred_dets, pred_embs = self.model(data) pred_dets, pred_embs = self.model(data)
...@@ -210,12 +205,8 @@ class Tracker(object): ...@@ -210,12 +205,8 @@ class Tracker(object):
det_file)) det_file))
tracker = self.model.tracker tracker = self.model.tracker
for step_id, data in enumerate(dataloader): for step_id, data in enumerate(tqdm(dataloader)):
self.status['step_id'] = step_id self.status['step_id'] = step_id
if frame_id % 40 == 0:
logger.info('Processing frame {} ({:.2f} fps)'.format(
frame_id, 1. / max(1e-5, timer.average_time)))
ori_image = data['ori_image'] # [bs, H, W, 3] ori_image = data['ori_image'] # [bs, H, W, 3]
ori_image_shape = data['ori_image'].shape[1:3] ori_image_shape = data['ori_image'].shape[1:3]
# ori_image_shape: [H, W] # ori_image_shape: [H, W]
...@@ -339,8 +330,8 @@ class Tracker(object): ...@@ -339,8 +330,8 @@ class Tracker(object):
results[0].append( results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids)) (frame_id + 1, online_tlwhs, online_scores, online_ids))
save_vis_results(data, frame_id, online_ids, online_tlwhs, save_vis_results(data, frame_id, online_ids, online_tlwhs,
online_scores, timer.average_time, show_image, online_scores, timer.average_time, show_image,
save_dir, self.cfg.num_classes) save_dir, self.cfg.num_classes)
elif isinstance(tracker, JDETracker): elif isinstance(tracker, JDETracker):
# trick hyperparams only used for MOTChallenge (MOT17, MOT20) Test-set # trick hyperparams only used for MOTChallenge (MOT17, MOT20) Test-set
...@@ -366,12 +357,12 @@ class Tracker(object): ...@@ -366,12 +357,12 @@ class Tracker(object):
online_scores[cls_id].append(tscore) online_scores[cls_id].append(tscore)
# save results # save results
results[cls_id].append( results[cls_id].append(
(frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id], (frame_id + 1, online_tlwhs[cls_id],
online_ids[cls_id])) online_scores[cls_id], online_ids[cls_id]))
timer.toc() timer.toc()
save_vis_results(data, frame_id, online_ids, online_tlwhs, save_vis_results(data, frame_id, online_ids, online_tlwhs,
online_scores, timer.average_time, show_image, online_scores, timer.average_time, show_image,
save_dir, self.cfg.num_classes) save_dir, self.cfg.num_classes)
frame_id += 1 frame_id += 1
...@@ -417,7 +408,7 @@ class Tracker(object): ...@@ -417,7 +408,7 @@ class Tracker(object):
save_dir = os.path.join(output_dir, 'mot_outputs', save_dir = os.path.join(output_dir, 'mot_outputs',
seq) if save_images or save_videos else None seq) if save_images or save_videos else None
logger.info('start seq: {}'.format(seq)) logger.info('Evaluate seq: {}'.format(seq))
self.dataset.set_images(self.get_infer_images(infer_dir)) self.dataset.set_images(self.get_infer_images(infer_dir))
dataloader = create('EvalMOTReader')(self.dataset, 0) dataloader = create('EvalMOTReader')(self.dataset, 0)
...@@ -458,7 +449,6 @@ class Tracker(object): ...@@ -458,7 +449,6 @@ class Tracker(object):
os.system(cmd_str) os.system(cmd_str)
logger.info('Save video in {}.'.format(output_video_path)) logger.info('Save video in {}.'.format(output_video_path))
logger.info('Evaluate seq: {}'.format(seq))
# update metrics # update metrics
for metric in self._metrics: for metric in self._metrics:
metric.update(data_root, seq, data_type, result_root, metric.update(data_root, seq, data_type, result_root,
...@@ -582,6 +572,7 @@ class Tracker(object): ...@@ -582,6 +572,7 @@ class Tracker(object):
write_mot_results(result_filename, results, data_type, write_mot_results(result_filename, results, data_type,
self.cfg.num_classes) self.cfg.num_classes)
def get_trick_hyperparams(video_name, ori_buffer, ori_thresh): def get_trick_hyperparams(video_name, ori_buffer, ori_thresh):
if video_name[:3] != 'MOT': if video_name[:3] != 'MOT':
# only used for MOTChallenge (MOT17, MOT20) Test-set # only used for MOTChallenge (MOT17, MOT20) Test-set
...@@ -610,5 +601,5 @@ def get_trick_hyperparams(video_name, ori_buffer, ori_thresh): ...@@ -610,5 +601,5 @@ def get_trick_hyperparams(video_name, ori_buffer, ori_thresh):
track_thresh = 0.3 track_thresh = 0.3
else: else:
track_thresh = ori_thresh track_thresh = ori_thresh
return track_buffer, ori_thresh return track_buffer, ori_thresh
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import sys import sys
import copy import copy
import time import time
from tqdm import tqdm
import numpy as np import numpy as np
import typing import typing
...@@ -500,7 +501,7 @@ class Trainer(object): ...@@ -500,7 +501,7 @@ class Trainer(object):
flops_loader = create('{}Reader'.format(self.mode.capitalize()))( flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, self.cfg.worker_num, self._eval_batch_sampler) self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
self._flops(flops_loader) self._flops(flops_loader)
for step_id, data in enumerate(loader): for step_id, data in enumerate(tqdm(loader)):
self.status['step_id'] = step_id self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status) self._compose_callback.on_step_begin(self.status)
# forward # forward
...@@ -553,7 +554,7 @@ class Trainer(object): ...@@ -553,7 +554,7 @@ class Trainer(object):
flops_loader = create('TestReader')(self.dataset, 0) flops_loader = create('TestReader')(self.dataset, 0)
self._flops(flops_loader) self._flops(flops_loader)
results = [] results = []
for step_id, data in enumerate(loader): for step_id, data in enumerate(tqdm(loader)):
self.status['step_id'] = step_id self.status['step_id'] = step_id
# forward # forward
outs = self.model(data) outs = self.model(data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册