未验证 提交 ea73520d 编写于 作者: S su 提交者: GitHub

Fix the multi-class support in a cleaner way. (#151)

上级 cdcc5b47
......@@ -17,6 +17,8 @@ class I3DHead(BaseHead):
spatial_type (str): Pooling type in spatial dimension. Default: 'avg'.
dropout_ratio (float): Probability of dropout layer. Default: 0.5.
init_std (float): Std value for Initiation. Default: 0.01.
kwargs (dict, optional): Any keyword argument to be used to initialize
the head.
"""
def __init__(self,
......@@ -25,8 +27,9 @@ class I3DHead(BaseHead):
loss_cls=dict(type='CrossEntropyLoss'),
spatial_type='avg',
dropout_ratio=0.5,
init_std=0.01):
super().__init__(num_classes, in_channels, loss_cls)
init_std=0.01,
**kwargs):
super().__init__(num_classes, in_channels, loss_cls, **kwargs)
self.spatial_type = spatial_type
self.dropout_ratio = dropout_ratio
......
......@@ -18,6 +18,8 @@ class SlowFastHead(BaseHead):
spatial_type (str): Pooling type in spatial dimension. Default: 'avg'.
dropout_ratio (float): Probability of dropout layer. Default: 0.8.
init_std (float): Std value for Initiation. Default: 0.01.
kwargs (dict, optional): Any keyword argument to be used to initialize
the head.
"""
def __init__(self,
......@@ -26,9 +28,10 @@ class SlowFastHead(BaseHead):
loss_cls=dict(type='CrossEntropyLoss'),
spatial_type='avg',
dropout_ratio=0.8,
init_std=0.01):
init_std=0.01,
**kwargs):
super().__init__(num_classes, in_channels, loss_cls)
super().__init__(num_classes, in_channels, loss_cls, **kwargs)
self.spatial_type = spatial_type
self.dropout_ratio = dropout_ratio
self.init_std = init_std
......
......@@ -21,10 +21,10 @@ class TSMHead(BaseHead):
init_std (float): Std value for Initiation. Default: 0.01.
is_shift (bool): Indicating whether the feature is shifted.
Default: True.
multi_class (bool): Determines whether it is a multi-class
recognition task. Default: False.
temporal_pool (bool): Indicating whether feature is temporal pooled.
Default: False.
kwargs (dict, optional): Any keyword argument to be used to initialize
the head.
"""
def __init__(self,
......@@ -36,9 +36,9 @@ class TSMHead(BaseHead):
dropout_ratio=0.8,
init_std=0.001,
is_shift=True,
multi_class=False,
temporal_pool=False):
super().__init__(num_classes, in_channels, loss_cls, multi_class)
temporal_pool=False,
**kwargs):
super().__init__(num_classes, in_channels, loss_cls, **kwargs)
self.spatial_type = spatial_type
self.dropout_ratio = dropout_ratio
......
......@@ -18,10 +18,8 @@ class TSNHead(BaseHead):
consensus (dict): Consensus config dict.
dropout_ratio (float): Probability of dropout layer. Default: 0.4.
init_std (float): Std value for Initiation. Default: 0.01.
multi_class (bool): Determines whether it is a multi-class
recognition task. Default: False.
label_smooth_eps (float): Epsilon used in label smooth.
Reference: https://arxiv.org/abs/1906.02629. Default: 0.
kwargs (dict, optional): Any keyword argument to be used to initialize
the head.
"""
def __init__(self,
......@@ -32,14 +30,8 @@ class TSNHead(BaseHead):
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.4,
init_std=0.01,
multi_class=False,
label_smooth_eps=0.0):
super().__init__(
num_classes,
in_channels,
loss_cls=loss_cls,
multi_class=multi_class,
label_smooth_eps=label_smooth_eps)
**kwargs):
super().__init__(num_classes, in_channels, loss_cls=loss_cls, **kwargs)
self.spatial_type = spatial_type
self.dropout_ratio = dropout_ratio
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册