From 40cabbde151c738de78af7a5cd3ed475b101c42d Mon Sep 17 00:00:00 2001 From: cnn Date: Thu, 20 May 2021 17:37:28 +0800 Subject: [PATCH] [dev] fix infer bug of s2anet (#3080) * fix infer bug of s2anet --- configs/dota/README.md | 2 +- ppdet/engine/export_utils.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/configs/dota/README.md b/configs/dota/README.md index f11d441e6..371b1e39c 100644 --- a/configs/dota/README.md +++ b/configs/dota/README.md @@ -97,7 +97,7 @@ python3.7 tools/infer.py -c configs/dota/s2anet_1x_dota.yml -o weights=./weights ## 预测部署 -Paddle中`multiclass_nms`算子的输入支持四边形输入,因此部署时可以不不需要依赖旋转框IOU计算算子。 +Paddle中`multiclass_nms`算子的输入支持四边形输入,因此部署时可以不需要依赖旋转框IOU计算算子。 ```bash # 预测 diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index 87f6e2499..a317d45e0 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -112,6 +112,11 @@ def _dump_infer_config(config, path, image_shape, model): config['TestReader'], config['TestDataset'], config['metric'], label_arch, image_shape) + if infer_arch == 'S2ANet': + # TODO: move background to num_classes + if infer_cfg['label_list'][0] != 'background': + infer_cfg['label_list'].insert(0, 'background') + yaml.dump(infer_cfg, open(path, 'w')) logger.info("Export inference config file to {}".format(os.path.join(path))) return image_shape -- GitLab