diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index a55869a641f8b36a85b1771d487f04c60124651a..d63901a03df045709632db1d5403de4368422339 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -591,7 +591,7 @@ class TableLabelEncode(AttnLabelEncode): replace_empty_cell_token=False, merge_no_span_structure=False, learn_empty_box=False, - point_num=4, + point_num=2, **kwargs): self.max_text_len = max_text_length self.lower = False @@ -669,13 +669,15 @@ class TableLabelEncode(AttnLabelEncode): # encode box bboxes = np.zeros( - (self._max_text_len, self.point_num), dtype=np.float32) + (self._max_text_len, self.point_num * 2), dtype=np.float32) bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32) bbox_idx = 0 + for i, token in enumerate(structure): if self.idx2char[token] in self.td_token: - if 'bbox' in cells[bbox_idx]: + if 'bbox' in cells[bbox_idx] and len(cells[bbox_idx][ + 'tokens']) > 0: bbox = cells[bbox_idx]['bbox'].copy() bbox = np.array(bbox, dtype=np.float32).reshape(-1) bboxes[i] = bbox @@ -723,11 +725,13 @@ class TableMasterLabelEncode(TableLabelEncode): replace_empty_cell_token=False, merge_no_span_structure=False, learn_empty_box=False, - point_num=4, + point_num=2, **kwargs): super(TableMasterLabelEncode, self).__init__( max_text_length, character_dict_path, replace_empty_cell_token, merge_no_span_structure, learn_empty_box, point_num, **kwargs) + self.pad_idx = self.dict[self.pad_str] + self.unknown_idx = self.dict[self.unknown_str] @property def _max_text_len(self): diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py index d7fd99e6952aacc0182a482ca5ae5ddaf959a026..4bdccad3998c00bfc2b0ef12bec2983d2953fdb3 100644 --- a/ppocr/losses/table_att_loss.py +++ b/ppocr/losses/table_att_loss.py @@ -21,15 +21,21 @@ from paddle import nn from paddle.nn import functional as F from paddle import fluid + class TableAttentionLoss(nn.Layer): - def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs): + def __init__(self, + structure_weight, + loc_weight, + use_giou=False, + giou_weight=1.0, + **kwargs): super(TableAttentionLoss, self).__init__() self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none') self.structure_weight = structure_weight self.loc_weight = loc_weight self.use_giou = use_giou self.giou_weight = giou_weight - + def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'): ''' :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] @@ -48,9 +54,10 @@ class TableAttentionLoss(nn.Layer): inters = iw * ih # union - uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3 - ) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * ( - bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps + uni = (preds[:, 2] - preds[:, 0] + 1e-3) * ( + preds[:, 3] - preds[:, 1] + 1e-3) + (bbox[:, 2] - bbox[:, 0] + 1e-3 + ) * (bbox[:, 3] - bbox[:, 1] + + 1e-3) - inters + eps # ious ious = inters / uni @@ -80,30 +87,34 @@ class TableAttentionLoss(nn.Layer): structure_probs = predicts['structure_probs'] structure_targets = batch[1].astype("int64") structure_targets = structure_targets[:, 1:] - if len(batch) == 6: - structure_mask = batch[5].astype("int64") - structure_mask = structure_mask[:, 1:] - structure_mask = paddle.reshape(structure_mask, [-1]) - structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]]) + structure_probs = paddle.reshape(structure_probs, + [-1, structure_probs.shape[-1]]) structure_targets = paddle.reshape(structure_targets, [-1]) structure_loss = self.loss_func(structure_probs, structure_targets) - - if len(batch) == 6: - structure_loss = structure_loss * structure_mask - -# structure_loss = paddle.sum(structure_loss) * self.structure_weight + structure_loss = paddle.mean(structure_loss) * self.structure_weight - + loc_preds = predicts['loc_preds'] loc_targets = batch[2].astype("float32") - loc_targets_mask = batch[4].astype("float32") + loc_targets_mask = batch[3].astype("float32") loc_targets = loc_targets[:, 1:, :] loc_targets_mask = loc_targets_mask[:, 1:, :] - loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight + loc_loss = F.mse_loss(loc_preds * loc_targets_mask, + loc_targets) * self.loc_weight if self.use_giou: - loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight + loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, + loc_targets) * self.giou_weight total_loss = structure_loss + loc_loss + loc_loss_giou - return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou} + return { + 'loss': total_loss, + "structure_loss": structure_loss, + "loc_loss": loc_loss, + "loc_loss_giou": loc_loss_giou + } else: - total_loss = structure_loss + loc_loss - return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss} \ No newline at end of file + total_loss = structure_loss + loc_loss + return { + 'loss': total_loss, + "structure_loss": structure_loss, + "loc_loss": loc_loss + } diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py index 17f3dc92b27cda3e9a19dea2a3bf72988c00b415..26f577a03aeba0977eda2866a9046715f03a1f63 100644 --- a/ppocr/metrics/table_metric.py +++ b/ppocr/metrics/table_metric.py @@ -31,6 +31,8 @@ class TableStructureMetric(object): gt_structure_batch_list): pred_str = ''.join(pred) target_str = ''.join(target) + # pred_str = pred_str.replace('','').replace('','').replace('','').replace('','') + # target_str = target_str.replace('','').replace('','').replace('','').replace('','') if pred_str == target_str: correct_num += 1 all_num += 1 @@ -131,10 +133,10 @@ class TableMetric(object): self.bbox_metric.reset() def format_box(self, box): - if self.point_num == 4: + if self.point_num == 2: x1, y1, x2, y2 = box box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] - elif self.point_num == 8: + elif self.point_num == 4: x1, y1, x2, y2, x3, y3, x4, y4 = box box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] return box diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py index b64713898d40d48f19b3fafc7c175153bcba09a4..4f39d6253d8d596fecdc4736666a6d3106601a82 100644 --- a/ppocr/modeling/heads/table_att_head.py +++ b/ppocr/modeling/heads/table_att_head.py @@ -31,16 +31,18 @@ class TableAttentionHead(nn.Layer): loc_type, in_max_len=488, max_text_length=800, + out_channels=30, + point_num=2, **kwargs): super(TableAttentionHead, self).__init__() self.input_size = in_channels[-1] self.hidden_size = hidden_size - self.elem_num = 30 + self.out_channels = out_channels self.max_text_length = max_text_length self.structure_attention_cell = AttentionGRUCell( - self.input_size, hidden_size, self.elem_num, use_gru=False) - self.structure_generator = nn.Linear(hidden_size, self.elem_num) + self.input_size, hidden_size, self.out_channels, use_gru=False) + self.structure_generator = nn.Linear(hidden_size, self.out_channels) self.loc_type = loc_type self.in_max_len = in_max_len @@ -53,7 +55,8 @@ class TableAttentionHead(nn.Layer): self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1) else: self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1) - self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) + self.loc_generator = nn.Linear(self.input_size + hidden_size, + point_num * 2) def _char_to_onehot(self, input_char, onehot_dim): input_ont_hot = F.one_hot(input_char, onehot_dim) @@ -77,7 +80,7 @@ class TableAttentionHead(nn.Layer): structure = targets[0] for i in range(self.max_text_length + 1): elem_onehots = self._char_to_onehot( - structure[:, i], onehot_dim=self.elem_num) + structure[:, i], onehot_dim=self.out_channels) (outputs, hidden), alpha = self.structure_attention_cell( hidden, fea, elem_onehots) output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) @@ -104,7 +107,7 @@ class TableAttentionHead(nn.Layer): i = 0 while i < max_text_length + 1: elem_onehots = self._char_to_onehot( - temp_elem, onehot_dim=self.elem_num) + temp_elem, onehot_dim=self.out_channels) (outputs, hidden), alpha = self.structure_attention_cell( hidden, fea, elem_onehots) output_hiddens.append(paddle.unsqueeze(outputs, axis=1))