提交 17c66f42 编写于 作者: G gaotingquan

Fix a bug about inference

上级 36de6ee3
......@@ -28,6 +28,7 @@ import paddle
from paddle.distributed import ParallelEnv
import paddle.nn.functional as F
def parse_args():
def str2bool(v):
return v.lower() in ("true", "t", "1")
......@@ -101,8 +102,11 @@ def main():
place = paddle.CPUPlace()
paddle.disable_static(place)
if "EfficientNet" in args.model:
net = architectures.__dict__[args.model](is_test=True)
else:
net = architectures.__dict__[args.model]()
net = architectures.__dict__[args.model]()
load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights)
image_list = get_image_list(args.image_file)
for idx, filename in enumerate(image_list):
......
......@@ -61,7 +61,6 @@ def create_model(architecture, classes_num):
Args:
architecture(dict): architecture information,
name(such as ResNet50) is needed
image(variable): model input variable
classes_num(int): num of classes
Returns:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册