From 7e04419658c05815b837ec9e822160fa1d693a7d Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Thu, 3 Dec 2020 11:34:57 +0800 Subject: [PATCH] fix PadBatch (#1800) * fix PadBatch * fix PadBatchOp --- ppdet/data/transform/batch_operator.py | 25 ++++++++++++++----------- ppdet/data/transform/batch_operators.py | 24 +++++++++++++----------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/ppdet/data/transform/batch_operator.py b/ppdet/data/transform/batch_operator.py index 870f1ad3b..fa4c9e9ca 100644 --- a/ppdet/data/transform/batch_operator.py +++ b/ppdet/data/transform/batch_operator.py @@ -94,7 +94,8 @@ class PadBatchOp(BaseOperator): if self.pad_gt: gt_num = [] - if data['gt_poly'] is not None and len(data['gt_poly']) > 0: + if 'gt_poly' in data and data['gt_poly'] is not None and len(data[ + 'gt_poly']) > 0: pad_mask = True else: pad_mask = False @@ -112,18 +113,20 @@ class PadBatchOp(BaseOperator): for p_p in poly: point_num.append(int(len(p_p) / 2)) gt_num_max = max(gt_num) - gt_box_data = np.zeros([gt_num_max, 4]) - gt_class_data = np.zeros([gt_num_max]) - is_crowd_data = np.ones([gt_num_max]) - - if pad_mask: - poly_num_max = max(poly_num) - poly_part_num_max = max(poly_part_num) - point_num_max = max(point_num) - gt_masks_data = -np.ones( - [poly_num_max, poly_part_num_max, point_num_max, 2]) for i, data in enumerate(samples): + gt_box_data = np.zeros([gt_num_max, 4], dtype=np.float32) + gt_class_data = np.zeros([gt_num_max], dtype=np.int32) + is_crowd_data = np.ones([gt_num_max], dtype=np.int32) + + if pad_mask: + poly_num_max = max(poly_num) + poly_part_num_max = max(poly_part_num) + point_num_max = max(point_num) + gt_masks_data = -np.ones( + [poly_num_max, poly_part_num_max, point_num_max, 2], + dtype=np.float32) + gt_num = data['gt_bbox'].shape[0] gt_box_data[0:gt_num, :] = data['gt_bbox'] gt_class_data[0:gt_num] = np.squeeze(data['gt_class']) diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index d4c1c81d8..dbb640496 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -85,7 +85,8 @@ class PadBatch(BaseOperator): data['im_info'][:2] = max_shape[1:3] if self.pad_gt: gt_num = [] - if data['gt_poly'] is not None and len(data['gt_poly']) > 0: + if 'gt_poly' in data and data['gt_poly'] is not None and len(data[ + 'gt_poly']) > 0: pad_mask = True else: pad_mask = False @@ -103,18 +104,19 @@ class PadBatch(BaseOperator): for p_p in poly: point_num.append(int(len(p_p) / 2)) gt_num_max = max(gt_num) - gt_box_data = np.zeros([gt_num_max, 4]) - gt_class_data = np.zeros([gt_num_max]) - is_crowd_data = np.ones([gt_num_max]) - - if pad_mask: - poly_num_max = max(poly_num) - poly_part_num_max = max(poly_part_num) - point_num_max = max(point_num) - gt_masks_data = -np.ones( - [poly_num_max, poly_part_num_max, point_num_max, 2]) for i, data in enumerate(samples): + gt_box_data = np.zeros([gt_num_max, 4], dtype=np.float32) + gt_class_data = np.zeros([gt_num_max], dtype=np.int32) + is_crowd_data = np.ones([gt_num_max], dtype=np.int32) + if pad_mask: + poly_num_max = max(poly_num) + poly_part_num_max = max(poly_part_num) + point_num_max = max(point_num) + gt_masks_data = -np.ones( + [poly_num_max, poly_part_num_max, point_num_max, 2], + dtype=np.float32) + gt_num = data['gt_bbox'].shape[0] gt_box_data[0:gt_num, :] = data['gt_bbox'] gt_class_data[0:gt_num] = np.squeeze(data['gt_class']) -- GitLab