diff --git a/tools/infer/infer.py b/tools/infer/infer.py index 12833ac5d599ef2a15ada466d9c8c14351461cb9..37d578eef1cc6c14b7459411d0b398af0db6d584 100644 --- a/tools/infer/infer.py +++ b/tools/infer/infer.py @@ -39,6 +39,7 @@ def parse_args(): parser.add_argument("-m", "--model", type=str) parser.add_argument("-p", "--pretrained_model", type=str) parser.add_argument("--use_gpu", type=str2bool, default=True) + parser.add_argument("--class_num", type=int, default=1000) parser.add_argument( "--load_static_weights", type=str2bool, @@ -122,7 +123,7 @@ def main(): paddle.disable_static(place) - net = architectures.__dict__[args.model]() + net = architectures.__dict__[args.model](class_dim=args.class_num) load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights) image_list = get_image_list(args.image_file) for idx, filename in enumerate(image_list):