提交 83849b4e 编写于 作者: W wat3rBro 提交者: Francisco Massa

Add registry for model builder functions (#153)

* adding registry to hook custom building blocks

* adding customizable rpn head

* support customizable c2 weight loading
上级 1276d20b
......@@ -134,6 +134,8 @@ _C.MODEL.RPN.MIN_SIZE = 0
# all FPN levels
_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN = 2000
_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST = 2000
# Custom rpn head, empty to use default conv or separable conv
_C.MODEL.RPN.RPN_HEAD = "SingleConvRPNHead"
# ---------------------------------------------------------------------------- #
......
......@@ -3,16 +3,21 @@ from collections import OrderedDict
from torch import nn
from maskrcnn_benchmark.modeling import registry
from . import fpn as fpn_module
from . import resnet
@registry.BACKBONES.register("R-50-C4")
def build_resnet_backbone(cfg):
body = resnet.ResNet(cfg)
model = nn.Sequential(OrderedDict([("body", body)]))
return model
@registry.BACKBONES.register("R-50-FPN")
@registry.BACKBONES.register("R-101-FPN")
def build_resnet_fpn_backbone(cfg):
body = resnet.ResNet(cfg)
in_channels_stage2 = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
......@@ -31,14 +36,9 @@ def build_resnet_fpn_backbone(cfg):
return model
_BACKBONES = {"resnet": build_resnet_backbone, "resnet-fpn": build_resnet_fpn_backbone}
def build_backbone(cfg):
assert cfg.MODEL.BACKBONE.CONV_BODY.startswith(
"R-"
), "Only ResNet and ResNeXt models are currently implemented"
# Models using FPN end with "-FPN"
if cfg.MODEL.BACKBONE.CONV_BODY.endswith("-FPN"):
return build_resnet_fpn_backbone(cfg)
return build_resnet_backbone(cfg)
assert cfg.MODEL.BACKBONE.CONV_BODY in registry.BACKBONES, \
"cfg.MODEL.BACKBONE.CONV_BODY: {} are not registered in registry".format(
cfg.MODEL.BACKBONE.CONV_BODY
)
return registry.BACKBONES[cfg.MODEL.BACKBONE.CONV_BODY](cfg)
......@@ -18,6 +18,7 @@ from torch import nn
from maskrcnn_benchmark.layers import FrozenBatchNorm2d
from maskrcnn_benchmark.layers import Conv2d
from maskrcnn_benchmark.utils.registry import Registry
# ResNet stage specification
......@@ -290,30 +291,15 @@ class StemWithFixedBatchNorm(nn.Module):
return x
_TRANSFORMATION_MODULES = {"BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm}
_TRANSFORMATION_MODULES = Registry({
"BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm
})
_STEM_MODULES = {"StemWithFixedBatchNorm": StemWithFixedBatchNorm}
_STEM_MODULES = Registry({"StemWithFixedBatchNorm": StemWithFixedBatchNorm})
_STAGE_SPECS = {
_STAGE_SPECS = Registry({
"R-50-C4": ResNet50StagesTo4,
"R-50-C5": ResNet50StagesTo5,
"R-50-FPN": ResNet50FPNStagesTo5,
"R-101-FPN": ResNet101FPNStagesTo5,
}
def register_transformation_module(module_name, module):
_register_generic(_TRANSFORMATION_MODULES, module_name, module)
def register_stem_module(module_name, module):
_register_generic(_STEM_MODULES, module_name, module)
def register_stage_spec(stage_spec_name, stage_spec):
_register_generic(_STAGE_SPECS, stage_spec_name, stage_spec)
def _register_generic(module_dict, module_name, module):
assert module_name not in module_dict
module_dict[module_name] = module
})
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from maskrcnn_benchmark.utils.registry import Registry
BACKBONES = Registry()
ROI_BOX_FEATURE_EXTRACTORS = Registry()
RPN_HEADS = Registry()
......@@ -2,10 +2,12 @@
from torch import nn
from torch.nn import functional as F
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.backbone import resnet
from maskrcnn_benchmark.modeling.poolers import Pooler
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("ResNet50Conv5ROIFeatureExtractor")
class ResNet50Conv5ROIFeatureExtractor(nn.Module):
def __init__(self, config):
super(ResNet50Conv5ROIFeatureExtractor, self).__init__()
......@@ -39,6 +41,7 @@ class ResNet50Conv5ROIFeatureExtractor(nn.Module):
return x
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("FPN2MLPFeatureExtractor")
class FPN2MLPFeatureExtractor(nn.Module):
"""
Heads for FPN for classification
......@@ -77,12 +80,8 @@ class FPN2MLPFeatureExtractor(nn.Module):
return x
_ROI_BOX_FEATURE_EXTRACTORS = {
"ResNet50Conv5ROIFeatureExtractor": ResNet50Conv5ROIFeatureExtractor,
"FPN2MLPFeatureExtractor": FPN2MLPFeatureExtractor,
}
def make_roi_box_feature_extractor(cfg):
func = _ROI_BOX_FEATURE_EXTRACTORS[cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR]
func = registry.ROI_BOX_FEATURE_EXTRACTORS[
cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR
]
return func(cfg)
......@@ -3,20 +3,23 @@ import torch
import torch.nn.functional as F
from torch import nn
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from .loss import make_rpn_loss_evaluator
from .anchor_generator import make_anchor_generator
from .inference import make_rpn_postprocessor
@registry.RPN_HEADS.register("SingleConvRPNHead")
class RPNHead(nn.Module):
"""
Adds a simple RPN Head with classification and regression heads
"""
def __init__(self, in_channels, num_anchors):
def __init__(self, cfg, in_channels, num_anchors):
"""
Arguments:
cfg : config
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
"""
......@@ -57,7 +60,10 @@ class RPNModule(torch.nn.Module):
anchor_generator = make_anchor_generator(cfg)
in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
head = RPNHead(in_channels, anchor_generator.num_anchors_per_location()[0])
rpn_head = registry.RPN_HEADS[cfg.MODEL.RPN.RPN_HEAD]
head = rpn_head(
cfg, in_channels, anchor_generator.num_anchors_per_location()[0]
)
rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
......
......@@ -6,6 +6,7 @@ from collections import OrderedDict
import torch
from maskrcnn_benchmark.utils.model_serialization import load_state_dict
from maskrcnn_benchmark.utils.registry import Registry
def _rename_basic_resnet_weights(layer_keys):
......@@ -135,11 +136,20 @@ _C2_STAGE_NAMES = {
"R-101": ["1.2", "2.3", "3.22", "4.2"],
}
def load_c2_format(cfg, f):
# TODO make it support other architectures
C2_FORMAT_LOADER = Registry()
@C2_FORMAT_LOADER.register("R-50-C4")
@C2_FORMAT_LOADER.register("R-50-FPN")
@C2_FORMAT_LOADER.register("R-101-FPN")
def load_resnet_c2_format(cfg, f):
state_dict = _load_c2_pickled_weights(f)
conv_body = cfg.MODEL.BACKBONE.CONV_BODY
arch = conv_body.replace("-C4", "").replace("-FPN", "")
stages = _C2_STAGE_NAMES[arch]
state_dict = _rename_weights_for_resnet(state_dict, stages)
return dict(model=state_dict)
def load_c2_format(cfg, f):
return C2_FORMAT_LOADER[cfg.MODEL.BACKBONE.CONV_BODY](cfg, f)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
def _register_generic(module_dict, module_name, module):
assert module_name not in module_dict
module_dict[module_name] = module
class Registry(dict):
'''
A helper class for managing registering modules, it extends a dictionary
and provides a register functions.
Eg. creeting a registry:
some_registry = Registry({"default": default_module})
There're two ways of registering new modules:
1): normal way is just calling register function:
def foo():
...
some_registry.register("foo_module", foo)
2): used as decorator when declaring the module:
@some_registry.register("foo_module")
@some_registry.register("foo_modeul_nickname")
def foo():
...
Access of module is just like using a dictionary, eg:
f = some_registry["foo_modeul"]
'''
def __init__(self, *args, **kwargs):
super(Registry, self).__init__(*args, **kwargs)
def register(self, module_name, module=None):
# used as function call
if module is not None:
_register_generic(self, module_name, module)
return
# used as decorator
def register_fn(fn):
_register_generic(self, module_name, fn)
return fn
return register_fn
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册