From 17c66f4290148d199a9c7412b95803993e2d44ec Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Mon, 19 Oct 2020 23:06:53 +0800 Subject: [PATCH] Fix a bug about inference --- tools/infer/infer.py | 6 +++++- tools/program.py | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tools/infer/infer.py b/tools/infer/infer.py index de31d98d..e84a9837 100644 --- a/tools/infer/infer.py +++ b/tools/infer/infer.py @@ -28,6 +28,7 @@ import paddle from paddle.distributed import ParallelEnv import paddle.nn.functional as F + def parse_args(): def str2bool(v): return v.lower() in ("true", "t", "1") @@ -101,8 +102,11 @@ def main(): place = paddle.CPUPlace() paddle.disable_static(place) + if "EfficientNet" in args.model: + net = architectures.__dict__[args.model](is_test=True) + else: + net = architectures.__dict__[args.model]() - net = architectures.__dict__[args.model]() load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights) image_list = get_image_list(args.image_file) for idx, filename in enumerate(image_list): diff --git a/tools/program.py b/tools/program.py index d6f66680..dcc9decb 100644 --- a/tools/program.py +++ b/tools/program.py @@ -61,7 +61,6 @@ def create_model(architecture, classes_num): Args: architecture(dict): architecture information, name(such as ResNet50) is needed - image(variable): model input variable classes_num(int): num of classes Returns: -- GitLab