提交 600067f4 编写于 作者: W weishengyu

dbg

上级 e9162595
......@@ -33,11 +33,8 @@ __all__ = ["build_model", "RecModel", "DistillationModel"]
def build_model(config):
config = copy.deepcopy(config)
model_type = config.pop("name")
return_patterns = config.pop("return_patterns", None)
mod = importlib.import_module(__name__)
arch = getattr(mod, model_type)(**config)
if return_patterns is not None and isinstance(arch, TheseusLayer):
arch.update_res(return_patterns=return_patterns, return_dict=True)
return arch
......@@ -59,10 +56,7 @@ class RecModel(nn.Layer):
super().__init__()
backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name")
return_patterns = config.pop("return_patterns", None)
self.backbone = eval(backbone_name)(**backbone_config)
if return_patterns is not None and isinstance(self.backbone, TheseusLayer):
self.backbone.update_res(return_patterns=return_patterns, return_dict=True)
if "BackboneStopLayer" in config:
backbone_stop_layer = config["BackboneStopLayer"]["name"]
self.backbone.stop_after(backbone_stop_layer)
......
......@@ -333,6 +333,8 @@ class Engine(object):
out = self.model(batch_tensor)
if isinstance(out, list):
out = out[0]
if isinstance(out, dict):
out = out["output"]
result = self.postprocess_func(out, image_file_list)
print(result)
batch_data.clear()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册