提交 a4e1da66 编写于 作者: Z zhiboniu

modify attr export model

上级 9cf1abae
...@@ -287,7 +287,8 @@ class ResNet(TheseusLayer): ...@@ -287,7 +287,8 @@ class ResNet(TheseusLayer):
data_format="NCHW", data_format="NCHW",
input_image_channel=3, input_image_channel=3,
return_patterns=None, return_patterns=None,
return_stages=None): return_stages=None,
**kargs):
super().__init__() super().__init__()
self.cfg = config self.cfg = config
......
...@@ -20,6 +20,7 @@ Arch: ...@@ -20,6 +20,7 @@ Arch:
name: "ResNet50" name: "ResNet50"
pretrained: True pretrained: True
class_num: 26 class_num: 26
infer_add_softmax: False
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
...@@ -110,5 +111,3 @@ DataLoader: ...@@ -110,5 +111,3 @@ DataLoader:
Metric: Metric:
Eval: Eval:
- ATTRMetric: - ATTRMetric:
...@@ -457,7 +457,9 @@ class Engine(object): ...@@ -457,7 +457,9 @@ class Engine(object):
def export(self): def export(self):
assert self.mode == "export" assert self.mode == "export"
use_multilabel = self.config["Global"].get("use_multilabel", False) use_multilabel = self.config["Global"].get(
"use_multilabel",
False) and not "ATTRMetric" in self.config["Metric"]["Eval"][0]
model = ExportModel(self.config["Arch"], self.model, use_multilabel) model = ExportModel(self.config["Arch"], self.model, use_multilabel)
if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"] is not None:
load_dygraph_pretrain(model.base_model, load_dygraph_pretrain(model.base_model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册