diff --git a/tools/infer/infer.py b/tools/infer/infer.py index de31d98dc657f165826ff1da929e18341bd3ab93..e84a98379e538ad5f3fc52961986ceb7a439c7e0 100644 --- a/tools/infer/infer.py +++ b/tools/infer/infer.py @@ -28,6 +28,7 @@ import paddle from paddle.distributed import ParallelEnv import paddle.nn.functional as F + def parse_args(): def str2bool(v): return v.lower() in ("true", "t", "1") @@ -101,8 +102,11 @@ def main(): place = paddle.CPUPlace() paddle.disable_static(place) + if "EfficientNet" in args.model: + net = architectures.__dict__[args.model](is_test=True) + else: + net = architectures.__dict__[args.model]() - net = architectures.__dict__[args.model]() load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights) image_list = get_image_list(args.image_file) for idx, filename in enumerate(image_list): diff --git a/tools/program.py b/tools/program.py index d6f666804fdf46aeb9202f6a0fe609842b40c1e9..dcc9decb03fefe6648f28039a2337d24a77d28c0 100644 --- a/tools/program.py +++ b/tools/program.py @@ -61,7 +61,6 @@ def create_model(architecture, classes_num): Args: architecture(dict): architecture information, name(such as ResNet50) is needed - image(variable): model input variable classes_num(int): num of classes Returns: