提交 3672d1f2 编写于 作者: weixin_46524038's avatar weixin_46524038 提交者: cuicheng01

add swinV1 22k weights

上级 5544dbaf
...@@ -35,9 +35,9 @@ MODEL_URLS = { ...@@ -35,9 +35,9 @@ MODEL_URLS = {
"SwinTransformer_base_patch4_window12_384": "SwinTransformer_base_patch4_window12_384":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/SwinTransformer_base_patch4_window12_384_pretrained.pdparams", "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/SwinTransformer_base_patch4_window12_384_pretrained.pdparams",
"SwinTransformer_large_patch4_window7_224": "SwinTransformer_large_patch4_window7_224":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/SwinTransformer_large_patch4_window7_224_22kto1k_pretrained.pdparams", "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/SwinTransformer_large_patch4_window7_224_pretrained.pdparams",
"SwinTransformer_large_patch4_window12_384": "SwinTransformer_large_patch4_window12_384":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/SwinTransformer_large_patch4_window12_384_22kto1k_pretrained.pdparams", "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/SwinTransformer_large_patch4_window12_384_pretrained.pdparams",
} }
__all__ = list(MODEL_URLS.keys()) __all__ = list(MODEL_URLS.keys())
...@@ -45,13 +45,15 @@ __all__ = list(MODEL_URLS.keys()) ...@@ -45,13 +45,15 @@ __all__ = list(MODEL_URLS.keys())
# The following re-implementation of roll is inspired by # The following re-implementation of roll is inspired by
# https://gitee.com/ascend/pytorch/blob/master/torch_npu/contrib/function/roll.py # https://gitee.com/ascend/pytorch/blob/master/torch_npu/contrib/function/roll.py
class RollWithIndexSelect(paddle.autograd.PyLayer): class RollWithIndexSelect(paddle.autograd.PyLayer):
@staticmethod @staticmethod
def forward(ctx, input1, index_fp, index_bp): def forward(ctx, input1, index_fp, index_bp):
N, H, W, C = input1.shape N, H, W, C = input1.shape
ctx.input1 = input1 ctx.input1 = input1
ctx.index_bp = index_bp ctx.index_bp = index_bp
result = input1.reshape([N, H * W, C]).index_select(index_fp, 1).reshape([N, H, W, C]) result = input1.reshape([N, H * W, C]).index_select(
index_fp, 1).reshape([N, H, W, C])
return result return result
@staticmethod @staticmethod
...@@ -59,14 +61,15 @@ class RollWithIndexSelect(paddle.autograd.PyLayer): ...@@ -59,14 +61,15 @@ class RollWithIndexSelect(paddle.autograd.PyLayer):
input1 = ctx.input1 input1 = ctx.input1
N, H, W, C = input1.shape N, H, W, C = input1.shape
index_bp = ctx.index_bp index_bp = ctx.index_bp
grad_input = grad.reshape([N, H * W, C]).index_select(index_bp, 1).reshape([N, H, W, C]) grad_input = grad.reshape([N, H * W, C]).index_select(
index_bp, 1).reshape([N, H, W, C])
return grad_input, None, None return grad_input, None, None
def get_roll_index(H, W, shifts, place): def get_roll_index(H, W, shifts, place):
index = np.arange(0, H * W, dtype=np.int64).reshape([H, W]) index = np.arange(0, H * W, dtype=np.int64).reshape([H, W])
index_fp = np.roll(index, shift=shifts, axis=(0, 1)).reshape([-1]) index_fp = np.roll(index, shift=shifts, axis=(0, 1)).reshape([-1])
index_bp = {i:idx for idx, i in enumerate(index_fp.tolist())} index_bp = {i: idx for idx, i in enumerate(index_fp.tolist())}
index_bp = [index_bp[i] for i in range(H * W)] index_bp = [index_bp[i] for i in range(H * W)]
index_fp = paddle.to_tensor(index_fp, place=place) index_fp = paddle.to_tensor(index_fp, place=place)
index_bp = paddle.to_tensor(index_fp, dtype='int64', place=place) index_bp = paddle.to_tensor(index_fp, dtype='int64', place=place)
...@@ -97,7 +100,9 @@ class RollWrapper(object): ...@@ -97,7 +100,9 @@ class RollWrapper(object):
@staticmethod @staticmethod
def roll(x, shifts, axis): def roll(x, shifts, axis):
if RollWrapper._roll is None: if RollWrapper._roll is None:
RollWrapper._roll = NpuRollWithIndexSelect() if 'npu' in paddle.device.get_all_custom_device_type() else paddle.roll RollWrapper._roll = NpuRollWithIndexSelect(
) if 'npu' in paddle.device.get_all_custom_device_type(
) else paddle.roll
return RollWrapper._roll(x, shifts, axis) return RollWrapper._roll(x, shifts, axis)
...@@ -507,7 +512,7 @@ class PatchMerging(nn.Layer): ...@@ -507,7 +512,7 @@ class PatchMerging(nn.Layer):
# x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C # x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
# x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C # x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.reshape([B, H//2, 2, W//2, 2, C]) x = x.reshape([B, H // 2, 2, W // 2, 2, C])
x = x.transpose((0, 1, 3, 4, 2, 5)) x = x.transpose((0, 1, 3, 4, 2, 5))
x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C
...@@ -703,7 +708,7 @@ class SwinTransformer(TheseusLayer): ...@@ -703,7 +708,7 @@ class SwinTransformer(TheseusLayer):
img_size=224, img_size=224,
patch_size=4, patch_size=4,
in_chans=3, in_chans=3,
class_num=1000, class_num=5,
embed_dim=96, embed_dim=96,
depths=[2, 2, 6, 2], depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24], num_heads=[3, 6, 12, 24],
...@@ -822,11 +827,21 @@ class SwinTransformer(TheseusLayer): ...@@ -822,11 +827,21 @@ class SwinTransformer(TheseusLayer):
return flops return flops
def _load_pretrained(pretrained, model, model_url, use_ssld=False): def _load_pretrained(pretrained,
model,
model_url,
use_ssld=False,
use_imagenet22k_pretrained=False,
use_imagenet22kto1k_pretrained=False):
if pretrained is False: if pretrained is False:
pass pass
elif pretrained is True: elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) load_dygraph_pretrain_from_url(
model,
model_url,
use_ssld=use_ssld,
use_imagenet22k_pretrained=use_imagenet22k_pretrained,
use_imagenet22kto1k_pretrained=use_imagenet22kto1k_pretrained)
elif isinstance(pretrained, str): elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained) load_dygraph_pretrain(model, pretrained)
else: else:
...@@ -835,81 +850,105 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False): ...@@ -835,81 +850,105 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False):
) )
def SwinTransformer_tiny_patch4_window7_224(pretrained=False, def SwinTransformer_tiny_patch4_window7_224(
use_ssld=False, pretrained=False,
**kwargs): use_ssld=False,
use_imagenet22k_pretrained=False,
use_imagenet22kto1k_pretrained=False,
**kwargs):
model = SwinTransformer( model = SwinTransformer(
embed_dim=96, embed_dim=96,
depths=[2, 2, 6, 2], depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24], num_heads=[3, 6, 12, 24],
window_size=7, window_size=7,
drop_path_rate=0.2, drop_path_rate=0.2, # if imagenet22k or imagenet22kto1k, set drop_path_rate=0.1
**kwargs) **kwargs)
_load_pretrained( _load_pretrained(
pretrained, pretrained,
model, model,
MODEL_URLS["SwinTransformer_tiny_patch4_window7_224"], MODEL_URLS["SwinTransformer_tiny_patch4_window7_224"],
use_ssld=use_ssld) use_ssld=use_ssld,
use_imagenet22k_pretrained=use_imagenet22k_pretrained,
use_imagenet22kto1k_pretrained=use_imagenet22kto1k_pretrained)
return model return model
def SwinTransformer_small_patch4_window7_224(pretrained=False, def SwinTransformer_small_patch4_window7_224(
use_ssld=False, pretrained=False,
**kwargs): use_ssld=False,
use_imagenet22k_pretrained=False,
use_imagenet22kto1k_pretrained=False,
**kwargs):
model = SwinTransformer( model = SwinTransformer(
embed_dim=96, embed_dim=96,
depths=[2, 2, 18, 2], depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24], num_heads=[3, 6, 12, 24],
window_size=7, window_size=7,
drop_path_rate=0.3, # if imagenet22k or imagenet22kto1k, set drop_path_rate=0.2
**kwargs) **kwargs)
_load_pretrained( _load_pretrained(
pretrained, pretrained,
model, model,
MODEL_URLS["SwinTransformer_small_patch4_window7_224"], MODEL_URLS["SwinTransformer_small_patch4_window7_224"],
use_ssld=use_ssld) use_ssld=use_ssld,
use_imagenet22k_pretrained=use_imagenet22k_pretrained,
use_imagenet22kto1k_pretrained=use_imagenet22kto1k_pretrained)
return model return model
def SwinTransformer_base_patch4_window7_224(pretrained=False, def SwinTransformer_base_patch4_window7_224(
use_ssld=False, pretrained=False,
**kwargs): use_ssld=False,
use_imagenet22k_pretrained=False,
use_imagenet22kto1k_pretrained=False,
**kwargs):
model = SwinTransformer( model = SwinTransformer(
embed_dim=128, embed_dim=128,
depths=[2, 2, 18, 2], depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32], num_heads=[4, 8, 16, 32],
window_size=7, window_size=7,
drop_path_rate=0.5, drop_path_rate=0.5, # if imagenet22k or imagenet22kto1k, set drop_path_rate=0.2
**kwargs) **kwargs)
_load_pretrained( _load_pretrained(
pretrained, pretrained,
model, model,
MODEL_URLS["SwinTransformer_base_patch4_window7_224"], MODEL_URLS["SwinTransformer_base_patch4_window7_224"],
use_ssld=use_ssld) use_ssld=use_ssld,
use_imagenet22k_pretrained=use_imagenet22k_pretrained,
use_imagenet22kto1k_pretrained=use_imagenet22kto1k_pretrained)
return model return model
def SwinTransformer_base_patch4_window12_384(pretrained=False, def SwinTransformer_base_patch4_window12_384(
use_ssld=False, pretrained=False,
**kwargs): use_ssld=False,
use_imagenet22k_pretrained=False,
use_imagenet22kto1k_pretrained=False,
**kwargs):
model = SwinTransformer( model = SwinTransformer(
img_size=384, img_size=384,
embed_dim=128, embed_dim=128,
depths=[2, 2, 18, 2], depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32], num_heads=[4, 8, 16, 32],
window_size=12, window_size=12,
drop_path_rate=0.5, # NOTE: do not appear in offical code drop_path_rate=0.5, # if imagenet22k or imagenet22kto1k, set drop_path_rate=0.2
**kwargs) **kwargs)
_load_pretrained( _load_pretrained(
pretrained, pretrained,
model, model,
MODEL_URLS["SwinTransformer_base_patch4_window12_384"], MODEL_URLS["SwinTransformer_base_patch4_window12_384"],
use_ssld=use_ssld) use_ssld=use_ssld,
use_imagenet22k_pretrained=use_imagenet22k_pretrained,
use_imagenet22kto1k_pretrained=use_imagenet22kto1k_pretrained)
return model return model
def SwinTransformer_large_patch4_window7_224(pretrained=False, def SwinTransformer_large_patch4_window7_224(
use_ssld=False, pretrained=False,
**kwargs): use_ssld=False,
use_imagenet22k_pretrained=False,
use_imagenet22kto1k_pretrained=True,
**kwargs):
model = SwinTransformer( model = SwinTransformer(
embed_dim=192, embed_dim=192,
depths=[2, 2, 18, 2], depths=[2, 2, 18, 2],
...@@ -920,13 +959,18 @@ def SwinTransformer_large_patch4_window7_224(pretrained=False, ...@@ -920,13 +959,18 @@ def SwinTransformer_large_patch4_window7_224(pretrained=False,
pretrained, pretrained,
model, model,
MODEL_URLS["SwinTransformer_large_patch4_window7_224"], MODEL_URLS["SwinTransformer_large_patch4_window7_224"],
use_ssld=use_ssld) use_ssld=use_ssld,
use_imagenet22k_pretrained=use_imagenet22k_pretrained,
use_imagenet22kto1k_pretrained=use_imagenet22kto1k_pretrained)
return model return model
def SwinTransformer_large_patch4_window12_384(pretrained=False, def SwinTransformer_large_patch4_window12_384(
use_ssld=False, pretrained=False,
**kwargs): use_ssld=False,
use_imagenet22k_pretrained=False,
use_imagenet22kto1k_pretrained=True,
**kwargs):
model = SwinTransformer( model = SwinTransformer(
img_size=384, img_size=384,
embed_dim=192, embed_dim=192,
...@@ -938,5 +982,7 @@ def SwinTransformer_large_patch4_window12_384(pretrained=False, ...@@ -938,5 +982,7 @@ def SwinTransformer_large_patch4_window12_384(pretrained=False,
pretrained, pretrained,
model, model,
MODEL_URLS["SwinTransformer_large_patch4_window12_384"], MODEL_URLS["SwinTransformer_large_patch4_window12_384"],
use_ssld=use_ssld) use_ssld=use_ssld,
use_imagenet22k_pretrained=use_imagenet22k_pretrained,
use_imagenet22kto1k_pretrained=use_imagenet22kto1k_pretrained)
return model return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册