未验证 提交 b3a80bee 编写于 作者: F Feng Wang 提交者: GitHub

feat(tools): add assignment visualizer (#1616)

feat(tools): add assignment visualizer (#1616)
上级 4f8f1d79
......@@ -10,7 +10,8 @@ This repo is an implementation of PyTorch version YOLOX, there is also a [MegEng
<img src="assets/git_fig.png" width="1000" >
## Updates!!
* 【2022/04/14】 We suport jit compile op.
* 【2023/02/28】 We support assignment visualization tool, see doc [here](./docs/assignment_visualization.md).
* 【2022/04/14】 We support jit compile op.
* 【2021/08/19】 We optimize the training process with **2x** faster training and **~1%** higher performance! See [notes](docs/updates_note.md) for more details.
* 【2021/08/05】 We release [MegEngine version YOLOX](https://github.com/MegEngine/YOLOX).
* 【2021/07/28】 We fix the fatal error of [memory leak](https://github.com/Megvii-BaseDetection/YOLOX/issues/103)
......@@ -206,6 +207,7 @@ python -m yolox.tools.eval -n yolox-s -c yolox_s.pth -b 1 -d 1 --conf 0.001 --f
* [Training on custom data](docs/train_custom_data.md)
* [Caching for custom data](docs/cache.md)
* [Manipulating training image size](docs/manipulate_training_image_size.md)
* [Assignment visualization](docs/assignment_visualization.md)
* [Freezing model](docs/freeze_module.md)
</details>
......@@ -243,8 +245,8 @@ If you use YOLOX in your research, please cite our work by using the following B
}
```
## In memory of Dr. Jian Sun
Without the guidance of [Dr. Sun Jian](http://www.jiansun.org/), YOLOX would not have been released and open sourced to the community.
The passing away of Dr. Sun Jian is a great loss to the Computer Vision field. We have added this section here to express our remembrance and condolences to our captain Dr. Sun.
Without the guidance of [Dr. Jian Sun](http://www.jiansun.org/), YOLOX would not have been released and open sourced to the community.
The passing away of Dr. Jian is a huge loss to the Computer Vision field. We add this section here to express our remembrance and condolences to our captain Dr. Jian.
It is hoped that every AI practitioner in the world will stick to the concept of "continuous innovation to expand cognitive boundaries, and extraordinary technology to achieve product value" and move forward all the way.
<div align="center"><img src="assets/sunjian.png" width="200"></div>
......
# Visualize label assignment
This tutorial explains how to visualize your label asssignment result when training with YOLOX.
## 1. Visualization command
We provide a visualization tool to help you visualize your label assignment result. You can find it in [`tools/visualize_assignment.py`](../tools/visualize_assign.py).
Here is an example of command to visualize your label assignment result:
```shell
python3 tools/visualize_assign.py -f /path/to/your/exp.py yolox-s -d 1 -b 8 --max-batch 2
```
`max-batch` here means the maximum number of batches to visualize. The default value is 1, which the tool means only visualize the first batch.
By the way, the mosaic augmentation is used in default dataloader, so you can also see the mosaic result here.
After running the command, the logger will show you where the visualization result is saved, let's open it and into the step 2.
## 2. Check the visualization result
Here is an example of visualization result:
<div align="center"><img src="../assets/assignment.png" width="640"></div>
Those dots in one box is the matched anchor of gt box. **The color of dots is the same as the color of the box** to help you determine which object is assigned to the anchor. Note the box and dots are **instance level** visualization, which means the same class may have different colors.
**If the gt box doesn't match any anchor, the box will be marked as red and the red text "unmatched" will be drawn over the box**.
Please feel free to open an issue if you have any questions.
#!/usr/bin/env python3
# Copyright (c) Megvii, Inc. and its affiliates.
import os
import sys
import random
import time
import warnings
from loguru import logger
import torch
import torch.backends.cudnn as cudnn
from yolox.exp import Exp, get_exp
from yolox.core import Trainer
from yolox.utils import configure_module, configure_omp
from yolox.tools.train import make_parser
class AssignVisualizer(Trainer):
def __init__(self, exp: Exp, args):
super().__init__(exp, args)
self.batch_cnt = 0
self.vis_dir = os.path.join(self.file_name, "vis")
os.makedirs(self.vis_dir, exist_ok=True)
def train_one_iter(self):
iter_start_time = time.time()
inps, targets = self.prefetcher.next()
inps = inps.to(self.data_type)
targets = targets.to(self.data_type)
targets.requires_grad = False
inps, targets = self.exp.preprocess(inps, targets, self.input_size)
data_end_time = time.time()
with torch.cuda.amp.autocast(enabled=self.amp_training):
path_prefix = os.path.join(self.vis_dir, f"assign_vis_{self.batch_cnt}_")
self.model.visualize(inps, targets, path_prefix)
if self.use_model_ema:
self.ema_model.update(self.model)
iter_end_time = time.time()
self.meter.update(
iter_time=iter_end_time - iter_start_time,
data_time=data_end_time - iter_start_time,
)
self.batch_cnt += 1
if self.batch_cnt >= self.args.max_batch:
sys.exit(0)
def after_train(self):
logger.info("Finish visualize assignment, exit...")
def assign_vis_parser():
parser = make_parser()
parser.add_argument("--max-batch", type=int, default=1, help="max batch of images to visualize")
return parser
@logger.catch
def main(exp: Exp, args):
if exp.seed is not None:
random.seed(exp.seed)
torch.manual_seed(exp.seed)
cudnn.deterministic = True
warnings.warn(
"You have chosen to seed training. This will turn on the CUDNN deterministic setting, "
"which can slow down your training considerably! You may see unexpected behavior "
"when restarting from checkpoints."
)
# set environment variables for distributed training
configure_omp()
cudnn.benchmark = True
visualizer = AssignVisualizer(exp, args)
visualizer.train()
if __name__ == "__main__":
configure_module()
args = assign_vis_parser().parse_args()
exp = get_exp(args.exp_file, args.name)
exp.merge(args.opts)
if not args.experiment_name:
args.experiment_name = exp.exp_name
main(exp, args)
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import datetime
......
......@@ -9,7 +9,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from yolox.utils import bboxes_iou, meshgrid
from yolox.utils import bboxes_iou, cxcywh2xyxy, meshgrid, visualize_assign
from .losses import IOUloss
from .network_blocks import BaseConv, DWConv
......@@ -511,11 +511,7 @@ class YOLOXHead(nn.Module):
)
def get_geometry_constraint(
self,
gt_bboxes_per_image,
expanded_strides,
x_shifts,
y_shifts,
self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts,
):
"""
Calculate whether the center of an object is located in a fixed range of
......@@ -546,8 +542,6 @@ class YOLOXHead(nn.Module):
return anchor_filter, geometry_relation
def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
# Dynamic K
# ---------------------------------------------------------------
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
n_candidate_k = min(10, pair_wise_ious.size(1))
......@@ -580,3 +574,68 @@ class YOLOXHead(nn.Module):
fg_mask_inboxes
]
return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
def visualize_assign_result(self, xin, labels=None, imgs=None, save_prefix="assign_vis_"):
# original forward logic
outputs, x_shifts, y_shifts, expanded_strides = [], [], [], []
# TODO: use forward logic here.
for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
zip(self.cls_convs, self.reg_convs, self.strides, xin)
):
x = self.stems[k](x)
cls_x = x
reg_x = x
cls_feat = cls_conv(cls_x)
cls_output = self.cls_preds[k](cls_feat)
reg_feat = reg_conv(reg_x)
reg_output = self.reg_preds[k](reg_feat)
obj_output = self.obj_preds[k](reg_feat)
output = torch.cat([reg_output, obj_output, cls_output], 1)
output, grid = self.get_output_and_grid(output, k, stride_this_level, xin[0].type())
x_shifts.append(grid[:, :, 0])
y_shifts.append(grid[:, :, 1])
expanded_strides.append(
torch.full((1, grid.shape[1]), stride_this_level).type_as(xin[0])
)
outputs.append(output)
outputs = torch.cat(outputs, 1)
bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
obj_preds = outputs[:, :, 4:5] # [batch, n_anchors_all, 1]
cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
# calculate targets
total_num_anchors = outputs.shape[1]
x_shifts = torch.cat(x_shifts, 1) # [1, n_anchors_all]
y_shifts = torch.cat(y_shifts, 1) # [1, n_anchors_all]
expanded_strides = torch.cat(expanded_strides, 1)
nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # number of objects
for batch_idx, (img, num_gt, label) in enumerate(zip(imgs, nlabel, labels)):
img = imgs[batch_idx].permute(1, 2, 0).to(torch.uint8)
num_gt = int(num_gt)
if num_gt == 0:
fg_mask = outputs.new_zeros(total_num_anchors).bool()
else:
gt_bboxes_per_image = label[:num_gt, 1:5]
gt_classes = label[:num_gt, 0]
bboxes_preds_per_image = bbox_preds[batch_idx]
_, fg_mask, _, matched_gt_inds, _ = self.get_assignments( # noqa
batch_idx, num_gt, gt_bboxes_per_image, gt_classes,
bboxes_preds_per_image, expanded_strides, x_shifts,
y_shifts, cls_preds, obj_preds,
)
img = img.cpu().numpy().copy() # copy is crucial here
coords = torch.stack([
((x_shifts + 0.5) * expanded_strides).flatten()[fg_mask],
((y_shifts + 0.5) * expanded_strides).flatten()[fg_mask],
], 1)
xyxy_boxes = cxcywh2xyxy(gt_bboxes_per_image)
save_name = save_prefix + str(batch_idx) + ".png"
img = visualize_assign(img, xyxy_boxes, coords, matched_gt_inds, save_name)
logger.info(f"save img to {save_name}")
......@@ -46,3 +46,7 @@ class YOLOX(nn.Module):
outputs = self.head(fpn_outs)
return outputs
def visualize(self, x, targets, save_prefix="assign_vis_"):
fpn_outs = self.backbone(x)
self.head.visualize_assign_result(fpn_outs, targets, x, save_prefix)
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.
from .allreduce_norm import *
......
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.
import numpy as np
......@@ -15,6 +14,7 @@ __all__ = [
"adjust_box_anns",
"xyxy2xywh",
"xyxy2cxcywh",
"cxcywh2xyxy",
]
......@@ -133,3 +133,11 @@ def xyxy2cxcywh(bboxes):
bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5
bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5
return bboxes
def cxcywh2xyxy(bboxes):
bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] * 0.5
bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] * 0.5
bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2]
bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3]
return bboxes
......@@ -2,10 +2,51 @@
# Copyright (c) Megvii Inc. All rights reserved.
import os
import random
import cv2
import numpy as np
__all__ = ["mkdir", "nms", "multiclass_nms", "demo_postprocess"]
__all__ = [
"mkdir", "nms", "multiclass_nms", "demo_postprocess", "random_color", "visualize_assign"
]
def random_color():
return random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)
def visualize_assign(img, boxes, coords, match_results, save_name=None) -> np.ndarray:
"""visualize label assign result.
Args:
img: img to visualize
boxes: gt boxes in xyxy format
coords: coords of matched anchors
match_results: match results of each gt box and coord.
save_name: name of save image, if None, image will not be saved. Default: None.
"""
for box_id, box in enumerate(boxes):
x1, y1, x2, y2 = box
color = random_color()
assign_coords = coords[match_results == box_id]
if assign_coords.numel() == 0:
# unmatched boxes are red
color = (0, 0, 255)
cv2.putText(
img, "unmatched", (int(x1), int(y1) - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 1
)
else:
for coord in assign_coords:
# draw assigned anchor
cv2.circle(img, (int(coord[0]), int(coord[1])), 3, color, -1)
cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
if save_name is not None:
cv2.imwrite(save_name, img)
return img
def mkdir(path):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册