提交 fc9c59c4 编写于 作者: G gaotingquan 提交者: cuicheng01

update pretrained url

上级 fe692cb8
...@@ -22,11 +22,16 @@ from paddle.nn.initializer import XavierUniform, TruncatedNormal, Constant ...@@ -22,11 +22,16 @@ from paddle.nn.initializer import XavierUniform, TruncatedNormal, Constant
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = { MODEL_URLS = {
"CvT_13_224": "", # TODO "CvT_13_224":
"CvT_13_384": "", # TODO "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/CvT_13_224_pretrained.pdparams",
"CvT_21_224": "", # TODO "CvT_13_384":
"CvT_21_384": "", # TODO "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/CvT_13_384_pretrained.pdparams",
"CvT_W24_384": "", # TODO "CvT_21_224":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/CvT_21_224_pretrained.pdparams",
"CvT_21_384":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/CvT_21_384_pretrained.pdparams",
"CvT_W24_384":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/CvT_W24_384_22kto1k_pretrained.pdparams",
} }
__all__ = list(MODEL_URLS.keys()) __all__ = list(MODEL_URLS.keys())
...@@ -509,11 +514,19 @@ class ConvolutionalVisionTransformer(nn.Layer): ...@@ -509,11 +514,19 @@ class ConvolutionalVisionTransformer(nn.Layer):
return x return x
def _load_pretrained(pretrained, model, model_url, use_ssld=False): def _load_pretrained(pretrained,
model,
model_url,
use_ssld=False,
use_imagenet22kto1k_pretrained=False):
if pretrained is False: if pretrained is False:
pass pass
elif pretrained is True: elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) load_dygraph_pretrain_from_url(
model,
model_url,
use_ssld=use_ssld,
use_imagenet22kto1k_pretrained=use_imagenet22kto1k_pretrained)
elif isinstance(pretrained, str): elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained) load_dygraph_pretrain(model, pretrained)
else: else:
...@@ -556,7 +569,10 @@ def CvT_13_224(pretrained=False, use_ssld=False, **kwargs): ...@@ -556,7 +569,10 @@ def CvT_13_224(pretrained=False, use_ssld=False, **kwargs):
return model return model
def CvT_13_384(pretrained=False, use_ssld=False, **kwargs): def CvT_13_384(pretrained=False,
use_ssld=False,
use_imagenet22kto1k_pretrained=False,
**kwargs):
msvit_spec = dict( msvit_spec = dict(
INIT='trunc_norm', INIT='trunc_norm',
NUM_STAGES=3, NUM_STAGES=3,
...@@ -586,7 +602,11 @@ def CvT_13_384(pretrained=False, use_ssld=False, **kwargs): ...@@ -586,7 +602,11 @@ def CvT_13_384(pretrained=False, use_ssld=False, **kwargs):
spec=msvit_spec, spec=msvit_spec,
**kwargs) **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["CvT_13_384"], use_ssld=use_ssld) pretrained,
model,
MODEL_URLS["CvT_13_384"],
use_ssld=use_ssld,
use_imagenet22kto1k_pretrained=use_imagenet22kto1k_pretrained)
return model return model
...@@ -624,7 +644,10 @@ def CvT_21_224(pretrained=False, use_ssld=False, **kwargs): ...@@ -624,7 +644,10 @@ def CvT_21_224(pretrained=False, use_ssld=False, **kwargs):
return model return model
def CvT_21_384(pretrained=False, use_ssld=False, **kwargs): def CvT_21_384(pretrained=False,
use_ssld=False,
use_imagenet22kto1k_pretrained=False,
**kwargs):
msvit_spec = dict( msvit_spec = dict(
INIT='trunc_norm', INIT='trunc_norm',
NUM_STAGES=3, NUM_STAGES=3,
...@@ -654,7 +677,11 @@ def CvT_21_384(pretrained=False, use_ssld=False, **kwargs): ...@@ -654,7 +677,11 @@ def CvT_21_384(pretrained=False, use_ssld=False, **kwargs):
spec=msvit_spec, spec=msvit_spec,
**kwargs) **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["CvT_21_384"], use_ssld=use_ssld) pretrained,
model,
MODEL_URLS["CvT_21_384"],
use_ssld=use_ssld,
use_imagenet22kto1k_pretrained=use_imagenet22kto1k_pretrained)
return model return model
...@@ -688,5 +715,9 @@ def CvT_W24_384(pretrained=False, use_ssld=False, **kwargs): ...@@ -688,5 +715,9 @@ def CvT_W24_384(pretrained=False, use_ssld=False, **kwargs):
spec=msvit_spec, spec=msvit_spec,
**kwargs) **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["CvT_W24_384"], use_ssld=use_ssld) pretrained,
model,
MODEL_URLS["CvT_W24_384"],
use_ssld=use_ssld,
use_imagenet22kto1k_pretrained=True)
return model return model
...@@ -23,10 +23,14 @@ import paddle.nn as nn ...@@ -23,10 +23,14 @@ import paddle.nn as nn
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = { MODEL_URLS = {
"MicroNet_M0": "", # TODO "MicroNet_M0":
"MicroNet_M1": "", # TODO "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MicroNet_M0_pretrained.pdparams",
"MicroNet_M2": "", # TODO "MicroNet_M1":
"MicroNet_M3": "", # TODO "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MicroNet_M1_pretrained.pdparams",
"MicroNet_M2":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MicroNet_M2_pretrained.pdparams",
"MicroNet_M3":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MicroNet_M3_pretrained.pdparams",
} }
__all__ = MODEL_URLS.keys() __all__ = MODEL_URLS.keys()
......
...@@ -24,7 +24,8 @@ MODEL_URLS = { ...@@ -24,7 +24,8 @@ MODEL_URLS = {
"MobileNeXt_x0_35": "", # TODO "MobileNeXt_x0_35": "", # TODO
"MobileNeXt_x0_5": "", # TODO "MobileNeXt_x0_5": "", # TODO
"MobileNeXt_x0_75": "", # TODO "MobileNeXt_x0_75": "", # TODO
"MobileNeXt_x1_0": "", # TODO "MobileNeXt_x1_0":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNeXt_x1_0_pretrained.pdparams",
"MobileNeXt_x1_4": "", # TODO "MobileNeXt_x1_4": "", # TODO
} }
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Code was based on https://github.com/micronDLA/MobileViTv3/blob/main/MobileViTv3-v1/cvnets/models/classification/mobilevit.py # Code was based on https://github.com/micronDLA/MobileViTV3/blob/main/MobileViTv3-v1/cvnets/models/classification/mobilevit.py
# reference: https://arxiv.org/abs/2209.15159 # reference: https://arxiv.org/abs/2209.15159
import math import math
...@@ -26,15 +26,24 @@ import paddle.nn.functional as F ...@@ -26,15 +26,24 @@ import paddle.nn.functional as F
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = { MODEL_URLS = {
"MobileViTv3_XXS": "", "MobileViTV3_XXS":
"MobileViTv3_XS": "", "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViTV3_XXS_pretrained.pdparams",
"MobileViTv3_S": "", "MobileViTV3_XS":
"MobileViTv3_XXS_L2": "", "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViTV3_XS_pretrained.pdparams",
"MobileViTv3_XS_L2": "", "MobileViTV3_S":
"MobileViTv3_S_L2": "", "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViTV3_S_pretrained.pdparams",
"MobileViTv3_x0_5": "", "MobileViTV3_XXS_L2":
"MobileViTv3_x0_75": "", "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViTV3_XXS_L2_pretrained.pdparams",
"MobileViTv3_x1_0": "", "MobileViTV3_XS_L2":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViTV3_XS_L2_pretrained.pdparams",
"MobileViTV3_S_L2":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViTV3_S_L2_pretrained.pdparams",
"MobileViTV3_x0_5":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViTV3_x0_5_pretrained.pdparams",
"MobileViTV3_x0_75":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViTV3_x0_75_pretrained.pdparams",
"MobileViTV3_x1_0":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileViTV3_x1_0_pretrained.pdparams",
} }
layer_norm_2d = partial(nn.GroupNorm, num_groups=1) layer_norm_2d = partial(nn.GroupNorm, num_groups=1)
...@@ -185,9 +194,9 @@ class TransformerEncoder(nn.Layer): ...@@ -185,9 +194,9 @@ class TransformerEncoder(nn.Layer):
return x return x
class MobileViTv3Block(nn.Layer): class MobileViTV3Block(nn.Layer):
""" """
MobileViTv3 block MobileViTV3 block
""" """
def __init__(self, def __init__(self,
...@@ -207,7 +216,7 @@ class MobileViTv3Block(nn.Layer): ...@@ -207,7 +216,7 @@ class MobileViTv3Block(nn.Layer):
var_ffn: Optional[bool]=False, var_ffn: Optional[bool]=False,
no_fusion: Optional[bool]=False): no_fusion: Optional[bool]=False):
# For MobileViTv3: Normal 3x3 convolution --> Depthwise 3x3 convolution # For MobileViTV3: Normal 3x3 convolution --> Depthwise 3x3 convolution
padding = (conv_ksize - 1) // 2 * dilation padding = (conv_ksize - 1) // 2 * dilation
conv_3x3_in = nn.Sequential( conv_3x3_in = nn.Sequential(
('conv', nn.Conv2D( ('conv', nn.Conv2D(
...@@ -228,7 +237,7 @@ class MobileViTv3Block(nn.Layer): ...@@ -228,7 +237,7 @@ class MobileViTv3Block(nn.Layer):
('norm', nn.BatchNorm2D(in_channels)), ('act', nn.Silu())) ('norm', nn.BatchNorm2D(in_channels)), ('act', nn.Silu()))
conv_3x3_out = None conv_3x3_out = None
# For MobileViTv3: input+global --> local+global # For MobileViTV3: input+global --> local+global
if not no_fusion: if not no_fusion:
#input_ch = tr_dim + in_ch #input_ch = tr_dim + in_ch
conv_3x3_out = nn.Sequential( conv_3x3_out = nn.Sequential(
...@@ -375,7 +384,7 @@ class MobileViTv3Block(nn.Layer): ...@@ -375,7 +384,7 @@ class MobileViTv3Block(nn.Layer):
def forward(self, x): def forward(self, x):
res = x res = x
# For MobileViTv3: Normal 3x3 convolution --> Depthwise 3x3 convolution # For MobileViTV3: Normal 3x3 convolution --> Depthwise 3x3 convolution
fm_conv = self.local_rep(x) fm_conv = self.local_rep(x)
# convert feature map to patches # convert feature map to patches
...@@ -390,10 +399,10 @@ class MobileViTv3Block(nn.Layer): ...@@ -390,10 +399,10 @@ class MobileViTv3Block(nn.Layer):
fm = self.conv_proj(fm) fm = self.conv_proj(fm)
if self.fusion is not None: if self.fusion is not None:
# For MobileViTv3: input+global --> local+global # For MobileViTV3: input+global --> local+global
fm = self.fusion(paddle.concat((fm_conv, fm), axis=1)) fm = self.fusion(paddle.concat((fm_conv, fm), axis=1))
# For MobileViTv3: Skip connection # For MobileViTV3: Skip connection
fm = fm + res fm = fm + res
return fm return fm
...@@ -470,9 +479,9 @@ class LinearAttnFFN(nn.Layer): ...@@ -470,9 +479,9 @@ class LinearAttnFFN(nn.Layer):
return x return x
class MobileViTv3Block_v2(nn.Layer): class MobileViTV3Block_v2(nn.Layer):
""" """
This class defines the `MobileViTv3 block` This class defines the `MobileViTV3 block`
""" """
def __init__(self, def __init__(self,
...@@ -516,7 +525,7 @@ class MobileViTv3Block_v2(nn.Layer): ...@@ -516,7 +525,7 @@ class MobileViTv3Block_v2(nn.Layer):
ffn_dropout=ffn_dropout, ffn_dropout=ffn_dropout,
attn_norm_layer=attn_norm_layer) attn_norm_layer=attn_norm_layer)
# MobileViTv3: input changed from just global to local+global # MobileViTV3: input changed from just global to local+global
self.conv_proj = nn.Sequential( self.conv_proj = nn.Sequential(
('conv', nn.Conv2D( ('conv', nn.Conv2D(
2 * cnn_out_dim, in_channels, 1, bias_attr=False)), 2 * cnn_out_dim, in_channels, 1, bias_attr=False)),
...@@ -590,18 +599,18 @@ class MobileViTv3Block_v2(nn.Layer): ...@@ -590,18 +599,18 @@ class MobileViTv3Block_v2(nn.Layer):
# [B x Patch x Patches x C] --> [B x C x Patches x Patch] # [B x Patch x Patches x C] --> [B x C x Patches x Patch]
fm = self.folding(patches=patches, output_size=output_size) fm = self.folding(patches=patches, output_size=output_size)
# MobileViTv3: local+global instead of only global # MobileViTV3: local+global instead of only global
fm = self.conv_proj(paddle.concat((fm, fm_conv), axis=1)) fm = self.conv_proj(paddle.concat((fm, fm_conv), axis=1))
# MobileViTv3: skip connection # MobileViTV3: skip connection
fm = fm + x fm = fm + x
return fm return fm
class MobileViTv3(nn.Layer): class MobileViTV3(nn.Layer):
""" """
MobileViTv3: MobileViTV3:
""" """
def __init__(self, def __init__(self,
...@@ -740,7 +749,7 @@ class MobileViTv3(nn.Layer): ...@@ -740,7 +749,7 @@ class MobileViTv3(nn.Layer):
if self.mobilevit_v2_based: if self.mobilevit_v2_based:
block.append( block.append(
MobileViTv3Block_v2( MobileViTV3Block_v2(
in_channels=input_channel, in_channels=input_channel,
attn_unit_dim=cfg["attn_unit_dim"], attn_unit_dim=cfg["attn_unit_dim"],
ffn_multiplier=cfg.get("ffn_multiplier"), ffn_multiplier=cfg.get("ffn_multiplier"),
...@@ -765,7 +774,7 @@ class MobileViTv3(nn.Layer): ...@@ -765,7 +774,7 @@ class MobileViTv3(nn.Layer):
"Got {} and {}.".format(transformer_dim, head_dim)) "Got {} and {}.".format(transformer_dim, head_dim))
block.append( block.append(
MobileViTv3Block( MobileViTV3Block(
in_channels=input_channel, in_channels=input_channel,
transformer_dim=transformer_dim, transformer_dim=transformer_dim,
ffn_dim=ffn_dim, ffn_dim=ffn_dim,
...@@ -827,7 +836,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False): ...@@ -827,7 +836,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False):
) )
def MobileViTv3_S(pretrained=False, use_ssld=False, **kwargs): def MobileViTV3_S(pretrained=False, use_ssld=False, **kwargs):
mv2_exp_mult = 4 mv2_exp_mult = 4
mobilevit_config = { mobilevit_config = {
"layer0": { "layer0": {
...@@ -890,14 +899,14 @@ def MobileViTv3_S(pretrained=False, use_ssld=False, **kwargs): ...@@ -890,14 +899,14 @@ def MobileViTv3_S(pretrained=False, use_ssld=False, **kwargs):
"last_layer_exp_factor": 4 "last_layer_exp_factor": 4
} }
model = MobileViTv3(mobilevit_config, **kwargs) model = MobileViTV3(mobilevit_config, **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["MobileViTv3_S"], use_ssld=use_ssld) pretrained, model, MODEL_URLS["MobileViTV3_S"], use_ssld=use_ssld)
return model return model
def MobileViTv3_XS(pretrained=False, use_ssld=False, **kwargs): def MobileViTV3_XS(pretrained=False, use_ssld=False, **kwargs):
mv2_exp_mult = 4 mv2_exp_mult = 4
mobilevit_config = { mobilevit_config = {
"layer0": { "layer0": {
...@@ -960,14 +969,14 @@ def MobileViTv3_XS(pretrained=False, use_ssld=False, **kwargs): ...@@ -960,14 +969,14 @@ def MobileViTv3_XS(pretrained=False, use_ssld=False, **kwargs):
"last_layer_exp_factor": 4 "last_layer_exp_factor": 4
} }
model = MobileViTv3(mobilevit_config, **kwargs) model = MobileViTV3(mobilevit_config, **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["MobileViTv3_XS"], use_ssld=use_ssld) pretrained, model, MODEL_URLS["MobileViTV3_XS"], use_ssld=use_ssld)
return model return model
def MobileViTv3_XXS(pretrained=False, use_ssld=False, **kwargs): def MobileViTV3_XXS(pretrained=False, use_ssld=False, **kwargs):
mv2_exp_mult = 2 mv2_exp_mult = 2
mobilevit_config = { mobilevit_config = {
"layer0": { "layer0": {
...@@ -1030,14 +1039,14 @@ def MobileViTv3_XXS(pretrained=False, use_ssld=False, **kwargs): ...@@ -1030,14 +1039,14 @@ def MobileViTv3_XXS(pretrained=False, use_ssld=False, **kwargs):
"last_layer_exp_factor": 4 "last_layer_exp_factor": 4
} }
model = MobileViTv3(mobilevit_config, **kwargs) model = MobileViTV3(mobilevit_config, **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["MobileViTv3_XXS"], use_ssld=use_ssld) pretrained, model, MODEL_URLS["MobileViTV3_XXS"], use_ssld=use_ssld)
return model return model
def MobileViTv3_S_L2(pretrained=False, use_ssld=False, **kwargs): def MobileViTV3_S_L2(pretrained=False, use_ssld=False, **kwargs):
mv2_exp_mult = 4 mv2_exp_mult = 4
mobilevit_config = { mobilevit_config = {
"layer0": { "layer0": {
...@@ -1100,14 +1109,14 @@ def MobileViTv3_S_L2(pretrained=False, use_ssld=False, **kwargs): ...@@ -1100,14 +1109,14 @@ def MobileViTv3_S_L2(pretrained=False, use_ssld=False, **kwargs):
"last_layer_exp_factor": 4 "last_layer_exp_factor": 4
} }
model = MobileViTv3(mobilevit_config, **kwargs) model = MobileViTV3(mobilevit_config, **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["MobileViTv3_S_L2"], use_ssld=use_ssld) pretrained, model, MODEL_URLS["MobileViTV3_S_L2"], use_ssld=use_ssld)
return model return model
def MobileViTv3_XS_L2(pretrained=False, use_ssld=False, **kwargs): def MobileViTV3_XS_L2(pretrained=False, use_ssld=False, **kwargs):
mv2_exp_mult = 4 mv2_exp_mult = 4
mobilevit_config = { mobilevit_config = {
"layer0": { "layer0": {
...@@ -1170,14 +1179,14 @@ def MobileViTv3_XS_L2(pretrained=False, use_ssld=False, **kwargs): ...@@ -1170,14 +1179,14 @@ def MobileViTv3_XS_L2(pretrained=False, use_ssld=False, **kwargs):
"last_layer_exp_factor": 4 "last_layer_exp_factor": 4
} }
model = MobileViTv3(mobilevit_config, **kwargs) model = MobileViTV3(mobilevit_config, **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["MobileViTv3_XS_L2"], use_ssld=use_ssld) pretrained, model, MODEL_URLS["MobileViTV3_XS_L2"], use_ssld=use_ssld)
return model return model
def MobileViTv3_XXS_L2(pretrained=False, use_ssld=False, **kwargs): def MobileViTV3_XXS_L2(pretrained=False, use_ssld=False, **kwargs):
mv2_exp_mult = 2 mv2_exp_mult = 2
mobilevit_config = { mobilevit_config = {
"layer0": { "layer0": {
...@@ -1240,14 +1249,14 @@ def MobileViTv3_XXS_L2(pretrained=False, use_ssld=False, **kwargs): ...@@ -1240,14 +1249,14 @@ def MobileViTv3_XXS_L2(pretrained=False, use_ssld=False, **kwargs):
"last_layer_exp_factor": 4 "last_layer_exp_factor": 4
} }
model = MobileViTv3(mobilevit_config, **kwargs) model = MobileViTV3(mobilevit_config, **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["MobileViTv3_XXS_L2"], use_ssld=use_ssld) pretrained, model, MODEL_URLS["MobileViTV3_XXS_L2"], use_ssld=use_ssld)
return model return model
def MobileViTv3_x1_0(pretrained=False, use_ssld=False, **kwargs): def MobileViTV3_x1_0(pretrained=False, use_ssld=False, **kwargs):
mobilevit_config = { mobilevit_config = {
"layer0": { "layer0": {
"img_channels": 3, "img_channels": 3,
...@@ -1303,14 +1312,14 @@ def MobileViTv3_x1_0(pretrained=False, use_ssld=False, **kwargs): ...@@ -1303,14 +1312,14 @@ def MobileViTv3_x1_0(pretrained=False, use_ssld=False, **kwargs):
"last_layer_exp_factor": 4, "last_layer_exp_factor": 4,
} }
model = MobileViTv3(mobilevit_config, mobilevit_v2_based=True, **kwargs) model = MobileViTV3(mobilevit_config, mobilevit_v2_based=True, **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["MobileViTv3_x1_0"], use_ssld=use_ssld) pretrained, model, MODEL_URLS["MobileViTV3_x1_0"], use_ssld=use_ssld)
return model return model
def MobileViTv3_x0_75(pretrained=False, use_ssld=False, **kwargs): def MobileViTV3_x0_75(pretrained=False, use_ssld=False, **kwargs):
mobilevit_config = { mobilevit_config = {
"layer0": { "layer0": {
"img_channels": 3, "img_channels": 3,
...@@ -1366,14 +1375,14 @@ def MobileViTv3_x0_75(pretrained=False, use_ssld=False, **kwargs): ...@@ -1366,14 +1375,14 @@ def MobileViTv3_x0_75(pretrained=False, use_ssld=False, **kwargs):
"last_layer_exp_factor": 4, "last_layer_exp_factor": 4,
} }
model = MobileViTv3(mobilevit_config, mobilevit_v2_based=True, **kwargs) model = MobileViTV3(mobilevit_config, mobilevit_v2_based=True, **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["MobileViTv3_x0_75"], use_ssld=use_ssld) pretrained, model, MODEL_URLS["MobileViTV3_x0_75"], use_ssld=use_ssld)
return model return model
def MobileViTv3_x0_5(pretrained=False, use_ssld=False, **kwargs): def MobileViTV3_x0_5(pretrained=False, use_ssld=False, **kwargs):
mobilevit_config = { mobilevit_config = {
"layer0": { "layer0": {
"img_channels": 3, "img_channels": 3,
...@@ -1429,8 +1438,8 @@ def MobileViTv3_x0_5(pretrained=False, use_ssld=False, **kwargs): ...@@ -1429,8 +1438,8 @@ def MobileViTv3_x0_5(pretrained=False, use_ssld=False, **kwargs):
"last_layer_exp_factor": 4, "last_layer_exp_factor": 4,
} }
model = MobileViTv3(mobilevit_config, mobilevit_v2_based=True, **kwargs) model = MobileViTV3(mobilevit_config, mobilevit_v2_based=True, **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["MobileViTv3_x0_5"], use_ssld=use_ssld) pretrained, model, MODEL_URLS["MobileViTV3_x0_5"], use_ssld=use_ssld)
return model return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册