diff --git a/configs/dota/README.md b/configs/dota/README.md index f11d441e61b3621137ac9221c4429492f4ad904d..371b1e39c484e0a27c2ef7e503ee4481c0ca3d8f 100644 --- a/configs/dota/README.md +++ b/configs/dota/README.md @@ -97,7 +97,7 @@ python3.7 tools/infer.py -c configs/dota/s2anet_1x_dota.yml -o weights=./weights ## 预测部署 -Paddle中`multiclass_nms`算子的输入支持四边形输入,因此部署时可以不不需要依赖旋转框IOU计算算子。 +Paddle中`multiclass_nms`算子的输入支持四边形输入,因此部署时可以不需要依赖旋转框IOU计算算子。 ```bash # 预测 diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index 87f6e24994ad756c5aae3c843a0e1f0609270332..a317d45e0f58fee999533d19e70c7f2a00c0f139 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -112,6 +112,11 @@ def _dump_infer_config(config, path, image_shape, model): config['TestReader'], config['TestDataset'], config['metric'], label_arch, image_shape) + if infer_arch == 'S2ANet': + # TODO: move background to num_classes + if infer_cfg['label_list'][0] != 'background': + infer_cfg['label_list'].insert(0, 'background') + yaml.dump(infer_cfg, open(path, 'w')) logger.info("Export inference config file to {}".format(os.path.join(path))) return image_shape