未验证 提交 e6d4d2bc 编写于 作者: W Wenyu 提交者: GitHub

fix export_model for swin (#6399)

上级 c3cda7a8
...@@ -30,14 +30,12 @@ EvalReader: ...@@ -30,14 +30,12 @@ EvalReader:
TestReader: TestReader:
inputs_def: inputs_def:
image_shape: [1, 3, 640, 640] image_shape: [-1, 3, 640, 640]
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [640, 640], keep_ratio: True} - LetterBoxResize: {target_size: 640}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
...@@ -20,7 +20,6 @@ MIT License [see LICENSE for details] ...@@ -20,7 +20,6 @@ MIT License [see LICENSE for details]
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.nn.initializer import TruncatedNormal, Constant, Assign
from ppdet.modeling.shape_spec import ShapeSpec from ppdet.modeling.shape_spec import ShapeSpec
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
import numpy as np import numpy as np
...@@ -64,7 +63,7 @@ def window_partition(x, window_size): ...@@ -64,7 +63,7 @@ def window_partition(x, window_size):
""" """
B, H, W, C = x.shape B, H, W, C = x.shape
x = x.reshape( x = x.reshape(
[B, H // window_size, window_size, W // window_size, window_size, C]) [-1, H // window_size, window_size, W // window_size, window_size, C])
windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape( windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape(
[-1, window_size, window_size, C]) [-1, window_size, window_size, C])
return windows return windows
...@@ -80,10 +79,11 @@ def window_reverse(windows, window_size, H, W): ...@@ -80,10 +79,11 @@ def window_reverse(windows, window_size, H, W):
Returns: Returns:
x: (B, H, W, C) x: (B, H, W, C)
""" """
_, _, _, C = windows.shape
B = int(windows.shape[0] / (H * W / window_size / window_size)) B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.reshape( x = windows.reshape(
[B, H // window_size, W // window_size, window_size, window_size, -1]) [-1, H // window_size, W // window_size, window_size, window_size, C])
x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, H, W, -1]) x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, H, W, C])
return x return x
...@@ -158,14 +158,14 @@ class WindowAttention(nn.Layer): ...@@ -158,14 +158,14 @@ class WindowAttention(nn.Layer):
""" """
B_, N, C = x.shape B_, N, C = x.shape
qkv = self.qkv(x).reshape( qkv = self.qkv(x).reshape(
[B_, N, 3, self.num_heads, C // self.num_heads]).transpose( [-1, N, 3, self.num_heads, C // self.num_heads]).transpose(
[2, 0, 3, 1, 4]) [2, 0, 3, 1, 4])
q, k, v = qkv[0], qkv[1], qkv[2] q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale q = q * self.scale
attn = paddle.mm(q, k.transpose([0, 1, 3, 2])) attn = paddle.mm(q, k.transpose([0, 1, 3, 2]))
index = self.relative_position_index.reshape([-1]) index = self.relative_position_index.flatten()
relative_position_bias = paddle.index_select( relative_position_bias = paddle.index_select(
self.relative_position_bias_table, index) self.relative_position_bias_table, index)
...@@ -179,7 +179,7 @@ class WindowAttention(nn.Layer): ...@@ -179,7 +179,7 @@ class WindowAttention(nn.Layer):
if mask is not None: if mask is not None:
nW = mask.shape[0] nW = mask.shape[0]
attn = attn.reshape([B_ // nW, nW, self.num_heads, N, N attn = attn.reshape([-1, nW, self.num_heads, N, N
]) + mask.unsqueeze(1).unsqueeze(0) ]) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.reshape([-1, self.num_heads, N, N]) attn = attn.reshape([-1, self.num_heads, N, N])
attn = self.softmax(attn) attn = self.softmax(attn)
...@@ -189,7 +189,7 @@ class WindowAttention(nn.Layer): ...@@ -189,7 +189,7 @@ class WindowAttention(nn.Layer):
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
# x = (attn @ v).transpose(1, 2).reshape([B_, N, C]) # x = (attn @ v).transpose(1, 2).reshape([B_, N, C])
x = paddle.mm(attn, v).transpose([0, 2, 1, 3]).reshape([B_, N, C]) x = paddle.mm(attn, v).transpose([0, 2, 1, 3]).reshape([-1, N, C])
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
...@@ -267,7 +267,7 @@ class SwinTransformerBlock(nn.Layer): ...@@ -267,7 +267,7 @@ class SwinTransformerBlock(nn.Layer):
shortcut = x shortcut = x
x = self.norm1(x) x = self.norm1(x)
x = x.reshape([B, H, W, C]) x = x.reshape([-1, H, W, C])
# pad feature maps to multiples of window size # pad feature maps to multiples of window size
pad_l = pad_t = 0 pad_l = pad_t = 0
...@@ -289,7 +289,7 @@ class SwinTransformerBlock(nn.Layer): ...@@ -289,7 +289,7 @@ class SwinTransformerBlock(nn.Layer):
x_windows = window_partition( x_windows = window_partition(
shifted_x, self.window_size) # nW*B, window_size, window_size, C shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.reshape( x_windows = x_windows.reshape(
[-1, self.window_size * self.window_size, [x_windows.shape[0], self.window_size * self.window_size,
C]) # nW*B, window_size*window_size, C C]) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA # W-MSA/SW-MSA
...@@ -298,7 +298,7 @@ class SwinTransformerBlock(nn.Layer): ...@@ -298,7 +298,7 @@ class SwinTransformerBlock(nn.Layer):
# merge windows # merge windows
attn_windows = attn_windows.reshape( attn_windows = attn_windows.reshape(
[-1, self.window_size, self.window_size, C]) [x_windows.shape[0], self.window_size, self.window_size, C])
shifted_x = window_reverse(attn_windows, self.window_size, Hp, shifted_x = window_reverse(attn_windows, self.window_size, Hp,
Wp) # B H' W' C Wp) # B H' W' C
...@@ -314,7 +314,7 @@ class SwinTransformerBlock(nn.Layer): ...@@ -314,7 +314,7 @@ class SwinTransformerBlock(nn.Layer):
if pad_r > 0 or pad_b > 0: if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :] x = x[:, :H, :W, :]
x = x.reshape([B, H * W, C]) x = x.reshape([-1, H * W, C])
# FFN # FFN
x = shortcut + self.drop_path(x) x = shortcut + self.drop_path(x)
...@@ -345,7 +345,7 @@ class PatchMerging(nn.Layer): ...@@ -345,7 +345,7 @@ class PatchMerging(nn.Layer):
B, L, C = x.shape B, L, C = x.shape
assert L == H * W, "input feature has wrong size" assert L == H * W, "input feature has wrong size"
x = x.reshape([B, H, W, C]) x = x.reshape([-1, H, W, C])
# padding # padding
pad_input = (H % 2 == 1) or (W % 2 == 1) pad_input = (H % 2 == 1) or (W % 2 == 1)
...@@ -357,7 +357,7 @@ class PatchMerging(nn.Layer): ...@@ -357,7 +357,7 @@ class PatchMerging(nn.Layer):
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C x = x.reshape([-1, H * W // 4, 4 * C]) # B H/2*W/2 4*C
x = self.norm(x) x = self.norm(x)
x = self.reduction(x) x = self.reduction(x)
...@@ -664,7 +664,7 @@ class SwinTransformer(nn.Layer): ...@@ -664,7 +664,7 @@ class SwinTransformer(nn.Layer):
def forward(self, x): def forward(self, x):
"""Forward function.""" """Forward function."""
x = self.patch_embed(x['image']) x = self.patch_embed(x['image'])
_, _, Wh, Ww = x.shape B, _, Wh, Ww = x.shape
if self.ape: if self.ape:
# interpolate the position embedding to the corresponding size # interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate( absolute_pos_embed = F.interpolate(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册