From cee1f68e92f20f7661c560ce76b3adf99564d2b4 Mon Sep 17 00:00:00 2001 From: cuicheng01 Date: Thu, 12 Aug 2021 02:35:56 +0000 Subject: [PATCH] fix ml export_model --- tools/export_model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tools/export_model.py b/tools/export_model.py index 5d6b338d..299bac56 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -37,15 +37,17 @@ def parse_args(): parser.add_argument("--class_dim", type=int, default=1000) parser.add_argument("--load_static_weights", type=str2bool, default=False) parser.add_argument("--img_size", type=int, default=224) + parser.add_argument("--multilabel", type=str2bool, default=False) return parser.parse_args() class Net(paddle.nn.Layer): - def __init__(self, net, class_dim, model): + def __init__(self, net, class_dim, model, multilabel): super(Net, self).__init__() self.pre_net = net(class_dim=class_dim) self.model = model + self.multilabel = multilabel def eval(self): self.training = False @@ -57,7 +59,7 @@ class Net(paddle.nn.Layer): x = self.pre_net(inputs) if self.model == "GoogLeNet": x = x[0] - x = F.softmax(x) + x = F.softmax(x) if not self.multilabel else F.sigmoid(x) return x @@ -65,7 +67,7 @@ def main(): args = parse_args() net = architectures.__dict__[args.model] - model = Net(net, args.class_dim, args.model) + model = Net(net, args.class_dim, args.model, args.multilabel) load_dygraph_pretrain( model.pre_net, path=args.pretrained_model, -- GitLab