提交 0e63a246 编写于 作者: C cuicheng01

fix infer class_num bugs in release/2.0-beta branch

上级 82e7a90b
......@@ -39,6 +39,7 @@ def parse_args():
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-p", "--pretrained_model", type=str)
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--class_num", type=int, default=1000)
parser.add_argument(
"--load_static_weights",
type=str2bool,
......@@ -122,7 +123,7 @@ def main():
paddle.disable_static(place)
net = architectures.__dict__[args.model]()
net = architectures.__dict__[args.model](class_dim=args.class_num)
load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights)
image_list = get_image_list(args.image_file)
for idx, filename in enumerate(image_list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册