未验证 提交 a0ccc7d3 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] fix fairmot (#4614)

* fix Momentum and label list

* fix export of mcfairmot
上级 af56c3e5
......@@ -12,7 +12,6 @@ LearningRate:
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
......
......@@ -9,6 +9,13 @@ norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
# add crowdhuman
TrainDataset:
!MOTDataSet
dataset_dir: dataset/mot
image_lists: ['mot17.train', 'caltech.all', 'cuhksysu.train', 'prw.train', 'citypersons.train', 'eth.train', 'crowdhuman.train', 'crowdhuman.val']
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
worker_num: 4
TrainReader:
inputs_def:
......
......@@ -10,7 +10,7 @@ norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
# for MOT training
# add crowdhuman
TrainDataset:
!MOTDataSet
dataset_dir: dataset/mot
......
......@@ -10,7 +10,7 @@ norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
# for MOT training
# add crowdhuman
TrainDataset:
!MOTDataSet
dataset_dir: dataset/mot
......
......@@ -10,7 +10,7 @@ norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
# for MOT training
# add crowdhuman
TrainDataset:
!MOTDataSet
dataset_dir: dataset/mot
......
......@@ -154,7 +154,7 @@ class MOTDataSet(DetDataset):
last_index += v
self.num_identities_dict = defaultdict(int)
self.num_identities_dict[0] = int(last_index + 1) # single class
self.num_identities_dict[0] = int(last_index + 1) # single class
self.num_imgs_each_data = [len(x) for x in self.img_files.values()]
self.total_imgs = sum(self.num_imgs_each_data)
......@@ -249,6 +249,7 @@ class MCMOTDataSet(DetDataset):
└——————labels_with_ids
└——————train
"""
def __init__(self,
dataset_dir=None,
image_lists=[],
......@@ -343,22 +344,26 @@ class MCMOTDataSet(DetDataset):
# cname2cid and cid2cname
cname2cid = {}
if self.label_list:
if self.label_list is not None:
# if use label_list for multi source mix dataset,
# please make sure label_list in the first sub_dataset at least.
sub_dataset = self.image_lists[0].split('.')[0]
label_path = os.path.join(self.dataset_dir, sub_dataset,
self.label_list)
if not os.path.exists(label_path):
raise ValueError("label_list {} does not exists".format(
label_path))
with open(label_path, 'r') as fr:
label_id = 0
for line in fr.readlines():
cname2cid[line.strip()] = label_id
label_id += 1
logger.info(
"Note: label_list {} does not exists, use VisDrone 10 classes labels as default.".
format(label_path))
cname2cid = visdrone_mcmot_label()
else:
with open(label_path, 'r') as fr:
label_id = 0
for line in fr.readlines():
cname2cid[line.strip()] = label_id
label_id += 1
else:
cname2cid = visdrone_mcmot_label()
cid2cname = dict([(v, k) for (k, v) in cname2cid.items()])
logger.info('MCMOT dataset summary: ')
......
......@@ -440,13 +440,13 @@ class CenterNetPostProcess(TTFBox):
def __call__(self, hm, wh, reg, im_shape, scale_factor):
heat = self._simple_nms(hm)
scores, inds, topk_clses, ys, xs = self._topk(heat)
scores = paddle.tensor.unsqueeze(scores, [1])
clses = paddle.tensor.unsqueeze(topk_clses, [1])
scores = scores.unsqueeze(1)
clses = topk_clses.unsqueeze(1)
reg_t = paddle.transpose(reg, [0, 2, 3, 1])
# Like TTFBox, batch size is 1.
# TODO: support batch size > 1
reg = paddle.reshape(reg_t, [-1, paddle.shape(reg_t)[-1]])
reg = paddle.reshape(reg_t, [-1, reg_t.shape[-1]])
reg = paddle.gather(reg, inds)
xs = paddle.cast(xs, 'float32')
ys = paddle.cast(ys, 'float32')
......@@ -454,7 +454,7 @@ class CenterNetPostProcess(TTFBox):
ys = ys + reg[:, 1:2]
wh_t = paddle.transpose(wh, [0, 2, 3, 1])
wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]])
wh = paddle.reshape(wh_t, [-1, wh_t.shape[-1]])
wh = paddle.gather(wh, inds)
if self.regress_ltrb:
......@@ -486,8 +486,7 @@ class CenterNetPostProcess(TTFBox):
scale_x = scale_factor[:, 1:2]
scale_expand = paddle.concat(
[scale_x, scale_y, scale_x, scale_y], axis=1)
boxes_shape = paddle.shape(bboxes)
boxes_shape.stop_gradient = True
boxes_shape = bboxes.shape[:]
scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
bboxes = paddle.divide(bboxes, scale_expand)
if self.for_mot:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册