From 07ba056d6a02db6e146514b7234e834157a80265 Mon Sep 17 00:00:00 2001 From: Zhi Tian Date: Wed, 20 Nov 2019 19:25:02 +1030 Subject: [PATCH] onnx done. --- fcos_core/modeling/backbone/fpn.py | 6 +- fcos_core/modeling/rpn/fcos/inference.py | 9 +- {tools => onnx}/export_model_to_onnx.py | 7 +- onnx/test_fcos_onnx_model.py | 322 +++++++++++++++++++++++ tools/test_fcos_onnx_model.py | 145 ---------- 5 files changed, 331 insertions(+), 158 deletions(-) rename {tools => onnx}/export_model_to_onnx.py (91%) create mode 100644 onnx/test_fcos_onnx_model.py delete mode 100644 tools/test_fcos_onnx_model.py diff --git a/fcos_core/modeling/backbone/fpn.py b/fcos_core/modeling/backbone/fpn.py index 381efcd..ff46182 100644 --- a/fcos_core/modeling/backbone/fpn.py +++ b/fcos_core/modeling/backbone/fpn.py @@ -58,9 +58,9 @@ class FPN(nn.Module): continue # inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest") inner_lateral = getattr(self, inner_block)(feature) - # TODO use size instead of scale to make it robust to different sizes - inner_top_down = F.upsample( - last_inner, size=(int(inner_lateral.shape[-2]) + 1 - 1, int(inner_lateral.shape[-1])), mode='nearest' + inner_top_down = F.interpolate( + last_inner, size=(int(inner_lateral.shape[-2]), int(inner_lateral.shape[-1])), + mode='nearest' ) last_inner = inner_lateral + inner_top_down results.insert(0, getattr(self, layer_block)(last_inner)) diff --git a/fcos_core/modeling/rpn/fcos/inference.py b/fcos_core/modeling/rpn/fcos/inference.py index 5bc59e6..0679401 100644 --- a/fcos_core/modeling/rpn/fcos/inference.py +++ b/fcos_core/modeling/rpn/fcos/inference.py @@ -9,7 +9,6 @@ from fcos_core.structures.bounding_box import BoxList from fcos_core.structures.boxlist_ops import cat_boxlist from fcos_core.structures.boxlist_ops import boxlist_ml_nms from fcos_core.structures.boxlist_ops import remove_small_boxes -from torchvision.ops.boxes import batched_nms class FCOSPostProcessor(torch.nn.Module): @@ -148,13 +147,7 @@ class FCOSPostProcessor(torch.nn.Module): results = [] for i in range(num_images): # multiclass nms - keep = batched_nms( - boxlists[i].bbox, - boxlists[i].get_field("scores"), - boxlists[i].get_field("labels"), - self.nms_thresh - ) - result = boxlists[i][keep] + result = boxlist_ml_nms(boxlists[i], self.nms_thresh) number_of_detections = len(result) # Limit to max_per_image detections **over all classes** diff --git a/tools/export_model_to_onnx.py b/onnx/export_model_to_onnx.py similarity index 91% rename from tools/export_model_to_onnx.py rename to onnx/export_model_to_onnx.py index 4957af0..e6e8e27 100644 --- a/tools/export_model_to_onnx.py +++ b/onnx/export_model_to_onnx.py @@ -1,6 +1,9 @@ """ +Please make sure you are using pytorch >= 1.4.0. A working example to export the R-50 based FCOS model: -python tools/export_model_to_onnx.py --config-file configs/fcos/fcos_imprv_R_50_FPN_1x.yaml MODEL.WEIGHT FCOS_imprv_R_50_FPN_1x.pth +python onnx/export_model_to_onnx.py \ + --config-file configs/fcos/fcos_imprv_R_50_FPN_1x.yaml \ + MODEL.WEIGHT FCOS_imprv_R_50_FPN_1x.pth """ from fcos_core.utils.env import setup_environment # noqa F401 isort:skip @@ -25,7 +28,7 @@ def main(): parser = argparse.ArgumentParser(description="Export model to the onnx format") parser.add_argument( "--config-file", - default="/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml", + default="configs/fcos/fcos_imprv_R_50_FPN_1x.yaml", metavar="FILE", help="path to config file", ) diff --git a/onnx/test_fcos_onnx_model.py b/onnx/test_fcos_onnx_model.py new file mode 100644 index 0000000..58e17ca --- /dev/null +++ b/onnx/test_fcos_onnx_model.py @@ -0,0 +1,322 @@ +""" +An example: +Please make sure you are using torchvision >= 0.5.0. + +wget https://cloudstor.aarnet.edu.au/plus/s/38fQAdi2HBkn274/download -O fcos_imprv_R_50_FPN_1x.onnx +python onnx/test_fcos_onnx_model.py \ + --onnx-model fcos_imprv_R_50_FPN_1x.onnx \ + --config-file configs/fcos/fcos_imprv_R_50_FPN_1x.yaml \ + TEST.IMS_PER_BATCH 1 \ + INPUT.MIN_SIZE_TEST 800 + +If you encounter an out of memory error, please try to reduce INPUT.MIN_SIZE_TEST. +""" +from fcos_core.utils.env import setup_environment # noqa F401 isort:skip + +import argparse +import os + +import torch +from torch import nn +import onnx +from fcos_core.config import cfg +from fcos_core.data import make_data_loader +from fcos_core.engine.inference import inference +from fcos_core.utils.collect_env import collect_env_info +from fcos_core.utils.comm import synchronize, get_rank +from fcos_core.utils.logger import setup_logger +from fcos_core.utils.miscellaneous import mkdir +from fcos_core.modeling.rpn.fcos.inference import make_fcos_postprocessor +import caffe2.python.onnx.backend as backend +import numpy as np +from fcos_core.structures.bounding_box import BoxList +from fcos_core.structures.boxlist_ops import cat_boxlist +from fcos_core.structures.boxlist_ops import remove_small_boxes +from torchvision.ops.boxes import batched_nms + + +class FCOSPostProcessor(torch.nn.Module): + """ + Performs post-processing on the outputs of the RetinaNet boxes. + This is only used in the testing. + """ + def __init__( + self, + pre_nms_thresh, + pre_nms_top_n, + nms_thresh, + fpn_post_nms_top_n, + min_size, + num_classes + ): + """ + Arguments: + pre_nms_thresh (float) + pre_nms_top_n (int) + nms_thresh (float) + fpn_post_nms_top_n (int) + min_size (int) + num_classes (int) + box_coder (BoxCoder) + """ + super(FCOSPostProcessor, self).__init__() + self.pre_nms_thresh = pre_nms_thresh + self.pre_nms_top_n = pre_nms_top_n + self.nms_thresh = nms_thresh + self.fpn_post_nms_top_n = fpn_post_nms_top_n + self.min_size = min_size + self.num_classes = num_classes + + def forward_for_single_feature_map( + self, locations, box_cls, + box_regression, centerness, + image_sizes): + """ + Arguments: + anchors: list[BoxList] + box_cls: tensor of size N, A * C, H, W + box_regression: tensor of size N, A * 4, H, W + """ + N, C, H, W = box_cls.shape + + # put in the same format as locations + box_cls = box_cls.view(N, C, H, W).permute(0, 2, 3, 1) + box_cls = box_cls.reshape(N, -1, C).sigmoid() + box_regression = box_regression.view(N, 4, H, W).permute(0, 2, 3, 1) + box_regression = box_regression.reshape(N, -1, 4) + centerness = centerness.view(N, 1, H, W).permute(0, 2, 3, 1) + centerness = centerness.reshape(N, -1).sigmoid() + + candidate_inds = box_cls > self.pre_nms_thresh + pre_nms_top_n = candidate_inds.view(N, -1).sum(1) + pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) + + # multiply the classification scores with centerness scores + box_cls = box_cls * centerness[:, :, None] + + results = [] + for i in range(N): + per_box_cls = box_cls[i] + per_candidate_inds = candidate_inds[i] + per_box_cls = per_box_cls[per_candidate_inds] + + per_candidate_nonzeros = per_candidate_inds.nonzero() + per_box_loc = per_candidate_nonzeros[:, 0] + per_class = per_candidate_nonzeros[:, 1] + 1 + + per_box_regression = box_regression[i] + per_box_regression = per_box_regression[per_box_loc] + per_locations = locations[per_box_loc] + + per_pre_nms_top_n = pre_nms_top_n[i] + + if per_candidate_inds.sum().item() > per_pre_nms_top_n.item(): + per_box_cls, top_k_indices = \ + per_box_cls.topk(per_pre_nms_top_n, sorted=False) + per_class = per_class[top_k_indices] + per_box_regression = per_box_regression[top_k_indices] + per_locations = per_locations[top_k_indices] + + detections = torch.stack([ + per_locations[:, 0] - per_box_regression[:, 0], + per_locations[:, 1] - per_box_regression[:, 1], + per_locations[:, 0] + per_box_regression[:, 2], + per_locations[:, 1] + per_box_regression[:, 3], + ], dim=1) + + h, w = image_sizes[i] + boxlist = BoxList(detections, (int(w), int(h)), mode="xyxy") + boxlist.add_field("labels", per_class) + boxlist.add_field("scores", torch.sqrt(per_box_cls)) + boxlist = boxlist.clip_to_image(remove_empty=False) + boxlist = remove_small_boxes(boxlist, self.min_size) + results.append(boxlist) + + return results + + def forward(self, locations, box_cls, box_regression, centerness, image_sizes): + """ + Arguments: + anchors: list[list[BoxList]] + box_cls: list[tensor] + box_regression: list[tensor] + image_sizes: list[(h, w)] + Returns: + boxlists (list[BoxList]): the post-processed anchors, after + applying box decoding and NMS + """ + sampled_boxes = [] + for _, (l, o, b, c) in enumerate(zip(locations, box_cls, box_regression, centerness)): + sampled_boxes.append( + self.forward_for_single_feature_map( + l, o, b, c, image_sizes + ) + ) + + boxlists = list(zip(*sampled_boxes)) + boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] + boxlists = self.select_over_all_levels(boxlists) + + return boxlists + + # TODO very similar to filter_results from PostProcessor + # but filter_results is per image + # TODO Yang: solve this issue in the future. No good solution + # right now. + def select_over_all_levels(self, boxlists): + num_images = len(boxlists) + results = [] + for i in range(num_images): + # multiclass nms + keep = batched_nms( + boxlists[i].bbox, + boxlists[i].get_field("scores"), + boxlists[i].get_field("labels"), + self.nms_thresh + ) + result = boxlists[i][keep] + number_of_detections = len(result) + + # Limit to max_per_image detections **over all classes** + if number_of_detections > self.fpn_post_nms_top_n > 0: + cls_scores = result.get_field("scores") + image_thresh, _ = torch.kthvalue( + cls_scores.cpu(), + number_of_detections - self.fpn_post_nms_top_n + 1 + ) + keep = cls_scores >= image_thresh.item() + keep = torch.nonzero(keep).squeeze(1) + result = result[keep] + results.append(result) + return results + + +class ONNX_FCOS(nn.Module): + def __init__(self, onnx_model_path, cfg): + super(ONNX_FCOS, self).__init__() + self.onnx_model = backend.prepare( + onnx.load(onnx_model_path), + device=cfg.MODEL.DEVICE.upper() + ) + # Note that we still use PyTorch for postprocessing + self.postprocessing = FCOSPostProcessor( + pre_nms_thresh=cfg.MODEL.FCOS.INFERENCE_TH, + pre_nms_top_n=cfg.MODEL.FCOS.PRE_NMS_TOP_N, + nms_thresh=cfg.MODEL.FCOS.NMS_TH, + fpn_post_nms_top_n=cfg.TEST.DETECTIONS_PER_IMG, + min_size=0, + num_classes=cfg.MODEL.FCOS.NUM_CLASSES + ) + self.cfg = cfg + self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES + + def forward(self, images): + outputs = self.onnx_model.run(images.tensors.cpu().numpy()) + outputs = [torch.from_numpy(o).to(self.cfg.MODEL.DEVICE) for o in outputs] + num_outputs = len(outputs) // 3 + logits = outputs[:num_outputs] + bbox_reg = outputs[num_outputs:2 * num_outputs] + centerness = outputs[2 * num_outputs:] + + locations = self.compute_locations(logits) + boxes = self.postprocessing(locations, logits, bbox_reg, centerness, images.image_sizes) + return boxes + + def compute_locations(self, features): + locations = [] + for level, feature in enumerate(features): + h, w = feature.size()[-2:] + locations_per_level = self.compute_locations_per_level( + h, w, self.fpn_strides[level], + feature.device + ) + locations.append(locations_per_level) + return locations + + def compute_locations_per_level(self, h, w, stride, device): + shifts_x = torch.arange( + 0, w * stride, step=stride, + dtype=torch.float32, device=device + ) + shifts_y = torch.arange( + 0, h * stride, step=stride, + dtype=torch.float32, device=device + ) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 + return locations + + +def main(): + parser = argparse.ArgumentParser(description="Test onnx models of FCOS") + parser.add_argument( + "--config-file", + default="/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml", + metavar="FILE", + help="path to config file", + ) + parser.add_argument( + "--onnx-model", + default="fcos_imprv_R_50_FPN_1x.onnx", + metavar="FILE", + help="path to the onnx model", + ) + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + + args = parser.parse_args() + + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + + # The onnx model can only be used with DATALOADER.NUM_WORKERS = 0 + cfg.DATALOADER.NUM_WORKERS = 0 + + cfg.freeze() + + save_dir = "" + logger = setup_logger("fcos_core", save_dir, get_rank()) + logger.info(cfg) + + logger.info("Collecting env info (might take some time)") + logger.info("\n" + collect_env_info()) + + model = ONNX_FCOS(args.onnx_model, cfg) + model.to(cfg.MODEL.DEVICE) + + iou_types = ("bbox",) + if cfg.MODEL.MASK_ON: + iou_types = iou_types + ("segm",) + if cfg.MODEL.KEYPOINT_ON: + iou_types = iou_types + ("keypoints",) + output_folders = [None] * len(cfg.DATASETS.TEST) + dataset_names = cfg.DATASETS.TEST + if cfg.OUTPUT_DIR: + for idx, dataset_name in enumerate(dataset_names): + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) + mkdir(output_folder) + output_folders[idx] = output_folder + data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=False) + for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val): + inference( + model, + data_loader_val, + dataset_name=dataset_name, + iou_types=iou_types, + box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, + device=cfg.MODEL.DEVICE, + expected_results=cfg.TEST.EXPECTED_RESULTS, + expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, + output_folder=output_folder, + ) + synchronize() + + +if __name__ == "__main__": + main() diff --git a/tools/test_fcos_onnx_model.py b/tools/test_fcos_onnx_model.py deleted file mode 100644 index 909c045..0000000 --- a/tools/test_fcos_onnx_model.py +++ /dev/null @@ -1,145 +0,0 @@ -""" -An example: -python tools/test_fcos_onnx_model.py --onnx-model fcos_imprv_R_50_FPN_1x.onnx --config-file configs/fcos/fcos_imprv_R_50_FPN_1x.yaml TEST.IMS_PER_BATCH 1 DATALOADER.NUM_WORKERS 0 - -""" -from fcos_core.utils.env import setup_environment # noqa F401 isort:skip - -import argparse -import os - -import torch -from torch import nn -import onnx -from fcos_core.config import cfg -from fcos_core.data import make_data_loader -from fcos_core.engine.inference import inference -from fcos_core.modeling.detector import build_detection_model -from fcos_core.utils.checkpoint import DetectronCheckpointer -from fcos_core.utils.collect_env import collect_env_info -from fcos_core.utils.comm import synchronize, get_rank -from fcos_core.utils.logger import setup_logger -from fcos_core.utils.miscellaneous import mkdir -from fcos_core.modeling.rpn.fcos.inference import make_fcos_postprocessor -import caffe2.python.onnx.backend as backend -import numpy as np -from fcos_core.structures.image_list import to_image_list - - -class ONNX_FCOS(nn.Module): - def __init__(self, onnx_model_path, cfg): - super(ONNX_FCOS, self).__init__() - self.onnx_model = backend.prepare( - onnx.load(onnx_model_path), - device=cfg.MODEL.DEVICE.upper() - ) - self.postprocessing = make_fcos_postprocessor(cfg) - self.cfg = cfg - self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES - - def forward(self, images): - outputs = self.onnx_model.run(images.tensors.cpu().numpy()) - outputs = [torch.from_numpy(o).to(self.cfg.MODEL.DEVICE) for o in outputs] - num_outputs = len(outputs) // 3 - logits = outputs[:num_outputs] - bbox_reg = outputs[num_outputs:2 * num_outputs] - centerness = outputs[2 * num_outputs:] - - locations = self.compute_locations(logits) - boxes = self.postprocessing(locations, logits, bbox_reg, centerness, images.image_sizes) - return boxes - - def compute_locations(self, features): - locations = [] - for level, feature in enumerate(features): - h, w = feature.size()[-2:] - locations_per_level = self.compute_locations_per_level( - h, w, self.fpn_strides[level], - feature.device - ) - locations.append(locations_per_level) - return locations - - def compute_locations_per_level(self, h, w, stride, device): - shifts_x = torch.arange( - 0, w * stride, step=stride, - dtype=torch.float32, device=device - ) - shifts_y = torch.arange( - 0, h * stride, step=stride, - dtype=torch.float32, device=device - ) - shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) - shift_x = shift_x.reshape(-1) - shift_y = shift_y.reshape(-1) - locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 - return locations - - -def main(): - parser = argparse.ArgumentParser(description="Test onnx models of FCOS") - parser.add_argument( - "--config-file", - default="/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml", - metavar="FILE", - help="path to config file", - ) - parser.add_argument( - "--onnx-model", - default="fcos_imprv_R_50_FPN_1x.onnx", - metavar="FILE", - help="path to the onnx model", - ) - parser.add_argument( - "opts", - help="Modify config options using the command-line", - default=None, - nargs=argparse.REMAINDER, - ) - - args = parser.parse_args() - - cfg.merge_from_file(args.config_file) - cfg.merge_from_list(args.opts) - cfg.freeze() - - save_dir = "" - logger = setup_logger("fcos_core", save_dir, get_rank()) - logger.info(cfg) - - logger.info("Collecting env info (might take some time)") - logger.info("\n" + collect_env_info()) - - model = ONNX_FCOS(args.onnx_model, cfg) - model.to(cfg.MODEL.DEVICE) - - iou_types = ("bbox",) - if cfg.MODEL.MASK_ON: - iou_types = iou_types + ("segm",) - if cfg.MODEL.KEYPOINT_ON: - iou_types = iou_types + ("keypoints",) - output_folders = [None] * len(cfg.DATASETS.TEST) - dataset_names = cfg.DATASETS.TEST - if cfg.OUTPUT_DIR: - for idx, dataset_name in enumerate(dataset_names): - output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) - mkdir(output_folder) - output_folders[idx] = output_folder - data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=False) - for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val): - inference( - model, - data_loader_val, - dataset_name=dataset_name, - iou_types=iou_types, - box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, - device=cfg.MODEL.DEVICE, - expected_results=cfg.TEST.EXPECTED_RESULTS, - expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, - output_folder=output_folder, - ) - synchronize() - - -if __name__ == "__main__": - main() -- GitLab