未验证 提交 b0fa8dce 编写于 作者: L littletomatodonkey 提交者: GitHub

fix log when running vit in static mode (#654)

* fix log when running vit in statich mode

* fix comment
上级 23104797
......@@ -253,11 +253,9 @@ class VisionTransformer(nn.Layer):
self.head = nn.Linear(embed_dim,
class_dim) if class_dim > 0 else Identity()
# TODO(littletomatodonkey): same init in static mode
if paddle.in_dynamic_mode():
trunc_normal_(self.pos_embed)
trunc_normal_(self.cls_token)
self.apply(self._init_weights)
trunc_normal_(self.pos_embed)
trunc_normal_(self.cls_token)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
......
......@@ -30,6 +30,13 @@ def get_architectures():
return names
def get_blacklist_model_in_static_mode():
from ppcls.modeling.architectures import distilled_vision_transformer
from ppcls.modeling.architectures import vision_transformer
blacklist = distilled_vision_transformer.__all__ + vision_transformer.__all__
return blacklist
def similar_architectures(name='', names=[], thresh=0.1, topk=10):
"""
inferred similar architectures
......
......@@ -24,6 +24,7 @@ from paddle import is_compiled_with_cuda
from ppcls.modeling import get_architectures
from ppcls.modeling import similar_architectures
from ppcls.modeling import get_blacklist_model_in_static_mode
from ppcls.utils import logger
......@@ -79,6 +80,19 @@ def check_architecture(architecture):
sys.exit(1)
def check_model_with_running_mode(architecture):
"""
check whether the model is consistent with the operating mode
"""
# some model are not supported in the static mode
blacklist = get_blacklist_model_in_static_mode()
if not paddle.in_dynamic_mode() and architecture["name"] in blacklist:
logger.error("Model: {} is not supported in the staic mode.".format(
architecture["name"]))
sys.exit(1)
return
def check_mix(architecture, use_mix=False):
"""
check mix parameter
......
......@@ -104,6 +104,7 @@ def check_config(config):
architecture = config.get('ARCHITECTURE')
check.check_architecture(architecture)
check.check_model_with_running_mode(architecture)
use_mix = config.get('use_mix', False)
check.check_mix(architecture, use_mix)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册