提交 e4a3e1bb 编写于 作者: G gaotingquan 提交者: Wei Shengyu

backbone_zoo -> backbone

上级 0e28a39d
...@@ -19,7 +19,7 @@ import paddle.nn as nn ...@@ -19,7 +19,7 @@ import paddle.nn as nn
from paddle.jit import to_static from paddle.jit import to_static
from paddle.static import InputSpec from paddle.static import InputSpec
from . import backbone as backbone_zoo from . import backbone
from .gears import build_gear from .gears import build_gear
from .utils import * from .utils import *
from .backbone.base.theseus_layer import TheseusLayer from .backbone.base.theseus_layer import TheseusLayer
...@@ -36,7 +36,7 @@ def build_model(config, mode="train"): ...@@ -36,7 +36,7 @@ def build_model(config, mode="train"):
model_type = arch_config.pop("name") model_type = arch_config.pop("name")
use_sync_bn = arch_config.pop("use_sync_bn", False) use_sync_bn = arch_config.pop("use_sync_bn", False)
if hasattr(backbone_zoo, model_type): if hasattr(backbone, model_type):
model = ClassModel(model_type, **arch_config) model = ClassModel(model_type, **arch_config)
else: else:
model = getattr(sys.modules[__name__], model_type)("ClassModel", model = getattr(sys.modules[__name__], model_type)("ClassModel",
...@@ -83,7 +83,7 @@ class ClassModel(TheseusLayer): ...@@ -83,7 +83,7 @@ class ClassModel(TheseusLayer):
else: else:
backbone_name = model_type backbone_name = model_type
backbone_config = config backbone_config = config
self.backbone = getattr(backbone_zoo, backbone_name)(**backbone_config) self.backbone = getattr(backbone, backbone_name)(**backbone_config)
def forward(self, batch): def forward(self, batch):
x, label = batch[0], batch[1] x, label = batch[0], batch[1]
...@@ -95,7 +95,7 @@ class RecModel(TheseusLayer): ...@@ -95,7 +95,7 @@ class RecModel(TheseusLayer):
super().__init__() super().__init__()
backbone_config = config["Backbone"] backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name") backbone_name = backbone_config.pop("name")
self.backbone = getattr(backbone_zoo, backbone_name)(**backbone_config) self.backbone = getattr(backbone, backbone_name)(**backbone_config)
self.head_feature_from = config.get('head_feature_from', 'neck') self.head_feature_from = config.get('head_feature_from', 'neck')
if "BackboneStopLayer" in config: if "BackboneStopLayer" in config:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册