diff --git a/ppdet/data/transform/batch_operator.py b/ppdet/data/transform/batch_operator.py index 870f1ad3b052abf6937f509476010f1c38580909..fa4c9e9ca580716898a564cb3bc32fcb743c0823 100644 --- a/ppdet/data/transform/batch_operator.py +++ b/ppdet/data/transform/batch_operator.py @@ -94,7 +94,8 @@ class PadBatchOp(BaseOperator): if self.pad_gt: gt_num = [] - if data['gt_poly'] is not None and len(data['gt_poly']) > 0: + if 'gt_poly' in data and data['gt_poly'] is not None and len(data[ + 'gt_poly']) > 0: pad_mask = True else: pad_mask = False @@ -112,18 +113,20 @@ class PadBatchOp(BaseOperator): for p_p in poly: point_num.append(int(len(p_p) / 2)) gt_num_max = max(gt_num) - gt_box_data = np.zeros([gt_num_max, 4]) - gt_class_data = np.zeros([gt_num_max]) - is_crowd_data = np.ones([gt_num_max]) - - if pad_mask: - poly_num_max = max(poly_num) - poly_part_num_max = max(poly_part_num) - point_num_max = max(point_num) - gt_masks_data = -np.ones( - [poly_num_max, poly_part_num_max, point_num_max, 2]) for i, data in enumerate(samples): + gt_box_data = np.zeros([gt_num_max, 4], dtype=np.float32) + gt_class_data = np.zeros([gt_num_max], dtype=np.int32) + is_crowd_data = np.ones([gt_num_max], dtype=np.int32) + + if pad_mask: + poly_num_max = max(poly_num) + poly_part_num_max = max(poly_part_num) + point_num_max = max(point_num) + gt_masks_data = -np.ones( + [poly_num_max, poly_part_num_max, point_num_max, 2], + dtype=np.float32) + gt_num = data['gt_bbox'].shape[0] gt_box_data[0:gt_num, :] = data['gt_bbox'] gt_class_data[0:gt_num] = np.squeeze(data['gt_class']) diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index d4c1c81d8a40303749ba381fe0aae92126c22916..dbb640496a8be6bb128ed74ccb61e3e08a853fad 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -85,7 +85,8 @@ class PadBatch(BaseOperator): data['im_info'][:2] = max_shape[1:3] if self.pad_gt: gt_num = [] - if data['gt_poly'] is not None and len(data['gt_poly']) > 0: + if 'gt_poly' in data and data['gt_poly'] is not None and len(data[ + 'gt_poly']) > 0: pad_mask = True else: pad_mask = False @@ -103,18 +104,19 @@ class PadBatch(BaseOperator): for p_p in poly: point_num.append(int(len(p_p) / 2)) gt_num_max = max(gt_num) - gt_box_data = np.zeros([gt_num_max, 4]) - gt_class_data = np.zeros([gt_num_max]) - is_crowd_data = np.ones([gt_num_max]) - - if pad_mask: - poly_num_max = max(poly_num) - poly_part_num_max = max(poly_part_num) - point_num_max = max(point_num) - gt_masks_data = -np.ones( - [poly_num_max, poly_part_num_max, point_num_max, 2]) for i, data in enumerate(samples): + gt_box_data = np.zeros([gt_num_max, 4], dtype=np.float32) + gt_class_data = np.zeros([gt_num_max], dtype=np.int32) + is_crowd_data = np.ones([gt_num_max], dtype=np.int32) + if pad_mask: + poly_num_max = max(poly_num) + poly_part_num_max = max(poly_part_num) + point_num_max = max(point_num) + gt_masks_data = -np.ones( + [poly_num_max, poly_part_num_max, point_num_max, 2], + dtype=np.float32) + gt_num = data['gt_bbox'].shape[0] gt_box_data[0:gt_num, :] = data['gt_bbox'] gt_class_data[0:gt_num] = np.squeeze(data['gt_class'])