提交 c4572cc5 编写于 作者: R Ross Wightman

Add Visformer-small weighs, tweak torchscript jit test img size.

上级 83487e2a
......@@ -33,7 +33,11 @@ else:
TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
TARGET_BWD_SIZE = 128
MAX_BWD_SIZE = 320
MAX_FWD_FEAT_SIZE = 448
MAX_FWD_OUT_SIZE = 448
TARGET_JIT_SIZE = 128
MAX_JIT_SIZE = 320
TARGET_FFEAT_SIZE = 96
MAX_FFEAT_SIZE = 256
def _get_input_size(model, target=None):
......@@ -109,10 +113,10 @@ def test_model_default_cfgs(model_name, batch_size):
pool_size = cfg['pool_size']
input_size = model.default_cfg['input_size']
if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and \
if all([x <= MAX_FWD_OUT_SIZE for x in input_size]) and \
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
# output sizes only checked if default res <= 448 * 448 to keep resource down
input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size])
input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size])
input_tensor = torch.randn((batch_size, *input_size))
# test forward_features (always unpooled)
......@@ -176,8 +180,8 @@ def test_model_forward_torchscript(model_name, batch_size):
model = create_model(model_name, pretrained=False)
model.eval()
input_size = _get_input_size(model, 128)
if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
input_size = _get_input_size(model, TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
pytest.skip("Fixed input size model > limit.")
model = torch.jit.script(model)
......@@ -205,8 +209,8 @@ def test_model_forward_features(model_name, batch_size):
expected_channels = model.feature_info.channels()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
input_size = _get_input_size(model, 96) # jit compile is already a bit slow and we've tested normal res already...
if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
input_size = _get_input_size(model, TARGET_FFEAT_SIZE)
if max(input_size) > MAX_FFEAT_SIZE:
pytest.skip("Fixed input size model > limit.")
outputs = model(torch.randn((batch_size, *input_size)))
......
......@@ -33,7 +33,9 @@ def _cfg(url='', **kwargs):
default_cfgs = dict(
visformer_tiny=_cfg(),
visformer_small=_cfg(),
visformer_small=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/visformer_small-839e1f5b.pth'
),
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册