未验证 提交 7e044196 编写于 作者: W wangguanzhong 提交者: GitHub

fix PadBatch (#1800)

* fix PadBatch

* fix PadBatchOp
上级 94f02683
......@@ -94,7 +94,8 @@ class PadBatchOp(BaseOperator):
if self.pad_gt:
gt_num = []
if data['gt_poly'] is not None and len(data['gt_poly']) > 0:
if 'gt_poly' in data and data['gt_poly'] is not None and len(data[
'gt_poly']) > 0:
pad_mask = True
else:
pad_mask = False
......@@ -112,18 +113,20 @@ class PadBatchOp(BaseOperator):
for p_p in poly:
point_num.append(int(len(p_p) / 2))
gt_num_max = max(gt_num)
gt_box_data = np.zeros([gt_num_max, 4])
gt_class_data = np.zeros([gt_num_max])
is_crowd_data = np.ones([gt_num_max])
for i, data in enumerate(samples):
gt_box_data = np.zeros([gt_num_max, 4], dtype=np.float32)
gt_class_data = np.zeros([gt_num_max], dtype=np.int32)
is_crowd_data = np.ones([gt_num_max], dtype=np.int32)
if pad_mask:
poly_num_max = max(poly_num)
poly_part_num_max = max(poly_part_num)
point_num_max = max(point_num)
gt_masks_data = -np.ones(
[poly_num_max, poly_part_num_max, point_num_max, 2])
[poly_num_max, poly_part_num_max, point_num_max, 2],
dtype=np.float32)
for i, data in enumerate(samples):
gt_num = data['gt_bbox'].shape[0]
gt_box_data[0:gt_num, :] = data['gt_bbox']
gt_class_data[0:gt_num] = np.squeeze(data['gt_class'])
......
......@@ -85,7 +85,8 @@ class PadBatch(BaseOperator):
data['im_info'][:2] = max_shape[1:3]
if self.pad_gt:
gt_num = []
if data['gt_poly'] is not None and len(data['gt_poly']) > 0:
if 'gt_poly' in data and data['gt_poly'] is not None and len(data[
'gt_poly']) > 0:
pad_mask = True
else:
pad_mask = False
......@@ -103,18 +104,19 @@ class PadBatch(BaseOperator):
for p_p in poly:
point_num.append(int(len(p_p) / 2))
gt_num_max = max(gt_num)
gt_box_data = np.zeros([gt_num_max, 4])
gt_class_data = np.zeros([gt_num_max])
is_crowd_data = np.ones([gt_num_max])
for i, data in enumerate(samples):
gt_box_data = np.zeros([gt_num_max, 4], dtype=np.float32)
gt_class_data = np.zeros([gt_num_max], dtype=np.int32)
is_crowd_data = np.ones([gt_num_max], dtype=np.int32)
if pad_mask:
poly_num_max = max(poly_num)
poly_part_num_max = max(poly_part_num)
point_num_max = max(point_num)
gt_masks_data = -np.ones(
[poly_num_max, poly_part_num_max, point_num_max, 2])
[poly_num_max, poly_part_num_max, point_num_max, 2],
dtype=np.float32)
for i, data in enumerate(samples):
gt_num = data['gt_bbox'].shape[0]
gt_box_data[0:gt_num, :] = data['gt_bbox']
gt_class_data[0:gt_num] = np.squeeze(data['gt_class'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册