未验证 提交 48417da3 编写于 作者: W Wenyu 提交者: GitHub

fix freeze (#5883)

上级 bb4fbe84
...@@ -688,10 +688,10 @@ class SwinTransformer(nn.Layer): ...@@ -688,10 +688,10 @@ class SwinTransformer(nn.Layer):
if self.frozen_stages >= 0: if self.frozen_stages >= 0:
self.patch_embed.eval() self.patch_embed.eval()
for param in self.patch_embed.parameters(): for param in self.patch_embed.parameters():
param.requires_grad = False param.stop_gradient = True
if self.frozen_stages >= 1 and self.ape: if self.frozen_stages >= 1 and self.ape:
self.absolute_pos_embed.requires_grad = False self.absolute_pos_embed.stop_gradient = True
if self.frozen_stages >= 2: if self.frozen_stages >= 2:
self.pos_drop.eval() self.pos_drop.eval()
...@@ -699,7 +699,7 @@ class SwinTransformer(nn.Layer): ...@@ -699,7 +699,7 @@ class SwinTransformer(nn.Layer):
m = self.layers[i] m = self.layers[i]
m.eval() m.eval()
for param in m.parameters(): for param in m.parameters():
param.requires_grad = False param.stop_gradient = True
def _init_weights(self, m): def _init_weights(self, m):
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册