提交 a0c33908 编写于 作者: 文幕地方's avatar 文幕地方

add TableMaster

上级 2d89b2ce
Global:
use_gpu: true
epoch_num: 17
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/table_master/
save_epoch_step: 17
# evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: ppstructure/docs/table/table.jpg
save_res_path: output/table_master
# for data or label process
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
infer_mode: False
max_text_length: 500
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: MultiStepDecay
learning_rate: 0.001
milestones: [12, 15]
gamma: 0.1
warmup_epoch: 0.02
regularizer:
name: 'L2'
factor: 0.00000
Architecture:
model_type: table
algorithm: TableMaster
Backbone:
name: TableResNetExtra
gcb_config:
ratio: 0.0625
headers: 1
att_scale: False
fusion_type: channel_add
layers: [False, True, True, True]
layers: [1,2,5,3]
Head:
name: TableMasterHead
hidden_size: 512
headers: 8
dropout: 0
d_ff: 2024
max_text_length: 500
Loss:
name: TableMasterLoss
ignore_index: 42 # set to len of dict + 3
PostProcess:
name: TableMasterLabelDecode
box_shape: pad
Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: true # cost many time, set False for training
Train:
dataset:
name: PubTabDataSet
data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/train/
label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/PubTabNet_2.0.0_train.jsonl]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: true
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'structure', 'bboxes', 'bbox_masks','shape']
loader:
shuffle: True
batch_size_per_card: 8
drop_last: True
num_workers: 1
Eval:
dataset:
name: PubTabDataSet
data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/
label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/val_500.jsonl]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [ 480, 480 ]
- TableBoxEncode:
use_xywh: true
- NormalizeImage:
scale: 1./255.
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks','shape' ]
loader:
shuffle: False
drop_last: False
batch_size_per_card: 2
num_workers: 8
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 5 print_batch_step: 5
save_model_dir: ./output/table_mv3/ save_model_dir: ./output/table_mv3/
save_epoch_step: 3 save_epoch_step: 400
# evaluation is run every 400 iterations after the 0th iteration # evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400] eval_batch_step: [0, 400]
cal_metric_during_train: True cal_metric_during_train: True
...@@ -12,13 +12,12 @@ Global: ...@@ -12,13 +12,12 @@ Global:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: doc/table/table.jpg infer_img: ppstructure/docs/table/table.jpg
save_res_path: output/table_mv3
# for data or label process # for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en character_type: en
max_text_length: 100 max_text_length: 500
max_elem_length: 800
max_cell_num: 500
infer_mode: False infer_mode: False
process_total_num: 0 process_total_num: 0
process_cut_num: 0 process_cut_num: 0
...@@ -44,11 +43,8 @@ Architecture: ...@@ -44,11 +43,8 @@ Architecture:
Head: Head:
name: TableAttentionHead name: TableAttentionHead
hidden_size: 256 hidden_size: 256
l2_decay: 0.00001
loc_type: 2 loc_type: 2
max_text_length: 100 max_text_length: 500
max_elem_length: 800
max_cell_num: 500
Loss: Loss:
name: TableAttentionLoss name: TableAttentionLoss
...@@ -61,6 +57,7 @@ PostProcess: ...@@ -61,6 +57,7 @@ PostProcess:
Metric: Metric:
name: TableMetric name: TableMetric
main_indicator: acc main_indicator: acc
compute_bbox_metric: False # cost many time, set False for training
Train: Train:
dataset: dataset:
...@@ -71,18 +68,23 @@ Train: ...@@ -71,18 +68,23 @@ Train:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
- TableBoxEncode:
- ResizeTableImage: - ResizeTableImage:
max_len: 488 max_len: 488
- TableLabelEncode:
- NormalizeImage: - NormalizeImage:
scale: 1./255. scale: 1./255.
mean: [0.485, 0.456, 0.406] mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225] std: [0.229, 0.224, 0.225]
order: 'hwc' order: 'hwc'
- PaddingTableImage: - PaddingTableImage:
size: [488, 488]
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader: loader:
shuffle: True shuffle: True
batch_size_per_card: 32 batch_size_per_card: 32
...@@ -92,24 +94,29 @@ Train: ...@@ -92,24 +94,29 @@ Train:
Eval: Eval:
dataset: dataset:
name: PubTabDataSet name: PubTabDataSet
data_dir: train_data/table/pubtabnet/val/ data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/val_500.jsonl]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
- TableBoxEncode:
- ResizeTableImage: - ResizeTableImage:
max_len: 488 max_len: 488
- TableLabelEncode:
- NormalizeImage: - NormalizeImage:
scale: 1./255. scale: 1./255.
mean: [0.485, 0.456, 0.406] mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225] std: [0.229, 0.224, 0.225]
order: 'hwc' order: 'hwc'
- PaddingTableImage: - PaddingTableImage:
size: [488, 488]
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -48,10 +48,12 @@ class GenTableMask(object): ...@@ -48,10 +48,12 @@ class GenTableMask(object):
in_text = False # 是否遍历到了字符区内 in_text = False # 是否遍历到了字符区内
box_list = [] box_list = []
for i in range(len(project_val_array)): for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 if in_text == False and project_val_array[
i] > spilt_threshold: # 进入字符区了
in_text = True in_text = True
start_idx = i start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 elif project_val_array[
i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i end_idx = i
in_text = False in_text = False
if end_idx - start_idx <= 2: if end_idx - start_idx <= 2:
...@@ -70,7 +72,8 @@ class GenTableMask(object): ...@@ -70,7 +72,8 @@ class GenTableMask(object):
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY) box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
h, w = box_gray_img.shape h, w = box_gray_img.shape
# 灰度图片进行二值化处理 # 灰度图片进行二值化处理
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV) ret, thresh1 = cv2.threshold(box_gray_img, 200, 255,
cv2.THRESH_BINARY_INV)
# 纵向腐蚀 # 纵向腐蚀
if h < w: if h < w:
kernel = np.ones((2, 1), np.uint8) kernel = np.ones((2, 1), np.uint8)
...@@ -95,10 +98,12 @@ class GenTableMask(object): ...@@ -95,10 +98,12 @@ class GenTableMask(object):
box_list = [] box_list = []
spilt_threshold = 0 spilt_threshold = 0
for i in range(len(project_val_array)): for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 if in_text == False and project_val_array[
i] > spilt_threshold: # 进入字符区了
in_text = True in_text = True
start_idx = i start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 elif project_val_array[
i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i end_idx = i
in_text = False in_text = False
if end_idx - start_idx <= 2: if end_idx - start_idx <= 2:
...@@ -120,7 +125,8 @@ class GenTableMask(object): ...@@ -120,7 +125,8 @@ class GenTableMask(object):
h_end = h h_end = h
word_img = erosion[h_start:h_end + 1, :] word_img = erosion[h_start:h_end + 1, :]
word_h, word_w = word_img.shape word_h, word_w = word_img.shape
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h) w_split_list, w_projection_map = self.projection(word_img.T,
word_w, word_h)
w_start, w_end = w_split_list[0][0], w_split_list[-1][1] w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
if h_start > 0: if h_start > 0:
h_start -= 1 h_start -= 1
...@@ -170,7 +176,8 @@ class GenTableMask(object): ...@@ -170,7 +176,8 @@ class GenTableMask(object):
for sno in range(len(split_bbox_list)): for sno in range(len(split_bbox_list)):
left, top, right, bottom = split_bbox_list[sno] left, top, right, bottom = split_bbox_list[sno]
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom]) left, top, right, bottom = self.shrink_bbox(
[left, top, right, bottom])
if self.mask_type == 1: if self.mask_type == 1:
mask_img[top:bottom, left:right] = 1.0 mask_img[top:bottom, left:right] = 1.0
data['mask_img'] = mask_img data['mask_img'] = mask_img
...@@ -179,66 +186,44 @@ class GenTableMask(object): ...@@ -179,66 +186,44 @@ class GenTableMask(object):
data['image'] = mask_img data['image'] = mask_img
return data return data
class ResizeTableImage(object): class ResizeTableImage(object):
def __init__(self, max_len, **kwargs): def __init__(self, max_len, resize_bboxes=False, infer_mode=False,
**kwargs):
super(ResizeTableImage, self).__init__() super(ResizeTableImage, self).__init__()
self.max_len = max_len self.max_len = max_len
self.resize_bboxes = resize_bboxes
self.infer_mode = infer_mode
def get_img_bbox(self, cells): def __call__(self, data):
bbox_list = [] img = data['image']
if len(cells) == 0:
return bbox_list
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
bbox_list.append(bbox)
return bbox_list
def resize_img_table(self, img, bbox_list, max_len):
height, width = img.shape[0:2] height, width = img.shape[0:2]
ratio = max_len / (max(height, width) * 1.0) ratio = self.max_len / (max(height, width) * 1.0)
resize_h = int(height * ratio) resize_h = int(height * ratio)
resize_w = int(width * ratio) resize_w = int(width * ratio)
img_new = cv2.resize(img, (resize_w, resize_h)) resize_img = cv2.resize(img, (resize_w, resize_h))
bbox_list_new = [] if self.resize_bboxes and not self.infer_mode:
for bno in range(len(bbox_list)): data['bboxes'] = data['bboxes'] * ratio
left, top, right, bottom = bbox_list[bno].copy() data['image'] = resize_img
left = int(left * ratio) data['src_img'] = img
top = int(top * ratio) data['shape'] = np.array([resize_h, resize_w, ratio, ratio])
right = int(right * ratio)
bottom = int(bottom * ratio)
bbox_list_new.append([left, top, right, bottom])
return img_new, bbox_list_new
def __call__(self, data):
img = data['image']
if 'cells' not in data:
cells = []
else:
cells = data['cells']
bbox_list = self.get_img_bbox(cells)
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
data['image'] = img_new
cell_num = len(cells)
bno = 0
for cno in range(cell_num):
if "bbox" in data['cells'][cno]:
data['cells'][cno]['bbox'] = bbox_list_new[bno]
bno += 1
data['max_len'] = self.max_len data['max_len'] = self.max_len
return data return data
class PaddingTableImage(object): class PaddingTableImage(object):
def __init__(self, **kwargs): def __init__(self, size, **kwargs):
super(PaddingTableImage, self).__init__() super(PaddingTableImage, self).__init__()
self.size = size
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
max_len = data['max_len'] pad_h, pad_w = self.size
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32) padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
height, width = img.shape[0:2] height, width = img.shape[0:2]
padding_img[0:height, 0:width, :] = img.copy() padding_img[0:height, 0:width, :] = img.copy()
data['image'] = padding_img data['image'] = padding_img
shape = data['shape'].tolist()
shape.extend([pad_h, pad_w])
data['shape'] = np.array(shape)
return data return data
\ No newline at end of file
...@@ -443,7 +443,9 @@ class KieLabelEncode(object): ...@@ -443,7 +443,9 @@ class KieLabelEncode(object):
elif 'key_cls' in anno.keys(): elif 'key_cls' in anno.keys():
labels.append(anno['key_cls']) labels.append(anno['key_cls'])
else: else:
raise ValueError("Cannot found 'key_cls' in ann.keys(), please check your training annotation.") raise ValueError(
"Cannot found 'key_cls' in ann.keys(), please check your training annotation."
)
edges.append(ann.get('edge', 0)) edges.append(ann.get('edge', 0))
ann_infos = dict( ann_infos = dict(
image=data['image'], image=data['image'],
...@@ -580,171 +582,197 @@ class SRNLabelEncode(BaseRecLabelEncode): ...@@ -580,171 +582,197 @@ class SRNLabelEncode(BaseRecLabelEncode):
return idx return idx
class TableLabelEncode(object): class TableLabelEncode(AttnLabelEncode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
def __init__(self, def __init__(self,
max_text_length, max_text_length,
max_elem_length,
max_cell_num,
character_dict_path, character_dict_path,
span_weight=1.0, replace_empty_cell_token=False,
merge_no_span_structure=False,
learn_empty_box=False,
point_num=4,
**kwargs): **kwargs):
self.max_text_length = max_text_length self.max_text_len = max_text_length
self.max_elem_length = max_elem_length self.lower = False
self.max_cell_num = max_cell_num self.learn_empty_box = learn_empty_box
list_character, list_elem = self.load_char_elem_dict( self.merge_no_span_structure = merge_no_span_structure
character_dict_path) self.replace_empty_cell_token = replace_empty_cell_token
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem) dict_character = []
self.dict_character = {}
for i, char in enumerate(list_character):
self.dict_character[char] = i
self.dict_elem = {}
for i, elem in enumerate(list_elem):
self.dict_elem[elem] = i
self.span_weight = span_weight
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin: with open(character_dict_path, "rb") as fin:
lines = fin.readlines() lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\r\n").split("\t") for line in lines:
character_num = int(substr[0]) line = line.decode('utf-8').strip("\n").strip("\r\n")
elem_num = int(substr[1]) dict_character.append(line)
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\r\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\r\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def get_span_idx_list(self): dict_character = self.add_special_char(dict_character)
span_idx_list = [] self.dict = {}
for elem in self.dict_elem: for i, char in enumerate(dict_character):
if 'span' in elem: self.dict[char] = i
span_idx_list.append(self.dict_elem[elem]) self.idx2char = {v: k for k, v in self.dict.items()}
return span_idx_list
self.character = dict_character
self.point_num = point_num
self.pad_idx = self.dict[self.beg_str]
self.start_idx = self.dict[self.beg_str]
self.end_idx = self.dict[self.end_str]
self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
self.empty_bbox_token_dict = {
"[]": '<eb></eb>',
"[' ']": '<eb1></eb1>',
"['<b>', ' ', '</b>']": '<eb2></eb2>',
"['\\u2028', '\\u2028']": '<eb3></eb3>',
"['<sup>', ' ', '</sup>']": '<eb4></eb4>',
"['<b>', '</b>']": '<eb5></eb5>',
"['<i>', ' ', '</i>']": '<eb6></eb6>',
"['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
"['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
"['<i>', '</i>']": '<eb9></eb9>',
"['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']":
'<eb10></eb10>',
}
@property
def _max_text_len(self):
return self.max_text_len + 2
def __call__(self, data): def __call__(self, data):
cells = data['cells'] cells = data['cells']
structure = data['structure']['tokens'] structure = data['structure']
structure = self.encode(structure, 'elem') if self.merge_no_span_structure:
structure = self._merge_no_span_structure(structure)
if self.replace_empty_cell_token:
structure = self._replace_empty_cell_token(structure, cells)
# remove empty token and add " " to span token
new_structure = []
for token in structure:
if token != '':
if 'span' in token and token[0] != ' ':
token = ' ' + token
new_structure.append(token)
# encode structure
structure = self.encode(new_structure)
if structure is None: if structure is None:
return None return None
elem_num = len(structure)
structure = [0] + structure + [len(self.dict_elem) - 1] structure = [self.start_idx] + structure + [self.end_idx
structure = structure + [0] * (self.max_elem_length + 2 - len(structure) ] # add sos abd eos
) structure = structure + [self.pad_idx] * (self._max_text_len -
len(structure)) # pad
structure = np.array(structure) structure = np.array(structure)
data['structure'] = structure data['structure'] = structure
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td']
span_idx_list = self.get_span_idx_list()
td_idx_list = np.logical_or(structure == elem_char_idx1,
structure == elem_char_idx2)
td_idx_list = np.where(td_idx_list)[0]
structure_mask = np.ones(
(self.max_elem_length + 2, 1), dtype=np.float32)
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
bbox_list_mask = np.zeros(
(self.max_elem_length + 2, 1), dtype=np.float32)
img_height, img_width, img_ch = data['image'].shape
if len(span_idx_list) > 0:
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
span_weight = min(max(span_weight, 1.0), self.span_weight)
for cno in range(len(cells)):
if 'bbox' in cells[cno]:
bbox = cells[cno]['bbox'].copy()
bbox[0] = bbox[0] * 1.0 / img_width
bbox[1] = bbox[1] * 1.0 / img_height
bbox[2] = bbox[2] * 1.0 / img_width
bbox[3] = bbox[3] * 1.0 / img_height
td_idx = td_idx_list[cno]
bbox_list[td_idx] = bbox
bbox_list_mask[td_idx] = 1.0
cand_span_idx = td_idx + 1
if cand_span_idx < (self.max_elem_length + 2):
if structure[cand_span_idx] in span_idx_list:
structure_mask[cand_span_idx] = span_weight
data['bbox_list'] = bbox_list
data['bbox_list_mask'] = bbox_list_mask
data['structure_mask'] = structure_mask
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
data['sp_tokens'] = np.array([
char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num, elem_num
])
return data
def encode(self, text, char_or_elem): if len(structure) > self._max_text_len:
"""convert text-label into text-index.
"""
if char_or_elem == "char":
max_len = self.max_text_length
current_dict = self.dict_character
else:
max_len = self.max_elem_length
current_dict = self.dict_elem
if len(text) > max_len:
return None
if len(text) == 0:
if char_or_elem == "char":
return [self.dict_character['space']]
else:
return None
text_list = []
for char in text:
if char not in current_dict:
return None
text_list.append(current_dict[char])
if len(text_list) == 0:
if char_or_elem == "char":
return [self.dict_character['space']]
else:
return None return None
return text_list
def get_ignored_tokens(self, char_or_elem): # encode box
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem) bboxes = np.zeros(
end_idx = self.get_beg_end_flag_idx("end", char_or_elem) (self._max_text_len, self.point_num), dtype=np.float32)
return [beg_idx, end_idx] bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
bbox_idx = 0
for i, token in enumerate(structure):
if self.idx2char[token] in self.td_token:
if 'bbox' in cells[bbox_idx]:
bbox = cells[bbox_idx]['bbox'].copy()
bbox = np.array(bbox, dtype=np.float32).reshape(-1)
bboxes[i] = bbox
bbox_masks[i] = 1.0
if self.learn_empty_box:
bbox_masks[i] = 1.0
bbox_idx += 1
data['bboxes'] = bboxes
data['bbox_masks'] = bbox_masks
return data
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): def _merge_no_span_structure(self, structure):
if char_or_elem == "char": new_structure = []
if beg_or_end == "beg": i = 0
idx = np.array(self.dict_character[self.beg_str]) while i < len(structure):
elif beg_or_end == "end": token = structure[i]
idx = np.array(self.dict_character[self.end_str]) if token == '<td>':
else: token = '<td></td>'
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \ i += 1
% beg_or_end new_structure.append(token)
elif char_or_elem == "elem": i += 1
if beg_or_end == "beg": return new_structure
idx = np.array(self.dict_elem[self.beg_str])
elif beg_or_end == "end": def _replace_empty_cell_token(self, token_list, cells):
idx = np.array(self.dict_elem[self.end_str]) bbox_idx = 0
else: add_empty_bbox_token_list = []
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ for token in token_list:
% beg_or_end if token in ['<td></td>', '<td', '<td>']:
if 'bbox' not in cells[bbox_idx].keys():
content = str(cells[bbox_idx]['tokens'])
token = self.empty_bbox_token_dict[content]
add_empty_bbox_token_list.append(token)
bbox_idx += 1
else: else:
assert False, "Unsupport type %s in char_or_elem" \ add_empty_bbox_token_list.append(token)
% char_or_elem return add_empty_bbox_token_list
return idx
class TableMasterLabelEncode(TableLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path,
replace_empty_cell_token=False,
merge_no_span_structure=False,
learn_empty_box=False,
point_num=4,
**kwargs):
super(TableMasterLabelEncode, self).__init__(
max_text_length, character_dict_path, replace_empty_cell_token,
merge_no_span_structure, learn_empty_box, point_num, **kwargs)
@property
def _max_text_len(self):
return self.max_text_len
def add_special_char(self, dict_character):
self.beg_str = '<SOS>'
self.end_str = '<EOS>'
self.unknown_str = '<UKN>'
self.pad_str = '<PAD>'
dict_character = dict_character
dict_character = dict_character + [
self.unknown_str, self.beg_str, self.end_str, self.pad_str
]
return dict_character
class TableBoxEncode(object):
def __init__(self, use_xywh=False, **kwargs):
self.use_xywh = use_xywh
def __call__(self, data):
img_height, img_width = data['image'].shape[:2]
bboxes = data['bboxes']
if self.use_xywh and bboxes.shape[1] == 4:
bboxes = self.xyxy2xywh(bboxes)
bboxes[:, 0::2] /= img_width
bboxes[:, 1::2] /= img_height
data['bboxes'] = bboxes
return data
def xyxy2xywh(self, bboxes):
"""
Convert coord (x1,y1,x2,y2) to (x,y,w,h).
where (x1,y1) is top-left, (x2,y2) is bottom-right.
(x,y) is bbox center and (w,h) is width and height.
:param bboxes: (x1, y1, x2, y2)
:return:
"""
new_bboxes = np.empty_like(bboxes)
new_bboxes[:, 0] = (bboxes[:, 0] + bboxes[:, 2]) / 2 # x center
new_bboxes[:, 1] = (bboxes[:, 1] + bboxes[:, 3]) / 2 # y center
new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] # width
new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] # height
return new_bboxes
class SARLabelEncode(BaseRecLabelEncode): class SARLabelEncode(BaseRecLabelEncode):
...@@ -1030,7 +1058,6 @@ class MultiLabelEncode(BaseRecLabelEncode): ...@@ -1030,7 +1058,6 @@ class MultiLabelEncode(BaseRecLabelEncode):
use_space_char, **kwargs) use_space_char, **kwargs)
def __call__(self, data): def __call__(self, data):
data_ctc = copy.deepcopy(data) data_ctc = copy.deepcopy(data)
data_sar = copy.deepcopy(data) data_sar = copy.deepcopy(data)
data_out = dict() data_out = dict()
......
...@@ -16,6 +16,7 @@ import os ...@@ -16,6 +16,7 @@ import os
import random import random
from paddle.io import Dataset from paddle.io import Dataset
import json import json
from copy import deepcopy
from .imaug import transform, create_operators from .imaug import transform, create_operators
...@@ -29,33 +30,63 @@ class PubTabDataSet(Dataset): ...@@ -29,33 +30,63 @@ class PubTabDataSet(Dataset):
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
label_file_path = dataset_config.pop('label_file_path') label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir'] self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
self.do_hard_select = False
if 'hard_select' in loader_config:
self.do_hard_select = loader_config['hard_select']
self.hard_prob = loader_config['hard_prob']
if self.do_hard_select:
self.img_select_prob = self.load_hard_select_prob()
self.table_select_type = None
if 'table_select_type' in loader_config:
self.table_select_type = loader_config['table_select_type']
self.table_select_prob = loader_config['table_select_prob']
self.seed = seed self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_path) self.mode = mode.lower()
with open(label_file_path, "rb") as f: logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = f.readlines() self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines))) # self.check(config['Global']['max_text_length'])
if mode.lower() == "train":
if mode.lower() == "train" and self.do_shuffle:
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list] self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def check(self, max_text_length):
data_lines = []
for line in self.data_lines:
data_line = line.decode('utf-8').strip("\n")
info = json.loads(data_line)
file_name = info['filename']
cells = info['html']['cells'].copy()
structure = info['html']['structure']['tokens'].copy()
img_path = os.path.join(self.data_dir, file_name)
if not os.path.exists(img_path):
self.logger.warning("{} does not exist!".format(img_path))
continue
if len(structure) == 0 or len(structure) > max_text_length:
continue
# data = {'img_path': img_path, 'cells': cells, 'structure':structure,'file_name':file_name}
data_lines.append(line)
self.data_lines = data_lines
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: if self.do_shuffle:
random.seed(self.seed) random.seed(self.seed)
...@@ -68,47 +99,34 @@ class PubTabDataSet(Dataset): ...@@ -68,47 +99,34 @@ class PubTabDataSet(Dataset):
data_line = data_line.decode('utf-8').strip("\n") data_line = data_line.decode('utf-8').strip("\n")
info = json.loads(data_line) info = json.loads(data_line)
file_name = info['filename'] file_name = info['filename']
select_flag = True
if self.do_hard_select:
prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1):
select_flag = False
if self.table_select_type:
structure = info['html']['structure']['tokens'].copy()
structure_str = ''.join(structure)
table_type = "simple"
if 'colspan' in structure_str or 'rowspan' in structure_str:
table_type = "complex"
if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1):
select_flag = False
if select_flag:
cells = info['html']['cells'].copy() cells = info['html']['cells'].copy()
structure = info['html']['structure'].copy() structure = info['html']['structure']['tokens'].copy()
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
data = { data = {
'img_path': img_path, 'img_path': img_path,
'cells': cells, 'cells': cells,
'structure': structure 'structure': structure,
'file_name': file_name
} }
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
img = f.read() img = f.read()
data['image'] = img data['image'] = img
outs = transform(data, self.ops) outs = transform(data, self.ops)
else: except:
outs = None import traceback
except Exception as e: err = traceback.format_exc()
self.logger.error( self.logger.error(
"When parsing line {}, error happened with msg: {}".format( "When parsing line {}, error happened with msg: {}".format(err))
data_line, e))
outs = None outs = None
if outs is None: if outs is None:
return self.__getitem__(np.random.randint(self.__len__())) rnd_idx = np.random.randint(self.__len__(
)) if self.mode == "train" else (idx + 1) % self.__len__()
return self.__getitem__(rnd_idx)
return outs return outs
def __len__(self): def __len__(self):
return len(self.data_idx_order_list) return len(self.data_lines)
...@@ -51,7 +51,7 @@ from .combined_loss import CombinedLoss ...@@ -51,7 +51,7 @@ from .combined_loss import CombinedLoss
# table loss # table loss
from .table_att_loss import TableAttentionLoss from .table_att_loss import TableAttentionLoss
from .table_master_loss import TableMasterLoss
# vqa token loss # vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
...@@ -61,7 +61,8 @@ def build_loss(config): ...@@ -61,7 +61,8 @@ def build_loss(config):
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss', 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss' 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
class TableMasterLoss(nn.Layer):
def __init__(self, ignore_index=-1):
super(TableMasterLoss, self).__init__()
self.structure_loss = nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction='mean')
self.box_loss = nn.L1Loss(reduction='sum')
self.eps = 1e-12
def forward(self, predicts, batch):
# structure_loss
structure_probs = predicts['structure_probs']
structure_targets = batch[1]
structure_targets = structure_targets[:, 1:]
structure_probs = structure_probs.reshape(
[-1, structure_probs.shape[-1]])
structure_targets = structure_targets.reshape([-1])
structure_loss = self.structure_loss(structure_probs, structure_targets)
structure_loss = structure_loss.mean()
losses = dict(structure_loss=structure_loss)
# box loss
bboxes_preds = predicts['loc_preds']
bboxes_targets = batch[2][:, 1:, :]
bbox_masks = batch[3][:, 1:]
# mask empty-bbox or non-bbox structure token's bbox.
masked_bboxes_preds = bboxes_preds * bbox_masks
masked_bboxes_targets = bboxes_targets * bbox_masks
# horizon loss (x and width)
horizon_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 0::2],
masked_bboxes_targets[:, :, 0::2])
horizon_loss = horizon_sum_loss / (bbox_masks.sum() + self.eps)
# vertical loss (y and height)
vertical_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 1::2],
masked_bboxes_targets[:, :, 1::2])
vertical_loss = vertical_sum_loss / (bbox_masks.sum() + self.eps)
horizon_loss = horizon_loss.mean()
vertical_loss = vertical_loss.mean()
all_loss = structure_loss + horizon_loss + vertical_loss
losses.update({
'loss': all_loss,
'horizon_bbox_loss': horizon_loss,
'vertical_bbox_loss': vertical_loss
})
return losses
...@@ -12,29 +12,30 @@ ...@@ -12,29 +12,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from ppocr.metrics.det_metric import DetMetric
class TableMetric(object): class TableStructureMetric(object):
def __init__(self, main_indicator='acc', **kwargs): def __init__(self, main_indicator='acc', eps=1e-6, **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.eps = 1e-5 self.eps = eps
self.reset() self.reset()
def __call__(self, pred, batch, *args, **kwargs): def __call__(self, pred_label, batch=None, *args, **kwargs):
structure_probs = pred['structure_probs'].numpy() preds, labels = pred_label
structure_labels = batch[1] pred_structure_batch_list = preds['structure_batch_list']
gt_structure_batch_list = labels['structure_batch_list']
correct_num = 0 correct_num = 0
all_num = 0 all_num = 0
structure_probs = np.argmax(structure_probs, axis=2) for (pred, pred_conf), target in zip(pred_structure_batch_list,
structure_labels = structure_labels[:, 1:] gt_structure_batch_list):
batch_size = structure_probs.shape[0] pred_str = ''.join(pred)
for bno in range(batch_size): target_str = ''.join(target)
all_num += 1 if pred_str == target_str:
if (structure_probs[bno] == structure_labels[bno]).all():
correct_num += 1 correct_num += 1
all_num += 1
self.correct_num += correct_num self.correct_num += correct_num
self.all_num += all_num self.all_num += all_num
return {'acc': correct_num * 1.0 / (all_num + self.eps), }
def get_metric(self): def get_metric(self):
""" """
...@@ -49,3 +50,91 @@ class TableMetric(object): ...@@ -49,3 +50,91 @@ class TableMetric(object):
def reset(self): def reset(self):
self.correct_num = 0 self.correct_num = 0
self.all_num = 0 self.all_num = 0
self.len_acc_num = 0
self.token_nums = 0
self.anys_dict = dict()
from collections import defaultdict
self.error_num_dict = defaultdict(int)
class TableMetric(object):
def __init__(self,
main_indicator='acc',
compute_bbox_metric=False,
point_num=4,
**kwargs):
"""
@param sub_metrics: configs of sub_metric
@param main_matric: main_matric for save best_model
@param kwargs:
"""
self.structure_metric = TableStructureMetric()
self.bbox_metric = DetMetric() if compute_bbox_metric else None
self.main_indicator = main_indicator
self.point_num = point_num
self.reset()
def __call__(self, pred_label, batch=None, *args, **kwargs):
self.structure_metric(pred_label)
if self.bbox_metric is not None:
self.bbox_metric(*self.prepare_bbox_metric_input(pred_label))
def prepare_bbox_metric_input(self, pred_label):
pred_bbox_batch_list = []
gt_ignore_tags_batch_list = []
gt_bbox_batch_list = []
preds, labels = pred_label
batch_num = len(preds['bbox_batch_list'])
for batch_idx in range(batch_num):
# pred
pred_bbox_list = [
self.format_box(pred_box)
for pred_box in preds['bbox_batch_list'][batch_idx]
]
pred_bbox_batch_list.append({'points': pred_bbox_list})
# gt
gt_bbox_list = []
gt_ignore_tags_list = []
for gt_box in labels['bbox_batch_list'][batch_idx]:
gt_bbox_list.append(self.format_box(gt_box))
gt_ignore_tags_list.append(0)
gt_bbox_batch_list.append(gt_bbox_list)
gt_ignore_tags_batch_list.append(gt_ignore_tags_list)
return [
pred_bbox_batch_list,
[0, 0, gt_bbox_batch_list, gt_ignore_tags_batch_list]
]
def get_metric(self):
structure_metric = self.structure_metric.get_metric()
if self.bbox_metric is None:
return structure_metric
bbox_metric = self.bbox_metric.get_metric()
if self.main_indicator == self.bbox_metric.main_indicator:
output = bbox_metric
for sub_key in structure_metric:
output["structure_metric_{}".format(
sub_key)] = structure_metric[sub_key]
else:
output = structure_metric
for sub_key in bbox_metric:
output["bbox_metric_{}".format(sub_key)] = bbox_metric[sub_key]
return output
def reset(self):
self.structure_metric.reset()
if self.bbox_metric is not None:
self.bbox_metric.reset()
def format_box(self, box):
if self.point_num == 4:
x1, y1, x2, y2 = box
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
elif self.point_num == 8:
x1, y1, x2, y2, x3, y3, x4, y4 = box
box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
return box
...@@ -20,7 +20,10 @@ def build_backbone(config, model_type): ...@@ -20,7 +20,10 @@ def build_backbone(config, model_type):
from .det_mobilenet_v3 import MobileNetV3 from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet from .det_resnet_vd import ResNet
from .det_resnet_vd_sast import ResNet_SAST from .det_resnet_vd_sast import ResNet_SAST
support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"] from .table_master_resnet import TableResNetExtra
support_dict = [
"MobileNetV3", "ResNet", "ResNet_SAST", "TableResNetExtra"
]
elif model_type == "rec" or model_type == "cls": elif model_type == "rec" or model_type == "cls":
from .rec_mobilenet_v3 import MobileNetV3 from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet from .rec_resnet_vd import ResNet
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
gcb_config=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2D(
inplanes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm2D(planes, momentum=0.9)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(
planes, planes, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(planes, momentum=0.9)
self.downsample = downsample
self.stride = stride
self.gcb_config = gcb_config
if self.gcb_config is not None:
gcb_ratio = gcb_config['ratio']
gcb_headers = gcb_config['headers']
att_scale = gcb_config['att_scale']
fusion_type = gcb_config['fusion_type']
self.context_block = MultiAspectGCAttention(
inplanes=planes,
ratio=gcb_ratio,
headers=gcb_headers,
att_scale=att_scale,
fusion_type=fusion_type)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.gcb_config is not None:
out = self.context_block(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
def get_gcb_config(gcb_config, layer):
if gcb_config is None or not gcb_config['layers'][layer]:
return None
else:
return gcb_config
class TableResNetExtra(nn.Layer):
def __init__(self, layers, in_channels=3, gcb_config=None):
assert len(layers) >= 4
super(TableResNetExtra, self).__init__()
self.inplanes = 128
self.conv1 = nn.Conv2D(
in_channels,
64,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm2D(64)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2D(
64, 128, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(128)
self.relu2 = nn.ReLU()
self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2)
self.layer1 = self._make_layer(
BasicBlock,
256,
layers[0],
stride=1,
gcb_config=get_gcb_config(gcb_config, 0))
self.conv3 = nn.Conv2D(
256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn3 = nn.BatchNorm2D(256)
self.relu3 = nn.ReLU()
self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2)
self.layer2 = self._make_layer(
BasicBlock,
256,
layers[1],
stride=1,
gcb_config=get_gcb_config(gcb_config, 1))
self.conv4 = nn.Conv2D(
256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn4 = nn.BatchNorm2D(256)
self.relu4 = nn.ReLU()
self.maxpool3 = nn.MaxPool2D(kernel_size=2, stride=2)
self.layer3 = self._make_layer(
BasicBlock,
512,
layers[2],
stride=1,
gcb_config=get_gcb_config(gcb_config, 2))
self.conv5 = nn.Conv2D(
512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn5 = nn.BatchNorm2D(512)
self.relu5 = nn.ReLU()
self.layer4 = self._make_layer(
BasicBlock,
512,
layers[3],
stride=1,
gcb_config=get_gcb_config(gcb_config, 3))
self.conv6 = nn.Conv2D(
512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn6 = nn.BatchNorm2D(512)
self.relu6 = nn.ReLU()
self.out_channels = [256, 256, 512]
def _make_layer(self, block, planes, blocks, stride=1, gcb_config=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias_attr=False),
nn.BatchNorm2D(planes * block.expansion), )
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
gcb_config=gcb_config))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
f = []
x = self.conv1(x) # 1,64,480,480
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x) # 1,128,480,480
x = self.bn2(x)
x = self.relu2(x)
# (48, 160)
x = self.maxpool1(x) # 1,64,240,240
x = self.layer1(x)
x = self.conv3(x) # 1,256,240,240
x = self.bn3(x)
x = self.relu3(x)
f.append(x)
# (24, 80)
x = self.maxpool2(x) # 1,256,120,120
x = self.layer2(x)
x = self.conv4(x) # 1,256,120,120
x = self.bn4(x)
x = self.relu4(x)
f.append(x)
# (12, 40)
x = self.maxpool3(x) # 1,256,60,60
x = self.layer3(x) # 1,512,60,60
x = self.conv5(x) # 1,512,60,60
x = self.bn5(x)
x = self.relu5(x)
x = self.layer4(x) # 1,512,60,60
x = self.conv6(x) # 1,512,60,60
x = self.bn6(x)
x = self.relu6(x)
f.append(x)
# (6, 40)
return f
class MultiAspectGCAttention(nn.Layer):
def __init__(self,
inplanes,
ratio,
headers,
pooling_type='att',
att_scale=False,
fusion_type='channel_add'):
super(MultiAspectGCAttention, self).__init__()
assert pooling_type in ['avg', 'att']
assert fusion_type in ['channel_add', 'channel_mul', 'channel_concat']
assert inplanes % headers == 0 and inplanes >= 8 # inplanes must be divided by headers evenly
self.headers = headers
self.inplanes = inplanes
self.ratio = ratio
self.planes = int(inplanes * ratio)
self.pooling_type = pooling_type
self.fusion_type = fusion_type
self.att_scale = False
self.single_header_inplanes = int(inplanes / headers)
if pooling_type == 'att':
self.conv_mask = nn.Conv2D(
self.single_header_inplanes, 1, kernel_size=1)
self.softmax = nn.Softmax(axis=2)
else:
self.avg_pool = nn.AdaptiveAvgPool2D(1)
if fusion_type == 'channel_add':
self.channel_add_conv = nn.Sequential(
nn.Conv2D(
self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
nn.Conv2D(
self.planes, self.inplanes, kernel_size=1))
elif fusion_type == 'channel_concat':
self.channel_concat_conv = nn.Sequential(
nn.Conv2D(
self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
nn.Conv2D(
self.planes, self.inplanes, kernel_size=1))
# for concat
self.cat_conv = nn.Conv2D(
2 * self.inplanes, self.inplanes, kernel_size=1)
elif fusion_type == 'channel_mul':
self.channel_mul_conv = nn.Sequential(
nn.Conv2D(
self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
nn.Conv2D(
self.planes, self.inplanes, kernel_size=1))
def spatial_pool(self, x):
batch, channel, height, width = x.shape
if self.pooling_type == 'att':
# [N*headers, C', H , W] C = headers * C'
x = x.reshape([
batch * self.headers, self.single_header_inplanes, height, width
])
input_x = x
# [N*headers, C', H * W] C = headers * C'
# input_x = input_x.view(batch, channel, height * width)
input_x = input_x.reshape([
batch * self.headers, self.single_header_inplanes,
height * width
])
# [N*headers, 1, C', H * W]
input_x = input_x.unsqueeze(1)
# [N*headers, 1, H, W]
context_mask = self.conv_mask(x)
# [N*headers, 1, H * W]
context_mask = context_mask.reshape(
[batch * self.headers, 1, height * width])
# scale variance
if self.att_scale and self.headers > 1:
context_mask = context_mask / paddle.sqrt(
self.single_header_inplanes)
# [N*headers, 1, H * W]
context_mask = self.softmax(context_mask)
# [N*headers, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N*headers, 1, C', 1] = [N*headers, 1, C', H * W] * [N*headers, 1, H * W, 1]
context = paddle.matmul(input_x, context_mask)
# [N, headers * C', 1, 1]
context = context.reshape(
[batch, self.headers * self.single_header_inplanes, 1, 1])
else:
# [N, C, 1, 1]
context = self.avg_pool(x)
return context
def forward(self, x):
# [N, C, 1, 1]
context = self.spatial_pool(x)
out = x
if self.fusion_type == 'channel_mul':
# [N, C, 1, 1]
channel_mul_term = F.sigmoid(self.channel_mul_conv(context))
out = out * channel_mul_term
elif self.fusion_type == 'channel_add':
# [N, C, 1, 1]
channel_add_term = self.channel_add_conv(context)
out = out + channel_add_term
else:
# [N, C, 1, 1]
channel_concat_term = self.channel_concat_conv(context)
# use concat
_, C1, _, _ = channel_concat_term.shape
N, C2, H, W = out.shape
out = paddle.concat(
[out, channel_concat_term.expand([-1, -1, H, W])], axis=1)
out = self.cat_conv(out)
out = F.layer_norm(out, [self.inplanes, H, W])
out = F.relu(out)
return out
...@@ -41,12 +41,13 @@ def build_head(config): ...@@ -41,12 +41,13 @@ def build_head(config):
from .kie_sdmgr_head import SDMGRHead from .kie_sdmgr_head import SDMGRHead
from .table_att_head import TableAttentionHead from .table_att_head import TableAttentionHead
from .table_master_head import TableMasterHead
support_dict = [ support_dict = [
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead' 'MultiHead', 'TableMasterHead'
] ]
#table head #table head
......
...@@ -21,6 +21,8 @@ import paddle.nn as nn ...@@ -21,6 +21,8 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
import numpy as np import numpy as np
from .rec_att_head import AttentionGRUCell
class TableAttentionHead(nn.Layer): class TableAttentionHead(nn.Layer):
def __init__(self, def __init__(self,
...@@ -28,17 +30,13 @@ class TableAttentionHead(nn.Layer): ...@@ -28,17 +30,13 @@ class TableAttentionHead(nn.Layer):
hidden_size, hidden_size,
loc_type, loc_type,
in_max_len=488, in_max_len=488,
max_text_length=100, max_text_length=800,
max_elem_length=800,
max_cell_num=500,
**kwargs): **kwargs):
super(TableAttentionHead, self).__init__() super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1] self.input_size = in_channels[-1]
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.elem_num = 30 self.elem_num = 30
self.max_text_length = max_text_length self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
self.structure_attention_cell = AttentionGRUCell( self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False) self.input_size, hidden_size, self.elem_num, use_gru=False)
...@@ -50,11 +48,11 @@ class TableAttentionHead(nn.Layer): ...@@ -50,11 +48,11 @@ class TableAttentionHead(nn.Layer):
self.loc_generator = nn.Linear(hidden_size, 4) self.loc_generator = nn.Linear(hidden_size, 4)
else: else:
if self.in_max_len == 640: if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1) self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1)
elif self.in_max_len == 800: elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1) self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
else: else:
self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1) self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
def _char_to_onehot(self, input_char, onehot_dim): def _char_to_onehot(self, input_char, onehot_dim):
...@@ -77,7 +75,7 @@ class TableAttentionHead(nn.Layer): ...@@ -77,7 +75,7 @@ class TableAttentionHead(nn.Layer):
output_hiddens = [] output_hiddens = []
if self.training and targets is not None: if self.training and targets is not None:
structure = targets[0] structure = targets[0]
for i in range(self.max_elem_length + 1): for i in range(self.max_text_length + 1):
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num) structure[:, i], onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
...@@ -102,9 +100,9 @@ class TableAttentionHead(nn.Layer): ...@@ -102,9 +100,9 @@ class TableAttentionHead(nn.Layer):
elem_onehots = None elem_onehots = None
outputs = None outputs = None
alpha = None alpha = None
max_elem_length = paddle.to_tensor(self.max_elem_length) max_text_length = paddle.to_tensor(self.max_text_length)
i = 0 i = 0
while i < max_elem_length + 1: while i < max_text_length + 1:
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.elem_num) temp_elem, onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
...@@ -128,119 +126,3 @@ class TableAttentionHead(nn.Layer): ...@@ -128,119 +126,3 @@ class TableAttentionHead(nn.Layer):
loc_preds = self.loc_generator(loc_concat) loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds) loc_preds = F.sigmoid(loc_preds)
return {'structure_probs': structure_probs, 'loc_preds': loc_preds} return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
class AttentionGRUCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
class AttentionLSTM(nn.Layer):
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
super(AttentionLSTM, self).__init__()
self.input_size = in_channels
self.hidden_size = hidden_size
self.num_classes = out_channels
self.attention_cell = AttentionLSTMCell(
in_channels, hidden_size, out_channels, use_gru=False)
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
num_steps = batch_max_length
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
(batch_size, self.hidden_size)))
output_hiddens = []
if targets is not None:
for i in range(num_steps):
# one-hot vectors for a i-th char
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
hidden = (hidden[1][0], hidden[1][1])
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output)
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets, onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
probs_step = self.generator(hidden[0])
hidden = (hidden[1][0], hidden[1][1])
if probs is None:
probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
[probs, paddle.unsqueeze(
probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1)
targets = next_input
return probs
class AttentionLSTMCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionLSTMCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
if not use_gru:
self.rnn = nn.LSTMCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
else:
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
import paddle
from paddle import nn
from paddle.nn import functional as F
class TableMasterHead(nn.Layer):
"""
Split to two transformer header at the last layer.
Cls_layer is used to structure token classification.
Bbox_layer is used to regress bbox coord.
"""
def __init__(self,
in_channels,
out_channels=30,
headers=8,
d_ff=2048,
dropout=0,
max_text_length=500,
point_num=4,
**kwargs):
super(TableMasterHead, self).__init__()
hidden_size = in_channels[-1]
self.layers = clones(
DecoderLayer(headers, hidden_size, dropout, d_ff), 2)
self.cls_layer = clones(
DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
self.bbox_layer = clones(
DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
self.cls_fc = nn.Linear(hidden_size, out_channels)
self.bbox_fc = nn.Sequential(
# nn.Linear(hidden_size, hidden_size),
nn.Linear(hidden_size, point_num),
nn.Sigmoid())
self.norm = nn.LayerNorm(hidden_size)
self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels)
self.positional_encoding = PositionalEncoding(d_model=hidden_size)
self.SOS = out_channels - 3
self.PAD = out_channels - 1
self.out_channels = out_channels
self.point_num = point_num
self.max_text_length = max_text_length
def make_mask(self, tgt):
"""
Make mask for self attention.
:param src: [b, c, h, l_src]
:param tgt: [b, l_tgt]
:return:
"""
trg_pad_mask = (tgt != self.PAD).unsqueeze(1).unsqueeze(3)
tgt_len = paddle.shape(tgt)[1]
trg_sub_mask = paddle.tril(
paddle.ones(
([tgt_len, tgt_len]), dtype=paddle.float32))
tgt_mask = paddle.logical_and(
trg_pad_mask.astype(paddle.float32), trg_sub_mask)
return tgt_mask.astype(paddle.float32)
def decode(self, input, feature, src_mask, tgt_mask):
# main process of transformer decoder.
x = self.embedding(input) # x: 1*x*512, feature: 1*3600,512
x = self.positional_encoding(x)
# origin transformer layers
for i, layer in enumerate(self.layers):
x = layer(x, feature, src_mask, tgt_mask)
# cls head
for layer in self.cls_layer:
cls_x = layer(x, feature, src_mask, tgt_mask)
cls_x = self.norm(cls_x)
# bbox head
for layer in self.bbox_layer:
bbox_x = layer(x, feature, src_mask, tgt_mask)
bbox_x = self.norm(bbox_x)
return self.cls_fc(cls_x), self.bbox_fc(bbox_x)
def greedy_forward(self, SOS, feature):
input = SOS
output = paddle.zeros(
[input.shape[0], self.max_text_length + 1, self.out_channels])
bbox_output = paddle.zeros(
[input.shape[0], self.max_text_length + 1, self.point_num])
max_text_length = paddle.to_tensor(self.max_text_length)
for i in range(max_text_length + 1):
target_mask = self.make_mask(input)
out_step, bbox_output_step = self.decode(input, feature, None,
target_mask)
prob = F.softmax(out_step, axis=-1)
next_word = prob.argmax(axis=2, dtype="int64")
input = paddle.concat(
[input, next_word[:, -1].unsqueeze(-1)], axis=1)
if i == self.max_text_length:
output = out_step
bbox_output = bbox_output_step
return output, bbox_output
def forward_train(self, out_enc, targets):
# x is token of label
# feat is feature after backbone before pe.
# out_enc is feature after pe.
padded_targets = targets[0]
src_mask = None
tgt_mask = self.make_mask(padded_targets[:, :-1])
output, bbox_output = self.decode(padded_targets[:, :-1], out_enc,
src_mask, tgt_mask)
return {'structure_probs': output, 'loc_preds': bbox_output}
def forward_test(self, out_enc):
batch_size = out_enc.shape[0]
SOS = paddle.zeros([batch_size, 1], dtype='int64') + self.SOS
output, bbox_output = self.greedy_forward(SOS, out_enc)
# output = F.softmax(output)
return {'structure_probs': output, 'loc_preds': bbox_output}
def forward(self, feat, targets=None):
feat = feat[-1]
b, c, h, w = feat.shape
feat = feat.reshape([b, c, h * w]) # flatten 2D feature map
feat = feat.transpose((0, 2, 1))
out_enc = self.positional_encoding(feat)
if self.training:
return self.forward_train(out_enc, targets)
return self.forward_test(out_enc)
class DecoderLayer(nn.Layer):
"""
Decoder is made of self attention, srouce attention and feed forward.
"""
def __init__(self, headers, d_model, dropout, d_ff):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(headers, d_model, dropout)
self.src_attn = MultiHeadAttention(headers, d_model, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.sublayer = clones(SubLayerConnection(d_model, dropout), 3)
def forward(self, x, feature, src_mask, tgt_mask):
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
x = self.sublayer[1](
x, lambda x: self.src_attn(x, feature, feature, src_mask))
return self.sublayer[2](x, self.feed_forward)
class MultiHeadAttention(nn.Layer):
def __init__(self, headers, d_model, dropout):
super(MultiHeadAttention, self).__init__()
assert d_model % headers == 0
self.d_k = int(d_model / headers)
self.headers = headers
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
B = query.shape[0]
# 1) Do all the linear projections in batch from d_model => h x d_k
query, key, value = \
[l(x).reshape([B, 0, self.headers, self.d_k]).transpose([0, 2, 1, 3])
for l, x in zip(self.linears, (query, key, value))]
# 2) Apply attention on all the projected vectors in batch
x, self.attn = self_attention(
query, key, value, mask=mask, dropout=self.dropout)
x = x.transpose([0, 2, 1, 3]).reshape([B, 0, self.headers * self.d_k])
return self.linears[-1](x)
class FeedForward(nn.Layer):
def __init__(self, d_model, d_ff, dropout):
super(FeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class SubLayerConnection(nn.Layer):
"""
A residual connection followed by a layer norm.
Note for code simplicity the norm is first as opposed to last.
"""
def __init__(self, size, dropout):
super(SubLayerConnection, self).__init__()
self.norm = nn.LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))
def masked_fill(x, mask, value):
mask = mask.astype(x.dtype)
return x * paddle.logical_not(mask).astype(x.dtype) + mask * value
def self_attention(query, key, value, mask=None, dropout=None):
"""
Compute 'Scale Dot Product Attention'
"""
d_k = value.shape[-1]
score = paddle.matmul(query, key.transpose([0, 1, 3, 2]) / math.sqrt(d_k))
if mask is not None:
# score = score.masked_fill(mask == 0, -1e9) # b, h, L, L
score = masked_fill(score, mask == 0, -6.55e4) # for fp16
p_attn = F.softmax(score, axis=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return paddle.matmul(p_attn, value), p_attn
def clones(module, N):
""" Produce N identical layers """
return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
class Embeddings(nn.Layer):
def __init__(self, d_model, vocab):
super(Embeddings, self).__init__()
self.lut = nn.Embedding(vocab, d_model)
self.d_model = d_model
def forward(self, *input):
x = input[0]
return self.lut(x) * math.sqrt(self.d_model)
class PositionalEncoding(nn.Layer):
""" Implement the PE function. """
def __init__(self, d_model, dropout=0., max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Compute the positional encodings once in log space.
pe = paddle.zeros([max_len, d_model])
position = paddle.arange(0, max_len).unsqueeze(1).astype('float32')
div_term = paddle.exp(
paddle.arange(0, d_model, 2) * -math.log(10000.0) / d_model)
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, feat, **kwargs):
feat = feat + self.pe[:, :paddle.shape(feat)[1]] # pe 1*5000*512
return self.dropout(feat)
...@@ -308,3 +308,46 @@ class Const(object): ...@@ -308,3 +308,46 @@ class Const(object):
end_lr=self.learning_rate, end_lr=self.learning_rate,
last_epoch=self.last_epoch) last_epoch=self.last_epoch)
return learning_rate return learning_rate
class MultiStepDecay(object):
"""
Piecewise learning rate decay
Args:
step_each_epoch(int): steps each epoch
learning_rate (float): The initial learning rate. It is a python float number.
step_size (int): the interval to update.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self,
learning_rate,
milestones,
step_each_epoch,
gamma,
warmup_epoch=0,
last_epoch=-1,
**kwargs):
super(MultiStepDecay, self).__init__()
self.milestones = [step_each_epoch * e for e in milestones]
self.learning_rate = learning_rate
self.gamma = gamma
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
def __call__(self):
learning_rate = lr.MultiStepDecay(
learning_rate=self.learning_rate,
milestones=self.milestones,
gamma=self.gamma,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
\ No newline at end of file
...@@ -26,8 +26,9 @@ from .east_postprocess import EASTPostProcess ...@@ -26,8 +26,9 @@ from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode SEEDLabelDecode, PRENLabelDecode
from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
...@@ -42,7 +43,7 @@ def build_post_process(config, global_config=None): ...@@ -42,7 +43,7 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode' 'DistillationSARLabelDecode', 'TableMasterLabelDecode'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
...@@ -444,146 +444,6 @@ class SRNLabelDecode(BaseRecLabelDecode): ...@@ -444,146 +444,6 @@ class SRNLabelDecode(BaseRecLabelDecode):
return idx return idx
class TableLabelDecode(object):
""" """
def __init__(self, character_dict_path, **kwargs):
list_character, list_elem = self.load_char_elem_dict(
character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
self.dict_idx_character = {}
for i, char in enumerate(list_character):
self.dict_idx_character[i] = char
self.dict_character[char] = i
self.dict_elem = {}
self.dict_idx_elem = {}
for i, elem in enumerate(list_elem):
self.dict_idx_elem[i] = elem
self.dict_elem[elem] = i
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
"\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(loc_preds, paddle.Tensor):
loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
structure_idx, structure_probs, 'elem')
res_html_code_list = []
res_loc_list = []
batch_num = len(structure_str)
for bno in range(batch_num):
res_loc = []
for sno in range(len(structure_str[bno])):
text = structure_str[bno][sno]
if text in ['<td>', '<td']:
pos = structure_pos[bno][sno]
res_loc.append(loc_preds[bno, pos])
res_html_code = ''.join(structure_str[bno])
res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc)
return {
'res_html_code': res_html_code_list,
'res_loc': res_loc_list,
'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,
'structure_str_list': structure_str
}
def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
current_dict = self.dict_idx_character
else:
current_dict = self.dict_idx_elem
ignored_tokens = self.get_ignored_tokens('elem')
beg_idx, end_idx = ignored_tokens
result_list = []
result_pos_list = []
result_score_list = []
result_elem_idx_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
elem_pos_list = []
elem_idx_list = []
score_list = []
for idx in range(len(text_index[batch_idx])):
tmp_elem_idx = int(text_index[batch_idx][idx])
if idx > 0 and tmp_elem_idx == end_idx:
break
if tmp_elem_idx in ignored_tokens:
continue
char_list.append(current_dict[tmp_elem_idx])
elem_pos_list.append(idx)
score_list.append(structure_probs[batch_idx, idx])
elem_idx_list.append(tmp_elem_idx)
result_list.append(char_list)
result_pos_list.append(elem_pos_list)
result_score_list.append(score_list)
result_elem_idx_list.append(elem_idx_list)
return result_list, result_pos_list, result_score_list, result_elem_idx_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = self.dict_character[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_character[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = self.dict_elem[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_elem[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
class SARLabelDecode(BaseRecLabelDecode): class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .rec_postprocess import AttnLabelDecode
class TableLabelDecode(AttnLabelDecode):
""" """
def __init__(self, character_dict_path, **kwargs):
super(TableLabelDecode, self).__init__(character_dict_path)
self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
def __call__(self, preds, batch=None):
structure_probs = preds['structure_probs']
bbox_preds = preds['loc_preds']
if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(bbox_preds, paddle.Tensor):
bbox_preds = bbox_preds.numpy()
shape_list = batch[-1]
result = self.decode(structure_probs, bbox_preds, shape_list)
if len(batch) == 1: # only contains shape
return result
label_decode_result = self.decode_label(batch)
return result, label_decode_result
def decode(self, structure_probs, bbox_preds, shape_list):
"""convert text-label into text-index.
"""
ignored_tokens = self.get_ignored_tokens()
end_idx = self.dict[self.end_str]
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_batch_list = []
bbox_batch_list = []
batch_size = len(structure_idx)
for batch_idx in range(batch_size):
structure_list = []
bbox_list = []
score_list = []
for idx in range(len(structure_idx[batch_idx])):
char_idx = int(structure_idx[batch_idx][idx])
if idx > 0 and char_idx == end_idx:
break
if char_idx in ignored_tokens:
continue
text = self.character[char_idx]
if text in self.td_token:
bbox = bbox_preds[batch_idx, idx]
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
bbox_list.append(bbox)
structure_list.append(text)
score_list.append(structure_probs[batch_idx, idx])
structure_batch_list.append([structure_list, np.mean(score_list)])
bbox_batch_list.append(np.array(bbox_list))
result = {
'bbox_batch_list': bbox_batch_list,
'structure_batch_list': structure_batch_list,
}
return result
def decode_label(self, batch):
"""convert text-label into text-index.
"""
structure_idx = batch[1]
gt_bbox_list = batch[2]
shape_list = batch[-1]
ignored_tokens = self.get_ignored_tokens()
end_idx = self.dict[self.end_str]
structure_batch_list = []
bbox_batch_list = []
batch_size = len(structure_idx)
for batch_idx in range(batch_size):
structure_list = []
bbox_list = []
for idx in range(len(structure_idx[batch_idx])):
char_idx = int(structure_idx[batch_idx][idx])
if idx > 0 and char_idx == end_idx:
break
if char_idx in ignored_tokens:
continue
structure_list.append(self.character[char_idx])
bbox = gt_bbox_list[batch_idx][idx]
if bbox.sum() != 0:
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
bbox_list.append(bbox)
structure_batch_list.append(structure_list)
bbox_batch_list.append(bbox_list)
result = {
'bbox_batch_list': bbox_batch_list,
'structure_batch_list': structure_batch_list,
}
return result
def _bbox_decode(self, bbox, shape):
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
src_h = h / ratio_h
src_w = w / ratio_w
bbox[0::2] *= src_w
bbox[1::2] *= src_h
return bbox
class TableMasterLabelDecode(TableLabelDecode):
""" """
def __init__(self, character_dict_path, box_shape='ori', **kwargs):
super(TableMasterLabelDecode, self).__init__(character_dict_path)
self.box_shape = box_shape
assert box_shape in [
'ori', 'pad'
], 'The shape used for box normalization must be ori or pad'
def add_special_char(self, dict_character):
self.beg_str = '<SOS>'
self.end_str = '<EOS>'
self.unknown_str = '<UKN>'
self.pad_str = '<PAD>'
dict_character = dict_character
dict_character = dict_character + [
self.unknown_str, self.beg_str, self.end_str, self.pad_str
]
return dict_character
def get_ignored_tokens(self):
pad_idx = self.dict[self.pad_str]
start_idx = self.dict[self.beg_str]
end_idx = self.dict[self.end_str]
unknown_idx = self.dict[self.unknown_str]
return [start_idx, end_idx, pad_idx, unknown_idx]
def _bbox_decode(self, bbox, shape):
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
if self.box_shape == 'pad':
h, w = pad_h, pad_w
bbox[0::2] *= w
bbox[1::2] *= h
bbox[0::2] /= ratio_w
bbox[1::2] /= ratio_h
return bbox
<thead>
<tr>
<td></td>
</tr>
</thead>
<tbody>
<eb></eb>
</tbody>
<td
colspan="5"
>
</td>
colspan="2"
colspan="3"
<eb2></eb2>
<eb1></eb1>
rowspan="2"
colspan="4"
colspan="6"
rowspan="3"
colspan="9"
colspan="10"
colspan="7"
rowspan="4"
rowspan="5"
rowspan="9"
colspan="8"
rowspan="8"
rowspan="6"
rowspan="7"
rowspan="10"
<eb3></eb3>
<eb4></eb4>
<eb5></eb5>
<eb6></eb6>
<eb7></eb7>
<eb8></eb8>
<eb9></eb9>
<eb10></eb10>
...@@ -23,6 +23,7 @@ os.environ["FLAGS_allocator_strategy"] = 'auto_growth' ...@@ -23,6 +23,7 @@ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2 import cv2
import numpy as np import numpy as np
import time import time
import json
import tools.infer.utility as utility import tools.infer.utility as utility
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
...@@ -34,32 +35,50 @@ from ppstructure.utility import parse_args ...@@ -34,32 +35,50 @@ from ppstructure.utility import parse_args
logger = get_logger() logger = get_logger()
class TableStructurer(object): def build_pre_process_list(args):
def __init__(self, args): resize_op = {'ResizeTableImage': {'max_len': args.table_max_len, }}
pre_process_list = [{ pad_op = {
'ResizeTableImage': { 'PaddingTableImage': {
'max_len': args.table_max_len 'size': [args.table_max_len, args.table_max_len]
} }
}, { }
normalize_op = {
'NormalizeImage': { 'NormalizeImage': {
'std': [0.229, 0.224, 0.225], 'std': [0.229, 0.224, 0.225] if
'mean': [0.485, 0.456, 0.406], args.table_algorithm not in ['TableMaster'] else [0.5, 0.5, 0.5],
'mean': [0.485, 0.456, 0.406] if
args.table_algorithm not in ['TableMaster'] else [0.5, 0.5, 0.5],
'scale': '1./255.', 'scale': '1./255.',
'order': 'hwc' 'order': 'hwc'
} }
}, {
'PaddingTableImage': None
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image']
} }
}] to_chw_op = {'ToCHWImage': None}
keep_keys_op = {'KeepKeys': {'keep_keys': ['image', 'shape']}}
if args.table_algorithm not in ['TableMaster']:
pre_process_list = [
resize_op, normalize_op, pad_op, to_chw_op, keep_keys_op
]
else:
pre_process_list = [
resize_op, pad_op, normalize_op, to_chw_op, keep_keys_op
]
return pre_process_list
class TableStructurer(object):
def __init__(self, args):
pre_process_list = build_pre_process_list(args)
if args.table_algorithm not in ['TableMaster']:
postprocess_params = { postprocess_params = {
'name': 'TableLabelDecode', 'name': 'TableLabelDecode',
"character_dict_path": args.table_char_dict_path, "character_dict_path": args.table_char_dict_path,
} }
else:
postprocess_params = {
'name': 'TableMasterLabelDecode',
"character_dict_path": args.table_char_dict_path,
'box_shape': 'pad'
}
self.preprocess_op = create_operators(pre_process_list) self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
...@@ -88,27 +107,30 @@ class TableStructurer(object): ...@@ -88,27 +107,30 @@ class TableStructurer(object):
preds['structure_probs'] = outputs[1] preds['structure_probs'] = outputs[1]
preds['loc_preds'] = outputs[0] preds['loc_preds'] = outputs[0]
post_result = self.postprocess_op(preds) shape_list = np.expand_dims(data[-1], axis=0)
post_result = self.postprocess_op(preds, [shape_list])
structure_str_list = post_result['structure_str_list']
res_loc = post_result['res_loc'] structure_str_list = post_result['structure_batch_list'][0]
imgh, imgw = ori_im.shape[0:2] bbox_list = post_result['bbox_batch_list'][0]
res_loc_final = [] structure_str_list = structure_str_list[0]
for rno in range(len(res_loc[0])):
x0, y0, x1, y1 = res_loc[0][rno]
left = max(int(imgw * x0), 0)
top = max(int(imgh * y0), 0)
right = min(int(imgw * x1), imgw - 1)
bottom = min(int(imgh * y1), imgh - 1)
res_loc_final.append([left, top, right, bottom])
structure_str_list = structure_str_list[0][:-1]
structure_str_list = [ structure_str_list = [
'<html>', '<body>', '<table>' '<html>', '<body>', '<table>'
] + structure_str_list + ['</table>', '</body>', '</html>'] ] + structure_str_list + ['</table>', '</body>', '</html>']
elapse = time.time() - starttime elapse = time.time() - starttime
return (structure_str_list, res_loc_final), elapse return structure_str_list, bbox_list, elapse
def draw_rectangle(img_path, boxes, use_xywh=False):
img = cv2.imread(img_path)
img_show = img.copy()
for box in boxes.astype(int):
if use_xywh:
x, y, w, h = box
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
else:
x1, y1, x2, y2 = box
cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
return img_show
def main(args): def main(args):
...@@ -116,6 +138,11 @@ def main(args): ...@@ -116,6 +138,11 @@ def main(args):
table_structurer = TableStructurer(args) table_structurer = TableStructurer(args)
count = 0 count = 0
total_time = 0 total_time = 0
use_xywh = args.table_algorithm in ['TableMaster']
os.makedirs(args.output, exist_ok=True)
with open(
os.path.join(args.output, 'infer.txt'), mode='w',
encoding='utf-8') as f_w:
for image_file in image_file_list: for image_file in image_file_list:
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
...@@ -123,10 +150,19 @@ def main(args): ...@@ -123,10 +150,19 @@ def main(args):
if img is None: if img is None:
logger.info("error in loading image:{}".format(image_file)) logger.info("error in loading image:{}".format(image_file))
continue continue
structure_res, elapse = table_structurer(img) structure_str_list, bbox_list, elapse = table_structurer(img)
logger.info("result: {}".format(structure_res)) bbox_list_str = json.dumps(bbox_list.tolist())
logger.info("result: {}, {}".format(structure_str_list,
bbox_list_str))
f_w.write("result: {}, {}\n".format(structure_str_list,
bbox_list_str))
img = draw_rectangle(image_file, bbox_list, use_xywh)
img_save_path = os.path.join(args.output,
os.path.basename(image_file))
cv2.imwrite(img_save_path, img)
logger.info("save vis result to {}".format(img_save_path))
if count > 0: if count > 0:
total_time += elapse total_time += elapse
count += 1 count += 1
......
...@@ -25,6 +25,7 @@ def init_args(): ...@@ -25,6 +25,7 @@ def init_args():
parser.add_argument("--output", type=str, default='./output') parser.add_argument("--output", type=str, default='./output')
# params for table structure # params for table structure
parser.add_argument("--table_max_len", type=int, default=488) parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_algorithm", type=str, default='TableAttn')
parser.add_argument("--table_model_dir", type=str) parser.add_argument("--table_model_dir", type=str)
parser.add_argument( parser.add_argument(
"--table_char_dict_path", "--table_char_dict_path",
......
...@@ -88,6 +88,8 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None): ...@@ -88,6 +88,8 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
infer_shape = [1, 32, 100] infer_shape = [1, 32, 100]
elif arch_config["model_type"] == "table": elif arch_config["model_type"] == "table":
infer_shape = [3, 488, 488] infer_shape = [3, 488, 488]
if arch_config["algorithm"] == "TableMaster":
infer_shape = [3, 480, 480]
model = to_static( model = to_static(
model, model,
input_spec=[ input_spec=[
......
...@@ -40,6 +40,7 @@ import tools.program as program ...@@ -40,6 +40,7 @@ import tools.program as program
import cv2 import cv2
@paddle.no_grad()
def main(config, device, logger, vdl_writer): def main(config, device, logger, vdl_writer):
global_config = config['Global'] global_config = config['Global']
...@@ -53,27 +54,31 @@ def main(config, device, logger, vdl_writer): ...@@ -53,27 +54,31 @@ def main(config, device, logger, vdl_writer):
getattr(post_process_class, 'character')) getattr(post_process_class, 'character'))
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
algorithm = config['Architecture']['algorithm']
use_xywh = algorithm in ['TableMaster']
load_model(config, model) load_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
use_padding = False
for op in config['Eval']['dataset']['transforms']: for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0] op_name = list(op)[0]
if 'Label' in op_name: if 'Encode' in op_name:
continue continue
if op_name == 'KeepKeys': if op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image'] op[op_name]['keep_keys'] = ['image', 'shape']
if op_name == "ResizeTableImage":
use_padding = True
padding_max_len = op['ResizeTableImage']['max_len']
transforms.append(op) transforms.append(op)
global_config['infer_mode'] = True global_config['infer_mode'] = True
ops = create_operators(transforms, global_config) ops = create_operators(transforms, global_config)
save_res_path = config['Global']['save_res_path']
os.makedirs(save_res_path, exist_ok=True)
model.eval() model.eval()
with open(
os.path.join(save_res_path, 'infer.txt'), mode='w',
encoding='utf-8') as f_w:
for file in get_image_file_list(config['Global']['infer_img']): for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file)) logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f: with open(file, 'rb') as f:
...@@ -81,27 +86,44 @@ def main(config, device, logger, vdl_writer): ...@@ -81,27 +86,44 @@ def main(config, device, logger, vdl_writer):
data = {'image': img} data = {'image': img}
batch = transform(data, ops) batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0) images = np.expand_dims(batch[0], axis=0)
shape_list = np.expand_dims(batch[1], axis=0)
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
preds = model(images) preds = model(images)
post_result = post_process_class(preds) post_result = post_process_class(preds, [shape_list])
res_html_code = post_result['res_html_code']
res_loc = post_result['res_loc'] structure_str_list = post_result['structure_batch_list'][0]
img = cv2.imread(file) bbox_list = post_result['bbox_batch_list'][0]
imgh, imgw = img.shape[0:2] structure_str_list = structure_str_list[0]
res_loc_final = [] structure_str_list = [
for rno in range(len(res_loc[0])): '<html>', '<body>', '<table>'
x0, y0, x1, y1 = res_loc[0][rno] ] + structure_str_list + ['</table>', '</body>', '</html>']
left = max(int(imgw * x0), 0) bbox_list_str = json.dumps(bbox_list.tolist())
top = max(int(imgh * y0), 0)
right = min(int(imgw * x1), imgw - 1) logger.info("result: {}, {}".format(structure_str_list,
bottom = min(int(imgh * y1), imgh - 1) bbox_list_str))
cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2) f_w.write("result: {}, {}\n".format(structure_str_list,
res_loc_final.append([left, top, right, bottom]) bbox_list_str))
res_loc_str = json.dumps(res_loc_final)
logger.info("result: {}, {}".format(res_html_code, res_loc_final)) img = draw_rectangle(file, bbox_list, use_xywh)
cv2.imwrite(
os.path.join(save_res_path, os.path.basename(file)), img)
logger.info("success!") logger.info("success!")
def draw_rectangle(img_path, boxes, use_xywh=False):
img = cv2.imread(img_path)
img_show = img.copy()
for box in boxes.astype(int):
if use_xywh:
x, y, w, h = box
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
else:
x1, y1, x2, y2 = box
cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
return img_show
if __name__ == '__main__': if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess() config, device, logger, vdl_writer = program.preprocess()
main(config, device, logger, vdl_writer) main(config, device, logger, vdl_writer)
...@@ -274,8 +274,11 @@ def train(config, ...@@ -274,8 +274,11 @@ def train(config,
if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
if model_type in ['table', 'kie']: if model_type in ['kie']:
eval_class(preds, batch) eval_class(preds, batch)
elif model_type in ['table']:
post_result = post_process_class(preds, batch)
eval_class(post_result, batch)
else: else:
if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2' if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
]: # for multi head loss ]: # for multi head loss
...@@ -302,7 +305,8 @@ def train(config, ...@@ -302,7 +305,8 @@ def train(config,
train_stats.update(stats) train_stats.update(stats)
if log_writer is not None and dist.get_rank() == 0: if log_writer is not None and dist.get_rank() == 0:
log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step) log_writer.log_metrics(
metrics=train_stats.get(), prefix="TRAIN", step=global_step)
if dist.get_rank() == 0 and ( if dist.get_rank() == 0 and (
(global_step > 0 and global_step % print_batch_step == 0) or (global_step > 0 and global_step % print_batch_step == 0) or
...@@ -349,7 +353,8 @@ def train(config, ...@@ -349,7 +353,8 @@ def train(config,
# logger metric # logger metric
if log_writer is not None: if log_writer is not None:
log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step) log_writer.log_metrics(
metrics=cur_metric, prefix="EVAL", step=global_step)
if cur_metric[main_indicator] >= best_model_dict[ if cur_metric[main_indicator] >= best_model_dict[
main_indicator]: main_indicator]:
...@@ -372,11 +377,18 @@ def train(config, ...@@ -372,11 +377,18 @@ def train(config,
logger.info(best_str) logger.info(best_str)
# logger best metric # logger best metric
if log_writer is not None: if log_writer is not None:
log_writer.log_metrics(metrics={ log_writer.log_metrics(
"best_{}".format(main_indicator): best_model_dict[main_indicator] metrics={
}, prefix="EVAL", step=global_step) "best_{}".format(main_indicator):
best_model_dict[main_indicator]
log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict) },
prefix="EVAL",
step=global_step)
log_writer.log_model(
is_best=True,
prefix="best_accuracy",
metadata=best_model_dict)
reader_start = time.time() reader_start = time.time()
if dist.get_rank() == 0: if dist.get_rank() == 0:
...@@ -408,7 +420,8 @@ def train(config, ...@@ -408,7 +420,8 @@ def train(config,
epoch=epoch, epoch=epoch,
global_step=global_step) global_step=global_step)
if log_writer is not None: if log_writer is not None:
log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch)) log_writer.log_model(
is_best=False, prefix='iter_epoch_{}'.format(epoch))
best_str = 'best metric, {}'.format(', '.join( best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
...@@ -446,7 +459,6 @@ def eval(model, ...@@ -446,7 +459,6 @@ def eval(model,
preds = model(batch) preds = model(batch)
else: else:
preds = model(images) preds = model(images)
batch_numpy = [] batch_numpy = []
for item in batch: for item in batch:
if isinstance(item, paddle.Tensor): if isinstance(item, paddle.Tensor):
...@@ -456,9 +468,9 @@ def eval(model, ...@@ -456,9 +468,9 @@ def eval(model,
# Obtain usable results from post-processing methods # Obtain usable results from post-processing methods
total_time += time.time() - start total_time += time.time() - start
# Evaluate the results of the current batch # Evaluate the results of the current batch
if model_type in ['table', 'kie']: if model_type in ['kie']:
eval_class(preds, batch_numpy) eval_class(preds, batch_numpy)
elif model_type in ['vqa']: elif model_type in ['table', 'vqa']:
post_result = post_process_class(preds, batch_numpy) post_result = post_process_class(preds, batch_numpy)
eval_class(post_result, batch_numpy) eval_class(post_result, batch_numpy)
else: else:
...@@ -559,7 +571,8 @@ def preprocess(is_train=False): ...@@ -559,7 +571,8 @@ def preprocess(is_train=False):
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR' 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR',
'TableMaster'
] ]
device = 'cpu' device = 'cpu'
...@@ -578,7 +591,8 @@ def preprocess(is_train=False): ...@@ -578,7 +591,8 @@ def preprocess(is_train=False):
vdl_writer_path = '{}/vdl/'.format(save_model_dir) vdl_writer_path = '{}/vdl/'.format(save_model_dir)
log_writer = VDLLogger(save_model_dir) log_writer = VDLLogger(save_model_dir)
loggers.append(log_writer) loggers.append(log_writer)
if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config: if ('use_wandb' in config['Global'] and
config['Global']['use_wandb']) or 'wandb' in config:
save_dir = config['Global']['save_model_dir'] save_dir = config['Global']['save_model_dir']
wandb_writer_path = "{}/wandb".format(save_dir) wandb_writer_path = "{}/wandb".format(save_dir)
if "wandb" in config: if "wandb" in config:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册