未验证 提交 9c0f0496 编写于 作者: L littletomatodonkey 提交者: GitHub

fix export model eval (#710)

* adapt to net.eval for the framework just contains training flag setting
* fix bug when export swin transformer
上级 dd70cb1b
......@@ -63,7 +63,7 @@ def window_partition(x, window_size):
return windows
def window_reverse(windows, window_size, H, W):
def window_reverse(windows, window_size, H, W, C):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
......@@ -74,10 +74,9 @@ def window_reverse(windows, window_size, H, W):
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.reshape(
[B, H // window_size, W // window_size, window_size, window_size, -1])
x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, H, W, -1])
[-1, H // window_size, W // window_size, window_size, window_size, C])
x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, H, W, C])
return x
......@@ -334,8 +333,8 @@ class SwinTransformerBlock(nn.Layer):
# merge windows
attn_windows = attn_windows.reshape(
[-1, self.window_size, self.window_size, C])
shifted_x = window_reverse(attn_windows, self.window_size, H,
W) # B H' W' C
shifted_x = window_reverse(attn_windows, self.window_size, H, W,
C) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
......@@ -406,7 +405,7 @@ class PatchMerging(nn.Layer):
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.reshape([B, -1, 4 * C]) # B H/2*W/2 4*C
x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
......@@ -551,10 +550,8 @@ class PatchEmbed(nn.Layer):
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
"Input image size ({H}*{W}) doesn't match model ({}*{}).".format(
H, W, self.img_size[0], self.img_size[1])
# TODO (littletomatodonkey), uncomment the line will cause failure of jit.save
# assert [H, W] == self.img_size[:2], "Input image size ({H}*{W}) doesn't match model ({}*{}).".format(H, W, self.img_size[0], self.img_size[1])
x = self.proj(x)
x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
......
......@@ -47,6 +47,12 @@ class Net(paddle.nn.Layer):
self.pre_net = net(class_dim=class_dim)
self.model = model
def eval(self):
self.training = False
for layer in self.sublayers():
layer.training = False
layer.eval()
def forward(self, inputs):
x = self.pre_net(inputs)
if self.model == "GoogLeNet":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册