提交 4962f712 编写于 作者: Y Yang Nie 提交者: Tingquan Gao

remove `ChannelShuffle2`

上级 a881c7a7
...@@ -232,11 +232,6 @@ class ChannelShuffle(nn.Layer): ...@@ -232,11 +232,6 @@ class ChannelShuffle(nn.Layer):
return out return out
class ChannelShuffle2(ChannelShuffle):
pass
class SpatialSepConvSF(nn.Layer): class SpatialSepConvSF(nn.Layer):
def __init__(self, inp, oups, kernel_size, stride): def __init__(self, inp, oups, kernel_size, stride):
super().__init__() super().__init__()
...@@ -417,8 +412,7 @@ class DYMicroBlock(nn.Layer): ...@@ -417,8 +412,7 @@ class DYMicroBlock(nn.Layer):
g=gs1, g=gs1,
expansion=False) if y2 > 0 else nn.ReLU6(), expansion=False) if y2 > 0 else nn.ReLU6(),
ChannelShuffle(gs1[1]), ChannelShuffle(gs1[1]),
ChannelShuffle2(hidden_dim2 // 2) ChannelShuffle(hidden_dim2 // 2) if y2 != 0 else nn.Identity(),
if y2 != 0 else nn.Identity(),
GroupConv(hidden_dim2, oup, (g1, g2)), GroupConv(hidden_dim2, oup, (g1, g2)),
DYShiftMax( DYShiftMax(
oup, oup,
...@@ -431,7 +425,7 @@ class DYMicroBlock(nn.Layer): ...@@ -431,7 +425,7 @@ class DYMicroBlock(nn.Layer):
g=(g1, g2), g=(g1, g2),
expansion=False) if y3 > 0 else nn.Identity(), expansion=False) if y3 > 0 else nn.Identity(),
ChannelShuffle(g2), ChannelShuffle(g2),
ChannelShuffle2(oup // 2) ChannelShuffle(oup // 2)
if oup % 2 == 0 and y3 != 0 else nn.Identity(), ) if oup % 2 == 0 and y3 != 0 else nn.Identity(), )
elif g2 == 0: elif g2 == 0:
self.layers = nn.Sequential( self.layers = nn.Sequential(
...@@ -472,9 +466,9 @@ class DYMicroBlock(nn.Layer): ...@@ -472,9 +466,9 @@ class DYMicroBlock(nn.Layer):
init_b=init_b, init_b=init_b,
g=gs1, g=gs1,
expansion=True, ) if y2 > 0 else nn.ReLU6(), expansion=True, ) if y2 > 0 else nn.ReLU6(),
ChannelShuffle2(hidden_dim2 // 4) ChannelShuffle(hidden_dim2 // 4)
if y1 != 0 and y2 != 0 else nn.Identity() if y1 != 0 and y2 != 0 else nn.Identity()
if y1 == 0 and y2 == 0 else ChannelShuffle2(hidden_dim2 // 2), if y1 == 0 and y2 == 0 else ChannelShuffle(hidden_dim2 // 2),
GroupConv(hidden_dim2, oup, (g1, g2)), GroupConv(hidden_dim2, oup, (g1, g2)),
DYShiftMax( DYShiftMax(
oup, oup,
...@@ -488,7 +482,7 @@ class DYMicroBlock(nn.Layer): ...@@ -488,7 +482,7 @@ class DYMicroBlock(nn.Layer):
g=(g1, g2), g=(g1, g2),
expansion=False) if y3 > 0 else nn.Identity(), expansion=False) if y3 > 0 else nn.Identity(),
ChannelShuffle(g2), ChannelShuffle(g2),
ChannelShuffle2(oup // 2) if y3 != 0 else nn.Identity(), ) ChannelShuffle(oup // 2) if y3 != 0 else nn.Identity(), )
def forward(self, x): def forward(self, x):
out = self.layers(x) out = self.layers(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册