From 212adf7bd070547de3d5945d8d208b7ab156c0a8 Mon Sep 17 00:00:00 2001 From: linjintao Date: Sat, 15 Feb 2020 23:54:58 +0800 Subject: [PATCH] take AdaptiveAvgPoolnd to replace origin one --- config/i3d_rgb_32x2x1_r50_3d_kinetics400_100e.py | 2 -- config/tsn_rgb_1x1x3_r50_2d_kinetics400_100e.py | 1 - mmaction/models/heads/i3d_head.py | 11 ++--------- mmaction/models/heads/tsn_head.py | 13 ++++--------- 4 files changed, 6 insertions(+), 21 deletions(-) diff --git a/config/i3d_rgb_32x2x1_r50_3d_kinetics400_100e.py b/config/i3d_rgb_32x2x1_r50_3d_kinetics400_100e.py index 2155a39..aed67a3 100644 --- a/config/i3d_rgb_32x2x1_r50_3d_kinetics400_100e.py +++ b/config/i3d_rgb_32x2x1_r50_3d_kinetics400_100e.py @@ -14,8 +14,6 @@ model = dict( num_classes=400, in_channels=2048, spatial_type='avg', - spatial_size=7, - temporal_size=4, dropout_ratio=0.5, init_std=0.01)) # model training and testing settings diff --git a/config/tsn_rgb_1x1x3_r50_2d_kinetics400_100e.py b/config/tsn_rgb_1x1x3_r50_2d_kinetics400_100e.py index 16e9e74..117e97d 100644 --- a/config/tsn_rgb_1x1x3_r50_2d_kinetics400_100e.py +++ b/config/tsn_rgb_1x1x3_r50_2d_kinetics400_100e.py @@ -11,7 +11,6 @@ model = dict( num_classes=400, in_channels=2048, spatial_type='avg', - spatial_size=7, consensus=dict(type='AvgConsensus', dim=1), dropout_ratio=0.4, init_std=0.01)) diff --git a/mmaction/models/heads/i3d_head.py b/mmaction/models/heads/i3d_head.py index 90eb528..57d8c58 100644 --- a/mmaction/models/heads/i3d_head.py +++ b/mmaction/models/heads/i3d_head.py @@ -1,7 +1,5 @@ -import mmcv import torch.nn as nn from mmcv.cnn.weight_init import normal_init -from torch.nn.modules.utils import _pair from ..registry import HEADS from .base import BaseHead @@ -27,17 +25,11 @@ class I3DHead(BaseHead): num_classes, in_channels=2048, spatial_type='avg', - spatial_size=7, - temporal_size=4, dropout_ratio=0.5, init_std=0.01): super(I3DHead, self).__init__(num_classes, in_channels) - self.spatial_size = _pair(spatial_size) - assert mmcv.is_tuple_of(self.spatial_size, int) self.spatial_type = spatial_type - self.temporal_size = temporal_size - self.pool_size = (self.temporal_size, ) + self.spatial_size self.dropout_ratio = dropout_ratio self.init_std = init_std if self.dropout_ratio != 0: @@ -47,7 +39,8 @@ class I3DHead(BaseHead): self.fc_cls = nn.Linear(self.in_channels, self.num_classes) if self.spatial_type == 'avg': - self.avg_pool = nn.AvgPool3d(self.pool_size, stride=1, padding=0) + # use `nn.AdaptiveAvgPool3d` to adaptively match the in_channels. + self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) else: self.avg_pool = None diff --git a/mmaction/models/heads/tsn_head.py b/mmaction/models/heads/tsn_head.py index c316dd6..ecdb8b3 100644 --- a/mmaction/models/heads/tsn_head.py +++ b/mmaction/models/heads/tsn_head.py @@ -1,7 +1,5 @@ -import mmcv import torch.nn as nn from mmcv.cnn.weight_init import normal_init -from torch.nn.modules.utils import _pair from ..registry import HEADS from .base import BaseHead @@ -41,13 +39,10 @@ class TSNHead(BaseHead): num_classes, in_channels=2048, spatial_type='avg', - spatial_size=7, consensus=dict(type='AvgConsensus', dim=1), dropout_ratio=0.4, init_std=0.01): super(TSNHead, self).__init__(num_classes, in_channels) - self.spatial_size = _pair(spatial_size) - assert mmcv.is_tuple_of(self.spatial_size, int) self.spatial_type = spatial_type self.dropout_ratio = dropout_ratio @@ -60,10 +55,10 @@ class TSNHead(BaseHead): self.consensus = None if self.spatial_type == 'avg': - self.avg_pool2d = nn.AvgPool2d( - self.spatial_size, stride=1, padding=0) + # use `nn.AdaptiveAvgPool2d` to adaptively match the in_channels. + self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) else: - self.avg_pool2d = None + self.avg_pool = None if self.dropout_ratio != 0: self.dropout = nn.Dropout(p=self.dropout_ratio) @@ -76,7 +71,7 @@ class TSNHead(BaseHead): def forward(self, x, num_segs): # [N * num_segs, in_channels, 7, 7] - x = self.avg_pool2d(x) + x = self.avg_pool(x) # [N * num_segs, in_channels, 1, 1] x = x.reshape((-1, num_segs) + x.shape[1:]) # [N, num_segs, in_channels, 1, 1] -- GitLab