提交 a4e1da66 编写于 作者: Z zhiboniu

modify attr export model

上级 9cf1abae
......@@ -287,7 +287,8 @@ class ResNet(TheseusLayer):
data_format="NCHW",
input_image_channel=3,
return_patterns=None,
return_stages=None):
return_stages=None,
**kargs):
super().__init__()
self.cfg = config
......
......@@ -20,6 +20,7 @@ Arch:
name: "ResNet50"
pretrained: True
class_num: 26
infer_add_softmax: False
# loss function config for traing/eval process
Loss:
......@@ -110,5 +111,3 @@ DataLoader:
Metric:
Eval:
- ATTRMetric:
......@@ -457,7 +457,9 @@ class Engine(object):
def export(self):
assert self.mode == "export"
use_multilabel = self.config["Global"].get("use_multilabel", False)
use_multilabel = self.config["Global"].get(
"use_multilabel",
False) and not "ATTRMetric" in self.config["Metric"]["Eval"][0]
model = ExportModel(self.config["Arch"], self.model, use_multilabel)
if self.config["Global"]["pretrained_model"] is not None:
load_dygraph_pretrain(model.base_model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册