""" This code is based on https://github.com/sshaoshuai/PointRCNN/blob/master/lib/datasets/kitti_dataset.py """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import cv2 import numpy as np import utils.calibration as calibration from utils.object3d import get_objects_from_label from PIL import Image __all__ = ["KittiDataset"] class KittiDataset(object): def __init__(self, data_dir, split='train'): assert split in ['train', 'train_aug', 'val', 'test'], "unknown split {}".format(split) self.split = split self.is_test = self.split == 'test' self.imageset_dir = os.path.join(data_dir, 'KITTI', 'object', 'testing' if self.is_test else 'training') split_dir = os.path.join(data_dir, 'KITTI', 'ImageSets', split + '.txt') self.image_idx_list = [x.strip() for x in open(split_dir).readlines()] self.num_sample = self.image_idx_list.__len__() self.image_dir = os.path.join(self.imageset_dir, 'image_2') self.lidar_dir = os.path.join(self.imageset_dir, 'velodyne') self.calib_dir = os.path.join(self.imageset_dir, 'calib') self.label_dir = os.path.join(self.imageset_dir, 'label_2') self.plane_dir = os.path.join(self.imageset_dir, 'planes') def get_image(self, idx): img_file = os.path.join(self.image_dir, '%06d.png' % idx) assert os.path.exists(img_file) return cv2.imread(img_file) # (H, W, 3) BGR mode def get_image_shape(self, idx): img_file = os.path.join(self.image_dir, '%06d.png' % idx) assert os.path.exists(img_file) im = Image.open(img_file) width, height = im.size return height, width, 3 def get_lidar(self, idx): lidar_file = os.path.join(self.lidar_dir, '%06d.bin' % idx) assert os.path.exists(lidar_file) return np.fromfile(lidar_file, dtype=np.float32).reshape(-1, 4) def get_calib(self, idx): calib_file = os.path.join(self.calib_dir, '%06d.txt' % idx) assert os.path.exists(calib_file) return calibration.Calibration(calib_file) def get_label(self, idx): label_file = os.path.join(self.label_dir, '%06d.txt' % idx) assert os.path.exists(label_file) # return kitti_utils.get_objects_from_label(label_file) return get_objects_from_label(label_file) def get_road_plane(self, idx): plane_file = os.path.join(self.plane_dir, '%06d.txt' % idx) with open(plane_file, 'r') as f: lines = f.readlines() lines = [float(i) for i in lines[3].split()] plane = np.asarray(lines) # Ensure normal is always facing up, this is in the rectified camera coordinate if plane[1] > 0: plane = -plane norm = np.linalg.norm(plane[0:3]) plane = plane / norm return plane