提交 cee1f68e 编写于 作者: C cuicheng01

fix ml export_model

上级 94b1b920
...@@ -37,15 +37,17 @@ def parse_args(): ...@@ -37,15 +37,17 @@ def parse_args():
parser.add_argument("--class_dim", type=int, default=1000) parser.add_argument("--class_dim", type=int, default=1000)
parser.add_argument("--load_static_weights", type=str2bool, default=False) parser.add_argument("--load_static_weights", type=str2bool, default=False)
parser.add_argument("--img_size", type=int, default=224) parser.add_argument("--img_size", type=int, default=224)
parser.add_argument("--multilabel", type=str2bool, default=False)
return parser.parse_args() return parser.parse_args()
class Net(paddle.nn.Layer): class Net(paddle.nn.Layer):
def __init__(self, net, class_dim, model): def __init__(self, net, class_dim, model, multilabel):
super(Net, self).__init__() super(Net, self).__init__()
self.pre_net = net(class_dim=class_dim) self.pre_net = net(class_dim=class_dim)
self.model = model self.model = model
self.multilabel = multilabel
def eval(self): def eval(self):
self.training = False self.training = False
...@@ -57,7 +59,7 @@ class Net(paddle.nn.Layer): ...@@ -57,7 +59,7 @@ class Net(paddle.nn.Layer):
x = self.pre_net(inputs) x = self.pre_net(inputs)
if self.model == "GoogLeNet": if self.model == "GoogLeNet":
x = x[0] x = x[0]
x = F.softmax(x) x = F.softmax(x) if not self.multilabel else F.sigmoid(x)
return x return x
...@@ -65,7 +67,7 @@ def main(): ...@@ -65,7 +67,7 @@ def main():
args = parse_args() args = parse_args()
net = architectures.__dict__[args.model] net = architectures.__dict__[args.model]
model = Net(net, args.class_dim, args.model) model = Net(net, args.class_dim, args.model, args.multilabel)
load_dygraph_pretrain( load_dygraph_pretrain(
model.pre_net, model.pre_net,
path=args.pretrained_model, path=args.pretrained_model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册