提交 cdd3c3a0 编写于 作者: Y Yang Nie 提交者: Tingquan Gao

clear type hint

上级 f6ac4a61
...@@ -52,12 +52,12 @@ class InvertedResidual(nn.Layer): ...@@ -52,12 +52,12 @@ class InvertedResidual(nn.Layer):
""" """
def __init__(self, def __init__(self,
in_channels: int, in_channels,
out_channels: int, out_channels,
stride: int, stride,
expand_ratio: Union[int, float], expand_ratio,
dilation: int=1, dilation=1,
skip_connection: Optional[bool]=True) -> None: skip_connection=True):
super().__init__() super().__init__()
assert stride in [1, 2] assert stride in [1, 2]
self.stride = stride self.stride = stride
...@@ -148,12 +148,12 @@ class LinearSelfAttention(nn.Layer): ...@@ -148,12 +148,12 @@ class LinearSelfAttention(nn.Layer):
class LinearAttnFFN(nn.Layer): class LinearAttnFFN(nn.Layer):
def __init__(self, def __init__(self,
embed_dim: int, embed_dim,
ffn_latent_dim: int, ffn_latent_dim,
attn_dropout: Optional[float]=0.0, attn_dropout=0.0,
dropout: Optional[float]=0.1, dropout=0.1,
ffn_dropout: Optional[float]=0.0, ffn_dropout=0.0,
norm_layer: Optional[str]=layer_norm_2d) -> None: norm_layer=layer_norm_2d) -> None:
super().__init__() super().__init__()
attn_unit = LinearSelfAttention( attn_unit = LinearSelfAttention(
embed_dim=embed_dim, attn_dropout=attn_dropout, bias=True) embed_dim=embed_dim, attn_dropout=attn_dropout, bias=True)
...@@ -185,18 +185,18 @@ class MobileViTV2Block(nn.Layer): ...@@ -185,18 +185,18 @@ class MobileViTV2Block(nn.Layer):
""" """
def __init__(self, def __init__(self,
in_channels: int, in_channels,
attn_unit_dim: int, attn_unit_dim,
ffn_multiplier: float=2.0, ffn_multiplier=2.0,
n_attn_blocks: Optional[int]=2, n_attn_blocks=2,
attn_dropout: Optional[float]=0.0, attn_dropout=0.0,
dropout: Optional[float]=0.0, dropout=0.0,
ffn_dropout: Optional[float]=0.0, ffn_dropout=0.0,
patch_h: Optional[int]=8, patch_h=8,
patch_w: Optional[int]=8, patch_w=8,
conv_ksize: Optional[int]=3, conv_ksize=3,
dilation: Optional[int]=1, dilation=1,
attn_norm_layer: Optional[str]=layer_norm_2d): attn_norm_layer=layer_norm_2d):
super().__init__() super().__init__()
cnn_out_dim = attn_unit_dim cnn_out_dim = attn_unit_dim
padding = (conv_ksize - 1) // 2 * dilation padding = (conv_ksize - 1) // 2 * dilation
...@@ -232,15 +232,8 @@ class MobileViTV2Block(nn.Layer): ...@@ -232,15 +232,8 @@ class MobileViTV2Block(nn.Layer):
self.patch_h = patch_h self.patch_h = patch_h
self.patch_w = patch_w self.patch_w = patch_w
def _build_attn_layer(self, def _build_attn_layer(self, d_model, ffn_mult, n_layers, attn_dropout,
d_model: int, dropout, ffn_dropout, attn_norm_layer):
ffn_mult: float,
n_layers: int,
attn_dropout: float,
dropout: float,
ffn_dropout: float,
attn_norm_layer: nn.Layer):
# ensure that dims are multiple of 16 # ensure that dims are multiple of 16
ffn_dims = [ffn_mult * d_model // 16 * 16] * n_layers ffn_dims = [ffn_mult * d_model // 16 * 16] * n_layers
...@@ -271,7 +264,7 @@ class MobileViTV2Block(nn.Layer): ...@@ -271,7 +264,7 @@ class MobileViTV2Block(nn.Layer):
return patches, (img_h, img_w) return patches, (img_h, img_w)
def folding(self, patches, output_size: Tuple[int, int]): def folding(self, patches, output_size):
batch_size, in_dim, patch_size, n_patches = patches.shape batch_size, in_dim, patch_size, n_patches = patches.shape
# [B, C, P, N] # [B, C, P, N]
...@@ -306,10 +299,7 @@ class MobileViTV2(nn.Layer): ...@@ -306,10 +299,7 @@ class MobileViTV2(nn.Layer):
MobileViTV2 MobileViTV2
""" """
def __init__(self, def __init__(self, mobilevit_config, class_num=1000, output_stride=None):
mobilevit_config: Dict,
class_num=1000,
output_stride=None):
super().__init__() super().__init__()
self.round_nearest = 8 self.round_nearest = 8
self.dilation = 1 self.dilation = 1
...@@ -475,7 +465,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False): ...@@ -475,7 +465,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False):
) )
def get_configuration(width_multiplier) -> Dict: def get_configuration(width_multiplier):
ffn_multiplier = 2 ffn_multiplier = 2
mv2_exp_mult = 2 # max(1.0, min(2.0, 2.0 * width_multiplier)) mv2_exp_mult = 2 # max(1.0, min(2.0, 2.0 * width_multiplier))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册