提交 5dfcc983 编写于 作者: 文幕地方's avatar 文幕地方

fix bug

上级 9b1e9ae6
...@@ -591,7 +591,7 @@ class TableLabelEncode(AttnLabelEncode): ...@@ -591,7 +591,7 @@ class TableLabelEncode(AttnLabelEncode):
replace_empty_cell_token=False, replace_empty_cell_token=False,
merge_no_span_structure=False, merge_no_span_structure=False,
learn_empty_box=False, learn_empty_box=False,
point_num=4, point_num=2,
**kwargs): **kwargs):
self.max_text_len = max_text_length self.max_text_len = max_text_length
self.lower = False self.lower = False
...@@ -669,13 +669,15 @@ class TableLabelEncode(AttnLabelEncode): ...@@ -669,13 +669,15 @@ class TableLabelEncode(AttnLabelEncode):
# encode box # encode box
bboxes = np.zeros( bboxes = np.zeros(
(self._max_text_len, self.point_num), dtype=np.float32) (self._max_text_len, self.point_num * 2), dtype=np.float32)
bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32) bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
bbox_idx = 0 bbox_idx = 0
for i, token in enumerate(structure): for i, token in enumerate(structure):
if self.idx2char[token] in self.td_token: if self.idx2char[token] in self.td_token:
if 'bbox' in cells[bbox_idx]: if 'bbox' in cells[bbox_idx] and len(cells[bbox_idx][
'tokens']) > 0:
bbox = cells[bbox_idx]['bbox'].copy() bbox = cells[bbox_idx]['bbox'].copy()
bbox = np.array(bbox, dtype=np.float32).reshape(-1) bbox = np.array(bbox, dtype=np.float32).reshape(-1)
bboxes[i] = bbox bboxes[i] = bbox
...@@ -723,11 +725,13 @@ class TableMasterLabelEncode(TableLabelEncode): ...@@ -723,11 +725,13 @@ class TableMasterLabelEncode(TableLabelEncode):
replace_empty_cell_token=False, replace_empty_cell_token=False,
merge_no_span_structure=False, merge_no_span_structure=False,
learn_empty_box=False, learn_empty_box=False,
point_num=4, point_num=2,
**kwargs): **kwargs):
super(TableMasterLabelEncode, self).__init__( super(TableMasterLabelEncode, self).__init__(
max_text_length, character_dict_path, replace_empty_cell_token, max_text_length, character_dict_path, replace_empty_cell_token,
merge_no_span_structure, learn_empty_box, point_num, **kwargs) merge_no_span_structure, learn_empty_box, point_num, **kwargs)
self.pad_idx = self.dict[self.pad_str]
self.unknown_idx = self.dict[self.unknown_str]
@property @property
def _max_text_len(self): def _max_text_len(self):
......
...@@ -21,8 +21,14 @@ from paddle import nn ...@@ -21,8 +21,14 @@ from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle import fluid from paddle import fluid
class TableAttentionLoss(nn.Layer): class TableAttentionLoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs): def __init__(self,
structure_weight,
loc_weight,
use_giou=False,
giou_weight=1.0,
**kwargs):
super(TableAttentionLoss, self).__init__() super(TableAttentionLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none') self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
self.structure_weight = structure_weight self.structure_weight = structure_weight
...@@ -48,9 +54,10 @@ class TableAttentionLoss(nn.Layer): ...@@ -48,9 +54,10 @@ class TableAttentionLoss(nn.Layer):
inters = iw * ih inters = iw * ih
# union # union
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3 uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (
) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * ( preds[:, 3] - preds[:, 1] + 1e-3) + (bbox[:, 2] - bbox[:, 0] + 1e-3
bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps ) * (bbox[:, 3] - bbox[:, 1] +
1e-3) - inters + eps
# ious # ious
ious = inters / uni ious = inters / uni
...@@ -80,30 +87,34 @@ class TableAttentionLoss(nn.Layer): ...@@ -80,30 +87,34 @@ class TableAttentionLoss(nn.Layer):
structure_probs = predicts['structure_probs'] structure_probs = predicts['structure_probs']
structure_targets = batch[1].astype("int64") structure_targets = batch[1].astype("int64")
structure_targets = structure_targets[:, 1:] structure_targets = structure_targets[:, 1:]
if len(batch) == 6: structure_probs = paddle.reshape(structure_probs,
structure_mask = batch[5].astype("int64") [-1, structure_probs.shape[-1]])
structure_mask = structure_mask[:, 1:]
structure_mask = paddle.reshape(structure_mask, [-1])
structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
structure_targets = paddle.reshape(structure_targets, [-1]) structure_targets = paddle.reshape(structure_targets, [-1])
structure_loss = self.loss_func(structure_probs, structure_targets) structure_loss = self.loss_func(structure_probs, structure_targets)
if len(batch) == 6:
structure_loss = structure_loss * structure_mask
# structure_loss = paddle.sum(structure_loss) * self.structure_weight
structure_loss = paddle.mean(structure_loss) * self.structure_weight structure_loss = paddle.mean(structure_loss) * self.structure_weight
loc_preds = predicts['loc_preds'] loc_preds = predicts['loc_preds']
loc_targets = batch[2].astype("float32") loc_targets = batch[2].astype("float32")
loc_targets_mask = batch[4].astype("float32") loc_targets_mask = batch[3].astype("float32")
loc_targets = loc_targets[:, 1:, :] loc_targets = loc_targets[:, 1:, :]
loc_targets_mask = loc_targets_mask[:, 1:, :] loc_targets_mask = loc_targets_mask[:, 1:, :]
loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight loc_loss = F.mse_loss(loc_preds * loc_targets_mask,
loc_targets) * self.loc_weight
if self.use_giou: if self.use_giou:
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask,
loc_targets) * self.giou_weight
total_loss = structure_loss + loc_loss + loc_loss_giou total_loss = structure_loss + loc_loss + loc_loss_giou
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou} return {
'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss,
"loc_loss_giou": loc_loss_giou
}
else: else:
total_loss = structure_loss + loc_loss total_loss = structure_loss + loc_loss
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss} return {
\ No newline at end of file 'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss
}
...@@ -31,6 +31,8 @@ class TableStructureMetric(object): ...@@ -31,6 +31,8 @@ class TableStructureMetric(object):
gt_structure_batch_list): gt_structure_batch_list):
pred_str = ''.join(pred) pred_str = ''.join(pred)
target_str = ''.join(target) target_str = ''.join(target)
# pred_str = pred_str.replace('<thead>','').replace('</thead>','').replace('<tbody>','').replace('</tbody>','')
# target_str = target_str.replace('<thead>','').replace('</thead>','').replace('<tbody>','').replace('</tbody>','')
if pred_str == target_str: if pred_str == target_str:
correct_num += 1 correct_num += 1
all_num += 1 all_num += 1
...@@ -131,10 +133,10 @@ class TableMetric(object): ...@@ -131,10 +133,10 @@ class TableMetric(object):
self.bbox_metric.reset() self.bbox_metric.reset()
def format_box(self, box): def format_box(self, box):
if self.point_num == 4: if self.point_num == 2:
x1, y1, x2, y2 = box x1, y1, x2, y2 = box
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
elif self.point_num == 8: elif self.point_num == 4:
x1, y1, x2, y2, x3, y3, x4, y4 = box x1, y1, x2, y2, x3, y3, x4, y4 = box
box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
return box return box
...@@ -31,16 +31,18 @@ class TableAttentionHead(nn.Layer): ...@@ -31,16 +31,18 @@ class TableAttentionHead(nn.Layer):
loc_type, loc_type,
in_max_len=488, in_max_len=488,
max_text_length=800, max_text_length=800,
out_channels=30,
point_num=2,
**kwargs): **kwargs):
super(TableAttentionHead, self).__init__() super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1] self.input_size = in_channels[-1]
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.elem_num = 30 self.out_channels = out_channels
self.max_text_length = max_text_length self.max_text_length = max_text_length
self.structure_attention_cell = AttentionGRUCell( self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False) self.input_size, hidden_size, self.out_channels, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.elem_num) self.structure_generator = nn.Linear(hidden_size, self.out_channels)
self.loc_type = loc_type self.loc_type = loc_type
self.in_max_len = in_max_len self.in_max_len = in_max_len
...@@ -53,7 +55,8 @@ class TableAttentionHead(nn.Layer): ...@@ -53,7 +55,8 @@ class TableAttentionHead(nn.Layer):
self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1) self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
else: else:
self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1) self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) self.loc_generator = nn.Linear(self.input_size + hidden_size,
point_num * 2)
def _char_to_onehot(self, input_char, onehot_dim): def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim) input_ont_hot = F.one_hot(input_char, onehot_dim)
...@@ -77,7 +80,7 @@ class TableAttentionHead(nn.Layer): ...@@ -77,7 +80,7 @@ class TableAttentionHead(nn.Layer):
structure = targets[0] structure = targets[0]
for i in range(self.max_text_length + 1): for i in range(self.max_text_length + 1):
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num) structure[:, i], onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots) hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
...@@ -104,7 +107,7 @@ class TableAttentionHead(nn.Layer): ...@@ -104,7 +107,7 @@ class TableAttentionHead(nn.Layer):
i = 0 i = 0
while i < max_text_length + 1: while i < max_text_length + 1:
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.elem_num) temp_elem, onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots) hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册