From 8d6e1137ff75bc55a701c8700b0f36414a1f3bb2 Mon Sep 17 00:00:00 2001 From: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Date: Tue, 31 Aug 2021 14:18:11 +0800 Subject: [PATCH] delay coco data load (#4095) --- ppdet/data/source/keypoint_coco.py | 36 +++++++++++++++++++----------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/ppdet/data/source/keypoint_coco.py b/ppdet/data/source/keypoint_coco.py index 5b7b99ee9..fdea57ada 100644 --- a/ppdet/data/source/keypoint_coco.py +++ b/ppdet/data/source/keypoint_coco.py @@ -63,6 +63,9 @@ class KeypointBottomUpBaseDataset(DetDataset): self.ann_info['num_joints'] = num_joints self.img_ids = [] + def parse_dataset(self): + pass + def __len__(self): """Get dataset length.""" return len(self.img_ids) @@ -136,26 +139,30 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset): super().__init__(dataset_dir, image_dir, anno_path, num_joints, transform, shard, test_mode) - ann_file = os.path.join(dataset_dir, anno_path) - self.coco = COCO(ann_file) + self.ann_file = os.path.join(dataset_dir, anno_path) + self.shard = shard + self.test_mode = test_mode + + def parse_dataset(self): + self.coco = COCO(self.ann_file) self.img_ids = self.coco.getImgIds() - if not test_mode: + if not self.test_mode: self.img_ids = [ img_id for img_id in self.img_ids if len(self.coco.getAnnIds( imgIds=img_id, iscrowd=None)) > 0 ] - blocknum = int(len(self.img_ids) / shard[1]) - self.img_ids = self.img_ids[(blocknum * shard[0]):(blocknum * (shard[0] - + 1))] + blocknum = int(len(self.img_ids) / self.shard[1]) + self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * ( + self.shard[0] + 1))] self.num_images = len(self.img_ids) self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs) self.dataset_name = 'coco' cat_ids = self.coco.getCatIds() self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)}) - print(f'=> num_images: {self.num_images}') + print('=> num_images: {}'.format(self.num_images)) @staticmethod def _get_mapping_id_name(imgs): @@ -301,20 +308,23 @@ class KeypointBottomUpCrowdPoseDataset(KeypointBottomUpCocoDataset): super().__init__(dataset_dir, image_dir, anno_path, num_joints, transform, shard, test_mode) - ann_file = os.path.join(dataset_dir, anno_path) + self.ann_file = os.path.join(dataset_dir, anno_path) + self.shard = shard + self.test_mode = test_mode - self.coco = COCO(ann_file) + def parse_dataset(self): + self.coco = COCO(self.ann_file) self.img_ids = self.coco.getImgIds() - if not test_mode: + if not self.test_mode: self.img_ids = [ img_id for img_id in self.img_ids if len(self.coco.getAnnIds( imgIds=img_id, iscrowd=None)) > 0 ] - blocknum = int(len(self.img_ids) / shard[1]) - self.img_ids = self.img_ids[(blocknum * shard[0]):(blocknum * (shard[0] - + 1))] + blocknum = int(len(self.img_ids) / self.shard[1]) + self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * ( + self.shard[0] + 1))] self.num_images = len(self.img_ids) self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs) -- GitLab