未验证 提交 08f3c0be 编写于 作者: S shuluoshu 提交者: GitHub

add M3D-RPN model (#4822)

* Add M3d-RPN model.
Co-authored-by: Nyexiaoqing <yexiaoqing@baidu.com>
上级 a33f0814
# M3D-RPN: Monocular 3D Region Proposal Network for Object Detection
## Introduction
Monocular 3D region proposal network for object detection accepted to ICCV 2019 (Oral), detailed in [arXiv report](https://arxiv.org/abs/1907.06038).
## Setup
- **Cuda & Python**
In this project we utilize PaddlePaddle1.8 with Python 3, Cuda 9, and a few Anaconda packages.
- **Data**
Download the full [KITTI](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d) detection dataset. Then place a softlink (or the actual data) in *M3D-RPN/data/kitti*.
```
cd M3D-RPN
ln -s /path/to/kitti dataset/kitti
```
Then use the following scripts to extract the data splits, which use softlinks to the above directory for efficient storage.
```
python dataset/kitti_split1/setup_split.py
python dataset/kitti_split2/setup_split.py
```
Next, build the KITTI devkit eval for each split.
```
sh dataset/kitti_split1/devkit/cpp/build.sh
sh dataset/kitti_split2/devkit/cpp/build.sh
```
Lastly, build the nms modules
```
cd lib/nms
make
```
## Training
Training is split into a warmup and main configurations. Review the configurations in *config* for details.
```
// First train the warmup (without depth-aware)
python train.py --config=kitti_3d_multi_warmup
// Then train the main experiment (with depth-aware)
python train.py --config=kitti_3d_multi_main
```
## Testing
We provide models for the main experiments on val1 data splits available to download here [M3D-RPN-release.tar](https://pan.baidu.com/s/1VQa5hGzIbauLOQi-0kR9Hg), passward:ls39.
Testing requires paths to the configuration file and model weights, exposed variables near the top *test.py*. To test a configuration and model, simply update the variables and run the test file as below.
```
python test.py --conf_path M3D-RPN-release/conf.pkl --weights_path M3D-RPN-release/iter50000.0_params.pdparams
```
"""
config of main
"""
from easydict import EasyDict as edict
import numpy as np
def Config():
"""
config
"""
conf = edict()
# ----------------------------------------
# general
# ----------------------------------------
conf.model = 'model_3d_dilate_depth_aware'
# solver settings
conf.solver_type = 'sgd'
conf.lr = 0.004
conf.momentum = 0.9
conf.weight_decay = 0.0005
conf.max_iter = 50000
conf.snapshot_iter = 10000
conf.display = 20
conf.do_test = True
# sgd parameters
conf.lr_policy = 'poly'
conf.lr_steps = None
conf.lr_target = conf.lr * 0.00001
# random
conf.rng_seed = 2
conf.cuda_seed = 2
# misc network
conf.image_means = [0.485, 0.456, 0.406]
conf.image_stds = [0.229, 0.224, 0.225]
conf.feat_stride = 16
conf.has_3d = True
# ----------------------------------------
# image sampling and datasets
# ----------------------------------------
# scale sampling
conf.test_scale = 512
conf.crop_size = [512, 1760]
conf.mirror_prob = 0.50
conf.distort_prob = -1
# datasets
conf.dataset_test = 'kitti_split1'
conf.datasets_train = [{
'name': 'kitti_split1',
'anno_fmt': 'kitti_det',
'im_ext': '.png',
'scale': 1
}]
conf.use_3d_for_2d = True
# percent expected height ranges based on test_scale
# used for anchor selection
conf.percent_anc_h = [0.0625, 0.75]
# labels settings
conf.min_gt_h = conf.test_scale * conf.percent_anc_h[0]
conf.max_gt_h = conf.test_scale * conf.percent_anc_h[1]
conf.min_gt_vis = 0.65
conf.ilbls = ['Van', 'ignore']
conf.lbls = ['Car', 'Pedestrian', 'Cyclist']
# ----------------------------------------
# detection sampling
# ----------------------------------------
# detection sampling
conf.batch_size = 2
conf.fg_image_ratio = 1.0
conf.box_samples = 0.20
conf.fg_fraction = 0.20
conf.bg_thresh_lo = 0
conf.bg_thresh_hi = 0.5
conf.fg_thresh = 0.5
conf.ign_thresh = 0.5
conf.best_thresh = 0.35
# ----------------------------------------
# inference and testing
# ----------------------------------------
# nms
conf.nms_topN_pre = 3000
conf.nms_topN_post = 40
conf.nms_thres = 0.4
conf.clip_boxes = False
conf.test_protocol = 'kitti'
conf.test_db = 'kitti'
conf.test_min_h = 0
conf.min_det_scales = [0, 0]
# ----------------------------------------
# anchor settings
# ----------------------------------------
# clustering settings
conf.cluster_anchors = 0
conf.even_anchors = 0
conf.expand_anchors = 0
conf.anchors = None
conf.bbox_means = None
conf.bbox_stds = None
# initialize anchors
base = (conf.max_gt_h / conf.min_gt_h)**(1 / (12 - 1))
conf.anchor_scales = np.array(
[conf.min_gt_h * (base**i) for i in range(0, 12)])
conf.anchor_ratios = np.array([0.5, 1.0, 1.5])
# loss logic
conf.hard_negatives = True
conf.focal_loss = 0
conf.cls_2d_lambda = 1
conf.iou_2d_lambda = 1
conf.bbox_2d_lambda = 0
conf.bbox_3d_lambda = 1
conf.bbox_3d_proj_lambda = 0.0
conf.hill_climbing = True
conf.bins = 32
# visdom
conf.visdom_port = 8100
conf.pretrained = 'paddle.pdparams'
return conf
"""
config of warmup
"""
from easydict import EasyDict as edict
import numpy as np
def Config():
"""
config
"""
conf = edict()
# ----------------------------------------
# general
# ----------------------------------------
conf.model = 'model_3d_dilate'
# solver settings
conf.solver_type = 'sgd'
conf.lr = 0.004
conf.momentum = 0.9
conf.weight_decay = 0.0005
conf.max_iter = 50000
conf.snapshot_iter = 10000
conf.display = 20
conf.do_test = True
# sgd parameters
conf.lr_policy = 'poly'
conf.lr_steps = None
conf.lr_target = conf.lr * 0.00001
# random
conf.rng_seed = 2
conf.cuda_seed = 2
# misc network
conf.image_means = [0.485, 0.456, 0.406]
conf.image_stds = [0.229, 0.224, 0.225]
conf.feat_stride = 16
conf.has_3d = True
# ----------------------------------------
# image sampling and datasets
# ----------------------------------------
# scale sampling
conf.test_scale = 512
conf.crop_size = [512, 1760]
conf.mirror_prob = 0.50
conf.distort_prob = -1
# datasets
conf.dataset_test = 'kitti_split1'
conf.datasets_train = [{
'name': 'kitti_split1',
'anno_fmt': 'kitti_det',
'im_ext': '.png',
'scale': 1
}]
conf.use_3d_for_2d = True
# percent expected height ranges based on test_scale
# used for anchor selection
conf.percent_anc_h = [0.0625, 0.75]
# labels settings
conf.min_gt_h = conf.test_scale * conf.percent_anc_h[0]
conf.max_gt_h = conf.test_scale * conf.percent_anc_h[1]
conf.min_gt_vis = 0.65
conf.ilbls = ['Van', 'ignore']
conf.lbls = ['Car', 'Pedestrian', 'Cyclist']
# ----------------------------------------
# detection sampling
# ----------------------------------------
# detection sampling
conf.batch_size = 2
conf.fg_image_ratio = 1.0
conf.box_samples = 0.20
conf.fg_fraction = 0.20
conf.bg_thresh_lo = 0
conf.bg_thresh_hi = 0.5
conf.fg_thresh = 0.5
conf.ign_thresh = 0.5
conf.best_thresh = 0.35
# ----------------------------------------
# inference and testing
# ----------------------------------------
# nms
conf.nms_topN_pre = 3000
conf.nms_topN_post = 40
conf.nms_thres = 0.4
conf.clip_boxes = False
conf.test_protocol = 'kitti'
conf.test_db = 'kitti'
conf.test_min_h = 0
conf.min_det_scales = [0, 0]
# ----------------------------------------
# anchor settings
# ----------------------------------------
# clustering settings
conf.cluster_anchors = 0
conf.even_anchors = 0
conf.expand_anchors = 0
conf.anchors = None
conf.bbox_means = None
conf.bbox_stds = None
# initialize anchors
base = (conf.max_gt_h / conf.min_gt_h)**(1 / (12 - 1))
conf.anchor_scales = np.array(
[conf.min_gt_h * (base**i) for i in range(0, 12)])
conf.anchor_ratios = np.array([0.5, 1.0, 1.5])
# loss logic
conf.hard_negatives = True
conf.focal_loss = 0
conf.cls_2d_lambda = 1
conf.iou_2d_lambda = 1
conf.bbox_2d_lambda = 0
conf.bbox_3d_lambda = 1
conf.bbox_3d_proj_lambda = 0.0
conf.hill_climbing = True
conf.pretrained = 'pretrained_model/densenet.pdparams'
# visdom
conf.visdom_port = 8100
return conf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
init
"""
from . import m3drpn_reader
#from .m3drpn_reader import *
#__all__ = m3drpn_reader.__all__
\ No newline at end of file
"""
This code is based on https://github.com/garrickbrazil/M3D-RPN/blob/master/lib/augmentations.py
This file contains all data augmentation functions.
Every transform should have a __call__ function which takes in (self, image, imobj)
where imobj is an arbitary dict containing relevant information to the image.
In many cases the imobj can be None, which enables the same augmentations to be used
during testing as they are in training.
Optionally, most transforms should have an __init__ function as well, if needed.
"""
import numpy as np
from numpy import random
import cv2
import math
import os
import sys
import lib.util as util
class Compose(object):
"""
Composes a set of functions which take in an image and an object, into a single transform
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, imobj=None):
for t in self.transforms:
img, imobj = t(img, imobj)
return img, imobj
class ConvertToFloat(object):
"""
Converts image data type to float.
"""
def __call__(self, image, imobj=None):
return image.astype(np.float32), imobj
class Normalize(object):
"""
Normalize the image
"""
def __init__(self, mean, stds):
self.mean = np.array(mean, dtype=np.float32)
self.stds = np.array(stds, dtype=np.float32)
def __call__(self, image, imobj=None):
image = image.astype(np.float32)
image /= 255.0
image -= np.tile(self.mean, int(image.shape[2] / self.mean.shape[0]))
image /= np.tile(self.stds, int(image.shape[2] / self.stds.shape[0]))
return image.astype(np.float32), imobj
class Resize(object):
"""
Resize the image according to the target size height and the image height.
If the image needs to be cropped after the resize, we crop it to self.size,
otherwise we pad it with zeros along the right edge
If the object has ground truths we also scale the (known) box coordinates.
"""
def __init__(self, size):
self.size = size
def __call__(self, image, imobj=None):
scale_factor = self.size[0] / image.shape[0]
h = np.round(image.shape[0] * scale_factor).astype(int)
w = np.round(image.shape[1] * scale_factor).astype(int)
# resize
image = cv2.resize(image, (w, h))
if len(self.size) > 1:
# crop in
if image.shape[1] > self.size[1]:
image = image[:, 0:self.size[1], :]
# pad out
elif image.shape[1] < self.size[1]:
padW = self.size[1] - image.shape[1]
image = np.pad(image, [(0, 0), (0, padW), (0, 0)], 'constant')
if imobj:
# store scale factor, just in case
imobj.scale_factor = scale_factor
if 'gts' in imobj:
# scale all coordinates
for gtind, gt in enumerate(imobj.gts):
if 'bbox_full' in imobj.gts[gtind]:
imobj.gts[gtind].bbox_full *= scale_factor
if 'bbox_vis' in imobj.gts[gtind]:
imobj.gts[gtind].bbox_vis *= scale_factor
if 'bbox_3d' in imobj.gts[gtind]:
# only scale x/y center locations (in 2D space!)
imobj.gts[gtind].bbox_3d[0] *= scale_factor
imobj.gts[gtind].bbox_3d[1] *= scale_factor
if 'gts_pre' in imobj:
# scale all coordinates
for gtind, gt in enumerate(imobj.gts_pre):
if 'bbox_full' in imobj.gts_pre[gtind]:
imobj.gts_pre[gtind].bbox_full *= scale_factor
if 'bbox_vis' in imobj.gts_pre[gtind]:
imobj.gts_pre[gtind].bbox_vis *= scale_factor
if 'bbox_3d' in imobj.gts_pre[gtind]:
# only scale x/y center locations (in 2D space!)
imobj.gts_pre[gtind].bbox_3d[0] *= scale_factor
imobj.gts_pre[gtind].bbox_3d[1] *= scale_factor
return image, imobj
class RandomSaturation(object):
"""
Randomly adjust the saturation of an image given a lower and upper bound,
and a distortion probability.
This function assumes the image is in HSV!!
"""
def __init__(self, distort_prob, lower=0.5, upper=1.5):
self.distort_prob = distort_prob
self.lower = lower
self.upper = upper
assert self.upper >= self.lower, "contrast upper must be >= lower."
assert self.lower >= 0, "contrast lower must be non-negative."
def __call__(self, image, imobj=None):
if random.rand() <= self.distort_prob:
image[:, :, 1] *= random.uniform(self.lower, self.upper)
return image, imobj
class RandomHue(object):
"""
Randomly adjust the hue of an image given a delta degree to rotate by,
and a distortion probability.
This function assumes the image is in HSV!!
"""
def __init__(self, distort_prob, delta=18.0):
assert delta >= 0.0 and delta <= 360.0
self.delta = delta
self.distort_prob = distort_prob
def __call__(self, image, imobj=None):
if random.rand() <= self.distort_prob:
image[:, :, 0] += random.uniform(-self.delta, self.delta)
image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0
image[:, :, 0][image[:, :, 0] < 0.0] += 360.0
return image, imobj
class ConvertColor(object):
"""
Converts color spaces to/from HSV and BGR
"""
def __init__(self, current='BGR', transform='HSV'):
self.transform = transform
self.current = current
def __call__(self, image, imobj=None):
# BGR --> HSV
if self.current == 'BGR' and self.transform == 'HSV':
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# HSV --> BGR
elif self.current == 'HSV' and self.transform == 'BGR':
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
else:
raise NotImplementedError
return image, imobj
class RandomContrast(object):
"""
Randomly adjust contrast of an image given lower and upper bound,
and a distortion probability.
"""
def __init__(self, distort_prob, lower=0.5, upper=1.5):
self.lower = lower
self.upper = upper
self.distort_prob = distort_prob
assert self.upper >= self.lower, "contrast upper must be >= lower."
assert self.lower >= 0, "contrast lower must be non-negative."
def __call__(self, image, imobj=None):
if random.rand() <= self.distort_prob:
alpha = random.uniform(self.lower, self.upper)
image *= alpha
return image, imobj
class RandomMirror(object):
"""
Randomly mirror an image horzontially, given a mirror probabilty.
Also, adjust all box cordinates accordingly.
"""
def __init__(self, mirror_prob):
self.mirror_prob = mirror_prob
def __call__(self, image, imobj):
_, width, _ = image.shape
if random.rand() <= self.mirror_prob:
image = image[:, ::-1, :]
image = np.ascontiguousarray(image)
# flip the coordinates w.r.t the horizontal flip (only adjust X)
for gtind, gt in enumerate(imobj.gts):
if 'bbox_full' in imobj.gts[gtind]:
imobj.gts[gtind].bbox_full[0] = image.shape[
1] - gt.bbox_full[0] - gt.bbox_full[2]
if 'bbox_vis' in imobj.gts[gtind]:
imobj.gts[gtind].bbox_vis[0] = image.shape[1] - gt.bbox_vis[
0] - gt.bbox_vis[2]
if 'bbox_3d' in imobj.gts[gtind]:
imobj.gts[gtind].bbox_3d[0] = image.shape[1] - gt.bbox_3d[
0] - 1
rotY = gt.bbox_3d[10]
rotY = (-math.pi - rotY) if rotY < 0 else (math.pi - rotY)
while rotY > math.pi:
rotY -= math.pi * 2
while rotY < (-math.pi):
rotY += math.pi * 2
cx2d = gt.bbox_3d[0]
cy2d = gt.bbox_3d[1]
cz2d = gt.bbox_3d[2]
coord3d = imobj.p2_inv.dot(
np.array([cx2d * cz2d, cy2d * cz2d, cz2d, 1]))
alpha = util.convertRot2Alpha(rotY, coord3d[2], coord3d[0])
imobj.gts[gtind].bbox_3d[10] = rotY
imobj.gts[gtind].bbox_3d[6] = alpha
return image, imobj
class RandomBrightness(object):
"""
Randomly adjust the brightness of an image given given a +- delta range,
and a distortion probability.
"""
def __init__(self, distort_prob, delta=32):
assert delta >= 0.0
assert delta <= 255.0
self.delta = delta
self.distort_prob = distort_prob
def __call__(self, image, imobj=None):
if random.rand() <= self.distort_prob:
delta = random.uniform(-self.delta, self.delta)
image += delta
return image, imobj
class PhotometricDistort(object):
"""
Packages all photometric distortions into a single transform.
"""
def __init__(self, distort_prob):
self.distort_prob = distort_prob
# contrast is duplicated because it may happen before or after
# the other transforms with equal probability.
self.transforms = [
RandomContrast(distort_prob), ConvertColor(transform='HSV'),
RandomSaturation(distort_prob), RandomHue(distort_prob),
ConvertColor(
current='HSV', transform='BGR'), RandomContrast(distort_prob)
]
self.rand_brightness = RandomBrightness(distort_prob)
def __call__(self, image, imobj):
# do contrast first
if random.rand() <= 0.5:
distortion = self.transforms[:-1]
# do contrast last
else:
distortion = self.transforms[1:]
# add random brightness
distortion.insert(0, self.rand_brightness)
# compose transformation
distortion = Compose(distortion)
return distortion(image.copy(), imobj)
class Augmentation(object):
"""
Data Augmentation class which packages the typical pre-processing
and all data augmentation transformations (mirror and photometric distort)
into a single transform.
"""
def __init__(self, conf):
self.mean = conf.image_means
self.stds = conf.image_stds
self.size = conf.crop_size
self.mirror_prob = conf.mirror_prob
self.distort_prob = conf.distort_prob
if conf.distort_prob <= 0:
self.augment = Compose([
ConvertToFloat(), RandomMirror(self.mirror_prob),
Resize(self.size), Normalize(self.mean, self.stds)
])
else:
self.augment = Compose([
ConvertToFloat(), PhotometricDistort(self.distort_prob),
RandomMirror(self.mirror_prob), Resize(self.size),
Normalize(self.mean, self.stds)
])
def __call__(self, img, imobj):
return self.augment(img, imobj)
class Preprocess(object):
"""
Preprocess function which ONLY does the basic pre-processing of an image,
meant to be used during the testing/eval stages.
"""
def __init__(self, size, mean, stds):
self.mean = mean
self.stds = stds
self.size = size
self.preprocess = Compose([
ConvertToFloat(), Resize(self.size), Normalize(self.mean, self.stds)
])
def __call__(self, img):
img = self.preprocess(img)[0]
for i in range(int(img.shape[2] / 3)):
# convert to RGB then permute to be [B C H W]
img[:, :, (i * 3):(i * 3) + 3] = img[:, :, (i * 3 + 2, i * 3 + 1, i
* 3)]
img = np.transpose(img, [2, 0, 1])
return img
# -*- coding:utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
data reader
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import os.path as osp
import signal
import numpy as np
import random
import logging
import math
import copy
import glob
import time
import re
from PIL import Image
import lib.util as util
from lib.rpn_util import *
import data.augmentations as augmentations
from easydict import EasyDict as edict
import cv2
import pdb
__all__ = ["M3drpnReader"]
logger = logging.getLogger(__name__)
class M3drpnReader(object):
"""m3drpn reader"""
def __init__(self, conf, data_dir):
self.data_dir = data_dir
self.conf = conf
self.video_det = False if not ('video_det' in conf) else conf.video_det
self.video_count = 1 if not (
'video_count' in conf) else conf.video_count
self.use_3d_for_2d = ('use_3d_for_2d' in conf) and conf.use_3d_for_2d
self.load_data()
self.transform = augmentations.Augmentation(conf)
def _read_data_file(self, fname):
assert osp.isfile(fname), \
"{} is not a file".format(fname)
with open(fname) as f:
return [line.strip() for line in f]
def load_data(self):
"""load data"""
logger.info("Loading KITTI dataset from {} ...".format(self.data_dir))
# read all_files.txt
for dbind, db in enumerate(self.conf.datasets_train):
logging.info('Loading imgs_label {}'.format(db['name']))
# single imdb
imdb_single_db = []
# kitti formatting
if db['anno_fmt'].lower() == 'kitti_det':
train_folder = os.path.join(self.data_dir, db['name'],
'training')
ann_folder = os.path.join(
train_folder, 'label_2',
'') # dataset/kitti_split1/training/image_2/
cal_folder = os.path.join(train_folder, 'calib', '')
im_folder = os.path.join(train_folder, 'image_2', '')
# get sorted filepaths
annlist = sorted(glob(ann_folder + '*.txt')) # 3712
imdb_start = time()
self.affine_size = None if not (
'affine_size' in self.conf) else self.conf.affine_size
for annind, annpath in enumerate(annlist):
# get file parts
base = os.path.basename(annpath)
id, ext = os.path.splitext(base)
calpath = os.path.join(cal_folder, id + '.txt')
impath = os.path.join(im_folder, id + db['im_ext'])
impath_pre = os.path.join(train_folder, 'prev_2',
id + '_01' + db['im_ext'])
impath_pre2 = os.path.join(train_folder, 'prev_2',
id + '_02' + db['im_ext'])
impath_pre3 = os.path.join(train_folder, 'prev_2',
id + '_03' + db['im_ext'])
# read gts
p2 = read_kitti_cal(calpath)
p2_inv = np.linalg.inv(p2)
gts = read_kitti_label(annpath, p2, self.use_3d_for_2d)
obj = edict()
# store gts
obj.id = id
obj.gts = gts
obj.p2 = p2
obj.p2_inv = p2_inv
# im properties
im = Image.open(impath)
obj.path = impath
obj.path_pre = impath_pre
obj.path_pre2 = impath_pre2
obj.path_pre3 = impath_pre3
obj.imW, obj.imH = im.size
# database properties
obj.dbname = db.name
obj.scale = db.scale
obj.dbind = dbind
obj.affine_gt = None # did not compute transformer
# store
imdb_single_db.append(obj)
if (annind % 1000) == 0 and annind > 0:
time_str, dt = util.compute_eta(imdb_start, annind,
len(annlist))
logging.info('{}/{}, dt: {:0.4f}, eta: {}'.format(
annind, len(annlist), dt, time_str))
self.data = {}
self.data['train'] = imdb_single_db
self.data['test'] = {}
self.len = len(imdb_single_db)
self.sampled_weights = balance_samples(self.conf, imdb_single_db)
def _augmented_single(self, index):
"""
Grabs the item at the given index. Specifically,
- read the image from disk
- read the imobj from RAM
- applies data augmentation to (im, imobj)
- converts image to RGB and [B C W H]
"""
if not self.video_det:
# read image
im = cv2.imread(self.data['train'][index].path)
else:
# read images
im = cv2.imread(self.data['train'][index].path)
video_count = 1 if self.video_count is None else self.video_count
if video_count >= 2:
im_pre = cv2.imread(self.data['train'][index].path_pre)
if not im_pre.shape == im.shape:
im_pre = cv2.resize(im_pre, (im.shape[1], im.shape[0]))
im = np.concatenate((im, im_pre), axis=2)
if video_count >= 3:
im_pre2 = cv2.imread(self.data['train'][index].path_pre2)
if im_pre2 is None:
im_pre2 = im_pre
if not im_pre2.shape == im.shape:
im_pre2 = cv2.resize(im_pre2, (im.shape[1], im.shape[0]))
im = np.concatenate((im, im_pre2), axis=2)
if video_count >= 4:
im_pre3 = cv2.imread(self.data['train'][index].path_pre3)
if im_pre3 is None:
im_pre3 = im_pre2
if not im_pre3.shape == im.shape:
im_pre3 = cv2.resize(im_pre3, (im.shape[1], im.shape[0]))
im = np.concatenate((im, im_pre3), axis=2)
# transform / data augmentation
im, imobj = self.transform(im, copy.deepcopy(self.data['train'][index]))
for i in range(int(im.shape[2] / 3)):
# convert to RGB then permute to be [B C H W]
im[:, :, (i * 3):(i * 3) + 3] = im[:, :, (i * 3 + 2, i * 3 + 1, i *
3)]
im = np.transpose(im, [2, 0, 1])
return im, imobj
def get_reader(self, batch_size, mode='train', shuffle=True):
"""
get reader
"""
assert mode in ['train', 'test'], \
"mode can only be 'train' or 'test'"
imgs = self.data[mode]
idxs = np.arange(len(imgs))
idxs = np.random.choice(
self.len, self.len, replace=True, p=self.sampled_weights)
if mode == 'train' and shuffle:
np.random.shuffle(idxs)
def reader():
"""reader"""
batch_out = []
for ind in idxs:
augmented_img, im_obj = self._augmented_single(ind)
batch_out.append([augmented_img, im_obj])
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
return reader
# derived from M3D-RPN
def read_kitti_cal(calfile):
"""
Reads the kitti calibration projection matrix (p2) file from disc.
Args:
calfile (str): path to single calibration file
"""
text_file = open(calfile, 'r')
p2pat = re.compile((
'(P2:)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)'
+ '\s+(fpat)\s+(fpat)\s+(fpat)\s*\n').replace(
'fpat', '[-+]?[\d]+\.?[\d]*[Ee](?:[-+]?[\d]+)?'))
for line in text_file:
parsed = p2pat.fullmatch(line)
# bbGt annotation in text format of:
# cls x y w h occ x y w h ign ang
if parsed is not None:
p2 = np.zeros([4, 4], dtype=float)
p2[0, 0] = parsed.group(2)
p2[0, 1] = parsed.group(3)
p2[0, 2] = parsed.group(4)
p2[0, 3] = parsed.group(5)
p2[1, 0] = parsed.group(6)
p2[1, 1] = parsed.group(7)
p2[1, 2] = parsed.group(8)
p2[1, 3] = parsed.group(9)
p2[2, 0] = parsed.group(10)
p2[2, 1] = parsed.group(11)
p2[2, 2] = parsed.group(12)
p2[2, 3] = parsed.group(13)
p2[3, 3] = 1
text_file.close()
return p2
def balance_samples(conf, imdb):
"""
Balances the samples in an image dataset according to the given configuration.
Basically we check which images have relevant foreground samples and which are empty,
then we compute the sampling weights according to a desired fg_image_ratio.
This is primarily useful in datasets which have a lot of empty (background) images, which may
cause instability during training if not properly balanced against.
"""
sample_weights = np.ones(len(imdb))
if conf.fg_image_ratio >= 0:
empty_inds = []
valid_inds = []
for imind, imobj in enumerate(imdb):
valid = 0
scale = conf.test_scale / imobj.imH
igns, rmvs = determine_ignores(imobj.gts, conf.lbls, conf.ilbls,
conf.min_gt_vis, conf.min_gt_h,
conf.max_gt_h, scale)
for gtind, gt in enumerate(imobj.gts):
if (not igns[gtind]) and (not rmvs[gtind]):
valid += 1
sample_weights[imind] = valid
if valid > 0:
valid_inds.append(imind)
else:
empty_inds.append(imind)
if not (conf.fg_image_ratio == 2):
fg_weight = 1
bg_weight = 0
sample_weights[valid_inds] = fg_weight
sample_weights[empty_inds] = bg_weight
logging.info('weighted respectively as {:.2f} and {:.2f}'.format(
fg_weight, bg_weight))
logging.info('Found {} foreground and {} empty images'.format(
np.sum(sample_weights > 0), np.sum(sample_weights <= 0)))
# force sampling weights to sum to 1
sample_weights /= np.sum(sample_weights)
return sample_weights
def read_kitti_poses(posefile):
"""
read_kitti_poses
"""
text_file = open(posefile, 'r')
ppat1 = re.compile((
'(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)'
+ '\s+(fpat)\s+(fpat)\s+(fpat)\s*\n').replace(
'fpat', '[-+]?[\d]+\.?[\d]*[Ee](?:[-+]?[\d]+)?'))
ppat2 = re.compile((
'(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)'
+ '\s+(fpat)\s+(fpat)\s+(fpat)\s*\n').replace('fpat',
'[-+]?[\d]+\.?[\d]*'))
ps = []
for line in text_file:
parsed1 = ppat1.fullmatch(line)
parsed2 = ppat2.fullmatch(line)
if parsed1 is not None:
p = np.zeros([4, 4], dtype=float)
p[0, 0] = parsed1.group(1)
p[0, 1] = parsed1.group(2)
p[0, 2] = parsed1.group(3)
p[0, 3] = parsed1.group(4)
p[1, 0] = parsed1.group(5)
p[1, 1] = parsed1.group(6)
p[1, 2] = parsed1.group(7)
p[1, 3] = parsed1.group(8)
p[2, 0] = parsed1.group(9)
p[2, 1] = parsed1.group(10)
p[2, 2] = parsed1.group(11)
p[2, 3] = parsed1.group(12)
p[3, 3] = 1
ps.append(p)
elif parsed2 is not None:
p = np.zeros([4, 4], dtype=float)
p[0, 0] = parsed2.group(1)
p[0, 1] = parsed2.group(2)
p[0, 2] = parsed2.group(3)
p[0, 3] = parsed2.group(4)
p[1, 0] = parsed2.group(5)
p[1, 1] = parsed2.group(6)
p[1, 2] = parsed2.group(7)
p[1, 3] = parsed2.group(8)
p[2, 0] = parsed2.group(9)
p[2, 1] = parsed2.group(10)
p[2, 2] = parsed2.group(11)
p[2, 3] = parsed2.group(12)
p[3, 3] = 1
ps.append(p)
text_file.close()
return ps
def read_kitti_label(file, p2, use_3d_for_2d=False):
"""
Reads the kitti label file from disc.
Args:
file (str): path to single label file for an image
p2 (ndarray): projection matrix for the given image
"""
gts = []
text_file = open(file, 'r')
'''
Values Name Description
----------------------------------------------------------------------------
1 type Describes the type of object: 'Car', 'Van', 'Truck',
'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram',
'Misc' or 'DontCare'
1 truncated Float from 0 (non-truncated) to 1 (truncated), where
truncated refers to the object leaving image boundaries
1 occluded Integer (0,1,2,3) indicating occlusion state:
0 = fully visible, 1 = partly occluded
2 = largely occluded, 3 = unknown
1 alpha Observation angle of object, ranging [-pi..pi]
4 bbox 2D bounding box of object in the image (0-based index):
contains left, top, right, bottom pixel coordinates
3 dimensions 3D object dimensions: height, width, length (in meters)
3 location 3D object location x,y,z in camera coordinates (in meters)
1 rotation_y Rotation ry around Y-axis in camera coordinates [-pi..pi]
1 score Only for results: Float, indicating confidence in
detection, needed for p/r curves, higher is better.
'''
pattern = re.compile((
'([a-zA-Z\-\?\_]+)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+'
+
'(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s+(fpat)\s*((fpat)?)\n'
).replace('fpat', '[-+]?\d*\.\d+|[-+]?\d+'))
for line in text_file:
parsed = pattern.fullmatch(line)
# bbGt annotation in text format of:
# cls x y w h occ x y w h ign ang
if parsed is not None:
obj = edict()
ign = False
cls = parsed.group(1)
trunc = float(parsed.group(2))
occ = float(parsed.group(3))
alpha = float(parsed.group(4))
x = float(parsed.group(5))
y = float(parsed.group(6))
x2 = float(parsed.group(7))
y2 = float(parsed.group(8))
width = x2 - x + 1
height = y2 - y + 1
h3d = float(parsed.group(9))
w3d = float(parsed.group(10))
l3d = float(parsed.group(11))
cx3d = float(parsed.group(12)) # center of car in 3d
cy3d = float(parsed.group(13)) # bottom of car in 3d
cz3d = float(parsed.group(14)) # center of car in 3d
rotY = float(parsed.group(15))
# actually center the box
cy3d -= (h3d / 2)
elevation = (1.65 - cy3d)
if use_3d_for_2d and h3d > 0 and w3d > 0 and l3d > 0:
# re-compute the 2D box using 3D (finally, avoids clipped boxes)
verts3d, corners_3d = project_3d(
p2, cx3d, cy3d, cz3d, w3d, h3d, l3d, rotY, return_3d=True)
# any boxes behind camera plane?
if np.any(corners_3d[2, :] <= 0):
ign = True
else:
x = min(verts3d[:, 0])
y = min(verts3d[:, 1])
x2 = max(verts3d[:, 0])
y2 = max(verts3d[:, 1])
width = x2 - x + 1
height = y2 - y + 1
# project cx, cy, cz
coord3d = p2.dot(np.array([cx3d, cy3d, cz3d, 1]))
# store the projected instead
cx3d_2d = coord3d[0]
cy3d_2d = coord3d[1]
cz3d_2d = coord3d[2]
cx = cx3d_2d / cz3d_2d
cy = cy3d_2d / cz3d_2d
# encode occlusion with range estimation
# 0 = fully visible, 1 = partly occluded
# 2 = largely occluded, 3 = unknown
if occ == 0:
vis = 1
elif occ == 1:
vis = 0.66
elif occ == 2:
vis = 0.33
else:
vis = 0.0
while rotY > math.pi:
rotY -= math.pi * 2
while rotY < (-math.pi):
rotY += math.pi * 2
# recompute alpha
alpha = util.convertRot2Alpha(rotY, cz3d, cx3d)
obj.elevation = elevation
obj.cls = cls
obj.occ = occ > 0
obj.ign = ign
obj.visibility = vis
obj.trunc = trunc
obj.alpha = alpha
obj.rotY = rotY
# is there an extra field? (assume to be track)
if len(parsed.groups()) >= 16 and parsed.group(16).isdigit():
obj.track = int(parsed.group(16))
obj.bbox_full = np.array([x, y, width, height])
obj.bbox_3d = [
cx, cy, cz3d_2d, w3d, h3d, l3d, alpha, cx3d, cy3d, cz3d, rotY
]
obj.center_3d = [cx3d, cy3d, cz3d]
gts.append(obj)
text_file.close()
return gts
def _term_reader(signum, frame):
"""_term_reader"""
logger.info('pid {} terminated, terminate reader process '
'group {}...'.format(os.getpid(), os.getpgrp()))
os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)
signal.signal(signal.SIGINT, _term_reader)
signal.signal(signal.SIGTERM, _term_reader)
此差异已折叠。
此差异已折叠。
all:
python setup.py build_ext --inplace
rm -rf build
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import numpy as np
cimport numpy as np
cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
return a if a >= b else b
cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
return a if a <= b else b
def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1)
cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1]
cdef int ndets = dets.shape[0]
cdef np.ndarray[np.int_t, ndim=1] suppressed = \
np.zeros((ndets), dtype=np.int)
# nominal indices
cdef int _i, _j
# sorted indices
cdef int i, j
# temp variables for box i's (the box currently under consideration)
cdef np.float32_t ix1, iy1, ix2, iy2, iarea
# variables for computing overlap with box j (lower scoring box)
cdef np.float32_t xx1, yy1, xx2, yy2
cdef np.float32_t w, h
cdef np.float32_t inter, ovr
keep = []
for _i in range(ndets):
i = order[_i]
if suppressed[i] == 1:
continue
keep.append(i)
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0.0, xx2 - xx1 + 1)
h = max(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (iarea + areas[j] - inter)
if ovr >= thresh:
suppressed[j] = 1
return keep
void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
int boxes_dim, float nms_overlap_thresh, int device_id);
# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import numpy as np
cimport numpy as np
assert sizeof(int) == sizeof(np.int32_t)
cdef extern from "gpu_nms.hpp":
void _nms(np.int32_t*, int*, np.float32_t*, int, int, float, int)
def gpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh,
np.int32_t device_id=0):
cdef int boxes_num = dets.shape[0]
cdef int boxes_dim = dets.shape[1]
cdef int num_out
cdef np.ndarray[np.int32_t, ndim=1] \
keep = np.zeros(boxes_num, dtype=np.int32)
cdef np.ndarray[np.float32_t, ndim=1] \
scores = dets[:, 4]
cdef np.ndarray[np.int_t, ndim=1] \
order = scores.argsort()[::-1]
cdef np.ndarray[np.float32_t, ndim=2] \
sorted_dets = dets[order, :]
_nms(&keep[0], &num_out, &sorted_dets[0, 0], boxes_num, boxes_dim, thresh, device_id)
keep = keep[:num_out]
return list(order[keep])
// ------------------------------------------------------------------
// Faster R-CNN
// Copyright (c) 2015 Microsoft
// Licensed under The MIT License [see fast-rcnn/LICENSE for details]
// Written by Shaoqing Ren
// ------------------------------------------------------------------
#include "gpu_nms.hpp"
#include <vector>
#include <iostream>
#define CUDA_CHECK(condition) \
/* Code block avoids redefinition of cudaError_t error */ \
do { \
cudaError_t error = condition; \
if (error != cudaSuccess) { \
std::cout << cudaGetErrorString(error) << std::endl; \
} \
} while (0)
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
int const threadsPerBlock = sizeof(unsigned long long) * 8;
__device__ inline float devIoU(float const * const a, float const * const b) {
float left = max(a[0], b[0]), right = min(a[2], b[2]);
float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
float interS = width * height;
float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
return interS / (Sa + Sb - interS);
}
__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
const float *dev_boxes, unsigned long long *dev_mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
__shared__ float block_boxes[threadsPerBlock * 5];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 5 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
block_boxes[threadIdx.x * 5 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
block_boxes[threadIdx.x * 5 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
block_boxes[threadIdx.x * 5 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
block_boxes[threadIdx.x * 5 + 4] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const float *cur_box = dev_boxes + cur_box_idx * 5;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
t |= 1ULL << i;
}
}
const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
void _set_device(int device_id) {
int current_device;
CUDA_CHECK(cudaGetDevice(&current_device));
if (current_device == device_id) {
return;
}
// The call to cudaSetDevice must come before any calls to Get, which
// may perform initialization using the GPU.
CUDA_CHECK(cudaSetDevice(device_id));
}
void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
int boxes_dim, float nms_overlap_thresh, int device_id) {
_set_device(device_id);
float* boxes_dev = NULL;
unsigned long long* mask_dev = NULL;
const int col_blocks = DIVUP(boxes_num, threadsPerBlock);
CUDA_CHECK(cudaMalloc(&boxes_dev,
boxes_num * boxes_dim * sizeof(float)));
CUDA_CHECK(cudaMemcpy(boxes_dev,
boxes_host,
boxes_num * boxes_dim * sizeof(float),
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMalloc(&mask_dev,
boxes_num * col_blocks * sizeof(unsigned long long)));
dim3 blocks(DIVUP(boxes_num, threadsPerBlock),
DIVUP(boxes_num, threadsPerBlock));
dim3 threads(threadsPerBlock);
nms_kernel<<<blocks, threads>>>(boxes_num,
nms_overlap_thresh,
boxes_dev,
mask_dev);
std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
CUDA_CHECK(cudaMemcpy(&mask_host[0],
mask_dev,
sizeof(unsigned long long) * boxes_num * col_blocks,
cudaMemcpyDeviceToHost));
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
int num_to_keep = 0;
for (int i = 0; i < boxes_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long *p = &mask_host[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
*num_out = num_to_keep;
CUDA_CHECK(cudaFree(boxes_dev));
CUDA_CHECK(cudaFree(mask_dev));
}
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import numpy as np
def py_cpu_nms(dets, thresh):
"""Pure Python NMS baseline."""
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import os
from os.path import join as pjoin
from setuptools import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
import subprocess
import numpy as np
def find_in_path(name, path):
"Find a file in a search path"
# Adapted fom
# http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/
for dir in path.split(os.pathsep):
binpath = pjoin(dir, name)
if os.path.exists(binpath):
return os.path.abspath(binpath)
return None
def locate_cuda():
"""Locate the CUDA environment on the system
Returns a dict with keys 'home', 'nvcc', 'include', and 'lib64'
and values giving the absolute path to each directory.
Starts by looking for the CUDAHOME env variable. If not found, everything
is based on finding 'nvcc' in the PATH.
"""
# first check if the CUDAHOME env variable is in use
if 'CUDAHOME' in os.environ:
home = os.environ['CUDAHOME']
nvcc = pjoin(home, 'bin', 'nvcc')
else:
# otherwise, search the PATH for NVCC
default_path = pjoin(os.sep, 'usr', 'local', 'cuda', 'bin')
nvcc = find_in_path('nvcc',
os.environ['PATH'] + os.pathsep + default_path)
if nvcc is None:
raise EnvironmentError(
'The nvcc binary could not be '
'located in your $PATH. Either add it to your path, or set $CUDAHOME'
)
home = os.path.dirname(os.path.dirname(nvcc))
cudaconfig = {
'home': home,
'nvcc': nvcc,
'include': pjoin(home, 'include'),
'lib64': pjoin(home, 'lib64')
}
for k, v in cudaconfig.items():
if not os.path.exists(v):
raise EnvironmentError(
'The CUDA %s path could not be located in %s' % (k, v))
return cudaconfig
CUDA = locate_cuda()
# Obtain the numpy include directory. This logic works across numpy versions.
try:
numpy_include = np.get_include()
except AttributeError:
numpy_include = np.get_numpy_include()
def customize_compiler_for_nvcc(self):
"""inject deep into distutils to customize how the dispatch
to gcc/nvcc works.
If you subclass UnixCCompiler, it's not trivial to get your subclass
injected in, and still have the right customizations (i.e.
distutils.sysconfig.customize_compiler) run on it. So instead of going
the OO route, I have this. Note, it's kindof like a wierd functional
subclassing going on."""
# tell the compiler it can processes .cu
self.src_extensions.append('.cu')
# save references to the default compiler_so and _comple methods
default_compiler_so = self.compiler_so
super = self._compile
# now redefine the _compile method. This gets executed for each
# object but distutils doesn't have the ability to change compilers
# based on source extension: we add it.
def _compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
if os.path.splitext(src)[1] == '.cu':
# use the cuda for .cu files
self.set_executable('compiler_so', CUDA['nvcc'])
# use only a subset of the extra_postargs, which are 1-1 translated
# from the extra_compile_args in the Extension class
postargs = extra_postargs['nvcc']
else:
postargs = extra_postargs['gcc']
super(obj, src, ext, cc_args, postargs, pp_opts)
# reset the default compiler_so, which we might have changed for cuda
self.compiler_so = default_compiler_so
# inject our redefined _compile method into the class
self._compile = _compile
# run the customize_compiler
class custom_build_ext(build_ext):
def build_extensions(self):
customize_compiler_for_nvcc(self.compiler)
build_ext.build_extensions(self)
ext_modules = [
Extension(
"cpu_nms", ["cpu_nms.pyx"],
extra_compile_args={'gcc': ["-Wno-cpp", "-Wno-unused-function"]},
include_dirs=[numpy_include],
library_dirs=["/home/vis/yexiaoqing/06_pp3d/python3.7_paddle1.8/lib"],
libraries=['python3.7m']),
Extension(
'gpu_nms',
['nms_kernel.cu', 'gpu_nms.pyx'],
#library_dirs=[CUDA['lib64']],
library_dirs=[
CUDA['lib64'],
"/home/vis/yexiaoqing/06_pp3d/python3.7_paddle1.8/lib"
],
libraries=['cudart', 'python3.7m'],
language='c++',
runtime_library_dirs=[CUDA['lib64']],
# this syntax is specific to this build system
# we're only going to use certain compiler args with nvcc and not with
# gcc the implementation of this trick is in customize_compiler() below
extra_compile_args={
'gcc': ["-Wno-unused-function"],
'nvcc': [
'-arch=sm_35', '--ptxas-options=-v', '-c', '--compiler-options',
"'-fPIC'"
]
},
include_dirs=[numpy_include, CUDA['include']]),
]
setup(
name='fast_rcnn',
ext_modules=ext_modules,
# inject our custom trigger
cmdclass={'build_ext': custom_build_ext}, )
此差异已折叠。
"""
This code is based on https://github.com/garrickbrazil/M3D-RPN/blob/master/lib/util.py
This file is meant to contain generic utility functions
which can be easily re-used in any project, and are not
specific to any project or framework (except python!).
"""
import os
import sys
from glob import glob
from time import time
import matplotlib.pyplot as plt
import numpy as np
import importlib
import pickle
import logging
import datetime
import pprint
import shutil
import math
import copy
import cv2
from mpl_toolkits.mplot3d import Axes3D
from PIL import Image
def copyfile(src, dst):
"""
copyfile
"""
shutil.copyfile(src, dst)
def pretty_print(name, input, val_width=40, key_width=0):
"""
This function creates a formatted string from a given dictionary input.
It may not support all data types, but can probably be extended.
Args:
name (str): name of the variable root
input (dict): dictionary to print
val_width (int): the width of the right hand side values
key_width (int): the minimum key width, (always auto-defaults to the longest key!)
Example:
pretty_str = pretty_print('conf', conf.__dict__)
pretty_str = pretty_print('conf', {'key1': 'example', 'key2': [1,2,3,4,5], 'key3': np.random.rand(4,4)})
print(pretty_str)
or
logging.info(pretty_str)
"""
# root
pretty_str = name + ': {\n'
# determine key width
for key in input.keys():
key_width = max(key_width, len(str(key)) + 4)
# cycle keys
for key in input.keys():
val = input[key]
# round values to 3 decimals..
if type(val) == np.ndarray: val = np.round(val, 3).tolist()
# difficult formatting
val_str = str(val)
if len(val_str) > val_width:
val_str = pprint.pformat(val, width=val_width)
val_str = val_str.replace('\n', '\n{tab}')
tab = ('{0:' + str(4 + key_width) + '}').format('')
val_str = val_str.replace('{tab}', tab)
# more difficult formatting
format_str = '{0:' + str(4) + '}{1:' + str(key_width) + '} {2:' + str(
val_width) + '}\n'
pretty_str += format_str.format('', key + ':', val_str)
# close root object
pretty_str += '}'
return pretty_str
def absolute_import(file_path):
"""
Imports a python module / file given its ABSOLUTE path.
Args:
file_path (str): absolute path to a python file to attempt to import
"""
# module name
_, name, _ = file_parts(file_path)
# load the spec and module
spec = importlib.util.spec_from_file_location(name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def init_log_file(folder_path, suffix=None, log_level=logging.INFO):
"""
This function inits a log file given a folder to write the log to.
it automatically adds a timestamp and optional suffix to the log.
Anything written to the log will automatically write to console too.
Example:
import logging
init_log_file('output/logs/')
logging.info('this will show up in both the log AND console!')
"""
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
log_format = '[%(levelname)s]: %(asctime)s %(message)s'
if suffix is not None:
file_name = timestamp + '_' + suffix
else:
file_name = timestamp
file_path = os.path.join(folder_path, file_name)
logging.basicConfig(filename=file_path, level=log_level, format=log_format)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
return file_path
def denorm_image(im, image_means, image_stds):
"""
:param im:
:param image_means:
:param image_stds:
:return:
"""
im = copy.deepcopy(im)
im[:, :, 0] *= image_stds[0]
im[:, :, 1] *= image_stds[1]
im[:, :, 2] *= image_stds[2]
im[:, :, 0] += image_means[0]
im[:, :, 1] += image_means[1]
im[:, :, 2] += image_means[2]
return im
def compute_eta(start_time, idx, total):
"""
Computes the estimated time as a formatted string as well
as the change in delta time dt.
Example:
from time import time
start_time = time()
for i in range(0, total):
<lengthly computation>
time_str, dt = compute_eta(start_time, i, total)
"""
dt = (time() - start_time) / idx
timeleft = np.max([dt * (total - idx), 0])
if timeleft > 3600:
time_str = '{:.1f}h'.format(timeleft / 3600)
elif timeleft > 60:
time_str = '{:.1f}m'.format(timeleft / 60)
else:
time_str = '{:.1f}s'.format(timeleft)
return time_str, dt
def interp_color(dist,
bounds=[0, 1],
color_lo=(0, 0, 250),
color_hi=(0, 250, 250)):
"""
:param dist:
:param bounds:
:param color_lo:
:param color_hi:
:return:
"""
percent = (dist - bounds[0]) / (bounds[1] - bounds[0])
b = color_lo[0] * (1 - percent) + color_hi[0] * percent
g = color_lo[1] * (1 - percent) + color_hi[1] * percent
r = color_lo[2] * (1 - percent) + color_hi[2] * percent
return (b, g, r)
def create_colorbar(height, width, color_lo=(0, 0, 250),
color_hi=(0, 250, 250)):
"""
:param height:
:param width:
:param color_lo:
:param color_hi:
:return:
"""
im = np.zeros([height, width, 3])
for h in range(0, height):
color = interp_color(h + 0.5, [0, height], color_hi, color_lo)
im[h, :, 0] = (color[0])
im[h, :, 1] = (color[1])
im[h, :, 2] = (color[2])
return im.astype(np.uint8)
def mkdir_if_missing(directory, delete_if_exist=False):
"""
Recursively make a directory structure even if missing.
if delete_if_exist=True then we will delete it first
which can be useful when better control over initialization is needed.
"""
if delete_if_exist and os.path.exists(directory): shutil.rmtree(directory)
# check if not exist, then make
if not os.path.exists(directory):
os.makedirs(directory)
def list_files(base_dir, file_pattern):
"""
Returns a list of files given a directory and pattern
The results are sorted alphabetically
Example:
files = list_files('path/to/images/', '*.jpg')
"""
return sorted(glob(os.path.join(base_dir) + file_pattern))
def file_parts(file_path):
"""
Lists a files parts such as base_path, file name and extension
Example
base, name, ext = file_parts('path/to/file/dog.jpg')
print(base, name, ext) --> ('path/to/file/', 'dog', '.jpg')
"""
base_path, tail = os.path.split(file_path)
name, ext = os.path.splitext(tail)
return base_path, name, ext
def pickle_write(file_path, obj):
"""
Serialize an object to a provided file_path
"""
with open(file_path, 'wb') as file:
pickle.dump(obj, file)
def pickle_read(file_path):
"""
De-serialize an object from a provided file_path
"""
with open(file_path, 'rb') as file:
return pickle.load(file)
def get_color(ind, hex=False):
"""
:param ind:
:param hex:
:return:
"""
colors = [(111, 74, 0), (81, 0, 81), (128, 64, 128), (244, 35, 232),
(250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156),
(190, 153, 153), (180, 165, 180), (150, 100, 100), (150, 120, 90),
(153, 153, 153), (250, 170, 30), (220, 220, 0), (107, 142, 35),
(152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0),
(0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 0, 90), (0, 0, 110),
(0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)]
color = colors[ind % len(colors)]
if hex:
return '#%02x%02x%02x' % (color[0], color[1], color[2])
else:
return color
def draw_3d_box(im, verts, color=(0, 200, 200), thickness=1):
"""
:param im:
:param verts:
:param color:
:param thickness:
:return:
"""
for lind in range(0, verts.shape[0] - 1):
v1 = verts[lind]
v2 = verts[lind + 1]
cv2.line(im, (int(v1[0]), int(v1[1])), (int(v2[0]), int(v2[1])), color,
thickness)
def draw_bev(canvas_bev,
z3d,
l3d,
w3d,
x3d,
ry3d,
color=(0, 200, 200),
scale=1,
thickness=2):
"""
:param canvas_bev:
:param z3d:
:param l3d:
:param w3d:
:param x3d:
:param ry3d:
:param color:
:param scale:
:param thickness:
:return:
"""
w = l3d * scale
l = w3d * scale
x = x3d * scale
z = z3d * scale
r = ry3d * -1
corners1 = np.array([[-w / 2, -l / 2, 1], [+w / 2, -l / 2, 1],
[+w / 2, +l / 2, 1], [-w / 2, +l / 2, 1]])
ry = np.array([
[+math.cos(r), -math.sin(r), 0],
[+math.sin(r), math.cos(r), 0],
[0, 0, 1],
])
corners2 = ry.dot(corners1.T).T
corners2[:, 0] += w / 2 + x + canvas_bev.shape[1] / 2
corners2[:, 1] += l / 2 + z
draw_line(
canvas_bev, corners2[0], corners2[1], color=color, thickness=thickness)
draw_line(
canvas_bev, corners2[1], corners2[2], color=color, thickness=thickness)
draw_line(
canvas_bev, corners2[2], corners2[3], color=color, thickness=thickness)
draw_line(
canvas_bev, corners2[3], corners2[0], color=color, thickness=thickness)
def draw_line(im, v1, v2, color=(0, 200, 200), thickness=1):
"""
:param im:
:param v1:
:param v2:
:param color:
:param thickness:
:return:
"""
cv2.line(im, (int(v1[0]), int(v1[1])), (int(v2[0]), int(v2[1])), color,
thickness)
def draw_circle(im,
pos,
radius=5,
thickness=1,
color=(250, 100, 100),
fill=True):
"""
:param im:
:param pos:
:param radius:
:param thickness:
:param color:
:param fill:
:return:
"""
if fill: thickness = -1
cv2.circle(
im, (int(pos[0]), int(pos[1])),
radius,
color=color,
thickness=thickness)
def draw_2d_box(im, box, color=(0, 200, 200), thickness=1):
"""
:param im:
:param box:
:param color:
:param thickness:
:return:
"""
x = box[0]
y = box[1]
w = box[2]
h = box[3]
x2 = (x + w) - 1
y2 = (y + h) - 1
cv2.rectangle(im, (int(x), int(y)), (int(x2), int(y2)), color, thickness)
def imshow(im, fig_num=None):
"""
:param im:
:param fig_num:
:return:
"""
if fig_num is not None: plt.figure(fig_num)
if len(im.shape) == 2:
im = np.tile(im, [3, 1, 1]).transpose([1, 2, 0])
plt.imshow(cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_RGB2BGR))
plt.show(block=False)
def imwrite(im, path):
"""
:param im:
:param path:
:return:
"""
cv2.imwrite(path, im)
def imread(path):
"""
:param path:
:return:
"""
return cv2.imread(path)
def draw_tick_marks(im, ticks):
"""
:param im:
:param ticks:
:return:
"""
ticks_loc = list(
range(0, im.shape[0] + 1, int((im.shape[0]) / (len(ticks) - 1))))
for tind, tick in enumerate(ticks):
y = min(max(ticks_loc[tind], 50), im.shape[0] - 10)
x = im.shape[1] - 115
draw_text(
im,
'-{}m'.format(tick), (x, y),
lineType=2,
scale=1.1,
bg_color=None)
def draw_text(im,
text,
pos,
scale=0.4,
color=(0, 0, 0),
font=cv2.FONT_HERSHEY_SIMPLEX,
bg_color=(0, 255, 255),
blend=0.33,
lineType=1):
"""
:param im:
:param text:
:param pos:
:param scale:
:param color:
:param font:
:param bg_color:
:param blend:
:param lineType:
:return:
"""
pos = [int(pos[0]), int(pos[1])]
if bg_color is not None:
text_size, _ = cv2.getTextSize(text, font, scale, lineType)
x_s = int(np.clip(pos[0], a_min=0, a_max=im.shape[1]))
x_e = int(
np.clip(
pos[0] + text_size[0] - 1 + 4, a_min=0, a_max=im.shape[1]))
y_s = int(
np.clip(
pos[1] - text_size[1] - 2, a_min=0, a_max=im.shape[0]))
y_e = int(np.clip(pos[1] + 1 - 2, a_min=0, a_max=im.shape[0]))
im[y_s:y_e + 1, x_s:x_e + 1, 0] = im[
y_s:y_e + 1, x_s:x_e + 1, 0] * blend + bg_color[0] * (1 - blend)
im[y_s:y_e + 1, x_s:x_e + 1, 1] = im[
y_s:y_e + 1, x_s:x_e + 1, 1] * blend + bg_color[1] * (1 - blend)
im[y_s:y_e + 1, x_s:x_e + 1, 2] = im[
y_s:y_e + 1, x_s:x_e + 1, 2] * blend + bg_color[2] * (1 - blend)
pos[0] = int(np.clip(pos[0] + 2, a_min=0, a_max=im.shape[1]))
pos[1] = int(np.clip(pos[1] - 2, a_min=0, a_max=im.shape[0]))
cv2.putText(im, text, tuple(pos), font, scale, color, lineType)
# Calculates rotation matrix to euler angles
# The result is the same as MATLAB except the order
# of the euler angles ( x and z are swapped ).
# adopted from https://www.learnopencv.com/rotation-matrix-to-euler-angles/
def mat2euler(R):
"""
:param R:
:return:
"""
sy = math.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0])
singular = sy < 1e-6
if not singular:
x = math.atan2(R[2, 1], R[2, 2])
y = math.atan2(-R[2, 0], sy)
z = math.atan2(R[1, 0], R[0, 0])
else:
raise ValueError('singular matrix found in mat2euler')
return np.array([x, y, z])
def fig_to_im(fig):
"""
:param fig:
:return:
"""
fig.canvas.draw()
# Get the RGBA buffer from the figure
w, h = fig.canvas.get_width_height()
buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8)
buf.shape = (w, h, 4)
# canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
buf = np.roll(buf, 3, axis=2)
w, h, d = buf.shape
im_pil = Image.frombytes("RGBA", (w, h), buf.tostring())
im_np = np.array(im_pil)[:, :, :3]
return im_np
def imzoom(im, zoom=0):
"""
:param im:
:param zoom:
:return:
"""
# single value passed in for both axis?
# extend same val for w, h
zoom = np.array(zoom)
if zoom.size == 1: zoom = np.array([zoom, zoom])
zoom = np.clip(zoom, a_min=0, a_max=0.99)
cx = im.shape[1] / 2
cy = im.shape[0] / 2
w = im.shape[1] * (1 - zoom[0])
h = im.shape[0] * (1 - zoom[-1])
x1 = int(np.clip(cx - w / 2, a_min=0, a_max=im.shape[1] - 1))
x2 = int(np.clip(cx + w / 2, a_min=0, a_max=im.shape[1] - 1))
y1 = int(np.clip(cy - h / 2, a_min=0, a_max=im.shape[0] - 1))
y2 = int(np.clip(cy + h / 2, a_min=0, a_max=im.shape[0] - 1))
im = im[y1:y2 + 1, x1:x2 + 1, :]
return im
def imhstack(im1, im2):
"""
:param im1:
:param im2:
:return:
"""
sf = im1.shape[0] / im2.shape[0]
if sf > 1:
im2 = cv2.resize(im2, (int(im2.shape[1] / sf), im1.shape[0]))
else:
im1 = cv2.resize(im1, (int(im1.shape[1] / sf), im2.shape[0]))
im_concat = np.hstack((im1, im2))
return im_concat
# Calculates Rotation Matrix given euler angles.
# adopted from https://www.learnopencv.com/rotation-matrix-to-euler-angles/
def euler2mat(x, y, z):
"""
:param x:
:param y:
:param z:
:return:
"""
R_x = np.array([[1, 0, 0], [0, math.cos(x), -math.sin(x)],
[0, math.sin(x), math.cos(x)]])
R_y = np.array([[math.cos(y), 0, math.sin(y)], [0, 1, 0],
[-math.sin(y), 0, math.cos(y)]])
R_z = np.array([[math.cos(z), -math.sin(z), 0],
[math.sin(z), math.cos(z), 0], [0, 0, 1]])
R = np.dot(R_z, np.dot(R_y, R_x))
return R
def convertAlpha2Rot(alpha, z3d, x3d):
"""
:param alpha:
:param z3d:
:param x3d:
:return:
"""
ry3d = alpha + math.atan2(-z3d, x3d) + 0.5 * math.pi
while ry3d > math.pi:
ry3d -= math.pi * 2
while ry3d < (-math.pi):
ry3d += math.pi * 2
return ry3d
def convertRot2Alpha(ry3d, z3d, x3d):
"""
convertRot2Alpha
"""
alpha = ry3d - math.atan2(-z3d, x3d) - 0.5 * math.pi
while alpha > math.pi:
alpha -= math.pi * 2
while alpha < (-math.pi):
alpha += math.pi * 2
return alpha
"""
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
from .densenet import densenet121
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""densenet backbone"""
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
def _bn_function_factory(norm, conv):
def bn_function(*inputs):
concated_features = fluid.layers.concat(inputs, 1)
bottleneck_output = conv(norm(concated_features))
return bottleneck_output
return bn_function
class _DenseLayer_rpn(fluid.dygraph.Layer):
def __init__(self,
num_input_features,
growth_rate,
bn_size,
drop_rate,
memory_efficient=False):
super(_DenseLayer_rpn, self).__init__()
self.add_sublayer('norm1', BatchNorm(num_input_features, act='relu'))
self.add_sublayer(
'conv1',
Conv2D(
num_input_features,
bn_size * growth_rate,
filter_size=1,
stride=1,
bias_attr=False))
self.add_sublayer('norm2', BatchNorm(bn_size * growth_rate, act='relu'))
self.add_sublayer(
'conv2',
Conv2D(
bn_size * growth_rate,
growth_rate,
filter_size=3,
dilation=2,
stride=1,
padding=2,
bias_attr=False))
self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient
def forward(self, *prev_features):
bn_function = _bn_function_factory(self.norm1, self.conv1)
bottleneck_output = bn_function(*prev_features)
new_features = self.conv2(self.norm2(bottleneck_output))
if self.drop_rate > 0:
new_features = fluid.layers.dropout(new_features, self.drop_rate)
return new_features
class _DenseBlock_rpn(fluid.dygraph.Layer):
def __init__(self,
num_layers,
num_input_features,
bn_size,
growth_rate,
drop_rate,
memory_efficient=False):
super(_DenseBlock_rpn, self).__init__()
self.res_out_list = []
for i in range(num_layers):
layer = _DenseLayer_rpn(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate,
memory_efficient=memory_efficient, )
res_out = self.add_sublayer('denselayer%d' % (i + 1), layer)
self.res_out_list.append(res_out)
def forward(self, init_features):
features = [init_features]
for layer in self.res_out_list:
new_features = layer(*features)
features.append(new_features)
return fluid.layers.concat(features, axis=1)
class _DenseLayer(fluid.dygraph.Layer):
def __init__(self,
num_input_features,
growth_rate,
bn_size,
drop_rate,
memory_efficient=False):
super(_DenseLayer, self).__init__()
self.add_sublayer('norm1', BatchNorm(num_input_features, act='relu'))
self.add_sublayer(
'conv1',
Conv2D(
num_input_features,
bn_size * growth_rate,
filter_size=1,
stride=1,
bias_attr=False))
self.add_sublayer('norm2', BatchNorm(bn_size * growth_rate, act='relu'))
self.add_sublayer(
'conv2',
Conv2D(
bn_size * growth_rate,
growth_rate,
filter_size=3,
stride=1,
padding=1,
bias_attr=False))
self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient
def forward(self, *prev_features):
bn_function = _bn_function_factory(self.norm1, self.conv1)
bottleneck_output = bn_function(*prev_features)
new_features = self.conv2(self.norm2(bottleneck_output))
if self.drop_rate > 0:
new_features = fluid.layers.dropout(new_features, self.drop_rate)
return new_features
class _DenseBlock(fluid.dygraph.Layer):
def __init__(self,
num_layers,
num_input_features,
bn_size,
growth_rate,
drop_rate,
memory_efficient=False):
super(_DenseBlock, self).__init__()
self.res_out_list = []
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate,
memory_efficient=memory_efficient, )
res_out = self.add_sublayer('denselayer%d' % (i + 1), layer)
self.res_out_list.append(res_out)
def forward(self, init_features):
features = [init_features]
for layer in self.res_out_list:
new_features = layer(*features)
features.append(new_features)
return fluid.layers.concat(features, axis=1)
class _Transition(Sequential):
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.add_sublayer('norm', BatchNorm(num_input_features, act='relu'))
self.add_sublayer(
'conv',
Conv2D(
num_input_features,
num_output_features,
filter_size=1,
stride=1,
bias_attr=False))
self.add_sublayer(
'pool', Pool2D(
pool_size=2, pool_stride=2, pool_type='avg'))
class DenseNet(fluid.dygraph.Layer):
"""
"""
def __init__(self,
growth_rate=32,
block_config=(6, 12, 24, 16),
num_init_features=64,
bn_size=4,
drop_rate=0,
num_classes=1000,
memory_efficient=False):
super(DenseNet, self).__init__()
self.features = Sequential(
('conv0', Conv2D(
3,
num_init_features,
filter_size=7,
stride=2,
padding=3,
bias_attr=False)), ('norm0', BatchNorm(
num_init_features, act='relu')), ('pool0', Pool2D(
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')))
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
if i == 3:
block = _DenseBlock_rpn(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate,
memory_efficient=memory_efficient)
self.features.add_sublayer('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
else:
block = _DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate,
memory_efficient=memory_efficient)
self.features.add_sublayer('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(
num_input_features=num_features,
num_output_features=num_features // 2)
self.features.add_sublayer('transition%d' % (i + 1), trans)
num_features = num_features // 2
# Final batch norm
self.features.add_sublayer('norm5', BatchNorm(num_features))
# Linear layer
self.classifier = Linear(num_features, num_classes)
# init num_features
self.features.num_features = num_features
def forward(self, x):
features = self.features(x)
out = fluid.layers.relu(features)
out = fluid.layers.adaptive_pool2d(
input=out, pool_size=[1, 1], pool_type='avg')
out = fluid.layers.flatten(out, 1)
out = self.classifier(out)
return out
def _densenet(arch, growth_rate, block_config, num_init_features, **kwargs):
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
return model
def densenet121(**kwargs):
return _densenet('densenet121', 32, (6, 12, 24, 16), 64, **kwargs)
def densenet161(**kwargs):
return _densenet('densenet161', 48, (6, 12, 36, 24), 96, **kwargs)
def densenet169(**kwargs):
return _densenet('densenet169', 32, (6, 12, 32, 32), 64, **kwargs)
def densenet201(**kwargs):
return _densenet('densenet201', 32, (6, 12, 48, 32), 64, **kwargs)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
model_3d_dilate
"""
from lib.rpn_util import *
from models.backbone.densenet import densenet121
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.initializer import Normal
import math
def initial_type(name,
input_channels,
init="kaiming",
use_bias=False,
filter_size=0,
stddev=0.02):
if init == "kaiming":
fan_in = input_channels * filter_size * filter_size
bound = 1 / math.sqrt(fan_in)
param_attr = fluid.ParamAttr(
name=name + "_weight",
initializer=fluid.initializer.Uniform(
low=-bound, high=bound))
if use_bias == True:
bias_attr = fluid.ParamAttr(
name=name + '_bias',
initializer=fluid.initializer.Uniform(
low=-bound, high=bound))
else:
bias_attr = False
else:
param_attr = fluid.ParamAttr(
name=name + "_weight",
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=stddev))
if use_bias == True:
bias_attr = fluid.ParamAttr(
name=name + "_bias",
initializer=fluid.initializer.Constant(0.0))
else:
bias_attr = False
return param_attr, bias_attr
class ConvLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
padding=0,
stride=1,
groups=None,
act=None,
name=None):
super(ConvLayer, self).__init__()
param_attr, bias_attr = initial_type(
name=name,
input_channels=num_channels,
use_bias=True,
filter_size=filter_size)
self.num_filters = num_filters
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
padding=padding,
stride=stride,
groups=groups,
act=act,
param_attr=param_attr,
bias_attr=bias_attr)
def forward(self, inputs):
x = self._conv(inputs)
return x
class RPN(fluid.dygraph.Layer):
def __init__(self, phase, base, conf):
super(RPN, self).__init__()
self.base = base
del self.base.transition3.pool
self.phase = phase
self.num_classes = len(conf['lbls']) + 1
self.num_anchors = conf['anchors'].shape[0]
self.prop_feats = ConvLayer(
num_channels=self.base.num_features,
num_filters=512,
filter_size=3,
padding=1,
act='relu',
name='rpn_prop_feats')
self.cls = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_classes * self.num_anchors,
filter_size=1,
name='rpn_cls')
self.bbox_x = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_x')
self.bbox_y = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_y')
self.bbox_w = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_w')
self.bbox_h = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_h')
self.bbox_x3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_x3d')
self.bbox_y3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_y3d')
self.bbox_z3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_z3d')
self.bbox_w3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_w3d')
self.bbox_h3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_h3d')
self.bbox_l3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_l3d')
self.bbox_rY3d = ConvLayer(
num_channels=self.prop_feats.num_filters,
num_filters=self.num_anchors,
filter_size=1,
name='rpn_bbox_rY3d')
self.feat_stride = conf.feat_stride
self.feat_size = calc_output_size(
np.array(conf.crop_size), self.feat_stride)
self.rois = locate_anchors(conf.anchors, self.feat_size,
conf.feat_stride)
self.anchors = conf.anchors
def forward(self, inputs):
# backbone
x = self.base(inputs)
prop_feats = self.prop_feats(x)
cls = self.cls(prop_feats)
# bbox 2d
bbox_x = self.bbox_x(prop_feats)
bbox_y = self.bbox_y(prop_feats)
bbox_w = self.bbox_w(prop_feats)
bbox_h = self.bbox_h(prop_feats)
# bbox 3d
bbox_x3d = self.bbox_x3d(prop_feats)
bbox_y3d = self.bbox_y3d(prop_feats)
bbox_z3d = self.bbox_z3d(prop_feats)
bbox_w3d = self.bbox_w3d(prop_feats)
bbox_h3d = self.bbox_h3d(prop_feats)
bbox_l3d = self.bbox_l3d(prop_feats)
bbox_rY3d = self.bbox_rY3d(prop_feats)
batch_size, c, feat_h, feat_w = cls.shape
feat_size = fluid.layers.shape(cls)[2:4]
# reshape for cross entropy
cls = fluid.layers.reshape(
x=cls,
shape=[
batch_size, self.num_classes, feat_h * self.num_anchors, feat_w
])
# score probabilities
prob = fluid.layers.softmax(cls, axis=1)
# reshape for consistency
bbox_x = flatten_tensor(
fluid.layers.reshape(
x=bbox_x,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_y = flatten_tensor(
fluid.layers.reshape(
x=bbox_y,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_w = flatten_tensor(
fluid.layers.reshape(
x=bbox_w,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_h = flatten_tensor(
fluid.layers.reshape(
x=bbox_h,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_x3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_x3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_y3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_y3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_z3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_z3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_w3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_w3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_h3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_h3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_l3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_l3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
bbox_rY3d = flatten_tensor(
fluid.layers.reshape(
x=bbox_rY3d,
shape=[batch_size, 1, feat_h * self.num_anchors, feat_w]))
# bundle
bbox_2d = fluid.layers.concat(
input=[bbox_x, bbox_y, bbox_w, bbox_h], axis=2)
bbox_3d = fluid.layers.concat(
input=[
bbox_x3d, bbox_y3d, bbox_z3d, bbox_w3d, bbox_h3d, bbox_l3d,
bbox_rY3d
],
axis=2)
cls = flatten_tensor(cls)
prob = flatten_tensor(prob)
if self.phase == "train":
return cls, prob, bbox_2d, bbox_3d, feat_size
else:
if self.feat_size[0] != feat_h or self.feat_size[1] != feat_w:
#self.feat_size = [feat_h, feat_w]
#self.rois = locate_anchors(self.anchors, self.feat_size, self.feat_stride)
self.rois = locate_anchors(self.anchors, [feat_h, feat_w],
self.feat_stride)
return cls, prob, bbox_2d, bbox_3d, feat_size, self.rois
def build(conf, backbone, phase='train'):
train = phase.lower() == 'train'
if backbone.lower() == "densenet121":
model_backbone = densenet121() # pretrain
rpn_net = RPN(phase, model_backbone.features, conf)
# pretrain
if 'pretrained' in conf and conf.pretrained is not None:
print("load pretrain model from ", conf.pretrained)
pretrained, _ = fluid.load_dygraph(conf.pretrained)
rpn_net.base.set_dict(pretrained, use_structured_name=True)
if train: rpn_net.train()
else: rpn_net.eval()
return rpn_net
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import requests
import time
import functools
import tarfile
import shutil
lasttime = time.time()
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
FLUSH_INTERVAL = 0.1
def progress(str, end=False):
global lasttime
if end:
str += "\n"
lasttime = 0
if time.time() - lasttime >= FLUSH_INTERVAL:
sys.stdout.write("\r%s" % str)
lasttime = time.time()
sys.stdout.flush()
def _download_file(url, savepath, print_progress):
r = requests.get(url, stream=True)
total_length = r.headers.get('content-length')
if total_length is None:
with open(savepath, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(savepath, 'wb') as f:
dl = 0
total_length = int(total_length)
starttime = time.time()
if print_progress:
print("Downloading %s" % os.path.basename(savepath))
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
if print_progress:
done = int(50 * dl / total_length)
progress("[%-50s] %.2f%%" %
('=' * done, float(100 * dl) / total_length))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
def _uncompress_file(filepath, extrapath, delete_file, print_progress):
if print_progress:
print("Uncompress %s" % os.path.basename(filepath))
if filepath.endswith("zip"):
handler = _uncompress_file_zip
elif filepath.endswith("tgz"):
handler = _uncompress_file_tar
else:
handler = functools.partial(_uncompress_file_tar, mode="r")
for total_num, index, rootpath in handler(filepath, extrapath):
if print_progress:
done = int(50 * float(index) / total_num)
progress("[%-50s] %.2f%%" %
('=' * done, float(100 * index) / total_num))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
if delete_file:
os.remove(filepath)
return rootpath
def _uncompress_file_zip(filepath, extrapath):
files = zipfile.ZipFile(filepath, 'r')
filelist = files.namelist()
rootpath = filelist[0]
total_num = len(filelist)
for index, file in enumerate(filelist):
files.extract(file, extrapath)
yield total_num, index, rootpath
files.close()
yield total_num, index, rootpath
def _uncompress_file_tar(filepath, extrapath, mode="r:gz"):
files = tarfile.open(filepath, mode)
filelist = files.getnames()
total_num = len(filelist)
rootpath = filelist[0]
for index, file in enumerate(filelist):
files.extract(file, extrapath)
yield total_num, index, rootpath
files.close()
yield total_num, index, rootpath
def download_file_and_uncompress(url,
savepath=None,
extrapath=None,
extraname=None,
print_progress=True,
cover=False,
delete_file=True):
if savepath is None:
savepath = "."
if extrapath is None:
extrapath = "."
savename = url.split("/")[-1]
savepath = os.path.join(savepath, savename)
savename = ".".join(savename.split(".")[:-1])
savename = os.path.join(extrapath, savename)
extraname = savename if extraname is None else os.path.join(extrapath,
extraname)
if cover:
if os.path.exists(savepath):
shutil.rmtree(savepath)
if os.path.exists(savename):
shutil.rmtree(savename)
if os.path.exists(extraname):
shutil.rmtree(extraname)
if not os.path.exists(extraname):
if not os.path.exists(savename):
if not os.path.exists(savepath):
_download_file(url, savepath, print_progress)
savename = _uncompress_file(savepath, extrapath, delete_file,
print_progress)
savename = os.path.join(extrapath, savename)
shutil.move(savename, extraname)
model_urls = {
"densenet121":
"https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet121_pretrained.tar",
"resnet101":
"http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar",
}
if __name__ == "__main__":
if len(sys.argv) != 2:
print("usage:\n python download_model.py ${MODEL_NAME}")
exit(1)
model_name = sys.argv[1]
if not model_name in model_urls.keys():
print("Only support: \n {}".format("\n ".join(
list(model_urls.keys()))))
exit(1)
url = model_urls[model_name]
download_file_and_uncompress(
url=url,
savepath=LOCAL_PATH,
extrapath=LOCAL_PATH,
extraname=model_name)
print("Pretrained Model download success!")
export CUDA_VISIBLE_DEVICES=0
python test.py --data_dir dataset --conf_path output/kitti_3d_multi_warmup/conf.pkl --weights_path output/kitti_3d_multi_warmup/epoch1.pdparams
export CUDA_VISIBLE_DEVICES=1
export export FLAGS_fraction_of_gpu_memory_to_use=0.1
python train.py --data_dir dataset --conf kitti_3d_multi_warmup
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""test"""
import os
import sys
import argparse
import ast
import logging
import paddle.fluid as fluid
import paddle.fluid.framework as framework
from models import *
from easydict import EasyDict as edict
from lib.rpn_util import *
sys.path.append(os.getcwd())
import lib.core as core
from lib.util import *
import pdb
import paddle
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid import framework
logging.root.handlers = []
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
"""parse"""
parser = argparse.ArgumentParser("M3D-RPN train script")
parser.add_argument("--conf_path", type=str, default='', help="config.pkl")
parser.add_argument(
'--weights_path', type=str, default='', help='weights save path')
parser.add_argument(
'--backbone',
type=str,
default='DenseNet121',
help='backbone model to train, default DenseNet121')
parser.add_argument(
'--data_dir', type=str, default='dataset', help='dataset directory')
args = parser.parse_args()
return args
def test():
"""main train"""
args = parse_args()
# load config
conf = edict(pickle_read(args.conf_path))
conf.pretrained = None
results_path = os.path.join('output', 'tmp_results', 'data')
# make directory
mkdir_if_missing(results_path, delete_if_exist=True)
with fluid.dygraph.guard(fluid.CUDAPlace(0)):
# training network
src_path = os.path.join('.', 'models', conf.model + '.py')
train_model = absolute_import(src_path)
train_model = train_model.build(conf, args.backbone, 'train')
train_model.eval()
train_model.phase = "eval"
Already_trained, _ = fluid.load_dygraph(args.weights_path)
print("loaded model from ", args.weights_path)
train_model.set_dict(Already_trained) #, use_structured_name=True)
print("start evaluation...")
test_kitti_3d(conf.dataset_test, train_model, conf, results_path,
args.data_dir)
print("Evaluation Finished!")
if __name__ == '__main__':
test()
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册