From b01a79aba7d9e2128824cfd40a2beb7aba93ec55 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Thu, 9 Jun 2022 15:08:45 +0800 Subject: [PATCH] fix swin (#2004) --- .../legendary_models/swin_transformer.py | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/ppcls/arch/backbone/legendary_models/swin_transformer.py b/ppcls/arch/backbone/legendary_models/swin_transformer.py index 2a3401b2..c9511501 100644 --- a/ppcls/arch/backbone/legendary_models/swin_transformer.py +++ b/ppcls/arch/backbone/legendary_models/swin_transformer.py @@ -157,6 +157,7 @@ class WindowAttention(nn.Layer): relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) @@ -168,6 +169,23 @@ class WindowAttention(nn.Layer): trunc_normal_(self.relative_position_bias_table) self.softmax = nn.Softmax(axis=-1) + def eval(self, ): + # this is used to re-param swin for model export + relative_position_bias_table = self.relative_position_bias_table + window_size = self.window_size + index = self.relative_position_index.reshape([-1]) + + relative_position_bias = paddle.index_select( + relative_position_bias_table, index) + relative_position_bias = relative_position_bias.reshape([ + window_size[0] * window_size[1], window_size[0] * window_size[1], + -1 + ]) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.transpose( + [2, 0, 1]) # nH, Wh*Ww, Wh*Ww + relative_position_bias = relative_position_bias.unsqueeze(0) + self.register_buffer("relative_position_bias", relative_position_bias) + def forward(self, x, mask=None): """ Args: @@ -183,18 +201,21 @@ class WindowAttention(nn.Layer): q = q * self.scale attn = paddle.mm(q, k.transpose([0, 1, 3, 2])) - index = self.relative_position_index.reshape([-1]) + if self.training or not hasattr(self, "relative_position_bias"): + index = self.relative_position_index.reshape([-1]) - relative_position_bias = paddle.index_select( - self.relative_position_bias_table, index) - relative_position_bias = relative_position_bias.reshape([ - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], -1 - ]) # Wh*Ww,Wh*Ww,nH + relative_position_bias = paddle.index_select( + self.relative_position_bias_table, index) + relative_position_bias = relative_position_bias.reshape([ + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1 + ]) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.transpose( - [2, 0, 1]) # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) + relative_position_bias = relative_position_bias.transpose( + [2, 0, 1]) # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + else: + attn = attn + self.relative_position_bias if mask is not None: nW = mask.shape[0] -- GitLab