未验证 提交 a0ccc7d3 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] fix fairmot (#4614)

* fix Momentum and label list

* fix export of mcfairmot
上级 af56c3e5
...@@ -12,7 +12,6 @@ LearningRate: ...@@ -12,7 +12,6 @@ LearningRate:
OptimizerBuilder: OptimizerBuilder:
optimizer: optimizer:
momentum: 0.9
type: Momentum type: Momentum
regularizer: regularizer:
factor: 0.0001 factor: 0.0001
......
...@@ -9,6 +9,13 @@ norm_type: sync_bn ...@@ -9,6 +9,13 @@ norm_type: sync_bn
use_ema: true use_ema: true
ema_decay: 0.9998 ema_decay: 0.9998
# add crowdhuman
TrainDataset:
!MOTDataSet
dataset_dir: dataset/mot
image_lists: ['mot17.train', 'caltech.all', 'cuhksysu.train', 'prw.train', 'citypersons.train', 'eth.train', 'crowdhuman.train', 'crowdhuman.val']
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
worker_num: 4 worker_num: 4
TrainReader: TrainReader:
inputs_def: inputs_def:
......
...@@ -10,7 +10,7 @@ norm_type: sync_bn ...@@ -10,7 +10,7 @@ norm_type: sync_bn
use_ema: true use_ema: true
ema_decay: 0.9998 ema_decay: 0.9998
# for MOT training # add crowdhuman
TrainDataset: TrainDataset:
!MOTDataSet !MOTDataSet
dataset_dir: dataset/mot dataset_dir: dataset/mot
......
...@@ -10,7 +10,7 @@ norm_type: sync_bn ...@@ -10,7 +10,7 @@ norm_type: sync_bn
use_ema: true use_ema: true
ema_decay: 0.9998 ema_decay: 0.9998
# for MOT training # add crowdhuman
TrainDataset: TrainDataset:
!MOTDataSet !MOTDataSet
dataset_dir: dataset/mot dataset_dir: dataset/mot
......
...@@ -10,7 +10,7 @@ norm_type: sync_bn ...@@ -10,7 +10,7 @@ norm_type: sync_bn
use_ema: true use_ema: true
ema_decay: 0.9998 ema_decay: 0.9998
# for MOT training # add crowdhuman
TrainDataset: TrainDataset:
!MOTDataSet !MOTDataSet
dataset_dir: dataset/mot dataset_dir: dataset/mot
......
...@@ -154,7 +154,7 @@ class MOTDataSet(DetDataset): ...@@ -154,7 +154,7 @@ class MOTDataSet(DetDataset):
last_index += v last_index += v
self.num_identities_dict = defaultdict(int) self.num_identities_dict = defaultdict(int)
self.num_identities_dict[0] = int(last_index + 1) # single class self.num_identities_dict[0] = int(last_index + 1) # single class
self.num_imgs_each_data = [len(x) for x in self.img_files.values()] self.num_imgs_each_data = [len(x) for x in self.img_files.values()]
self.total_imgs = sum(self.num_imgs_each_data) self.total_imgs = sum(self.num_imgs_each_data)
...@@ -249,6 +249,7 @@ class MCMOTDataSet(DetDataset): ...@@ -249,6 +249,7 @@ class MCMOTDataSet(DetDataset):
└——————labels_with_ids └——————labels_with_ids
└——————train └——————train
""" """
def __init__(self, def __init__(self,
dataset_dir=None, dataset_dir=None,
image_lists=[], image_lists=[],
...@@ -343,22 +344,26 @@ class MCMOTDataSet(DetDataset): ...@@ -343,22 +344,26 @@ class MCMOTDataSet(DetDataset):
# cname2cid and cid2cname # cname2cid and cid2cname
cname2cid = {} cname2cid = {}
if self.label_list: if self.label_list is not None:
# if use label_list for multi source mix dataset, # if use label_list for multi source mix dataset,
# please make sure label_list in the first sub_dataset at least. # please make sure label_list in the first sub_dataset at least.
sub_dataset = self.image_lists[0].split('.')[0] sub_dataset = self.image_lists[0].split('.')[0]
label_path = os.path.join(self.dataset_dir, sub_dataset, label_path = os.path.join(self.dataset_dir, sub_dataset,
self.label_list) self.label_list)
if not os.path.exists(label_path): if not os.path.exists(label_path):
raise ValueError("label_list {} does not exists".format( logger.info(
label_path)) "Note: label_list {} does not exists, use VisDrone 10 classes labels as default.".
with open(label_path, 'r') as fr: format(label_path))
label_id = 0 cname2cid = visdrone_mcmot_label()
for line in fr.readlines(): else:
cname2cid[line.strip()] = label_id with open(label_path, 'r') as fr:
label_id += 1 label_id = 0
for line in fr.readlines():
cname2cid[line.strip()] = label_id
label_id += 1
else: else:
cname2cid = visdrone_mcmot_label() cname2cid = visdrone_mcmot_label()
cid2cname = dict([(v, k) for (k, v) in cname2cid.items()]) cid2cname = dict([(v, k) for (k, v) in cname2cid.items()])
logger.info('MCMOT dataset summary: ') logger.info('MCMOT dataset summary: ')
......
...@@ -440,13 +440,13 @@ class CenterNetPostProcess(TTFBox): ...@@ -440,13 +440,13 @@ class CenterNetPostProcess(TTFBox):
def __call__(self, hm, wh, reg, im_shape, scale_factor): def __call__(self, hm, wh, reg, im_shape, scale_factor):
heat = self._simple_nms(hm) heat = self._simple_nms(hm)
scores, inds, topk_clses, ys, xs = self._topk(heat) scores, inds, topk_clses, ys, xs = self._topk(heat)
scores = paddle.tensor.unsqueeze(scores, [1]) scores = scores.unsqueeze(1)
clses = paddle.tensor.unsqueeze(topk_clses, [1]) clses = topk_clses.unsqueeze(1)
reg_t = paddle.transpose(reg, [0, 2, 3, 1]) reg_t = paddle.transpose(reg, [0, 2, 3, 1])
# Like TTFBox, batch size is 1. # Like TTFBox, batch size is 1.
# TODO: support batch size > 1 # TODO: support batch size > 1
reg = paddle.reshape(reg_t, [-1, paddle.shape(reg_t)[-1]]) reg = paddle.reshape(reg_t, [-1, reg_t.shape[-1]])
reg = paddle.gather(reg, inds) reg = paddle.gather(reg, inds)
xs = paddle.cast(xs, 'float32') xs = paddle.cast(xs, 'float32')
ys = paddle.cast(ys, 'float32') ys = paddle.cast(ys, 'float32')
...@@ -454,7 +454,7 @@ class CenterNetPostProcess(TTFBox): ...@@ -454,7 +454,7 @@ class CenterNetPostProcess(TTFBox):
ys = ys + reg[:, 1:2] ys = ys + reg[:, 1:2]
wh_t = paddle.transpose(wh, [0, 2, 3, 1]) wh_t = paddle.transpose(wh, [0, 2, 3, 1])
wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]]) wh = paddle.reshape(wh_t, [-1, wh_t.shape[-1]])
wh = paddle.gather(wh, inds) wh = paddle.gather(wh, inds)
if self.regress_ltrb: if self.regress_ltrb:
...@@ -486,8 +486,7 @@ class CenterNetPostProcess(TTFBox): ...@@ -486,8 +486,7 @@ class CenterNetPostProcess(TTFBox):
scale_x = scale_factor[:, 1:2] scale_x = scale_factor[:, 1:2]
scale_expand = paddle.concat( scale_expand = paddle.concat(
[scale_x, scale_y, scale_x, scale_y], axis=1) [scale_x, scale_y, scale_x, scale_y], axis=1)
boxes_shape = paddle.shape(bboxes) boxes_shape = bboxes.shape[:]
boxes_shape.stop_gradient = True
scale_expand = paddle.expand(scale_expand, shape=boxes_shape) scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
bboxes = paddle.divide(bboxes, scale_expand) bboxes = paddle.divide(bboxes, scale_expand)
if self.for_mot: if self.for_mot:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册