未验证 提交 a518e453 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #1132 from cuicheng01/release/2.1

fix ml export_model
......@@ -37,15 +37,17 @@ def parse_args():
parser.add_argument("--class_dim", type=int, default=1000)
parser.add_argument("--load_static_weights", type=str2bool, default=False)
parser.add_argument("--img_size", type=int, default=224)
parser.add_argument("--multilabel", type=str2bool, default=False)
return parser.parse_args()
class Net(paddle.nn.Layer):
def __init__(self, net, class_dim, model):
def __init__(self, net, class_dim, model, multilabel):
super(Net, self).__init__()
self.pre_net = net(class_dim=class_dim)
self.model = model
self.multilabel = multilabel
def eval(self):
self.training = False
......@@ -57,7 +59,7 @@ class Net(paddle.nn.Layer):
x = self.pre_net(inputs)
if self.model == "GoogLeNet":
x = x[0]
x = F.softmax(x)
x = F.softmax(x) if not self.multilabel else F.sigmoid(x)
return x
......@@ -65,7 +67,7 @@ def main():
args = parse_args()
net = architectures.__dict__[args.model]
model = Net(net, args.class_dim, args.model)
model = Net(net, args.class_dim, args.model, args.multilabel)
load_dygraph_pretrain(
model.pre_net,
path=args.pretrained_model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册