提交 bda8ab01 编写于 作者: R Ross Wightman

Remove min channels for SelectiveKernel, divisor should cover cases well enough.

上级 a27f4aec
......@@ -49,7 +49,7 @@ class SelectiveKernelAttn(nn.Module):
class SelectiveKernel(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1,
rd_ratio=1./16, rd_channels=None, min_rd_channels=32, rd_divisor=8, keep_3x3=True, split_input=True,
rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
""" Selective Kernel Convolution Module
......@@ -68,7 +68,6 @@ class SelectiveKernel(nn.Module):
dilation (int): dilation for module as a whole, impacts dilation of each branch
groups (int): number of groups for each branch
rd_ratio (int, float): reduction factor for attention features
min_rd_channels (int): minimum attention feature channels
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
can be viewed as grouping by path, output expands to module out_channels count
......@@ -103,8 +102,7 @@ class SelectiveKernel(nn.Module):
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
for k, d in zip(kernel_size, dilation)])
attn_channels = rd_channels or make_divisible(
out_channels * rd_ratio, min_value=min_rd_channels, divisor=rd_divisor)
attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
self.drop_block = drop_block
......
......@@ -153,7 +153,7 @@ def skresnet18(pretrained=False, **kwargs):
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down.
"""
sk_kwargs = dict(min_rd_channels=16, rd_ratio=1/8, split_input=True)
sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
model_args = dict(
block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
......@@ -167,7 +167,7 @@ def skresnet34(pretrained=False, **kwargs):
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down.
"""
sk_kwargs = dict(min_rd_channels=16, rd_ratio=1/8, split_input=True)
sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
model_args = dict(
block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
......@@ -207,7 +207,7 @@ def skresnext50_32x4d(pretrained=False, **kwargs):
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
the SKNet-50 model in the Select Kernel Paper
"""
sk_kwargs = dict(min_rd_channels=32, rd_ratio=1/16, split_input=False)
sk_kwargs = dict(rd_ratio=1/16, rd_divisor=32, split_input=False)
model_args = dict(
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册