diff --git a/tools/export_model.py b/tools/export_model.py index 5d6b338dbcd7071ccda93471c6f9531246927eeb..299bac564e3563fd77c9be7021bb21abfcbe6441 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -37,15 +37,17 @@ def parse_args(): parser.add_argument("--class_dim", type=int, default=1000) parser.add_argument("--load_static_weights", type=str2bool, default=False) parser.add_argument("--img_size", type=int, default=224) + parser.add_argument("--multilabel", type=str2bool, default=False) return parser.parse_args() class Net(paddle.nn.Layer): - def __init__(self, net, class_dim, model): + def __init__(self, net, class_dim, model, multilabel): super(Net, self).__init__() self.pre_net = net(class_dim=class_dim) self.model = model + self.multilabel = multilabel def eval(self): self.training = False @@ -57,7 +59,7 @@ class Net(paddle.nn.Layer): x = self.pre_net(inputs) if self.model == "GoogLeNet": x = x[0] - x = F.softmax(x) + x = F.softmax(x) if not self.multilabel else F.sigmoid(x) return x @@ -65,7 +67,7 @@ def main(): args = parse_args() net = architectures.__dict__[args.model] - model = Net(net, args.class_dim, args.model) + model = Net(net, args.class_dim, args.model, args.multilabel) load_dygraph_pretrain( model.pre_net, path=args.pretrained_model,