提交 01f1314f 编写于 作者: C cuicheng01

fix googlenet infer

上级 c492e1b2
......@@ -555,6 +555,8 @@ class Trainer(object):
if len(batch_data) >= batch_size or idx == len(image_list) - 1:
batch_tensor = paddle.to_tensor(batch_data)
out = self.model(batch_tensor)
if isinstance(out, list):
out = out[0]
result = postprocess_func(out, image_file_list)
print(result)
batch_data.clear()
......
......@@ -38,6 +38,7 @@ class ExportModel(nn.Layer):
def __init__(self, config):
super().__init__()
print (config)
self.base_model = build_model(config)
# we should choose a final model to export
......@@ -63,6 +64,8 @@ class ExportModel(nn.Layer):
def forward(self, x):
x = self.base_model(x)
if isinstance(x, list):
x = x[0]
if self.infer_model_name is not None:
x = x[self.infer_model_name]
if self.infer_output_key is not None:
......@@ -76,7 +79,6 @@ if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(
args.config, overrides=args.override, show=False)
log_file = os.path.join(config['Global']['output_dir'],
config["Arch"]["name"], "export.log")
init_logger(name='root', log_file=log_file)
......@@ -86,7 +88,6 @@ if __name__ == "__main__":
assert config["Global"]["device"] in ["cpu", "gpu", "xpu"]
device = paddle.set_device(config["Global"]["device"])
model = ExportModel(config["Arch"])
if config["Global"]["pretrained_model"] is not None:
load_dygraph_pretrain(model.base_model,
config["Global"]["pretrained_model"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册