diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 1176377457b38e158b6a30026adcca6eb0a66dc2..73da16a147234dd441b92627a9781038488e68ad 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -1388,3 +1388,37 @@ class MultiHeadAttention(nn.Layer): if self.need_weights: outs.append(weights) return out if len(outs) == 1 else tuple(outs) + + +@register +class ConvMixer(nn.Layer): + def __init__( + self, + dim, + depth, + kernel_size=3, ): + super().__init__() + self.dim = dim + self.depth = depth + self.kernel_size = kernel_size + + self.mixer = self.conv_mixer(dim, depth, kernel_size) + + def forward(self, x): + return self.mixer(x) + + @staticmethod + def conv_mixer( + dim, + depth, + kernel_size, ): + Seq, ActBn = nn.Sequential, lambda x: Seq(x, nn.GELU(), nn.BatchNorm2D(dim)) + Residual = type('Residual', (Seq, ), + {'forward': lambda self, x: self[0](x) + x}) + return Seq(*[ + Seq(Residual( + ActBn( + nn.Conv2D( + dim, dim, kernel_size, groups=dim, padding="same"))), + ActBn(nn.Conv2D(dim, dim, 1))) for i in range(depth) + ])