提交 2953d108 编写于 作者: Y Yang Zhang 提交者: GitHub

Provide a way to avoid duplicated config keys, e.g., `num_classes` (#2764)

* Clean up a bit

* Add `__shared__` annotation for shared variables between modules

* Dedup `num_classes` configurations

* Document `__shared__` usage

* Remove some unused variable

* Improve docstring and add comments for `__shared__`
上级 a653012d
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
weights: output/cascade_rcnn_r50_fpn_1x/model_final
metric: COCO
num_classes: 81
CascadeRCNN:
backbone: ResNet
......@@ -74,7 +75,6 @@ CascadeBBoxAssigner:
bg_thresh_hi: [0.5, 0.6, 0.7]
fg_thresh: [0.5, 0.6, 0.7]
fg_fraction: 0.25
num_classes: 81
CascadeBBoxHead:
head: FC6FC7Head
......@@ -82,7 +82,6 @@ CascadeBBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
FC6FC7Head:
num_chan: 1024
......
......@@ -10,6 +10,7 @@ snapshot_iter: 10000
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar
metric: COCO
weights: output/faster_rcnn_r101_1x/model_final
num_classes: 81
FasterRCNN:
backbone: ResNet
......@@ -64,7 +65,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
BBoxHead:
head: ResNetC5
......@@ -72,7 +72,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
LearningRate:
base_lr: 0.01
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar
weights: output/faster_rcnn_r101_fpn_1x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
......@@ -73,7 +74,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
BBoxHead:
head: TwoFCHead
......@@ -81,7 +81,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
TwoFCHead:
num_chan: 1024
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar
weights: output/faster_rcnn_r101_fpn_2x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
......@@ -73,7 +74,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
BBoxHead:
head: TwoFCHead
......@@ -81,7 +81,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
TwoFCHead:
num_chan: 1024
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar
weights: output/faster_rcnn_r101_vd_fpn_1x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
......@@ -74,7 +75,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
BBoxHead:
head: TwoFCHead
......@@ -82,7 +82,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
TwoFCHead:
num_chan: 1024
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar
weights: output/faster_rcnn_r101_vd_fpn_2x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
......@@ -74,7 +75,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
BBoxHead:
head: TwoFCHead
......@@ -82,7 +82,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
TwoFCHead:
num_chan: 1024
......
......@@ -10,6 +10,7 @@ snapshot_iter: 10000
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/faster_rcnn_r50_1x/model_final
num_classes: 81
FasterRCNN:
backbone: ResNet
......@@ -64,7 +65,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
BBoxHead:
head: ResNetC5
......@@ -72,7 +72,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
LearningRate:
base_lr: 0.01
......
......@@ -10,6 +10,7 @@ snapshot_iter: 10000
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/faster_rcnn_r50_2x/model_final
num_classes: 81
FasterRCNN:
backbone: ResNet
......@@ -64,7 +65,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
BBoxHead:
head: ResNetC5
......@@ -72,7 +72,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
LearningRate:
base_lr: 0.01
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/fpn/faster_rcnn_r50_fpn_1x/model_final
num_classes: 81
FasterRCNN:
backbone: ResNet
......@@ -74,7 +75,6 @@ BBoxAssigner:
bg_thresh_hi: 0.5
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
BBoxHead:
head: TwoFCHead
......@@ -82,7 +82,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
TwoFCHead:
num_chan: 1024
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/faster_rcnn_r50_fpn_2x/model_final
num_classes: 81
FasterRCNN:
backbone: ResNet
......@@ -74,7 +75,6 @@ BBoxAssigner:
bg_thresh_hi: 0.5
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
BBoxHead:
head: TwoFCHead
......@@ -82,7 +82,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
TwoFCHead:
num_chan: 1024
......
......@@ -10,6 +10,7 @@ snapshot_iter: 10000
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
metric: COCO
weights: output/faster_rcnn_r50_vd_1x/model_final
num_classes: 81
FasterRCNN:
backbone: ResNet
......@@ -66,7 +67,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
BBoxHead:
head: ResNetC5
......@@ -74,7 +74,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
LearningRate:
base_lr: 0.01
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
weights: output/faster_rcnn_r50_vd_fpn_2x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
......@@ -74,7 +75,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
BBoxHead:
head: TwoFCHead
......@@ -82,7 +82,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
TwoFCHead:
num_chan: 1024
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/SE154_vd_pretrained.tar
weights: output/faster_rcnn_se154_vd_fpn_s1x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: SENet
......@@ -76,7 +77,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
BBoxHead:
head: TwoFCHead
......@@ -84,7 +84,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
TwoFCHead:
num_chan: 1024
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_vd_64x4d_pretrained.tar
weights: output/faster_rcnn_x101_vd_64x4d_fpn_1x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNeXt
......@@ -76,7 +77,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
BBoxHead:
head: TwoFCHead
......@@ -84,7 +84,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
TwoFCHead:
num_chan: 1024
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar
metric: COCO
weights: output/mask_rcnn_r101_fpn_1x/model_final/
num_classes: 81
MaskRCNN:
backbone: ResNet
......@@ -68,7 +69,6 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
num_classes: 81
num_convs: 4
resolution: 28
......@@ -79,7 +79,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
MaskAssigner:
resolution: 28
......@@ -90,7 +89,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
TwoFCHead:
num_chan: 1024
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/mask_rcnn_r50_1x/model_final
num_classes: 81
MaskRCNN:
backbone: ResNet
......@@ -66,12 +67,10 @@ BBoxHead:
nms_threshold: 0.5
normalized: false
score_threshold: 0.05
num_classes: 81
MaskHead:
dilation: 1
num_chan_reduced: 256
num_classes: 81
resolution: 14
BBoxAssigner:
......@@ -81,10 +80,8 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
MaskAssigner:
num_classes: 81
resolution: 14
LearningRate:
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/mask_rcnn_r50_2x/model_final/
num_classes: 81
MaskRCNN:
backbone: ResNet
......@@ -67,12 +68,10 @@ BBoxHead:
nms_threshold: 0.5
normalized: false
score_threshold: 0.05
num_classes: 81
MaskHead:
dilation: 1
num_chan_reduced: 256
num_classes: 81
resolution: 14
BBoxAssigner:
......@@ -82,10 +81,8 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
MaskAssigner:
num_classes: 81
resolution: 14
LearningRate:
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/mask_rcnn_r50_fpn_1x/model_final/
num_classes: 81
MaskRCNN:
backbone: ResNet
......@@ -68,7 +69,6 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
num_classes: 81
num_convs: 4
resolution: 28
......@@ -79,7 +79,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
MaskAssigner:
resolution: 28
......@@ -90,7 +89,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
TwoFCHead:
num_chan: 1024
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
weights: output/mask_rcnn_r50_fpn_2x/model_final/
metric: COCO
num_classes: 81
MaskRCNN:
backbone: ResNet
......@@ -68,7 +69,6 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
num_classes: 81
num_convs: 4
resolution: 28
......@@ -79,7 +79,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
MaskAssigner:
resolution: 28
......@@ -90,7 +89,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
TwoFCHead:
num_chan: 1024
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
metric: COCO
weights: output/mask_rcnn_r50_vd_fpn_2x/model_final/
num_classes: 81
MaskRCNN:
backbone: ResNet
......@@ -69,7 +70,6 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
num_classes: 81
num_convs: 4
resolution: 28
......@@ -80,7 +80,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
MaskAssigner:
resolution: 28
......@@ -91,7 +90,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
TwoFCHead:
num_chan: 1024
......
......@@ -10,6 +10,7 @@ save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/SE154_vd_pretrained.tar
weights: output/mask_rcnn_se154_vd_fpn_s1x/model_final/
metric: COCO
num_classes: 81
MaskRCNN:
backbone: SENet
......@@ -71,7 +72,6 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
num_classes: 81
num_convs: 4
resolution: 28
......@@ -82,7 +82,6 @@ BBoxAssigner:
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
MaskAssigner:
resolution: 28
......@@ -93,7 +92,6 @@ BBoxHead:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
TwoFCHead:
num_chan: 1024
......
......@@ -188,3 +188,14 @@ A small utility (`tools/configure.py`) is included to simplify the configuration
```shell
python tools/configure.py --minimal generate FasterRCNN BBoxHead
```
# FAQ
**Q:** There are some configuration options that are used by multiple modules (e.g., `num_classes`), how do I avoid duplication in config files?
**A:** We provided a `__shared__` annotation for exactly this purpose, simply annotate like this `__shared__ = ['num_classes']`. It works as follows:
1. if `num_classes` is configured for a module in config file, it takes precedence.
2. if `num_classes` is not configured for a module but is present in the config file as a global key, its value will be used.
3. otherwise, the default value (`81`) will be used.
......@@ -180,3 +180,14 @@ pip install typeguard http://github.com/willthefrog/docstring_parser/tarball/mas
```shell
python tools/configure.py --minimal generate FasterRCNN BBoxHead
```
# FAQ
**Q:** 某些配置项会在多个模块中用到(如 `num_classes`),如何避免在配置文件中多次重复设置?
**A:** 框架提供了 `__shared__` 标记来实现配置的共享,用户可以标记参数,如 `__shared__ = ['num_classes']` ,配置数值作用规则如下:
1. 如果模块配置中提供了 `num_classes` ,会优先使用其数值。
2. 如果模块配置中未提供 `num_classes` ,但配置文件中存在全局键值,那么会使用全局键值。
3. 两者均为配置的情况下,将使用默认值(`81`)。
......@@ -43,13 +43,14 @@ except Exception:
if not check_type.__warning_sent__:
from ppdet.utils.cli import ColorTTY
color_tty = ColorTTY()
message = "typeguard is not installed, type checking is not available"
message = "typeguard is not installed," \
+ "type checking is not available"
print(color_tty.yellow(message))
check_type.__warning_sent__ = True
check_type.__warning_sent__ = False
__all__ = ['SchemaValue', 'SchemaDict', 'extract_schema']
__all__ = ['SchemaValue', 'SchemaDict', 'SharedConfig', 'extract_schema']
class SchemaValue(object):
......@@ -160,6 +161,27 @@ class SchemaDict(dict):
self.name, ", ".join(mismatch_keys)))
class SharedConfig(object):
"""
Representation class for `__shared__` annotations, which work as follows:
- if `key` is set for the module in config file, its value will take
precedence
- if `key` is not set for the module but present in the config file, its
value will be used
- otherwise, use the provided `default_value` as fallback
Args:
key: config[key] will be injected
default_value: fallback value
"""
def __init__(self, key, default_value=None):
super(SharedConfig, self).__init__()
self.key = key
self.default_value = default_value
def extract_schema(cls):
"""
Extract schema from a given class
......@@ -216,6 +238,7 @@ def extract_schema(cls):
schema.strict = not has_kwargs
schema.pymodule = importlib.import_module(cls.__module__)
schema.inject = getattr(cls, '__inject__', [])
schema.shared = getattr(cls, '__shared__', [])
for idx, name in enumerate(names):
comment = name in comments and comments[name] or name
if name in schema.inject:
......@@ -223,8 +246,13 @@ def extract_schema(cls):
else:
type_ = name in annotations and annotations[name] or None
value_schema = SchemaValue(name, comment, type_)
if idx >= num_required:
value_schema.set_default(defaults[idx - num_required])
if name in schema.shared:
assert idx >= num_required, "shared config must have default value"
default = defaults[idx - num_required]
value_schema.set_default(SharedConfig(name, default))
elif idx >= num_required:
default = defaults[idx - num_required]
value_schema.set_default(default)
schema.set_schema(name, value_schema)
return schema
......@@ -16,6 +16,7 @@ import importlib
import inspect
import yaml
from .schema import SharedConfig
__all__ = ['serializable', 'Callable']
......@@ -59,7 +60,8 @@ def _make_python_representer(cls):
def serializable(cls):
"""
Add loader and dumper for given class, which must be "trivially serializable"
Add loader and dumper for given class, which must be
"trivially serializable"
Args:
cls: class to be serialized
......@@ -72,6 +74,10 @@ def serializable(cls):
return cls
yaml.add_representer(SharedConfig,
lambda d, o: d.represent_data(o.default_value))
@serializable
class Callable(object):
"""
......
......@@ -23,7 +23,7 @@ import sys
import yaml
import copy
from .config.schema import SchemaDict, extract_schema
from .config.schema import SchemaDict, SharedConfig, extract_schema
from .config.yaml_helpers import serializable
__all__ = [
......@@ -136,7 +136,8 @@ def create(cls_or_name, **kwargs):
assert type(cls_or_name) in [type, str
], "should be a class or name of a class"
name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
assert name in global_config and isinstance(global_config[name], SchemaDict), \
assert name in global_config and \
isinstance(global_config[name], SchemaDict), \
"the module {} is not registered".format(name)
config = global_config[name]
config.update(kwargs)
......@@ -145,9 +146,26 @@ def create(cls_or_name, **kwargs):
kwargs = {}
kwargs.update(global_config[name])
# parse `shared` annoation of registered modules
if getattr(config, 'shared', None):
for k in config.shared:
target_key = config[k]
shared_conf = config.schema[k].default
assert isinstance(shared_conf, SharedConfig)
if target_key is not None and not isinstance(
target_key, SharedConfig):
continue # value is given for the module
elif shared_conf.key in global_config:
# `key` is present in config
kwargs[k] = global_config[shared_conf.key]
else:
kwargs[k] = shared_conf.default_value
# parse `inject` annoation of registered modules
if getattr(config, 'inject', None):
for k in config.inject:
target_key = global_config[name][k]
target_key = config[k]
# optional dependency
if target_key is None:
continue
......
......@@ -181,7 +181,6 @@ class DataSet(object):
Args:
annotation (str): annotation file path
image_dir (str): directory where image files are stored
num_classes (int): number of classes
shuffle (bool): shuffle samples
"""
__source__ = 'RoiDbSource'
......
......@@ -25,7 +25,7 @@ from paddle.fluid.regularizer import L2Decay
from ppdet.modeling.ops import (AnchorGenerator, RetinaTargetAssign,
RetinaOutputDecoder)
from ppdet.core.workspace import register, serializable
from ppdet.core.workspace import register
__all__ = ['RetinaHead']
......@@ -52,6 +52,7 @@ class RetinaHead(object):
sigma (float): The parameter in smooth l1 loss
"""
__inject__ = ['anchor_generator', 'target_assign', 'output_decoder']
__shared__ = ['num_classes']
def __init__(self,
anchor_generator=AnchorGenerator().__dict__,
......@@ -333,7 +334,6 @@ class RetinaHead(object):
cls_pred_reshape_list = output['cls_pred']
bbox_pred_reshape_list = output['bbox_pred']
anchor_reshape_list = output['anchor']
anchor_var_reshape_list = output['anchor_var']
for i in range(self.max_level - self.min_level + 1):
cls_pred_reshape_list[i] = fluid.layers.sigmoid(
cls_pred_reshape_list[i])
......
......@@ -88,6 +88,7 @@ class GenerateProposals(object):
class MaskAssigner(object):
__op__ = fluid.layers.generate_mask_labels
__append_doc__ = True
__shared__ = ['num_classes']
def __init__(self, num_classes=81, resolution=14):
super(MaskAssigner, self).__init__()
......@@ -123,6 +124,7 @@ class MultiClassNMS(object):
class BBoxAssigner(object):
__op__ = fluid.layers.generate_proposal_labels
__append_doc__ = True
__shared__ = ['num_classes']
def __init__(self,
batch_size_per_im=512,
......
......@@ -92,12 +92,13 @@ class BBoxHead(object):
RCNN bbox head
Args:
head (object): the head module instance, e.g., `ResNetC5` or `TwoFCHead`
head (object): the head module instance, e.g., `ResNetC5`, `TwoFCHead`
box_coder (object): `BoxCoder` instance
nms (object): `MultiClassNMS` instance
num_classes: number of output classes
"""
__inject__ = ['head', 'box_coder', 'nms']
__shared__ = ['num_classes']
def __init__(self,
head,
......
......@@ -37,6 +37,7 @@ class CascadeBBoxHead(object):
num_classes: number of output classes
"""
__inject__ = ['head', 'nms']
__shared__ = ['num_classes']
def __init__(self, head, nms=MultiClassNMS().__dict__, num_classes=81):
super(CascadeBBoxHead, self).__init__()
......
......@@ -38,6 +38,8 @@ class MaskHead(object):
num_classes (int): number of output classes
"""
__shared__ = ['num_classes']
def __init__(self,
num_convs=0,
num_chan_reduced=256,
......
......@@ -26,6 +26,8 @@ __all__ = ['BBoxAssigner', 'MaskAssigner', 'CascadeBBoxAssigner']
@register
class CascadeBBoxAssigner(object):
__shared__ = ['num_classes']
def __init__(self,
batch_size_per_im=512,
fg_fraction=.25,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册