未验证 提交 8d6e1137 编写于 作者: Z zhiboniu 提交者: GitHub

delay coco data load (#4095)

上级 f74aa666
......@@ -63,6 +63,9 @@ class KeypointBottomUpBaseDataset(DetDataset):
self.ann_info['num_joints'] = num_joints
self.img_ids = []
def parse_dataset(self):
pass
def __len__(self):
"""Get dataset length."""
return len(self.img_ids)
......@@ -136,26 +139,30 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
super().__init__(dataset_dir, image_dir, anno_path, num_joints,
transform, shard, test_mode)
ann_file = os.path.join(dataset_dir, anno_path)
self.coco = COCO(ann_file)
self.ann_file = os.path.join(dataset_dir, anno_path)
self.shard = shard
self.test_mode = test_mode
def parse_dataset(self):
self.coco = COCO(self.ann_file)
self.img_ids = self.coco.getImgIds()
if not test_mode:
if not self.test_mode:
self.img_ids = [
img_id for img_id in self.img_ids
if len(self.coco.getAnnIds(
imgIds=img_id, iscrowd=None)) > 0
]
blocknum = int(len(self.img_ids) / shard[1])
self.img_ids = self.img_ids[(blocknum * shard[0]):(blocknum * (shard[0]
+ 1))]
blocknum = int(len(self.img_ids) / self.shard[1])
self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
self.shard[0] + 1))]
self.num_images = len(self.img_ids)
self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
self.dataset_name = 'coco'
cat_ids = self.coco.getCatIds()
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
print(f'=> num_images: {self.num_images}')
print('=> num_images: {}'.format(self.num_images))
@staticmethod
def _get_mapping_id_name(imgs):
......@@ -301,20 +308,23 @@ class KeypointBottomUpCrowdPoseDataset(KeypointBottomUpCocoDataset):
super().__init__(dataset_dir, image_dir, anno_path, num_joints,
transform, shard, test_mode)
ann_file = os.path.join(dataset_dir, anno_path)
self.ann_file = os.path.join(dataset_dir, anno_path)
self.shard = shard
self.test_mode = test_mode
self.coco = COCO(ann_file)
def parse_dataset(self):
self.coco = COCO(self.ann_file)
self.img_ids = self.coco.getImgIds()
if not test_mode:
if not self.test_mode:
self.img_ids = [
img_id for img_id in self.img_ids
if len(self.coco.getAnnIds(
imgIds=img_id, iscrowd=None)) > 0
]
blocknum = int(len(self.img_ids) / shard[1])
self.img_ids = self.img_ids[(blocknum * shard[0]):(blocknum * (shard[0]
+ 1))]
blocknum = int(len(self.img_ids) / self.shard[1])
self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
self.shard[0] + 1))]
self.num_images = len(self.img_ids)
self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册