未验证 提交 08f3c0be 编写于 作者: S shuluoshu 提交者: GitHub

add M3D-RPN model (#4822)

* Add M3d-RPN model.
Co-authored-by: Nyexiaoqing <yexiaoqing@baidu.com>
上级 a33f0814
# M3D-RPN: Monocular 3D Region Proposal Network for Object Detection
## Introduction
Monocular 3D region proposal network for object detection accepted to ICCV 2019 (Oral), detailed in [arXiv report](https://arxiv.org/abs/1907.06038).
## Setup
- **Cuda & Python**
In this project we utilize PaddlePaddle1.8 with Python 3, Cuda 9, and a few Anaconda packages.
- **Data**
Download the full [KITTI](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d) detection dataset. Then place a softlink (or the actual data) in *M3D-RPN/data/kitti*.
```
cd M3D-RPN
ln -s /path/to/kitti dataset/kitti
```
Then use the following scripts to extract the data splits, which use softlinks to the above directory for efficient storage.
```
python dataset/kitti_split1/setup_split.py
python dataset/kitti_split2/setup_split.py
```
Next, build the KITTI devkit eval for each split.
```
sh dataset/kitti_split1/devkit/cpp/build.sh
sh dataset/kitti_split2/devkit/cpp/build.sh
```
Lastly, build the nms modules
```
cd lib/nms
make
```
## Training
Training is split into a warmup and main configurations. Review the configurations in *config* for details.
```
// First train the warmup (without depth-aware)
python train.py --config=kitti_3d_multi_warmup
// Then train the main experiment (with depth-aware)
python train.py --config=kitti_3d_multi_main
```
## Testing
We provide models for the main experiments on val1 data splits available to download here [M3D-RPN-release.tar](https://pan.baidu.com/s/1VQa5hGzIbauLOQi-0kR9Hg), passward:ls39.
Testing requires paths to the configuration file and model weights, exposed variables near the top *test.py*. To test a configuration and model, simply update the variables and run the test file as below.
```
python test.py --conf_path M3D-RPN-release/conf.pkl --weights_path M3D-RPN-release/iter50000.0_params.pdparams
```
"""
config of main
"""
from easydict import EasyDict as edict
import numpy as np
def Config():
"""
config
"""
conf = edict()
# ----------------------------------------
# general
# ----------------------------------------
conf.model = 'model_3d_dilate_depth_aware'
# solver settings
conf.solver_type = 'sgd'
conf.lr = 0.004
conf.momentum = 0.9
conf.weight_decay = 0.0005
conf.max_iter = 50000
conf.snapshot_iter = 10000
conf.display = 20
conf.do_test = True
# sgd parameters
conf.lr_policy = 'poly'
conf.lr_steps = None
conf.lr_target = conf.lr * 0.00001
# random
conf.rng_seed = 2
conf.cuda_seed = 2
# misc network
conf.image_means = [0.485, 0.456, 0.406]
conf.image_stds = [0.229, 0.224, 0.225]
conf.feat_stride = 16
conf.has_3d = True
# ----------------------------------------
# image sampling and datasets
# ----------------------------------------
# scale sampling
conf.test_scale = 512
conf.crop_size = [512, 1760]
conf.mirror_prob = 0.50
conf.distort_prob = -1
# datasets
conf.dataset_test = 'kitti_split1'
conf.datasets_train = [{
'name': 'kitti_split1',
'anno_fmt': 'kitti_det',
'im_ext': '.png',
'scale': 1
}]
conf.use_3d_for_2d = True
# percent expected height ranges based on test_scale
# used for anchor selection
conf.percent_anc_h = [0.0625, 0.75]
# labels settings
conf.min_gt_h = conf.test_scale * conf.percent_anc_h[0]
conf.max_gt_h = conf.test_scale * conf.percent_anc_h[1]
conf.min_gt_vis = 0.65
conf.ilbls = ['Van', 'ignore']
conf.lbls = ['Car', 'Pedestrian', 'Cyclist']
# ----------------------------------------
# detection sampling
# ----------------------------------------
# detection sampling
conf.batch_size = 2
conf.fg_image_ratio = 1.0
conf.box_samples = 0.20
conf.fg_fraction = 0.20
conf.bg_thresh_lo = 0
conf.bg_thresh_hi = 0.5
conf.fg_thresh = 0.5
conf.ign_thresh = 0.5
conf.best_thresh = 0.35
# ----------------------------------------
# inference and testing
# ----------------------------------------
# nms
conf.nms_topN_pre = 3000
conf.nms_topN_post = 40
conf.nms_thres = 0.4
conf.clip_boxes = False
conf.test_protocol = 'kitti'
conf.test_db = 'kitti'
conf.test_min_h = 0
conf.min_det_scales = [0, 0]
# ----------------------------------------
# anchor settings
# ----------------------------------------
# clustering settings
conf.cluster_anchors = 0
conf.even_anchors = 0
conf.expand_anchors = 0
conf.anchors = None
conf.bbox_means = None
conf.bbox_stds = None
# initialize anchors
base = (conf.max_gt_h / conf.min_gt_h)**(1 / (12 - 1))
conf.anchor_scales = np.array(
[conf.min_gt_h * (base**i) for i in range(0, 12)])
conf.anchor_ratios = np.array([0.5, 1.0, 1.5])
# loss logic
conf.hard_negatives = True
conf.focal_loss = 0
conf.cls_2d_lambda = 1
conf.iou_2d_lambda = 1
conf.bbox_2d_lambda = 0
conf.bbox_3d_lambda = 1
conf.bbox_3d_proj_lambda = 0.0
conf.hill_climbing = True
conf.bins = 32
# visdom
conf.visdom_port = 8100
conf.pretrained = 'paddle.pdparams'
return conf
"""
config of warmup
"""
from easydict import EasyDict as edict
import numpy as np
def Config():
"""
config
"""
conf = edict()
# ----------------------------------------
# general
# ----------------------------------------
conf.model = 'model_3d_dilate'
# solver settings
conf.solver_type = 'sgd'
conf.lr = 0.004
conf.momentum = 0.9
conf.weight_decay = 0.0005
conf.max_iter = 50000
conf.snapshot_iter = 10000
conf.display = 20
conf.do_test = True
# sgd parameters
conf.lr_policy = 'poly'
conf.lr_steps = None
conf.lr_target = conf.lr * 0.00001
# random
conf.rng_seed = 2
conf.cuda_seed = 2
# misc network
conf.image_means = [0.485, 0.456, 0.406]
conf.image_stds = [0.229, 0.224, 0.225]
conf.feat_stride = 16
conf.has_3d = True
# ----------------------------------------
# image sampling and datasets
# ----------------------------------------
# scale sampling
conf.test_scale = 512
conf.crop_size = [512, 1760]
conf.mirror_prob = 0.50
conf.distort_prob = -1
# datasets
conf.dataset_test = 'kitti_split1'
conf.datasets_train = [{
'name': 'kitti_split1',
'anno_fmt': 'kitti_det',
'im_ext': '.png',
'scale': 1
}]
conf.use_3d_for_2d = True
# percent expected height ranges based on test_scale
# used for anchor selection
conf.percent_anc_h = [0.0625, 0.75]
# labels settings
conf.min_gt_h = conf.test_scale * conf.percent_anc_h[0]
conf.max_gt_h = conf.test_scale * conf.percent_anc_h[1]
conf.min_gt_vis = 0.65
conf.ilbls = ['Van', 'ignore']
conf.lbls = ['Car', 'Pedestrian', 'Cyclist']
# ----------------------------------------
# detection sampling
# ----------------------------------------
# detection sampling
conf.batch_size = 2
conf.fg_image_ratio = 1.0
conf.box_samples = 0.20
conf.fg_fraction = 0.20
conf.bg_thresh_lo = 0
conf.bg_thresh_hi = 0.5
conf.fg_thresh = 0.5
conf.ign_thresh = 0.5
conf.best_thresh = 0.35
# ----------------------------------------
# inference and testing
# ----------------------------------------
# nms
conf.nms_topN_pre = 3000
conf.nms_topN_post = 40
conf.nms_thres = 0.4
conf.clip_boxes = False
conf.test_protocol = 'kitti'
conf.test_db = 'kitti'
conf.test_min_h = 0
conf.min_det_scales = [0, 0]
# ----------------------------------------
# anchor settings
# ----------------------------------------
# clustering settings
conf.cluster_anchors = 0
conf.even_anchors = 0
conf.expand_anchors = 0
conf.anchors = None
conf.bbox_means = None
conf.bbox_stds = None
# initialize anchors
base = (conf.max_gt_h / conf.min_gt_h)**(1 / (12 - 1))
conf.anchor_scales = np.array(
[conf.min_gt_h * (base**i) for i in range(0, 12)])
conf.anchor_ratios = np.array([0.5, 1.0, 1.5])
# loss logic
conf.hard_negatives = True
conf.focal_loss = 0
conf.cls_2d_lambda = 1
conf.iou_2d_lambda = 1
conf.bbox_2d_lambda = 0
conf.bbox_3d_lambda = 1
conf.bbox_3d_proj_lambda = 0.0
conf.hill_climbing = True
conf.pretrained = 'pretrained_model/densenet.pdparams'
# visdom
conf.visdom_port = 8100
return conf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
init
"""
from . import m3drpn_reader
#from .m3drpn_reader import *
#__all__ = m3drpn_reader.__all__
\ No newline at end of file
"""
This code is based on https://github.com/garrickbrazil/M3D-RPN/blob/master/lib/augmentations.py
This file contains all data augmentation functions.
Every transform should have a __call__ function which takes in (self, image, imobj)
where imobj is an arbitary dict containing relevant information to the image.
In many cases the imobj can be None, which enables the same augmentations to be used
during testing as they are in training.
Optionally, most transforms should have an __init__ function as well, if needed.
"""
import numpy as np
from numpy import random
import cv2
import math
import os
import sys
import lib.util as util
class Compose(object):
"""
Composes a set of functions which take in an image and an object, into a single transform
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, imobj=None):
for t in self.transforms:
img, imobj = t(img, imobj)
return img, imobj
class ConvertToFloat(object):
"""
Converts image data type to float.
"""
def __call__(self, image, imobj=None):
return image.astype(np.float32), imobj
class Normalize(object):
"""
Normalize the image
"""
def __init__(self, mean, stds):
self.mean = np.array(mean, dtype=np.float32)
self.stds = np.array(stds, dtype=np.float32)
def __call__(self, image, imobj=None):
image = image.astype(np.float32)
image /= 255.0
image -= np.tile(self.mean, int(image.shape[2] / self.mean.shape[0]))
image /= np.tile(self.stds, int(image.shape[2] / self.stds.shape[0]))
return image.astype(np.float32), imobj
class Resize(object):
"""
Resize the image according to the target size height and the image height.
If the image needs to be cropped after the resize, we crop it to self.size,
otherwise we pad it with zeros along the right edge
If the object has ground truths we also scale the (known) box coordinates.
"""
def __init__(self, size):
self.size = size
def __call__(self, image, imobj=None):
scale_factor = self.size[0] / image.shape[0]
h = np.round(image.shape[0] * scale_factor).astype(int)
w = np.round(image.shape[1] * scale_factor).astype(int)
# resize
image = cv2.resize(image, (w, h))
if len(self.size) > 1:
# crop in
if image.shape[1] > self.size[1]:
image = image[:, 0:self.size[1], :]
# pad out
elif image.shape[1] < self.size[1]:
padW = self.size[1] - image.shape[1]
image = np.pad(image, [(0, 0), (0, padW), (0, 0)], 'constant')
if imobj:
# store scale factor, just in case
imobj.scale_factor = scale_factor
if 'gts' in imobj:
# scale all coordinates
for gtind, gt in enumerate(imobj.gts):
if 'bbox_full' in imobj.gts[gtind]:
imobj.gts[gtind].bbox_full *= scale_factor
if 'bbox_vis' in imobj.gts[gtind]:
imobj.gts[gtind].bbox_vis *= scale_factor
if 'bbox_3d' in imobj.gts[gtind]:
# only scale x/y center locations (in 2D space!)
imobj.gts[gtind].bbox_3d[0] *= scale_factor
imobj.gts[gtind].bbox_3d[1] *= scale_factor
if 'gts_pre' in imobj:
# scale all coordinates
for gtind, gt in enumerate(imobj.gts_pre):
if 'bbox_full' in imobj.gts_pre[gtind]:
imobj.gts_pre[gtind].bbox_full *= scale_factor
if 'bbox_vis' in imobj.gts_pre[gtind]:
imobj.gts_pre[gtind].bbox_vis *= scale_factor
if 'bbox_3d' in imobj.gts_pre[gtind]:
# only scale x/y center locations (in 2D space!)
imobj.gts_pre[gtind].bbox_3d[0] *= scale_factor
imobj.gts_pre[gtind].bbox_3d[1] *= scale_factor
return image, imobj
class RandomSaturation(object):
"""
Randomly adjust the saturation of an image given a lower and upper bound,
and a distortion probability.
This function assumes the image is in HSV!!
"""
def __init__(self, distort_prob, lower=0.5, upper=1.5):
self.distort_prob = distort_prob
self.lower = lower
self.upper = upper
assert self.upper >= self.lower, "contrast upper must be >= lower."
assert self.lower >= 0, "contrast lower must be non-negative."
def __call__(self, image, imobj=None):
if random.rand() <= self.distort_prob:
image[:, :, 1] *= random.uniform(self.lower, self.upper)
return image, imobj
class RandomHue(object):
"""
Randomly adjust the hue of an image given a delta degree to rotate by,
and a distortion probability.
This function assumes the image is in HSV!!
"""
def __init__(self, distort_prob, delta=18.0):
assert delta >= 0.0 and delta <= 360.0
self.delta = delta
self.distort_prob = distort_prob
def __call__(self, image, imobj=None):
if random.rand() <= self.distort_prob:
image[:, :, 0] += random.uniform(-self.delta, self.delta)
image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0
image[:, :, 0][image[:, :, 0] < 0.0] += 360.0
return image, imobj
class ConvertColor(object):
"""
Converts color spaces to/from HSV and BGR
"""
def __init__(self, current='BGR', transform='HSV'):
self.transform = transform
self.current = current
def __call__(self, image, imobj=None):
# BGR --> HSV
if self.current == 'BGR' and self.transform == 'HSV':
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# HSV --> BGR
elif self.current == 'HSV' and self.transform == 'BGR':
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
else:
raise NotImplementedError
return image, imobj
class RandomContrast(object):
"""
Randomly adjust contrast of an image given lower and upper bound,
and a distortion probability.
"""
def __init__(self, distort_prob, lower=0.5, upper=1.5):
self.lower = lower
self.upper = upper
self.distort_prob = distort_prob
assert self.upper >= self.lower, "contrast upper must be >= lower."
assert self.lower >= 0, "contrast lower must be non-negative."
def __call__(self, image, imobj=None):
if random.rand() <= self.distort_prob:
alpha = random.uniform(self.lower, self.upper)
image *= alpha
return image, imobj
class RandomMirror(object):
"""
Randomly mirror an image horzontially, given a mirror probabilty.
Also, adjust all box cordinates accordingly.
"""
def __init__(self, mirror_prob):
self.mirror_prob = mirror_prob
def __call__(self, image, imobj):
_, width, _ = image.shape
if random.rand() <= self.mirror_prob:
image = image[:, ::-1, :]
image = np.ascontiguousarray(image)
# flip the coordinates w.r.t the horizontal flip (only adjust X)
for gtind, gt in enumerate(imobj.gts):
if 'bbox_full' in imobj.gts[gtind]:
imobj.gts[gtind].bbox_full[0] = image.shape[
1] - gt.bbox_full[0] - gt.bbox_full[2]
if 'bbox_vis' in imobj.gts[gtind]:
imobj.gts[gtind].bbox_vis[0] = image.shape[1] - gt.bbox_vis[
0] - gt.bbox_vis[2]
if 'bbox_3d' in imobj.gts[gtind]:
imobj.gts[gtind].bbox_3d[0] = image.shape[1] - gt.bbox_3d[
0] - 1
rotY = gt.bbox_3d[10]
rotY = (-math.pi - rotY) if rotY < 0 else (math.pi - rotY)
while rotY > math.pi:
rotY -= math.pi * 2
while rotY < (-math.pi):
rotY += math.pi * 2
cx2d = gt.bbox_3d[0]
cy2d = gt.bbox_3d[1]
cz2d = gt.bbox_3d[2]
coord3d = imobj.p2_inv.dot(
np.array([cx2d * cz2d, cy2d * cz2d, cz2d, 1]))
alpha = util.convertRot2Alpha(rotY, coord3d[2], coord3d[0])
imobj.gts[gtind].bbox_3d[10] = rotY
imobj.gts[gtind].bbox_3d[6] = alpha
return image, imobj
class RandomBrightness(object):
"""
Randomly adjust the brightness of an image given given a +- delta range,
and a distortion probability.
"""
def __init__(self, distort_prob, delta=32):
assert delta >= 0.0
assert delta <= 255.0
self.delta = delta
self.distort_prob = distort_prob
def __call__(self, image, imobj=None):
if random.rand() <= self.distort_prob:
delta = random.uniform(-self.delta, self.delta)
image += delta
return image, imobj
class PhotometricDistort(object):
"""
Packages all photometric distortions into a single transform.
"""
def __init__(self, distort_prob):
self.distort_prob = distort_prob
# contrast is duplicated because it may happen before or after
# the other transforms with equal probability.
self.transforms = [
RandomContrast(distort_prob), ConvertColor(transform='HSV'),
RandomSaturation(distort_prob), RandomHue(distort_prob),
ConvertColor(
current='HSV', transform='BGR'), RandomContrast(distort_prob)
]
self.rand_brightness = RandomBrightness(distort_prob)
def __call__(self, image, imobj):
# do contrast first
if random.rand() <= 0.5:
distortion = self.transforms[:-1]
# do contrast last
else:
distortion = self.transforms[1:]
# add random brightness
distortion.insert(0, self.rand_brightness)
# compose transformation
distortion = Compose(distortion)
return distortion(image.copy(), imobj)
class Augmentation(object):
"""
Data Augmentation class which packages the typical pre-processing
and all data augmentation transformations (mirror and photometric distort)
into a single transform.
"""
def __init__(self, conf):
self.mean = conf.image_means
self.stds = conf.image_stds
self.size = conf.crop_size
self.mirror_prob = conf.mirror_prob
self.distort_prob = conf.distort_prob
if conf.distort_prob <= 0:
self.augment = Compose([
ConvertToFloat(), RandomMirror(self.mirror_prob),
Resize(self.size), Normalize(self.mean, self.stds)
])
else:
self.augment = Compose([
ConvertToFloat(), PhotometricDistort(self.distort_prob),
RandomMirror(self.mirror_prob), Resize(self.size),
Normalize(self.mean, self.stds)
])
def __call__(self, img, imobj):
return self.augment(img, imobj)
class Preprocess(object):
"""
Preprocess function which ONLY does the basic pre-processing of an image,
meant to be used during the testing/eval stages.
"""
def __init__(self, size, mean, stds):
self.mean = mean
self.stds = stds
self.size = size
self.preprocess = Compose([
ConvertToFloat(), Resize(self.size), Normalize(self.mean, self.stds)
])
def __call__(self, img):
img = self.preprocess(img)[0]
for i in range(int(img.shape[2] / 3)):
# convert to RGB then permute to be [B C H W]
img[:, :, (i * 3):(i * 3) + 3] = img[:, :, (i * 3 + 2, i * 3 + 1, i
* 3)]
img = np.transpose(img, [2, 0, 1])
return img
# -*- coding:utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
data reader
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import os.path as osp
import signal
import numpy as np
import random
import logging
import math
import copy
import glob
import time
import re
from PIL import Image
import lib.util as util
from lib.rpn_util import *
import data.augmentations as augmentations
from easydict import EasyDict as edict
import cv2
import pdb
__all__ = ["M3drpnReader"]
logger = logging.getLogger(__name__)
class M3drpnReader(object):
"""m3drpn reader"""
def __init__(self, conf, data_dir):
self.data_dir = data_dir
self.conf = conf
self.video_det = False if not ('video_det' in conf) else conf.video_det
self.video_count = 1 if not (
'video_count' in conf) else conf.video_count
self.use_3d_for_2d = ('use_3d_for_2d' in conf) and conf.use_3d_for_2d
self.load_data()
self.transform = augmentations.Augmentation(conf)
def _read_data_file(self, fname):
assert osp.isfile(fname), \
"{} is not a file".format(fname)
with open(fname) as f:
return [line.strip() for line in f]
def load_data(self):
"""load data"""
logger.info("Loading KITTI dataset from {} ...".format(self.data_dir))
# read all_files.txt
for dbind, db in enumerate(self.conf.datasets_train):
logging.info('Loading imgs_label {}'.format(db['name']))
# single imdb
imdb_single_db = []
# kitti formatting
if db['anno_fmt'].lower() == 'kitti_det':
train_folder = os.path.join(self.data_dir, db['name'],
'training')
ann_folder = os.path.join(
train_folder, 'label_2',
'') # dataset/kitti_split1/training/image_2/
cal_folder = os.path.join(train_folder, 'calib', '')
im_folder = os.path.join(train_folder, 'image_2', '')
# get sorted filepaths
annlist = sorted(glob(ann_folder + '*.txt')) # 3712
imdb_start = time()
self.affine_size = None if not (
'affine_size' in self.conf) else self.conf.affine_size
for annind, annpath in enumerate(annlist):
# get file parts
base = os.path.basename(annpath)
id, ext = os.path.splitext(base)
calpath = os.path.join(cal_folder, id + '.txt')
impath = os.path.join(im_folder, id + db['im_ext'])
impath_pre = os.path.join(train_folder, 'prev_2',
id + '_01' + db['im_ext'])
impath_pre2 = os.path.join(train_folder, 'prev_2',
id + '_02' + db['im_ext'])
impath_pre3 = os.path.join(train_folder, 'prev_2',
id + '_03' + db['im_ext'])
# read gts
p2 = read_kitti_cal(calpath)
p2_inv = np.linalg.inv(p2)
gts = read_kitti_label(annpath, p2, self.use_3d_for_2d)
obj = edict()
# store gts
obj.id = id
obj.gts = gts
obj.p2 = p2
obj.p2_inv = p2_inv
# im properties
im = Image.open(impath)
obj.path = impath
obj.path_pre = impath_pre
obj.path_pre2 = impath_pre2
obj.path_pre3 = impath_pre3
obj.imW, obj.imH = im.size
# database properties
obj.dbname = db.name
obj.scale = db.scale
obj.dbind = dbind
obj.affine_gt = None # did not compute transformer
# store
imdb_single_db.append(obj)
if (annind % 1000) == 0 and annind > 0:
time_str, dt = util.compute_eta(imdb_start, annind,
len(annlist))
logging.info('{}/{}, dt: {:0.4f}, eta: {}'.format(
annind, len(annlist), dt, time_str))
self.data = {}
self.data['train'] = imdb_single_db
self.data['test'] = {}
self.len = len(imdb_single_db)
self.sampled_weights = balance_samples(self.conf, imdb_single_db)
def _augmented_single(self, index):
"""
Grabs the item at the given index. Specifically,
- read the image from disk
- read the imobj from RAM
- applies data augmentation to (im, imobj)
- converts image to RGB and [B C W H]
"""
if not self.video_det:
# read image
im = cv2.imread(self.data['train'][index].path)
else:
# read images
im = cv2.imread(self.data['train'][index].path)
video_count = 1 if self.video_count is None else self.video_count
if video_count >= 2:
im_pre = cv2.imread(self.data['train'][index].path_pre)
if not im_pre.shape == im.shape:
im_pre = cv2.resize(im_pre, (im.shape[1], im.shape[0]))
im = np.concatenate((im, im_pre), axis=2)
if video_count >= 3:
im_pre2 = cv2.imread(self.data['train'][index].path_pre2)
if im_pre2 is None:
im_pre2 = im_pre
if not im_pre2.shape == im.shape:
im_pre2 = cv2.resize(im_pre2, (im.shape[1], im.shape[0]))
im = np.concatenate((im, im_pre2), axis=2)
if video_count >= 4:
im_pre3 = cv2.imread(self.data['train'][index].path_pre3)
if im_pre3 is None:
im_pre3 = im_pre2
if not im_pre3.shape == im.shape:
im_pre3 = cv2.resize(im_pre3, (im.shape[1], im.shape[0]))
im = np.concatenate((im, im_pre3), axis=2)
# transform / data augmentation
im, imobj = self.transform(im, copy.deepcopy(self.data['train'][index]))
for i in range(int(im.shape[2] / 3)):
# convert to RGB then permute to be [B C H W]
im[:, :, (i * 3):(i * 3) + 3] = im[:, :, (i * 3 + 2, i * 3 + 1, i *
3)]
im = np.transpose(im, [2, 0, 1])
return im, imobj
def get_reader(self, batch_size, mode='train', shuffle=True):
"""
get reader
"""
assert mode in ['train', 'test'], \
"mode can only be 'train' or 'test'"
imgs = self.data[mode]
idxs = np.arange(len(imgs))
idxs = np.random.choice(
self.len, self.len, replace=True, p=self.sampled_weights)
if mode == 'train' and shuffle:
np.random.shuffle(idxs)
def reader():
"""reader"""
batch_out = []
for ind in idxs:
augmented_img, im_obj = self._augmented_single(ind)
batch_out.append([augmented_img, im_obj])
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
return reader
# derived from M3D-RPN
def read_kitti_cal(calfile):
"""
Reads the kitti calibration projection matrix (p2) file from disc.
Args:
calfile (str): path to single calibration file
"""
text_file = open(calfile, 'r')
p2pat = re.compile((
'(P2:)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)'
+ '\s+(fpat)\s+(fpat)\s+(fpat)\s*\n').replace(
'fpat', '[-+]?[\d]+\.?[\d]*[Ee](?:[-+]?[\d]+)?'))
for line in text_file:
parsed = p2pat.fullmatch(line)
# bbGt annotation in text format of:
# cls x y w h occ x y w h ign ang
if parsed is not None:
p2 = np.zeros([4, 4], dtype=float)
p2[0, 0] = parsed.group(2)
p2[0, 1] = parsed.group(3)
p2[0, 2] = parsed.group(4)
p2[0, 3] = parsed.group(5)
p2[1, 0] = parsed.group(6)
p2[1, 1] = parsed.group(7)
p2[1, 2] = parsed.group(8)
p2[1, 3] = parsed.group(9)
p2[2, 0] = parsed.group(10)
p2[2, 1] = parsed.group(11)
p2[2, 2] = parsed.group(12)
p2[2, 3] = parsed.group(13)
p2[3, 3] = 1
text_file.close()
return p2
def balance_samples(conf, imdb):
"""
Balances the samples in an image dataset according to the given configuration.
Basically we check which images have relevant foreground samples and which are empty,
then we compute the sampling weights according to a desired fg_image_ratio.
This is primarily useful in datasets which have a lot of empty (background) images, which may
cause instability during training if not properly balanced against.
"""
sample_weights = np.ones(len(imdb))
if conf.fg_image_ratio >= 0:
empty_inds = []
valid_inds = []
for imind, imobj in enumerate(imdb):
valid = 0
scale = conf.test_scale / imobj.imH
igns, rmvs = determine_ignores(imobj.gts, conf.lbls, conf.ilbls,
conf.min_gt_vis, conf.min_gt_h,
conf.max_gt_h, scale)
for gtind, gt in enumerate(imobj.gts):
if (not igns[gtind]) and (not rmvs[gtind]):
valid += 1
sample_weights[imind] = valid
if valid > 0:
valid_inds.append(imind)
else:
empty_inds.append(imind)
if not (conf.fg_image_ratio == 2):
fg_weight = 1
bg_weight = 0
sample_weights[valid_inds] = fg_weight
sample_weights[empty_inds] = bg_weight
logging.info('weighted respectively as {:.2f} and {:.2f}'.format(
fg_weight, bg_weight))
logging.info('Found {} foreground and {} empty images'.format(
np.sum(sample_weights > 0), np.sum(sample_weights <= 0)))
# force sampling weights to sum to 1
sample_weights /= np.sum(sample_weights)
return sample_weights
def read_kitti_poses(posefile):
"""
read_kitti_poses
"""
text_file = open(posefile, 'r')
ppat1 = re.compile((
'(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)'
+ '\s+(fpat)\s+(fpat)\s+(fpat)\s*\n').replace(
'fpat', '[-+]?[\d]+\.?[\d]*[Ee](?:[-+]?[\d]+)?'))
ppat2 = re.compile((
'(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)'
+ '\s+(fpat)\s+(fpat)\s+(fpat)\s*\n').replace('fpat',
'[-+]?[\d]+\.?[\d]*'))
ps = []
for line in text_file:
parsed1 = ppat1.fullmatch(line)
parsed2 = ppat2.fullmatch(line)
if parsed1 is not None:
p = np.zeros([4, 4], dtype=float)
p[0, 0] = parsed1.group(1)
p[0, 1] = parsed1.group(2)
p[0, 2] = parsed1.group(3)
p[0, 3] = parsed1.group(4)
p[1, 0] = parsed1.group(5)
p[1, 1] = parsed1.group(6)
p[1, 2] = parsed1.group(7)
p[1, 3] = parsed1.group(8)
p[2, 0] = parsed1.group(9)
p[2, 1] = parsed1.group(10)
p[2, 2] = parsed1.group(11)
p[2, 3] = parsed1.group(12)
p[3, 3] = 1
ps.append(p)
elif parsed2 is not None:
p = np.zeros([4, 4], dtype=float)
p[0, 0] = parsed2.group(1)
p[0, 1] = parsed2.group(2)
p[0, 2] = parsed2.group(3)
p[0, 3] = parsed2.group(4)
p[1, 0] = parsed2.group(5)
p[1, 1] = parsed2.group(6)
p[1, 2] = parsed2.group(7)
p[1, 3] = parsed2.group(8)
p[2, 0] = parsed2.group(9)
p[2, 1] = parsed2.group(10)
p[2, 2] = parsed2.group(11)
p[2, 3] = parsed2.group(12)
p[3, 3] = 1
ps.append(p)
text_file.close()
return ps
def read_kitti_label(file, p2, use_3d_for_2d=False):
"""
Reads the kitti label file from disc.
Args:
file (str): path to single label file for an image
p2 (ndarray): projection matrix for the given image
"""
gts = []
text_file = open(file, 'r')
'''
Values Name Description
----------------------------------------------------------------------------
1 type Describes the type of object: 'Car', 'Van', 'Truck',
'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram',
'Misc' or 'DontCare'
1 truncated Float from 0 (non-truncated) to 1 (truncated), where
truncated refers to the object leaving image boundaries
1 occluded Integer (0,1,2,3) indicating occlusion state:
0 = fully visible, 1 = partly occluded
2 = largely occluded, 3 = unknown
1 alpha Observation angle of object, ranging [-pi..pi]
4 bbox 2D bounding box of object in the image (0-based index):
contains left, top, right, bottom pixel coordinates
3 dimensions 3D object dimensions: height, width, length (in meters)
3 location 3D object location x,y,z in camera coordinates (in meters)
1 rotation_y Rotation ry around Y-axis in camera coordinates [-pi..pi]
1 score Only for results: Float, indicating confidence in
detection, needed for p/r curves, higher is better.
'''
pattern = re.compile((
'([a-zA-Z\-\?\_]+)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+'
+
'(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s*((fpat)?)\n'
).replace('fpat', '[-+]?\d*\.\d+|[-+]?\d+'))
for line in text_file:
parsed = pattern.fullmatch(line)
# bbGt annotation in text format of:
# cls x y w h occ x y w h ign ang
if parsed is not None:
obj = edict()
ign = False
cls = parsed.group(1)
trunc = float(parsed.group(2))
occ = float(parsed.group(3))
alpha = float(parsed.group(4))
x = float(parsed.group(5))
y = float(parsed.group(6))
x2 = float(parsed.group(7))
y2 = float(parsed.group(8))
width = x2 - x + 1
height = y2 - y + 1
h3d = float(parsed.group(9))
w3d = float(parsed.group(10))
l3d = float(parsed.group(11))
cx3d = float(parsed.group(12)) # center of car in 3d
cy3d = float(parsed.group(13)) # bottom of car in 3d
cz3d = float(parsed.group(14)) # center of car in 3d
rotY = float(parsed.group(15))
# actually center the box
cy3d -= (h3d / 2)
elevation = (1.65 - cy3d)
if use_3d_for_2d and h3d > 0 and w3d > 0 and l3d > 0:
# re-compute the 2D box using 3D (finally, avoids clipped boxes)
verts3d, corners_3d = project_3d(
p2, cx3d, cy3d, cz3d, w3d, h3d, l3d, rotY, return_3d=True)
# any boxes behind camera plane?
if np.any(corners_3d[2, :] <= 0):
ign = True
else:
x = min(verts3d[:, 0])
y = min(verts3d[:, 1])
x2 = max(verts3d[:, 0])
y2 = max(verts3d[:, 1])
width = x2 - x + 1
height = y2 - y + 1
# project cx, cy, cz
coord3d = p2.dot(np.array([cx3d, cy3d, cz3d, 1]))
# store the projected instead
cx3d_2d = coord3d[0]
cy3d_2d = coord3d[1]
cz3d_2d = coord3d[2]
cx = cx3d_2d / cz3d_2d
cy = cy3d_2d / cz3d_2d
# encode occlusion with range estimation
# 0 = fully visible, 1 = partly occluded
# 2 = largely occluded, 3 = unknown
if occ == 0:
vis = 1
elif occ == 1:
vis = 0.66
elif occ == 2:
vis = 0.33
else:
vis = 0.0
while rotY > math.pi:
rotY -= math.pi * 2
while rotY < (-math.pi):
rotY += math.pi * 2
# recompute alpha
alpha = util.convertRot2Alpha(rotY, cz3d, cx3d)
obj.elevation = elevation
obj.cls = cls
obj.occ = occ > 0
obj.ign = ign
obj.visibility = vis
obj.trunc = trunc
obj.alpha = alpha
obj.rotY = rotY
# is there an extra field? (assume to be track)
if len(parsed.groups()) >= 16 and parsed.group(16).isdigit():
obj.track = int(parsed.group(16))
obj.bbox_full = np.array([x, y, width, height])
obj.bbox_3d = [
cx, cy, cz3d_2d, w3d, h3d, l3d, alpha, cx3d, cy3d, cz3d, rotY
]
obj.center_3d = [cx3d, cy3d, cz3d]
gts.append(obj)
text_file.close()
return gts
def _term_reader(signum, frame):
"""_term_reader"""
logger.info('pid {} terminated, terminate reader process '
'group {}...'.format(os.getpid(), os.getpgrp()))
os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)
signal.signal(signal.SIGINT, _term_reader)
signal.signal(signal.SIGTERM, _term_reader)
"""
This code is based on https://github.com/garrickbrazil/M3D-RPN/blob/master/lib/core.py
This file is meant to contain all functions of the detective framework
which are "specific" to the framework but generic among experiments.
For example, all the experiments need to initialize configs, training models,
log stats, display stats, and etc. However, these functions are generally fixed
to this framework and cannot be easily transferred in other projects.
"""
# -----------------------------------------
# python modules
# -----------------------------------------
from easydict import EasyDict as edict
from shapely.geometry import Polygon
#import matplotlib.pyplot as plt
from copy import copy
import importlib
import random
#import visdom
#import torch
import paddle.fluid as fluid
import paddle
import shutil
import sys
import os
import cv2
import math
import numpy as np
import struct
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.proto.framework_pb2 import VarType
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
# stop python from writing so much bytecode
sys.dont_write_bytecode = True
# -----------------------------------------
# custom modules
# -----------------------------------------
from lib.util import *
def init_config(conf_name):
"""
Loads configuration file, by checking for the conf_name.py configuration file as
./config/<conf_name>.py which must have function "Config".
This function must return a configuration dictionary with any necessary variables for the experiment.
"""
conf = importlib.import_module('config.' + conf_name).Config()
return conf
import paddle.fluid as fluid
class MyPolynomialDecay(fluid.dygraph.PolynomialDecay):
def step(self):
tmp_step_num = self.step_num
tmp_decay_steps = self.decay_steps
tmp_step_num = self.create_lr_var(tmp_step_num if tmp_step_num < self.
decay_steps else self.decay_steps)
scale = float(tmp_decay_steps) / (
1 - float(self.end_learning_rate / self.learning_rate)**
(1 / self.power))
decay_lr = self.learning_rate * ((1 - float(tmp_step_num) / scale)
**self.power)
return decay_lr
def adjust_lr(conf):
#if 'batch_skip' in conf and ((iter + 1) % conf.batch_skip) > 0: return
if conf.solver_type.lower() == 'sgd':
lr = conf.lr
lr_steps = conf.lr_steps
max_iter = conf.max_iter
lr_policy = conf.lr_policy
lr_target = conf.lr_target
# perform the exact number of steps needed to get to lr_target
# if lr_policy.lower() == 'step':
# scale = (lr_target / lr) ** (1 / total_steps)
# lr *= scale ** step_count
# compute the scale needed to go from lr --> lr_target
# using a polynomial function instead.
if lr_policy.lower() == 'poly':
lr = MyPolynomialDecay(lr, max_iter, lr_target, power=0.9)
else:
raise ValueError('{} lr_policy not understood'.format(lr_policy))
return lr
def init_training_model(conf, backbone, cache_folder):
"""
This function is meant to load the training model and optimizer, which expects
./model/<conf.model>.py to be the pytorch model file.
The function copies the model file into the cache BEFORE loading, for easy reproducibility.
"""
src_path = os.path.join('.', 'models', conf.model + '.py')
dst_path = os.path.join(cache_folder, conf.model + '.py')
# (re-) copy the model file
if os.path.exists(dst_path): os.remove(dst_path)
shutil.copyfile(src_path, dst_path)
# load and build
network = absolute_import(dst_path)
network = network.build(conf, backbone, 'train')
# multi-gpu
#network = torch.nn.DataParallel(network)
# load SGD
if conf.solver_type.lower() == 'sgd':
mo = conf.momentum
wd = conf.weight_decay
lr = adjust_lr(conf)
optimizer = fluid.optimizer.MomentumOptimizer(
learning_rate=lr,
momentum=mo,
regularization=fluid.regularizer.L2Decay(wd),
parameter_list=network.parameters())
# load adam
elif conf.solver_type.lower() == 'adam':
lr = conf.lr
wd = conf.weight_decay
optimizer = fluid.optimizer.Adam(
learning_rate=lr,
regularization=fluid.regularizer.L2Decay(wd),
parameter_list=network.parameters())
# load adamax
elif conf.solver_type.lower() == 'adamax':
lr = conf.lr
wd = conf.weight_decay
optimizer = fluid.optimizer.Adamax(
learning_rate=lr,
regularization=fluid.regularizer.L2Decay(wd),
parameter_list=network.parameters())
return network, optimizer #, lr
def intersect(box_a, box_b, mode='combinations', data_type=None):
"""
Computes the amount of intersect between two different sets of boxes.
Args:
box_a (nparray): Mx4 boxes, defined by [x1, y1, x2, y2]
box_a (nparray): Nx4 boxes, defined by [x1, y1, x2, y2]
mode (str): either 'combinations' or 'list', where combinations will check all combinations of box_a and
box_b hence MxN array, and list expects the same size list M == N, hence returns Mx1 array.
data_type (type): either torch.Tensor or np.ndarray, we automatically determine otherwise
"""
# determine type
if data_type is None: data_type = type(box_a)
# this mode computes the intersect in the sense of combinations.
# i.e., box_a = M x 4, box_b = N x 4 then the output is M x N
if mode == 'combinations':
# np.ndarray
if data_type == np.ndarray:
max_xy = np.minimum(
box_a[:, 2:4], np.expand_dims(
box_b[:, 2:4], axis=1))
min_xy = np.maximum(
box_a[:, 0:2], np.expand_dims(
box_b[:, 0:2], axis=1))
inter = np.clip((max_xy - min_xy), a_min=0, a_max=None)
# unknown type
else:
raise ValueError('type {} is not implemented'.format(data_type))
return inter[:, :, 0] * inter[:, :, 1]
# this mode computes the intersect in the sense of list_a vs. list_b.
# i.e., box_a = M x 4, box_b = M x 4 then the output is Mx1
elif mode == 'list':
# torch.Tesnor
if data_type == fluid.core_avx.VarBase:
max_xy = fluid.layers.elementwise_min(box_a[:, 2:], box_b[:, 2:])
min_xy = fluid.layers.elementwise_max(box_a[:, :2], box_b[:, :2])
inter = fluid.layers.clamp((max_xy - min_xy), 0)
# np.ndarray
elif data_type == np.ndarray:
max_xy = np.minimum(box_a[:, 2:], box_b[:, 2:])
min_xy = np.maximum(box_a[:, :2], box_b[:, :2])
inter = np.clip((max_xy - min_xy), a_min=0, a_max=None)
# unknown type
else:
raise ValueError('unknown data type {}'.format(data_type))
return inter[:, 0] * inter[:, 1]
else:
raise ValueError('unknown mode {}'.format(mode))
def iou(box_a, box_b, mode='combinations', data_type=None):
"""
Computes the amount of Intersection over Union (IoU) between two different sets of boxes.
Args:
box_a (nparray): Mx4 boxes, defined by [x1, y1, x2, y2]
box_a (nparray): Nx4 boxes, defined by [x1, y1, x2, y2]
mode (str): either 'combinations' or 'list', where combinations will check all combinations of box_a and
box_b hence MxN array, and list expects the same size list M == N, hence returns Mx1 array.
data_type (type): either torch.Tensor or np.ndarray, we automatically determine otherwise
"""
# determine type
if data_type is None: data_type = type(box_a)
# this mode computes the IoU in the sense of combinations.
# i.e., box_a = M x 4, box_b = N x 4 then the output is M x N
if mode == 'combinations':
inter = intersect(box_a, box_b, data_type=data_type)
area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1]))
area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1]))
union = np.expand_dims(area_a, 0) + np.expand_dims(area_b, 1) - inter
# np.ndarray
if data_type == np.ndarray:
return (inter / union).T
# unknown type
else:
raise ValueError('unknown data type {}'.format(data_type))
# this mode compares every box in box_a with target in box_b
# i.e., box_a = M x 4 and box_b = M x 4 then output is M x 1
elif mode == 'list':
inter = intersect(box_a, box_b, mode=mode)
area_a = (box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])
area_b = (box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])
union = area_a + area_b - inter
return inter / union
else:
raise ValueError('unknown mode {}'.format(mode))
def iou_ign(box_a, box_b, mode='combinations', data_type=None):
"""
Computes the amount of overap of box_b has within box_a, which is handy for dealing with ignore regions.
Hence, assume that box_b are ignore regions and box_a are anchor boxes, then we may want to know how
much overlap the anchors have inside of the ignore regions (hence ignore area_b!)
Args:
box_a (nparray): Mx4 boxes, defined by [x1, y1, x2, y2]
box_a (nparray): Nx4 boxes, defined by [x1, y1, x2, y2]
mode (str): either 'combinations' or 'list', where combinations will check all combinations of box_a and
box_b hence MxN array, and list expects the same size list M == N, hence returns Mx1 array.
data_type (type): either torch.Tensor or np.ndarray, we automatically determine otherwise
"""
if data_type is None: data_type = type(box_a)
# this mode computes the IoU in the sense of combinations.
# i.e., box_a = M x 4, box_b = N x 4 then the output is M x N
if mode == 'combinations':
inter = intersect(box_a, box_b, data_type=data_type)
area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1]))
area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1]))
union = np.expand_dims(area_a, 0) + np.expand_dims(area_b,
1) * 0 - inter * 0
# torch and numpy have different calls for transpose
if data_type == np.ndarray:
return (inter / union).T
# unknown type
else:
raise ValueError('unknown data type {}'.format(data_type))
else:
raise ValueError('unknown mode {}'.format(mode))
def to_int(string, dest="I"):
return struct.unpack(dest, string)[0]
def parse_shape_from_file(filename):
with open(filename, "rb") as file:
version = file.read(4)
lod_level = to_int(file.read(8), dest="Q")
for i in range(lod_level):
_size = to_int(file.read(8), dest="Q")
_ = file.read(_size)
version = file.read(4)
tensor_desc_size = to_int(file.read(4))
tensor_desc = VarType.TensorDesc()
tensor_desc.ParseFromString(file.read(tensor_desc_size))
return tuple(tensor_desc.dims)
def load_vars(train_prog, path):
"""
loads a paddle models vars from a given path.
"""
load_vars = []
load_fail_vars = []
def var_shape_matched(var, shape):
var_exist = os.path.exists(os.path.join(path, var.name))
if var_exist:
var_shape = parse_shape_from_file(os.path.join(path, var.name))
return var_shape == shape
return False
for x in train_prog.list_vars():
if isinstance(x, fluid.framework.Parameter):
shape = tuple(fluid.global_scope().find_var(x.name).get_tensor()
.shape())
if var_shape_matched(x, shape):
load_vars.append(x)
else:
load_fail_vars.append(x)
return load_vars, load_fail_vars
def log_stats(tracker, iteration, start_time, start_iter, max_iter, skip=1):
"""
This function writes the given stats to the log / prints to the screen.
Also, computes the estimated time arrival (eta) for completion and (dt) delta time per iteration.
Args:
tracker (array): dictionary array tracker objects. See below.
iteration (int): the current iteration
start_time (float): starting time of whole experiment
start_iter (int): starting iteration of whole experiment
max_iter (int): maximum iteration to go to
A tracker object is a dictionary with the following:
"name": the name of the statistic being tracked, e.g., 'fg_acc', 'abs_z'
"group": an arbitrary group key, e.g., 'loss', 'acc', 'misc'
"format": the python string format to use (see official str format function in python), e.g., '{:.2f}' for
a float with 2 decimal places.
"""
display_str = 'iter: {}'.format((int((iteration) / skip)))
# compute eta
time_str, dt = compute_eta(start_time, iteration - start_iter,
max_iter - start_iter)
# cycle through all tracks
last_group = ''
for key in sorted(tracker.keys()):
if type(tracker[key]) == list:
# compute mean
meanval = np.mean(tracker[key])
# get properties
format = tracker[key + '_obj'].format
group = tracker[key + '_obj'].group
name = tracker[key + '_obj'].name
# logic to have the string formatted nicely
# basically roughly this format:
# iter: {}, group_1 (name: val, name: val), group_2 (name: val), dt: val, eta: val
if last_group != group and last_group == '':
display_str += (', {} ({}: ' + format).format(group, name,
meanval)
elif last_group != group:
display_str += ('), {} ({}: ' + format).format(group, name,
meanval)
else:
display_str += (', {}: ' + format).format(name, meanval)
last_group = group
# append dt and eta
display_str += '), dt: {:0.2f}, eta: {}'.format(dt, time_str)
# log
logging.info(display_str)
def display_stats(vis,
tracker,
iteration,
start_time,
start_iter,
max_iter,
conf_name,
conf_pretty,
skip=1):
"""
This function plots the statistics using visdom package, similar to the log_stats function.
Also, computes the estimated time arrival (eta) for completion and (dt) delta time per iteration.
Args:
vis (visdom): the main visdom session object
tracker (array): dictionary array tracker objects. See below.
iteration (int): the current iteration
start_time (float): starting time of whole experiment
start_iter (int): starting iteration of whole experiment
max_iter (int): maximum iteration to go to
conf_name (str): experiment name used for visdom display
conf_pretty (str): pretty string with ALL configuration params to display
A tracker object is a dictionary with the following:
"name": the name of the statistic being tracked, e.g., 'fg_acc', 'abs_z'
"group": an arbitrary group key, e.g., 'loss', 'acc', 'misc'
"format": the python string format to use (see official str format function in python), e.g., '{:.2f}' for
a float with 2 decimal places.
"""
# compute eta
time_str, dt = compute_eta(start_time, iteration - start_iter,
max_iter - start_iter)
# general info
info = 'Experiment: <b>{}</b>, Eta: <b>{}</b>, Time/it: {:0.2f}s\n'.format(
conf_name, time_str, dt)
info += conf_pretty
# replace all newlines and spaces with line break <br> and non-breaking spaces &nbsp
info = info.replace('\n', '<br>')
info = info.replace(' ', '&nbsp')
# pre-formatted html tag
info = '<pre>' + info + '</pre'
# update the info window
vis.text(
info, win='info', opts={'title': 'info',
'width': 500,
'height': 350})
# draw graphs for each track
for key in sorted(tracker.keys()):
if type(tracker[key]) == list:
meanval = np.mean(tracker[key])
group = tracker[key + '_obj'].group
name = tracker[key + '_obj'].name
# new data point
vis.line(
X=np.array([(iteration + 1)]),
Y=np.array([meanval]),
win=group,
name=name,
update='append',
opts={
'showlegend': True,
'title': group,
'width': 500,
'height': 350,
'xlabel': 'iteration'
})
def compute_stats(tracker, stats):
"""
Copies any arbitary statistics which appear in 'stats' into 'tracker'.
Also, for each new object to track we will secretly store the objects information
into 'tracker' with the key as (group + name + '_obj'). This way we can retrieve these properties later.
Args:
tracker (array): dictionary array tracker objects. See below.
stats (array): dictionary array tracker objects. See below.
A tracker object is a dictionary with the following:
"name": the name of the statistic being tracked, e.g., 'fg_acc', 'abs_z'
"group": an arbitrary group key, e.g., 'loss', 'acc', 'misc'
"format": the python string format to use (see official str format function in python), e.g., '{:.2f}' for
a float with 2 decimal places.
"""
# through all stats
for stat in stats:
# get properties
name = stat['name']
group = stat['group']
val = stat['val']
# convention for identificaiton
id = group + name
# init if not exist?
if not (id in tracker): tracker[id] = []
# # convert tensor to numpy
# if type(val) == torch.Tensor:
# val = val.cpu().detach().numpy()
# store
tracker[id].append(val)
# store object info
obj_id = id + '_obj'
if not (obj_id in tracker):
stat.pop('val', None)
tracker[id + '_obj'] = stat
def init_training_paths(conf_name, use_tmp_folder=None):
"""
Simple function to store and create the relevant paths for the project,
based on the base = current_working_dir (cwd). For this reason, we expect
that the experiments are run from the root folder.
data = ./data
output = ./output/<conf_name>
weights = ./output/<conf_name>/weights
results = ./output/<conf_name>/results
logs = ./output/<conf_name>/log
Args:
conf_name (str): configuration experiment name (used for storage into ./output/<conf_name>)
"""
# make paths
paths = edict()
paths.base = os.getcwd()
paths.data = os.path.join(paths.base, 'dataset')
paths.output = os.path.join(os.getcwd(), 'output', conf_name)
paths.weights = os.path.join(paths.output, 'weights')
paths.logs = os.path.join(paths.output, 'log')
if use_tmp_folder:
paths.results = os.path.join(paths.base, '.tmp_results', conf_name,
'results')
else:
paths.results = os.path.join(paths.output, 'results')
# make directories
mkdir_if_missing(paths.output)
mkdir_if_missing(paths.logs)
mkdir_if_missing(paths.weights)
mkdir_if_missing(paths.results)
return paths
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""loss"""
import sys
from functools import reduce
sys.dont_write_bytecode = True
# -----------------------------------------
# custom modules
# -----------------------------------------
import paddle.fluid as fluid
import paddle
from paddle.fluid.dygraph import to_variable
sys.path.append("../../")
from lib.rpn_util import *
class RPN_3D_loss(fluid.dygraph.Layer):
def __init__(self, conf):
super(RPN_3D_loss, self).__init__()
self.num_classes = len(conf.lbls) + 1
self.num_anchors = conf.anchors.shape[0]
self.anchors = conf.anchors
self.bbox_means = conf.bbox_means
self.bbox_stds = conf.bbox_stds
self.feat_stride = conf.feat_stride
self.fg_fraction = conf.fg_fraction
self.box_samples = conf.box_samples
self.ign_thresh = conf.ign_thresh
self.nms_thres = conf.nms_thres
self.fg_thresh = conf.fg_thresh
self.bg_thresh_lo = conf.bg_thresh_lo
self.bg_thresh_hi = conf.bg_thresh_hi
self.best_thresh = conf.best_thresh
self.hard_negatives = conf.hard_negatives
self.focal_loss = conf.focal_loss
self.crop_size = conf.crop_size
self.cls_2d_lambda = conf.cls_2d_lambda
self.iou_2d_lambda = conf.iou_2d_lambda
self.bbox_2d_lambda = conf.bbox_2d_lambda
self.bbox_3d_lambda = conf.bbox_3d_lambda
self.bbox_3d_proj_lambda = conf.bbox_3d_proj_lambda
self.lbls = conf.lbls
self.ilbls = conf.ilbls
self.min_gt_vis = conf.min_gt_vis
self.min_gt_h = conf.min_gt_h
self.max_gt_h = conf.max_gt_h
def forward(self, cls, prob, bbox_2d, bbox_3d, imobjs, feat_size):
stats = []
loss = np.array([0]).astype('float32')
loss = to_variable(loss)
FG_ENC = 1000
BG_ENC = 2000
IGN_FLAG = 3000
batch_size = cls.shape[0]
prob_detach = prob.detach().numpy()
bbox_x = bbox_2d[:, :, 0]
bbox_y = bbox_2d[:, :, 1]
bbox_w = bbox_2d[:, :, 2]
bbox_h = bbox_2d[:, :, 3]
bbox_x3d = bbox_3d[:, :, 0]
bbox_y3d = bbox_3d[:, :, 1]
bbox_z3d = bbox_3d[:, :, 2]
bbox_w3d = bbox_3d[:, :, 3]
bbox_h3d = bbox_3d[:, :, 4]
bbox_l3d = bbox_3d[:, :, 5]
bbox_ry3d = bbox_3d[:, :, 6]
bbox_x3d_proj = np.zeros(bbox_x3d.shape)
bbox_y3d_proj = np.zeros(bbox_x3d.shape)
bbox_z3d_proj = np.zeros(bbox_x3d.shape)
labels = np.zeros(cls.shape[0:2])
labels_weight = np.zeros(cls.shape[0:2])
labels_scores = np.zeros(cls.shape[0:2])
bbox_x_tar = np.zeros(cls.shape[0:2])
bbox_y_tar = np.zeros(cls.shape[0:2])
bbox_w_tar = np.zeros(cls.shape[0:2])
bbox_h_tar = np.zeros(cls.shape[0:2])
bbox_x3d_tar = np.zeros(cls.shape[0:2])
bbox_y3d_tar = np.zeros(cls.shape[0:2])
bbox_z3d_tar = np.zeros(cls.shape[0:2])
bbox_w3d_tar = np.zeros(cls.shape[0:2])
bbox_h3d_tar = np.zeros(cls.shape[0:2])
bbox_l3d_tar = np.zeros(cls.shape[0:2])
bbox_ry3d_tar = np.zeros(cls.shape[0:2])
bbox_x3d_proj_tar = np.zeros(cls.shape[0:2])
bbox_y3d_proj_tar = np.zeros(cls.shape[0:2])
bbox_z3d_proj_tar = np.zeros(cls.shape[0:2])
bbox_weights = np.zeros(cls.shape[0:2])
ious_2d = np.zeros(cls.shape[0:2])
ious_3d = np.zeros(cls.shape[0:2])
coords_abs_z = np.zeros(cls.shape[0:2])
coords_abs_ry = np.zeros(cls.shape[0:2])
# get all rois
# rois' type now is nparray
rois = locate_anchors(self.anchors, feat_size, self.feat_stride)
rois = rois.astype('float32')
#bbox_3d dtype is Variable, so bbox_3d_dn is
bbox_x3d_dn = bbox_x3d * self.bbox_stds[:, 4][0] + self.bbox_means[:,
4][0]
bbox_y3d_dn = bbox_y3d * self.bbox_stds[:, 5][0] + self.bbox_means[:,
5][0]
bbox_z3d_dn = bbox_z3d * self.bbox_stds[:, 6][0] + self.bbox_means[:,
6][0]
bbox_w3d_dn = bbox_w3d * self.bbox_stds[:, 7][0] + self.bbox_means[:,
7][0]
bbox_h3d_dn = bbox_h3d * self.bbox_stds[:, 8][0] + self.bbox_means[:,
8][0]
bbox_l3d_dn = bbox_l3d * self.bbox_stds[:, 9][0] + self.bbox_means[:,
9][0]
bbox_ry3d_dn = bbox_ry3d * self.bbox_stds[:, 10][
0] + self.bbox_means[:, 10][0]
src_anchors = self.anchors[rois[:, 4].astype('int64'), :] #nparray
src_anchors = src_anchors.astype('float32')
src_anchors = to_variable(src_anchors) #Variable
src_anchors.stop_gradient = True
if len(src_anchors.shape) == 1:
src_anchors = fluid.layers.unsqueeze(input=src_anchors, axis=0)
# compute 3d transform
#the following four all are nparrays
widths = rois[:, 2] - rois[:, 0] + 1.0
heights = rois[:, 3] - rois[:, 1] + 1.0
ctr_x = rois[:, 0] + 0.5 * widths
ctr_y = rois[:, 1] + 0.5 * heights
ctr_x_unsqueeze = fluid.layers.unsqueeze(
input=to_variable(ctr_x), axes=0)
ctr_y_unsqueeze = fluid.layers.unsqueeze(
input=to_variable(ctr_y), axes=0)
widths_unsqueeze = fluid.layers.unsqueeze(
input=to_variable(widths), axes=0)
heights_unsqueeze = fluid.layers.unsqueeze(
input=to_variable(heights), axes=0)
bbox_z3d_unsqueeze = fluid.layers.unsqueeze(
input=src_anchors[:, 4], axes=0)
bbox_w3d_unsqueeze = fluid.layers.unsqueeze(
input=src_anchors[:, 5], axes=0)
bbox_h3d_unsqueeze = fluid.layers.unsqueeze(
input=src_anchors[:, 6], axes=0)
bbox_l3d_unsqueeze = fluid.layers.unsqueeze(
input=src_anchors[:, 7], axes=0)
bbox_ry3d_unsqueeze = fluid.layers.unsqueeze(
input=src_anchors[:, 8], axes=0)
bbox_x3d_dn = bbox_x3d_dn * widths_unsqueeze + ctr_x_unsqueeze
bbox_y3d_dn = bbox_y3d_dn * heights_unsqueeze + ctr_y_unsqueeze
bbox_z3d_dn = bbox_z3d_unsqueeze + bbox_z3d_dn
bbox_w3d_dn = fluid.layers.exp(bbox_w3d_dn) * bbox_w3d_unsqueeze
bbox_h3d_dn = fluid.layers.exp(bbox_h3d_dn) * bbox_h3d_unsqueeze
bbox_l3d_dn = fluid.layers.exp(bbox_l3d_dn) * bbox_l3d_unsqueeze
bbox_ry3d_dn = bbox_ry3d_unsqueeze + bbox_ry3d_dn
ious_2d_var_list = []
for bind in range(0, batch_size):
imobj = imobjs[bind]
gts = imobj.gts
p2_inv = to_variable(imobj.p2_inv).astype('float32')
# filter gts
igns, rmvs = determine_ignores(gts, self.lbls, self.ilbls,
self.min_gt_vis, self.min_gt_h)
# accumulate boxes
gts_all = bbXYWH2Coords(np.array([gt.bbox_full for gt in gts]))
gts_3d = np.array([gt.bbox_3d for gt in gts])
if not ((rmvs == False) & (igns == False)).any():
continue
# filter out irrelevant cls, and ignore cls
gts_val = gts_all[(rmvs == False) & (igns == False), :]
gts_ign = gts_all[(rmvs == False) & (igns == True), :]
gts_3d = gts_3d[(rmvs == False) & (igns == False), :]
# accumulate labels
box_lbls = np.array([gt.cls for gt in gts])
box_lbls = box_lbls[(rmvs == False) & (igns == False)]
box_lbls = np.array(
[clsName2Ind(self.lbls, cls) for cls in box_lbls])
if gts_val.shape[0] > 0 or gts_ign.shape[0] > 0:
# bbox regression
transforms, ols, raw_gt = compute_targets(
gts_val,
gts_ign,
box_lbls,
rois,
self.fg_thresh,
self.ign_thresh,
self.bg_thresh_lo,
self.bg_thresh_hi,
self.best_thresh,
anchors=self.anchors,
gts_3d=gts_3d,
tracker=rois[:, 4])
# normalize 2d
transforms[:, 0:4] -= self.bbox_means[:, 0:4]
transforms[:, 0:4] /= self.bbox_stds[:, 0:4]
# normalize 3d
transforms[:, 5:12] -= self.bbox_means[:, 4:]
transforms[:, 5:12] /= self.bbox_stds[:, 4:]
labels_fg = transforms[:, 4] > 0
labels_bg = transforms[:, 4] < 0
labels_ign = transforms[:, 4] == 0
fg_inds = np.flatnonzero(labels_fg)
bg_inds = np.flatnonzero(labels_bg)
ign_inds = np.flatnonzero(labels_ign)
labels[bind, fg_inds] = transforms[fg_inds, 4]
labels[bind, ign_inds] = IGN_FLAG
labels[bind, bg_inds] = 0
bbox_x_tar[bind, :] = transforms[:, 0]
bbox_y_tar[bind, :] = transforms[:, 1]
bbox_w_tar[bind, :] = transforms[:, 2]
bbox_h_tar[bind, :] = transforms[:, 3]
bbox_x3d_tar[bind, :] = transforms[:, 5]
bbox_y3d_tar[bind, :] = transforms[:, 6]
bbox_z3d_tar[bind, :] = transforms[:, 7]
bbox_w3d_tar[bind, :] = transforms[:, 8]
bbox_h3d_tar[bind, :] = transforms[:, 9]
bbox_l3d_tar[bind, :] = transforms[:, 10]
bbox_ry3d_tar[bind, :] = transforms[:, 11]
bbox_x3d_proj_tar[bind, :] = raw_gt[:, 12]
bbox_y3d_proj_tar[bind, :] = raw_gt[:, 13]
bbox_z3d_proj_tar[bind, :] = raw_gt[:, 14]
transforms = to_variable(transforms)
# ----------------------------------------
# box sampling
# ----------------------------------------
if self.box_samples == np.inf:
fg_num = len(fg_inds)
bg_num = len(bg_inds)
else:
fg_num = min(
round(rois.shape[0] * self.box_samples *
self.fg_fraction), len(fg_inds))
bg_num = min(
round(rois.shape[0] * self.box_samples - fg_num),
len(bg_inds))
if self.hard_negatives:
if fg_num > 0 and fg_num != fg_inds.shape[0]:
scores = prob_detach[bind, fg_inds, labels[
bind, fg_inds].astype(int)]
fg_score_ascend = (scores).argsort()
fg_inds = fg_inds[fg_score_ascend]
fg_inds = fg_inds[0:fg_num]
if bg_num > 0 and bg_num != bg_inds.shape[0]:
scores = prob_detach[bind, bg_inds, labels[
bind, bg_inds].astype(int)]
bg_score_ascend = (scores).argsort()
bg_inds = bg_inds[bg_score_ascend]
bg_inds = bg_inds[0:bg_num]
else:
if fg_num > 0 and fg_num != fg_inds.shape[0]:
fg_inds = np.random.choice(
fg_inds, fg_num, replace=False)
if bg_num > 0 and bg_num != bg_inds.shape[0]:
bg_inds = np.random.choice(
bg_inds, bg_num, replace=False)
labels_weight[bind, bg_inds] = BG_ENC
labels_weight[bind, fg_inds] = FG_ENC
bbox_weights[bind, fg_inds] = 1
# ----------------------------------------
# compute IoU stats
# ----------------------------------------
if fg_num > 0:
# compile deltas pred (Variable)
bbox_x_bind = bbox_x[bind, :]
bbox_x_bind_unsqueeze = fluid.layers.unsqueeze(
bbox_x_bind, axes=1)
bbox_y_bind = bbox_y[bind, :]
bbox_y_bind_unsqueeze = fluid.layers.unsqueeze(
bbox_y_bind, axes=1)
bbox_w_bind = bbox_w[bind, :]
bbox_w_bind_unsqueeze = fluid.layers.unsqueeze(
bbox_w_bind, axes=1)
bbox_h_bind = bbox_h[bind, :]
bbox_h_bind_unsqueeze = fluid.layers.unsqueeze(
bbox_h_bind, axes=1)
deltas_2d = fluid.layers.concat(
(bbox_x_bind_unsqueeze, bbox_y_bind_unsqueeze,
bbox_w_bind_unsqueeze, bbox_h_bind_unsqueeze),
axis=1)
# compile deltas targets (nparray)
deltas_2d_tar = np.concatenate(
(bbox_x_tar[bind, :, np.newaxis],
bbox_y_tar[bind, :, np.newaxis],
bbox_w_tar[bind, :, np.newaxis],
bbox_h_tar[bind, :, np.newaxis]),
axis=1).astype('float32')
# move to gpu
deltas_2d_tar = to_variable(deltas_2d_tar)
deltas_2d_tar.stop_gradient = True
means = self.bbox_means[0, :]
stds = self.bbox_stds[0, :]
#variable
coords_2d = bbox_transform_inv(
rois, deltas_2d, means=means, stds=stds)
coords_2d_tar = bbox_transform_inv(
rois, deltas_2d_tar, means=means, stds=stds)
#vaiable
ious_2d_var = iou(coords_2d, coords_2d_tar, mode='list')
ious_2d_var_shape = ious_2d_var.shape
ious_2d_fg_mask = np.zeros(ious_2d_var_shape).astype(
'float32')
ious_2d_fg_mask[fg_inds] = 1
ious_2d_var = ious_2d_var * to_variable(ious_2d_fg_mask)
ious_2d_var_list.append(ious_2d_var)
bbox_x3d_dn_fg = bbox_x3d_dn.numpy()[bind, fg_inds]
bbox_y3d_dn_fg = bbox_y3d_dn.numpy()[bind, fg_inds]
src_anchors = self.anchors[rois[fg_inds, 4].astype('int64')]
src_anchors = to_variable(src_anchors).astype('float32')
src_anchors.stop_gradient = True
if len(src_anchors.shape) == 1:
src_anchors = fluid.layers.unsqueeze(
input=src_anchors, axes=0)
#nparray
bbox_x3d_dn_fg = bbox_x3d_dn.numpy()[bind, fg_inds]
bbox_y3d_dn_fg = bbox_y3d_dn.numpy()[bind, fg_inds]
bbox_z3d_dn_fg = bbox_z3d_dn.numpy()[bind, fg_inds]
bbox_w3d_dn_fg = bbox_w3d_dn.numpy()[bind, fg_inds]
bbox_h3d_dn_fg = bbox_h3d_dn.numpy()[bind, fg_inds]
bbox_l3d_dn_fg = bbox_l3d_dn.numpy()[bind, fg_inds]
bbox_ry3d_dn_fg = bbox_ry3d_dn.numpy()[bind, fg_inds]
# re-scale all 2D back to original
bbox_x3d_dn_fg /= imobj['scale_factor']
bbox_y3d_dn_fg /= imobj['scale_factor']
coords_2d = fluid.layers.concat(
(to_variable(bbox_x3d_dn_fg[np.newaxis, :] *
bbox_z3d_dn_fg[np.newaxis, :]),
to_variable(bbox_y3d_dn_fg[np.newaxis, :] *
bbox_z3d_dn_fg[np.newaxis, :]),
to_variable(bbox_z3d_dn_fg[np.newaxis, :])),
axis=0)
coords_2d = fluid.layers.concat(
(coords_2d,
to_variable(np.ones([1, coords_2d.shape[1]])).astype(
'float32')),
axis=0)
coords_3d = fluid.layers.matmul(p2_inv, coords_2d)
bbox_x3d_proj[bind, fg_inds] = coords_3d[0, :].numpy()
bbox_y3d_proj[bind, fg_inds] = coords_3d[1, :].numpy()
bbox_z3d_proj[bind, fg_inds] = coords_3d[2, :].numpy()
# absolute targets
bbox_z3d_dn_tar = bbox_z3d_tar[
bind, fg_inds] * self.bbox_stds[:, 6][
0] + self.bbox_means[:, 6][0]
bbox_z3d_dn_tar = to_variable(bbox_z3d_dn_tar).astype(
'float32')
bbox_z3d_dn_tar.stop_gradient = True
bbox_z3d_dn_tar = src_anchors[:, 4] + bbox_z3d_dn_tar
bbox_ry3d_dn_tar = bbox_ry3d_tar[
bind, fg_inds] * self.bbox_stds[:, 10][
0] + self.bbox_means[:, 10][0]
bbox_ry3d_dn_tar = to_variable(bbox_ry3d_dn_tar).astype(
'float32')
bbox_ry3d_dn_tar.stop_gradient = True
bbox_ry3d_dn_tar = src_anchors[:, 8] + bbox_ry3d_dn_tar
bbox_z3d_dn_fg = to_variable(bbox_z3d_dn_fg)
bbox_ry3d_dn_fg = to_variable(bbox_ry3d_dn_fg)
bbox_abs_z3d_var = fluid.layers.abs(bbox_z3d_dn_tar -
bbox_z3d_dn_fg)
coords_abs_z[bind, fg_inds] = bbox_abs_z3d_var.numpy()
bbox_abs_ry3d_var = fluid.layers.abs(bbox_ry3d_dn_tar -
bbox_ry3d_dn_fg)
coords_abs_ry[bind, fg_inds] = bbox_abs_ry3d_var.numpy()
else:
bg_inds = np.arange(0, rois.shape[0])
if self.box_samples == np.inf: bg_num = len(bg_inds)
else:
bg_num = min(
round(self.box_samples * (1 - self.fg_fraction)),
len(bg_inds))
if self.hard_negatives:
if bg_num > 0 and bg_num != bg_inds.shape[0]:
scores = prob_detach[bind, bg_inds, labels[
bind, bg_inds].astype(int)]
bg_score_ascend = (scores).argsort()
bg_inds = bg_inds[bg_score_ascend]
bg_inds = bg_inds[0:bg_num]
else:
if bg_num > 0 and bg_num != bg_inds.shape[0]:
bg_inds = np.random.choice(
bg_inds, bg_num, replace=False)
labels[bind, :] = 0
labels_weight[bind, bg_inds] = BG_ENC
# grab label predictions (for weighing purposes) dtype: nparray
active = labels[bind, :] != IGN_FLAG
labels_scores[bind, active] = prob_detach[bind, active, labels[
bind, active].astype(int)]
# ----------------------------------------
# useful statistics
# ----------------------------------------
fg_inds_all = np.flatnonzero((labels > 0) & (labels != IGN_FLAG))
bg_inds_all = np.flatnonzero((labels == 0) & (labels != IGN_FLAG))
fg_inds_unravel = np.unravel_index(fg_inds_all, prob_detach.shape[0:2])
bg_inds_unravel = np.unravel_index(bg_inds_all, prob_detach.shape[0:2])
cls_pred = np.argmax(cls.detach().numpy(), axis=2)
if self.cls_2d_lambda and len(fg_inds_all) > 0:
acc_fg = np.mean(
cls_pred[fg_inds_unravel] == labels[fg_inds_unravel])
stats.append({
'name': 'fg',
'val': acc_fg,
'format': '{:0.2f}',
'group': 'acc'
})
if self.cls_2d_lambda and len(bg_inds_all) > 0:
acc_bg = np.mean(
cls_pred[bg_inds_unravel] == labels[bg_inds_unravel])
stats.append({
'name': 'bg',
'val': acc_bg,
'format': '{:0.2f}',
'group': 'acc'
})
# ----------------------------------------
# box weighting
# ----------------------------------------
fg_inds = np.flatnonzero(labels_weight == FG_ENC)
bg_inds = np.flatnonzero(labels_weight == BG_ENC)
active_inds = np.concatenate((fg_inds, bg_inds), axis=0)
fg_num = len(fg_inds)
bg_num = len(bg_inds)
labels_weight[...] = 0.0
box_samples = fg_num + bg_num
fg_inds_unravel = np.unravel_index(fg_inds, labels_weight.shape)
bg_inds_unravel = np.unravel_index(bg_inds, labels_weight.shape)
active_inds_unravel = np.unravel_index(active_inds, labels_weight.shape)
labels_weight[active_inds_unravel] = 1.0
if self.fg_fraction is not None:
if fg_num > 0:
fg_weight = (self.fg_fraction /
(1 - self.fg_fraction)) * (bg_num / fg_num)
labels_weight[fg_inds_unravel] = fg_weight
labels_weight[bg_inds_unravel] = 1.0
else:
labels_weight[bg_inds_unravel] = 1.0
# different method of doing hard negative mining
# use the scores to normalize the importance of each sample
# hence, encourages the network to get all "correct" rather than
# becoming more correct at a decision it is already good at
# this method is equivelent to the focal loss with additional mean scaling
if self.focal_loss:
weights_sum = 0
# re-weight bg
if bg_num > 0:
bg_scores = labels_scores[bg_inds_unravel]
bg_weights = (1 - bg_scores)**self.focal_loss
weights_sum += np.sum(bg_weights)
labels_weight[bg_inds_unravel] *= bg_weights
# re-weight fg
if fg_num > 0:
fg_scores = labels_scores[fg_inds_unravel]
fg_weights = (1 - fg_scores)**self.focal_loss
weights_sum += np.sum(fg_weights)
labels_weight[fg_inds_unravel] *= fg_weights
# ----------------------------------------
# classification loss
# ----------------------------------------
labels_weight = labels_weight.view()
labels_weight.shape = np.product(labels_weight.shape)
active = labels_weight > 0
labels_weight_active = labels_weight[active]
labels_weight_active = to_variable(labels_weight_active)
labels_weight_active = labels_weight_active.astype('float32')
labels_weight_active.stop_gradient = True
labels = labels.view().astype('int64')
labels.shape = np.product(labels.shape)
labels_active = labels[active]
labels_active = to_variable(labels_active)
labels_active.stop_gradient = True
active_index = np.flatnonzero(active)
cls_reshape = fluid.layers.reshape(cls, shape=[-1, cls.shape[2]])
active_index_var = to_variable(active_index)
active_index_var.stop_gradient = True
cls_active = fluid.layers.gather(cls_reshape, index=active_index_var)
if self.cls_2d_lambda:
# cls loss
if np.any(active):
labels_active = fluid.layers.reshape(
labels_active, shape=[-1, 1])
loss_cls = fluid.layers.softmax_with_cross_entropy(
cls_active, labels_active, ignore_index=IGN_FLAG)
labels_weight_active = fluid.layers.unsqueeze(
labels_weight_active, axes=1)
loss_cls = fluid.layers.elementwise_mul(loss_cls,
labels_weight_active)
# simple gradient clipping
loss_cls = fluid.layers.clip(loss_cls, min=0.0, max=2000.0)
# take mean and scale lambda
loss_cls = fluid.layers.mean(loss_cls)
loss_cls *= self.cls_2d_lambda
loss += loss_cls
stats.append({
'name': 'cls',
'val': loss_cls.numpy(),
'format': '{:0.4f}',
'group': 'loss'
})
# ----------------------------------------
# bbox regression loss
# ----------------------------------------
if np.sum(bbox_weights) > 0:
bbox_total_nums = np.product(bbox_weights.shape)
bbox_weights = bbox_weights.view().astype('float32')
bbox_weights.shape = bbox_total_nums
active = bbox_weights > 0
active_index = np.flatnonzero(active)
active_len = active_index.size
active_index_var = to_variable(active_index)
active_index_var.stop_gradient = True
bbox_weights.shape = 1, bbox_total_nums
bbox_weights_active = bbox_weights[:, active]
bbox_weights_active = to_variable(bbox_weights_active)
bbox_weights_active.stop_gradient = True
if self.bbox_2d_lambda:
# bbox loss 2d
bbox_x_tar = bbox_x_tar.view().astype('float32')
bbox_x_tar.shape = 1, bbox_total_nums
bbox_x_tar_active = bbox_x_tar[:, active]
bbox_x_tar_active = to_variable(bbox_x_tar_active)
bbox_x_tar_active.stop_gradient = True
bbox_y_tar = bbox_y_tar.view().astype('float32')
bbox_y_tar.shape = 1, bbox_total_nums
bbox_y_tar_active = bbox_y_tar[:, active]
bbox_y_tar_active = to_variable(bbox_y_tar_active)
bbox_y_tar_active.stop_gradient = True
bbox_w_tar = bbox_w_tar.view().astype('float32')
bbox_w_tar.shape = 1, bbox_total_nums
bbox_w_tar_active = bbox_w_tar[:, active]
bbox_w_tar_active = to_variable(bbox_w_tar_active)
bbox_w_tar_active.stop_gradient = True
bbox_h_tar = bbox_h_tar.view().astype('float32')
bbox_h_tar.shape = 1, bbox_total_nums
bbox_h_tar_active = bbox_h_tar[:, active]
bbox_h_tar_active = to_variable(bbox_h_tar_active)
bbox_h_tar_active.stop_gradient = True
bbox_x = fluid.layers.reshape(bbox_x, shape=[-1])
bbox_x_active = fluid.layers.gather(bbox_x, active_index_var)
bbox_x_active = fluid.layers.unsqueeze(bbox_x_active, axes=0)
bbox_y = fluid.layers.reshape(bbox_y, shape=[-1])
bbox_y_active = fluid.layers.gather(bbox_y, active_index_var)
bbox_y_active = fluid.layers.unsqueeze(bbox_y_active, axes=0)
bbox_w = fluid.layers.reshape(bbox_w, shape=[-1])
bbox_w_active = fluid.layers.gather(bbox_w, active_index_var)
bbox_w_active = fluid.layers.unsqueeze(bbox_w_active, axes=0)
bbox_h = fluid.layers.reshape(bbox_h, shape=[-1])
bbox_h_active = fluid.layers.gather(bbox_h, active_index_var)
bbox_h_active = fluid.layers.unsqueeze(bbox_h_active, axes=0)
loss_bbox_x = fluid.layers.smooth_l1(
bbox_x_active,
bbox_x_tar_active,
outside_weight=bbox_weights_active)
loss_bbox_y = fluid.layers.smooth_l1(
bbox_y_active,
bbox_y_tar_active,
outside_weight=bbox_weights_active)
loss_bbox_w = fluid.layers.smooth_l1(
bbox_w_active,
bbox_w_tar_active,
outside_weight=bbox_weights_active)
loss_bbox_h = fluid.layers.smooth_l1(
bbox_h_active,
bbox_h_tar_active,
outside_weight=bbox_weights_active)
bbox_2d_loss = (
loss_bbox_x + loss_bbox_y + loss_bbox_w + loss_bbox_h
) / active_len
bbox_2d_loss *= self.bbox_2d_lambda
loss += bbox_2d_loss
stats.append({
'name': 'bbox_2d',
'val': bbox_2d_loss.numpy(),
'format': '{:0.4f}',
'group': 'loss'
})
if self.bbox_3d_lambda:
# bbox loss 3d
bbox_x3d_tar = bbox_x3d_tar.view().astype('float32')
bbox_x3d_tar.shape = 1, bbox_total_nums
bbox_x3d_tar_active = bbox_x3d_tar[:, active]
bbox_x3d_tar_active = to_variable(bbox_x3d_tar_active)
bbox_x3d_tar_active.stop_gradient = True
bbox_y3d_tar = bbox_y3d_tar.view().astype('float32')
bbox_y3d_tar.shape = 1, bbox_total_nums
bbox_y3d_tar_active = bbox_y3d_tar[:, active]
bbox_y3d_tar_active = to_variable(bbox_y3d_tar_active)
bbox_y3d_tar_active.stop_gradient = True
bbox_z3d_tar = bbox_z3d_tar.view().astype('float32')
bbox_z3d_tar.shape = 1, bbox_total_nums
bbox_z3d_tar_active = bbox_z3d_tar[:, active]
bbox_z3d_tar_active = to_variable(bbox_z3d_tar_active)
bbox_z3d_tar_active.stop_gradient = True
bbox_w3d_tar = bbox_w3d_tar.view().astype('float32')
bbox_w3d_tar.shape = 1, bbox_total_nums
bbox_w3d_tar_active = bbox_w3d_tar[:, active]
bbox_w3d_tar_active = to_variable(bbox_w3d_tar_active)
bbox_w3d_tar_active.stop_gradient = True
bbox_h3d_tar = bbox_h3d_tar.view().astype('float32')
bbox_h3d_tar.shape = 1, bbox_total_nums
bbox_h3d_tar_active = bbox_h3d_tar[:, active]
bbox_h3d_tar_active = to_variable(bbox_h3d_tar_active)
bbox_h3d_tar_active.stop_gradient = True
bbox_l3d_tar = bbox_l3d_tar.view().astype('float32')
bbox_l3d_tar.shape = 1, bbox_total_nums
bbox_l3d_tar_active = bbox_l3d_tar[:, active]
bbox_l3d_tar_active = to_variable(bbox_l3d_tar_active)
bbox_l3d_tar_active.stop_gradient = True
bbox_ry3d_tar = bbox_ry3d_tar.view().astype('float32')
bbox_ry3d_tar.shape = 1, bbox_total_nums
bbox_ry3d_tar_active = bbox_ry3d_tar[:, active]
bbox_ry3d_tar_active = to_variable(bbox_ry3d_tar_active)
bbox_ry3d_tar_active.stop_gradient = True
bbox_x3d = fluid.layers.reshape(bbox_x3d, shape=[-1])
bbox_x3d_active = fluid.layers.gather(bbox_x3d,
active_index_var)
bbox_x3d_active = fluid.layers.unsqueeze(
bbox_x3d_active, axes=0)
bbox_y3d = fluid.layers.reshape(bbox_y3d, shape=[-1])
bbox_y3d_active = fluid.layers.gather(bbox_y3d,
active_index_var)
bbox_y3d_active = fluid.layers.unsqueeze(
bbox_y3d_active, axes=0)
bbox_z3d = fluid.layers.reshape(bbox_z3d, shape=[-1])
bbox_z3d_active = fluid.layers.gather(bbox_z3d,
active_index_var)
bbox_z3d_active = fluid.layers.unsqueeze(
bbox_z3d_active, axes=0)
bbox_w3d = fluid.layers.reshape(bbox_w3d, shape=[-1])
bbox_w3d_active = fluid.layers.gather(bbox_w3d,
active_index_var)
bbox_w3d_active = fluid.layers.unsqueeze(
bbox_w3d_active, axes=0)
bbox_h3d = fluid.layers.reshape(bbox_h3d, shape=[-1])
bbox_h3d_active = fluid.layers.gather(bbox_h3d,
active_index_var)
bbox_h3d_active = fluid.layers.unsqueeze(
bbox_h3d_active, axes=0)
bbox_l3d = fluid.layers.reshape(bbox_l3d, shape=[-1])
bbox_l3d_active = fluid.layers.gather(bbox_l3d,
active_index_var)
bbox_l3d_active = fluid.layers.unsqueeze(
bbox_l3d_active, axes=0)
bbox_ry3d = fluid.layers.reshape(bbox_ry3d, shape=[-1])
bbox_ry3d_active = fluid.layers.gather(bbox_ry3d,
active_index_var)
bbox_ry3d_active = fluid.layers.unsqueeze(
bbox_ry3d_active, axes=0)
loss_bbox_x3d = fluid.layers.smooth_l1(
bbox_x3d_active.astype('float32'),
bbox_x3d_tar_active.astype('float32'),
outside_weight=bbox_weights_active.astype('float32'))
loss_bbox_y3d = fluid.layers.smooth_l1(
bbox_y3d_active.astype('float32'),
bbox_y3d_tar_active.astype('float32'),
outside_weight=bbox_weights_active.astype('float32'))
loss_bbox_z3d = fluid.layers.smooth_l1(
bbox_z3d_active.astype('float32'),
bbox_z3d_tar_active.astype('float32'),
outside_weight=bbox_weights_active.astype('float32'))
loss_bbox_w3d = fluid.layers.smooth_l1(
bbox_w3d_active.astype('float32'),
bbox_w3d_tar_active.astype('float32'),
outside_weight=bbox_weights_active.astype('float32'))
loss_bbox_h3d = fluid.layers.smooth_l1(
bbox_h3d_active.astype('float32'),
bbox_h3d_tar_active.astype('float32'),
outside_weight=bbox_weights_active.astype('float32'))
loss_bbox_l3d = fluid.layers.smooth_l1(
bbox_l3d_active.astype('float32'),
bbox_l3d_tar_active.astype('float32'),
outside_weight=bbox_weights_active.astype('float32'))
loss_bbox_ry3d = fluid.layers.smooth_l1(
bbox_ry3d_active.astype('float32'),
bbox_ry3d_tar_active.astype('float32'),
outside_weight=bbox_weights_active.astype('float32'))
bbox_3d_loss = (loss_bbox_x3d + loss_bbox_y3d + loss_bbox_z3d)
bbox_3d_loss += (loss_bbox_w3d + loss_bbox_h3d + loss_bbox_l3d +
loss_bbox_ry3d)
bbox_3d_loss = bbox_3d_loss / active_len
bbox_3d_loss *= self.bbox_3d_lambda
bbox_3d_loss = bbox_3d_loss
loss += bbox_3d_loss
stats.append({
'name': 'bbox_3d',
'val': bbox_3d_loss.numpy(),
'format': '{:0.4f}',
'group': 'loss'
})
if self.bbox_3d_proj_lambda:
# bbox loss 3d
bbox_x3d_proj_tar = bbox_x3d_proj_tar.view().astype('float32')
bbox_x3d_proj_tar.shape = 1, bbox_total_nums
bbox_x3d_proj_tar_active = bbox_x3d_proj_tar[:, active]
bbox_x3d_proj_tar_active = to_variable(bbox_x3d_proj_tar_active)
bbox_x3d_proj_tar_active.stop_gradient = True
bbox_y3d_proj_tar = bbox_y3d_proj_tar.view().astype('float32')
bbox_y3d_proj_tar.shape = 1, bbox_total_nums
bbox_y3d_proj_tar_active = bbox_y3d_proj_tar[:, active]
bbox_y3d_proj_tar_active = to_variable(bbox_y3d_proj_tar_active)
bbox_y3d_proj_tar_active.stop_gradient = True
bbox_z3d_proj_tar = bbox_z3d_proj_tar.view().astype('float32')
bbox_z3d_proj_tar.shape = 1, bbox_total_nums
bbox_z3d_proj_tar_active = bbox_z3d_proj_tar[:, active]
bbox_z3d_proj_tar_active = to_variable(bbox_z3d_proj_tar_active)
bbox_z3d_proj_tar_active.stop_gradient = True
bbox_x3d_proj = bbox_x3d_proj.view()
bbox_x3d_proj.shape = 1, bbox_total_nums
bbox_x3d_proj_active = bbox_x3d_proj[:, active]
bbox_x3d_proj_active = to_variable(bbox_x3d_proj_active)
bbox_y3d_proj = bbox_y3d_proj.view()
bbox_y3d_proj.shape = 1, bbox_total_nums
bbox_y3d_proj_active = bbox_y3d_proj[:, active]
bbox_y3d_proj_active = to_variable(bbox_y3d_proj_active)
bbox_y3d_proj_active.stop_gradient = True
bbox_z3d_proj = bbox_z3d_proj.view()
bbox_z3d_proj.shape = 1, bbox_total_nums
bbox_z3d_proj_active = bbox_z3d_proj[:, active]
bbox_z3d_proj_active = to_variable(bbox_z3d_proj_active)
bbox_z3d_proj_active.stop_gradient = True
loss_bbox_x3d_proj = fluid.layers.smooth_l1(
bbox_x3d_proj_active.astype('float32'),
bbox_x3d_proj_tar_active.astype('float32'),
outside_weight=bbox_weights_active.astype('float32'))
loss_bbox_y3d_proj = fluid.layers.smooth_l1(
bbox_y3d_proj_active.astype('float32'),
bbox_y3d_proj_tar_active.astype('float32'),
outside_weight=bbox_weights_active.astype('float32'))
loss_bbox_z3d_proj = fluid.layers.smooth_l1(
bbox_z3d_proj_active.astype('float32'),
bbox_z3d_proj_tar_active.astype('float32'),
outside_weight=bbox_weights_active.astype('float32'))
bbox_3d_proj_loss = (
loss_bbox_x3d_proj + loss_bbox_y3d_proj + loss_bbox_z3d_proj
)
bbox_3d_proj_loss = bbox_3d_proj_loss / active_len
bbox_3d_proj_loss *= self.bbox_3d_proj_lambda
bbox_3d_proj_loss = bbox_3d_proj_loss
loss += bbox_3d_proj_loss
stats.append({
'name': 'bbox_3d_proj',
'val': bbox_3d_proj_loss.numpy(),
'format': '{:0.4f}',
'group': 'loss'
})
coords_abs_z = fluid.layers.reshape(
to_variable(coords_abs_z), shape=[-1])
coords_abs_z_np = coords_abs_z.numpy()
coords_abs_z_active = coords_abs_z_np[active]
coords_abs_z = to_variable(coords_abs_z_active)
coords_abs_z_mean = fluid.layers.mean(coords_abs_z)
stats.append({
'name': 'z',
'val': coords_abs_z_mean.numpy(),
'format': '{:0.2f}',
'group': 'misc'
})
coords_abs_ry = fluid.layers.reshape(
to_variable(coords_abs_ry), shape=[-1])
coords_abs_ry_np = coords_abs_ry.numpy()
coords_abs_ry_active = coords_abs_ry_np[active]
coords_abs_ry = to_variable(coords_abs_ry_active)
coords_abs_ry_mean = fluid.layers.mean(coords_abs_ry)
stats.append({
'name': 'ry',
'val': coords_abs_ry_mean.numpy(),
'format': '{:0.2f}',
'group': 'misc'
})
ious_2d = fluid.layers.concat(ious_2d_var_list, axis=0)
ious_2d = fluid.layers.reshape(ious_2d, shape=[-1])
ious_2d_active = fluid.layers.gather(ious_2d, active_index_var)
ious_2d_mean = fluid.layers.mean(ious_2d_active)
stats.append({
'name': 'iou',
'val': ious_2d_mean.numpy(),
'format': '{:0.2f}',
'group': 'acc'
})
# use a 2d IoU based log loss
if self.iou_2d_lambda:
iou_2d_loss = -fluid.layers.log(ious_2d_active)
iou_2d_loss = (iou_2d_loss * bbox_weights_active)
iou_2d_loss = fluid.layers.mean(iou_2d_loss)
iou_2d_loss *= self.iou_2d_lambda
loss += iou_2d_loss
stats.append({
'name': 'iou',
'val': iou_2d_loss.numpy(),
'format': '{:0.4f}',
'group': 'loss'
})
return loss, stats
all:
python setup.py build_ext --inplace
rm -rf build
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import numpy as np
cimport numpy as np
cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
return a if a >= b else b
cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
return a if a <= b else b
def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1)
cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1]
cdef int ndets = dets.shape[0]
cdef np.ndarray[np.int_t, ndim=1] suppressed = \
np.zeros((ndets), dtype=np.int)
# nominal indices
cdef int _i, _j
# sorted indices
cdef int i, j
# temp variables for box i's (the box currently under consideration)
cdef np.float32_t ix1, iy1, ix2, iy2, iarea
# variables for computing overlap with box j (lower scoring box)
cdef np.float32_t xx1, yy1, xx2, yy2
cdef np.float32_t w, h
cdef np.float32_t inter, ovr
keep = []
for _i in range(ndets):
i = order[_i]
if suppressed[i] == 1:
continue
keep.append(i)
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0.0, xx2 - xx1 + 1)
h = max(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (iarea + areas[j] - inter)
if ovr >= thresh:
suppressed[j] = 1
return keep
void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
int boxes_dim, float nms_overlap_thresh, int device_id);
# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import numpy as np
cimport numpy as np
assert sizeof(int) == sizeof(np.int32_t)
cdef extern from "gpu_nms.hpp":
void _nms(np.int32_t*, int*, np.float32_t*, int, int, float, int)
def gpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh,
np.int32_t device_id=0):
cdef int boxes_num = dets.shape[0]
cdef int boxes_dim = dets.shape[1]
cdef int num_out
cdef np.ndarray[np.int32_t, ndim=1] \
keep = np.zeros(boxes_num, dtype=np.int32)
cdef np.ndarray[np.float32_t, ndim=1] \
scores = dets[:, 4]
cdef np.ndarray[np.int_t, ndim=1] \
order = scores.argsort()[::-1]
cdef np.ndarray[np.float32_t, ndim=2] \
sorted_dets = dets[order, :]
_nms(&keep[0], &num_out, &sorted_dets[0, 0], boxes_num, boxes_dim, thresh, device_id)
keep = keep[:num_out]
return list(order[keep])
// ------------------------------------------------------------------
// Faster R-CNN
// Copyright (c) 2015 Microsoft
// Licensed under The MIT License [see fast-rcnn/LICENSE for details]
// Written by Shaoqing Ren
// ------------------------------------------------------------------
#include "gpu_nms.hpp"
#include <vector>
#include <iostream>
#define CUDA_CHECK(condition) \
/* Code block avoids redefinition of cudaError_t error */ \
do { \
cudaError_t error = condition; \
if (error != cudaSuccess) { \
std::cout << cudaGetErrorString(error) << std::endl; \
} \
} while (0)
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
int const threadsPerBlock = sizeof(unsigned long long) * 8;
__device__ inline float devIoU(float const * const a, float const * const b) {
float left = max(a[0], b[0]), right = min(a[2], b[2]);
float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
float interS = width * height;
float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
return interS / (Sa + Sb - interS);
}
__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
const float *dev_boxes, unsigned long long *dev_mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
__shared__ float block_boxes[threadsPerBlock * 5];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 5 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
block_boxes[threadIdx.x * 5 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
block_boxes[threadIdx.x * 5 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
block_boxes[threadIdx.x * 5 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
block_boxes[threadIdx.x * 5 + 4] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const float *cur_box = dev_boxes + cur_box_idx * 5;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
t |= 1ULL << i;
}
}
const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
void _set_device(int device_id) {
int current_device;
CUDA_CHECK(cudaGetDevice(&current_device));
if (current_device == device_id) {
return;
}
// The call to cudaSetDevice must come before any calls to Get, which
// may perform initialization using the GPU.
CUDA_CHECK(cudaSetDevice(device_id));
}
void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
int boxes_dim, float nms_overlap_thresh, int device_id) {
_set_device(device_id);
float* boxes_dev = NULL;
unsigned long long* mask_dev = NULL;
const int col_blocks = DIVUP(boxes_num, threadsPerBlock);
CUDA_CHECK(cudaMalloc(&boxes_dev,
boxes_num * boxes_dim * sizeof(float)));
CUDA_CHECK(cudaMemcpy(boxes_dev,
boxes_host,
boxes_num * boxes_dim * sizeof(float),
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMalloc(&mask_dev,
boxes_num * col_blocks * sizeof(unsigned long long)));
dim3 blocks(DIVUP(boxes_num, threadsPerBlock),
DIVUP(boxes_num, threadsPerBlock));
dim3 threads(threadsPerBlock);
nms_kernel<<<blocks, threads>>>(boxes_num,
nms_overlap_thresh,
boxes_dev,
mask_dev);
std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
CUDA_CHECK(cudaMemcpy(&mask_host[0],
mask_dev,
sizeof(unsigned long long) * boxes_num * col_blocks,
cudaMemcpyDeviceToHost));
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
int num_to_keep = 0;
for (int i = 0; i < boxes_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long *p = &mask_host[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
*num_out = num_to_keep;
CUDA_CHECK(cudaFree(boxes_dev));
CUDA_CHECK(cudaFree(mask_dev));
}
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import numpy as np
def py_cpu_nms(dets, thresh):
"""Pure Python NMS baseline."""
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import os
from os.path import join as pjoin
from setuptools import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
import subprocess
import numpy as np
def find_in_path(name, path):
"Find a file in a search path"
# Adapted fom
# http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/
for dir in path.split(os.pathsep):
binpath = pjoin(dir, name)
if os.path.exists(binpath):
return os.path.abspath(binpath)
return None
def locate_cuda():
"""Locate the CUDA environment on the system
Returns a dict with keys 'home', 'nvcc', 'include', and 'lib64'
and values giving the absolute path to each directory.
Starts by looking for the CUDAHOME env variable. If not found, everything
is based on finding 'nvcc' in the PATH.
"""
# first check if the CUDAHOME env variable is in use
if 'CUDAHOME' in os.environ:
home = os.environ['CUDAHOME']
nvcc = pjoin(home, 'bin', 'nvcc')
else:
# otherwise, search the PATH for NVCC
default_path = pjoin(os.sep, 'usr', 'local', 'cuda', 'bin')
nvcc = find_in_path('nvcc',
os.environ['PATH'] + os.pathsep + default_path)
if nvcc is None:
raise EnvironmentError(
'The nvcc binary could not be '
'located in your $PATH. Either add it to your path, or set $CUDAHOME'
)
home = os.path.dirname(os.path.dirname(nvcc))
cudaconfig = {
'home': home,
'nvcc': nvcc,
'include': pjoin(home, 'include'),
'lib64': pjoin(home, 'lib64')
}
for k, v in cudaconfig.items():
if not os.path.exists(v):
raise EnvironmentError(
'The CUDA %s path could not be located in %s' % (k, v))
return cudaconfig
CUDA = locate_cuda()
# Obtain the numpy include directory. This logic works across numpy versions.
try:
numpy_include = np.get_include()
except AttributeError:
numpy_include = np.get_numpy_include()
def customize_compiler_for_nvcc(self):
"""inject deep into distutils to customize how the dispatch
to gcc/nvcc works.
If you subclass UnixCCompiler, it's not trivial to get your subclass
injected in, and still have the right customizations (i.e.
distutils.sysconfig.customize_compiler) run on it. So instead of going
the OO route, I have this. Note, it's kindof like a wierd functional
subclassing going on."""
# tell the compiler it can processes .cu
self.src_extensions.append('.cu')
# save references to the default compiler_so and _comple methods
default_compiler_so = self.compiler_so
super = self._compile
# now redefine the _compile method. This gets executed for each
# object but distutils doesn't have the ability to change compilers
# based on source extension: we add it.
def _compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
if os.path.splitext(src)[1] == '.cu':
# use the cuda for .cu files
self.set_executable('compiler_so', CUDA['nvcc'])
# use only a subset of the extra_postargs, which are 1-1 translated
# from the extra_compile_args in the Extension class
postargs = extra_postargs['nvcc']
else:
postargs = extra_postargs['gcc']
super(obj, src, ext, cc_args, postargs, pp_opts)
# reset the default compiler_so, which we might have changed for cuda
self.compiler_so = default_compiler_so
# inject our redefined _compile method into the class
self._compile = _compile
# run the customize_compiler
class custom_build_ext(build_ext):
def build_extensions(self):
customize_compiler_for_nvcc(self.compiler)
build_ext.build_extensions(self)
ext_modules = [
Extension(
"cpu_nms", ["cpu_nms.pyx"],
extra_compile_args={'gcc': ["-Wno-cpp", "-Wno-unused-function"]},
include_dirs=[numpy_include],
library_dirs=["/home/vis/yexiaoqing/06_pp3d/python3.7_paddle1.8/lib"],
libraries=['python3.7m']),
Extension(
'gpu_nms',
['nms_kernel.cu', 'gpu_nms.pyx'],
#library_dirs=[CUDA['lib64']],
library_dirs=[
CUDA['lib64'],
"/home/vis/yexiaoqing/06_pp3d/python3.7_paddle1.8/lib"
],
libraries=['cudart', 'python3.7m'],
language='c++',
runtime_library_dirs=[CUDA['lib64']],
# this syntax is specific to this build system
# we're only going to use certain compiler args with nvcc and not with
# gcc the implementation of this trick is in customize_compiler() below
extra_compile_args={
'gcc': ["-Wno-unused-function"],
'nvcc': [
'-arch=sm_35', '--ptxas-options=-v', '-c', '--compiler-options',
"'-fPIC'"
]
},
include_dirs=[numpy_include, CUDA['include']]),
]
setup(
name='fast_rcnn',
ext_modules=ext_modules,
# inject our custom trigger
cmdclass={'build_ext': custom_build_ext}, )
"""
This code is based on https://github.com/garrickbrazil/M3D-RPN/blob/master/lib/rpn_util.py
This file is meant to contain functions which are
specific to region proposal networks.
"""
#import matplotlib.pyplot as plt
import subprocess
#import torch
import math
import os
import re
import gc
from lib.util import *
from lib.core import *
from data.augmentations import *
from lib.nms.gpu_nms import gpu_nms
from copy import deepcopy
import numpy as np
import pdb
import cv2
import paddle
from paddle.fluid import framework
import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
def generate_anchors(conf, imdb, cache_folder):
"""
Generates the anchors according to the configuration and
(optionally) based on the imdb properties.
"""
#
# use cache?
if (cache_folder is not None
) and os.path.exists(os.path.join(cache_folder, 'anchors.pkl')):
anchors = pickle_read(os.path.join(cache_folder, 'anchors.pkl'))
#
# generate anchors
else:
#
anchors = np.zeros(
[len(conf.anchor_scales) * len(conf.anchor_ratios), 4],
dtype=np.float32)
#
aind = 0
#
# compute simple anchors based on scale/ratios
for scale in conf.anchor_scales:
#
for ratio in conf.anchor_ratios:
#
h = scale
w = scale * ratio
#
anchors[aind, 0:4] = anchor_center(w, h, conf.feat_stride)
aind += 1
#
#
# optionally cluster anchors
if conf.cluster_anchors:
anchors = cluster_anchors(
conf.feat_stride, anchors, conf.test_scale, imdb, conf.lbls,
conf.ilbls, conf.anchor_ratios, conf.min_gt_vis, conf.min_gt_h,
conf.max_gt_h, conf.even_anchors, conf.expand_anchors)
#
#
# has 3d? then need to compute stats for each new dimension
# presuming that anchors are initialized in "2d"
elif conf.has_3d:
#
# compute the default stats for each anchor
normalized_gts = []
#
# check all images
for imind, imobj in enumerate(imdb):
#
# has ground truths?
if len(imobj.gts) > 0:
#
scale = imobj.scale * conf.test_scale / imobj.imH
#
# determine ignores
igns, rmvs = determine_ignores(imobj.gts, conf.lbls,
conf.ilbls, conf.min_gt_vis,
conf.min_gt_h, np.inf, scale)
#
# accumulate boxes
gts_all = bbXYWH2Coords(
np.array([gt.bbox_full * scale for gt in imobj.gts]))
gts_val = gts_all[(rmvs == False) & (igns == False), :]
#
gts_3d = np.array([gt.bbox_3d for gt in imobj.gts])
gts_3d = gts_3d[(rmvs == False) & (igns == False), :]
#
if gts_val.shape[0] > 0:
#
# center all 2D ground truths
for gtind in range(0, gts_val.shape[0]):
w = gts_val[gtind, 2] - gts_val[gtind, 0] + 1
h = gts_val[gtind, 3] - gts_val[gtind, 1] + 1
#
gts_val[gtind, 0:4] = anchor_center(
w, h, conf.feat_stride)
#
if gts_val.shape[0] > 0:
normalized_gts += np.concatenate(
(gts_val, gts_3d), axis=1).tolist()
#
# convert to np
normalized_gts = np.array(normalized_gts)
#
# expand dimensions
anchors = np.concatenate(
(anchors, np.zeros([anchors.shape[0], 5])), axis=1)
#
# bbox_3d order --> [cx3d, cy3d, cz3d, w3d, h3d, l3d, rotY]
anchors_z3d = [[] for x in range(anchors.shape[0])]
anchors_w3d = [[] for x in range(anchors.shape[0])]
anchors_h3d = [[] for x in range(anchors.shape[0])]
anchors_l3d = [[] for x in range(anchors.shape[0])]
anchors_rotY = [[] for x in range(anchors.shape[0])]
#
# find best matches for each ground truth
ols = iou(anchors[:, 0:4], normalized_gts[:, 0:4])
gt_target_ols = np.amax(ols, axis=0)
gt_target_anchor = np.argmax(ols, axis=0)
#
# assign each box to an anchor
for gtind, gt in enumerate(normalized_gts):
#
anum = gt_target_anchor[gtind]
#
if gt_target_ols[gtind] > 0.2:
anchors_z3d[anum].append(gt[6])
anchors_w3d[anum].append(gt[7])
anchors_h3d[anum].append(gt[8])
anchors_l3d[anum].append(gt[9])
anchors_rotY[anum].append(gt[10])
#
# compute global means
anchors_z3d_gl = np.empty(0)
anchors_w3d_gl = np.empty(0)
anchors_h3d_gl = np.empty(0)
anchors_l3d_gl = np.empty(0)
anchors_rotY_gl = np.empty(0)
#
# update anchors
for aind in range(0, anchors.shape[0]):
#
if len(np.array(anchors_z3d[aind])) > 0:
#
if conf.has_3d:
#
anchors_z3d_gl = np.hstack(
(anchors_z3d_gl, np.array(anchors_z3d[aind])))
anchors_w3d_gl = np.hstack(
(anchors_w3d_gl, np.array(anchors_w3d[aind])))
anchors_h3d_gl = np.hstack(
(anchors_h3d_gl, np.array(anchors_h3d[aind])))
anchors_l3d_gl = np.hstack(
(anchors_l3d_gl, np.array(anchors_l3d[aind])))
anchors_rotY_gl = np.hstack(
(anchors_rotY_gl, np.array(anchors_rotY[aind])))
#
anchors[aind, 4] = np.mean(np.array(anchors_z3d[aind]))
anchors[aind, 5] = np.mean(np.array(anchors_w3d[aind]))
anchors[aind, 6] = np.mean(np.array(anchors_h3d[aind]))
anchors[aind, 7] = np.mean(np.array(anchors_l3d[aind]))
anchors[aind, 8] = np.mean(np.array(anchors_rotY[aind]))
#
else:
raise ValueError('Non-used anchor #{} found'.format(aind))
#
if (cache_folder is not None):
pickle_write(os.path.join(cache_folder, 'anchors.pkl'), anchors)
#
conf.anchors = anchors
def anchor_center(w, h, stride):
"""
Centers an anchor based on a stride and the anchor shape (w, h).
center ground truths with steps of half stride
hence box 0 is centered at (7.5, 7.5) rather than (0, 0)
for a feature stride of 16 px.
"""
anchor = np.zeros([4], dtype=np.float32)
#
anchor[0] = -w / 2 + (stride - 1) / 2
anchor[1] = -h / 2 + (stride - 1) / 2
anchor[2] = w / 2 + (stride - 1) / 2
anchor[3] = h / 2 + (stride - 1) / 2
#
return anchor
def cluster_anchors(feat_stride,
anchors,
test_scale,
imdb,
lbls,
ilbls,
anchor_ratios,
min_gt_vis=0.99,
min_gt_h=0,
max_gt_h=10e10,
even_anchor_distribution=False,
expand_anchors=False,
expand_stop_dt=0.0025):
"""
Clusters the anchors based on the imdb boxes (in 2D and/or 3D).
#
Generally, this method does a custom k-means clustering using 2D IoU
as a distance metric.
"""
normalized_gts = []
# keep track if using 3d
has_3d = False
# check all images
for imind, imobj in enumerate(imdb):
# has ground truths?
if len(imobj.gts) > 0:
scale = imobj.scale * test_scale / imobj.imH
# determine ignores
igns, rmvs = determine_ignores(imobj.gts, lbls, ilbls, min_gt_vis,
min_gt_h, np.inf, scale)
# check for 3d box
has_3d = 'bbox_3d' in imobj.gts[0]
# accumulate boxes
gts_all = bbXYWH2Coords(
np.array([gt.bbox_full * scale for gt in imobj.gts]))
gts_val = gts_all[(rmvs == False) & (igns == False), :]
if has_3d:
gts_3d = np.array([gt.bbox_3d for gt in imobj.gts])
gts_3d = gts_3d[(rmvs == False) & (igns == False), :]
if gts_val.shape[0] > 0:
# center all 2D ground truths
for gtind in range(0, gts_val.shape[0]):
w = gts_val[gtind, 2] - gts_val[gtind, 0] + 1
h = gts_val[gtind, 3] - gts_val[gtind, 1] + 1
gts_val[gtind, 0:4] = anchor_center(w, h, feat_stride)
if gts_val.shape[0] > 0:
# add normalized gts given 3d or 2d boxes
if has_3d:
normalized_gts += np.concatenate(
(gts_val, gts_3d), axis=1).tolist()
else:
normalized_gts += gts_val.tolist()
# convert to np
normalized_gts = np.array(normalized_gts)
# sort by height
sorted_inds = np.argsort((normalized_gts[:, 3] - normalized_gts[:, 1] + 1))
normalized_gts = normalized_gts[sorted_inds, :]
min_h = normalized_gts[0, 3] - normalized_gts[0, 1] + 1
max_h = normalized_gts[-1, 3] - normalized_gts[-1, 1] + 1
# for 3d, expand dimensions
if has_3d:
anchors = np.concatenate(
(anchors, np.zeros([anchors.shape[0], 5])), axis=1)
# init expand
best_anchors = anchors
expand_last_iou = 0
expand_dif = 1
best_iou = 0
best_cov = 0
while np.round(expand_dif, 5) > expand_stop_dt:
# init cluster
max_rounds = 1000
round = 0
last_iou = 0
dif = 1
if even_anchor_distribution:
sample_num = int(
np.floor(normalized_gts.shape[0] / anchors.shape[0]))
# evenly distribute the anchors
for aind in range(0, anchors.shape[0]):
x1 = normalized_gts[aind * sample_num:(aind * sample_num +
sample_num), 0]
y1 = normalized_gts[aind * sample_num:(aind * sample_num +
sample_num), 1]
x2 = normalized_gts[aind * sample_num:(aind * sample_num +
sample_num), 2]
y2 = normalized_gts[aind * sample_num:(aind * sample_num +
sample_num), 3]
w = np.mean(x2 - x1 + 1)
h = np.mean(y2 - y1 + 1)
anchors[aind, 0:4] = anchor_center(w, h, feat_stride)
else:
base = ((max_gt_h) / (min_gt_h))**(1 / (anchors.shape[0] - 1))
anchor_scales = np.array(
[(min_gt_h) * (base**i) for i in range(0, anchors.shape[0])])
aind = 0
# compute new anchors
for scale in anchor_scales:
for ratio in anchor_ratios:
h = scale
w = scale * ratio
anchors[aind, 0:4] = anchor_center(w, h, feat_stride)
aind += 1
while round < max_rounds and dif > -0.0:
# make empty arrays for each anchor
anchors_h = [[] for x in range(anchors.shape[0])]
anchors_w = [[] for x in range(anchors.shape[0])]
if has_3d:
# bbox_3d order --> [cx3d, cy3d, cz3d, w3d, h3d, l3d, rotY]
anchors_z3d = [[] for x in range(anchors.shape[0])]
anchors_w3d = [[] for x in range(anchors.shape[0])]
anchors_h3d = [[] for x in range(anchors.shape[0])]
anchors_l3d = [[] for x in range(anchors.shape[0])]
anchors_rotY = [[] for x in range(anchors.shape[0])]
round_ious = []
# find best matches for each ground truth
ols = iou(anchors[:, 0:4], normalized_gts[:, 0:4])
gt_target_ols = np.amax(ols, axis=0)
gt_target_anchor = np.argmax(ols, axis=0)
# assign each box to an anchor
for gtind, gt in enumerate(normalized_gts):
anum = gt_target_anchor[gtind]
w = gt[2] - gt[0] + 1
h = gt[3] - gt[1] + 1
anchors_h[anum].append(h)
anchors_w[anum].append(w)
if has_3d:
anchors_z3d[anum].append(gt[6])
anchors_w3d[anum].append(gt[7])
anchors_h3d[anum].append(gt[8])
anchors_l3d[anum].append(gt[9])
anchors_rotY[anum].append(gt[10])
round_ious.append(gt_target_ols[gtind])
# compute current iou
cur_iou = np.mean(np.array(round_ious))
# update anchors
for aind in range(0, anchors.shape[0]):
# compute mean h/w
if len(np.array(anchors_h[aind])) > 0:
mean_h = np.mean(np.array(anchors_h[aind]))
mean_w = np.mean(np.array(anchors_w[aind]))
anchors[aind, 0:4] = anchor_center(mean_w, mean_h,
feat_stride)
if has_3d:
anchors[aind, 4] = np.mean(np.array(anchors_z3d[aind]))
anchors[aind, 5] = np.mean(np.array(anchors_w3d[aind]))
anchors[aind, 6] = np.mean(np.array(anchors_h3d[aind]))
anchors[aind, 7] = np.mean(np.array(anchors_l3d[aind]))
anchors[aind, 8] = np.mean(np.array(anchors_rotY[aind]))
else:
# anchor not used
anchors[aind, :] = 0
anchors = np.nan_to_num(anchors)
valid_anchors = np.invert(np.all(anchors == 0, axis=1))
# redistribute non-valid anchors
valid_anchors_inds = np.flatnonzero(valid_anchors)
# determine most heavy anchors (to be split up)
valid_multi = np.array([len(x) for x in anchors_h])
valid_multi = valid_multi[valid_anchors_inds]
valid_multi = valid_multi / np.sum(valid_multi)
# store best configuration
if cur_iou > best_iou:
best_iou = cur_iou
best_anchors = anchors[valid_anchors, :]
best_cov = np.mean(np.array(round_ious) > 0.5)
# add random new anchors for any not used
for aind in range(0, anchors.shape[0]):
# make a new anchor
if not valid_anchors[aind]:
randomness = 0.5
multi = randomness * np.random.rand(len(valid_anchors_inds))
multi += valid_multi
multi /= np.sum(multi)
anchors[aind, :] = np.dot(anchors[valid_anchors_inds, :].T,
multi.T)
if not all(valid_anchors):
logging.info(
'warning: round {} some anchors not used during clustering'.
format(round))
dif = cur_iou - last_iou
last_iou = cur_iou
round += 1
logging.info(
'anchors={}, rounds={}, mean_iou={:.4f}, gt_coverage={:.4f}'.format(
anchors.shape[0], round, best_iou, best_cov))
expand_dif = best_iou - expand_last_iou
expand_last_iou = best_iou
# expand anchors to next size
if anchors.shape[0] < expand_anchors and expand_dif > expand_stop_dt:
# append blank anchor
if has_3d:
anchors = np.vstack((anchors, [0, 0, 0, 0, 0, 0, 0, 0, 0]))
else:
anchors = np.vstack((anchors, [0, 0, 0, 0]))
# force stop
else:
expand_dif = -1
logging.info('final_iou={:.4f}, final_coverage={:.4f}'.format(best_iou,
best_cov))
return best_anchors
def compute_targets(gts_val,
gts_ign,
box_lbls,
rois,
fg_thresh,
ign_thresh,
bg_thresh_lo,
bg_thresh_hi,
best_thresh,
gts_3d=None,
anchors=[],
tracker=[]):
"""
Computes the bbox targets of a set of rois and a set
of ground truth boxes, provided various ignore
settings in configuration
"""
ols = None
has_3d = gts_3d is not None
# init transforms which respectively hold [dx, dy, dw, dh, label]
# for labels bg=-1, ign=0, fg>=1
transforms = np.zeros([len(rois), 5], dtype=np.float32)
raw_gt = np.zeros([len(rois), 5], dtype=np.float32)
# if 3d, then init other terms after
if has_3d:
transforms = np.pad(transforms, [(0, 0), (0, gts_3d.shape[1])],
'constant')
raw_gt = np.pad(raw_gt, [(0, 0), (0, gts_3d.shape[1])], 'constant')
if gts_val.shape[0] > 0 or gts_ign.shape[0] > 0:
if gts_ign.shape[0] > 0:
# compute overlaps ign
ols_ign = iou_ign(rois, gts_ign)
ols_ign_max = np.amax(ols_ign, axis=1)
else:
ols_ign_max = np.zeros([rois.shape[0]], dtype=np.float32)
if gts_val.shape[0] > 0:
# compute overlaps valid
ols = iou(rois, gts_val)
ols_max = np.amax(ols, axis=1)
targets = np.argmax(ols, axis=1)
# find best matches for each ground truth
gt_best_rois = np.argmax(ols, axis=0)
gt_best_ols = np.amax(ols, axis=0)
gt_best_rois = gt_best_rois[gt_best_ols >= best_thresh]
gt_best_ols = gt_best_ols[gt_best_ols >= best_thresh]
fg_inds = np.flatnonzero(ols_max >= fg_thresh)
fg_inds = np.concatenate((fg_inds, gt_best_rois))
fg_inds = np.unique(fg_inds)
target_rois = gts_val[targets[fg_inds], :]
src_rois = rois[fg_inds, :]
if len(fg_inds) > 0:
# compute 2d transform
transforms[fg_inds, 0:4] = bbox_transform(src_rois, target_rois)
raw_gt[fg_inds, 0:4] = target_rois
if has_3d:
tracker = tracker.astype(np.int64)
src_3d = anchors[tracker[fg_inds], 4:]
target_3d = gts_3d[targets[fg_inds]]
raw_gt[fg_inds, 5:] = target_3d
# compute 3d transform
transforms[fg_inds, 5:] = bbox_transform_3d(
src_rois, src_3d, target_3d)
# store labels
transforms[fg_inds, 4] = [box_lbls[x] for x in targets[fg_inds]]
assert (all(transforms[fg_inds, 4] >= 1))
else:
ols_max = np.zeros(rois.shape[0], dtype=int)
fg_inds = np.empty(shape=[0])
gt_best_rois = np.empty(shape=[0])
# determine ignores
ign_inds = np.flatnonzero(ols_ign_max >= ign_thresh)
# determine background
bg_inds = np.flatnonzero((ols_max >= bg_thresh_lo) & (ols_max <
bg_thresh_hi))
# subtract fg and igns from background
bg_inds = np.setdiff1d(bg_inds, ign_inds)
bg_inds = np.setdiff1d(bg_inds, fg_inds)
bg_inds = np.setdiff1d(bg_inds, gt_best_rois)
# mark background
transforms[bg_inds, 4] = -1
else:
# all background
transforms[:, 4] = -1
return transforms, ols, raw_gt
def hill_climb(p2,
p2_inv,
box_2d,
x2d,
y2d,
z2d,
w3d,
h3d,
l3d,
ry3d,
step_z_init=0,
step_r_init=0,
z_lim=0,
r_lim=0,
min_ol_dif=0.0):
"""hill climb"""
step_z = step_z_init
step_r = step_r_init
ol_best, verts_best, _, invalid = test_projection(
p2, p2_inv, box_2d, x2d, y2d, z2d, w3d, h3d, l3d, ry3d)
if invalid: return z2d, ry3d, verts_best
# attempt to fit z/rot more properly
while (step_z > z_lim or step_r > r_lim):
if step_z > z_lim:
ol_neg, verts_neg, _, invalid_neg = test_projection(
p2, p2_inv, box_2d, x2d, y2d, z2d - step_z, w3d, h3d, l3d, ry3d)
ol_pos, verts_pos, _, invalid_pos = test_projection(
p2, p2_inv, box_2d, x2d, y2d, z2d + step_z, w3d, h3d, l3d, ry3d)
invalid = ((ol_pos - ol_best) <= min_ol_dif) and (
(ol_neg - ol_best) <= min_ol_dif)
if invalid:
step_z = step_z * 0.5
elif (ol_pos - ol_best
) > min_ol_dif and ol_pos > ol_neg and not invalid_pos:
z2d += step_z
ol_best = ol_pos
verts_best = verts_pos
elif (ol_neg - ol_best) > min_ol_dif and not invalid_neg:
z2d -= step_z
ol_best = ol_neg
verts_best = verts_neg
else:
step_z = step_z * 0.5
if step_r > r_lim:
ol_neg, verts_neg, _, invalid_neg = test_projection(
p2, p2_inv, box_2d, x2d, y2d, z2d, w3d, h3d, l3d, ry3d - step_r)
ol_pos, verts_pos, _, invalid_pos = test_projection(
p2, p2_inv, box_2d, x2d, y2d, z2d, w3d, h3d, l3d, ry3d + step_r)
invalid = ((ol_pos - ol_best) <= min_ol_dif) and (
(ol_neg - ol_best) <= min_ol_dif)
if invalid:
step_r = step_r * 0.5
elif (ol_pos - ol_best
) > min_ol_dif and ol_pos > ol_neg and not invalid_pos:
ry3d += step_r
ol_best = ol_pos
verts_best = verts_pos
elif (ol_neg - ol_best) > min_ol_dif and not invalid_neg:
ry3d -= step_r
ol_best = ol_neg
verts_best = verts_neg
else:
step_r = step_r * 0.5
while ry3d > math.pi:
ry3d -= math.pi * 2
while ry3d < (-math.pi):
ry3d += math.pi * 2
return z2d, ry3d, verts_best
# def clsInd2Name(lbls, ind):
# """
# Converts a cls ind to string name
# """
# if ind>=0 and ind<len(lbls):
# return lbls[ind]
# else:
# raise ValueError('unknown class')
def clsName2Ind(lbls, cls):
"""
Converts a cls name to an ind
"""
if cls in lbls:
return lbls.index(cls) + 1
else:
raise ValueError('unknown class')
def compute_bbox_stats(conf, imdb, cache_folder=''):
"""
Computes the mean and standard deviation for each regression
parameter (usually pertaining to [dx, dy, sw, sh] but sometimes
for 3d parameters too).
Once these stats are known we normalize the regression targets
to have 0 mean and 1 variance, to hypothetically ease training.
"""
if (cache_folder is not None) and os.path.exists(os.path.join(cache_folder, 'bbox_means.pkl')) \
and os.path.exists(os.path.join(cache_folder, 'bbox_stds.pkl')):
means = pickle_read(os.path.join(cache_folder, 'bbox_means.pkl'))
stds = pickle_read(os.path.join(cache_folder, 'bbox_stds.pkl'))
else:
if conf.has_3d:
squared_sums = np.zeros([1, 11], dtype=np.float128)
sums = np.zeros([1, 11], dtype=np.float128)
else:
squared_sums = np.zeros([1, 4], dtype=np.float128)
sums = np.zeros([1, 4], dtype=np.float128)
class_counts = np.zeros([1], dtype=np.float128) + 1e-10
# compute the mean first
logging.info('Computing bbox regression mean..')
for imind, imobj in enumerate(imdb):
if len(imobj.gts) > 0:
scale_factor = imobj.scale * conf.test_scale / imobj.imH
feat_size = calc_output_size(
np.array([imobj.imH, imobj.imW]) * scale_factor,
conf.feat_stride)
rois = locate_anchors(conf.anchors, feat_size, conf.feat_stride)
# determine ignores
igns, rmvs = determine_ignores(imobj.gts, conf.lbls, conf.ilbls,
conf.min_gt_vis, conf.min_gt_h,
np.inf, scale_factor)
# accumulate boxes
gts_all = bbXYWH2Coords(
np.array([gt.bbox_full * scale_factor for gt in imobj.gts]))
# filter out irrelevant cls, and ignore cls
gts_val = gts_all[(rmvs == False) & (igns == False), :]
gts_ign = gts_all[(rmvs == False) & (igns == True), :]
# accumulate labels
box_lbls = np.array([gt.cls for gt in imobj.gts])
box_lbls = box_lbls[(rmvs == False) & (igns == False)]
box_lbls = np.array(
[clsName2Ind(conf.lbls, cls) for cls in box_lbls])
if conf.has_3d:
# accumulate 3d boxes
gts_3d = np.array([gt.bbox_3d for gt in imobj.gts])
gts_3d = gts_3d[(rmvs == False) & (igns == False), :]
# rescale centers (in 2d)
for gtind, gt in enumerate(gts_3d):
gts_3d[gtind, 0:2] *= scale_factor
# compute transforms for all 3d
transforms, _, _ = compute_targets(
gts_val,
gts_ign,
box_lbls,
rois,
conf.fg_thresh,
conf.ign_thresh,
conf.bg_thresh_lo,
conf.bg_thresh_hi,
conf.best_thresh,
gts_3d=gts_3d,
anchors=conf.anchors,
tracker=rois[:, 4])
else:
# compute transforms for 2d
transforms, _, _ = compute_targets(
gts_val, gts_ign, box_lbls, rois, conf.fg_thresh,
conf.ign_thresh, conf.bg_thresh_lo, conf.bg_thresh_hi,
conf.best_thresh)
gt_inds = np.flatnonzero(transforms[:, 4] > 0)
if len(gt_inds) > 0:
if conf.has_3d:
sums[:, 0:4] += np.sum(transforms[gt_inds, 0:4], axis=0)
sums[:, 4:] += np.sum(transforms[gt_inds, 5:12], axis=0)
else:
sums += np.sum(transforms[gt_inds, 0:4], axis=0)
class_counts += len(gt_inds)
means = sums / class_counts
logging.info('Computing bbox regression stds..')
for imobj in imdb:
if len(imobj.gts) > 0:
scale_factor = imobj.scale * conf.test_scale / imobj.imH
feat_size = calc_output_size(
np.array([imobj.imH, imobj.imW]) * scale_factor,
conf.feat_stride)
rois = locate_anchors(conf.anchors, feat_size, conf.feat_stride)
# determine ignores
igns, rmvs = determine_ignores(imobj.gts, conf.lbls, conf.ilbls,
conf.min_gt_vis, conf.min_gt_h,
np.inf, scale_factor)
# accumulate boxes
gts_all = bbXYWH2Coords(
np.array([gt.bbox_full * scale_factor for gt in imobj.gts]))
# filter out irrelevant cls, and ignore cls
gts_val = gts_all[(rmvs == False) & (igns == False), :]
gts_ign = gts_all[(rmvs == False) & (igns == True), :]
# accumulate labels
box_lbls = np.array([gt.cls for gt in imobj.gts])
box_lbls = box_lbls[(rmvs == False) & (igns == False)]
box_lbls = np.array(
[clsName2Ind(conf.lbls, cls) for cls in box_lbls])
if conf.has_3d:
# accumulate 3d boxes
gts_3d = np.array([gt.bbox_3d for gt in imobj.gts])
gts_3d = gts_3d[(rmvs == False) & (igns == False), :]
# rescale centers (in 2d)
for gtind, gt in enumerate(gts_3d):
gts_3d[gtind, 0:2] *= scale_factor
# compute transforms for all 3d
transforms, _, _ = compute_targets(
gts_val,
gts_ign,
box_lbls,
rois,
conf.fg_thresh,
conf.ign_thresh,
conf.bg_thresh_lo,
conf.bg_thresh_hi,
conf.best_thresh,
gts_3d=gts_3d,
anchors=conf.anchors,
tracker=rois[:, 4])
else:
# compute transforms for 2d
transforms, _, _ = compute_targets(
gts_val, gts_ign, box_lbls, rois, conf.fg_thresh,
conf.ign_thresh, conf.bg_thresh_lo, conf.bg_thresh_hi,
conf.best_thresh)
gt_inds = np.flatnonzero(transforms[:, 4] > 0)
if len(gt_inds) > 0:
if conf.has_3d:
squared_sums[:, 0:4] += np.sum(np.power(
transforms[gt_inds, 0:4] - means[:, 0:4], 2),
axis=0)
squared_sums[:, 4:] += np.sum(np.power(
transforms[gt_inds, 5:12] - means[:, 4:], 2),
axis=0)
else:
squared_sums += np.sum(np.power(
transforms[gt_inds, 0:4] - means, 2),
axis=0)
stds = np.sqrt((squared_sums / class_counts))
means = means.astype(float)
stds = stds.astype(float)
logging.info('used {:d} boxes with avg std {:.4f}'.format(
int(class_counts[0]), np.mean(stds)))
if (cache_folder is not None):
pickle_write(os.path.join(cache_folder, 'bbox_means.pkl'), means)
pickle_write(os.path.join(cache_folder, 'bbox_stds.pkl'), stds)
conf.bbox_means = means
conf.bbox_stds = stds
def flatten_tensor(input):
"""
Flattens and permutes a tensor from size
[B x C x W x H] --> [B x (W x H) x C]
[B x C x H x W] --> [B x (W x H) x C]
"""
bsize, csize, h, w = input.shape
input_trans = fluid.layers.transpose(input, [0, 2, 3, 1])
output = fluid.layers.reshape(input_trans, [bsize, h * w, csize])
return output
# def unflatten_tensor(input, feat_size, anchors):
# """
# Un-flattens and un-permutes a tensor from size
# [B x (W x H) x C] --> [B x C x W x H]
# """
# bsize = input.shape[0]
# if len(input.shape) >= 3: csize = input.shape[2]
# else: csize = 1
# input = input.view(bsize, feat_size[0] * anchors.shape[0], feat_size[1], csize)
# input = input.permute(0, 3, 1, 2).contiguous()
# return input
def project_3d(p2, x3d, y3d, z3d, w3d, h3d, l3d, ry3d, return_3d=False):
"""
Projects a 3D box into 2D vertices
Args:
p2 (nparray): projection matrix of size 4x3
x3d: x-coordinate of center of object
y3d: y-coordinate of center of object
z3d: z-cordinate of center of object
w3d: width of object
h3d: height of object
l3d: length of object
ry3d: rotation w.r.t y-axis
"""
# compute rotational matrix around yaw axis
R = np.array([[+math.cos(ry3d), 0, +math.sin(ry3d)], [0, 1, 0],
[-math.sin(ry3d), 0, +math.cos(ry3d)]])
# 3D bounding box corners
x_corners = np.array([0, l3d, l3d, l3d, l3d, 0, 0, 0])
y_corners = np.array([0, 0, h3d, h3d, 0, 0, h3d, h3d])
z_corners = np.array([0, 0, 0, w3d, w3d, w3d, w3d, 0])
x_corners += -l3d / 2
y_corners += -h3d / 2
z_corners += -w3d / 2
# bounding box in object co-ordinate
corners_3d = np.array([x_corners, y_corners, z_corners])
# rotate
corners_3d = R.dot(corners_3d)
# translate
corners_3d += np.array([x3d, y3d, z3d]).reshape((3, 1))
corners_3D_1 = np.vstack((corners_3d, np.ones((corners_3d.shape[-1]))))
corners_2D = p2.dot(corners_3D_1)
corners_2D = corners_2D / corners_2D[2]
bb3d_lines_verts_idx = [0, 1, 2, 3, 4, 5, 6, 7, 0, 5, 4, 1, 2, 7, 6, 3]
verts3d = (corners_2D[:, bb3d_lines_verts_idx][:2]).astype(float).T
if return_3d:
return verts3d, corners_3d
else:
return verts3d
def project_3d_corners(p2, x3d, y3d, z3d, w3d, h3d, l3d, ry3d):
"""
Projects a 3D box into 2D vertices
Args:
p2 (nparray): projection matrix of size 4x3
x3d: x-coordinate of center of object
y3d: y-coordinate of center of object
z3d: z-cordinate of center of object
w3d: width of object
h3d: height of object
l3d: length of object
ry3d: rotation w.r.t y-axis
"""
# compute rotational matrix around yaw axis
R = np.array([[+math.cos(ry3d), 0, +math.sin(ry3d)], [0, 1, 0],
[-math.sin(ry3d), 0, +math.cos(ry3d)]])
# 3D bounding box corners
x_corners = np.array([0, l3d, l3d, l3d, l3d, 0, 0, 0])
y_corners = np.array([0, 0, h3d, h3d, 0, 0, h3d, h3d])
z_corners = np.array([0, 0, 0, w3d, w3d, w3d, w3d, 0])
'''
order of vertices
0 upper back right
1 upper front right
2 bottom front right
3 bottom front left
4 upper front left
5 upper back left
6 bottom back left
7 bottom back right
bot_inds = np.array([2,3,6,7])
top_inds = np.array([0,1,4,5])
'''
x_corners += -l3d / 2
y_corners += -h3d / 2
z_corners += -w3d / 2
# bounding box in object co-ordinate
corners_3d = np.array([x_corners, y_corners, z_corners])
# rotate
corners_3d = R.dot(corners_3d)
# translate
corners_3d += np.array([x3d, y3d, z3d]).reshape((3, 1))
corners_3D_1 = np.vstack((corners_3d, np.ones((corners_3d.shape[-1]))))
corners_2D = p2.dot(corners_3D_1)
corners_2D = corners_2D / corners_2D[2]
return corners_2D, corners_3D_1
# def bbCoords2XYWH(box):
# """
# Convert from [x1, y1, x2, y2] to [x,y,w,h]
# """
# if box.shape[0] == 0: return np.empty([0, 4], dtype=float)
# box[:, 2] -= box[:, 0] + 1
# box[:, 3] -= box[:, 1] + 1
# return box
def bbXYWH2Coords(box):
"""
Convert from [x,y,w,h] to [x1, y1, x2, y2]
"""
if box.shape[0] == 0: return np.empty([0, 4], dtype=float)
box[:, 2] += box[:, 0] - 1
box[:, 3] += box[:, 1] - 1
return box
def bbox_transform_3d(ex_rois_2d, ex_rois_3d, gt_rois):
"""
Compute the bbox target transforms in 3D.
Translations are done as simple difference, whereas others involving
scaling are done in log space (hence, log(1) = 0, log(0.8) < 0 and
log(1.2) > 0 which is a good property).
"""
ex_widths = ex_rois_2d[:, 2] - ex_rois_2d[:, 0] + 1.0
ex_heights = ex_rois_2d[:, 3] - ex_rois_2d[:, 1] + 1.0
ex_ctr_x = ex_rois_2d[:, 0] + 0.5 * (ex_widths - 1)
ex_ctr_y = ex_rois_2d[:, 1] + 0.5 * (ex_heights - 1)
gt_ctr_x = gt_rois[:, 0]
gt_ctr_y = gt_rois[:, 1]
targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths
targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights
delta_z = gt_rois[:, 2] - ex_rois_3d[:, 0]
scale_w = np.log(gt_rois[:, 3] / ex_rois_3d[:, 1])
scale_h = np.log(gt_rois[:, 4] / ex_rois_3d[:, 2])
scale_l = np.log(gt_rois[:, 5] / ex_rois_3d[:, 3])
deltaRotY = gt_rois[:, 6] - ex_rois_3d[:, 4]
targets = np.vstack((targets_dx, targets_dy, delta_z, scale_w, scale_h,
scale_l, deltaRotY)).transpose()
targets = np.hstack((targets, gt_rois[:, 7:]))
return targets
def bbox_transform(ex_rois, gt_rois):
"""
Compute the bbox target transforms in 2D.
Translations are done as simple difference, whereas others involving
scaling are done in log space (hence, log(1) = 0, log(0.8) < 0 and
log(1.2) > 0 which is a good property).
"""
ex_widths = ex_rois[:, 2] - ex_rois[:, 0] + 1.0
ex_heights = ex_rois[:, 3] - ex_rois[:, 1] + 1.0
ex_ctr_x = ex_rois[:, 0] + 0.5 * (ex_widths - 1)
ex_ctr_y = ex_rois[:, 1] + 0.5 * (ex_heights - 1)
gt_widths = gt_rois[:, 2] - gt_rois[:, 0] + 1.0
gt_heights = gt_rois[:, 3] - gt_rois[:, 1] + 1.0
gt_ctr_x = gt_rois[:, 0] + 0.5 * (gt_widths - 1.0)
gt_ctr_y = gt_rois[:, 1] + 0.5 * (gt_heights - 1.0)
targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths
targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights
targets_dw = np.log(gt_widths / ex_widths)
targets_dh = np.log(gt_heights / ex_heights)
targets = np.vstack(
(targets_dx, targets_dy, targets_dw, targets_dh)).transpose()
return targets
def bbox_transform_inv(boxes, deltas, means=None, stds=None):
"""
Compute the bbox target transforms in 3D.
Translations are done as simple difference, whereas others involving
scaling are done in log space (hence, log(1) = 0, log(0.8) < 0 and
log(1.2) > 0 which is a good property).
Args:
bboxes (nparray): N x 5 array describing [x1, y1, x2, y2, anchor_index]
deltas (nparray): N x 4 array describing [dx, dy, dw, dh]
return: bbox target transforms in 3D (nparray)
"""
if boxes.shape[0] == 0:
return np.zeros((0, deltas.shape[1]), dtype=deltas.dtype)
# boxes = boxes.astype(deltas.dtype, copy=False)
data_type = type(deltas)
if data_type == paddle.fluid.core_avx.VarBase:
boxes = to_variable(boxes)
widths = boxes[:, 2] - boxes[:, 0] + 1.0
heights = boxes[:, 3] - boxes[:, 1] + 1.0
ctr_x = boxes[:, 0] + 0.5 * widths
ctr_y = boxes[:, 1] + 0.5 * heights
dx = deltas[:, 0]
dy = deltas[:, 1]
dw = deltas[:, 2]
dh = deltas[:, 3]
if stds is not None:
dx *= stds[0]
dy *= stds[1]
dw *= stds[2]
dh *= stds[3]
if means is not None:
dx += means[0]
dy += means[1]
dw += means[2]
dh += means[3]
if data_type == np.ndarray:
pred_ctr_x = dx * widths + ctr_x
pred_ctr_y = dy * heights + ctr_y
pred_w = np.exp(dw) * widths
pred_h = np.exp(dh) * heights
pred_boxes = np.zeros(deltas.shape)
# x1, y1, x2, y2
pred_boxes[:, 0] = pred_ctr_x - 0.5 * pred_w
pred_boxes[:, 1] = pred_ctr_y - 0.5 * pred_h
pred_boxes[:, 2] = pred_ctr_x + 0.5 * pred_w
pred_boxes[:, 3] = pred_ctr_y + 0.5 * pred_h
return pred_boxes
elif data_type == paddle.fluid.core_avx.VarBase:
pred_ctr_x = dx * widths + ctr_x
pred_ctr_y = dy * heights + ctr_y
pred_w = fluid.layers.exp(dw) * widths
pred_h = fluid.layers.exp(dh) * heights
pred_x1 = fluid.layers.unsqueeze(pred_ctr_x - 0.5 * pred_w, 1)
pred_y1 = fluid.layers.unsqueeze(pred_ctr_y - 0.5 * pred_h, 1)
pred_x2 = fluid.layers.unsqueeze(pred_ctr_x + 0.5 * pred_w, 1)
pred_y2 = fluid.layers.unsqueeze(pred_ctr_y + 0.5 * pred_h, 1)
pred_boxes = fluid.layers.concat(
input=[pred_x1, pred_y1, pred_x2, pred_y2], axis=1)
return pred_boxes
else:
raise ValueError('unknown data type {}'.format(data_type))
def determine_ignores(gts,
lbls,
ilbls,
min_gt_vis=0.99,
min_gt_h=0,
max_gt_h=10e10,
scale_factor=1):
"""
Given various configuration settings, determine which ground truths
are ignored and which are relevant.
"""
igns = np.zeros([len(gts)], dtype=bool)
rmvs = np.zeros([len(gts)], dtype=bool)
for gtind, gt in enumerate(gts):
ign = gt.ign
ign |= gt.visibility < min_gt_vis
ign |= gt.bbox_full[3] * scale_factor < min_gt_h
ign |= gt.bbox_full[3] * scale_factor > max_gt_h
ign |= gt.cls in ilbls
rmv = not gt.cls in (lbls + ilbls)
igns[gtind] = ign
rmvs[gtind] = rmv
return igns, rmvs
def locate_anchors(anchors, feat_size, stride):
"""
Spreads each anchor shape across a feature map of size feat_size spaced by a known stride.
Args:
anchors (ndarray): N x 4 array describing [x1, y1, x2, y2] displacements for N anchors
feat_size (ndarray): the downsampled resolution W x H to spread anchors across
stride (int): stride of a network
Returns:
ndarray: 2D array = [(W x H) x 5] array consisting of [x1, y1, x2, y2, anchor_index]
"""
# compute rois
shift_x = np.array(range(0, feat_size[1], 1)) * float(stride)
shift_y = np.array(range(0, feat_size[0], 1)) * float(stride)
[shift_x, shift_y] = np.meshgrid(shift_x, shift_y)
rois = np.expand_dims(anchors[:, 0:4], axis=1)
shift_x = np.expand_dims(shift_x, axis=0)
shift_y = np.expand_dims(shift_y, axis=0)
shift_x1 = shift_x + np.expand_dims(rois[:, :, 0], axis=2)
shift_y1 = shift_y + np.expand_dims(rois[:, :, 1], axis=2)
shift_x2 = shift_x + np.expand_dims(rois[:, :, 2], axis=2)
shift_y2 = shift_y + np.expand_dims(rois[:, :, 3], axis=2)
# compute anchor tracker
anchor_tracker = np.zeros(shift_x1.shape, dtype=float)
for aind in range(0, rois.shape[0]):
anchor_tracker[aind, :, :] = aind
stack_size = feat_size[0] * anchors.shape[0]
shift_x1 = shift_x1.reshape(1, stack_size, feat_size[1]).reshape(-1, 1)
shift_y1 = shift_y1.reshape(1, stack_size, feat_size[1]).reshape(-1, 1)
shift_x2 = shift_x2.reshape(1, stack_size, feat_size[1]).reshape(-1, 1)
shift_y2 = shift_y2.reshape(1, stack_size, feat_size[1]).reshape(-1, 1)
anchor_tracker = anchor_tracker.reshape(1, stack_size,
feat_size[1]).reshape(-1, 1)
rois = np.concatenate(
(shift_x1, shift_y1, shift_x2, shift_y2, anchor_tracker), 1)
return rois
def calc_output_size(res, stride):
"""
Approximate the output size of a network
Args:
res (ndarray): input resolution
stride (int): stride of a network
Returns:
ndarray: output resolution
"""
return np.ceil(np.array(res) / stride).astype(int)
def im_detect_3d(im, net, rpn_conf, preprocess, p2, gpu=0, synced=False):
"""
Object detection in 3D
"""
imH_orig = im.shape[0]
imW_orig = im.shape[1]
im = preprocess(im)
im = im[np.newaxis, :, :, :]
imH = im.shape[2]
imW = im.shape[3]
# move to GPU
im = to_variable(im)
scale_factor = imH / imH_orig
cls, prob, bbox_2d, bbox_3d, feat_size, rois = net(im)
# compute feature resolution
num_anchors = rpn_conf.anchors.shape[0]
bbox_x = bbox_2d[:, :, 0]
bbox_y = bbox_2d[:, :, 1]
bbox_w = bbox_2d[:, :, 2]
bbox_h = bbox_2d[:, :, 3]
bbox_x3d = bbox_3d[:, :, 0]
bbox_y3d = bbox_3d[:, :, 1]
bbox_z3d = bbox_3d[:, :, 2]
bbox_w3d = bbox_3d[:, :, 3]
bbox_h3d = bbox_3d[:, :, 4]
bbox_l3d = bbox_3d[:, :, 5]
bbox_ry3d = bbox_3d[:, :, 6]
# detransform 3d
bbox_x3d = bbox_x3d * rpn_conf.bbox_stds[:, 4][
0] + rpn_conf.bbox_means[:, 4][0]
bbox_y3d = bbox_y3d * rpn_conf.bbox_stds[:, 5][
0] + rpn_conf.bbox_means[:, 5][0]
bbox_z3d = bbox_z3d * rpn_conf.bbox_stds[:, 6][
0] + rpn_conf.bbox_means[:, 6][0]
bbox_w3d = bbox_w3d * rpn_conf.bbox_stds[:, 7][
0] + rpn_conf.bbox_means[:, 7][0]
bbox_h3d = bbox_h3d * rpn_conf.bbox_stds[:, 8][
0] + rpn_conf.bbox_means[:, 8][0]
bbox_l3d = bbox_l3d * rpn_conf.bbox_stds[:, 9][
0] + rpn_conf.bbox_means[:, 9][0]
bbox_ry3d = bbox_ry3d * rpn_conf.bbox_stds[:, 10][
0] + rpn_conf.bbox_means[:, 10][0]
# find 3d source
#tracker = rois[:, 4].cpu().detach().numpy().astype(np.int64)
#src_3d = torch.from_numpy(rpn_conf.anchors[tracker, 4:]).cuda().type(torch.cuda.FloatTensor)
tracker = rois[:, 4].astype(np.int64)
src_3d = rpn_conf.anchors[tracker, 4:]
#tracker_sca = rois_sca[:, 4].cpu().detach().numpy().astype(np.int64)
#src_3d_sca = torch.from_numpy(rpn_conf.anchors[tracker_sca, 4:]).cuda().type(torch.cuda.FloatTensor)
# compute 3d transform
widths = rois[:, 2] - rois[:, 0] + 1.0
heights = rois[:, 3] - rois[:, 1] + 1.0
ctr_x = rois[:, 0] + 0.5 * widths
ctr_y = rois[:, 1] + 0.5 * heights
bbox_x3d_np = bbox_x3d.numpy()
bbox_y3d_np = bbox_y3d.numpy() #(1, N)
bbox_z3d_np = bbox_z3d.numpy()
bbox_w3d_np = bbox_w3d.numpy()
bbox_l3d_np = bbox_l3d.numpy()
bbox_h3d_np = bbox_h3d.numpy()
bbox_ry3d_np = bbox_ry3d.numpy()
bbox_x3d_np = bbox_x3d_np[0, :] * widths + ctr_x
bbox_y3d_np = bbox_y3d_np[0, :] * heights + ctr_y
bbox_x_np = bbox_x.numpy()
bbox_y_np = bbox_y.numpy()
bbox_w_np = bbox_w.numpy()
bbox_h_np = bbox_h.numpy()
bbox_z3d_np = src_3d[:, 0] + bbox_z3d_np[0, :] #(N, 5), (N2, 1)
bbox_w3d_np = np.exp(bbox_w3d_np[0, :]) * src_3d[:, 1]
bbox_h3d_np = np.exp(bbox_h3d_np[0, :]) * src_3d[:, 2]
bbox_l3d_np = np.exp(bbox_l3d_np[0, :]) * src_3d[:, 3]
bbox_ry3d_np = src_3d[:, 4] + bbox_ry3d_np[0, :]
# bundle
coords_3d = np.stack((bbox_x3d_np, bbox_y3d_np, bbox_z3d_np[:bbox_x3d_np.shape[0]], bbox_w3d_np[:bbox_x3d_np.shape[0]], bbox_h3d_np[:bbox_x3d_np.shape[0]], \
bbox_l3d_np[:bbox_x3d_np.shape[0]], bbox_ry3d_np[:bbox_x3d_np.shape[0]]), axis=1)#[N, 7]
# compile deltas pred
deltas_2d = np.concatenate(
(bbox_x_np[0, :, np.newaxis], bbox_y_np[0, :, np.newaxis],
bbox_w_np[0, :, np.newaxis], bbox_h_np[0, :, np.newaxis]),
axis=1) #N,4
coords_2d = bbox_transform_inv(
rois,
deltas_2d,
means=rpn_conf.bbox_means[0, :],
stds=rpn_conf.bbox_stds[0, :]) #[N,4]
# detach onto cpu
#coords_2d = coords_2d.cpu().detach().numpy()
#coords_3d = coords_3d.cpu().detach().numpy()
prob_np = prob[0, :, :].numpy() #.cpu().detach().numpy()
# scale coords
coords_2d[:, 0:4] /= scale_factor
coords_3d[:, 0:2] /= scale_factor
cls_pred = np.argmax(prob_np[:, 1:], axis=1) + 1
scores = np.amax(prob_np[:, 1:], axis=1)
aboxes = np.hstack((coords_2d, scores[:, np.newaxis]))
sorted_inds = (-aboxes[:, 4]).argsort()
original_inds = (sorted_inds).argsort()
aboxes = aboxes[sorted_inds, :]
coords_3d = coords_3d[sorted_inds, :]
cls_pred = cls_pred[sorted_inds]
tracker = tracker[sorted_inds]
if synced:
# nms
keep_inds = gpu_nms(
aboxes[:, 0:5].astype(np.float32),
rpn_conf.nms_thres,
device_id=gpu)
# convert to bool
keep = np.zeros([aboxes.shape[0], 1], dtype=bool)
keep[keep_inds, :] = True
# stack the keep array,
# sync to the original order
aboxes = np.hstack((aboxes, keep))
aboxes[original_inds, :]
else:
# pre-nms
cls_pred = cls_pred[0:min(rpn_conf.nms_topN_pre, cls_pred.shape[0])]
tracker = tracker[0:min(rpn_conf.nms_topN_pre, tracker.shape[0])]
aboxes = aboxes[0:min(rpn_conf.nms_topN_pre, aboxes.shape[0]), :]
coords_3d = coords_3d[0:min(rpn_conf.nms_topN_pre, coords_3d.shape[0])]
# nms
keep_inds = gpu_nms(
aboxes[:, 0:5].astype(np.float32),
rpn_conf.nms_thres,
device_id=gpu)
# stack cls prediction
aboxes = np.hstack((aboxes, cls_pred[:, np.newaxis], coords_3d,
tracker[:, np.newaxis]))
# suppress boxes
aboxes = aboxes[keep_inds, :]
# clip boxes
if rpn_conf.clip_boxes:
aboxes[:, 0] = np.clip(aboxes[:, 0], 0, imW_orig - 1)
aboxes[:, 1] = np.clip(aboxes[:, 1], 0, imH_orig - 1)
aboxes[:, 2] = np.clip(aboxes[:, 2], 0, imW_orig - 1)
aboxes[:, 3] = np.clip(aboxes[:, 3], 0, imH_orig - 1)
return aboxes
def get_2D_from_3D(p2, cx3d, cy3d, cz3d, w3d, h3d, l3d, rotY):
verts3d, corners_3d = project_3d(
p2, cx3d, cy3d, cz3d, w3d, h3d, l3d, rotY, return_3d=True)
# any boxes behind camera plane?
if np.any(corners_3d[2, :] <= 0):
ign = True
else:
x = min(verts3d[:, 0])
y = min(verts3d[:, 1])
x2 = max(verts3d[:, 0])
y2 = max(verts3d[:, 1])
return np.array([x, y, x2, y2])
def test_kitti_3d(dataset_test,
net,
rpn_conf,
results_path,
test_path,
use_log=True):
"""
Test the KITTI framework for object detection in 3D
"""
# import read_kitti_cal
from data.m3drpn_reader import read_kitti_cal
imlist = list_files(
os.path.join(test_path, dataset_test, 'validation', 'image_2', ''),
'*.png')
preprocess = Preprocess([rpn_conf.test_scale], rpn_conf.image_means,
rpn_conf.image_stds)
# fix paths slightly
_, test_iter, _ = file_parts(results_path.replace('/data', ''))
test_iter = test_iter.replace('results_', '')
# init
test_start = time()
for imind, impath in enumerate(imlist):
im = cv2.imread(impath)
base_path, name, ext = file_parts(impath)
# read in calib
p2 = read_kitti_cal(
os.path.join(test_path, dataset_test, 'validation', 'calib', name +
'.txt'))
p2_inv = np.linalg.inv(p2)
# forward test batch
aboxes = im_detect_3d(im, net, rpn_conf, preprocess, p2)
base_path, name, ext = file_parts(impath)
file = open(os.path.join(results_path, name + '.txt'), 'w')
text_to_write = ''
for boxind in range(0, min(rpn_conf.nms_topN_post, aboxes.shape[0])):
box = aboxes[boxind, :]
score = box[4]
cls = rpn_conf.lbls[int(box[5] - 1)]
#if score >= 0.75:
if score >= 0.5: #TODO yexiaoqing
x1 = box[0]
y1 = box[1]
x2 = box[2]
y2 = box[3]
width = (x2 - x1 + 1)
height = (y2 - y1 + 1)
# plot 3D box
x3d = box[6]
y3d = box[7]
z3d = box[8]
w3d = box[9]
h3d = box[10]
l3d = box[11]
ry3d = box[12]
# convert alpha into ry3d
coord3d = np.linalg.inv(p2).dot(
np.array([x3d * z3d, y3d * z3d, 1 * z3d, 1]))
ry3d = convertAlpha2Rot(ry3d, coord3d[2], coord3d[0])
step_r = 0.3 * math.pi
r_lim = 0.01
box_2d = np.array([x1, y1, width, height])
z3d, ry3d, verts_best = hill_climb(
p2,
p2_inv,
box_2d,
x3d,
y3d,
z3d,
w3d,
h3d,
l3d,
ry3d,
step_r_init=step_r,
r_lim=r_lim)
# predict a more accurate projection
coord3d = np.linalg.inv(p2).dot(
np.array([x3d * z3d, y3d * z3d, 1 * z3d, 1]))
alpha = convertRot2Alpha(ry3d, coord3d[2], coord3d[0])
x3d = coord3d[0]
y3d = coord3d[1]
z3d = coord3d[2]
y3d += h3d / 2
text_to_write += (
'{} -1 -1 {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} '
+ '{:.6f} {:.6f}\n').format(cls, alpha, x1, y1, x2, y2, h3d,
w3d, l3d, x3d, y3d, z3d, ry3d,
score)
file.write(text_to_write)
file.close()
# display stats
if (imind + 1) % 1000 == 0:
time_str, dt = compute_eta(test_start, imind + 1, len(imlist))
print_str = 'testing {}/{}, dt: {:0.3f}, eta: {}'.format(
imind + 1, len(imlist), dt, time_str)
if use_log: logging.info(print_str)
else: print(print_str)
# evaluate
script = os.path.join(test_path, dataset_test, 'devkit', 'cpp',
'evaluate_object')
with open(os.devnull, 'w') as devnull:
out = subprocess.check_output(
[script, results_path.replace('/data', '')], stderr=devnull)
for lbl in rpn_conf.lbls:
lbl = lbl.lower()
respath_2d = os.path.join(
results_path.replace('/data', ''),
'stats_{}_detection.txt'.format(lbl))
respath_gr = os.path.join(
results_path.replace('/data', ''),
'stats_{}_detection_ground.txt'.format(lbl))
respath_3d = os.path.join(
results_path.replace('/data', ''),
'stats_{}_detection_3d.txt'.format(lbl))
if os.path.exists(respath_2d):
easy, mod, hard = parse_kitti_result(respath_2d)
print_str = 'test_iter {} 2d {} --> easy: {:0.4f}, mod: {:0.4f}, hard: {:0.4f}'.format(
test_iter, lbl, easy, mod, hard)
if use_log: logging.info(print_str)
else: print(print_str)
if os.path.exists(respath_gr):
easy, mod, hard = parse_kitti_result(respath_gr)
print_str = 'test_iter {} gr {} --> easy: {:0.4f}, mod: {:0.4f}, hard: {:0.4f}'.format(
test_iter, lbl, easy, mod, hard)
if use_log: logging.info(print_str)
else: print(print_str)
if os.path.exists(respath_3d):
easy, mod, hard = parse_kitti_result(respath_3d)
print_str = 'test_iter {} 3d {} --> easy: {:0.4f}, mod: {:0.4f}, hard: {:0.4f}'.format(
test_iter, lbl, easy, mod, hard)
if use_log: logging.info(print_str)
else: print(print_str)
def parse_kitti_result(respath):
text_file = open(respath, 'r')
acc = np.zeros([3, 41], dtype=float)
lind = 0
for line in text_file:
parsed = re.findall('([\d]+\.?[\d]*)', line)
for i, num in enumerate(parsed):
acc[lind, i] = float(num)
lind += 1
text_file.close()
easy = np.mean(acc[0, 0:41:4])
mod = np.mean(acc[1, 0:41:4])
hard = np.mean(acc[2, 0:41:4])
#easy = np.mean(acc[0, 1:41:1])
#mod = np.mean(acc[1, 1:41:1])
#hard = np.mean(acc[2, 1:41:1])
return easy, mod, hard
# def parse_kitti_vo(respath):
# text_file = open(respath, 'r')
# acc = np.zeros([1, 2], dtype=float)
# lind = 0
# for line in text_file:
# parsed = re.findall('([\d]+\.?[\d]*)', line)
# for i, num in enumerate(parsed):
# acc[lind, i] = float(num)
# lind += 1
# text_file.close()
# t = acc[0, 0]*100
# r = acc[0, 1]
# return t, r
def test_projection(p2, p2_inv, box_2d, cx, cy, z, w3d, h3d, l3d, rotY):
"""
Tests the consistency of a 3D projection compared to a 2D box
"""
x = box_2d[0]
y = box_2d[1]
x2 = x + box_2d[2] - 1
y2 = y + box_2d[3] - 1
coord3d = p2_inv.dot(np.array([cx * z, cy * z, z, 1]))
cx3d = coord3d[0]
cy3d = coord3d[1]
cz3d = coord3d[2]
# put back on ground first
#cy3d += h3d/2
# re-compute the 2D box using 3D (finally, avoids clipped boxes)
verts3d, corners_3d = project_3d(
p2, cx3d, cy3d, cz3d, w3d, h3d, l3d, rotY, return_3d=True)
invalid = np.any(corners_3d[2, :] <= 0)
x_new = min(verts3d[:, 0])
y_new = min(verts3d[:, 1])
x2_new = max(verts3d[:, 0])
y2_new = max(verts3d[:, 1])
b1 = np.array([x, y, x2, y2])[np.newaxis, :]
b2 = np.array([x_new, y_new, x2_new, y2_new])[np.newaxis, :]
#ol = iou(b1, b2)[0][0]
ol = -(np.abs(x - x_new) + np.abs(y - y_new) + np.abs(x2 - x2_new) +
np.abs(y2 - y2_new))
return ol, verts3d, b2, invalid
"""
This code is based on https://github.com/garrickbrazil/M3D-RPN/blob/master/lib/util.py
This file is meant to contain generic utility functions
which can be easily re-used in any project, and are not
specific to any project or framework (except python!).
"""
import os
import sys
from glob import glob
from time import time
import matplotlib.pyplot as plt
import numpy as np
import importlib
import pickle
import logging
import datetime
import pprint
import shutil
import math
import copy
import cv2
from mpl_toolkits.mplot3d import Axes3D
from PIL import Image
def copyfile(src, dst):
"""
copyfile
"""
shutil.copyfile(src, dst)
def pretty_print(name, input, val_width=40, key_width=0):
"""
This function creates a formatted string from a given dictionary input.
It may not support all data types, but can probably be extended.
Args:
name (str): name of the variable root
input (dict): dictionary to print
val_width (int): the width of the right hand side values
key_width (int): the minimum key width, (always auto-defaults to the longest key!)
Example:
pretty_str = pretty_print('conf', conf.__dict__)
pretty_str = pretty_print('conf', {'key1': 'example', 'key2': [1,2,3,4,5], 'key3': np.random.rand(4,4)})
print(pretty_str)
or
logging.info(pretty_str)
"""
# root
pretty_str = name + ': {\n'
# determine key width
for key in input.keys():
key_width = max(key_width, len(str(key)) + 4)
# cycle keys
for key in input.keys():
val = input[key]
# round values to 3 decimals..
if type(val) == np.ndarray: val = np.round(val, 3).tolist()
# difficult formatting
val_str = str(val)
if len(val_str) > val_width:
val_str = pprint.pformat(val, width=val_width)
val_str = val_str.replace('\n', '\n{tab}')
tab = ('{0:' + str(4 + key_width) + '}').format('')
val_str = val_str.replace('{tab}', tab)
# more difficult formatting
format_str = '{0:' + str(4) + '}{1:' + str(key_width) + '} {2:' + str(
val_width) + '}\n'
pretty_str += format_str.format('', key + ':', val_str)
# close root object
pretty_str += '}'
return pretty_str
def absolute_import(file_path):
"""
Imports a python module / file given its ABSOLUTE path.
Args:
file_path (str): absolute path to a python file to attempt to import
"""
# module name
_, name, _ = file_parts(file_path)
# load the spec and module
spec = importlib.util.spec_from_file_location(name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def init_log_file(folder_path, suffix=None, log_level=logging.INFO):
"""
This function inits a log file given a folder to write the log to.
it automatically adds a timestamp and optional suffix to the log.
Anything written to the log will automatically write to console too.
Example:
import logging
init_log_file('output/logs/')
logging.info('this will show up in both the log AND console!')
"""
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
log_format = '[%(levelname)s]: %(asctime)s %(message)s'
if suffix is not None:
file_name = timestamp + '_' + suffix
else:
file_name = timestamp
file_path = os.path.join(folder_path, file_name)
logging.basicConfig(filename=file_path, level=log_level, format=log_format)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
return file_path
def denorm_image(im, image_means, image_stds):
"""
:param im:
:param image_means:
:param image_stds:
:return:
"""
im = copy.deepcopy(im)
im[:, :, 0] *= image_stds[0]
im[:, :, 1] *= image_stds[1]
im[:, :, 2] *= image_stds[2]
im[:, :, 0] += image_means[0]
im[:, :, 1] += image_means[1]
im[:, :, 2] += image_means[2]
return im
def compute_eta(start_time, idx, total):
"""
Computes the estimated time as a formatted string as well
as the change in delta time dt.
Example:
from time import time
start_time = time()
for i in range(0, total):
<lengthly computation>
time_str, dt = compute_eta(start_time, i, total)
"""
dt = (time() - start_time) / idx
timeleft = np.max([dt * (total - idx), 0])
if timeleft > 3600:
time_str = '{:.1f}h'.format(timeleft / 3600)
elif timeleft > 60:
time_str = '{:.1f}m'.format(timeleft / 60)
else:
time_str = '{:.1f}s'.format(timeleft)
return time_str, dt
def interp_color(dist,
bounds=[0, 1],
color_lo=(0, 0, 250),
color_hi=(0, 250, 250)):
"""
:param dist:
:param bounds:
:param color_lo:
:param color_hi:
:return:
"""
percent = (dist - bounds[0]) / (bounds[1] - bounds[0])
b = color_lo[0] * (1 - percent) + color_hi[0] * percent
g = color_lo[1] * (1 - percent) + color_hi[1] * percent
r = color_lo[2] * (1 - percent) + color_hi[2] * percent
return (b, g, r)
def create_colorbar(height, width, color_lo=(0, 0, 250),
color_hi=(0, 250, 250)):
"""
:param height:
:param width:
:param color_lo:
:param color_hi:
:return:
"""
im = np.zeros([height, width, 3])
for h in range(0, height):
color = interp_color(h + 0.5, [0, height], color_hi, color_lo)
im[h, :, 0] = (color[0])
im[h, :, 1] = (color[1])
im[h, :, 2] = (color[2])
return im.astype(np.uint8)
def mkdir_if_missing(directory, delete_if_exist=False):
"""
Recursively make a directory structure even if missing.
if delete_if_exist=True then we will delete it first
which can be useful when better control over initialization is needed.
"""
if delete_if_exist and os.path.exists(directory): shutil.rmtree(directory)
# check if not exist, then make
if not os.path.exists(directory):
os.makedirs(directory)
def list_files(base_dir, file_pattern):
"""
Returns a list of files given a directory and pattern
The results are sorted alphabetically
Example:
files = list_files('path/to/images/', '*.jpg')
"""
return sorted(glob(os.path.join(base_dir) + file_pattern))
def file_parts(file_path):
"""
Lists a files parts such as base_path, file name and extension
Example
base, name, ext = file_parts('path/to/file/dog.jpg')
print(base, name, ext) --> ('path/to/file/', 'dog', '.jpg')
"""
base_path, tail = os.path.split(file_path)
name, ext = os.path.splitext(tail)
return base_path, name, ext
def pickle_write(file_path, obj):
"""
Serialize an object to a provided file_path
"""
with open(file_path, 'wb') as file:
pickle.dump(obj, file)
def pickle_read(file_path):
"""
De-serialize an object from a provided file_path
"""
with open(file_path, 'rb') as file:
return pickle.load(file)
def get_color(ind, hex=False):
"""
:param ind:
:param hex:
:return:
"""
colors = [(111, 74, 0), (81, 0, 81), (128, 64, 128), (244, 35, 232),
(250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156),
(190, 153, 153), (180, 165, 180), (150, 100, 100), (150, 120, 90),
(153, 153, 153), (250, 170, 30), (220, 220, 0), (107, 142, 35),
(152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0),
(0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 0, 90), (0, 0, 110),
(0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)]
color = colors[ind % len(colors)]
if hex:
return '#%02x%02x%02x' % (color[0], color[1], color[2])
else:
return color
def draw_3d_box(im, verts, color=(0, 200, 200), thickness=1):
"""
:param im:
:param verts:
:param color:
:param thickness:
:return:
"""
for lind in range(0, verts.shape[0] - 1):
v1 = verts[lind]
v2 = verts[lind + 1]
cv2.line(im, (int(v1[0]), int(v1[1])), (int(v2[0]), int(v2[1])), color,
thickness)
def draw_bev(canvas_bev,
z3d,
l3d,
w3d,
x3d,
ry3d,
color=(0, 200, 200),
scale=1,
thickness=2):
"""
:param canvas_bev:
:param z3d:
:param l3d:
:param w3d:
:param x3d:
:param ry3d:
:param color:
:param scale:
:param thickness:
:return:
"""
w = l3d * scale
l = w3d * scale
x = x3d * scale
z = z3d * scale
r = ry3d * -1
corners1 = np.array([[-w / 2, -l / 2, 1], [+w / 2, -l / 2, 1],
[+w / 2, +l / 2, 1], [-w / 2, +l / 2, 1]])
ry = np.array([
[+math.cos(r), -math.sin(r), 0],
[+math.sin(r), math.cos(r), 0],
[0, 0, 1],
])
corners2 = ry.dot(corners1.T).T
corners2[:, 0] += w / 2 + x + canvas_bev.shape[1] / 2
corners2[:, 1] += l / 2 + z
draw_line(
canvas_bev, corners2[0], corners2[1], color=color, thickness=thickness)
draw_line(
canvas_bev, corners2[1], corners2[2], color=color, thickness=thickness)
draw_line(
canvas_bev, corners2[2], corners2[3], color=color, thickness=thickness)
draw_line(
canvas_bev, corners2[3], corners2[0], color=color, thickness=thickness)
def draw_line(im, v1, v2, color=(0, 200, 200), thickness=1):
"""
:param im:
:param v1:
:param v2:
:param color:
:param thickness:
:return:
"""
cv2.line(im, (int(v1[0]), int(v1[1])), (int(v2[0]), int(v2[1])), color,
thickness)
def draw_circle(im,
pos,
radius=5,
thickness=1,
color=(250, 100, 100),
fill=True):
"""
:param im:
:param pos:
:param radius:
:param thickness:
:param color:
:param fill:
:return:
"""
if fill: thickness = -1
cv2.circle(
im, (int(pos[0]), int(pos[1])),
radius,
color=color,
thickness=thickness)
def draw_2d_box(im, box, color=(0, 200, 200), thickness=1):
"""
:param im:
:param box:
:param color:
:param thickness:
:return:
"""
x = box[0]
y = box[1]
w = box[2]
h = box[3]
x2 = (x + w) - 1
y2 = (y + h) - 1
cv2.rectangle(im, (int(x), int(y)), (int(x2), int(y2)), color, thickness)
def imshow(im, fig_num=None):
"""
:param im:
:param fig_num:
:return:
"""
if fig_num is not None: plt.figure(fig_num)
if len(im.shape) == 2:
im = np.tile(im, [3, 1, 1]).transpose([1, 2, 0])
plt.imshow(cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_RGB2BGR))
plt.show(block=False)
def imwrite(im, path):
"""
:param im:
:param path:
:return:
"""
cv2.imwrite(path, im)
def imread(path):
"""
:param path:
:return:
"""
return cv2.imread(path)
def draw_tick_marks(im, ticks):
"""
:param im:
:param ticks:
:return:
"""
ticks_loc = list(
range(0, im.shape[0] + 1, int((im.shape[0]) / (len(ticks) - 1))))
for tind, tick in enumerate(ticks):
y = min(max(ticks_loc[tind], 50), im.shape[0] - 10)
x = im.shape[1] - 115
draw_text(
im,
'-{}m'.format(tick), (x, y),
lineType=2,
scale=1.1,
bg_color=None)
def draw_text(im,
text,
pos,
scale=0.4,
color=(0, 0, 0),
font=cv2.FONT_HERSHEY_SIMPLEX,
bg_color=(0, 255, 255),
blend=0.33,
lineType=1):
"""
:param im:
:param text:
:param pos:
:param scale:
:param color:
:param font:
:param bg_color:
:param blend:
:param lineType:
:return:
"""
pos = [int(pos[0]), int(pos[1])]
if bg_color is not None:
text_size, _ = cv2.getTextSize(text, font, scale, lineType)
x_s = int(np.clip(pos[0], a_min=0, a_max=im.shape[1]))
x_e = int(
np.clip(
pos[0] + text_size[0] - 1 + 4, a_min=0, a_max=im.shape[1]))
y_s = int(
np.clip(
pos[1] - text_size[1] - 2, a_min=0, a_max=im.shape[0]))
y_e = int(np.clip(pos[1] + 1 - 2, a_min=0, a_max=im.shape[0]))
im[y_s:y_e + 1, x_s:x_e + 1, 0] = im[
y_s:y_e + 1, x_s:x_e + 1, 0] * blend + bg_color[0] * (1 - blend)
im[y_s:y_e + 1, x_s:x_e + 1, 1] = im[
y_s:y_e + 1, x_s:x_e + 1, 1] * blend + bg_color[1] * (1 - blend)
im[y_s:y_e + 1, x_s:x_e + 1, 2] = im[
y_s:y_e + 1, x_s:x_e + 1, 2] * blend + bg_color[2] * (1 - blend)
pos[0] = int(np.clip(pos[0] + 2, a_min=0, a_max=im.shape[1]))
pos[1] = int(np.clip(pos[1] - 2, a_min=0, a_max=im.shape[0]))
cv2.putText(im, text, tuple(pos), font, scale, color, lineType)
# Calculates rotation matrix to euler angles
# The result is the same as MATLAB except the order
# of the euler angles ( x and z are swapped ).
# adopted from https://www.learnopencv.com/rotation-matrix-to-euler-angles/
def mat2euler(R):
"""
:param R:
:return:
"""
sy = math.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0])
singular = sy < 1e-6
if not singular:
x = math.atan2(R[2, 1], R[2, 2])
y = math.atan2(-R[2, 0], sy)
z = math.atan2(R[1, 0], R[0, 0])
else:
raise ValueError('singular matrix found in mat2euler')
return np.array([x, y, z])
def fig_to_im(fig):
"""
:param fig:
:return:
"""
fig.canvas.draw()
# Get the RGBA buffer from the figure
w, h = fig.canvas.get_width_height()
buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8)
buf.shape = (w, h, 4)
# canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
buf = np.roll(buf, 3, axis=2)
w, h, d = buf.shape
im_pil = Image.frombytes("RGBA", (w, h), buf.tostring())
im_np = np.array(im_pil)[:, :, :3]
return im_np
def imzoom(im, zoom=0):
"""
:param im:
:param zoom:
:return:
"""
# single value passed in for both axis?
# extend same val for w, h
zoom = np.array(zoom)
if zoom.size == 1: zoom = np.array([zoom, zoom])
zoom = np.clip(zoom, a_min=0, a_max=0.99)
cx = im.shape[1] / 2
cy = im.shape[0] / 2
w = im.shape[1] * (1 - zoom[0])
h = im.shape[0] * (1 - zoom[-1])
x1 = int(np.clip(cx - w / 2, a_min=0, a_max=im.shape[1] - 1))
x2 = int(np.clip(cx + w / 2, a_min=0, a_max=im.shape[1] - 1))
y1 = int(np.clip(cy - h / 2, a_min=0, a_max=im.shape[0] - 1))
y2 = int(np.clip(cy + h / 2, a_min=0, a_max=im.shape[0] - 1))
im = im[y1:y2 + 1, x1:x2 + 1, :]
return im
def imhstack(im1, im2):
"""
:param im1:
:param im2:
:return:
"""
sf = im1.shape[0] / im2.shape[0]
if sf > 1:
im2 = cv2.resize(im2, (int(im2.shape[1] / sf), im1.shape[0]))
else:
im1 = cv2.resize(im1, (int(im1.shape[1] / sf), im2.shape[0]))
im_concat = np.hstack((im1, im2))
return im_concat
# Calculates Rotation Matrix given euler angles.
# adopted from https://www.learnopencv.com/rotation-matrix-to-euler-angles/
def euler2mat(x, y, z):
"""
:param x:
:param y:
:param z:
:return:
"""
R_x = np.array([[1, 0, 0], [0, math.cos(x), -math.sin(x)],
[0, math.sin(x), math.cos(x)]])
R_y = np.array([[math.cos(y), 0, math.sin(y)], [0, 1, 0],
[-math.sin(y), 0, math.cos(y)]])
R_z = np.array([[math.cos(z), -math.sin(z), 0],
[math.sin(z), math.cos(z), 0], [0, 0, 1]])
R = np.dot(R_z, np.dot(R_y, R_x))
return R
def convertAlpha2Rot(alpha, z3d, x3d):
"""
:param alpha:
:param z3d:
:param x3d:
:return:
"""
ry3d = alpha + math.atan2(-z3d, x3d) + 0.5 * math.pi
while ry3d > math.pi:
ry3d -= math.pi * 2
while ry3d < (-math.pi):
ry3d += math.pi * 2
return ry3d
def convertRot2Alpha(ry3d, z3d, x3d):
"""
convertRot2Alpha
"""
alpha = ry3d - math.atan2(-z3d, x3d) - 0.5 * math.pi
while alpha > math.pi:
alpha -= math.pi * 2
while alpha < (-math.pi):
alpha += math.pi * 2
return alpha
"""
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
from .densenet import densenet121
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""densenet backbone"""
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
def _bn_function_factory(norm, conv):
def bn_function(*inputs):
concated_features = fluid.layers.concat(inputs, 1)
bottleneck_output = conv(norm(concated_features))
return bottleneck_output
return bn_function
class _DenseLayer_rpn(fluid.dygraph.Layer):
def __init__(self,
num_input_features,
growth_rate,
bn_size,
drop_rate,
memory_efficient=False):
super(_DenseLayer_rpn, self).__init__()
self.add_sublayer('norm1', BatchNorm(num_input_features, act='relu'))
self.add_sublayer(
'conv1',
Conv2D(
num_input_features,
bn_size * growth_rate,
filter_size=1,
stride=1,
bias_attr=False))
self.add_sublayer('norm2', BatchNorm(bn_size * growth_rate, act='relu'))
self.add_sublayer(
'conv2',
Conv2D(
bn_size * growth_rate,
growth_rate,
filter_size=3,
dilation=2,
stride=1,
padding=2,
bias_attr=False))
self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient
def forward(self, *prev_features):
bn_function = _bn_function_factory(self.norm1, self.conv1)
bottleneck_output = bn_function(*prev_features)
new_features = self.conv2(self.norm2(bottleneck_output))
if self.drop_rate > 0:
new_features = fluid.layers.dropout(new_features, self.drop_rate)
return new_features
class _DenseBlock_rpn(fluid.dygraph.Layer):
def __init__(self,
num_layers,
num_input_features,
bn_size,
growth_rate,
drop_rate,
memory_efficient=False):
super(_DenseBlock_rpn, self).__init__()
self.res_out_list = []
for i in range(num_layers):
layer = _DenseLayer_rpn(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate,
memory_efficient=memory_efficient, )
res_out = self.add_sublayer('denselayer%d' % (i + 1), layer)
self.res_out_list.append(res_out)
def forward(self, init_features):
features = [init_features]
for layer in self.res_out_list:
new_features = layer(*features)
features.append(new_features)
return fluid.layers.concat(features, axis=1)
class _DenseLayer(fluid.dygraph.Layer):
def __init__(self,
num_input_features,
growth_rate,
bn_size,
drop_rate,
memory_efficient=False):
super(_DenseLayer, self).__init__()
self.add_sublayer('norm1', BatchNorm(num_input_features, act='relu'))
self.add_sublayer(
'conv1',
Conv2D(
num_input_features,
bn_size * growth_rate,
filter_size=1,
stride=1,
bias_attr=False))
self.add_sublayer('norm2', BatchNorm(bn_size * growth_rate, act='relu'))
self.add_sublayer(
'conv2',
Conv2D(
bn_size * growth_rate,
growth_rate,
filter_size=3,
stride=1,
padding=1,
bias_attr=False))
self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient
def forward(self, *prev_features):
bn_function = _bn_function_factory(self.norm1, self.conv1)
bottleneck_output = bn_function(*prev_features)
new_features = self.conv2(self.norm2(bottleneck_output))
if self.drop_rate > 0:
new_features = fluid.layers.dropout(new_features, self.drop_rate)
return new_features
class _DenseBlock(fluid.dygraph.Layer):
def __init__(self,
num_layers,
num_input_features,
bn_size,
growth_rate,
drop_rate,
memory_efficient=False):
super(_DenseBlock, self).__init__()
self.res_out_list = []
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate,
memory_efficient=memory_efficient, )
res_out = self.add_sublayer('denselayer%d' % (i + 1), layer)
self.res_out_list.append(res_out)
def forward(self, init_features):
features = [init_features]
for layer in self.res_out_list:
new_features = layer(*features)
features.append(new_features)
return fluid.layers.concat(features, axis=1)
class _Transition(Sequential):
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.add_sublayer('norm', BatchNorm(num_input_features, act='relu'))
self.add_sublayer(
'conv',
Conv2D(
num_input_features,
num_output_features,
filter_size=1,
stride=1,
bias_attr=False))
self.add_sublayer(
'pool', Pool2D(
pool_size=2, pool_stride=2, pool_type='avg'))
class DenseNet(fluid.dygraph.Layer):
"""
"""
def __init__(self,
growth_rate=32,
block_config=(6, 12, 24, 16),
num_init_features=64,
bn_size=4,
drop_rate=0,
num_classes=1000,
memory_efficient=False):
super(DenseNet, self).__init__()
self.features = Sequential(
('conv0', Conv2D(
3,
num_init_features,
filter_size=7,
stride=2,
padding=3,
bias_attr=False)), ('norm0', BatchNorm(
num_init_features, act='relu')), ('pool0', Pool2D(
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')))
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
if i == 3:
block = _DenseBlock_rpn(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate,
memory_efficient=memory_efficient)
self.features.add_sublayer('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
else:
block = _DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate,
memory_efficient=memory_efficient)
self.features.add_sublayer('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(
num_input_features=num_features,
num_output_features=num_features // 2)
self.features.add_sublayer('transition%d' % (i + 1), trans)
num_features = num_features // 2
# Final batch norm
self.features.add_sublayer('norm5', BatchNorm(num_features))
# Linear layer
self.classifier = Linear(num_features, num_classes)
# init num_features
self.features.num_features = num_features
def forward(self, x):
features = self.features(x)
out = fluid.layers.relu(features)
out = fluid.layers.adaptive_pool2d(
input=out, pool_size=[1, 1], pool_type='avg')
out = fluid.layers.flatten(out, 1)
out = self.classifier(out)
return out
def _densenet(arch, growth_rate, block_config, num_init_features, **kwargs):
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
return model
def densenet121(**kwargs):
return _densenet('densenet121', 32, (6, 12, 24, 16), 64, **kwargs)
def densenet161(**kwargs):
return _densenet('densenet161', 48, (6, 12, 36, 24), 96, **kwargs)
def densenet169(**kwargs):
return _densenet('densenet169', 32, (6, 12, 32, 32), 64, **kwargs)
def densenet201(**kwargs):
return _densenet('densenet201', 32, (6, 12, 48, 32), 64, **kwargs)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
model_3d_dilate
"""
from lib.rpn_util import *
from models.backbone.densenet import densenet121
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.initializer import Normal
import math
def initial_type(name,
input_channels,
init="kaiming",
use_bias=False,
filter_size=0,
stddev=0.02):
if init == "kaiming":
fan_in = input_channels * filter_size * filter_size
bound = 1 / math.sqrt(fan_in)
param_attr = fluid.ParamAttr(
name=name + "_weight",
initializer=fluid.initializer.Uniform(
low=-bound, high=bound))
if use_bias == True:
bias_attr = fluid.ParamAttr(
name=name + '_bias',
initializer=fluid.initializer.Uniform(
low=-bound, high=bound))
else:
bias_attr = False
else:
param_attr = fluid.ParamAttr(
name=name + "_weight",
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=stddev))
if use_bias == True:
bias_attr = fluid.ParamAttr(
name=name + "_bias",
initializer=fluid.initializer.Constant(0.0))
else:
bias_attr = False
return param_attr, bias_attr
class ConvLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
padding=0,
stride=1,
groups=None,
act=None,
name=None):
super(ConvLayer, self).__init__()
param_attr, bias_attr = initial_type(
name=name,
input_channels=num_channels,
use_bias=True,
filter_size=filter_size)
self.num_filters = num_filters
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
padding=padding,
stride=stride,
groups=groups,
act=act,
param_attr=param_attr,
bias_attr=bias_attr)
def forward(self, inputs):
x = self._conv(inputs)
return x
class RPN(fluid.dygraph.Layer):
def __init__(self, phase, base, conf):
super(RPN, self).__init__()
self.base = base
del self.base.transition3.pool
self.phase = phase
self.num_classes = len(conf['lbls']) + 1
self.num_anchors = conf['anchors'].shape[0]
self.prop_feats = ConvLayer(
num_channels=self.base.num_features,
num_filters=512,
filter_size=3,
padding=1,
act='relu',
name='rpn_prop_feats')
self.cls = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_classes * self.num_anchors,
filter_size=1,
name='rpn_cls')
self.bbox_x = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_x')
self.bbox_y = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_y')
self.bbox_w = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_w')
self.bbox_h = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_h')
self.bbox_x3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_x3d')
self.bbox_y3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_y3d')
self.bbox_z3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_z3d')
self.bbox_w3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_w3d')
self.bbox_h3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_h3d')
self.bbox_l3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_l3d')
self.bbox_rY3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_rY3d')
self.feat_stride = conf.feat_stride
self.feat_size = calc_output_size(
np.array(conf.crop_size), self.feat_stride)
self.rois = locate_anchors(conf.anchors, self.feat_size,
conf.feat_stride)
self.anchors = conf.anchors
def forward(self, inputs):
# backbone
x = self.base(inputs)
prop_feats = self.prop_feats(x)
cls = self.cls(prop_feats)
# bbox 2d
bbox_x = self.bbox_x(prop_feats)
bbox_y = self.bbox_y(prop_feats)
bbox_w = self.bbox_w(prop_feats)
bbox_h = self.bbox_h(prop_feats)
# bbox 3d
bbox_x3d = self.bbox_x3d(prop_feats)
bbox_y3d = self.bbox_y3d(prop_feats)
bbox_z3d = self.bbox_z3d(prop_feats)
bbox_w3d = self.bbox_w3d(prop_feats)
bbox_h3d = self.bbox_h3d(prop_feats)
bbox_l3d = self.bbox_l3d(prop_feats)
bbox_rY3d = self.bbox_rY3d(prop_feats)
batch_size, c, feat_h, feat_w = cls.shape
feat_size = fluid.layers.shape(cls)[2:4]
# reshape for cross entropy
cls = fluid.layers.reshape(
x=cls,
shape=[
batch_size, self.num_classes, feat_h * self.num_anchors, feat_w
])
# score probabilities
prob = fluid.layers.softmax(cls, axis=1)
# reshape for consistency
bbox_x = flatten_tensor(
fluid.layers.reshape(
x=bbox_x,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_y = flatten_tensor(
fluid.layers.reshape(
x=bbox_y,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_w = flatten_tensor(
fluid.layers.reshape(
x=bbox_w,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_h = flatten_tensor(
fluid.layers.reshape(
x=bbox_h,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_x3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_x3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_y3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_y3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_z3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_z3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_w3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_w3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_h3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_h3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_l3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_l3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_rY3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_rY3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
# bundle
bbox_2d = fluid.layers.concat(
input=[bbox_x, bbox_y, bbox_w, bbox_h], axis=2)
bbox_3d = fluid.layers.concat(
input=[
bbox_x3d, bbox_y3d, bbox_z3d, bbox_w3d, bbox_h3d, bbox_l3d,
bbox_rY3d
],
axis=2)
cls = flatten_tensor(cls)
prob = flatten_tensor(prob)
if self.phase == "train":
return cls, prob, bbox_2d, bbox_3d, feat_size
else:
if self.feat_size[0] != feat_h or self.feat_size[1] != feat_w:
#self.feat_size = [feat_h, feat_w]
#self.rois = locate_anchors(self.anchors, self.feat_size, self.feat_stride)
self.rois = locate_anchors(self.anchors, [feat_h, feat_w],
self.feat_stride)
return cls, prob, bbox_2d, bbox_3d, feat_size, self.rois
def build(conf, backbone, phase='train'):
train = phase.lower() == 'train'
if backbone.lower() == "densenet121":
model_backbone = densenet121() # pretrain
rpn_net = RPN(phase, model_backbone.features, conf)
# pretrain
if 'pretrained' in conf and conf.pretrained is not None:
print("load pretrain model from ", conf.pretrained)
pretrained, _ = fluid.load_dygraph(conf.pretrained)
rpn_net.base.set_dict(pretrained, use_structured_name=True)
if train: rpn_net.train()
else: rpn_net.eval()
return rpn_net
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
model_3d_dilate_depth_aware
"""
from lib.rpn_util import *
from models.backbone.densenet import densenet121
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.container import Sequential
def initial_type(name,
input_channels,
init="kaiming",
use_bias=False,
filter_size=0,
stddev=0.02):
if init == "kaiming":
fan_in = input_channels * filter_size * filter_size
bound = 1 / math.sqrt(fan_in)
param_attr = fluid.ParamAttr(
name=name + "_weight",
initializer=fluid.initializer.Uniform(
low=-bound, high=bound))
if use_bias == True:
bias_attr = fluid.ParamAttr(
name=name + '_bias',
initializer=fluid.initializer.Uniform(
low=-bound, high=bound))
else:
bias_attr = False
else:
param_attr = fluid.ParamAttr(
name=name + "_weight",
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=stddev))
if use_bias == True:
bias_attr = fluid.ParamAttr(
name=name + "_bias",
initializer=fluid.initializer.Constant(0.0))
else:
bias_attr = False
return param_attr, bias_attr
class ConvLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
padding=0,
stride=1,
groups=None,
act=None,
name=None):
super(ConvLayer, self).__init__()
param_attr, bias_attr = initial_type(
name=name,
input_channels=num_channels,
use_bias=True,
filter_size=filter_size)
self.num_filters = num_filters
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
padding=padding,
stride=stride,
groups=groups,
act=act,
param_attr=param_attr,
bias_attr=bias_attr)
def forward(self, inputs):
x = self._conv(inputs)
return x
class LocalConv2d(fluid.dygraph.Layer):
"""LocalConv2d"""
def __init__(self,
num_rows,
num_feats_in,
num_feats_out,
kernel=1,
padding=0,
param_attr=None,
bias_attr=None):
super(LocalConv2d, self).__init__()
self.num_rows = num_rows
self.out_channels = num_feats_out
self.kernel = kernel
self.pad = padding
self.group_conv = Conv2D(
num_feats_in * num_rows,
num_feats_out * num_rows,
kernel,
stride=1,
groups=num_rows)
def forward(self, x):
b, c, h, w = x.shape
if self.pad:
x = fluid.layers.pad2d(
x,
paddings=[self.pad, self.pad, self.pad, self.pad],
mode='constant',
pad_value=0.0)
t = int(h / self.num_rows)
# unfold by rows # (dimension, size, step) 2, t+padding*2, t
tmp_list = []
for i in range(0, self.num_rows):
tmp = fluid.layers.slice(
x, axes=[2], starts=[i * t], ends=[i * t + (t + self.pad * 2)])
tmp_list.append(fluid.layers.transpose(tmp, [0, 1, 3, 2]))
x = fluid.layers.stack(tmp_list, axis=2)
x = fluid.layers.transpose(x, [0, 2, 1, 4, 3])
#b, h/row, c , row, w
x = fluid.layers.reshape(
x, [b, c * self.num_rows, t + self.pad * 2, (w + self.pad * 2)])
# group convolution for efficient parallel processing
y = self.group_conv(x)
y = fluid.layers.reshape(y, [b, self.num_rows, self.out_channels, t, w])
y = fluid.layers.transpose(y, [0, 2, 1, 3, 4])
y = fluid.layers.reshape(y, [b, self.out_channels, h, w])
return y
class RPN(fluid.dygraph.Layer):
"""RPN module"""
def __init__(self, phase, base, conf):
super(RPN, self).__init__()
self.base = base
self.conf = conf
del self.base.transition3.pool
# settings
self.num_classes = len(conf['lbls']) + 1
self.num_anchors = conf['anchors'].shape[0]
self.num_rows = int(
min(conf.bins, calc_output_size(conf.test_scale, conf.feat_stride)))
self.phase = phase
self.prop_feats = ConvLayer(
num_channels=self.base.num_features,
num_filters=512,
filter_size=3,
padding=1,
act='relu',
name='rpn_prop_feats')
# outputs
self.cls = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_classes * self.num_anchors,
filter_size=1,
name='rpn_cls')
# bbox 2d
self.bbox_x = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_x')
self.bbox_y = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_y')
self.bbox_w = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_w')
self.bbox_h = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_h')
# bbox 3d
self.bbox_x3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_x3d')
self.bbox_y3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_y3d')
self.bbox_z3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_z3d')
self.bbox_w3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_w3d')
self.bbox_h3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_h3d')
self.bbox_l3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_l3d')
self.bbox_rY3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_rY3d')
self.prop_feats_loc = LocalConv2d(
self.num_rows,
self.base.num_features,
512,
3,
padding=1,
param_attr=ParamAttr(name='rpn_prop_feats_weights_loc'),
bias_attr=ParamAttr(name='rpn_prop_feats_bias_loc'))
# outputs
self.cls_loc = LocalConv2d(
self.num_rows,
self.prop_feats.num_filters,
self.num_classes * self.num_anchors,
1,
param_attr=ParamAttr(name='rpn_cls_loc_weights_loc'),
bias_attr=ParamAttr(name='rpn_cls_loc_bias_loc'))
# bbox 2d
self.bbox_x_loc = LocalConv2d(
self.num_rows,
self.prop_feats.num_filters,
self.num_anchors,
1,
param_attr=ParamAttr(name='rpn_bbox_x_loc_weights_loc'),
bias_attr=ParamAttr(name='rpn_bbox_x_loc_bias_loc'))
self.bbox_y_loc = LocalConv2d(
self.num_rows,
self.prop_feats.num_filters,
self.num_anchors,
1,
param_attr=ParamAttr(name='rpn_bbox_y_loc_weights_loc'),
bias_attr=ParamAttr(name='rpn_bbox_y_loc_bias_loc'))
self.bbox_w_loc = LocalConv2d(
self.num_rows,
self.prop_feats.num_filters,
self.num_anchors,
1,
param_attr=ParamAttr(name='rpn_bbox_w_loc_weights_loc'),
bias_attr=ParamAttr(name='rpn_bbox_w_loc_bias_loc'))
self.bbox_h_loc = LocalConv2d(
self.num_rows,
self.prop_feats.num_filters,
self.num_anchors,
1,
param_attr=ParamAttr(name='rpn_bbox_h_loc_weights_loc'),
bias_attr=ParamAttr(name='rpn_bbox_h_loc_bias_loc'))
# bbox 3d
self.bbox_x3d_loc = LocalConv2d(
self.num_rows,
self.prop_feats.num_filters,
self.num_anchors,
1,
param_attr=ParamAttr(name='rpn_bbox_x3d_loc_weights_loc'),
bias_attr=ParamAttr(name='rpn_bbox_x3d_loc_bias_loc'))
self.bbox_y3d_loc = LocalConv2d(
self.num_rows,
self.prop_feats.num_filters,
self.num_anchors,
1,
param_attr=ParamAttr(name='rpn_bbox_y3d_loc_weights_loc'),
bias_attr=ParamAttr(name='rpn_bbox_y3d_loc_bias_loc'))
self.bbox_z3d_loc = LocalConv2d(
self.num_rows,
self.prop_feats.num_filters,
self.num_anchors,
1,
param_attr=ParamAttr(name='rpn_bbox_z3d_loc_weights_loc'),
bias_attr=ParamAttr(name='rpn_bbox_z3d_loc_bias_loc'))
self.bbox_w3d_loc = LocalConv2d(
self.num_rows,
self.prop_feats.num_filters,
self.num_anchors,
1,
param_attr=ParamAttr(name='rpn_bbox_w3d_loc_weights_loc'),
bias_attr=ParamAttr(name='rpn_bbox_w3d_loc_bias_loc'))
self.bbox_h3d_loc = LocalConv2d(
self.num_rows,
self.prop_feats.num_filters,
self.num_anchors,
1,
param_attr=ParamAttr(name='rpn_bbox_h3d_loc_weights_loc'),
bias_attr=ParamAttr(name='rpn_bbox_h3d_loc_bias_loc'))
self.bbox_l3d_loc = LocalConv2d(
self.num_rows,
self.prop_feats.num_filters,
self.num_anchors,
1,
param_attr=ParamAttr(name='rpn_bbox_l3d_loc_weights_loc'),
bias_attr=ParamAttr(name='rpn_bbox_l3d_loc_bias_loc'))
self.bbox_rY3d_loc = LocalConv2d(
self.num_rows,
self.prop_feats.num_filters,
self.num_anchors,
1,
param_attr=ParamAttr(name='rpn_bbox_rY3d_loc_weights_loc'),
bias_attr=ParamAttr(name='rpn_bbox_rY3d_loc_bias_loc'))
self.feat_stride = conf.feat_stride
self.feat_size = calc_output_size(
np.array(conf.crop_size), self.feat_stride)
self.rois = locate_anchors(conf.anchors, self.feat_size,
conf.feat_stride)
self.anchors = conf.anchors
self.bbox_means = conf.bbox_means
self.bbox_stds = conf.bbox_stds
self.cls_ble = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=fluid.initializer.Constant(
value=10e-5)) # TODO check
# bbox 2d
self.bbox_x_ble = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=fluid.initializer.Constant(value=10e-5))
self.bbox_y_ble = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=fluid.initializer.Constant(value=10e-5))
self.bbox_w_ble = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=fluid.initializer.Constant(value=10e-5))
self.bbox_h_ble = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=fluid.initializer.Constant(value=10e-5))
# bbox 3d
self.bbox_x3d_ble = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=fluid.initializer.Constant(value=10e-5))
self.bbox_y3d_ble = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=fluid.initializer.Constant(value=10e-5))
self.bbox_z3d_ble = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=fluid.initializer.Constant(value=10e-5))
self.bbox_w3d_ble = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=fluid.initializer.Constant(value=10e-5))
self.bbox_h3d_ble = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=fluid.initializer.Constant(value=10e-5))
self.bbox_l3d_ble = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=fluid.initializer.Constant(value=10e-5))
self.bbox_rY3d_ble = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=fluid.initializer.Constant(value=10e-5))
def forward(self, inputs):
# backbone
x = self.base(inputs)
prop_feats = self.prop_feats(x)
prop_feats_loc = self.prop_feats_loc(x)
prop_feats_loc = fluid.layers.relu(prop_feats_loc)
# f=open('./prop_feats.pkl','wb')
# pickle.dump(prop_feats_loc.numpy(),f)
# f.close()
cls = self.cls(prop_feats)
#bbox 2d
bbox_x = self.bbox_x(prop_feats)
bbox_y = self.bbox_y(prop_feats)
bbox_w = self.bbox_w(prop_feats)
bbox_h = self.bbox_h(prop_feats)
# bbox 3d
bbox_x3d = self.bbox_x3d(prop_feats)
bbox_y3d = self.bbox_y3d(prop_feats)
bbox_z3d = self.bbox_z3d(prop_feats)
bbox_w3d = self.bbox_w3d(prop_feats)
bbox_h3d = self.bbox_h3d(prop_feats)
bbox_l3d = self.bbox_l3d(prop_feats)
bbox_rY3d = self.bbox_rY3d(prop_feats)
cls_loc = self.cls_loc(prop_feats_loc)
# bbox 2d
bbox_x_loc = self.bbox_x_loc(prop_feats_loc)
bbox_y_loc = self.bbox_y_loc(prop_feats_loc)
bbox_w_loc = self.bbox_w_loc(prop_feats_loc)
bbox_h_loc = self.bbox_h_loc(prop_feats_loc)
# bbox 3d
bbox_x3d_loc = self.bbox_x3d_loc(prop_feats_loc)
bbox_y3d_loc = self.bbox_y3d_loc(prop_feats_loc)
bbox_z3d_loc = self.bbox_z3d_loc(prop_feats_loc)
bbox_w3d_loc = self.bbox_w3d_loc(prop_feats_loc)
bbox_h3d_loc = self.bbox_h3d_loc(prop_feats_loc)
bbox_l3d_loc = self.bbox_l3d_loc(prop_feats_loc)
bbox_rY3d_loc = self.bbox_rY3d_loc(prop_feats_loc)
cls_ble = fluid.layers.sigmoid(self.cls_ble)
# bbox 2d
bbox_x_ble = fluid.layers.sigmoid(self.bbox_x_ble)
bbox_y_ble = fluid.layers.sigmoid(self.bbox_y_ble)
bbox_w_ble = fluid.layers.sigmoid(self.bbox_w_ble)
bbox_h_ble = fluid.layers.sigmoid(self.bbox_h_ble)
# bbox 3d
bbox_x3d_ble = fluid.layers.sigmoid(self.bbox_x3d_ble)
bbox_y3d_ble = fluid.layers.sigmoid(self.bbox_y3d_ble)
bbox_z3d_ble = fluid.layers.sigmoid(self.bbox_z3d_ble)
bbox_w3d_ble = fluid.layers.sigmoid(self.bbox_w3d_ble)
bbox_h3d_ble = fluid.layers.sigmoid(self.bbox_h3d_ble)
bbox_l3d_ble = fluid.layers.sigmoid(self.bbox_l3d_ble)
bbox_rY3d_ble = fluid.layers.sigmoid(self.bbox_rY3d_ble)
# blend
cls = (cls * cls_ble) + (cls_loc * (1 - cls_ble))
bbox_x = (bbox_x * bbox_x_ble) + (bbox_x_loc * (1 - bbox_x_ble))
bbox_y = (bbox_y * bbox_y_ble) + (bbox_y_loc * (1 - bbox_y_ble))
bbox_w = (bbox_w * bbox_w_ble) + (bbox_w_loc * (1 - bbox_w_ble))
bbox_h = (bbox_h * bbox_h_ble) + (bbox_h_loc * (1 - bbox_h_ble))
bbox_x3d = (bbox_x3d * bbox_x3d_ble) + (bbox_x3d_loc *
(1 - bbox_x3d_ble))
bbox_y3d = (bbox_y3d * bbox_y3d_ble) + (bbox_y3d_loc *
(1 - bbox_y3d_ble))
bbox_z3d = (bbox_z3d * bbox_z3d_ble) + (bbox_z3d_loc *
(1 - bbox_z3d_ble))
bbox_h3d = (bbox_h3d * bbox_h3d_ble) + (bbox_h3d_loc *
(1 - bbox_h3d_ble))
bbox_w3d = (bbox_w3d * bbox_w3d_ble) + (bbox_w3d_loc *
(1 - bbox_w3d_ble))
bbox_l3d = (bbox_l3d * bbox_l3d_ble) + (bbox_l3d_loc *
(1 - bbox_l3d_ble))
bbox_rY3d = (bbox_rY3d * bbox_rY3d_ble) + (bbox_rY3d_loc *
(1 - bbox_rY3d_ble))
batch_size, c, feat_h, feat_w = cls.shape
feat_size = fluid.layers.shape(cls)[2:4]
# reshape for cross entropy
cls = fluid.layers.reshape(
x=cls,
shape=[
batch_size, self.num_classes, feat_h * self.num_anchors, feat_w
])
# score probabilities
prob = fluid.layers.softmax(cls, axis=1)
# reshape for consistency
bbox_x = flatten_tensor(
fluid.layers.reshape(
x=bbox_x,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_y = flatten_tensor(
fluid.layers.reshape(
x=bbox_y,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_w = flatten_tensor(
fluid.layers.reshape(
x=bbox_w,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_h = flatten_tensor(
fluid.layers.reshape(
x=bbox_h,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_x3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_x3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_y3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_y3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_z3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_z3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_w3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_w3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_h3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_h3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_l3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_l3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_rY3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_rY3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
# bundle
bbox_2d = fluid.layers.concat(
input=[bbox_x, bbox_y, bbox_w, bbox_h], axis=2)
bbox_3d = fluid.layers.concat(
input=[
bbox_x3d, bbox_y3d, bbox_z3d, bbox_w3d, bbox_h3d, bbox_l3d,
bbox_rY3d
],
axis=2)
cls = flatten_tensor(cls)
prob = flatten_tensor(prob)
if self.phase == "train":
return cls, prob, bbox_2d, bbox_3d, feat_size
else:
feat_stride = self.conf.feat_stride
anchors = self.conf.anchors
feat_size = calc_output_size(
np.array(self.conf.crop_size), feat_stride)
rois = locate_anchors(anchors, feat_size, feat_stride)
if feat_size[0] != feat_h or feat_size[1] != feat_w:
feat_size = [feat_h, feat_w]
rois = locate_anchors(anchors, feat_size, feat_stride)
return cls, prob, bbox_2d, bbox_3d, feat_size, rois
def build(conf, backbone, phase='train'):
# Backbone
if backbone.lower() == "densenet121":
backbone_res = densenet121()
train = phase.lower() == 'train'
num_cls = len(conf['lbls']) + 1
num_anchors = conf['anchors'].shape[0]
# RPN
rpn_net = RPN(phase, backbone_res.features, conf)
# pretrain
if 'pretrained' in conf and conf.pretrained is not None:
print("load pretrain model from ", conf.pretrained)
src_weights, _ = fluid.load_dygraph(conf.pretrained)
conv_layers = [
'prop_feats', 'cls', 'bbox_x', 'bbox_y', 'bbox_w', 'bbox_h',
'bbox_x3d', 'bbox_y3d', 'bbox_w3d', 'bbox_h3d', 'bbox_l3d',
'bbox_z3d', 'bbox_rY3d'
]
for layer in conv_layers:
src_weight_key = '{}._conv.weight'.format(layer)
src_bias_key = '{}._conv.bias'.format(layer)
dst_weight_key = '{}.group_conv.weight'.format(layer + '_loc')
dst_bias_key = '{}.group_conv.bias'.format(layer + '_loc')
src_weights[dst_weight_key] = np.tile(src_weights[src_weight_key],
(conf.bins, 1, 1, 1))
src_weights[dst_bias_key] = np.tile(src_weights[src_bias_key],
conf.bins)
src_keylist = list(src_weights.keys())
dst_keylist = list(rpn_net.state_dict().keys())
for key in dst_keylist:
if key not in src_keylist:
src_weights[key] = rpn_net.state_dict()[key]
rpn_net.set_dict(src_weights, use_structured_name=True)
if train: rpn_net.train()
else: rpn_net.eval()
return rpn_net
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import requests
import time
import functools
import tarfile
import shutil
lasttime = time.time()
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
FLUSH_INTERVAL = 0.1
def progress(str, end=False):
global lasttime
if end:
str += "\n"
lasttime = 0
if time.time() - lasttime >= FLUSH_INTERVAL:
sys.stdout.write("\r%s" % str)
lasttime = time.time()
sys.stdout.flush()
def _download_file(url, savepath, print_progress):
r = requests.get(url, stream=True)
total_length = r.headers.get('content-length')
if total_length is None:
with open(savepath, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(savepath, 'wb') as f:
dl = 0
total_length = int(total_length)
starttime = time.time()
if print_progress:
print("Downloading %s" % os.path.basename(savepath))
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
if print_progress:
done = int(50 * dl / total_length)
progress("[%-50s] %.2f%%" %
('=' * done, float(100 * dl) / total_length))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
def _uncompress_file(filepath, extrapath, delete_file, print_progress):
if print_progress:
print("Uncompress %s" % os.path.basename(filepath))
if filepath.endswith("zip"):
handler = _uncompress_file_zip
elif filepath.endswith("tgz"):
handler = _uncompress_file_tar
else:
handler = functools.partial(_uncompress_file_tar, mode="r")
for total_num, index, rootpath in handler(filepath, extrapath):
if print_progress:
done = int(50 * float(index) / total_num)
progress("[%-50s] %.2f%%" %
('=' * done, float(100 * index) / total_num))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
if delete_file:
os.remove(filepath)
return rootpath
def _uncompress_file_zip(filepath, extrapath):
files = zipfile.ZipFile(filepath, 'r')
filelist = files.namelist()
rootpath = filelist[0]
total_num = len(filelist)
for index, file in enumerate(filelist):
files.extract(file, extrapath)
yield total_num, index, rootpath
files.close()
yield total_num, index, rootpath
def _uncompress_file_tar(filepath, extrapath, mode="r:gz"):
files = tarfile.open(filepath, mode)
filelist = files.getnames()
total_num = len(filelist)
rootpath = filelist[0]
for index, file in enumerate(filelist):
files.extract(file, extrapath)
yield total_num, index, rootpath
files.close()
yield total_num, index, rootpath
def download_file_and_uncompress(url,
savepath=None,
extrapath=None,
extraname=None,
print_progress=True,
cover=False,
delete_file=True):
if savepath is None:
savepath = "."
if extrapath is None:
extrapath = "."
savename = url.split("/")[-1]
savepath = os.path.join(savepath, savename)
savename = ".".join(savename.split(".")[:-1])
savename = os.path.join(extrapath, savename)
extraname = savename if extraname is None else os.path.join(extrapath,
extraname)
if cover:
if os.path.exists(savepath):
shutil.rmtree(savepath)
if os.path.exists(savename):
shutil.rmtree(savename)
if os.path.exists(extraname):
shutil.rmtree(extraname)
if not os.path.exists(extraname):
if not os.path.exists(savename):
if not os.path.exists(savepath):
_download_file(url, savepath, print_progress)
savename = _uncompress_file(savepath, extrapath, delete_file,
print_progress)
savename = os.path.join(extrapath, savename)
shutil.move(savename, extraname)
model_urls = {
"densenet121":
"https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet121_pretrained.tar",
"resnet101":
"http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar",
}
if __name__ == "__main__":
if len(sys.argv) != 2:
print("usage:\n python download_model.py ${MODEL_NAME}")
exit(1)
model_name = sys.argv[1]
if not model_name in model_urls.keys():
print("Only support: \n {}".format("\n ".join(
list(model_urls.keys()))))
exit(1)
url = model_urls[model_name]
download_file_and_uncompress(
url=url,
savepath=LOCAL_PATH,
extrapath=LOCAL_PATH,
extraname=model_name)
print("Pretrained Model download success!")
export CUDA_VISIBLE_DEVICES=0
python test.py --data_dir dataset --conf_path output/kitti_3d_multi_warmup/conf.pkl --weights_path output/kitti_3d_multi_warmup/epoch1.pdparams
export CUDA_VISIBLE_DEVICES=1
export export FLAGS_fraction_of_gpu_memory_to_use=0.1
python train.py --data_dir dataset --conf kitti_3d_multi_warmup
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""test"""
import os
import sys
import argparse
import ast
import logging
import paddle.fluid as fluid
import paddle.fluid.framework as framework
from models import *
from easydict import EasyDict as edict
from lib.rpn_util import *
sys.path.append(os.getcwd())
import lib.core as core
from lib.util import *
import pdb
import paddle
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid import framework
logging.root.handlers = []
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
"""parse"""
parser = argparse.ArgumentParser("M3D-RPN train script")
parser.add_argument("--conf_path", type=str, default='', help="config.pkl")
parser.add_argument(
'--weights_path', type=str, default='', help='weights save path')
parser.add_argument(
'--backbone',
type=str,
default='DenseNet121',
help='backbone model to train, default DenseNet121')
parser.add_argument(
'--data_dir', type=str, default='dataset', help='dataset directory')
args = parser.parse_args()
return args
def test():
"""main train"""
args = parse_args()
# load config
conf = edict(pickle_read(args.conf_path))
conf.pretrained = None
results_path = os.path.join('output', 'tmp_results', 'data')
# make directory
mkdir_if_missing(results_path, delete_if_exist=True)
with fluid.dygraph.guard(fluid.CUDAPlace(0)):
# training network
src_path = os.path.join('.', 'models', conf.model + '.py')
train_model = absolute_import(src_path)
train_model = train_model.build(conf, args.backbone, 'train')
train_model.eval()
train_model.phase = "eval"
Already_trained, _ = fluid.load_dygraph(args.weights_path)
print("loaded model from ", args.weights_path)
train_model.set_dict(Already_trained) #, use_structured_name=True)
print("start evaluation...")
test_kitti_3d(conf.dataset_test, train_model, conf, results_path,
args.data_dir)
print("Evaluation Finished!")
if __name__ == '__main__':
test()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""main """
import os
import sys
import time
import shutil
import argparse
import ast
import logging
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.framework as framework
from models import *
from utils import *
sys.path.append(os.getcwd())
from data.m3drpn_reader import M3drpnReader
import lib.core as core
from lib.rpn_util import *
import pdb
from easydict import EasyDict as edict
import paddle
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid import framework
import math
from lib.loss.rpn_3d import *
import time
logging.root.handlers = []
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
"""parse"""
parser = argparse.ArgumentParser("M3D-RPN train script")
parser.add_argument(
"--use_data_parallel", # TODO
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use data parallel mode to train the model."
)
parser.add_argument(
'--backbone',
type=str,
default='DenseNet121',
help='backbone model to train, default DenseNet121')
parser.add_argument(
'--conf', type=str, default='kitti_3d_multi_main', help='config')
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')
parser.add_argument(
'--data_dir', type=str, default='dataset', help='dataset directory')
parser.add_argument(
'--save_dir',
type=str,
default='output',
help='directory name to save train snapshoot')
parser.add_argument(
'--resume',
type=str,
default=None,
help='path to resume training based on previous checkpoints. '
'None for not resuming any checkpoints.')
parser.add_argument(
'--log_interval',
type=int,
default=1,
help='mini-batch interval for logging.')
parser.add_argument(
'--ce',
action='store_true',
help='The flag indicating whether to run the task '
'for continuous evaluation.')
args = parser.parse_args()
return args
def train():
"""main train"""
args = parse_args()
if not os.path.isdir(args.save_dir):
os.makedirs(args.save_dir)
assert args.backbone in ['DenseNet121'], "--backbone unsupported"
# conf init
conf = core.init_config(args.conf)
paths = core.init_training_paths(args.conf)
tracker = edict()
start_iter = 0
start_time = time.time()
# get reader and anchor
m3drpn_reader = M3drpnReader(conf, args.data_dir)
epoch = (conf.max_iter / (m3drpn_reader.len / conf.batch_size)) + 1
train_reader = m3drpn_reader.get_reader(conf.batch_size, mode='train')
generate_anchors(conf, m3drpn_reader.data['train'], paths.output)
compute_bbox_stats(conf, m3drpn_reader.data['train'], paths.output)
pickle_write(os.path.join(paths.output, 'conf.pkl'), conf)
# train
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
if args.ce:
print("ce mode")
seed = 33
np.random.seed(seed)
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
# -----------------------------------------
# network and loss
# -----------------------------------------
# training network
train_model, optimizer = core.init_training_model(conf, args.backbone,
paths.output)
# setup loss
criterion_det = RPN_3D_loss(conf)
if args.use_data_parallel:
train_model = fluid.dygraph.parallel.DataParallel(train_model,
strategy)
if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
total_batch_num = 0
for epo in range(int(epoch)):
total_loss = 0.0
total_acc1 = 0.0
total_sample = 0
for batch_id, data in enumerate(train_reader()):
batch_start = time.time()
images = np.array([x[0].reshape(3, 512, 1760)
for x in data]).astype('float32')
imobjs = np.array([x[1] for x in data])
if len(np.array([x[1] for x in data])) != conf.batch_size:
continue
img = to_variable(images)
cls, prob, bbox_2d, bbox_3d, feat_size = train_model(img)
# # loss
det_loss, det_stats = criterion_det(cls, prob, bbox_2d, bbox_3d,
imobjs, feat_size)
total_loss = det_loss
stats = det_stats
# backprop
if total_loss > 0:
if args.use_data_parallel:
total_loss = train_model.scale_loss(total_loss)
total_loss.backward()
train_model.apply_collective_grads()
else:
total_loss.backward()
# batch skip, simulates larger batches by skipping gradient step
if (not 'batch_skip' in conf) or (
(batch_id + 1) % conf.batch_skip) == 0:
optimizer.minimize(total_loss)
optimizer.clear_gradients()
batch_end = time.time()
train_batch_cost = batch_end - batch_start
# keep track of stats
compute_stats(tracker, stats)
# -----------------------------------------
# display
# -----------------------------------------
iteration = epo * (m3drpn_reader.len / conf.batch_size
) + batch_id
if iteration % conf.display == 0 and iteration > start_iter:
# log results
log_stats(tracker, iteration, start_time, start_iter,
conf.max_iter)
print( "epoch %d | batch step %d | iter %d, batch cost: %.5f, loss %0.3f" % \
(epo, batch_id, iteration, train_batch_cost, total_loss.numpy()))
# reset tracker
tracker = edict()
# snapshot, do_test
if iteration % conf.snapshot_iter == 0 and iteration > start_iter:
fluid.save_dygraph(
train_model.state_dict(),
'{}/iter{}_params'.format(paths.weights, iteration))
fluid.save_dygraph(
optimizer.state_dict(),
'{}/iter{}_opt'.format(paths.weights, iteration))
#do test
if conf.do_test:
train_model.phase = "eval"
train_model.eval()
results_path = os.path.join(paths.results,
'results_{}'.format((epo)))
if conf.test_protocol.lower() == 'kitti':
results_path = os.path.join(results_path, 'data')
mkdir_if_missing(results_path, delete_if_exist=True)
test_kitti_3d(conf.dataset_test, train_model, conf,
results_path, paths.data)
train_model.phase = "train"
train_model.train()
if __name__ == '__main__':
train()
"""
if args.resume:
if not os.path.isdir(args.resume):
assert os.path.exists("{}.pdparams".format(args.resume)), \
"Given resume weight {}.pdparams not exist.".format(args.resume)
assert os.path.exists("{}.pdopt".format(args.resume)), \
"Given resume optimizer state {}.pdopt not exist.".format(args.resume)
fluid.load(train_prog, args.resume, exe)
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
Contains common utility functions.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import six
import logging
import numpy as np
import paddle.fluid as fluid
__all__ = ["check_gpu", "print_arguments", "parse_outputs", "Stat"]
logger = logging.getLogger(__name__)
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=True in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as True while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set --use_gpu=False to run model on CPU"
try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
sys.exit(1)
except Exception as e:
pass
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
logger.info("----------- Configuration Arguments -----------")
for arg, value in sorted(six.iteritems(vars(args))):
logger.info("%s: %s" % (arg, value))
logger.info("------------------------------------------------")
def parse_outputs(outputs):
"""parse_outputs"""
keys, values = [], []
for k, v in outputs.items():
keys.append(k)
v.persistable = True
values.append(v.name)
return keys, values
class Stat(object):
"""Stat"""
def __init__(self):
self.stats = {}
def update(self, keys, values):
"""update"""
for k, v in zip(keys, values):
if k not in self.stats:
self.stats[k] = []
self.stats[k].append(v)
def reset(self):
"""reset"""
self.stats = {}
def get_mean_log(self):
"""get_mean_log"""
log = ""
for k, v in self.stats.items():
log += "avg_{}: {:.4f}, ".format(k, np.mean(v))
return log
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册