未验证 提交 b01a79ab 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix swin (#2004)

上级 794af8c0
...@@ -157,6 +157,7 @@ class WindowAttention(nn.Layer): ...@@ -157,6 +157,7 @@ class WindowAttention(nn.Layer):
relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", self.register_buffer("relative_position_index",
relative_position_index) relative_position_index)
...@@ -168,6 +169,23 @@ class WindowAttention(nn.Layer): ...@@ -168,6 +169,23 @@ class WindowAttention(nn.Layer):
trunc_normal_(self.relative_position_bias_table) trunc_normal_(self.relative_position_bias_table)
self.softmax = nn.Softmax(axis=-1) self.softmax = nn.Softmax(axis=-1)
def eval(self, ):
# this is used to re-param swin for model export
relative_position_bias_table = self.relative_position_bias_table
window_size = self.window_size
index = self.relative_position_index.reshape([-1])
relative_position_bias = paddle.index_select(
relative_position_bias_table, index)
relative_position_bias = relative_position_bias.reshape([
window_size[0] * window_size[1], window_size[0] * window_size[1],
-1
]) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.transpose(
[2, 0, 1]) # nH, Wh*Ww, Wh*Ww
relative_position_bias = relative_position_bias.unsqueeze(0)
self.register_buffer("relative_position_bias", relative_position_bias)
def forward(self, x, mask=None): def forward(self, x, mask=None):
""" """
Args: Args:
...@@ -183,18 +201,21 @@ class WindowAttention(nn.Layer): ...@@ -183,18 +201,21 @@ class WindowAttention(nn.Layer):
q = q * self.scale q = q * self.scale
attn = paddle.mm(q, k.transpose([0, 1, 3, 2])) attn = paddle.mm(q, k.transpose([0, 1, 3, 2]))
index = self.relative_position_index.reshape([-1]) if self.training or not hasattr(self, "relative_position_bias"):
index = self.relative_position_index.reshape([-1])
relative_position_bias = paddle.index_select( relative_position_bias = paddle.index_select(
self.relative_position_bias_table, index) self.relative_position_bias_table, index)
relative_position_bias = relative_position_bias.reshape([ relative_position_bias = relative_position_bias.reshape([
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1], -1 self.window_size[0] * self.window_size[1], -1
]) # Wh*Ww,Wh*Ww,nH ]) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.transpose( relative_position_bias = relative_position_bias.transpose(
[2, 0, 1]) # nH, Wh*Ww, Wh*Ww [2, 0, 1]) # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0) attn = attn + relative_position_bias.unsqueeze(0)
else:
attn = attn + self.relative_position_bias
if mask is not None: if mask is not None:
nW = mask.shape[0] nW = mask.shape[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册