提交 07ba056d 编写于 作者: Z Zhi Tian

onnx done.

上级 f9f4817a
......@@ -58,9 +58,9 @@ class FPN(nn.Module):
continue
# inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest")
inner_lateral = getattr(self, inner_block)(feature)
# TODO use size instead of scale to make it robust to different sizes
inner_top_down = F.upsample(
last_inner, size=(int(inner_lateral.shape[-2]) + 1 - 1, int(inner_lateral.shape[-1])), mode='nearest'
inner_top_down = F.interpolate(
last_inner, size=(int(inner_lateral.shape[-2]), int(inner_lateral.shape[-1])),
mode='nearest'
)
last_inner = inner_lateral + inner_top_down
results.insert(0, getattr(self, layer_block)(last_inner))
......
......@@ -9,7 +9,6 @@ from fcos_core.structures.bounding_box import BoxList
from fcos_core.structures.boxlist_ops import cat_boxlist
from fcos_core.structures.boxlist_ops import boxlist_ml_nms
from fcos_core.structures.boxlist_ops import remove_small_boxes
from torchvision.ops.boxes import batched_nms
class FCOSPostProcessor(torch.nn.Module):
......@@ -148,13 +147,7 @@ class FCOSPostProcessor(torch.nn.Module):
results = []
for i in range(num_images):
# multiclass nms
keep = batched_nms(
boxlists[i].bbox,
boxlists[i].get_field("scores"),
boxlists[i].get_field("labels"),
self.nms_thresh
)
result = boxlists[i][keep]
result = boxlist_ml_nms(boxlists[i], self.nms_thresh)
number_of_detections = len(result)
# Limit to max_per_image detections **over all classes**
......
"""
Please make sure you are using pytorch >= 1.4.0.
A working example to export the R-50 based FCOS model:
python tools/export_model_to_onnx.py --config-file configs/fcos/fcos_imprv_R_50_FPN_1x.yaml MODEL.WEIGHT FCOS_imprv_R_50_FPN_1x.pth
python onnx/export_model_to_onnx.py \
--config-file configs/fcos/fcos_imprv_R_50_FPN_1x.yaml \
MODEL.WEIGHT FCOS_imprv_R_50_FPN_1x.pth
"""
from fcos_core.utils.env import setup_environment # noqa F401 isort:skip
......@@ -25,7 +28,7 @@ def main():
parser = argparse.ArgumentParser(description="Export model to the onnx format")
parser.add_argument(
"--config-file",
default="/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml",
default="configs/fcos/fcos_imprv_R_50_FPN_1x.yaml",
metavar="FILE",
help="path to config file",
)
......
"""
An example:
python tools/test_fcos_onnx_model.py --onnx-model fcos_imprv_R_50_FPN_1x.onnx --config-file configs/fcos/fcos_imprv_R_50_FPN_1x.yaml TEST.IMS_PER_BATCH 1 DATALOADER.NUM_WORKERS 0
Please make sure you are using torchvision >= 0.5.0.
wget https://cloudstor.aarnet.edu.au/plus/s/38fQAdi2HBkn274/download -O fcos_imprv_R_50_FPN_1x.onnx
python onnx/test_fcos_onnx_model.py \
--onnx-model fcos_imprv_R_50_FPN_1x.onnx \
--config-file configs/fcos/fcos_imprv_R_50_FPN_1x.yaml \
TEST.IMS_PER_BATCH 1 \
INPUT.MIN_SIZE_TEST 800
If you encounter an out of memory error, please try to reduce INPUT.MIN_SIZE_TEST.
"""
from fcos_core.utils.env import setup_environment # noqa F401 isort:skip
......@@ -14,8 +22,6 @@ import onnx
from fcos_core.config import cfg
from fcos_core.data import make_data_loader
from fcos_core.engine.inference import inference
from fcos_core.modeling.detector import build_detection_model
from fcos_core.utils.checkpoint import DetectronCheckpointer
from fcos_core.utils.collect_env import collect_env_info
from fcos_core.utils.comm import synchronize, get_rank
from fcos_core.utils.logger import setup_logger
......@@ -23,7 +29,166 @@ from fcos_core.utils.miscellaneous import mkdir
from fcos_core.modeling.rpn.fcos.inference import make_fcos_postprocessor
import caffe2.python.onnx.backend as backend
import numpy as np
from fcos_core.structures.image_list import to_image_list
from fcos_core.structures.bounding_box import BoxList
from fcos_core.structures.boxlist_ops import cat_boxlist
from fcos_core.structures.boxlist_ops import remove_small_boxes
from torchvision.ops.boxes import batched_nms
class FCOSPostProcessor(torch.nn.Module):
"""
Performs post-processing on the outputs of the RetinaNet boxes.
This is only used in the testing.
"""
def __init__(
self,
pre_nms_thresh,
pre_nms_top_n,
nms_thresh,
fpn_post_nms_top_n,
min_size,
num_classes
):
"""
Arguments:
pre_nms_thresh (float)
pre_nms_top_n (int)
nms_thresh (float)
fpn_post_nms_top_n (int)
min_size (int)
num_classes (int)
box_coder (BoxCoder)
"""
super(FCOSPostProcessor, self).__init__()
self.pre_nms_thresh = pre_nms_thresh
self.pre_nms_top_n = pre_nms_top_n
self.nms_thresh = nms_thresh
self.fpn_post_nms_top_n = fpn_post_nms_top_n
self.min_size = min_size
self.num_classes = num_classes
def forward_for_single_feature_map(
self, locations, box_cls,
box_regression, centerness,
image_sizes):
"""
Arguments:
anchors: list[BoxList]
box_cls: tensor of size N, A * C, H, W
box_regression: tensor of size N, A * 4, H, W
"""
N, C, H, W = box_cls.shape
# put in the same format as locations
box_cls = box_cls.view(N, C, H, W).permute(0, 2, 3, 1)
box_cls = box_cls.reshape(N, -1, C).sigmoid()
box_regression = box_regression.view(N, 4, H, W).permute(0, 2, 3, 1)
box_regression = box_regression.reshape(N, -1, 4)
centerness = centerness.view(N, 1, H, W).permute(0, 2, 3, 1)
centerness = centerness.reshape(N, -1).sigmoid()
candidate_inds = box_cls > self.pre_nms_thresh
pre_nms_top_n = candidate_inds.view(N, -1).sum(1)
pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)
# multiply the classification scores with centerness scores
box_cls = box_cls * centerness[:, :, None]
results = []
for i in range(N):
per_box_cls = box_cls[i]
per_candidate_inds = candidate_inds[i]
per_box_cls = per_box_cls[per_candidate_inds]
per_candidate_nonzeros = per_candidate_inds.nonzero()
per_box_loc = per_candidate_nonzeros[:, 0]
per_class = per_candidate_nonzeros[:, 1] + 1
per_box_regression = box_regression[i]
per_box_regression = per_box_regression[per_box_loc]
per_locations = locations[per_box_loc]
per_pre_nms_top_n = pre_nms_top_n[i]
if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
per_box_cls, top_k_indices = \
per_box_cls.topk(per_pre_nms_top_n, sorted=False)
per_class = per_class[top_k_indices]
per_box_regression = per_box_regression[top_k_indices]
per_locations = per_locations[top_k_indices]
detections = torch.stack([
per_locations[:, 0] - per_box_regression[:, 0],
per_locations[:, 1] - per_box_regression[:, 1],
per_locations[:, 0] + per_box_regression[:, 2],
per_locations[:, 1] + per_box_regression[:, 3],
], dim=1)
h, w = image_sizes[i]
boxlist = BoxList(detections, (int(w), int(h)), mode="xyxy")
boxlist.add_field("labels", per_class)
boxlist.add_field("scores", torch.sqrt(per_box_cls))
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist = remove_small_boxes(boxlist, self.min_size)
results.append(boxlist)
return results
def forward(self, locations, box_cls, box_regression, centerness, image_sizes):
"""
Arguments:
anchors: list[list[BoxList]]
box_cls: list[tensor]
box_regression: list[tensor]
image_sizes: list[(h, w)]
Returns:
boxlists (list[BoxList]): the post-processed anchors, after
applying box decoding and NMS
"""
sampled_boxes = []
for _, (l, o, b, c) in enumerate(zip(locations, box_cls, box_regression, centerness)):
sampled_boxes.append(
self.forward_for_single_feature_map(
l, o, b, c, image_sizes
)
)
boxlists = list(zip(*sampled_boxes))
boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
boxlists = self.select_over_all_levels(boxlists)
return boxlists
# TODO very similar to filter_results from PostProcessor
# but filter_results is per image
# TODO Yang: solve this issue in the future. No good solution
# right now.
def select_over_all_levels(self, boxlists):
num_images = len(boxlists)
results = []
for i in range(num_images):
# multiclass nms
keep = batched_nms(
boxlists[i].bbox,
boxlists[i].get_field("scores"),
boxlists[i].get_field("labels"),
self.nms_thresh
)
result = boxlists[i][keep]
number_of_detections = len(result)
# Limit to max_per_image detections **over all classes**
if number_of_detections > self.fpn_post_nms_top_n > 0:
cls_scores = result.get_field("scores")
image_thresh, _ = torch.kthvalue(
cls_scores.cpu(),
number_of_detections - self.fpn_post_nms_top_n + 1
)
keep = cls_scores >= image_thresh.item()
keep = torch.nonzero(keep).squeeze(1)
result = result[keep]
results.append(result)
return results
class ONNX_FCOS(nn.Module):
......@@ -33,7 +198,15 @@ class ONNX_FCOS(nn.Module):
onnx.load(onnx_model_path),
device=cfg.MODEL.DEVICE.upper()
)
self.postprocessing = make_fcos_postprocessor(cfg)
# Note that we still use PyTorch for postprocessing
self.postprocessing = FCOSPostProcessor(
pre_nms_thresh=cfg.MODEL.FCOS.INFERENCE_TH,
pre_nms_top_n=cfg.MODEL.FCOS.PRE_NMS_TOP_N,
nms_thresh=cfg.MODEL.FCOS.NMS_TH,
fpn_post_nms_top_n=cfg.TEST.DETECTIONS_PER_IMG,
min_size=0,
num_classes=cfg.MODEL.FCOS.NUM_CLASSES
)
self.cfg = cfg
self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES
......@@ -101,6 +274,10 @@ def main():
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
# The onnx model can only be used with DATALOADER.NUM_WORKERS = 0
cfg.DATALOADER.NUM_WORKERS = 0
cfg.freeze()
save_dir = ""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册