未验证 提交 16e3d740 编写于 作者: W Wenyu 提交者: GitHub

add conv mixer (#4280)

上级 0061381a
......@@ -1388,3 +1388,37 @@ class MultiHeadAttention(nn.Layer):
if self.need_weights:
outs.append(weights)
return out if len(outs) == 1 else tuple(outs)
@register
class ConvMixer(nn.Layer):
def __init__(
self,
dim,
depth,
kernel_size=3, ):
super().__init__()
self.dim = dim
self.depth = depth
self.kernel_size = kernel_size
self.mixer = self.conv_mixer(dim, depth, kernel_size)
def forward(self, x):
return self.mixer(x)
@staticmethod
def conv_mixer(
dim,
depth,
kernel_size, ):
Seq, ActBn = nn.Sequential, lambda x: Seq(x, nn.GELU(), nn.BatchNorm2D(dim))
Residual = type('Residual', (Seq, ),
{'forward': lambda self, x: self[0](x) + x})
return Seq(*[
Seq(Residual(
ActBn(
nn.Conv2D(
dim, dim, kernel_size, groups=dim, padding="same"))),
ActBn(nn.Conv2D(dim, dim, 1))) for i in range(depth)
])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册