未验证 提交 ab5b0151 编写于 作者: F Feng Ni 提交者: GitHub

[cherry-pick] add tqdm eval and infer (#5588)

上级 d244753b
......@@ -141,10 +141,6 @@ class LogPrinter(Callback):
dtime=str(data_time),
ips=ips)
logger.info(fmt)
if mode == 'eval':
step_id = status['step_id']
if step_id % 100 == 0:
logger.info("Eval iter: {}".format(step_id))
def on_epoch_end(self, status):
if dist.get_world_size() < 2 or dist.get_rank() == 0:
......
......@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import os
import cv2
import glob
import re
import paddle
import numpy as np
import os.path as osp
from tqdm import tqdm
from collections import defaultdict
from ppdet.core.workspace import create
......@@ -31,8 +30,7 @@ from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_bo
from ppdet.modeling.mot.utils import MOTTimer, load_det_results, write_mot_results, save_vis_results
from ppdet.modeling.mot.tracker import JDETracker, DeepSORTTracker
from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric
from ppdet.metrics import MCMOTMetric
from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric, MCMOTMetric
import ppdet.utils.stats as stats
from .callbacks import Callback, ComposeCallback
......@@ -142,11 +140,8 @@ class Tracker(object):
self.model.eval()
results = defaultdict(list) # support single class and multi classes
for step_id, data in enumerate(dataloader):
for step_id, data in enumerate(tqdm(dataloader)):
self.status['step_id'] = step_id
if frame_id % 40 == 0:
logger.info('Processing frame {} ({:.2f} fps)'.format(
frame_id, 1. / max(1e-5, timer.average_time)))
# forward
timer.tic()
pred_dets, pred_embs = self.model(data)
......@@ -210,12 +205,8 @@ class Tracker(object):
det_file))
tracker = self.model.tracker
for step_id, data in enumerate(dataloader):
for step_id, data in enumerate(tqdm(dataloader)):
self.status['step_id'] = step_id
if frame_id % 40 == 0:
logger.info('Processing frame {} ({:.2f} fps)'.format(
frame_id, 1. / max(1e-5, timer.average_time)))
ori_image = data['ori_image'] # [bs, H, W, 3]
ori_image_shape = data['ori_image'].shape[1:3]
# ori_image_shape: [H, W]
......@@ -339,8 +330,8 @@ class Tracker(object):
results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
save_vis_results(data, frame_id, online_ids, online_tlwhs,
online_scores, timer.average_time, show_image,
save_dir, self.cfg.num_classes)
online_scores, timer.average_time, show_image,
save_dir, self.cfg.num_classes)
elif isinstance(tracker, JDETracker):
# trick hyperparams only used for MOTChallenge (MOT17, MOT20) Test-set
......@@ -366,12 +357,12 @@ class Tracker(object):
online_scores[cls_id].append(tscore)
# save results
results[cls_id].append(
(frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
online_ids[cls_id]))
(frame_id + 1, online_tlwhs[cls_id],
online_scores[cls_id], online_ids[cls_id]))
timer.toc()
save_vis_results(data, frame_id, online_ids, online_tlwhs,
online_scores, timer.average_time, show_image,
save_dir, self.cfg.num_classes)
online_scores, timer.average_time, show_image,
save_dir, self.cfg.num_classes)
frame_id += 1
......@@ -417,7 +408,7 @@ class Tracker(object):
save_dir = os.path.join(output_dir, 'mot_outputs',
seq) if save_images or save_videos else None
logger.info('start seq: {}'.format(seq))
logger.info('Evaluate seq: {}'.format(seq))
self.dataset.set_images(self.get_infer_images(infer_dir))
dataloader = create('EvalMOTReader')(self.dataset, 0)
......@@ -458,7 +449,6 @@ class Tracker(object):
os.system(cmd_str)
logger.info('Save video in {}.'.format(output_video_path))
logger.info('Evaluate seq: {}'.format(seq))
# update metrics
for metric in self._metrics:
metric.update(data_root, seq, data_type, result_root,
......@@ -582,6 +572,7 @@ class Tracker(object):
write_mot_results(result_filename, results, data_type,
self.cfg.num_classes)
def get_trick_hyperparams(video_name, ori_buffer, ori_thresh):
if video_name[:3] != 'MOT':
# only used for MOTChallenge (MOT17, MOT20) Test-set
......@@ -610,5 +601,5 @@ def get_trick_hyperparams(video_name, ori_buffer, ori_thresh):
track_thresh = 0.3
else:
track_thresh = ori_thresh
return track_buffer, ori_thresh
......@@ -20,6 +20,7 @@ import os
import sys
import copy
import time
from tqdm import tqdm
import numpy as np
import typing
......@@ -500,7 +501,7 @@ class Trainer(object):
flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
self._flops(flops_loader)
for step_id, data in enumerate(loader):
for step_id, data in enumerate(tqdm(loader)):
self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status)
# forward
......@@ -553,7 +554,7 @@ class Trainer(object):
flops_loader = create('TestReader')(self.dataset, 0)
self._flops(flops_loader)
results = []
for step_id, data in enumerate(loader):
for step_id, data in enumerate(tqdm(loader)):
self.status['step_id'] = step_id
# forward
outs = self.model(data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册