提交 6591c89f 编写于 作者: Z Zhi Tian

add multi-scale testing

上级 e308e161
......@@ -26,3 +26,10 @@ SOLVER:
MAX_ITER: 90000
IMS_PER_BATCH: 16
WARMUP_METHOD: "constant"
TEST:
BBOX_AUG:
ENABLED: False
H_FLIP: True
SCALES: (400, 500, 600, 700, 900, 1000, 1100, 1200)
MAX_SIZE: 2000
SCALE_H_FLIP: True
......@@ -455,6 +455,28 @@ _C.TEST.IMS_PER_BATCH = 8
_C.TEST.DETECTIONS_PER_IMG = 100
# ---------------------------------------------------------------------------- #
# Test-time augmentations for bounding box detection
# See configs/test_time_aug/e2e_mask_rcnn_R-50-FPN_1x.yaml for an example
# ---------------------------------------------------------------------------- #
_C.TEST.BBOX_AUG = CN()
# Enable test-time augmentation for bounding box detection if True
_C.TEST.BBOX_AUG.ENABLED = False
# Horizontal flip at the original scale (id transform)
_C.TEST.BBOX_AUG.H_FLIP = False
# Each scale is the pixel size of an image's shortest side
_C.TEST.BBOX_AUG.SCALES = ()
# Max pixel size of the longer side
_C.TEST.BBOX_AUG.MAX_SIZE = 4000
# Horizontal flip at each scale
_C.TEST.BBOX_AUG.SCALE_H_FLIP = False
# ---------------------------------------------------------------------------- #
# Misc options
# ---------------------------------------------------------------------------- #
......
......@@ -10,7 +10,7 @@ from fcos_core.utils.imports import import_file
from . import datasets as D
from . import samplers
from .collate_batch import BatchCollator
from .collate_batch import BatchCollator, BBoxAugCollator
from .transforms import build_transforms
......@@ -150,7 +150,8 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
DatasetCatalog = paths_catalog.DatasetCatalog
dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST
transforms = build_transforms(cfg, is_train)
# If bbox aug is enabled in testing, simply set transforms to None and we will apply transforms later
transforms = None if not is_train and cfg.TEST.BBOX_AUG.ENABLED else build_transforms(cfg, is_train)
datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train)
data_loaders = []
......@@ -159,7 +160,8 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
batch_sampler = make_batch_data_sampler(
dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter
)
collator = BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY)
collator = BBoxAugCollator() if not is_train and cfg.TEST.BBOX_AUG.ENABLED else \
BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY)
num_workers = cfg.DATALOADER.NUM_WORKERS
data_loader = torch.utils.data.DataLoader(
dataset,
......
......@@ -18,3 +18,14 @@ class BatchCollator(object):
targets = transposed_batch[1]
img_ids = transposed_batch[2]
return images, targets, img_ids
class BBoxAugCollator(object):
"""
From a list of samples from the dataset,
returns the images and targets.
Images should be converted to batched images in `im_detect_bbox_aug`
"""
def __call__(self, batch):
return list(zip(*batch))
......@@ -54,9 +54,11 @@ class Resize(object):
return (oh, ow)
def __call__(self, image, target):
def __call__(self, image, target=None):
size = self.get_size(image.size)
image = F.resize(image, size)
if target is None:
return image
target = target.resize(image.size)
return image, target
......@@ -83,8 +85,10 @@ class Normalize(object):
self.std = std
self.to_bgr255 = to_bgr255
def __call__(self, image, target):
def __call__(self, image, target=None):
if self.to_bgr255:
image = image[[2, 1, 0]] * 255
image = F.normalize(image, mean=self.mean, std=self.std)
if target is None:
return image
return image, target
import torch
import torchvision.transforms as TT
from fcos_core.config import cfg
from fcos_core.data import transforms as T
from fcos_core.structures.image_list import to_image_list
from fcos_core.structures.bounding_box import BoxList
from fcos_core.modeling.rpn.fcos.inference import make_fcos_postprocessor
def im_detect_bbox_aug(model, images, device):
# Collect detections computed under different transformations
boxlists_ts = []
for _ in range(len(images)):
boxlists_ts.append([])
def add_preds_t(boxlists_t):
for i, boxlist_t in enumerate(boxlists_t):
if len(boxlists_ts[i]) == 0:
# The first one is identity transform, no need to resize the boxlist
boxlists_ts[i].append(boxlist_t)
else:
# Resize the boxlist as the first one
boxlists_ts[i].append(boxlist_t.resize(boxlists_ts[i][0].size))
# Compute detections for the original image (identity transform)
boxlists_i = im_detect_bbox(
model, images, cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST, device
)
add_preds_t(boxlists_i)
# Perform detection on the horizontally flipped image
if cfg.TEST.BBOX_AUG.H_FLIP:
boxlists_hf = im_detect_bbox_hflip(
model, images, cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST, device
)
add_preds_t(boxlists_hf)
# Compute detections at different scales
for scale in cfg.TEST.BBOX_AUG.SCALES:
max_size = cfg.TEST.BBOX_AUG.MAX_SIZE
boxlists_scl = im_detect_bbox_scale(
model, images, scale, max_size, device
)
add_preds_t(boxlists_scl)
if cfg.TEST.BBOX_AUG.SCALE_H_FLIP:
boxlists_scl_hf = im_detect_bbox_scale(
model, images, scale, max_size, device, hflip=True
)
add_preds_t(boxlists_scl_hf)
assert cfg.MODEL.FCOS_ON, "The multi-scale testing only supports FCOS detector"
# Merge boxlists detected by different bbox aug params
boxlists = []
for i, boxlist_ts in enumerate(boxlists_ts):
bbox = torch.cat([boxlist_t.bbox for boxlist_t in boxlist_ts])
scores = torch.cat([boxlist_t.get_field('scores') for boxlist_t in boxlist_ts])
labels = torch.cat([boxlist_t.get_field('labels') for boxlist_t in boxlist_ts])
boxlist = BoxList(bbox, boxlist_ts[0].size, boxlist_ts[0].mode)
boxlist.add_field('scores', scores)
boxlist.add_field('labels', labels)
boxlists.append(boxlist)
# Apply NMS and limit the final detections
post_processor = make_fcos_postprocessor(cfg)
results = post_processor.select_over_all_levels(boxlists)
return results
def im_detect_bbox(model, images, target_scale, target_max_size, device):
"""
Performs bbox detection on the original image.
"""
transform = TT.Compose([
T.Resize(target_scale, target_max_size),
TT.ToTensor(),
T.Normalize(
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=cfg.INPUT.TO_BGR255
)
])
images = [transform(image) for image in images]
images = to_image_list(images, cfg.DATALOADER.SIZE_DIVISIBILITY)
return model(images.to(device))
def im_detect_bbox_hflip(model, images, target_scale, target_max_size, device):
"""
Performs bbox detection on the horizontally flipped image.
Function signature is the same as for im_detect_bbox.
"""
transform = TT.Compose([
T.Resize(target_scale, target_max_size),
TT.RandomHorizontalFlip(1.0),
TT.ToTensor(),
T.Normalize(
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=cfg.INPUT.TO_BGR255
)
])
images = [transform(image) for image in images]
images = to_image_list(images, cfg.DATALOADER.SIZE_DIVISIBILITY)
boxlists = model(images.to(device))
# Invert the detections computed on the flipped image
boxlists_inv = [boxlist.transpose(0) for boxlist in boxlists]
return boxlists_inv
def im_detect_bbox_scale(model, images, target_scale, target_max_size, device, hflip=False):
"""
Computes bbox detections at the given scale.
Returns predictions in the scaled image space.
"""
if hflip:
boxlists_scl = im_detect_bbox_hflip(model, images, target_scale, target_max_size, device)
else:
boxlists_scl = im_detect_bbox(model, images, target_scale, target_max_size, device)
return boxlists_scl
......@@ -6,11 +6,13 @@ import os
import torch
from tqdm import tqdm
from fcos_core.config import cfg
from fcos_core.data.datasets.evaluation import evaluate
from ..utils.comm import is_main_process, get_world_size
from ..utils.comm import all_gather
from ..utils.comm import synchronize
from ..utils.timer import Timer, get_time_str
from .bbox_aug import im_detect_bbox_aug
def compute_on_dataset(model, data_loader, device, timer=None):
......@@ -19,11 +21,13 @@ def compute_on_dataset(model, data_loader, device, timer=None):
cpu_device = torch.device("cpu")
for _, batch in enumerate(tqdm(data_loader)):
images, targets, image_ids = batch
images = images.to(device)
with torch.no_grad():
if timer:
timer.tic()
output = model(images)
if cfg.TEST.BBOX_AUG.ENABLED:
output = im_detect_bbox_aug(model, images, device)
else:
output = model(images.to(device))
if timer:
torch.cuda.synchronize()
timer.toc()
......
......@@ -22,7 +22,8 @@ class PostProcessor(nn.Module):
nms=0.5,
detections_per_img=100,
box_coder=None,
cls_agnostic_bbox_reg=False
cls_agnostic_bbox_reg=False,
bbox_aug_enabled=False
):
"""
Arguments:
......@@ -39,6 +40,7 @@ class PostProcessor(nn.Module):
box_coder = BoxCoder(weights=(10., 10., 5., 5.))
self.box_coder = box_coder
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
self.bbox_aug_enabled = bbox_aug_enabled
def forward(self, x, boxes):
"""
......@@ -79,7 +81,8 @@ class PostProcessor(nn.Module):
):
boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape)
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist = self.filter_results(boxlist, num_classes)
if not self.bbox_aug_enabled: # If bbox aug is enabled, we will do it later
boxlist = self.filter_results(boxlist, num_classes)
results.append(boxlist)
return results
......@@ -156,12 +159,14 @@ def make_roi_box_post_processor(cfg):
nms_thresh = cfg.MODEL.ROI_HEADS.NMS
detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG
cls_agnostic_bbox_reg = cfg.MODEL.CLS_AGNOSTIC_BBOX_REG
bbox_aug_enabled = cfg.TEST.BBOX_AUG.ENABLED
postprocessor = PostProcessor(
score_thresh,
nms_thresh,
detections_per_img,
box_coder,
cls_agnostic_bbox_reg
cls_agnostic_bbox_reg,
bbox_aug_enabled
)
return postprocessor
......@@ -24,6 +24,7 @@ class FCOSPostProcessor(torch.nn.Module):
fpn_post_nms_top_n,
min_size,
num_classes,
bbox_aug_enabled=False
):
"""
Arguments:
......@@ -42,6 +43,7 @@ class FCOSPostProcessor(torch.nn.Module):
self.fpn_post_nms_top_n = fpn_post_nms_top_n
self.min_size = min_size
self.num_classes = num_classes
self.bbox_aug_enabled = bbox_aug_enabled
def forward_for_single_feature_map(
self, locations, box_cls,
......@@ -131,7 +133,8 @@ class FCOSPostProcessor(torch.nn.Module):
boxlists = list(zip(*sampled_boxes))
boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
boxlists = self.select_over_all_levels(boxlists)
if not self.bbox_aug_enabled:
boxlists = self.select_over_all_levels(boxlists)
return boxlists
......@@ -190,6 +193,7 @@ def make_fcos_postprocessor(config):
pre_nms_top_n = config.MODEL.FCOS.PRE_NMS_TOP_N
nms_thresh = config.MODEL.FCOS.NMS_TH
fpn_post_nms_top_n = config.TEST.DETECTIONS_PER_IMG
bbox_aug_enabled = config.TEST.BBOX_AUG.ENABLED
box_selector = FCOSPostProcessor(
pre_nms_thresh=pre_nms_thresh,
......@@ -197,7 +201,8 @@ def make_fcos_postprocessor(config):
nms_thresh=nms_thresh,
fpn_post_nms_top_n=fpn_post_nms_top_n,
min_size=0,
num_classes=config.MODEL.FCOS.NUM_CLASSES
num_classes=config.MODEL.FCOS.NUM_CLASSES,
bbox_aug_enabled=bbox_aug_enabled
)
return box_selector
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册