未验证 提交 8d6e1137 编写于 作者: Z zhiboniu 提交者: GitHub

delay coco data load (#4095)

上级 f74aa666
...@@ -63,6 +63,9 @@ class KeypointBottomUpBaseDataset(DetDataset): ...@@ -63,6 +63,9 @@ class KeypointBottomUpBaseDataset(DetDataset):
self.ann_info['num_joints'] = num_joints self.ann_info['num_joints'] = num_joints
self.img_ids = [] self.img_ids = []
def parse_dataset(self):
pass
def __len__(self): def __len__(self):
"""Get dataset length.""" """Get dataset length."""
return len(self.img_ids) return len(self.img_ids)
...@@ -136,26 +139,30 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset): ...@@ -136,26 +139,30 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
super().__init__(dataset_dir, image_dir, anno_path, num_joints, super().__init__(dataset_dir, image_dir, anno_path, num_joints,
transform, shard, test_mode) transform, shard, test_mode)
ann_file = os.path.join(dataset_dir, anno_path) self.ann_file = os.path.join(dataset_dir, anno_path)
self.coco = COCO(ann_file) self.shard = shard
self.test_mode = test_mode
def parse_dataset(self):
self.coco = COCO(self.ann_file)
self.img_ids = self.coco.getImgIds() self.img_ids = self.coco.getImgIds()
if not test_mode: if not self.test_mode:
self.img_ids = [ self.img_ids = [
img_id for img_id in self.img_ids img_id for img_id in self.img_ids
if len(self.coco.getAnnIds( if len(self.coco.getAnnIds(
imgIds=img_id, iscrowd=None)) > 0 imgIds=img_id, iscrowd=None)) > 0
] ]
blocknum = int(len(self.img_ids) / shard[1]) blocknum = int(len(self.img_ids) / self.shard[1])
self.img_ids = self.img_ids[(blocknum * shard[0]):(blocknum * (shard[0] self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
+ 1))] self.shard[0] + 1))]
self.num_images = len(self.img_ids) self.num_images = len(self.img_ids)
self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs) self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
self.dataset_name = 'coco' self.dataset_name = 'coco'
cat_ids = self.coco.getCatIds() cat_ids = self.coco.getCatIds()
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)}) self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
print(f'=> num_images: {self.num_images}') print('=> num_images: {}'.format(self.num_images))
@staticmethod @staticmethod
def _get_mapping_id_name(imgs): def _get_mapping_id_name(imgs):
...@@ -301,20 +308,23 @@ class KeypointBottomUpCrowdPoseDataset(KeypointBottomUpCocoDataset): ...@@ -301,20 +308,23 @@ class KeypointBottomUpCrowdPoseDataset(KeypointBottomUpCocoDataset):
super().__init__(dataset_dir, image_dir, anno_path, num_joints, super().__init__(dataset_dir, image_dir, anno_path, num_joints,
transform, shard, test_mode) transform, shard, test_mode)
ann_file = os.path.join(dataset_dir, anno_path) self.ann_file = os.path.join(dataset_dir, anno_path)
self.shard = shard
self.test_mode = test_mode
self.coco = COCO(ann_file) def parse_dataset(self):
self.coco = COCO(self.ann_file)
self.img_ids = self.coco.getImgIds() self.img_ids = self.coco.getImgIds()
if not test_mode: if not self.test_mode:
self.img_ids = [ self.img_ids = [
img_id for img_id in self.img_ids img_id for img_id in self.img_ids
if len(self.coco.getAnnIds( if len(self.coco.getAnnIds(
imgIds=img_id, iscrowd=None)) > 0 imgIds=img_id, iscrowd=None)) > 0
] ]
blocknum = int(len(self.img_ids) / shard[1]) blocknum = int(len(self.img_ids) / self.shard[1])
self.img_ids = self.img_ids[(blocknum * shard[0]):(blocknum * (shard[0] self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
+ 1))] self.shard[0] + 1))]
self.num_images = len(self.img_ids) self.num_images = len(self.img_ids)
self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs) self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册