未验证 提交 4575dfe7 编写于 作者: G George Ni 提交者: GitHub

[MOT] add MOT data (#2789)

* add mot data

* fix operators, source

* fix data source transform

* fix parse_dataset register_op

* fix scale_factor, RandomAffine

* add assert for check

* fix ci
上级 385f9bbd
metric: MOTDet
num_classes: 1
TrainDataset:
!MOTDataSet
dataset_dir: dataset/mot
image_lists: ['mot17.train', 'caltech.train', 'cuhksysu.train', 'prw.train', 'citypersons.train', 'eth.train']
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
EvalDataset:
!MOTDataSet
dataset_dir: dataset/mot
image_lists: ['citypersons.val', 'caltech.val'] # for detection
# image_lists: ['caltech.10k.val', 'cuhksysu.val', 'prw.val'] # for reid
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
TestDataset:
!ImageFolder
dataset_dir: dataset/mot
EvalMOTDataset:
!ImageFolder
dataset_dir: dataset/mot
keep_ori_im: False # set True if save visualization images or video
TestMOTDataset:
!MOTVideoDataset
dataset_dir: dataset/mot
keep_ori_im: False
# MOT Dataset
* **MIXMOT**
We use the same training data as [JDE](https://github.com/Zhongdao/Towards-Realtime-MOT) and [FairMOT](https://github.com/ifzhang/FairMOT) in this part and we call it "MIXMOT". Please refer to their [DATA ZOO](https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/DATASET_ZOO.md) to download and prepare all the training data including Caltech Pedestrian, CityPersons, CUHK-SYSU, PRW, ETHZ, MOT17 and MOT16.
* **2DMOT15 and MOT20**
[2DMOT15](https://motchallenge.net/data/2D_MOT_2015/) and [MOT20](https://motchallenge.net/data/MOT20/) can be downloaded from the official webpage of MOT challenge. After downloading, you should prepare the data in the following structure:
```
MOT15
|——————images
| └——————train
| └——————test
└——————labels_with_ids
└——————train
MOT20
|——————images
| └——————train
| └——————test
└——————labels_with_ids
└——————train
```
Annotations of these several relevant datasets are provided in a unified format. If you want to use these datasets, please **follow their licenses**,
and if you use any of these datasets in your research, please cite the original work (you can find the BibTeX in the bottom).
## Data Format
All the datasets have the following structure:
```
Caltech
|——————images
| └——————00001.jpg
| |—————— ...
| └——————0000N.jpg
└——————labels_with_ids
└——————00001.txt
|—————— ...
└——————0000N.txt
```
Every image has a corresponding annotation text. Given an image path,
the annotation text path can be generated by replacing the string `images` with `labels_with_ids` and replacing `.jpg` with `.txt`.
In the annotation text, each line is describing a bounding box and has the following format:
```
[class] [identity] [x_center] [y_center] [width] [height]
```
The field `[class]` should be `0`. Only single-class multi-object tracking is supported in this version.
The field `[identity]` is an integer from `0` to `num_identities - 1`, or `-1` if this box has no identity annotation.
***Note** that the values of `[x_center] [y_center] [width] [height]` are normalized by the width/height of the image, so they are floating point numbers ranging from 0 to 1.
## Final Dataset root
```
dataset/mot
|——————image_lists
|——————caltech.10k.val
|——————caltech.train
|——————caltech.val
|——————citypersons.train
|——————citypersons.val
|——————cuhksysu.train
|——————cuhksysu.val
|——————eth.train
|——————mot16.train
|——————mot17.train
|——————prw.train
|——————prw.val
|——————Caltech
|——————Cityscapes
|——————CUHKSYSU
|——————ETHZ
|——————MOT15
|——————MOT16
|——————MOT17
|——————MOT20
|——————PRW
```
## Download
### Caltech Pedestrian
Baidu NetDisk:
[[0]](https://pan.baidu.com/s/1sYBXXvQaXZ8TuNwQxMcAgg)
[[1]](https://pan.baidu.com/s/1lVO7YBzagex1xlzqPksaPw)
[[2]](https://pan.baidu.com/s/1PZXxxy_lrswaqTVg0GuHWg)
[[3]](https://pan.baidu.com/s/1M93NCo_E6naeYPpykmaNgA)
[[4]](https://pan.baidu.com/s/1ZXCdPNXfwbxQ4xCbVu5Dtw)
[[5]](https://pan.baidu.com/s/1kcZkh1tcEiBEJqnDtYuejg)
[[6]](https://pan.baidu.com/s/1sDjhtgdFrzR60KKxSjNb2A)
[[7]](https://pan.baidu.com/s/18Zvp_d33qj1pmutFDUbJyw)
Google Drive: [[annotations]](https://drive.google.com/file/d/1h8vxl_6tgi9QVYoer9XcY9YwNB32TE5k/view?usp=sharing) ,
please download all the images `.tar` files from [this page](http://www.vision.caltech.edu/Image_Datasets/CaltechPedestrians/datasets/USA/) and unzip the images under `Caltech/images`
You may need [this tool](https://github.com/mitmul/caltech-pedestrian-dataset-converter) to convert the original data format to jpeg images.
Original dataset webpage: [CaltechPedestrians](http://www.vision.caltech.edu/Image_Datasets/CaltechPedestrians/)
### CityPersons
Baidu NetDisk:
[[0]](https://pan.baidu.com/s/1g24doGOdkKqmbgbJf03vsw)
[[1]](https://pan.baidu.com/s/1mqDF9M5MdD3MGxSfe0ENsA)
[[2]](https://pan.baidu.com/s/1Qrbh9lQUaEORCIlfI25wdA)
[[3]](https://pan.baidu.com/s/1lw7shaffBgARDuk8mkkHhw)
Google Drive:
[[0]](https://drive.google.com/file/d/1DgLHqEkQUOj63mCrS_0UGFEM9BG8sIZs/view?usp=sharing)
[[1]](https://drive.google.com/file/d/1BH9Xz59UImIGUdYwUR-cnP1g7Ton_LcZ/view?usp=sharing)
[[2]](https://drive.google.com/file/d/1q_OltirP68YFvRWgYkBHLEFSUayjkKYE/view?usp=sharing)
[[3]](https://drive.google.com/file/d/1VSL0SFoQxPXnIdBamOZJzHrHJ1N2gsTW/view?usp=sharing)
Original dataset webpage: [Citypersons pedestrian detection dataset](https://bitbucket.org/shanshanzhang/citypersons)
### CUHK-SYSU
Baidu NetDisk:
[[0]](https://pan.baidu.com/s/1YFrlyB1WjcQmFW3Vt_sEaQ)
Google Drive:
[[0]](https://drive.google.com/file/d/1D7VL43kIV9uJrdSCYl53j89RE2K-IoQA/view?usp=sharing)
Original dataset webpage: [CUHK-SYSU Person Search Dataset](http://www.ee.cuhk.edu.hk/~xgwang/PS/dataset.html)
### PRW
Baidu NetDisk:
[[0]](https://pan.baidu.com/s/1iqOVKO57dL53OI1KOmWeGQ)
Google Drive:
[[0]](https://drive.google.com/file/d/116_mIdjgB-WJXGe8RYJDWxlFnc_4sqS8/view?usp=sharing)
Original dataset webpage: [Person Search in the Wild datset](http://www.liangzheng.com.cn/Project/project_prw.html)
### ETHZ (overlapping videos with MOT-16 removed):
Baidu NetDisk:
[[0]](https://pan.baidu.com/s/14EauGb2nLrcB3GRSlQ4K9Q)
Google Drive:
[[0]](https://drive.google.com/file/d/19QyGOCqn8K_rc9TXJ8UwLSxCx17e0GoY/view?usp=sharing)
Original dataset webpage: [ETHZ pedestrian datset](https://data.vision.ee.ethz.ch/cvl/aess/dataset/)
### MOT-17
Baidu NetDisk:
[[0]](https://pan.baidu.com/s/1lHa6UagcosRBz-_Y308GvQ)
Google Drive:
[[0]](https://drive.google.com/file/d/1ET-6w12yHNo8DKevOVgK1dBlYs739e_3/view?usp=sharing)
Original dataset webpage: [MOT-17](https://motchallenge.net/data/MOT17/)
### MOT-16 (for evaluation )
Baidu NetDisk:
[[0]](https://pan.baidu.com/s/10pUuB32Hro-h-KUZv8duiw)
Google Drive:
[[0]](https://drive.google.com/file/d/1254q3ruzBzgn4LUejDVsCtT05SIEieQg/view?usp=sharing)
Original dataset webpage: [MOT-16](https://motchallenge.net/data/MOT16/)
# Citation
Caltech:
```
@inproceedings{ dollarCVPR09peds,
author = "P. Doll\'ar and C. Wojek and B. Schiele and P. Perona",
title = "Pedestrian Detection: A Benchmark",
booktitle = "CVPR",
month = "June",
year = "2009",
city = "Miami",
}
```
Citypersons:
```
@INPROCEEDINGS{Shanshan2017CVPR,
Author = {Shanshan Zhang and Rodrigo Benenson and Bernt Schiele},
Title = {CityPersons: A Diverse Dataset for Pedestrian Detection},
Booktitle = {CVPR},
Year = {2017}
}
@INPROCEEDINGS{Cordts2016Cityscapes,
title={The Cityscapes Dataset for Semantic Urban Scene Understanding},
author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt},
booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2016}
}
```
CUHK-SYSU:
```
@inproceedings{xiaoli2017joint,
title={Joint Detection and Identification Feature Learning for Person Search},
author={Xiao, Tong and Li, Shuang and Wang, Bochao and Lin, Liang and Wang, Xiaogang},
booktitle={CVPR},
year={2017}
}
```
PRW:
```
@inproceedings{zheng2017person,
title={Person re-identification in the wild},
author={Zheng, Liang and Zhang, Hengheng and Sun, Shaoyan and Chandraker, Manmohan and Yang, Yi and Tian, Qi},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
pages={1367--1376},
year={2017}
}
```
ETHZ:
```
@InProceedings{eth_biwi_00534,
author = {A. Ess and B. Leibe and K. Schindler and and L. van Gool},
title = {A Mobile Vision System for Robust Multi-Person Tracking},
booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR'08)},
year = {2008},
month = {June},
publisher = {IEEE Press},
keywords = {}
}
```
MOT-16&17:
```
@article{milan2016mot16,
title={MOT16: A benchmark for multi-object tracking},
author={Milan, Anton and Leal-Taix{\'e}, Laura and Reid, Ian and Roth, Stefan and Schindler, Konrad},
journal={arXiv preprint arXiv:1603.00831},
year={2016}
}
```
......@@ -271,3 +271,39 @@ class TestReader(BaseDataLoader):
super(TestReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last,
drop_empty, num_classes, **kwargs)
@register
class EvalMOTReader(BaseDataLoader):
__shared__ = ['num_classes']
def __init__(self,
sample_transforms=[],
batch_transforms=[],
batch_size=1,
shuffle=False,
drop_last=False,
drop_empty=True,
num_classes=1,
**kwargs):
super(EvalMOTReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last,
drop_empty, num_classes, **kwargs)
@register
class TestMOTReader(BaseDataLoader):
__shared__ = ['num_classes']
def __init__(self,
sample_transforms=[],
batch_transforms=[],
batch_size=1,
shuffle=False,
drop_last=False,
drop_empty=True,
num_classes=1,
**kwargs):
super(TestMOTReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last,
drop_empty, num_classes, **kwargs)
......@@ -17,9 +17,11 @@ from . import voc
from . import widerface
from . import category
from . import keypoint_coco
from . import mot
from .coco import *
from .voc import *
from .widerface import *
from .category import *
from .keypoint_coco import *
from .mot import *
......@@ -86,10 +86,28 @@ def get_categories(metric_type, anno_file=None, arch=None):
elif metric_type.lower() == 'keypointtopdowncocoeval':
return (None, {'id': 'keypoint'})
elif metric_type.lower() in ['mot', 'motdet', 'reid']:
return _mot_category()
else:
raise ValueError("unknown metric type {}".format(metric_type))
def _mot_category():
"""
Get class id to category id map and category id
to category name map of mot dataset
"""
label_map = {'person': 0}
label_map = sorted(label_map.items(), key=lambda x: x[1])
cats = [l[0] for l in label_map]
clsid2catid = {i: i for i in range(len(cats))}
catid2name = {i: name for i, name in enumerate(cats)}
return clsid2catid, catid2name
def _coco17_category():
"""
Get class id to category id map and category id
......
......@@ -87,8 +87,13 @@ class DetDataset(Dataset):
return self.transform(roidb)
def check_or_download_dataset(self):
self.dataset_dir = get_dataset_path(self.dataset_dir, self.anno_path,
if isinstance(self.anno_path, list):
for path in self.anno_path:
self.dataset_dir = get_dataset_path(self.dataset_dir, path,
self.image_dir)
else:
self.dataset_dir = get_dataset_path(self.dataset_dir,
self.anno_path, self.image_dir)
def set_kwargs(self, **kwargs):
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
......@@ -134,19 +139,18 @@ class ImageFolder(DetDataset):
def __init__(self,
dataset_dir=None,
image_dir=None,
anno_path=None,
sample_num=-1,
use_default_label=None,
keep_ori_im=False,
**kwargs):
super(ImageFolder, self).__init__(
dataset_dir,
image_dir,
anno_path,
sample_num=sample_num,
use_default_label=use_default_label)
self.keep_ori_im = keep_ori_im
self._imid2path = {}
self.roidbs = None
self.sample_num = sample_num
def check_or_download_dataset(self):
return
......@@ -178,6 +182,8 @@ class ImageFolder(DetDataset):
if self.sample_num > 0 and ct >= self.sample_num:
break
rec = {'im_id': np.array([ct]), 'im_file': image}
if self.keep_ori_im:
rec.update({'keep_ori_im': 1})
self._imid2path[ct] = image
ct += 1
records.append(rec)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
from collections import OrderedDict
from .dataset import DetDataset
from ppdet.core.workspace import register, serializable
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
@register
@serializable
class MOTDataSet(DetDataset):
"""
Load dataset with MOT format.
Args:
dataset_dir (str): root directory for dataset.
image_lists (str|list): mot data image lists, muiti-source mot dataset.
data_fields (list): key name of data dictionary, at least have 'image'.
sample_num (int): number of samples to load, -1 means all.
label_list (str): if use_default_label is False, will load
mapping between category and class index.
Notes:
MOT datasets root directory following this:
dataset/mot
|——————image_lists
| |——————caltech.train
| |——————caltech.val
| |——————mot16.train
| |——————mot17.train
| ......
|——————Caltech
|——————MOT17
|——————......
All the MOT datasets have the following structure:
Caltech
|——————images
| └——————00001.jpg
| |—————— ...
| └——————0000N.jpg
└——————labels_with_ids
└——————00001.txt
|—————— ...
└——————0000N.txt
or
MOT17
|——————images
| └——————train
| └——————test
└——————labels_with_ids
└——————train
"""
def __init__(self,
dataset_dir=None,
image_lists=[],
data_fields=['image'],
sample_num=-1,
label_list=None):
super(MOTDataSet, self).__init__(
dataset_dir=dataset_dir,
data_fields=data_fields,
sample_num=sample_num)
self.dataset_dir = dataset_dir
self.image_lists = image_lists
self.label_list = label_list
if isinstance(self.image_lists, str):
self.image_lists = [self.image_lists]
def get_anno(self):
if self.image_lists == []:
return
# only used to get categories and metric
return os.path.join(self.dataset_dir, 'image_lists',
self.image_lists[0])
def parse_dataset(self):
self.img_files = OrderedDict()
self.img_start_index = OrderedDict()
self.label_files = OrderedDict()
self.tid_num = OrderedDict()
self.tid_start_index = OrderedDict()
img_index = 0
for data_name in self.image_lists:
# check every data image list
image_lists_dir = os.path.join(self.dataset_dir, 'image_lists')
assert os.path.isdir(image_lists_dir), \
"The {} is not a directory.".format(image_lists_dir)
list_path = os.path.join(image_lists_dir, data_name)
assert os.path.exists(list_path), \
"The list path {} does not exist.".format(list_path)
# record img_files, filter out empty ones
with open(list_path, 'r') as file:
self.img_files[data_name] = file.readlines()
self.img_files[data_name] = [
os.path.join(self.dataset_dir, x.strip())
for x in self.img_files[data_name]
]
self.img_files[data_name] = list(
filter(lambda x: len(x) > 0, self.img_files[data_name]))
self.img_start_index[data_name] = img_index
img_index += len(self.img_files[data_name])
# check data directory, images and labels_with_ids
if len(self.img_files[data_name]) == 0:
continue
else:
# self.img_files[data_name] each line following this:
# {self.dataset_dir}/MOT17/images/..
first_path = self.img_files[data_name][0]
data_dir = first_path.replace(self.dataset_dir,
'').split('/')[1]
data_dir = os.path.join(self.dataset_dir, data_dir)
assert os.path.exists(data_dir), \
"The data directory {} does not exist.".format(data_dir)
data_dir_images = os.path.join(data_dir, 'images')
assert os.path.exists(data_dir), \
"The data images directory {} does not exist.".format(data_dir_images)
data_dir_labels_with_ids = os.path.join(data_dir,
'labels_with_ids')
assert os.path.exists(data_dir), \
"The data labels directory {} does not exist.".format(data_dir_labels_with_ids)
# record label_files
self.label_files[data_name] = [
x.replace('images', 'labels_with_ids').replace(
'.png', '.txt').replace('.jpg', '.txt')
for x in self.img_files[data_name]
]
for data_name, label_paths in self.label_files.items():
max_index = -1
for lp in label_paths:
lb = np.loadtxt(lp)
if len(lb) < 1:
continue
if len(lb.shape) < 2:
img_max = lb[1]
else:
img_max = np.max(lb[:, 1])
if img_max > max_index:
max_index = img_max
self.tid_num[data_name] = int(max_index + 1)
last_index = 0
for i, (k, v) in enumerate(self.tid_num.items()):
self.tid_start_index[k] = last_index
last_index += v
self.total_identities = int(last_index + 1)
self.num_imgs_each_data = [len(x) for x in self.img_files.values()]
self.total_imgs = sum(self.num_imgs_each_data)
logger.info('=' * 80)
logger.info('MOT dataset summary: ')
logger.info(self.tid_num)
logger.info('total images: {}'.format(self.total_imgs))
logger.info('image start index: {}'.format(self.img_start_index))
logger.info('total identities: {}'.format(self.total_identities))
logger.info('identity start index: {}'.format(self.tid_start_index))
logger.info('=' * 80)
# mapping category name to class id
# first_class:0, second_class:1, ...
records = []
if self.label_list:
label_list_path = os.path.join(self.dataset_dir, self.label_list)
if not os.path.exists(label_list_path):
raise ValueError("label_list {} does not exists".format(
label_list_path))
with open(label_list_path, 'r') as fr:
label_id = 0
for line in fr.readlines():
cname2cid[line.strip()] = label_id
label_id += 1
else:
cname2cid = mot_label()
for img_index in range(self.total_imgs):
for i, (k, v) in enumerate(self.img_start_index.items()):
if img_index >= v:
data_name = list(self.label_files.keys())[i]
start_index = v
img_file = self.img_files[data_name][img_index - start_index]
lbl_file = self.label_files[data_name][img_index - start_index]
if not os.path.exists(img_file):
logger.warn('Illegal image file: {}, and it will be ignored'.
format(img_file))
continue
if not os.path.isfile(lbl_file):
logger.warn('Illegal label file: {}, and it will be ignored'.
format(lbl_file))
continue
labels = np.loadtxt(lbl_file, dtype=np.float32).reshape(-1, 6)
# each row in labels (N, 6) is [gt_class, gt_identity, cx, cy, w, h]
cx, cy = labels[:, 2], labels[:, 3]
w, h = labels[:, 4], labels[:, 5]
gt_bbox = np.stack((cx, cy, w, h)).T.astype('float32')
gt_class = labels[:, 0:1].astype('int32')
gt_score = np.ones((len(labels), 1)).astype('float32')
gt_ide = labels[:, 1:2].astype('int32')
mot_rec = {
'im_file': img_file,
'im_id': img_index,
} if 'image' in self.data_fields else {}
gt_rec = {
'gt_class': gt_class,
'gt_score': gt_score,
'gt_bbox': gt_bbox,
'gt_ide': gt_ide,
}
for k, v in gt_rec.items():
if k in self.data_fields:
mot_rec[k] = v
records.append(mot_rec)
if self.sample_num > 0 and img_index >= self.sample_num:
break
assert len(records) > 0, 'not found any mot record in %s' % (
self.image_lists)
self.roidbs, self.cname2cid = records, cname2cid
def mot_label():
labels_map = {'person': 0}
return labels_map
def _is_valid_video(f, extensions=('.mp4', '.avi', '.mov', '.rmvb', 'flv')):
return f.lower().endswith(extensions)
@register
@serializable
class MOTVideoDataset(DetDataset):
"""
Load MOT dataset with MOT format from video for inference.
Args:
video_file (str): path of the video file
dataset_dir (str): root directory for dataset.
keep_ori_im (bool): whether to keep original image, default False.
Set True when used during MOT model inference while saving
images or video, or used in DeepSORT.
"""
def __init__(self,
video_file='',
dataset_dir=None,
keep_ori_im=False,
**kwargs):
super(MOTVideoDataset, self).__init__(dataset_dir=dataset_dir)
self.video_file = video_file
self.dataset_dir = dataset_dir
self.keep_ori_im = keep_ori_im
self.roidbs = None
def parse_dataset(self, ):
if not self.roidbs:
self.roidbs = self._load_video_images()
def _load_video_images(self):
self.cap = cv2.VideoCapture(self.video_file)
self.vn = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
logger.info('Length of the video: {:d} frames'.format(self.vn))
res = True
ct = 0
records = []
while res:
res, img = self.cap.read()
image = np.ascontiguousarray(img, dtype=np.float32)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
im_shape = image.shape
rec = {
'im_id': np.array([ct]),
'image': image,
'h': im_shape[0],
'w': im_shape[1],
'im_shape': np.array(
im_shape[:2], dtype=np.float32),
'scale_factor': np.array(
[1., 1.], dtype=np.float32),
}
if self.keep_ori_im:
rec.update({'ori_image': image})
ct += 1
records.append(rec)
records = records[:-1]
assert len(records) > 0, "No image file found"
return records
def set_video(self, video_file):
self.video_file = video_file
assert os.path.isfile(self.video_file) and _is_valid_video(self.video_file), \
"wrong or unsupported file format: {}".format(self.video_file)
self.roidbs = self._load_video_images()
......@@ -15,11 +15,14 @@
from . import operators
from . import batch_operators
from . import keypoint_operators
from . import mot_operators
from .operators import *
from .batch_operators import *
from .keypoint_operators import *
from .mot_operators import *
__all__ = []
__all__ += registered_ops
__all__ += keypoint_operators.__all__
__all__ += mot_operators.__all__
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
from numbers import Integral
import cv2
import copy
import numpy as np
from .operators import BaseOperator, register_op
from ppdet.modeling.bbox_utils import bbox_iou_np_expand
from ppdet.core.workspace import serializable
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = ['LetterBoxResize', 'Gt2JDETargetThres', 'Gt2JDETargetMax']
@register_op
class LetterBoxResize(BaseOperator):
def __init__(self, target_size):
"""
Resize image to target size, convert normalized xywh to pixel xyxy
format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
Args:
target_size (int|list): image target size.
"""
super(LetterBoxResize, self).__init__()
if not isinstance(target_size, (Integral, Sequence)):
raise TypeError(
"Type of target_size is invalid. Must be Integer or List or Tuple, now is {}".
format(type(target_size)))
if isinstance(target_size, Integral):
target_size = [target_size, target_size]
self.target_size = target_size
def apply_image(self, img, height, width, color=(127.5, 127.5, 127.5)):
# letterbox: resize a rectangular image to a padded rectangular
shape = img.shape[:2] # [height, width]
ratio_h = float(height) / shape[0]
ratio_w = float(width) / shape[1]
ratio = min(ratio_h, ratio_w)
new_shape = (round(shape[1] * ratio),
round(shape[0] * ratio)) # [width, height]
padw = (width - new_shape[0]) / 2
padh = (height - new_shape[1]) / 2
top, bottom = round(padh - 0.1), round(padh + 0.1)
left, right = round(padw - 0.1), round(padw + 0.1)
img = cv2.resize(
img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT,
value=color) # padded rectangular
return img, ratio, padw, padh
def apply_bbox(self, bbox0, h, w, ratio, padw, padh):
bboxes = bbox0.copy()
bboxes[:, 0] = ratio * w * (bbox0[:, 0] - bbox0[:, 2] / 2) + padw
bboxes[:, 1] = ratio * h * (bbox0[:, 1] - bbox0[:, 3] / 2) + padh
bboxes[:, 2] = ratio * w * (bbox0[:, 0] + bbox0[:, 2] / 2) + padw
bboxes[:, 3] = ratio * h * (bbox0[:, 1] + bbox0[:, 3] / 2) + padh
return bboxes
def apply(self, sample, context=None):
""" Resize the image numpy.
"""
im = sample['image']
h, w = sample['im_shape']
if not isinstance(im, np.ndarray):
raise TypeError("{}: image type is not numpy.".format(self))
if len(im.shape) != 3:
raise ImageError('{}: image is not 3-dimensional.'.format(self))
# apply image
height, width = self.target_size
img, ratio, padw, padh = self.apply_image(
im, height=height, width=width)
sample['image'] = img
new_shape = (round(h * ratio), round(w * ratio))
sample['im_shape'] = np.asarray(new_shape, dtype=np.float32)
sample['scale_factor'] = np.asarray([ratio, ratio], dtype=np.float32)
# apply bbox
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], h, w, ratio,
padw, padh)
return sample
@register_op
class Gt2JDETargetThres(BaseOperator):
__shared__ = ['num_classes']
"""
Generate JDE targets by groud truth data when training
Args:
anchors (list): anchors of JDE model
anchor_masks (list): anchor_masks of JDE model
downsample_ratios (list): downsample ratios of JDE model
ide_thresh (float): thresh of identity, higher is groud truth
fg_thresh (float): thresh of foreground, higher is foreground
bg_thresh (float): thresh of background, lower is background
num_classes (int): number of classes
"""
def __init__(self,
anchors,
anchor_masks,
downsample_ratios,
ide_thresh=0.5,
fg_thresh=0.5,
bg_thresh=0.4,
num_classes=1):
super(Gt2JDETargetThres, self).__init__()
self.anchors = anchors
self.anchor_masks = anchor_masks
self.downsample_ratios = downsample_ratios
self.ide_thresh = ide_thresh
self.fg_thresh = fg_thresh
self.bg_thresh = bg_thresh
self.num_classes = num_classes
def generate_anchor(self, nGh, nGw, anchor_hw):
nA = len(anchor_hw)
yy, xx = np.meshgrid(np.arange(nGh), np.arange(nGw))
mesh = np.stack([xx.T, yy.T], axis=0) # [2, nGh, nGw]
mesh = np.repeat(mesh[None, :], nA, axis=0) # [nA, 2, nGh, nGw]
anchor_offset_mesh = anchor_hw[:, :, None][:, :, :, None]
anchor_offset_mesh = np.repeat(anchor_offset_mesh, nGh, axis=-2)
anchor_offset_mesh = np.repeat(anchor_offset_mesh, nGw, axis=-1)
anchor_mesh = np.concatenate(
[mesh, anchor_offset_mesh], axis=1) # [nA, 4, nGh, nGw]
return anchor_mesh
def encode_delta(self, gt_box_list, fg_anchor_list):
px, py, pw, ph = fg_anchor_list[:, 0], fg_anchor_list[:,1], \
fg_anchor_list[:, 2], fg_anchor_list[:,3]
gx, gy, gw, gh = gt_box_list[:, 0], gt_box_list[:, 1], \
gt_box_list[:, 2], gt_box_list[:, 3]
dx = (gx - px) / pw
dy = (gy - py) / ph
dw = np.log(gw / pw)
dh = np.log(gh / ph)
return np.stack([dx, dy, dw, dh], axis=1)
def pad_box(self, sample, num_max):
assert 'gt_bbox' in sample
bbox = sample['gt_bbox']
gt_num = len(bbox)
pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
if gt_num > 0:
pad_bbox[:gt_num, :] = bbox[:gt_num, :]
sample['gt_bbox'] = pad_bbox
if 'gt_score' in sample:
pad_score = np.zeros((num_max, ), dtype=np.float32)
if gt_num > 0:
pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
sample['gt_score'] = pad_score
if 'difficult' in sample:
pad_diff = np.zeros((num_max, ), dtype=np.int32)
if gt_num > 0:
pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
sample['difficult'] = pad_diff
if 'is_crowd' in sample:
pad_crowd = np.zeros((num_max, ), dtype=np.int32)
if gt_num > 0:
pad_crowd[:gt_num] = sample['is_crowd'][:gt_num, 0]
sample['is_crowd'] = pad_crowd
if 'gt_ide' in sample:
pad_ide = np.zeros((num_max, ), dtype=np.int32)
if gt_num > 0:
pad_ide[:gt_num] = sample['gt_ide'][:gt_num, 0]
sample['gt_ide'] = pad_ide
return sample
def __call__(self, samples, context=None):
assert len(self.anchor_masks) == len(self.downsample_ratios), \
"anchor_masks', and 'downsample_ratios' should have same length."
h, w = samples[0]['image'].shape[1:3]
num_max = 0
for sample in samples:
num_max = max(num_max, len(sample['gt_bbox']))
for sample in samples:
gt_bbox = sample['gt_bbox']
gt_ide = sample['gt_ide']
for i, (anchor_hw, downsample_ratio
) in enumerate(zip(self.anchors, self.downsample_ratios)):
anchor_hw = np.array(
anchor_hw, dtype=np.float32) / downsample_ratio
nA = len(anchor_hw)
nGh, nGw = int(h / downsample_ratio), int(w / downsample_ratio)
tbox = np.zeros((nA, nGh, nGw, 4), dtype=np.float32)
tconf = np.zeros((nA, nGh, nGw), dtype=np.float32)
tid = -np.ones((nA, nGh, nGw, 1), dtype=np.float32)
gxy, gwh = gt_bbox[:, 0:2].copy(), gt_bbox[:, 2:4].copy()
gxy[:, 0] = gxy[:, 0] * nGw
gxy[:, 1] = gxy[:, 1] * nGh
gwh[:, 0] = gwh[:, 0] * nGw
gwh[:, 1] = gwh[:, 1] * nGh
gxy[:, 0] = np.clip(gxy[:, 0], 0, nGw - 1)
gxy[:, 1] = np.clip(gxy[:, 1], 0, nGh - 1)
tboxes = np.concatenate([gxy, gwh], axis=1)
anchor_mesh = self.generate_anchor(nGh, nGw, anchor_hw)
anchor_list = np.transpose(anchor_mesh,
(0, 2, 3, 1)).reshape(-1, 4)
iou_pdist = bbox_iou_np_expand(
anchor_list, tboxes, x1y1x2y2=False)
iou_max = np.max(iou_pdist, axis=1)
max_gt_index = np.argmax(iou_pdist, axis=1)
iou_map = iou_max.reshape(nA, nGh, nGw)
gt_index_map = max_gt_index.reshape(nA, nGh, nGw)
id_index = iou_map > self.ide_thresh
fg_index = iou_map > self.fg_thresh
bg_index = iou_map < self.bg_thresh
ign_index = (iou_map < self.fg_thresh) * (
iou_map > self.bg_thresh)
tconf[fg_index] = 1
tconf[bg_index] = 0
tconf[ign_index] = -1
gt_index = gt_index_map[fg_index]
gt_box_list = tboxes[gt_index]
gt_id_list = gt_ide[gt_index_map[id_index]]
if np.sum(fg_index) > 0:
tid[id_index] = gt_id_list
fg_anchor_list = anchor_list.reshape(nA, nGh, nGw,
4)[fg_index]
delta_target = self.encode_delta(gt_box_list,
fg_anchor_list)
tbox[fg_index] = delta_target
sample['tbox{}'.format(i)] = tbox
sample['tconf{}'.format(i)] = tconf
sample['tide{}'.format(i)] = tid
sample.pop('gt_class')
sample = self.pad_box(sample, num_max)
return samples
@register_op
class Gt2JDETargetMax(BaseOperator):
__shared__ = ['num_classes']
"""
Generate JDE targets by groud truth data when evaluating
Args:
anchors (list): anchors of JDE model
anchor_masks (list): anchor_masks of JDE model
downsample_ratios (list): downsample ratios of JDE model
max_iou_thresh (float): iou thresh for high quality anchor
num_classes (int): number of classes
"""
def __init__(self,
anchors,
anchor_masks,
downsample_ratios,
max_iou_thresh=0.60,
num_classes=1):
super(Gt2JDETargetMax, self).__init__()
self.anchors = anchors
self.anchor_masks = anchor_masks
self.downsample_ratios = downsample_ratios
self.max_iou_thresh = max_iou_thresh
self.num_classes = num_classes
def __call__(self, samples, context=None):
assert len(self.anchor_masks) == len(self.downsample_ratios), \
"anchor_masks', and 'downsample_ratios' should have same length."
h, w = samples[0]['image'].shape[1:3]
for sample in samples:
gt_bbox = sample['gt_bbox']
gt_ide = sample['gt_ide']
for i, (anchor_hw, downsample_ratio
) in enumerate(zip(self.anchors, self.downsample_ratios)):
anchor_hw = np.array(
anchor_hw, dtype=np.float32) / downsample_ratio
nA = len(anchor_hw)
nGh, nGw = int(h / downsample_ratio), int(w / downsample_ratio)
tbox = np.zeros((nA, nGh, nGw, 4), dtype=np.float32)
tconf = np.zeros((nA, nGh, nGw), dtype=np.float32)
tid = -np.ones((nA, nGh, nGw, 1), dtype=np.float32)
gxy, gwh = gt_bbox[:, 0:2].copy(), gt_bbox[:, 2:4].copy()
gxy[:, 0] = gxy[:, 0] * nGw
gxy[:, 1] = gxy[:, 1] * nGh
gwh[:, 0] = gwh[:, 0] * nGw
gwh[:, 1] = gwh[:, 1] * nGh
gi = np.clip(gxy[:, 0], 0, nGw - 1).astype(int)
gj = np.clip(gxy[:, 1], 0, nGh - 1).astype(int)
# iou of targets-anchors (using wh only)
box1 = gwh
box2 = anchor_hw[:, None, :]
inter_area = np.minimum(box1, box2).prod(2)
iou = inter_area / (
box1.prod(1) + box2.prod(2) - inter_area + 1e-16)
# Select best iou_pred and anchor
iou_best = iou.max(0) # best anchor [0-2] for each target
a = np.argmax(iou, axis=0)
# Select best unique target-anchor combinations
iou_order = np.argsort(-iou_best) # best to worst
# Unique anchor selection
u = np.stack((gi, gj, a), 0)[:, iou_order]
_, first_unique = np.unique(u, axis=1, return_index=True)
mask = iou_order[first_unique]
# best anchor must share significant commonality (iou) with target
# TODO: examine arbitrary threshold
idx = mask[iou_best[mask] > self.max_iou_thresh]
if len(idx) > 0:
a_i, gj_i, gi_i = a[idx], gj[idx], gi[idx]
t_box = gt_bbox[idx]
t_id = gt_ide[idx]
if len(t_box.shape) == 1:
t_box = t_box.reshape(1, 4)
gxy, gwh = t_box[:, 0:2].copy(), t_box[:, 2:4].copy()
gxy[:, 0] = gxy[:, 0] * nGw
gxy[:, 1] = gxy[:, 1] * nGh
gwh[:, 0] = gwh[:, 0] * nGw
gwh[:, 1] = gwh[:, 1] * nGh
# XY coordinates
tbox[:, :, :, 0:2][a_i, gj_i, gi_i] = gxy - gxy.astype(int)
# Width and height in yolo method
tbox[:, :, :, 2:4][a_i, gj_i, gi_i] = np.log(gwh /
anchor_hw[a_i])
tconf[a_i, gj_i, gi_i] = 1
tid[a_i, gj_i, gi_i] = t_id
sample['tbox{}'.format(i)] = tbox
sample['tconf{}'.format(i)] = tconf
sample['tide{}'.format(i)] = tid
......@@ -122,7 +122,8 @@ class Decode(BaseOperator):
im = sample['image']
data = np.frombuffer(im, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
if 'keep_ori_im' in sample and sample['keep_ori_im']:
sample['ori_image'] = im
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
sample['image'] = im
......@@ -1640,6 +1641,11 @@ class Mixup(BaseOperator):
(is_difficult1, is_difficult2), axis=0)
result['difficult'] = is_difficult
if 'gt_ide' in sample[0]:
gt_ide1 = sample[0]['gt_ide']
gt_ide2 = sample[1]['gt_ide']
gt_ide = np.concatenate((gt_ide1, gt_ide2), axis=0)
result['gt_ide'] = gt_ide
return result
......@@ -1736,6 +1742,11 @@ class PadBox(BaseOperator):
if gt_num > 0:
pad_crowd[:gt_num] = sample['is_crowd'][:gt_num, 0]
sample['is_crowd'] = pad_crowd
if 'gt_ide' in sample:
pad_ide = np.zeros((num_max, ), dtype=np.int32)
if gt_num > 0:
pad_ide[:gt_num] = sample['gt_ide'][:gt_num, 0]
sample['gt_ide'] = pad_ide
return sample
......@@ -1999,3 +2010,200 @@ class Rbox2Poly(BaseOperator):
polys = bbox_utils.rbox2poly(rrects)
sample['gt_rbox2poly'] = polys
return sample
@register_op
class AugmentHSV(BaseOperator):
def __init__(self, fraction=0.50, is_bgr=True):
"""
Augment the SV channel of image data.
Args:
fraction (float): the fraction for augment
is_bgr (bool): whether the image is BGR mode
"""
super(AugmentHSV, self).__init__()
self.fraction = fraction
self.is_bgr = is_bgr
def apply(self, sample, context=None):
img = sample['image']
if self.is_bgr:
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
else:
img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
S = img_hsv[:, :, 1].astype(np.float32)
V = img_hsv[:, :, 2].astype(np.float32)
a = (random.random() * 2 - 1) * self.fraction + 1
S *= a
if a > 1:
np.clip(S, a_min=0, a_max=255, out=S)
a = (random.random() * 2 - 1) * self.fraction + 1
V *= a
if a > 1:
np.clip(V, a_min=0, a_max=255, out=V)
img_hsv[:, :, 1] = S.astype(np.uint8)
img_hsv[:, :, 2] = V.astype(np.uint8)
if self.is_bgr:
cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)
else:
cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB, dst=img)
sample['image'] = img
return sample
@register_op
class Norm2PixelBbox(BaseOperator):
"""
Transform the bounding box's coornidates which is in [0,1] to pixels.
"""
def __init__(self):
super(Norm2PixelBbox, self).__init__()
def apply(self, sample, context=None):
assert 'gt_bbox' in sample
bbox = sample['gt_bbox']
height, width = sample['image'].shape[:2]
bbox[:, 0::2] = bbox[:, 0::2] * width
bbox[:, 1::2] = bbox[:, 1::2] * height
sample['gt_bbox'] = bbox
return sample
@register_op
class RandomAffine(BaseOperator):
def __init__(self,
degrees=(-5, 5),
translate=(0.10, 0.10),
scale=(0.50, 1.20),
shear=(-2, 2),
borderValue=(127.5, 127.5, 127.5)):
"""
Transform the image data with random affine
"""
super(RandomAffine, self).__init__()
self.degrees = degrees
self.translate = translate
self.scale = scale
self.shear = shear
self.borderValue = borderValue
def apply(self, sample, context=None):
# https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
border = 0 # width of added border (optional)
img = sample['image']
height, width = img.shape[0], img.shape[1]
# Rotation and Scale
R = np.eye(3)
a = random.random() * (self.degrees[1] - self.degrees[0]
) + self.degrees[0]
s = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0]
R[:2] = cv2.getRotationMatrix2D(
angle=a, center=(width / 2, height / 2), scale=s)
# Translation
T = np.eye(3)
T[0, 2] = (
random.random() * 2 - 1
) * self.translate[0] * height + border # x translation (pixels)
T[1, 2] = (
random.random() * 2 - 1
) * self.translate[1] * width + border # y translation (pixels)
# Shear
S = np.eye(3)
S[0, 1] = math.tan((random.random() *
(self.shear[1] - self.shear[0]) + self.shear[0]) *
math.pi / 180) # x shear (deg)
S[1, 0] = math.tan((random.random() *
(self.shear[1] - self.shear[0]) + self.shear[0]) *
math.pi / 180) # y shear (deg)
M = S @T @R # Combined rotation matrix. ORDER IS IMPORTANT HERE!!
imw = cv2.warpPerspective(
img,
M,
dsize=(width, height),
flags=cv2.INTER_LINEAR,
borderValue=self.borderValue) # BGR order borderValue
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
targets = sample['gt_bbox']
n = targets.shape[0]
points = targets.copy()
area0 = (points[:, 2] - points[:, 0]) * (
points[:, 3] - points[:, 1])
# warp points
xy = np.ones((n * 4, 3))
xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
n * 4, 2) # x1y1, x2y2, x1y2, x2y1
xy = (xy @M.T)[:, :2].reshape(n, 8)
# create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
xy = np.concatenate(
(x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
# apply angle-based reduction
radians = a * math.pi / 180
reduction = max(abs(math.sin(radians)), abs(math.cos(radians)))**0.5
x = (xy[:, 2] + xy[:, 0]) / 2
y = (xy[:, 3] + xy[:, 1]) / 2
w = (xy[:, 2] - xy[:, 0]) * reduction
h = (xy[:, 3] - xy[:, 1]) * reduction
xy = np.concatenate(
(x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
# reject warped points outside of image
np.clip(xy[:, 0], 0, width, out=xy[:, 0])
np.clip(xy[:, 2], 0, width, out=xy[:, 2])
np.clip(xy[:, 1], 0, height, out=xy[:, 1])
np.clip(xy[:, 3], 0, height, out=xy[:, 3])
w = xy[:, 2] - xy[:, 0]
h = xy[:, 3] - xy[:, 1]
area = w * h
ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16))
i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10)
if sum(i) > 0:
sample['gt_bbox'] = xy[i].astype(sample['gt_bbox'].dtype)
sample['gt_class'] = sample['gt_class'][i]
if 'difficult' in sample:
sample['difficult'] = sample['difficult'][i]
if 'gt_ide' in sample:
sample['gt_ide'] = sample['gt_ide'][i]
if 'is_crowd' in sample:
sample['is_crowd'] = sample['is_crowd'][i]
sample['image'] = imw
return sample
else:
return sample
@register_op
class BboxCXCYWH2XYXY(BaseOperator):
"""
Convert bbox CXCYWH format to XYXY format.
[center_x, center_y, width, height] -> [x0, y0, x1, y1]
"""
def __init__(self):
super(BboxCXCYWH2XYXY, self).__init__()
def apply(self, sample, context=None):
assert 'gt_bbox' in sample
bbox0 = sample['gt_bbox']
bbox = bbox0.copy()
bbox[:, :2] = bbox0[:, :2] - bbox0[:, 2:4] / 2.
bbox[:, 2:4] = bbox0[:, :2] + bbox0[:, 2:4] / 2.
sample['gt_bbox'] = bbox
return sample
......@@ -570,3 +570,52 @@ def pd_rbox2poly(rrects):
polys[:, 5] += y_ctr
polys[:, 7] += y_ctr
return polys
def bbox_iou_np_expand(box1, box2, x1y1x2y2=True, eps=1e-16):
"""
Calculate the iou of box1 and box2 with numpy.
Args:
box1 (ndarray): [N, 4]
box2 (ndarray): [M, 4], usually N != M
x1y1x2y2 (bool): whether in x1y1x2y2 stype, default True
eps (float): epsilon to avoid divide by zero
Return:
iou (ndarray): iou of box1 and box2, [N, M]
"""
N, M = len(box1), len(box2) # usually N != M
if x1y1x2y2:
b1_x1, b1_y1 = box1[:, 0], box1[:, 1]
b1_x2, b1_y2 = box1[:, 2], box1[:, 3]
b2_x1, b2_y1 = box2[:, 0], box2[:, 1]
b2_x2, b2_y2 = box2[:, 2], box2[:, 3]
else:
# cxcywh style
# Transform from center and width to exact coordinates
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
# get the coordinates of the intersection rectangle
inter_rect_x1 = np.zeros((N, M), dtype=np.float32)
inter_rect_y1 = np.zeros((N, M), dtype=np.float32)
inter_rect_x2 = np.zeros((N, M), dtype=np.float32)
inter_rect_y2 = np.zeros((N, M), dtype=np.float32)
for i in range(len(box2)):
inter_rect_x1[:, i] = np.maximum(b1_x1, b2_x1[i])
inter_rect_y1[:, i] = np.maximum(b1_y1, b2_y1[i])
inter_rect_x2[:, i] = np.minimum(b1_x2, b2_x2[i])
inter_rect_y2[:, i] = np.minimum(b1_y2, b2_y2[i])
# Intersection area
inter_area = np.maximum(inter_rect_x2 - inter_rect_x1, 0) * np.maximum(
inter_rect_y2 - inter_rect_y1, 0)
# Union Area
b1_area = np.repeat(
((b1_x2 - b1_x1) * (b1_y2 - b1_y1)).reshape(-1, 1), M, axis=-1)
b2_area = np.repeat(
((b2_x2 - b2_x1) * (b2_y2 - b2_y1)).reshape(1, -1), N, axis=0)
ious = inter_area / (b1_area + b2_area - inter_area + eps)
return ious
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册