提交 55d2d6c8 编写于 作者: Z Zhi Tian

add FCOS

上级 b4d54657
---
name: "\U0001F41B Bug Report"
about: Submit a bug report to help us improve Mask R-CNN Benchmark
---
## 🐛 Bug
<!-- A clear and concise description of what the bug is. -->
## To Reproduce
Steps to reproduce the behavior:
1.
1.
1.
<!-- If you have a code sample, error messages, stack traces, please provide it here as well -->
## Expected behavior
<!-- A clear and concise description of what you expected to happen. -->
## Environment
Please copy and paste the output from the
[environment collection script from PyTorch](https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py)
(or fill out the checklist below manually).
You can get the script and run it with:
```
wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
```
- PyTorch Version (e.g., 1.0):
- OS (e.g., Linux):
- How you installed PyTorch (`conda`, `pip`, source):
- Build command you used (if compiling from source):
- Python version:
- CUDA/cuDNN version:
- GPU models and configuration:
- Any other relevant information:
## Additional context
<!-- Add any other context about the problem here. -->
---
name: "\U0001F680Feature Request"
about: Submit a proposal/request for a new Mask R-CNN Benchmark feature
---
## 🚀 Feature
<!-- A clear and concise description of the feature proposal -->
## Motivation
<!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
## Pitch
<!-- A clear and concise description of what you want to happen. -->
## Alternatives
<!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
## Additional context
<!-- Add any other context or screenshots about the feature request here. -->
---
name: "❓Questions/Help/Support"
about: Do you need support?
---
## ❓ Questions and Help
......@@ -17,13 +17,13 @@
# for that, check that `which conda`, `which pip` and `which python` points to the
# right path. From a clean conda env, this is what you need to do
conda create --name maskrcnn_benchmark
conda activate maskrcnn_benchmark
conda create --name FCOS
conda activate FCOS
# this installs the right pip and dependencies for the fresh python
conda install ipython
# maskrcnn_benchmark and coco api dependencies
# FCOS and coco api dependencies
pip install ninja yacs cython matplotlib tqdm
# follow PyTorch installation in https://pytorch.org/get-started/locally/
......@@ -40,8 +40,8 @@ python setup.py build_ext install
# install PyTorch Detection
cd $INSTALL_DIR
git clone https://github.com/facebookresearch/maskrcnn-benchmark.git
cd maskrcnn-benchmark
git clone https://github.com/tianzhi0549/FCOS.git
cd FCOS
# the following will install the lib with
# symbolic links, so that you can modify
......@@ -57,6 +57,7 @@ unset INSTALL_DIR
```
### Option 2: Docker Image (Requires CUDA, Linux only)
*The following steps are for original maskrcnn-benchmark. Please change the repository name if needed.*
Build image with defaults (`CUDA=9.0`, `CUDNN=7`, `FORCE_CUDA=1`):
......
MIT License
FCOS for non-commercial purposes
Copyright (c) 2018 Facebook
Copyright (c) 2019 the authors
All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Faster R-CNN and Mask R-CNN in PyTorch 1.0
This project aims at providing the necessary building blocks for easily
creating detection and segmentation models using PyTorch 1.0.
![alt text](demo/demo_e2e_mask_rcnn_X_101_32x8d_FPN_1x.png "from http://cocodataset.org/#explore?id=345434")
## Highlights
- **PyTorch 1.0:** RPN, Faster R-CNN and Mask R-CNN implementations that matches or exceeds Detectron accuracies
- **Very fast**: up to **2x** faster than [Detectron](https://github.com/facebookresearch/Detectron) and **30%** faster than [mmdetection](https://github.com/open-mmlab/mmdetection) during training. See [MODEL_ZOO.md](MODEL_ZOO.md) for more details.
- **Memory efficient:** uses roughly 500MB less GPU memory than mmdetection during training
- **Multi-GPU training and inference**
- **Batched inference:** can perform inference using multiple images per batch per GPU
- **CPU support for inference:** runs on CPU in inference time. See our [webcam demo](demo) for an example
- Provides pre-trained models for almost all reference Mask R-CNN and Faster R-CNN configurations with 1x schedule.
## Webcam and Jupyter notebook demo
We provide a simple webcam demo that illustrates how you can use `maskrcnn_benchmark` for inference:
```bash
cd demo
# by default, it runs on the GPU
# for best results, use min-image-size 800
python webcam.py --min-image-size 800
# can also run it on the CPU
python webcam.py --min-image-size 300 MODEL.DEVICE cpu
# or change the model that you want to use
python webcam.py --config-file ../configs/caffe2/e2e_mask_rcnn_R_101_FPN_1x_caffe2.yaml --min-image-size 300 MODEL.DEVICE cpu
# in order to see the probability heatmaps, pass --show-mask-heatmaps
python webcam.py --min-image-size 300 --show-mask-heatmaps MODEL.DEVICE cpu
# for the keypoint demo
python webcam.py --config-file ../configs/caffe2/e2e_keypoint_rcnn_R_50_FPN_1x_caffe2.yaml --min-image-size 300 MODEL.DEVICE cpu
```
A notebook with the demo can be found in [demo/Mask_R-CNN_demo.ipynb](demo/Mask_R-CNN_demo.ipynb).
## Installation
Check [INSTALL.md](INSTALL.md) for installation instructions.
## Model Zoo and Baselines
Pre-trained models, baselines and comparison with Detectron and mmdetection
can be found in [MODEL_ZOO.md](MODEL_ZOO.md)
## Inference in a few lines
We provide a helper class to simplify writing inference pipelines using pre-trained models.
Here is how we would do it. Run this from the `demo` folder:
```python
from maskrcnn_benchmark.config import cfg
from predictor import COCODemo
config_file = "../configs/caffe2/e2e_mask_rcnn_R_50_FPN_1x_caffe2.yaml"
# update the config options with the config file
cfg.merge_from_file(config_file)
# manual override some options
cfg.merge_from_list(["MODEL.DEVICE", "cpu"])
coco_demo = COCODemo(
cfg,
min_image_size=800,
confidence_threshold=0.7,
)
# load image and then run prediction
image = ...
predictions = coco_demo.run_on_opencv_image(image)
```
## Perform training on COCO dataset
For the following examples to work, you need to first install `maskrcnn_benchmark`.
You will also need to download the COCO dataset.
We recommend to symlink the path to the coco dataset to `datasets/` as follows
We use `minival` and `valminusminival` sets from [Detectron](https://github.com/facebookresearch/Detectron/blob/master/detectron/datasets/data/README.md#coco-minival-annotations)
```bash
# symlink the coco dataset
cd ~/github/maskrcnn-benchmark
mkdir -p datasets/coco
ln -s /path_to_coco_dataset/annotations datasets/coco/annotations
ln -s /path_to_coco_dataset/train2014 datasets/coco/train2014
ln -s /path_to_coco_dataset/test2014 datasets/coco/test2014
ln -s /path_to_coco_dataset/val2014 datasets/coco/val2014
# or use COCO 2017 version
ln -s /path_to_coco_dataset/annotations datasets/coco/annotations
ln -s /path_to_coco_dataset/train2017 datasets/coco/train2017
ln -s /path_to_coco_dataset/test2017 datasets/coco/test2017
ln -s /path_to_coco_dataset/val2017 datasets/coco/val2017
# for pascal voc dataset:
ln -s /path_to_VOCdevkit_dir datasets/voc
```
P.S. `COCO_2017_train` = `COCO_2014_train` + `valminusminival` , `COCO_2017_val` = `minival`
You can also configure your own paths to the datasets.
For that, all you need to do is to modify `maskrcnn_benchmark/config/paths_catalog.py` to
point to the location where your dataset is stored.
You can also create a new `paths_catalog.py` file which implements the same two classes,
and pass it as a config argument `PATHS_CATALOG` during training.
### Single GPU training
Most of the configuration files that we provide assume that we are running on 8 GPUs.
In order to be able to run it on fewer GPUs, there are a few possibilities:
**1. Run the following without modifications**
```bash
python /path_to_maskrcnn_benchmark/tools/train_net.py --config-file "/path/to/config/file.yaml"
```
This should work out of the box and is very similar to what we should do for multi-GPU training.
But the drawback is that it will use much more GPU memory. The reason is that we set in the
configuration files a global batch size that is divided over the number of GPUs. So if we only
have a single GPU, this means that the batch size for that GPU will be 8x larger, which might lead
to out-of-memory errors.
If you have a lot of memory available, this is the easiest solution.
**2. Modify the cfg parameters**
If you experience out-of-memory errors, you can reduce the global batch size. But this means that
you'll also need to change the learning rate, the number of iterations and the learning rate schedule.
Here is an example for Mask R-CNN R-50 FPN with the 1x schedule:
```bash
python tools/train_net.py --config-file "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml" SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025 SOLVER.MAX_ITER 720000 SOLVER.STEPS "(480000, 640000)" TEST.IMS_PER_BATCH 1
```
This follows the [scheduling rules from Detectron.](https://github.com/facebookresearch/Detectron/blob/master/configs/getting_started/tutorial_1gpu_e2e_faster_rcnn_R-50-FPN.yaml#L14-L30)
Note that we have multiplied the number of iterations by 8x (as well as the learning rate schedules),
and we have divided the learning rate by 8x.
We also changed the batch size during testing, but that is generally not necessary because testing
requires much less memory than training.
### Multi-GPU training
We use internally `torch.distributed.launch` in order to launch
multi-gpu training. This utility function from PyTorch spawns as many
Python processes as the number of GPUs we want to use, and each Python
process will only use a single GPU.
```bash
export NGPUS=8
python -m torch.distributed.launch --nproc_per_node=$NGPUS /path_to_maskrcnn_benchmark/tools/train_net.py --config-file "path/to/config/file.yaml"
```
## Abstractions
For more information on some of the main abstractions in our implementation, see [ABSTRACTIONS.md](ABSTRACTIONS.md).
## Adding your own dataset
This implementation adds support for COCO-style datasets.
But adding support for training on a new dataset can be done as follows:
```python
from maskrcnn_benchmark.structures.bounding_box import BoxList
class MyDataset(object):
def __init__(self, ...):
# as you would do normally
def __getitem__(self, idx):
# load the image as a PIL Image
image = ...
# load the bounding boxes as a list of list of boxes
# in this case, for illustrative purposes, we use
# x1, y1, x2, y2 order.
boxes = [[0, 0, 10, 10], [10, 20, 50, 50]]
# and labels
labels = torch.tensor([10, 20])
# create a BoxList from the boxes
boxlist = BoxList(boxes, image.size, mode="xyxy")
# add the labels to the boxlist
boxlist.add_field("labels", labels)
if self.transforms:
image, boxlist = self.transforms(image, boxlist)
# return the image, the boxlist and the idx in your dataset
return image, boxlist, idx
def get_img_info(self, idx):
# get img_height and img_width. This is used if
# we want to split the batches according to the aspect ratio
# of the image, as it can be more efficient than loading the
# image from disk
return {"height": img_height, "width": img_width}
```
That's it. You can also add extra fields to the boxlist, such as segmentation masks
(using `structures.segmentation_mask.SegmentationMask`), or even your own instance type.
For a full example of how the `COCODataset` is implemented, check [`maskrcnn_benchmark/data/datasets/coco.py`](maskrcnn_benchmark/data/datasets/coco.py).
Once you have created your dataset, it needs to be added in a couple of places:
- [`maskrcnn_benchmark/data/datasets/__init__.py`](maskrcnn_benchmark/data/datasets/__init__.py): add it to `__all__`
- [`maskrcnn_benchmark/config/paths_catalog.py`](maskrcnn_benchmark/config/paths_catalog.py): `DatasetCatalog.DATASETS` and corresponding `if` clause in `DatasetCatalog.get()`
### Testing
While the aforementioned example should work for training, we leverage the
cocoApi for computing the accuracies during testing. Thus, test datasets
should currently follow the cocoApi for now.
To enable your dataset for testing, add a corresponding if statement in [`maskrcnn_benchmark/data/datasets/evaluation/__init__.py`](maskrcnn_benchmark/data/datasets/evaluation/__init__.py):
```python
if isinstance(dataset, datasets.MyDataset):
return coco_evaluation(**args)
```
## Finetuning from Detectron weights on custom datasets
Create a script `tools/trim_detectron_model.py` like [here](https://gist.github.com/wangg12/aea194aa6ab6a4de088f14ee193fd968).
You can decide which keys to be removed and which keys to be kept by modifying the script.
Then you can simply point the converted model path in the config file by changing `MODEL.WEIGHT`.
For further information, please refer to [#15](https://github.com/facebookresearch/maskrcnn-benchmark/issues/15).
## Troubleshooting
If you have issues running or compiling this code, we have compiled a list of common issues in
[TROUBLESHOOTING.md](TROUBLESHOOTING.md). If your issue is not present there, please feel
free to open a new issue.
## Citations
Please consider citing this project in your publications if it helps your research. The following is a BibTeX reference. The BibTeX entry requires the `url` LaTeX package.
```
@misc{massa2018mrcnn,
author = {Massa, Francisco and Girshick, Ross},
title = {{maskrcnn-benchmark: Fast, modular reference implementation of Instance Segmentation and Object Detection algorithms in PyTorch}},
year = {2018},
howpublished = {\url{https://github.com/facebookresearch/maskrcnn-benchmark}},
note = {Accessed: [Insert date here]}
}
```
## Projects using maskrcnn-benchmark
- [RetinaMask: Learning to predict masks improves state-of-the-art single-shot detection for free](https://arxiv.org/abs/1901.03353).
Cheng-Yang Fu, Mykhailo Shvets, and Alexander C. Berg.
Tech report, arXiv,1901.03353.
## License
maskrcnn-benchmark is released under the MIT license. See [LICENSE](LICENSE) for additional details.
# Faster R-CNN and Mask R-CNN in PyTorch 1.0
# FCOS: Fully Convolutional One-Stage Object Detection
This project aims at providing the necessary building blocks for easily
creating detection and segmentation models using PyTorch 1.0.
The codes are used for implementing FCOS for object detection, described in:
![alt text](demo/demo_e2e_mask_rcnn_X_101_32x8d_FPN_1x.png "from http://cocodataset.org/#explore?id=345434")
FCOS: Fully Convolutional One-Stage Object Detection,
Tian, Zhi, Chunhua Shen, Hao Chen, and Tong He,
arXiv preprint arXiv:1904.01355 (2019).
The full paper is available at: [https://arxiv.org/abs/1904.01355](https://arxiv.org/abs/1904.01355).
## Highlights
- **PyTorch 1.0:** RPN, Faster R-CNN and Mask R-CNN implementations that matches or exceeds Detectron accuracies
- **Very fast**: up to **2x** faster than [Detectron](https://github.com/facebookresearch/Detectron) and **30%** faster than [mmdetection](https://github.com/open-mmlab/mmdetection) during training. See [MODEL_ZOO.md](MODEL_ZOO.md) for more details.
- **Memory efficient:** uses roughly 500MB less GPU memory than mmdetection during training
- **Multi-GPU training and inference**
- **Batched inference:** can perform inference using multiple images per batch per GPU
- **CPU support for inference:** runs on CPU in inference time. See our [webcam demo](demo) for an example
- Provides pre-trained models for almost all reference Mask R-CNN and Faster R-CNN configurations with 1x schedule.
## Webcam and Jupyter notebook demo
We provide a simple webcam demo that illustrates how you can use `maskrcnn_benchmark` for inference:
```bash
cd demo
# by default, it runs on the GPU
# for best results, use min-image-size 800
python webcam.py --min-image-size 800
# can also run it on the CPU
python webcam.py --min-image-size 300 MODEL.DEVICE cpu
# or change the model that you want to use
python webcam.py --config-file ../configs/caffe2/e2e_mask_rcnn_R_101_FPN_1x_caffe2.yaml --min-image-size 300 MODEL.DEVICE cpu
# in order to see the probability heatmaps, pass --show-mask-heatmaps
python webcam.py --min-image-size 300 --show-mask-heatmaps MODEL.DEVICE cpu
# for the keypoint demo
python webcam.py --config-file ../configs/caffe2/e2e_keypoint_rcnn_R_50_FPN_1x_caffe2.yaml --min-image-size 300 MODEL.DEVICE cpu
```
- **Totally anchor-free:** FCOS completely avoids the complicated computation related to anchor boxes and all hyper-parameters of anchor boxes.
- **Memory-efficient:** FCOS uses 2x less training memory footprint than its anchor-based counterpart RetinaNet.
- **Better performance:** Compared to RetinaNet, FCOS has better performance under exactly the same training and testing settings.
- **State-of-the-art performance:** Without bells and whistles, FCOS achieves state-of-the-art performances.
It achieves **41.0%** (ResNet-101-FPN) and **42.1%** (ResNeXt-32x8d-101) in AP on coco test-dev.
- **Faster:** FCOS enjoys faster training and inference speed than RetinaNet.
A notebook with the demo can be found in [demo/Mask_R-CNN_demo.ipynb](demo/Mask_R-CNN_demo.ipynb).
## Required hardware
We use 8 Nvidia V100 GPUs. \
But 4 1080Ti GPUs can also train a fully-fledged ResNet-50-FPN based FCOS since FCOS is memory-efficient.
## Installation
Check [INSTALL.md](INSTALL.md) for installation instructions.
## Model Zoo and Baselines
Pre-trained models, baselines and comparison with Detectron and mmdetection
can be found in [MODEL_ZOO.md](MODEL_ZOO.md)
## Inference in a few lines
We provide a helper class to simplify writing inference pipelines using pre-trained models.
Here is how we would do it. Run this from the `demo` folder:
```python
from maskrcnn_benchmark.config import cfg
from predictor import COCODemo
This FCOS implementation is based on [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark), so its installation is the same as original maskrcnn-benchmark.
config_file = "../configs/caffe2/e2e_mask_rcnn_R_50_FPN_1x_caffe2.yaml"
Please check [INSTALL.md](INSTALL.md) for installation instructions.
You may also want to see the original [README.md](MASKRCNN_README.md) of maskrcnn-benchmark.
# update the config options with the config file
cfg.merge_from_file(config_file)
# manual override some options
cfg.merge_from_list(["MODEL.DEVICE", "cpu"])
## Inference
The inference command line on coco minival split:
coco_demo = COCODemo(
cfg,
min_image_size=800,
confidence_threshold=0.7,
)
# load image and then run prediction
image = ...
predictions = coco_demo.run_on_opencv_image(image)
```
## Perform training on COCO dataset
For the following examples to work, you need to first install `maskrcnn_benchmark`.
python tools/test_net.py \
--config-file configs/fcos/fcos_R_50_FPN_1x.yaml \
MODEL.WEIGHT models/FCOS_R_50_FPN_1x.pth \
TEST.IMS_PER_BATCH 4
You will also need to download the COCO dataset.
We recommend to symlink the path to the coco dataset to `datasets/` as follows
Please note that:
1) If your model has other name, please replace `models/FCOS_R_50_FPN_1x.pth` with the name.
2) If you enounter out-of-memory error, please try to reduce `TEST.IMS_PER_BATCH` to 1.
3) If you want to evaluate another model, please change `--config-file` to its config file (in [configs/fcos](configs/fcos)) and `MODEL.WEIGHT` to its weights file.
We use `minival` and `valminusminival` sets from [Detectron](https://github.com/facebookresearch/Detectron/blob/master/detectron/datasets/data/README.md#coco-minival-annotations)
```bash
# symlink the coco dataset
cd ~/github/maskrcnn-benchmark
mkdir -p datasets/coco
ln -s /path_to_coco_dataset/annotations datasets/coco/annotations
ln -s /path_to_coco_dataset/train2014 datasets/coco/train2014
ln -s /path_to_coco_dataset/test2014 datasets/coco/test2014
ln -s /path_to_coco_dataset/val2014 datasets/coco/val2014
# or use COCO 2017 version
ln -s /path_to_coco_dataset/annotations datasets/coco/annotations
ln -s /path_to_coco_dataset/train2017 datasets/coco/train2017
ln -s /path_to_coco_dataset/test2017 datasets/coco/test2017
ln -s /path_to_coco_dataset/val2017 datasets/coco/val2017
# for pascal voc dataset:
ln -s /path_to_VOCdevkit_dir datasets/voc
```
For your convenience, we provide the following trained models (more models are coming soon).
P.S. `COCO_2017_train` = `COCO_2014_train` + `valminusminival` , `COCO_2017_val` = `minival`
Model | Total training mem (GB) | Multi-scale training | Testing time / im | AP (minival) | AP (test-dev) | Link
--- |:---:|:---:|:---:|:---:|:--:|:---:
FCOS_R_50_FPN_1x | 29.3 | No | 71ms | 36.6 | 37.0 | [download](https://cloudstor.aarnet.edu.au/plus/s/dDeDPBLEAt19Xrl/download)
FCOS_R_101_FPN_2x | 44.1 | Yes | 74ms | 40.9 | 41.0 | [download](https://cloudstor.aarnet.edu.au/plus/s/vjL3L0AW7vnhRTo/download)
FCOS_X_101_32x8d_FPN_2x | 72.9 | Yes | 122ms | 42.0 | 42.1 | [download](https://cloudstor.aarnet.edu.au/plus/s/U5myBfGF7MviZ97/download)
You can also configure your own paths to the datasets.
For that, all you need to do is to modify `maskrcnn_benchmark/config/paths_catalog.py` to
point to the location where your dataset is stored.
You can also create a new `paths_catalog.py` file which implements the same two classes,
and pass it as a config argument `PATHS_CATALOG` during training.
[1] *1x means the model is trained for 90K iterations.* \
[2] *2x means the model is trained for 180K iterations.* \
[3] *We report total training memory footprint on all GPUs instead of the memory footprint per GPU as in maskrcnn-benchmark*.
### Single GPU training
## Training
Most of the configuration files that we provide assume that we are running on 8 GPUs.
In order to be able to run it on fewer GPUs, there are a few possibilities:
The following command line will train FCOS_R_50_FPN_1x on 8 GPUs with Synchronous Stochastic Gradient Descent (SGD):
**1. Run the following without modifications**
python -m torch.distributed.launch \
--nproc_per_node=8 \
--master_port=$((RANDOM + 10000)) \
tools/train_net.py \
--skip-test \
--config-file configs/fcos/fcos_R_50_FPN_1x.yaml \
DATALOADER.NUM_WORKERS 2 \
OUTPUT_DIR training_dir/fcos_R_50_FPN_1x
Note that:
1) If you want to use fewer GPUs, please reduce `--nproc_per_node`. The total batch size does not depends on `nproc_per_node`. If you want to change the total batch size, please change `SOLVER.IMS_PER_BATCH` in [configs/fcos/fcos_R_50_FPN_1x.yaml](configs/fcos/fcos_R_50_FPN_1x.yaml).
2) The models will be saved into `OUTPUT_DIR`.
3) If you want to train FCOS with other backbones, please change `--config-file`.
4) Sometimes you may encounter a deadlock with 100% GPUs' usage, which might be a problem of NCCL. Please try `export NCCL_P2P_DISABLE=1` before running the training command line.
```bash
python /path_to_maskrcnn_benchmark/tools/train_net.py --config-file "/path/to/config/file.yaml"
```
This should work out of the box and is very similar to what we should do for multi-GPU training.
But the drawback is that it will use much more GPU memory. The reason is that we set in the
configuration files a global batch size that is divided over the number of GPUs. So if we only
have a single GPU, this means that the batch size for that GPU will be 8x larger, which might lead
to out-of-memory errors.
If you have a lot of memory available, this is the easiest solution.
**2. Modify the cfg parameters**
If you experience out-of-memory errors, you can reduce the global batch size. But this means that
you'll also need to change the learning rate, the number of iterations and the learning rate schedule.
Here is an example for Mask R-CNN R-50 FPN with the 1x schedule:
```bash
python tools/train_net.py --config-file "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml" SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025 SOLVER.MAX_ITER 720000 SOLVER.STEPS "(480000, 640000)" TEST.IMS_PER_BATCH 1
```
This follows the [scheduling rules from Detectron.](https://github.com/facebookresearch/Detectron/blob/master/configs/getting_started/tutorial_1gpu_e2e_faster_rcnn_R-50-FPN.yaml#L14-L30)
Note that we have multiplied the number of iterations by 8x (as well as the learning rate schedules),
and we have divided the learning rate by 8x.
We also changed the batch size during testing, but that is generally not necessary because testing
requires much less memory than training.
### Multi-GPU training
We use internally `torch.distributed.launch` in order to launch
multi-gpu training. This utility function from PyTorch spawns as many
Python processes as the number of GPUs we want to use, and each Python
process will only use a single GPU.
```bash
export NGPUS=8
python -m torch.distributed.launch --nproc_per_node=$NGPUS /path_to_maskrcnn_benchmark/tools/train_net.py --config-file "path/to/config/file.yaml"
```
## Contributing to the project
## Abstractions
For more information on some of the main abstractions in our implementation, see [ABSTRACTIONS.md](ABSTRACTIONS.md).
## Adding your own dataset
This implementation adds support for COCO-style datasets.
But adding support for training on a new dataset can be done as follows:
```python
from maskrcnn_benchmark.structures.bounding_box import BoxList
class MyDataset(object):
def __init__(self, ...):
# as you would do normally
def __getitem__(self, idx):
# load the image as a PIL Image
image = ...
# load the bounding boxes as a list of list of boxes
# in this case, for illustrative purposes, we use
# x1, y1, x2, y2 order.
boxes = [[0, 0, 10, 10], [10, 20, 50, 50]]
# and labels
labels = torch.tensor([10, 20])
# create a BoxList from the boxes
boxlist = BoxList(boxes, image.size, mode="xyxy")
# add the labels to the boxlist
boxlist.add_field("labels", labels)
if self.transforms:
image, boxlist = self.transforms(image, boxlist)
# return the image, the boxlist and the idx in your dataset
return image, boxlist, idx
def get_img_info(self, idx):
# get img_height and img_width. This is used if
# we want to split the batches according to the aspect ratio
# of the image, as it can be more efficient than loading the
# image from disk
return {"height": img_height, "width": img_width}
```
That's it. You can also add extra fields to the boxlist, such as segmentation masks
(using `structures.segmentation_mask.SegmentationMask`), or even your own instance type.
For a full example of how the `COCODataset` is implemented, check [`maskrcnn_benchmark/data/datasets/coco.py`](maskrcnn_benchmark/data/datasets/coco.py).
Once you have created your dataset, it needs to be added in a couple of places:
- [`maskrcnn_benchmark/data/datasets/__init__.py`](maskrcnn_benchmark/data/datasets/__init__.py): add it to `__all__`
- [`maskrcnn_benchmark/config/paths_catalog.py`](maskrcnn_benchmark/config/paths_catalog.py): `DatasetCatalog.DATASETS` and corresponding `if` clause in `DatasetCatalog.get()`
### Testing
While the aforementioned example should work for training, we leverage the
cocoApi for computing the accuracies during testing. Thus, test datasets
should currently follow the cocoApi for now.
To enable your dataset for testing, add a corresponding if statement in [`maskrcnn_benchmark/data/datasets/evaluation/__init__.py`](maskrcnn_benchmark/data/datasets/evaluation/__init__.py):
```python
if isinstance(dataset, datasets.MyDataset):
return coco_evaluation(**args)
```
## Finetuning from Detectron weights on custom datasets
Create a script `tools/trim_detectron_model.py` like [here](https://gist.github.com/wangg12/aea194aa6ab6a4de088f14ee193fd968).
You can decide which keys to be removed and which keys to be kept by modifying the script.
Then you can simply point the converted model path in the config file by changing `MODEL.WEIGHT`.
For further information, please refer to [#15](https://github.com/facebookresearch/maskrcnn-benchmark/issues/15).
## Troubleshooting
If you have issues running or compiling this code, we have compiled a list of common issues in
[TROUBLESHOOTING.md](TROUBLESHOOTING.md). If your issue is not present there, please feel
free to open a new issue.
Any pull requests or issues are weclome.
## Citations
Please consider citing this project in your publications if it helps your research. The following is a BibTeX reference. The BibTeX entry requires the `url` LaTeX package.
Please consider citing our paper in your publications if the project helps your research. The following is a BibTeX reference.
```
@misc{massa2018mrcnn,
author = {Massa, Francisco and Girshick, Ross},
title = {{maskrcnn-benchmark: Fast, modular reference implementation of Instance Segmentation and Object Detection algorithms in PyTorch}},
year = {2018},
howpublished = {\url{https://github.com/facebookresearch/maskrcnn-benchmark}},
note = {Accessed: [Insert date here]}
@article{tian2019fcos,
title={FCOS: Fully Convolutional One-Stage Object Detection},
author={Tian, Zhi and Shen, Chunhua and Chen, Hao and He, Tong},
journal={arXiv preprint arXiv:1904.01355},
year={2019}
}
```
## Projects using maskrcnn-benchmark
- [RetinaMask: Learning to predict masks improves state-of-the-art single-shot detection for free](https://arxiv.org/abs/1901.03353).
Cheng-Yang Fu, Mykhailo Shvets, and Alexander C. Berg.
Tech report, arXiv,1901.03353.
## License
maskrcnn-benchmark is released under the MIT license. See [LICENSE](LICENSE) for additional details.
For academic use, this project is licensed under the 2-clause BSD License - see the LICENSE file for details. For commercial use, please contact the authors.
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-101"
RPN_ONLY: True
FCOS_ON: True
BACKBONE:
CONV_BODY: "R-101-FPN-RETINANET"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
RETINANET:
USE_C5: False # FCOS uses P5 instead of C5
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_RANGE_TRAIN: (640, 800)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.01
WEIGHT_DECAY: 0.0001
STEPS: (120000, 160000)
MAX_ITER: 180000
IMS_PER_BATCH: 16
WARMUP_METHOD: "constant"
\ No newline at end of file
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
RPN_ONLY: True
FCOS_ON: True
BACKBONE:
CONV_BODY: "R-50-FPN-RETINANET"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
RETINANET:
USE_C5: False # FCOS uses P5 instead of C5
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (800,)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.01
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 16
WARMUP_METHOD: "constant"
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/FAIR/20171220/X-101-32x8d"
RPN_ONLY: True
FCOS_ON: True
BACKBONE:
CONV_BODY: "R-101-FPN-RETINANET"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
NUM_GROUPS: 32
WIDTH_PER_GROUP: 8
RETINANET:
USE_C5: False # FCOS uses P5 instead of C5
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_RANGE_TRAIN: (640, 800)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.01
WEIGHT_DECAY: 0.0001
STEPS: (120000, 160000)
MAX_ITER: 180000
IMS_PER_BATCH: 16
WARMUP_METHOD: "constant"
......@@ -23,6 +23,7 @@ _C = CN()
_C.MODEL = CN()
_C.MODEL.RPN_ONLY = False
_C.MODEL.MASK_ON = False
_C.MODEL.FCOS_ON = True
_C.MODEL.RETINANET_ON = False
_C.MODEL.KEYPOINT_ON = False
_C.MODEL.DEVICE = "cuda"
......@@ -41,6 +42,8 @@ _C.MODEL.WEIGHT = ""
_C.INPUT = CN()
# Size of the smallest side of the image during training
_C.INPUT.MIN_SIZE_TRAIN = (800,) # (800,)
# The range of the smallest side for multi-scale training
_C.INPUT.MIN_SIZE_RANGE_TRAIN = (-1, -1) # -1 means disabled and it will use MIN_SIZE_TRAIN
# Maximum size of the side of the image during training
_C.INPUT.MAX_SIZE_TRAIN = 1333
# Size of the smallest side of the image during testing
......@@ -274,6 +277,24 @@ _C.MODEL.RESNETS.BACKBONE_OUT_CHANNELS = 256 * 4
_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256
_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64
# ---------------------------------------------------------------------------- #
# FCOS Options
# ---------------------------------------------------------------------------- #
_C.MODEL.FCOS = CN()
_C.MODEL.FCOS.NUM_CLASSES = 81 # the number of classes including background
_C.MODEL.FCOS.FPN_STRIDES = [8, 16, 32, 64, 128]
_C.MODEL.FCOS.PRIOR_PROB = 0.01
_C.MODEL.FCOS.INFERENCE_TH = 0.05
_C.MODEL.FCOS.NMS_TH = 0.4
_C.MODEL.FCOS.PRE_NMS_TOP_N = 1000
# Focal loss parameter: alpha
_C.MODEL.FCOS.LOSS_ALPHA = 0.25
# Focal loss parameter: gamma
_C.MODEL.FCOS.LOSS_GAMMA = 2.0
# the number of convolutions used in the cls and bbox tower
_C.MODEL.FCOS.NUM_CONVS = 4
# ---------------------------------------------------------------------------- #
# RetinaNet Options (Follow the Detectron version)
......
......@@ -4,7 +4,15 @@ from . import transforms as T
def build_transforms(cfg, is_train=True):
if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
if cfg.INPUT.MIN_SIZE_RANGE_TRAIN[0] == -1:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
else:
assert len(cfg.INPUT.MIN_SIZE_RANGE_TRAIN) == 2, \
"MIN_SIZE_RANGE_TRAIN must have two elements (lower bound, upper bound)"
min_size = range(
cfg.INPUT.MIN_SIZE_RANGE_TRAIN[0],
cfg.INPUT.MIN_SIZE_RANGE_TRAIN[1] + 1
)
max_size = cfg.INPUT.MAX_SIZE_TRAIN
flip_prob = 0.5 # cfg.INPUT.FLIP_PROB_TRAIN
else:
......
......@@ -13,9 +13,12 @@ from .roi_pool import ROIPool
from .roi_pool import roi_pool
from .smooth_l1_loss import smooth_l1_loss
from .sigmoid_focal_loss import SigmoidFocalLoss
from .iou_loss import IOULoss
from .scale import Scale
__all__ = ["nms", "roi_align", "ROIAlign", "roi_pool", "ROIPool",
"smooth_l1_loss", "Conv2d", "ConvTranspose2d", "interpolate",
"BatchNorm2d", "FrozenBatchNorm2d", "SigmoidFocalLoss"
]
"BatchNorm2d", "FrozenBatchNorm2d", "SigmoidFocalLoss", "IOULoss",
"Scale"]
import torch
from torch import nn
class IOULoss(nn.Module):
def forward(self, pred, target, weight=None):
pred_left = pred[:, 0]
pred_top = pred[:, 1]
pred_right = pred[:, 2]
pred_bottom = pred[:, 3]
target_left = target[:, 0]
target_top = target[:, 1]
target_right = target[:, 2]
target_bottom = target[:, 3]
target_aera = (target_left + target_right) * \
(target_top + target_bottom)
pred_aera = (pred_left + pred_right) * \
(pred_top + pred_bottom)
w_intersect = torch.min(pred_left, target_left) + \
torch.min(pred_right, target_right)
h_intersect = torch.min(pred_bottom, target_bottom) + \
torch.min(pred_top, target_top)
area_intersect = w_intersect * h_intersect
area_union = target_aera + pred_aera - area_intersect
losses = -torch.log((area_intersect + 1.0) / (area_union + 1.0))
if weight is not None and weight.sum() > 0:
return (losses * weight).sum() / weight.sum()
else:
assert losses.numel() != 0
return losses.mean()
import torch
from torch import nn
class Scale(nn.Module):
def __init__(self, init_value=1.0):
super(Scale, self).__init__()
self.scale = nn.Parameter(torch.FloatTensor([init_value]))
def forward(self, input):
return input * self.scale
import math
import torch
import torch.nn.functional as F
from torch import nn
from .inference import make_fcos_postprocessor
from .loss import make_fcos_loss_evaluator
from maskrcnn_benchmark.layers import Scale
class FCOSHead(torch.nn.Module):
def __init__(self, cfg, in_channels):
"""
Arguments:
in_channels (int): number of channels of the input feature
"""
super(FCOSHead, self).__init__()
# TODO: Implement the sigmoid version first.
num_classes = cfg.MODEL.FCOS.NUM_CLASSES - 1
cls_tower = []
bbox_tower = []
for i in range(cfg.MODEL.FCOS.NUM_CONVS):
cls_tower.append(
nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1
)
)
cls_tower.append(nn.GroupNorm(32, in_channels))
cls_tower.append(nn.ReLU())
bbox_tower.append(
nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1
)
)
bbox_tower.append(nn.GroupNorm(32, in_channels))
bbox_tower.append(nn.ReLU())
self.add_module('cls_tower', nn.Sequential(*cls_tower))
self.add_module('bbox_tower', nn.Sequential(*bbox_tower))
self.cls_logits = nn.Conv2d(
in_channels, num_classes, kernel_size=3, stride=1,
padding=1
)
self.bbox_pred = nn.Conv2d(
in_channels, 4, kernel_size=3, stride=1,
padding=1
)
self.centerness = nn.Conv2d(
in_channels, 1, kernel_size=3, stride=1,
padding=1
)
# initialization
for modules in [self.cls_tower, self.bbox_tower,
self.cls_logits, self.bbox_pred,
self.centerness]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
# initialize the bias for focal loss
prior_prob = cfg.MODEL.FCOS.PRIOR_PROB
bias_value = -math.log((1 - prior_prob) / prior_prob)
torch.nn.init.constant_(self.cls_logits.bias, bias_value)
self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)])
def forward(self, x):
logits = []
bbox_reg = []
centerness = []
for l, feature in enumerate(x):
cls_tower = self.cls_tower(feature)
logits.append(self.cls_logits(cls_tower))
centerness.append(self.centerness(cls_tower))
bbox_reg.append(torch.exp(self.scales[l](
self.bbox_pred(self.bbox_tower(feature))
)))
return logits, bbox_reg, centerness
class FCOSModule(torch.nn.Module):
"""
Module for FCOS computation. Takes feature maps from the backbone and
FCOS outputs and losses. Only Test on FPN now.
"""
def __init__(self, cfg, in_channels):
super(FCOSModule, self).__init__()
head = FCOSHead(cfg, in_channels)
box_selector_test = make_fcos_postprocessor(cfg)
loss_evaluator = make_fcos_loss_evaluator(cfg)
self.head = head
self.box_selector_test = box_selector_test
self.loss_evaluator = loss_evaluator
self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES
def forward(self, images, features, targets=None):
"""
Arguments:
images (ImageList): images for which we want to compute the predictions
features (list[Tensor]): features computed from the images that are
used for computing the predictions. Each tensor in the list
correspond to different feature levels
targets (list[BoxList): ground-truth boxes present in the image (optional)
Returns:
boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
image.
losses (dict[Tensor]): the losses for the model during training. During
testing, it is an empty dict.
"""
box_cls, box_regression, centerness = self.head(features)
locations = self.compute_locations(features)
if self.training:
return self._forward_train(
locations, box_cls,
box_regression,
centerness, targets
)
else:
return self._forward_test(
locations, box_cls, box_regression,
centerness, images.image_sizes
)
def _forward_train(self, locations, box_cls, box_regression, centerness, targets):
loss_box_cls, loss_box_reg, loss_centerness = self.loss_evaluator(
locations, box_cls, box_regression, centerness, targets
)
losses = {
"loss_cls": loss_box_cls,
"loss_reg": loss_box_reg,
"loss_centerness": loss_centerness
}
return None, losses
def _forward_test(self, locations, box_cls, box_regression, centerness, image_sizes):
boxes = self.box_selector_test(
locations, box_cls, box_regression,
centerness, image_sizes
)
return boxes, {}
def compute_locations(self, features):
locations = []
for level, feature in enumerate(features):
h, w = feature.size()[-2:]
locations_per_level = self.compute_locations_per_level(
h, w, self.fpn_strides[level],
feature.device
)
locations.append(locations_per_level)
return locations
def compute_locations_per_level(self, h, w, stride, device):
shifts_x = torch.arange(
0, w * stride, step=stride,
dtype=torch.float32, device=device
)
shifts_y = torch.arange(
0, h * stride, step=stride,
dtype=torch.float32, device=device
)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2
return locations
def build_fcos(cfg, in_channels):
return FCOSModule(cfg, in_channels)
import torch
from ..inference import RPNPostProcessor
from ..utils import permute_and_flatten
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from maskrcnn_benchmark.modeling.utils import cat
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms
from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes
class FCOSPostProcessor(torch.nn.Module):
"""
Performs post-processing on the outputs of the RetinaNet boxes.
This is only used in the testing.
"""
def __init__(
self,
pre_nms_thresh,
pre_nms_top_n,
nms_thresh,
fpn_post_nms_top_n,
min_size,
num_classes,
):
"""
Arguments:
pre_nms_thresh (float)
pre_nms_top_n (int)
nms_thresh (float)
fpn_post_nms_top_n (int)
min_size (int)
num_classes (int)
box_coder (BoxCoder)
"""
super(FCOSPostProcessor, self).__init__()
self.pre_nms_thresh = pre_nms_thresh
self.pre_nms_top_n = pre_nms_top_n
self.nms_thresh = nms_thresh
self.fpn_post_nms_top_n = fpn_post_nms_top_n
self.min_size = min_size
self.num_classes = num_classes
def forward_for_single_feature_map(
self, locations, box_cls,
box_regression, centerness,
image_sizes):
"""
Arguments:
anchors: list[BoxList]
box_cls: tensor of size N, A * C, H, W
box_regression: tensor of size N, A * 4, H, W
"""
N, C, H, W = box_cls.shape
# put in the same format as locations
box_cls = box_cls.view(N, C, H, W).permute(0, 2, 3, 1)
box_cls = box_cls.reshape(N, -1, C).sigmoid()
box_regression = box_regression.view(N, 4, H, W).permute(0, 2, 3, 1)
box_regression = box_regression.reshape(N, -1, 4)
centerness = centerness.view(N, 1, H, W).permute(0, 2, 3, 1)
centerness = centerness.reshape(N, -1).sigmoid()
candidate_inds = box_cls > self.pre_nms_thresh
pre_nms_top_n = candidate_inds.view(N, -1).sum(1)
pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)
# multiply the classification scores with centerness scores
box_cls = box_cls * centerness[:, :, None]
results = []
for i in range(N):
per_box_cls = box_cls[i]
per_candidate_inds = candidate_inds[i]
per_box_cls = per_box_cls[per_candidate_inds]
per_candidate_nonzeros = per_candidate_inds.nonzero()
per_box_loc = per_candidate_nonzeros[:, 0]
per_class = per_candidate_nonzeros[:, 1] + 1
per_box_regression = box_regression[i]
per_box_regression = per_box_regression[per_box_loc]
per_locations = locations[per_box_loc]
per_pre_nms_top_n = pre_nms_top_n[i]
if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
per_box_cls, top_k_indices = \
per_box_cls.topk(per_pre_nms_top_n, sorted=False)
per_class = per_class[top_k_indices]
per_box_regression = per_box_regression[top_k_indices]
per_locations = per_locations[top_k_indices]
detections = torch.stack([
per_locations[:, 0] - per_box_regression[:, 0],
per_locations[:, 1] - per_box_regression[:, 1],
per_locations[:, 0] + per_box_regression[:, 2],
per_locations[:, 1] + per_box_regression[:, 3],
], dim=1)
h, w = image_sizes[i]
boxlist = BoxList(detections, (int(w), int(h)), mode="xyxy")
boxlist.add_field("labels", per_class)
boxlist.add_field("scores", per_box_cls)
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist = remove_small_boxes(boxlist, self.min_size)
results.append(boxlist)
return results
def forward(self, locations, box_cls, box_regression, centerness, image_sizes):
"""
Arguments:
anchors: list[list[BoxList]]
box_cls: list[tensor]
box_regression: list[tensor]
image_sizes: list[(h, w)]
Returns:
boxlists (list[BoxList]): the post-processed anchors, after
applying box decoding and NMS
"""
sampled_boxes = []
for _, (l, o, b, c) in enumerate(zip(locations, box_cls, box_regression, centerness)):
sampled_boxes.append(
self.forward_for_single_feature_map(
l, o, b, c, image_sizes
)
)
boxlists = list(zip(*sampled_boxes))
boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
boxlists = self.select_over_all_levels(boxlists)
return boxlists
# TODO very similar to filter_results from PostProcessor
# but filter_results is per image
# TODO Yang: solve this issue in the future. No good solution
# right now.
def select_over_all_levels(self, boxlists):
num_images = len(boxlists)
results = []
for i in range(num_images):
scores = boxlists[i].get_field("scores")
labels = boxlists[i].get_field("labels")
boxes = boxlists[i].bbox
boxlist = boxlists[i]
result = []
# skip the background
for j in range(1, self.num_classes):
inds = (labels == j).nonzero().view(-1)
scores_j = scores[inds]
boxes_j = boxes[inds, :].view(-1, 4)
boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
boxlist_for_class.add_field("scores", scores_j)
boxlist_for_class = boxlist_nms(
boxlist_for_class, self.nms_thresh,
score_field="scores"
)
num_labels = len(boxlist_for_class)
boxlist_for_class.add_field(
"labels", torch.full((num_labels,), j,
dtype=torch.int64,
device=scores.device)
)
result.append(boxlist_for_class)
result = cat_boxlist(result)
number_of_detections = len(result)
# Limit to max_per_image detections **over all classes**
if number_of_detections > self.fpn_post_nms_top_n > 0:
cls_scores = result.get_field("scores")
image_thresh, _ = torch.kthvalue(
cls_scores.cpu(),
number_of_detections - self.fpn_post_nms_top_n + 1
)
keep = cls_scores >= image_thresh.item()
keep = torch.nonzero(keep).squeeze(1)
result = result[keep]
results.append(result)
return results
def make_fcos_postprocessor(config):
pre_nms_thresh = config.MODEL.FCOS.INFERENCE_TH
pre_nms_top_n = config.MODEL.FCOS.PRE_NMS_TOP_N
nms_thresh = config.MODEL.FCOS.NMS_TH
fpn_post_nms_top_n = config.TEST.DETECTIONS_PER_IMG
box_selector = FCOSPostProcessor(
pre_nms_thresh=pre_nms_thresh,
pre_nms_top_n=pre_nms_top_n,
nms_thresh=nms_thresh,
fpn_post_nms_top_n=fpn_post_nms_top_n,
min_size=0,
num_classes=config.MODEL.FCOS.NUM_CLASSES
)
return box_selector
"""
This file contains specific functions for computing losses of FCOS
file
"""
import torch
from torch.nn import functional as F
from torch import nn
from ..utils import concat_box_prediction_layers
from maskrcnn_benchmark.layers import IOULoss
from maskrcnn_benchmark.layers import SigmoidFocalLoss
from maskrcnn_benchmark.modeling.matcher import Matcher
from maskrcnn_benchmark.modeling.utils import cat
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
INF = 100000000
class FCOSLossComputation(object):
"""
This class computes the FCOS losses.
"""
def __init__(self, cfg):
self.cls_loss_func = SigmoidFocalLoss(
cfg.MODEL.FCOS.LOSS_GAMMA,
cfg.MODEL.FCOS.LOSS_ALPHA
)
# we make use of IOU Loss for bounding boxes regression,
# but we found that L1 in log scale can yield a similar performance
self.box_reg_loss_func = IOULoss()
self.centerness_loss_func = nn.BCEWithLogitsLoss()
def prepare_targets(self, points, targets):
object_sizes_of_interest = [
[-1, 64],
[64, 128],
[128, 256],
[256, 512],
[512, INF],
]
expanded_object_sizes_of_interest = []
for l, points_per_level in enumerate(points):
object_sizes_of_interest_per_level = \
points_per_level.new_tensor(object_sizes_of_interest[l])
expanded_object_sizes_of_interest.append(
object_sizes_of_interest_per_level[None].expand(len(points_per_level), -1)
)
expanded_object_sizes_of_interest = torch.cat(expanded_object_sizes_of_interest, dim=0)
num_points_per_level = [len(points_per_level) for points_per_level in points]
points_all_level = torch.cat(points, dim=0)
labels, reg_targets = self.compute_targets_for_locations(
points_all_level, targets, expanded_object_sizes_of_interest
)
for i in range(len(labels)):
labels[i] = torch.split(labels[i], num_points_per_level, dim=0)
reg_targets[i] = torch.split(reg_targets[i], num_points_per_level, dim=0)
labels_level_first = []
reg_targets_level_first = []
for level in range(len(points)):
labels_level_first.append(
torch.cat([labels_per_im[level] for labels_per_im in labels], dim=0)
)
reg_targets_level_first.append(
torch.cat([reg_targets_per_im[level] for reg_targets_per_im in reg_targets], dim=0)
)
return labels_level_first, reg_targets_level_first
def compute_targets_for_locations(self, locations, targets, object_sizes_of_interest):
labels = []
reg_targets = []
xs, ys = locations[:, 0], locations[:, 1]
for im_i in range(len(targets)):
targets_per_im = targets[im_i]
assert targets_per_im.mode == "xyxy"
bboxes = targets_per_im.bbox
labels_per_im = targets_per_im.get_field("labels")
area = targets_per_im.area()
l = xs[:, None] - bboxes[:, 0][None]
t = ys[:, None] - bboxes[:, 1][None]
r = bboxes[:, 2][None] - xs[:, None]
b = bboxes[:, 3][None] - ys[:, None]
reg_targets_per_im = torch.stack([l, t, r, b], dim=2)
is_in_boxes = reg_targets_per_im.min(dim=2)[0] > 0
max_reg_targets_per_im = reg_targets_per_im.max(dim=2)[0]
# limit the regression range for each location
is_cared_in_the_level = \
(max_reg_targets_per_im >= object_sizes_of_interest[:, [0]]) & \
(max_reg_targets_per_im <= object_sizes_of_interest[:, [1]])
locations_to_gt_area = area[None].repeat(len(locations), 1)
locations_to_gt_area[is_in_boxes == 0] = INF
locations_to_gt_area[is_cared_in_the_level == 0] = INF
# if there are still more than one objects for a location,
# we choose the one with minimal area
locations_to_min_aera, locations_to_gt_inds = locations_to_gt_area.min(dim=1)
reg_targets_per_im = reg_targets_per_im[range(len(locations)), locations_to_gt_inds]
labels_per_im = labels_per_im[locations_to_gt_inds]
labels_per_im[locations_to_min_aera == INF] = 0
labels.append(labels_per_im)
reg_targets.append(reg_targets_per_im)
return labels, reg_targets
def compute_centerness_targets(self, reg_targets):
left_right = reg_targets[:, [0, 2]]
top_bottom = reg_targets[:, [1, 3]]
centerness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \
(top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
return torch.sqrt(centerness)
def __call__(self, locations, box_cls, box_regression, centerness, targets):
"""
Arguments:
locations (list[BoxList])
box_cls (list[Tensor])
box_regression (list[Tensor])
centerness (list[Tensor])
targets (list[BoxList])
Returns:
cls_loss (Tensor)
reg_loss (Tensor)
centerness_loss (Tensor)
"""
N = box_cls[0].size(0)
num_classes = box_cls[0].size(1)
labels, reg_targets = self.prepare_targets(locations, targets)
box_cls_flatten = []
box_regression_flatten = []
centerness_flatten = []
labels_flatten = []
reg_targets_flatten = []
for l in range(len(labels)):
box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(-1, num_classes))
box_regression_flatten.append(box_regression[l].permute(0, 2, 3, 1).reshape(-1, 4))
labels_flatten.append(labels[l].reshape(-1))
reg_targets_flatten.append(reg_targets[l].reshape(-1, 4))
centerness_flatten.append(centerness[l].reshape(-1))
box_cls_flatten = torch.cat(box_cls_flatten, dim=0)
box_regression_flatten = torch.cat(box_regression_flatten, dim=0)
centerness_flatten = torch.cat(centerness_flatten, dim=0)
labels_flatten = torch.cat(labels_flatten, dim=0)
reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0)
pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1)
cls_loss = self.cls_loss_func(
box_cls_flatten,
labels_flatten.int()
) / (pos_inds.numel() + N) # add N to avoid dividing by a zero
box_regression_flatten = box_regression_flatten[pos_inds]
reg_targets_flatten = reg_targets_flatten[pos_inds]
centerness_flatten = centerness_flatten[pos_inds]
if pos_inds.numel() > 0:
centerness_targets = self.compute_centerness_targets(reg_targets_flatten)
reg_loss = self.box_reg_loss_func(
box_regression_flatten,
reg_targets_flatten,
centerness_targets
)
centerness_loss = self.centerness_loss_func(
centerness_flatten,
centerness_targets
)
else:
reg_loss = box_regression_flatten.sum()
centerness_loss = centerness_flatten.sum()
return cls_loss, reg_loss, centerness_loss
def make_fcos_loss_evaluator(cfg):
loss_evaluator = FCOSLossComputation(cfg)
return loss_evaluator
......@@ -6,6 +6,7 @@ from torch import nn
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from maskrcnn_benchmark.modeling.rpn.retinanet.retinanet import build_retinanet
from maskrcnn_benchmark.modeling.rpn.fcos.fcos import build_fcos
from .loss import make_rpn_loss_evaluator
from .anchor_generator import make_anchor_generator
from .inference import make_rpn_postprocessor
......@@ -201,6 +202,8 @@ def build_rpn(cfg, in_channels):
"""
This gives the gist of it. Not super important because it doesn't change as much
"""
if cfg.MODEL.FCOS_ON:
return build_fcos(cfg, in_channels)
if cfg.MODEL.RETINANET_ON:
return build_retinanet(cfg, in_channels)
......
......@@ -195,8 +195,7 @@ class PolygonInstance(object):
polygons = valid_polygons
elif isinstance(polygons, PolygonInstance):
polygons = polygons.polygons.copy()
polygons = [p.clone() for p in polygons.polygons]
else:
RuntimeError(
"Type of argument `polygons` is not allowed:%s" % (type(polygons))
......
......@@ -84,7 +84,7 @@ def main():
data_loader_val,
dataset_name=dataset_name,
iou_types=iou_types,
box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
device=cfg.MODEL.DEVICE,
expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册