', '', ''] + dict_character
- return dict_character
-
-
class CTCLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
@@ -290,15 +259,26 @@ class E2ELabelEncodeTrain(object):
class KieLabelEncode(object):
- def __init__(self, character_dict_path, norm=10, directed=False, **kwargs):
+ def __init__(self,
+ character_dict_path,
+ class_path,
+ norm=10,
+ directed=False,
+ **kwargs):
super(KieLabelEncode, self).__init__()
self.dict = dict({'': 0})
+ self.label2classid_map = dict()
with open(character_dict_path, 'r', encoding='utf-8') as fr:
idx = 1
for line in fr:
char = line.strip()
self.dict[char] = idx
idx += 1
+ with open(class_path, "r") as fin:
+ lines = fin.readlines()
+ for idx, line in enumerate(lines):
+ line = line.strip("\n")
+ self.label2classid_map[line] = idx
self.norm = norm
self.directed = directed
@@ -438,12 +418,14 @@ class KieLabelEncode(object):
texts.append(ann['transcription'])
text_ind = [self.dict[c] for c in text if c in self.dict]
text_inds.append(text_ind)
- if 'label' in anno.keys():
- labels.append(ann['label'])
- elif 'key_cls' in anno.keys():
- labels.append(anno['key_cls'])
+ if 'label' in ann.keys():
+ labels.append(self.label2classid_map[ann['label']])
+ elif 'key_cls' in ann.keys():
+ labels.append(ann['key_cls'])
else:
- raise ValueError("Cannot found 'key_cls' in ann.keys(), please check your training annotation.")
+ raise ValueError(
+ "Cannot found 'key_cls' in ann.keys(), please check your training annotation."
+ )
edges.append(ann.get('edge', 0))
ann_infos = dict(
image=data['image'],
@@ -580,171 +562,210 @@ class SRNLabelEncode(BaseRecLabelEncode):
return idx
-class TableLabelEncode(object):
+class TableLabelEncode(AttnLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
- max_elem_length,
- max_cell_num,
character_dict_path,
- span_weight=1.0,
+ replace_empty_cell_token=False,
+ merge_no_span_structure=False,
+ learn_empty_box=False,
+ point_num=2,
**kwargs):
- self.max_text_length = max_text_length
- self.max_elem_length = max_elem_length
- self.max_cell_num = max_cell_num
- list_character, list_elem = self.load_char_elem_dict(
- character_dict_path)
- list_character = self.add_special_char(list_character)
- list_elem = self.add_special_char(list_elem)
- self.dict_character = {}
- for i, char in enumerate(list_character):
- self.dict_character[char] = i
- self.dict_elem = {}
- for i, elem in enumerate(list_elem):
- self.dict_elem[elem] = i
- self.span_weight = span_weight
-
- def load_char_elem_dict(self, character_dict_path):
- list_character = []
- list_elem = []
+ self.max_text_len = max_text_length
+ self.lower = False
+ self.learn_empty_box = learn_empty_box
+ self.merge_no_span_structure = merge_no_span_structure
+ self.replace_empty_cell_token = replace_empty_cell_token
+
+ dict_character = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
- substr = lines[0].decode('utf-8').strip("\r\n").split("\t")
- character_num = int(substr[0])
- elem_num = int(substr[1])
- for cno in range(1, 1 + character_num):
- character = lines[cno].decode('utf-8').strip("\r\n")
- list_character.append(character)
- for eno in range(1 + character_num, 1 + character_num + elem_num):
- elem = lines[eno].decode('utf-8').strip("\r\n")
- list_elem.append(elem)
- return list_character, list_elem
-
- def add_special_char(self, list_character):
- self.beg_str = "sos"
- self.end_str = "eos"
- list_character = [self.beg_str] + list_character + [self.end_str]
- return list_character
+ for line in lines:
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
+ dict_character.append(line)
+
+ dict_character = self.add_special_char(dict_character)
+ self.dict = {}
+ for i, char in enumerate(dict_character):
+ self.dict[char] = i
+ self.idx2char = {v: k for k, v in self.dict.items()}
- def get_span_idx_list(self):
- span_idx_list = []
- for elem in self.dict_elem:
- if 'span' in elem:
- span_idx_list.append(self.dict_elem[elem])
- return span_idx_list
+ self.character = dict_character
+ self.point_num = point_num
+ self.pad_idx = self.dict[self.beg_str]
+ self.start_idx = self.dict[self.beg_str]
+ self.end_idx = self.dict[self.end_str]
+
+ self.td_token = ['', ' | ', ' | | ']
+ self.empty_bbox_token_dict = {
+ "[]": '',
+ "[' ']": '',
+ "['', ' ', '']": '',
+ "['\\u2028', '\\u2028']": '',
+ "['', ' ', '']": '',
+ "['', '']": '',
+ "['', ' ', '']": '',
+ "['', '', '', '']": '',
+ "['', '', ' ', '', '']": '',
+ "['', '']": '',
+ "['', ' ', '\\u2028', ' ', '\\u2028', ' ', '']":
+ '',
+ }
+
+ @property
+ def _max_text_len(self):
+ return self.max_text_len + 2
def __call__(self, data):
cells = data['cells']
- structure = data['structure']['tokens']
- structure = self.encode(structure, 'elem')
+ structure = data['structure']
+ if self.merge_no_span_structure:
+ structure = self._merge_no_span_structure(structure)
+ if self.replace_empty_cell_token:
+ structure = self._replace_empty_cell_token(structure, cells)
+ # remove empty token and add " " to span token
+ new_structure = []
+ for token in structure:
+ if token != '':
+ if 'span' in token and token[0] != ' ':
+ token = ' ' + token
+ new_structure.append(token)
+ # encode structure
+ structure = self.encode(new_structure)
if structure is None:
return None
- elem_num = len(structure)
- structure = [0] + structure + [len(self.dict_elem) - 1]
- structure = structure + [0] * (self.max_elem_length + 2 - len(structure)
- )
+
+ structure = [self.start_idx] + structure + [self.end_idx
+ ] # add sos abd eos
+ structure = structure + [self.pad_idx] * (self._max_text_len -
+ len(structure)) # pad
structure = np.array(structure)
data['structure'] = structure
- elem_char_idx1 = self.dict_elem['']
- elem_char_idx2 = self.dict_elem[' | 0:
- span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
- span_weight = min(max(span_weight, 1.0), self.span_weight)
- for cno in range(len(cells)):
- if 'bbox' in cells[cno]:
- bbox = cells[cno]['bbox'].copy()
- bbox[0] = bbox[0] * 1.0 / img_width
- bbox[1] = bbox[1] * 1.0 / img_height
- bbox[2] = bbox[2] * 1.0 / img_width
- bbox[3] = bbox[3] * 1.0 / img_height
- td_idx = td_idx_list[cno]
- bbox_list[td_idx] = bbox
- bbox_list_mask[td_idx] = 1.0
- cand_span_idx = td_idx + 1
- if cand_span_idx < (self.max_elem_length + 2):
- if structure[cand_span_idx] in span_idx_list:
- structure_mask[cand_span_idx] = span_weight
-
- data['bbox_list'] = bbox_list
- data['bbox_list_mask'] = bbox_list_mask
- data['structure_mask'] = structure_mask
- char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
- char_end_idx = self.get_beg_end_flag_idx('end', 'char')
- elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
- elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
- data['sp_tokens'] = np.array([
- char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
- elem_char_idx1, elem_char_idx2, self.max_text_length,
- self.max_elem_length, self.max_cell_num, elem_num
- ])
+
+ if len(structure) > self._max_text_len:
+ return None
+
+ # encode box
+ bboxes = np.zeros(
+ (self._max_text_len, self.point_num * 2), dtype=np.float32)
+ bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
+
+ bbox_idx = 0
+
+ for i, token in enumerate(structure):
+ if self.idx2char[token] in self.td_token:
+ if 'bbox' in cells[bbox_idx] and len(cells[bbox_idx][
+ 'tokens']) > 0:
+ bbox = cells[bbox_idx]['bbox'].copy()
+ bbox = np.array(bbox, dtype=np.float32).reshape(-1)
+ bboxes[i] = bbox
+ bbox_masks[i] = 1.0
+ if self.learn_empty_box:
+ bbox_masks[i] = 1.0
+ bbox_idx += 1
+ data['bboxes'] = bboxes
+ data['bbox_masks'] = bbox_masks
return data
- def encode(self, text, char_or_elem):
- """convert text-label into text-index.
+ def _merge_no_span_structure(self, structure):
"""
- if char_or_elem == "char":
- max_len = self.max_text_length
- current_dict = self.dict_character
- else:
- max_len = self.max_elem_length
- current_dict = self.dict_elem
- if len(text) > max_len:
- return None
- if len(text) == 0:
- if char_or_elem == "char":
- return [self.dict_character['space']]
- else:
- return None
- text_list = []
- for char in text:
- if char not in current_dict:
- return None
- text_list.append(current_dict[char])
- if len(text_list) == 0:
- if char_or_elem == "char":
- return [self.dict_character['space']]
+ This code is refer from:
+ https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
+ """
+ new_structure = []
+ i = 0
+ while i < len(structure):
+ token = structure[i]
+ if token == ' | ':
+ token = ' | | '
+ i += 1
+ new_structure.append(token)
+ i += 1
+ return new_structure
+
+ def _replace_empty_cell_token(self, token_list, cells):
+ """
+ This fun code is refer from:
+ https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
+ """
+
+ bbox_idx = 0
+ add_empty_bbox_token_list = []
+ for token in token_list:
+ if token in [' | ', '']:
+ if 'bbox' not in cells[bbox_idx].keys():
+ content = str(cells[bbox_idx]['tokens'])
+ token = self.empty_bbox_token_dict[content]
+ add_empty_bbox_token_list.append(token)
+ bbox_idx += 1
else:
- return None
- return text_list
+ add_empty_bbox_token_list.append(token)
+ return add_empty_bbox_token_list
- def get_ignored_tokens(self, char_or_elem):
- beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
- end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
- return [beg_idx, end_idx]
- def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
- if char_or_elem == "char":
- if beg_or_end == "beg":
- idx = np.array(self.dict_character[self.beg_str])
- elif beg_or_end == "end":
- idx = np.array(self.dict_character[self.end_str])
- else:
- assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
- % beg_or_end
- elif char_or_elem == "elem":
- if beg_or_end == "beg":
- idx = np.array(self.dict_elem[self.beg_str])
- elif beg_or_end == "end":
- idx = np.array(self.dict_elem[self.end_str])
- else:
- assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
- % beg_or_end
- else:
- assert False, "Unsupport type %s in char_or_elem" \
- % char_or_elem
- return idx
+class TableMasterLabelEncode(TableLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path,
+ replace_empty_cell_token=False,
+ merge_no_span_structure=False,
+ learn_empty_box=False,
+ point_num=2,
+ **kwargs):
+ super(TableMasterLabelEncode, self).__init__(
+ max_text_length, character_dict_path, replace_empty_cell_token,
+ merge_no_span_structure, learn_empty_box, point_num, **kwargs)
+ self.pad_idx = self.dict[self.pad_str]
+ self.unknown_idx = self.dict[self.unknown_str]
+
+ @property
+ def _max_text_len(self):
+ return self.max_text_len
+
+ def add_special_char(self, dict_character):
+ self.beg_str = ''
+ self.end_str = ''
+ self.unknown_str = ''
+ self.pad_str = ''
+ dict_character = dict_character
+ dict_character = dict_character + [
+ self.unknown_str, self.beg_str, self.end_str, self.pad_str
+ ]
+ return dict_character
+
+
+class TableBoxEncode(object):
+ def __init__(self, use_xywh=False, **kwargs):
+ self.use_xywh = use_xywh
+
+ def __call__(self, data):
+ img_height, img_width = data['image'].shape[:2]
+ bboxes = data['bboxes']
+ if self.use_xywh and bboxes.shape[1] == 4:
+ bboxes = self.xyxy2xywh(bboxes)
+ bboxes[:, 0::2] /= img_width
+ bboxes[:, 1::2] /= img_height
+ data['bboxes'] = bboxes
+ return data
+
+ def xyxy2xywh(self, bboxes):
+ """
+ Convert coord (x1,y1,x2,y2) to (x,y,w,h).
+ where (x1,y1) is top-left, (x2,y2) is bottom-right.
+ (x,y) is bbox center and (w,h) is width and height.
+ :param bboxes: (x1, y1, x2, y2)
+ :return:
+ """
+ new_bboxes = np.empty_like(bboxes)
+ new_bboxes[:, 0] = (bboxes[:, 0] + bboxes[:, 2]) / 2 # x center
+ new_bboxes[:, 1] = (bboxes[:, 1] + bboxes[:, 3]) / 2 # y center
+ new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] # width
+ new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] # height
+ return new_bboxes
class SARLabelEncode(BaseRecLabelEncode):
@@ -848,6 +869,7 @@ class VQATokenLabelEncode(object):
contains_re=False,
add_special_ids=False,
algorithm='LayoutXLM',
+ use_textline_bbox_info=True,
infer_mode=False,
ocr_engine=None,
**kwargs):
@@ -876,11 +898,51 @@ class VQATokenLabelEncode(object):
self.add_special_ids = add_special_ids
self.infer_mode = infer_mode
self.ocr_engine = ocr_engine
+ self.use_textline_bbox_info = use_textline_bbox_info
+
+ def split_bbox(self, bbox, text, tokenizer):
+ words = text.split()
+ token_bboxes = []
+ curr_word_idx = 0
+ x1, y1, x2, y2 = bbox
+ unit_w = (x2 - x1) / len(text)
+ for idx, word in enumerate(words):
+ curr_w = len(word) * unit_w
+ word_bbox = [x1, y1, x1 + curr_w, y2]
+ token_bboxes.extend([word_bbox] * len(tokenizer.tokenize(word)))
+ x1 += (len(word) + 1) * unit_w
+ return token_bboxes
+
+ def filter_empty_contents(self, ocr_info):
+ """
+ find out the empty texts and remove the links
+ """
+ new_ocr_info = []
+ empty_index = []
+ for idx, info in enumerate(ocr_info):
+ if len(info["transcription"]) > 0:
+ new_ocr_info.append(copy.deepcopy(info))
+ else:
+ empty_index.append(info["id"])
+
+ for idx, info in enumerate(new_ocr_info):
+ new_link = []
+ for link in info["linking"]:
+ if link[0] in empty_index or link[1] in empty_index:
+ continue
+ new_link.append(link)
+ new_ocr_info[idx]["linking"] = new_link
+ return new_ocr_info
def __call__(self, data):
# load bbox and label info
ocr_info = self._load_ocr_info(data)
+ # for re
+ train_re = self.contains_re and not self.infer_mode
+ if train_re:
+ ocr_info = self.filter_empty_contents(ocr_info)
+
height, width, _ = data['image'].shape
words_list = []
@@ -892,8 +954,6 @@ class VQATokenLabelEncode(object):
entities = []
- # for re
- train_re = self.contains_re and not self.infer_mode
if train_re:
relations = []
id2label = {}
@@ -903,17 +963,19 @@ class VQATokenLabelEncode(object):
data['ocr_info'] = copy.deepcopy(ocr_info)
for info in ocr_info:
+ text = info["transcription"]
+ if len(text) <= 0:
+ continue
if train_re:
# for re
- if len(info["text"]) == 0:
+ if len(text) == 0:
empty_entity.add(info["id"])
continue
id2label[info["id"]] = info["label"]
relations.extend([tuple(sorted(l)) for l in info["linking"]])
# smooth_box
- bbox = self._smooth_box(info["bbox"], height, width)
+ info["bbox"] = self.trans_poly_to_bbox(info["points"])
- text = info["text"]
encode_res = self.tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
@@ -924,6 +986,19 @@ class VQATokenLabelEncode(object):
-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:
-1]
+
+ if self.use_textline_bbox_info:
+ bbox = [info["bbox"]] * len(encode_res["input_ids"])
+ else:
+ bbox = self.split_bbox(info["bbox"], info["transcription"],
+ self.tokenizer)
+ if len(bbox) <= 0:
+ continue
+ bbox = self._smooth_box(bbox, height, width)
+ if self.add_special_ids:
+ bbox.insert(0, [0, 0, 0, 0])
+ bbox.append([0, 0, 0, 0])
+
# parse label
if not self.infer_mode:
label = info['label']
@@ -948,7 +1023,7 @@ class VQATokenLabelEncode(object):
})
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
- bbox_list.extend([bbox] * len(encode_res["input_ids"]))
+ bbox_list.extend(bbox)
words_list.append(text)
segment_offset_id.append(len(input_ids_list))
if not self.infer_mode:
@@ -973,40 +1048,42 @@ class VQATokenLabelEncode(object):
data['entity_id_to_index_map'] = entity_id_to_index_map
return data
- def _load_ocr_info(self, data):
- def trans_poly_to_bbox(poly):
- x1 = np.min([p[0] for p in poly])
- x2 = np.max([p[0] for p in poly])
- y1 = np.min([p[1] for p in poly])
- y2 = np.max([p[1] for p in poly])
- return [x1, y1, x2, y2]
+ def trans_poly_to_bbox(self, poly):
+ x1 = np.min([p[0] for p in poly])
+ x2 = np.max([p[0] for p in poly])
+ y1 = np.min([p[1] for p in poly])
+ y2 = np.max([p[1] for p in poly])
+ return [x1, y1, x2, y2]
+ def _load_ocr_info(self, data):
if self.infer_mode:
ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
ocr_info = []
for res in ocr_result:
ocr_info.append({
- "text": res[1][0],
- "bbox": trans_poly_to_bbox(res[0]),
- "poly": res[0],
+ "transcription": res[1][0],
+ "bbox": self.trans_poly_to_bbox(res[0]),
+ "points": res[0],
})
return ocr_info
else:
info = data['label']
# read text info
info_dict = json.loads(info)
- return info_dict["ocr_info"]
+ return info_dict
- def _smooth_box(self, bbox, height, width):
- bbox[0] = int(bbox[0] * 1000.0 / width)
- bbox[2] = int(bbox[2] * 1000.0 / width)
- bbox[1] = int(bbox[1] * 1000.0 / height)
- bbox[3] = int(bbox[3] * 1000.0 / height)
- return bbox
+ def _smooth_box(self, bboxes, height, width):
+ bboxes = np.array(bboxes)
+ bboxes[:, 0] = bboxes[:, 0] * 1000 / width
+ bboxes[:, 2] = bboxes[:, 2] * 1000 / width
+ bboxes[:, 1] = bboxes[:, 1] * 1000 / height
+ bboxes[:, 3] = bboxes[:, 3] * 1000 / height
+ bboxes = bboxes.astype("int64").tolist()
+ return bboxes
def _parse_label(self, label, encode_res):
gt_label = []
- if label.lower() == "other":
+ if label.lower() in ["other", "others", "ignore"]:
gt_label.extend([0] * len(encode_res["input_ids"]))
else:
gt_label.append(self.label2id_map[("b-" + label).upper()])
@@ -1030,7 +1107,6 @@ class MultiLabelEncode(BaseRecLabelEncode):
use_space_char, **kwargs)
def __call__(self, data):
-
data_ctc = copy.deepcopy(data)
data_sar = copy.deepcopy(data)
data_out = dict()
@@ -1044,3 +1120,99 @@ class MultiLabelEncode(BaseRecLabelEncode):
data_out['label_sar'] = sar['label']
data_out['length'] = ctc['length']
return data_out
+
+
+class NRTRLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+
+ super(NRTRLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) >= self.max_text_len - 1:
+ return None
+ data['length'] = np.array(len(text))
+ text.insert(0, 2)
+ text.append(3)
+ text = text + [0] * (self.max_text_len - len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = ['blank', '', '', ''] + dict_character
+ return dict_character
+
+
+class ViTSTRLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ ignore_index=0,
+ **kwargs):
+
+ super(ViTSTRLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+ self.ignore_index = ignore_index
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) >= self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+ text.insert(0, self.ignore_index)
+ text.append(1)
+ text = text + [self.ignore_index] * (self.max_text_len + 2 - len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = ['', ''] + dict_character
+ return dict_character
+
+
+class ABINetLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ ignore_index=100,
+ **kwargs):
+
+ super(ABINetLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+ self.ignore_index = ignore_index
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) >= self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+ text.append(0)
+ text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = [''] + dict_character
+ return dict_character
diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py
index 09736515e7a388e191a12826e1e9e348e2fcde86..04cc2848fb4d25baaf553c6eda235ddb0e86511f 100644
--- a/ppocr/data/imaug/operators.py
+++ b/ppocr/data/imaug/operators.py
@@ -67,39 +67,6 @@ class DecodeImage(object):
return data
-class NRTRDecodeImage(object):
- """ decode image """
-
- def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
- self.img_mode = img_mode
- self.channel_first = channel_first
-
- def __call__(self, data):
- img = data['image']
- if six.PY2:
- assert type(img) is str and len(
- img) > 0, "invalid input 'img' in DecodeImage"
- else:
- assert type(img) is bytes and len(
- img) > 0, "invalid input 'img' in DecodeImage"
- img = np.frombuffer(img, dtype='uint8')
-
- img = cv2.imdecode(img, 1)
-
- if img is None:
- return None
- if self.img_mode == 'GRAY':
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
- elif self.img_mode == 'RGB':
- assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
- img = img[:, :, ::-1]
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- if self.channel_first:
- img = img.transpose((2, 0, 1))
- data['image'] = img
- return data
-
-
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
@@ -238,9 +205,12 @@ class DetResizeForTest(object):
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
+ self.keep_ratio = False
if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape']
self.resize_type = 1
+ if 'keep_ratio' in kwargs:
+ self.keep_ratio = kwargs['keep_ratio']
elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
@@ -270,6 +240,10 @@ class DetResizeForTest(object):
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
+ if self.keep_ratio is True:
+ resize_w = ori_w * resize_h / ori_h
+ N = math.ceil(resize_w / 32)
+ resize_w = N * 32
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index aa4523329d3881a4cadb185f00beea38bf109cd3..1055e369e4cdf8edfbff94fec0b20520001de11d 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -19,6 +19,8 @@ import random
import copy
from PIL import Image
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
+from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter
+from paddle.vision.transforms import Compose
class RecAug(object):
@@ -94,6 +96,36 @@ class BaseDataAugmentation(object):
return data
+class ABINetRecAug(object):
+ def __init__(self,
+ geometry_p=0.5,
+ deterioration_p=0.25,
+ colorjitter_p=0.25,
+ **kwargs):
+ self.transforms = Compose([
+ CVGeometry(
+ degrees=45,
+ translate=(0.0, 0.0),
+ scale=(0.5, 2.),
+ shear=(45, 15),
+ distortion=0.5,
+ p=geometry_p), CVDeterioration(
+ var=20, degrees=6, factor=4, p=deterioration_p),
+ CVColorJitter(
+ brightness=0.5,
+ contrast=0.5,
+ saturation=0.5,
+ hue=0.1,
+ p=colorjitter_p)
+ ])
+
+ def __call__(self, data):
+ img = data['image']
+ img = self.transforms(img)
+ data['image'] = img
+ return data
+
+
class RecConAug(object):
def __init__(self,
prob=0.5,
@@ -148,46 +180,6 @@ class ClsResizeImg(object):
return data
-class NRTRRecResizeImg(object):
- def __init__(self, image_shape, resize_type, padding=False, **kwargs):
- self.image_shape = image_shape
- self.resize_type = resize_type
- self.padding = padding
-
- def __call__(self, data):
- img = data['image']
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- image_shape = self.image_shape
- if self.padding:
- imgC, imgH, imgW = image_shape
- # todo: change to 0 and modified image shape
- h = img.shape[0]
- w = img.shape[1]
- ratio = w / float(h)
- if math.ceil(imgH * ratio) > imgW:
- resized_w = imgW
- else:
- resized_w = int(math.ceil(imgH * ratio))
- resized_image = cv2.resize(img, (resized_w, imgH))
- norm_img = np.expand_dims(resized_image, -1)
- norm_img = norm_img.transpose((2, 0, 1))
- resized_image = norm_img.astype(np.float32) / 128. - 1.
- padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
- padding_im[:, :, 0:resized_w] = resized_image
- data['image'] = padding_im
- return data
- if self.resize_type == 'PIL':
- image_pil = Image.fromarray(np.uint8(img))
- img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
- img = np.array(img)
- if self.resize_type == 'OpenCV':
- img = cv2.resize(img, self.image_shape)
- norm_img = np.expand_dims(img, -1)
- norm_img = norm_img.transpose((2, 0, 1))
- data['image'] = norm_img.astype(np.float32) / 128. - 1.
- return data
-
-
class RecResizeImg(object):
def __init__(self,
image_shape,
@@ -285,6 +277,84 @@ class RobustScannerRecResizeImg(object):
data['word_positons'] = word_positons
return data
+class GrayRecResizeImg(object):
+ def __init__(self,
+ image_shape,
+ resize_type,
+ inter_type='Image.ANTIALIAS',
+ scale=True,
+ padding=False,
+ **kwargs):
+ self.image_shape = image_shape
+ self.resize_type = resize_type
+ self.padding = padding
+ self.inter_type = eval(inter_type)
+ self.scale = scale
+
+ def __call__(self, data):
+ img = data['image']
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ image_shape = self.image_shape
+ if self.padding:
+ imgC, imgH, imgW = image_shape
+ # todo: change to 0 and modified image shape
+ h = img.shape[0]
+ w = img.shape[1]
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio))
+ resized_image = cv2.resize(img, (resized_w, imgH))
+ norm_img = np.expand_dims(resized_image, -1)
+ norm_img = norm_img.transpose((2, 0, 1))
+ resized_image = norm_img.astype(np.float32) / 128. - 1.
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+ padding_im[:, :, 0:resized_w] = resized_image
+ data['image'] = padding_im
+ return data
+ if self.resize_type == 'PIL':
+ image_pil = Image.fromarray(np.uint8(img))
+ img = image_pil.resize(self.image_shape, self.inter_type)
+ img = np.array(img)
+ if self.resize_type == 'OpenCV':
+ img = cv2.resize(img, self.image_shape)
+ norm_img = np.expand_dims(img, -1)
+ norm_img = norm_img.transpose((2, 0, 1))
+ if self.scale:
+ data['image'] = norm_img.astype(np.float32) / 128. - 1.
+ else:
+ data['image'] = norm_img.astype(np.float32) / 255.
+ return data
+
+
+class ABINetRecResizeImg(object):
+ def __init__(self, image_shape, **kwargs):
+ self.image_shape = image_shape
+
+ def __call__(self, data):
+ img = data['image']
+ norm_img, valid_ratio = resize_norm_img_abinet(img, self.image_shape)
+ data['image'] = norm_img
+ data['valid_ratio'] = valid_ratio
+ return data
+
+
+class SVTRRecResizeImg(object):
+ def __init__(self, image_shape, padding=True, **kwargs):
+ self.image_shape = image_shape
+ self.padding = padding
+
+ def __call__(self, data):
+ img = data['image']
+
+ norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
+ self.padding)
+ data['image'] = norm_img
+ data['valid_ratio'] = valid_ratio
+ return data
+
+
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
@@ -403,6 +473,26 @@ def resize_norm_img_srn(img, image_shape):
return np.reshape(img_black, (c, row, col)).astype(np.float32)
+def resize_norm_img_abinet(img, image_shape):
+ imgC, imgH, imgW = image_shape
+
+ resized_image = cv2.resize(
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ resized_w = imgW
+ resized_image = resized_image.astype('float32')
+ resized_image = resized_image / 255.
+
+ mean = np.array([0.485, 0.456, 0.406])
+ std = np.array([0.229, 0.224, 0.225])
+ resized_image = (
+ resized_image - mean[None, None, ...]) / std[None, None, ...]
+ resized_image = resized_image.transpose((2, 0, 1))
+ resized_image = resized_image.astype('float32')
+
+ valid_ratio = min(1.0, float(resized_w / imgW))
+ return resized_image, valid_ratio
+
+
def srn_other_inputs(image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
diff --git a/ppocr/data/imaug/gen_table_mask.py b/ppocr/data/imaug/table_ops.py
similarity index 77%
rename from ppocr/data/imaug/gen_table_mask.py
rename to ppocr/data/imaug/table_ops.py
index 08e35d5d1df7f9663b4e008451308d0ee409cf5a..8d139190ab4b22c553036ddc8e31cfbc7ec3423d 100644
--- a/ppocr/data/imaug/gen_table_mask.py
+++ b/ppocr/data/imaug/table_ops.py
@@ -32,7 +32,7 @@ class GenTableMask(object):
self.shrink_h_max = 5
self.shrink_w_max = 5
self.mask_type = mask_type
-
+
def projection(self, erosion, h, w, spilt_threshold=0):
# 水平投影
projection_map = np.ones_like(erosion)
@@ -48,10 +48,12 @@ class GenTableMask(object):
in_text = False # 是否遍历到了字符区内
box_list = []
for i in range(len(project_val_array)):
- if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
+ if in_text == False and project_val_array[
+ i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
- elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
+ elif project_val_array[
+ i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
@@ -70,7 +72,8 @@ class GenTableMask(object):
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
h, w = box_gray_img.shape
# 灰度图片进行二值化处理
- ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
+ ret, thresh1 = cv2.threshold(box_gray_img, 200, 255,
+ cv2.THRESH_BINARY_INV)
# 纵向腐蚀
if h < w:
kernel = np.ones((2, 1), np.uint8)
@@ -95,10 +98,12 @@ class GenTableMask(object):
box_list = []
spilt_threshold = 0
for i in range(len(project_val_array)):
- if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
+ if in_text == False and project_val_array[
+ i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
- elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
+ elif project_val_array[
+ i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
@@ -120,7 +125,8 @@ class GenTableMask(object):
h_end = h
word_img = erosion[h_start:h_end + 1, :]
word_h, word_w = word_img.shape
- w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
+ w_split_list, w_projection_map = self.projection(word_img.T,
+ word_w, word_h)
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
if h_start > 0:
h_start -= 1
@@ -170,75 +176,54 @@ class GenTableMask(object):
for sno in range(len(split_bbox_list)):
left, top, right, bottom = split_bbox_list[sno]
- left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
+ left, top, right, bottom = self.shrink_bbox(
+ [left, top, right, bottom])
if self.mask_type == 1:
mask_img[top:bottom, left:right] = 1.0
data['mask_img'] = mask_img
else:
- mask_img[top:bottom, left:right, :] = (255, 255, 255)
+ mask_img[top:bottom, left:right, :] = (255, 255, 255)
data['image'] = mask_img
return data
+
class ResizeTableImage(object):
- def __init__(self, max_len, **kwargs):
+ def __init__(self, max_len, resize_bboxes=False, infer_mode=False,
+ **kwargs):
super(ResizeTableImage, self).__init__()
self.max_len = max_len
+ self.resize_bboxes = resize_bboxes
+ self.infer_mode = infer_mode
- def get_img_bbox(self, cells):
- bbox_list = []
- if len(cells) == 0:
- return bbox_list
- cell_num = len(cells)
- for cno in range(cell_num):
- if "bbox" in cells[cno]:
- bbox = cells[cno]['bbox']
- bbox_list.append(bbox)
- return bbox_list
-
- def resize_img_table(self, img, bbox_list, max_len):
+ def __call__(self, data):
+ img = data['image']
height, width = img.shape[0:2]
- ratio = max_len / (max(height, width) * 1.0)
+ ratio = self.max_len / (max(height, width) * 1.0)
resize_h = int(height * ratio)
resize_w = int(width * ratio)
- img_new = cv2.resize(img, (resize_w, resize_h))
- bbox_list_new = []
- for bno in range(len(bbox_list)):
- left, top, right, bottom = bbox_list[bno].copy()
- left = int(left * ratio)
- top = int(top * ratio)
- right = int(right * ratio)
- bottom = int(bottom * ratio)
- bbox_list_new.append([left, top, right, bottom])
- return img_new, bbox_list_new
-
- def __call__(self, data):
- img = data['image']
- if 'cells' not in data:
- cells = []
- else:
- cells = data['cells']
- bbox_list = self.get_img_bbox(cells)
- img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
- data['image'] = img_new
- cell_num = len(cells)
- bno = 0
- for cno in range(cell_num):
- if "bbox" in data['cells'][cno]:
- data['cells'][cno]['bbox'] = bbox_list_new[bno]
- bno += 1
+ resize_img = cv2.resize(img, (resize_w, resize_h))
+ if self.resize_bboxes and not self.infer_mode:
+ data['bboxes'] = data['bboxes'] * ratio
+ data['image'] = resize_img
+ data['src_img'] = img
+ data['shape'] = np.array([resize_h, resize_w, ratio, ratio])
data['max_len'] = self.max_len
return data
+
class PaddingTableImage(object):
- def __init__(self, **kwargs):
+ def __init__(self, size, **kwargs):
super(PaddingTableImage, self).__init__()
-
+ self.size = size
+
def __call__(self, data):
img = data['image']
- max_len = data['max_len']
- padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
+ pad_h, pad_w = self.size
+ padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
height, width = img.shape[0:2]
padding_img[0:height, 0:width, :] = img.copy()
data['image'] = padding_img
+ shape = data['shape'].tolist()
+ shape.extend([pad_h, pad_w])
+ data['shape'] = np.array(shape)
return data
-
\ No newline at end of file
diff --git a/ppocr/data/imaug/vqa/__init__.py b/ppocr/data/imaug/vqa/__init__.py
index a5025e7985198e7ee40d6c92d8e1814eb1797032..bde175115536a3f644750260082204fe5f10dc05 100644
--- a/ppocr/data/imaug/vqa/__init__.py
+++ b/ppocr/data/imaug/vqa/__init__.py
@@ -13,7 +13,12 @@
# limitations under the License.
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
+from .augment import DistortBBox
__all__ = [
- 'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation'
+ 'VQATokenPad',
+ 'VQASerTokenChunk',
+ 'VQAReTokenChunk',
+ 'VQAReTokenRelation',
+ 'DistortBBox',
]
diff --git a/ppocr/data/imaug/vqa/augment.py b/ppocr/data/imaug/vqa/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcdc9685e9855c3a2d8e9f6f5add270f95f15a6c
--- /dev/null
+++ b/ppocr/data/imaug/vqa/augment.py
@@ -0,0 +1,37 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import numpy as np
+import random
+
+
+class DistortBBox:
+ def __init__(self, prob=0.5, max_scale=1, **kwargs):
+ """Random distort bbox
+ """
+ self.prob = prob
+ self.max_scale = max_scale
+
+ def __call__(self, data):
+ if random.random() > self.prob:
+ return data
+ bbox = np.array(data['bbox'])
+ rnd_scale = (np.random.rand(*bbox.shape) - 0.5) * 2 * self.max_scale
+ bbox = np.round(bbox + rnd_scale).astype(bbox.dtype)
+ data['bbox'] = np.clip(data['bbox'], 0, 1000)
+ data['bbox'] = bbox.tolist()
+ sys.stdout.flush()
+ return data
diff --git a/ppocr/data/pubtab_dataset.py b/ppocr/data/pubtab_dataset.py
index 671cda76fb4c36f3ac6bcc7da5a7fc4de241c0e2..642d3eb1961cbf0e829e6fb122f38c6af99df1c5 100644
--- a/ppocr/data/pubtab_dataset.py
+++ b/ppocr/data/pubtab_dataset.py
@@ -16,6 +16,7 @@ import os
import random
from paddle.io import Dataset
import json
+from copy import deepcopy
from .imaug import transform, create_operators
@@ -29,33 +30,63 @@ class PubTabDataSet(Dataset):
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
- label_file_path = dataset_config.pop('label_file_path')
+ label_file_list = dataset_config.pop('label_file_list')
+ data_source_num = len(label_file_list)
+ ratio_list = dataset_config.get("ratio_list", [1.0])
+ if isinstance(ratio_list, (float, int)):
+ ratio_list = [float(ratio_list)] * int(data_source_num)
+
+ assert len(
+ ratio_list
+ ) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
- self.do_hard_select = False
- if 'hard_select' in loader_config:
- self.do_hard_select = loader_config['hard_select']
- self.hard_prob = loader_config['hard_prob']
- if self.do_hard_select:
- self.img_select_prob = self.load_hard_select_prob()
- self.table_select_type = None
- if 'table_select_type' in loader_config:
- self.table_select_type = loader_config['table_select_type']
- self.table_select_prob = loader_config['table_select_prob']
self.seed = seed
- logger.info("Initialize indexs of datasets:%s" % label_file_path)
- with open(label_file_path, "rb") as f:
- self.data_lines = f.readlines()
- self.data_idx_order_list = list(range(len(self.data_lines)))
- if mode.lower() == "train":
+ self.mode = mode.lower()
+ logger.info("Initialize indexs of datasets:%s" % label_file_list)
+ self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
+ # self.check(config['Global']['max_text_length'])
+
+ if mode.lower() == "train" and self.do_shuffle:
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
-
- ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
+ def get_image_info_list(self, file_list, ratio_list):
+ if isinstance(file_list, str):
+ file_list = [file_list]
+ data_lines = []
+ for idx, file in enumerate(file_list):
+ with open(file, "rb") as f:
+ lines = f.readlines()
+ if self.mode == "train" or ratio_list[idx] < 1.0:
+ random.seed(self.seed)
+ lines = random.sample(lines,
+ round(len(lines) * ratio_list[idx]))
+ data_lines.extend(lines)
+ return data_lines
+
+ def check(self, max_text_length):
+ data_lines = []
+ for line in self.data_lines:
+ data_line = line.decode('utf-8').strip("\n")
+ info = json.loads(data_line)
+ file_name = info['filename']
+ cells = info['html']['cells'].copy()
+ structure = info['html']['structure']['tokens'].copy()
+
+ img_path = os.path.join(self.data_dir, file_name)
+ if not os.path.exists(img_path):
+ self.logger.warning("{} does not exist!".format(img_path))
+ continue
+ if len(structure) == 0 or len(structure) > max_text_length:
+ continue
+ # data = {'img_path': img_path, 'cells': cells, 'structure':structure,'file_name':file_name}
+ data_lines.append(line)
+ self.data_lines = data_lines
+
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
@@ -68,47 +99,35 @@ class PubTabDataSet(Dataset):
data_line = data_line.decode('utf-8').strip("\n")
info = json.loads(data_line)
file_name = info['filename']
- select_flag = True
- if self.do_hard_select:
- prob = self.img_select_prob[file_name]
- if prob < random.uniform(0, 1):
- select_flag = False
-
- if self.table_select_type:
- structure = info['html']['structure']['tokens'].copy()
- structure_str = ''.join(structure)
- table_type = "simple"
- if 'colspan' in structure_str or 'rowspan' in structure_str:
- table_type = "complex"
- if table_type == "complex":
- if self.table_select_prob < random.uniform(0, 1):
- select_flag = False
-
- if select_flag:
- cells = info['html']['cells'].copy()
- structure = info['html']['structure'].copy()
- img_path = os.path.join(self.data_dir, file_name)
- data = {
- 'img_path': img_path,
- 'cells': cells,
- 'structure': structure
- }
- if not os.path.exists(img_path):
- raise Exception("{} does not exist!".format(img_path))
- with open(data['img_path'], 'rb') as f:
- img = f.read()
- data['image'] = img
- outs = transform(data, self.ops)
- else:
- outs = None
- except Exception as e:
+ cells = info['html']['cells'].copy()
+ structure = info['html']['structure']['tokens'].copy()
+
+ img_path = os.path.join(self.data_dir, file_name)
+ if not os.path.exists(img_path):
+ raise Exception("{} does not exist!".format(img_path))
+ data = {
+ 'img_path': img_path,
+ 'cells': cells,
+ 'structure': structure,
+ 'file_name': file_name
+ }
+
+ with open(data['img_path'], 'rb') as f:
+ img = f.read()
+ data['image'] = img
+ outs = transform(data, self.ops)
+ except:
+ import traceback
+ err = traceback.format_exc()
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
- data_line, e))
+ data_line, err))
outs = None
if outs is None:
- return self.__getitem__(np.random.randint(self.__len__()))
+ rnd_idx = np.random.randint(self.__len__(
+ )) if self.mode == "train" else (idx + 1) % self.__len__()
+ return self.__getitem__(rnd_idx)
return outs
def __len__(self):
- return len(self.data_idx_order_list)
+ return len(self.data_lines)
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index de8419b7c1cf6a30ab7195a1cbcbb10a5e52642d..62e0544ea94daaaff7d019e6a48e65a2d508aca0 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -30,7 +30,7 @@ from .det_fce_loss import FCELoss
from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
-from .rec_nrtr_loss import NRTRLoss
+from .rec_ce_loss import CELoss
from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss
from .rec_pren_loss import PRENLoss
@@ -51,7 +51,7 @@ from .combined_loss import CombinedLoss
# table loss
from .table_att_loss import TableAttentionLoss
-
+from .table_master_loss import TableMasterLoss
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
@@ -60,8 +60,9 @@ def build_loss(config):
support_dict = [
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
- 'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
- 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss'
+ 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
+ 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
+ 'TableMasterLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py
index 2df96ea2642d10a50eb892d738f89318dc5e0f4c..74490791c2af0be54dab8ab30ac323790fcac657 100644
--- a/ppocr/losses/basic_loss.py
+++ b/ppocr/losses/basic_loss.py
@@ -57,17 +57,24 @@ class CELoss(nn.Layer):
class KLJSLoss(object):
def __init__(self, mode='kl'):
assert mode in ['kl', 'js', 'KL', 'JS'
- ], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
+ ], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean"):
- loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
-
- if self.mode.lower() == "js":
+ if self.mode.lower() == 'kl':
+ loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
+ loss += paddle.multiply(
+ p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
+ loss *= 0.5
+ elif self.mode.lower() == "js":
+ loss = paddle.multiply(p2, paddle.log((2*p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss += paddle.multiply(
- p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
+ p1, paddle.log((2*p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss *= 0.5
+ else:
+ raise ValueError("The mode.lower() if KLJSLoss should be one of ['kl', 'js']")
+
if reduction == "mean":
loss = paddle.mean(loss, axis=[1, 2])
elif reduction == "none" or reduction is None:
@@ -95,7 +102,7 @@ class DMLLoss(nn.Layer):
self.act = None
self.use_log = use_log
- self.jskl_loss = KLJSLoss(mode="js")
+ self.jskl_loss = KLJSLoss(mode="kl")
def _kldiv(self, x, target):
eps = 1.0e-10
diff --git a/ppocr/losses/rec_ce_loss.py b/ppocr/losses/rec_ce_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..614384de863c15b106aef831f8e938b89dadc246
--- /dev/null
+++ b/ppocr/losses/rec_ce_loss.py
@@ -0,0 +1,66 @@
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+
+
+class CELoss(nn.Layer):
+ def __init__(self,
+ smoothing=False,
+ with_all=False,
+ ignore_index=-1,
+ **kwargs):
+ super(CELoss, self).__init__()
+ if ignore_index >= 0:
+ self.loss_func = nn.CrossEntropyLoss(
+ reduction='mean', ignore_index=ignore_index)
+ else:
+ self.loss_func = nn.CrossEntropyLoss(reduction='mean')
+ self.smoothing = smoothing
+ self.with_all = with_all
+
+ def forward(self, pred, batch):
+
+ if isinstance(pred, dict): # for ABINet
+ loss = {}
+ loss_sum = []
+ for name, logits in pred.items():
+ if isinstance(logits, list):
+ logit_num = len(logits)
+ all_tgt = paddle.concat([batch[1]] * logit_num, 0)
+ all_logits = paddle.concat(logits, 0)
+ flt_logtis = all_logits.reshape([-1, all_logits.shape[2]])
+ flt_tgt = all_tgt.reshape([-1])
+ else:
+ flt_logtis = logits.reshape([-1, logits.shape[2]])
+ flt_tgt = batch[1].reshape([-1])
+ loss[name + '_loss'] = self.loss_func(flt_logtis, flt_tgt)
+ loss_sum.append(loss[name + '_loss'])
+ loss['loss'] = sum(loss_sum)
+ return loss
+ else:
+ if self.with_all: # for ViTSTR
+ tgt = batch[1]
+ pred = pred.reshape([-1, pred.shape[2]])
+ tgt = tgt.reshape([-1])
+ loss = self.loss_func(pred, tgt)
+ return {'loss': loss}
+ else: # for NRTR
+ max_len = batch[2].max()
+ tgt = batch[1][:, 1:2 + max_len]
+ pred = pred.reshape([-1, pred.shape[2]])
+ tgt = tgt.reshape([-1])
+ if self.smoothing:
+ eps = 0.1
+ n_class = pred.shape[1]
+ one_hot = F.one_hot(tgt, pred.shape[1])
+ one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (
+ n_class - 1)
+ log_prb = F.log_softmax(pred, axis=1)
+ non_pad_mask = paddle.not_equal(
+ tgt, paddle.zeros(
+ tgt.shape, dtype=tgt.dtype))
+ loss = -(one_hot * log_prb).sum(axis=1)
+ loss = loss.masked_select(non_pad_mask).mean()
+ else:
+ loss = self.loss_func(pred, tgt)
+ return {'loss': loss}
diff --git a/ppocr/losses/rec_nrtr_loss.py b/ppocr/losses/rec_nrtr_loss.py
deleted file mode 100644
index 200a6d0486dbf6f76dd674eb58f641b31a70f31c..0000000000000000000000000000000000000000
--- a/ppocr/losses/rec_nrtr_loss.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import paddle
-from paddle import nn
-import paddle.nn.functional as F
-
-
-class NRTRLoss(nn.Layer):
- def __init__(self, smoothing=True, **kwargs):
- super(NRTRLoss, self).__init__()
- self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
- self.smoothing = smoothing
-
- def forward(self, pred, batch):
- pred = pred.reshape([-1, pred.shape[2]])
- max_len = batch[2].max()
- tgt = batch[1][:, 1:2 + max_len]
- tgt = tgt.reshape([-1])
- if self.smoothing:
- eps = 0.1
- n_class = pred.shape[1]
- one_hot = F.one_hot(tgt, pred.shape[1])
- one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
- log_prb = F.log_softmax(pred, axis=1)
- non_pad_mask = paddle.not_equal(
- tgt, paddle.zeros(
- tgt.shape, dtype=tgt.dtype))
- loss = -(one_hot * log_prb).sum(axis=1)
- loss = loss.masked_select(non_pad_mask).mean()
- else:
- loss = self.loss_func(pred, tgt)
- return {'loss': loss}
diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py
index 51377efa2b5e802fe9f9fc1973c74deb00fc4816..3496c9072553d839017eaa017fe47dfb66fb9d3b 100644
--- a/ppocr/losses/table_att_loss.py
+++ b/ppocr/losses/table_att_loss.py
@@ -20,15 +20,21 @@ import paddle
from paddle import nn
from paddle.nn import functional as F
+
class TableAttentionLoss(nn.Layer):
- def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
+ def __init__(self,
+ structure_weight,
+ loc_weight,
+ use_giou=False,
+ giou_weight=1.0,
+ **kwargs):
super(TableAttentionLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
self.structure_weight = structure_weight
self.loc_weight = loc_weight
self.use_giou = use_giou
self.giou_weight = giou_weight
-
+
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
'''
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
@@ -47,9 +53,10 @@ class TableAttentionLoss(nn.Layer):
inters = iw * ih
# union
- uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
- ) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
- bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps
+ uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (
+ preds[:, 3] - preds[:, 1] + 1e-3) + (bbox[:, 2] - bbox[:, 0] + 1e-3
+ ) * (bbox[:, 3] - bbox[:, 1] +
+ 1e-3) - inters + eps
# ious
ious = inters / uni
@@ -79,30 +86,34 @@ class TableAttentionLoss(nn.Layer):
structure_probs = predicts['structure_probs']
structure_targets = batch[1].astype("int64")
structure_targets = structure_targets[:, 1:]
- if len(batch) == 6:
- structure_mask = batch[5].astype("int64")
- structure_mask = structure_mask[:, 1:]
- structure_mask = paddle.reshape(structure_mask, [-1])
- structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
+ structure_probs = paddle.reshape(structure_probs,
+ [-1, structure_probs.shape[-1]])
structure_targets = paddle.reshape(structure_targets, [-1])
structure_loss = self.loss_func(structure_probs, structure_targets)
-
- if len(batch) == 6:
- structure_loss = structure_loss * structure_mask
-
-# structure_loss = paddle.sum(structure_loss) * self.structure_weight
+
structure_loss = paddle.mean(structure_loss) * self.structure_weight
-
+
loc_preds = predicts['loc_preds']
loc_targets = batch[2].astype("float32")
- loc_targets_mask = batch[4].astype("float32")
+ loc_targets_mask = batch[3].astype("float32")
loc_targets = loc_targets[:, 1:, :]
loc_targets_mask = loc_targets_mask[:, 1:, :]
- loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
+ loc_loss = F.mse_loss(loc_preds * loc_targets_mask,
+ loc_targets) * self.loc_weight
if self.use_giou:
- loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
+ loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask,
+ loc_targets) * self.giou_weight
total_loss = structure_loss + loc_loss + loc_loss_giou
- return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
+ return {
+ 'loss': total_loss,
+ "structure_loss": structure_loss,
+ "loc_loss": loc_loss,
+ "loc_loss_giou": loc_loss_giou
+ }
else:
- total_loss = structure_loss + loc_loss
- return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}
\ No newline at end of file
+ total_loss = structure_loss + loc_loss
+ return {
+ 'loss': total_loss,
+ "structure_loss": structure_loss,
+ "loc_loss": loc_loss
+ }
diff --git a/ppocr/losses/table_master_loss.py b/ppocr/losses/table_master_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..dca982dbd43e2c14f15503e1e98d6fe6c18878c5
--- /dev/null
+++ b/ppocr/losses/table_master_loss.py
@@ -0,0 +1,70 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/JiaquanYe/TableMASTER-mmocr/tree/master/mmocr/models/textrecog/losses
+"""
+
+import paddle
+from paddle import nn
+
+
+class TableMasterLoss(nn.Layer):
+ def __init__(self, ignore_index=-1):
+ super(TableMasterLoss, self).__init__()
+ self.structure_loss = nn.CrossEntropyLoss(
+ ignore_index=ignore_index, reduction='mean')
+ self.box_loss = nn.L1Loss(reduction='sum')
+ self.eps = 1e-12
+
+ def forward(self, predicts, batch):
+ # structure_loss
+ structure_probs = predicts['structure_probs']
+ structure_targets = batch[1]
+ structure_targets = structure_targets[:, 1:]
+ structure_probs = structure_probs.reshape(
+ [-1, structure_probs.shape[-1]])
+ structure_targets = structure_targets.reshape([-1])
+
+ structure_loss = self.structure_loss(structure_probs, structure_targets)
+ structure_loss = structure_loss.mean()
+ losses = dict(structure_loss=structure_loss)
+
+ # box loss
+ bboxes_preds = predicts['loc_preds']
+ bboxes_targets = batch[2][:, 1:, :]
+ bbox_masks = batch[3][:, 1:]
+ # mask empty-bbox or non-bbox structure token's bbox.
+
+ masked_bboxes_preds = bboxes_preds * bbox_masks
+ masked_bboxes_targets = bboxes_targets * bbox_masks
+
+ # horizon loss (x and width)
+ horizon_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 0::2],
+ masked_bboxes_targets[:, :, 0::2])
+ horizon_loss = horizon_sum_loss / (bbox_masks.sum() + self.eps)
+ # vertical loss (y and height)
+ vertical_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 1::2],
+ masked_bboxes_targets[:, :, 1::2])
+ vertical_loss = vertical_sum_loss / (bbox_masks.sum() + self.eps)
+
+ horizon_loss = horizon_loss.mean()
+ vertical_loss = vertical_loss.mean()
+ all_loss = structure_loss + horizon_loss + vertical_loss
+ losses.update({
+ 'loss': all_loss,
+ 'horizon_bbox_loss': horizon_loss,
+ 'vertical_bbox_loss': vertical_loss
+ })
+ return losses
diff --git a/ppocr/losses/vqa_token_layoutlm_loss.py b/ppocr/losses/vqa_token_layoutlm_loss.py
index 244893d97d0e422c5ca270bdece689e13aba2b07..f9cd4634731a26dd990d6ffac3d8defc8cdf7e97 100755
--- a/ppocr/losses/vqa_token_layoutlm_loss.py
+++ b/ppocr/losses/vqa_token_layoutlm_loss.py
@@ -27,8 +27,8 @@ class VQASerTokenLayoutLMLoss(nn.Layer):
self.ignore_index = self.loss_class.ignore_index
def forward(self, predicts, batch):
- labels = batch[1]
- attention_mask = batch[4]
+ labels = batch[5]
+ attention_mask = batch[2]
if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = predicts.reshape(
diff --git a/ppocr/metrics/eval_det_iou.py b/ppocr/metrics/eval_det_iou.py
index bc05e7df7d1d21abfb9d9fbd224ecd7254d9f393..c144886b3f84a458a88931d6beb2153054eba7d0 100644
--- a/ppocr/metrics/eval_det_iou.py
+++ b/ppocr/metrics/eval_det_iou.py
@@ -83,14 +83,10 @@ class DetectionIoUEvaluator(object):
evaluationLog = ""
- # print(len(gt))
for n in range(len(gt)):
points = gt[n]['points']
- # transcription = gt[n]['text']
dontCare = gt[n]['ignore']
- # points = Polygon(points)
- # points = points.buffer(0)
- if not Polygon(points).is_valid or not Polygon(points).is_simple:
+ if not Polygon(points).is_valid:
continue
gtPol = points
@@ -105,9 +101,7 @@ class DetectionIoUEvaluator(object):
for n in range(len(pred)):
points = pred[n]['points']
- # points = Polygon(points)
- # points = points.buffer(0)
- if not Polygon(points).is_valid or not Polygon(points).is_simple:
+ if not Polygon(points).is_valid:
continue
detPol = points
@@ -191,8 +185,6 @@ class DetectionIoUEvaluator(object):
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
methodRecall * methodPrecision / (
methodRecall + methodPrecision)
- # print(methodRecall, methodPrecision, methodHmean)
- # sys.exit(-1)
methodMetrics = {
'precision': methodPrecision,
'recall': methodRecall,
diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py
index ca4d6474202b4e85cadf86ccb2fe2726c7fa9aeb..fd2631e442b8d111c64d5cf4b34ea9063d8c60dd 100644
--- a/ppocr/metrics/table_metric.py
+++ b/ppocr/metrics/table_metric.py
@@ -12,29 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
+from ppocr.metrics.det_metric import DetMetric
-class TableMetric(object):
- def __init__(self, main_indicator='acc', **kwargs):
+class TableStructureMetric(object):
+ def __init__(self, main_indicator='acc', eps=1e-6, **kwargs):
self.main_indicator = main_indicator
- self.eps = 1e-5
+ self.eps = eps
self.reset()
- def __call__(self, pred, batch, *args, **kwargs):
- structure_probs = pred['structure_probs'].numpy()
- structure_labels = batch[1]
+ def __call__(self, pred_label, batch=None, *args, **kwargs):
+ preds, labels = pred_label
+ pred_structure_batch_list = preds['structure_batch_list']
+ gt_structure_batch_list = labels['structure_batch_list']
correct_num = 0
all_num = 0
- structure_probs = np.argmax(structure_probs, axis=2)
- structure_labels = structure_labels[:, 1:]
- batch_size = structure_probs.shape[0]
- for bno in range(batch_size):
- all_num += 1
- if (structure_probs[bno] == structure_labels[bno]).all():
+ for (pred, pred_conf), target in zip(pred_structure_batch_list,
+ gt_structure_batch_list):
+ pred_str = ''.join(pred)
+ target_str = ''.join(target)
+ if pred_str == target_str:
correct_num += 1
+ all_num += 1
self.correct_num += correct_num
self.all_num += all_num
- return {'acc': correct_num * 1.0 / (all_num + self.eps), }
def get_metric(self):
"""
@@ -49,3 +50,89 @@ class TableMetric(object):
def reset(self):
self.correct_num = 0
self.all_num = 0
+ self.len_acc_num = 0
+ self.token_nums = 0
+ self.anys_dict = dict()
+
+
+class TableMetric(object):
+ def __init__(self,
+ main_indicator='acc',
+ compute_bbox_metric=False,
+ point_num=2,
+ **kwargs):
+ """
+
+ @param sub_metrics: configs of sub_metric
+ @param main_matric: main_matric for save best_model
+ @param kwargs:
+ """
+ self.structure_metric = TableStructureMetric()
+ self.bbox_metric = DetMetric() if compute_bbox_metric else None
+ self.main_indicator = main_indicator
+ self.point_num = point_num
+ self.reset()
+
+ def __call__(self, pred_label, batch=None, *args, **kwargs):
+ self.structure_metric(pred_label)
+ if self.bbox_metric is not None:
+ self.bbox_metric(*self.prepare_bbox_metric_input(pred_label))
+
+ def prepare_bbox_metric_input(self, pred_label):
+ pred_bbox_batch_list = []
+ gt_ignore_tags_batch_list = []
+ gt_bbox_batch_list = []
+ preds, labels = pred_label
+
+ batch_num = len(preds['bbox_batch_list'])
+ for batch_idx in range(batch_num):
+ # pred
+ pred_bbox_list = [
+ self.format_box(pred_box)
+ for pred_box in preds['bbox_batch_list'][batch_idx]
+ ]
+ pred_bbox_batch_list.append({'points': pred_bbox_list})
+
+ # gt
+ gt_bbox_list = []
+ gt_ignore_tags_list = []
+ for gt_box in labels['bbox_batch_list'][batch_idx]:
+ gt_bbox_list.append(self.format_box(gt_box))
+ gt_ignore_tags_list.append(0)
+ gt_bbox_batch_list.append(gt_bbox_list)
+ gt_ignore_tags_batch_list.append(gt_ignore_tags_list)
+
+ return [
+ pred_bbox_batch_list,
+ [0, 0, gt_bbox_batch_list, gt_ignore_tags_batch_list]
+ ]
+
+ def get_metric(self):
+ structure_metric = self.structure_metric.get_metric()
+ if self.bbox_metric is None:
+ return structure_metric
+ bbox_metric = self.bbox_metric.get_metric()
+ if self.main_indicator == self.bbox_metric.main_indicator:
+ output = bbox_metric
+ for sub_key in structure_metric:
+ output["structure_metric_{}".format(
+ sub_key)] = structure_metric[sub_key]
+ else:
+ output = structure_metric
+ for sub_key in bbox_metric:
+ output["bbox_metric_{}".format(sub_key)] = bbox_metric[sub_key]
+ return output
+
+ def reset(self):
+ self.structure_metric.reset()
+ if self.bbox_metric is not None:
+ self.bbox_metric.reset()
+
+ def format_box(self, box):
+ if self.point_num == 2:
+ x1, y1, x2, y2 = box
+ box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
+ elif self.point_num == 4:
+ x1, y1, x2, y2, x3, y3, x4, y4 = box
+ box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
+ return box
diff --git a/ppocr/metrics/vqa_token_re_metric.py b/ppocr/metrics/vqa_token_re_metric.py
index 8a13bc081298284194d365933cd67d5633957ee8..f84387d8beb729bcc4b420ceea24a5e9b2993c64 100644
--- a/ppocr/metrics/vqa_token_re_metric.py
+++ b/ppocr/metrics/vqa_token_re_metric.py
@@ -37,23 +37,26 @@ class VQAReTokenMetric(object):
gt_relations = []
for b in range(len(self.relations_list)):
rel_sent = []
- for head, tail in zip(self.relations_list[b]["head"],
- self.relations_list[b]["tail"]):
- rel = {}
- rel["head_id"] = head
- rel["head"] = (self.entities_list[b]["start"][rel["head_id"]],
- self.entities_list[b]["end"][rel["head_id"]])
- rel["head_type"] = self.entities_list[b]["label"][rel[
- "head_id"]]
-
- rel["tail_id"] = tail
- rel["tail"] = (self.entities_list[b]["start"][rel["tail_id"]],
- self.entities_list[b]["end"][rel["tail_id"]])
- rel["tail_type"] = self.entities_list[b]["label"][rel[
- "tail_id"]]
-
- rel["type"] = 1
- rel_sent.append(rel)
+ if "head" in self.relations_list[b]:
+ for head, tail in zip(self.relations_list[b]["head"],
+ self.relations_list[b]["tail"]):
+ rel = {}
+ rel["head_id"] = head
+ rel["head"] = (
+ self.entities_list[b]["start"][rel["head_id"]],
+ self.entities_list[b]["end"][rel["head_id"]])
+ rel["head_type"] = self.entities_list[b]["label"][rel[
+ "head_id"]]
+
+ rel["tail_id"] = tail
+ rel["tail"] = (
+ self.entities_list[b]["start"][rel["tail_id"]],
+ self.entities_list[b]["end"][rel["tail_id"]])
+ rel["tail_type"] = self.entities_list[b]["label"][rel[
+ "tail_id"]]
+
+ rel["type"] = 1
+ rel_sent.append(rel)
gt_relations.append(rel_sent)
re_metrics = self.re_score(
self.pred_relations_list, gt_relations, mode="boundaries")
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 0cc894dbfbb44e7433d9e07a41ce2b9f5a6f4bca..f4094d796b1f14c955e5962936e86bd6b3f5ec78 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -18,9 +18,13 @@ __all__ = ["build_backbone"]
def build_backbone(config, model_type):
if model_type == "det" or model_type == "table":
from .det_mobilenet_v3 import MobileNetV3
- from .det_resnet_vd import ResNet
+ from .det_resnet import ResNet
+ from .det_resnet_vd import ResNet_vd
from .det_resnet_vd_sast import ResNet_SAST
- support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"]
+ support_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST"]
+ if model_type == "table":
+ from .table_master_resnet import TableResNetExtra
+ support_dict.append('TableResNetExtra')
elif model_type == "rec" or model_type == "cls":
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet
@@ -28,35 +32,37 @@ def build_backbone(config, model_type):
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31
+ from .rec_resnet_45 import ResNet45
from .rec_resnet_aster import ResNet_ASTER
from .rec_micronet import MicroNet
from .rec_efficientb3_pren import EfficientNetb3_PREN
from .rec_svtrnet import SVTRNet
+ from .rec_vitstr import ViTSTR
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
- "ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN',
- 'SVTRNet', "ResNet31V2"
+ 'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
+ 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR'
]
- elif model_type == "e2e":
+ elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
support_dict = ['ResNet']
elif model_type == 'kie':
from .kie_unet_sdmgr import Kie_backbone
support_dict = ['Kie_backbone']
- elif model_type == "table":
+ elif model_type == 'table':
from .table_resnet_vd import ResNet
from .table_mobilenet_v3 import MobileNetV3
- support_dict = ["ResNet", "MobileNetV3"]
+ support_dict = ['ResNet', 'MobileNetV3']
elif model_type == 'vqa':
from .vqa_layoutlm import LayoutLMForSer, LayoutLMv2ForSer, LayoutLMv2ForRe, LayoutXLMForSer, LayoutXLMForRe
support_dict = [
- "LayoutLMForSer", "LayoutLMv2ForSer", 'LayoutLMv2ForRe',
- "LayoutXLMForSer", 'LayoutXLMForRe'
+ 'LayoutLMForSer', 'LayoutLMv2ForSer', 'LayoutLMv2ForRe',
+ 'LayoutXLMForSer', 'LayoutXLMForRe'
]
else:
raise NotImplementedError
- module_name = config.pop("name")
+ module_name = config.pop('name')
assert module_name in support_dict, Exception(
"when model typs is {}, backbone only support {}".format(model_type,
support_dict))
diff --git a/ppocr/modeling/backbones/det_resnet.py b/ppocr/modeling/backbones/det_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..87eef11cf0e33c24c0f539c8074b21f589345282
--- /dev/null
+++ b/ppocr/modeling/backbones/det_resnet.py
@@ -0,0 +1,236 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
+from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
+from paddle.nn.initializer import Uniform
+
+import math
+
+from paddle.vision.ops import DeformConv2D
+from paddle.regularizer import L2Decay
+from paddle.nn.initializer import Normal, Constant, XavierUniform
+from .det_resnet_vd import DeformableConvV2, ConvBNLayer
+
+
+class BottleneckBlock(nn.Layer):
+ def __init__(self,
+ num_channels,
+ num_filters,
+ stride,
+ shortcut=True,
+ is_dcn=False):
+ super(BottleneckBlock, self).__init__()
+
+ self.conv0 = ConvBNLayer(
+ in_channels=num_channels,
+ out_channels=num_filters,
+ kernel_size=1,
+ act="relu", )
+ self.conv1 = ConvBNLayer(
+ in_channels=num_filters,
+ out_channels=num_filters,
+ kernel_size=3,
+ stride=stride,
+ act="relu",
+ is_dcn=is_dcn,
+ dcn_groups=1, )
+ self.conv2 = ConvBNLayer(
+ in_channels=num_filters,
+ out_channels=num_filters * 4,
+ kernel_size=1,
+ act=None, )
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=num_channels,
+ out_channels=num_filters * 4,
+ kernel_size=1,
+ stride=stride, )
+
+ self.shortcut = shortcut
+
+ self._num_channels_out = num_filters * 4
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+ conv2 = self.conv2(conv1)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+
+ y = paddle.add(x=short, y=conv2)
+ y = F.relu(y)
+ return y
+
+
+class BasicBlock(nn.Layer):
+ def __init__(self,
+ num_channels,
+ num_filters,
+ stride,
+ shortcut=True,
+ name=None):
+ super(BasicBlock, self).__init__()
+ self.stride = stride
+ self.conv0 = ConvBNLayer(
+ in_channels=num_channels,
+ out_channels=num_filters,
+ kernel_size=3,
+ stride=stride,
+ act="relu")
+ self.conv1 = ConvBNLayer(
+ in_channels=num_filters,
+ out_channels=num_filters,
+ kernel_size=3,
+ act=None)
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=num_channels,
+ out_channels=num_filters,
+ kernel_size=1,
+ stride=stride)
+
+ self.shortcut = shortcut
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+ y = paddle.add(x=short, y=conv1)
+ y = F.relu(y)
+ return y
+
+
+class ResNet(nn.Layer):
+ def __init__(self,
+ in_channels=3,
+ layers=50,
+ out_indices=None,
+ dcn_stage=None):
+ super(ResNet, self).__init__()
+
+ self.layers = layers
+ self.input_image_channel = in_channels
+
+ supported_layers = [18, 34, 50, 101, 152]
+ assert layers in supported_layers, \
+ "supported layers are {} but input layer is {}".format(
+ supported_layers, layers)
+
+ if layers == 18:
+ depth = [2, 2, 2, 2]
+ elif layers == 34 or layers == 50:
+ depth = [3, 4, 6, 3]
+ elif layers == 101:
+ depth = [3, 4, 23, 3]
+ elif layers == 152:
+ depth = [3, 8, 36, 3]
+ num_channels = [64, 256, 512,
+ 1024] if layers >= 50 else [64, 64, 128, 256]
+ num_filters = [64, 128, 256, 512]
+
+ self.dcn_stage = dcn_stage if dcn_stage is not None else [
+ False, False, False, False
+ ]
+ self.out_indices = out_indices if out_indices is not None else [
+ 0, 1, 2, 3
+ ]
+
+ self.conv = ConvBNLayer(
+ in_channels=self.input_image_channel,
+ out_channels=64,
+ kernel_size=7,
+ stride=2,
+ act="relu", )
+ self.pool2d_max = MaxPool2D(
+ kernel_size=3,
+ stride=2,
+ padding=1, )
+
+ self.stages = []
+ self.out_channels = []
+ if layers >= 50:
+ for block in range(len(depth)):
+ shortcut = False
+ block_list = []
+ is_dcn = self.dcn_stage[block]
+ for i in range(depth[block]):
+ if layers in [101, 152] and block == 2:
+ if i == 0:
+ conv_name = "res" + str(block + 2) + "a"
+ else:
+ conv_name = "res" + str(block + 2) + "b" + str(i)
+ else:
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ bottleneck_block = self.add_sublayer(
+ conv_name,
+ BottleneckBlock(
+ num_channels=num_channels[block]
+ if i == 0 else num_filters[block] * 4,
+ num_filters=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut,
+ is_dcn=is_dcn))
+ block_list.append(bottleneck_block)
+ shortcut = True
+ if block in self.out_indices:
+ self.out_channels.append(num_filters[block] * 4)
+ self.stages.append(nn.Sequential(*block_list))
+ else:
+ for block in range(len(depth)):
+ shortcut = False
+ block_list = []
+ for i in range(depth[block]):
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ basic_block = self.add_sublayer(
+ conv_name,
+ BasicBlock(
+ num_channels=num_channels[block]
+ if i == 0 else num_filters[block],
+ num_filters=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut))
+ block_list.append(basic_block)
+ shortcut = True
+ if block in self.out_indices:
+ self.out_channels.append(num_filters[block])
+ self.stages.append(nn.Sequential(*block_list))
+
+ def forward(self, inputs):
+ y = self.conv(inputs)
+ y = self.pool2d_max(y)
+ out = []
+ for i, block in enumerate(self.stages):
+ y = block(y)
+ if i in self.out_indices:
+ out.append(y)
+ return out
diff --git a/ppocr/modeling/backbones/det_resnet_vd.py b/ppocr/modeling/backbones/det_resnet_vd.py
index 8c955a4af377374f21e7c09f0d10952f2fe1ceed..a421da0ab440e9b87c1c7efc7d2448f8f76ad205 100644
--- a/ppocr/modeling/backbones/det_resnet_vd.py
+++ b/ppocr/modeling/backbones/det_resnet_vd.py
@@ -25,7 +25,7 @@ from paddle.vision.ops import DeformConv2D
from paddle.regularizer import L2Decay
from paddle.nn.initializer import Normal, Constant, XavierUniform
-__all__ = ["ResNet"]
+__all__ = ["ResNet_vd", "ConvBNLayer", "DeformableConvV2"]
class DeformableConvV2(nn.Layer):
@@ -104,6 +104,7 @@ class ConvBNLayer(nn.Layer):
kernel_size,
stride=1,
groups=1,
+ dcn_groups=1,
is_vd_mode=False,
act=None,
is_dcn=False):
@@ -128,7 +129,7 @@ class ConvBNLayer(nn.Layer):
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
- groups=2, #groups,
+ groups=dcn_groups, #groups,
bias_attr=False)
self._batch_norm = nn.BatchNorm(out_channels, act=act)
@@ -162,7 +163,8 @@ class BottleneckBlock(nn.Layer):
kernel_size=3,
stride=stride,
act='relu',
- is_dcn=is_dcn)
+ is_dcn=is_dcn,
+ dcn_groups=2)
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
@@ -238,14 +240,14 @@ class BasicBlock(nn.Layer):
return y
-class ResNet(nn.Layer):
+class ResNet_vd(nn.Layer):
def __init__(self,
in_channels=3,
layers=50,
dcn_stage=None,
out_indices=None,
**kwargs):
- super(ResNet, self).__init__()
+ super(ResNet_vd, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
@@ -321,7 +323,6 @@ class ResNet(nn.Layer):
for block in range(len(depth)):
block_list = []
shortcut = False
- # is_dcn = self.dcn_stage[block]
for i in range(depth[block]):
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
diff --git a/ppocr/modeling/backbones/rec_resnet_31.py b/ppocr/modeling/backbones/rec_resnet_31.py
index e1d77c405dfed1541f6a0197af14d4edeb908803..b7990b67a248e2b22750f117a41e01820d3e83cc 100644
--- a/ppocr/modeling/backbones/rec_resnet_31.py
+++ b/ppocr/modeling/backbones/rec_resnet_31.py
@@ -27,7 +27,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
-__all__ = ["ResNet31V2"]
+__all__ = ["ResNet31"]
conv_weight_attr = nn.initializer.KaimingNormal()
diff --git a/ppocr/modeling/backbones/rec_resnet_45.py b/ppocr/modeling/backbones/rec_resnet_45.py
new file mode 100644
index 0000000000000000000000000000000000000000..9093d0bc99b78806d36662dec36b6cfbdd4ae493
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_resnet_45.py
@@ -0,0 +1,147 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/FangShancheng/ABINet/tree/main/modules
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import ParamAttr
+from paddle.nn.initializer import KaimingNormal
+import paddle.nn as nn
+import paddle.nn.functional as F
+import numpy as np
+import math
+
+__all__ = ["ResNet45"]
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ return nn.Conv2D(
+ in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=1,
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ bias_attr=False)
+
+
+def conv3x3(in_channel, out_channel, stride=1):
+ return nn.Conv2D(
+ in_channel,
+ out_channel,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ bias_attr=False)
+
+
+class BasicBlock(nn.Layer):
+ expansion = 1
+
+ def __init__(self, in_channels, channels, stride=1, downsample=None):
+ super().__init__()
+ self.conv1 = conv1x1(in_channels, channels)
+ self.bn1 = nn.BatchNorm2D(channels)
+ self.relu = nn.ReLU()
+ self.conv2 = conv3x3(channels, channels, stride)
+ self.bn2 = nn.BatchNorm2D(channels)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet45(nn.Layer):
+ def __init__(self, block=BasicBlock, layers=[3, 4, 6, 6, 3], in_channels=3):
+ self.inplanes = 32
+ super(ResNet45, self).__init__()
+ self.conv1 = nn.Conv2D(
+ 3,
+ 32,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ bias_attr=False)
+ self.bn1 = nn.BatchNorm2D(32)
+ self.relu = nn.ReLU()
+
+ self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
+ self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
+ self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
+ self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
+ self.out_channels = 512
+
+ # for m in self.modules():
+ # if isinstance(m, nn.Conv2D):
+ # n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
+ # m.weight.data.normal_(0, math.sqrt(2. / n))
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ # downsample = True
+ downsample = nn.Sequential(
+ nn.Conv2D(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ bias_attr=False),
+ nn.BatchNorm2D(planes * block.expansion), )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ # print(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ # print(x)
+ x = self.layer4(x)
+ x = self.layer5(x)
+ return x
diff --git a/ppocr/modeling/backbones/rec_svtrnet.py b/ppocr/modeling/backbones/rec_svtrnet.py
index c57bf46345d6e08f23b9258358f77f2285366314..c2c07f4476929d49237c8e9a10713f881f5f556b 100644
--- a/ppocr/modeling/backbones/rec_svtrnet.py
+++ b/ppocr/modeling/backbones/rec_svtrnet.py
@@ -147,7 +147,7 @@ class Attention(nn.Layer):
dim,
num_heads=8,
mixer='Global',
- HW=[8, 25],
+ HW=None,
local_k=[7, 11],
qkv_bias=False,
qk_scale=None,
@@ -210,7 +210,7 @@ class Block(nn.Layer):
num_heads,
mixer='Global',
local_mixer=[7, 11],
- HW=[8, 25],
+ HW=None,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
@@ -274,7 +274,9 @@ class PatchEmbed(nn.Layer):
img_size=[32, 100],
in_channels=3,
embed_dim=768,
- sub_num=2):
+ sub_num=2,
+ patch_size=[4, 4],
+ mode='pope'):
super().__init__()
num_patches = (img_size[1] // (2 ** sub_num)) * \
(img_size[0] // (2 ** sub_num))
@@ -282,50 +284,56 @@ class PatchEmbed(nn.Layer):
self.num_patches = num_patches
self.embed_dim = embed_dim
self.norm = None
- if sub_num == 2:
- self.proj = nn.Sequential(
- ConvBNLayer(
- in_channels=in_channels,
- out_channels=embed_dim // 2,
- kernel_size=3,
- stride=2,
- padding=1,
- act=nn.GELU,
- bias_attr=None),
- ConvBNLayer(
- in_channels=embed_dim // 2,
- out_channels=embed_dim,
- kernel_size=3,
- stride=2,
- padding=1,
- act=nn.GELU,
- bias_attr=None))
- if sub_num == 3:
- self.proj = nn.Sequential(
- ConvBNLayer(
- in_channels=in_channels,
- out_channels=embed_dim // 4,
- kernel_size=3,
- stride=2,
- padding=1,
- act=nn.GELU,
- bias_attr=None),
- ConvBNLayer(
- in_channels=embed_dim // 4,
- out_channels=embed_dim // 2,
- kernel_size=3,
- stride=2,
- padding=1,
- act=nn.GELU,
- bias_attr=None),
- ConvBNLayer(
- in_channels=embed_dim // 2,
- out_channels=embed_dim,
- kernel_size=3,
- stride=2,
- padding=1,
- act=nn.GELU,
- bias_attr=None))
+ if mode == 'pope':
+ if sub_num == 2:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None))
+ if sub_num == 3:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None),
+ ConvBNLayer(
+ in_channels=embed_dim // 4,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None))
+ elif mode == 'linear':
+ self.proj = nn.Conv2D(
+ 1, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.num_patches = img_size[0] // patch_size[0] * img_size[
+ 1] // patch_size[1]
def forward(self, x):
B, C, H, W = x.shape
diff --git a/ppocr/modeling/backbones/rec_vitstr.py b/ppocr/modeling/backbones/rec_vitstr.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5d7d5148a1120e6f97a321b4135c6780c0c5db2
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_vitstr.py
@@ -0,0 +1,120 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/roatienza/deep-text-recognition-benchmark/blob/master/modules/vitstr.py
+"""
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+from ppocr.modeling.backbones.rec_svtrnet import Block, PatchEmbed, zeros_, trunc_normal_, ones_
+
+scale_dim_heads = {'tiny': [192, 3], 'small': [384, 6], 'base': [768, 12]}
+
+
+class ViTSTR(nn.Layer):
+ def __init__(self,
+ img_size=[224, 224],
+ in_channels=1,
+ scale='tiny',
+ seqlen=27,
+ patch_size=[16, 16],
+ embed_dim=None,
+ depth=12,
+ num_heads=None,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_path_rate=0.,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ norm_layer='nn.LayerNorm',
+ act_layer='nn.GELU',
+ epsilon=1e-6,
+ out_channels=None,
+ **kwargs):
+ super().__init__()
+ self.seqlen = seqlen
+ embed_dim = embed_dim if embed_dim is not None else scale_dim_heads[
+ scale][0]
+ num_heads = num_heads if num_heads is not None else scale_dim_heads[
+ scale][1]
+ out_channels = out_channels if out_channels is not None else embed_dim
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ in_channels=in_channels,
+ embed_dim=embed_dim,
+ patch_size=patch_size,
+ mode='linear')
+ num_patches = self.patch_embed.num_patches
+
+ self.pos_embed = self.create_parameter(
+ shape=[1, num_patches + 1, embed_dim], default_initializer=zeros_)
+ self.add_parameter("pos_embed", self.pos_embed)
+ self.cls_token = self.create_parameter(
+ shape=[1, 1, embed_dim], default_initializer=zeros_)
+ self.add_parameter("cls_token", self.cls_token)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = np.linspace(0, drop_path_rate, depth)
+ self.blocks = nn.LayerList([
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=eval(act_layer),
+ epsilon=epsilon,
+ prenorm=False) for i in range(depth)
+ ])
+ self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
+
+ self.out_channels = out_channels
+
+ trunc_normal_(self.pos_embed)
+ trunc_normal_(self.cls_token)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+ cls_tokens = paddle.tile(self.cls_token, repeat_times=[B, 1, 1])
+ x = paddle.concat((cls_tokens, x), axis=1)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+ for blk in self.blocks:
+ x = blk(x)
+ x = self.norm(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = x[:, :self.seqlen]
+ return x.transpose([0, 2, 1]).unsqueeze(2)
diff --git a/ppocr/modeling/backbones/table_master_resnet.py b/ppocr/modeling/backbones/table_master_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..dacf5ed26e5374b3c93c1a983be1d7b5b4c471fc
--- /dev/null
+++ b/ppocr/modeling/backbones/table_master_resnet.py
@@ -0,0 +1,369 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/mmocr/models/textrecog/backbones/table_resnet_extra.py
+"""
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class BasicBlock(nn.Layer):
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ downsample=None,
+ gcb_config=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = nn.Conv2D(
+ inplanes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias_attr=False)
+ self.bn1 = nn.BatchNorm2D(planes, momentum=0.9)
+ self.relu = nn.ReLU()
+ self.conv2 = nn.Conv2D(
+ planes, planes, kernel_size=3, stride=1, padding=1, bias_attr=False)
+ self.bn2 = nn.BatchNorm2D(planes, momentum=0.9)
+ self.downsample = downsample
+ self.stride = stride
+ self.gcb_config = gcb_config
+
+ if self.gcb_config is not None:
+ gcb_ratio = gcb_config['ratio']
+ gcb_headers = gcb_config['headers']
+ att_scale = gcb_config['att_scale']
+ fusion_type = gcb_config['fusion_type']
+ self.context_block = MultiAspectGCAttention(
+ inplanes=planes,
+ ratio=gcb_ratio,
+ headers=gcb_headers,
+ att_scale=att_scale,
+ fusion_type=fusion_type)
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.gcb_config is not None:
+ out = self.context_block(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+def get_gcb_config(gcb_config, layer):
+ if gcb_config is None or not gcb_config['layers'][layer]:
+ return None
+ else:
+ return gcb_config
+
+
+class TableResNetExtra(nn.Layer):
+ def __init__(self, layers, in_channels=3, gcb_config=None):
+ assert len(layers) >= 4
+
+ super(TableResNetExtra, self).__init__()
+ self.inplanes = 128
+ self.conv1 = nn.Conv2D(
+ in_channels,
+ 64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias_attr=False)
+ self.bn1 = nn.BatchNorm2D(64)
+ self.relu1 = nn.ReLU()
+
+ self.conv2 = nn.Conv2D(
+ 64, 128, kernel_size=3, stride=1, padding=1, bias_attr=False)
+ self.bn2 = nn.BatchNorm2D(128)
+ self.relu2 = nn.ReLU()
+
+ self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2)
+
+ self.layer1 = self._make_layer(
+ BasicBlock,
+ 256,
+ layers[0],
+ stride=1,
+ gcb_config=get_gcb_config(gcb_config, 0))
+
+ self.conv3 = nn.Conv2D(
+ 256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
+ self.bn3 = nn.BatchNorm2D(256)
+ self.relu3 = nn.ReLU()
+
+ self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2)
+
+ self.layer2 = self._make_layer(
+ BasicBlock,
+ 256,
+ layers[1],
+ stride=1,
+ gcb_config=get_gcb_config(gcb_config, 1))
+
+ self.conv4 = nn.Conv2D(
+ 256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
+ self.bn4 = nn.BatchNorm2D(256)
+ self.relu4 = nn.ReLU()
+
+ self.maxpool3 = nn.MaxPool2D(kernel_size=2, stride=2)
+
+ self.layer3 = self._make_layer(
+ BasicBlock,
+ 512,
+ layers[2],
+ stride=1,
+ gcb_config=get_gcb_config(gcb_config, 2))
+
+ self.conv5 = nn.Conv2D(
+ 512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
+ self.bn5 = nn.BatchNorm2D(512)
+ self.relu5 = nn.ReLU()
+
+ self.layer4 = self._make_layer(
+ BasicBlock,
+ 512,
+ layers[3],
+ stride=1,
+ gcb_config=get_gcb_config(gcb_config, 3))
+
+ self.conv6 = nn.Conv2D(
+ 512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
+ self.bn6 = nn.BatchNorm2D(512)
+ self.relu6 = nn.ReLU()
+
+ self.out_channels = [256, 256, 512]
+
+ def _make_layer(self, block, planes, blocks, stride=1, gcb_config=None):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2D(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias_attr=False),
+ nn.BatchNorm2D(planes * block.expansion), )
+
+ layers = []
+ layers.append(
+ block(
+ self.inplanes,
+ planes,
+ stride,
+ downsample,
+ gcb_config=gcb_config))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ f = []
+ x = self.conv1(x)
+
+ x = self.bn1(x)
+ x = self.relu1(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu2(x)
+
+ x = self.maxpool1(x)
+ x = self.layer1(x)
+
+ x = self.conv3(x)
+ x = self.bn3(x)
+ x = self.relu3(x)
+ f.append(x)
+
+ x = self.maxpool2(x)
+ x = self.layer2(x)
+
+ x = self.conv4(x)
+ x = self.bn4(x)
+ x = self.relu4(x)
+ f.append(x)
+
+ x = self.maxpool3(x)
+
+ x = self.layer3(x)
+ x = self.conv5(x)
+ x = self.bn5(x)
+ x = self.relu5(x)
+
+ x = self.layer4(x)
+ x = self.conv6(x)
+ x = self.bn6(x)
+ x = self.relu6(x)
+ f.append(x)
+ return f
+
+
+class MultiAspectGCAttention(nn.Layer):
+ def __init__(self,
+ inplanes,
+ ratio,
+ headers,
+ pooling_type='att',
+ att_scale=False,
+ fusion_type='channel_add'):
+ super(MultiAspectGCAttention, self).__init__()
+ assert pooling_type in ['avg', 'att']
+
+ assert fusion_type in ['channel_add', 'channel_mul', 'channel_concat']
+ assert inplanes % headers == 0 and inplanes >= 8 # inplanes must be divided by headers evenly
+
+ self.headers = headers
+ self.inplanes = inplanes
+ self.ratio = ratio
+ self.planes = int(inplanes * ratio)
+ self.pooling_type = pooling_type
+ self.fusion_type = fusion_type
+ self.att_scale = False
+
+ self.single_header_inplanes = int(inplanes / headers)
+
+ if pooling_type == 'att':
+ self.conv_mask = nn.Conv2D(
+ self.single_header_inplanes, 1, kernel_size=1)
+ self.softmax = nn.Softmax(axis=2)
+ else:
+ self.avg_pool = nn.AdaptiveAvgPool2D(1)
+
+ if fusion_type == 'channel_add':
+ self.channel_add_conv = nn.Sequential(
+ nn.Conv2D(
+ self.inplanes, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(),
+ nn.Conv2D(
+ self.planes, self.inplanes, kernel_size=1))
+ elif fusion_type == 'channel_concat':
+ self.channel_concat_conv = nn.Sequential(
+ nn.Conv2D(
+ self.inplanes, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(),
+ nn.Conv2D(
+ self.planes, self.inplanes, kernel_size=1))
+ # for concat
+ self.cat_conv = nn.Conv2D(
+ 2 * self.inplanes, self.inplanes, kernel_size=1)
+ elif fusion_type == 'channel_mul':
+ self.channel_mul_conv = nn.Sequential(
+ nn.Conv2D(
+ self.inplanes, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(),
+ nn.Conv2D(
+ self.planes, self.inplanes, kernel_size=1))
+
+ def spatial_pool(self, x):
+ batch, channel, height, width = x.shape
+ if self.pooling_type == 'att':
+ # [N*headers, C', H , W] C = headers * C'
+ x = x.reshape([
+ batch * self.headers, self.single_header_inplanes, height, width
+ ])
+ input_x = x
+
+ # [N*headers, C', H * W] C = headers * C'
+ # input_x = input_x.view(batch, channel, height * width)
+ input_x = input_x.reshape([
+ batch * self.headers, self.single_header_inplanes,
+ height * width
+ ])
+
+ # [N*headers, 1, C', H * W]
+ input_x = input_x.unsqueeze(1)
+ # [N*headers, 1, H, W]
+ context_mask = self.conv_mask(x)
+ # [N*headers, 1, H * W]
+ context_mask = context_mask.reshape(
+ [batch * self.headers, 1, height * width])
+
+ # scale variance
+ if self.att_scale and self.headers > 1:
+ context_mask = context_mask / paddle.sqrt(
+ self.single_header_inplanes)
+
+ # [N*headers, 1, H * W]
+ context_mask = self.softmax(context_mask)
+
+ # [N*headers, 1, H * W, 1]
+ context_mask = context_mask.unsqueeze(-1)
+ # [N*headers, 1, C', 1] = [N*headers, 1, C', H * W] * [N*headers, 1, H * W, 1]
+ context = paddle.matmul(input_x, context_mask)
+
+ # [N, headers * C', 1, 1]
+ context = context.reshape(
+ [batch, self.headers * self.single_header_inplanes, 1, 1])
+ else:
+ # [N, C, 1, 1]
+ context = self.avg_pool(x)
+
+ return context
+
+ def forward(self, x):
+ # [N, C, 1, 1]
+ context = self.spatial_pool(x)
+
+ out = x
+
+ if self.fusion_type == 'channel_mul':
+ # [N, C, 1, 1]
+ channel_mul_term = F.sigmoid(self.channel_mul_conv(context))
+ out = out * channel_mul_term
+ elif self.fusion_type == 'channel_add':
+ # [N, C, 1, 1]
+ channel_add_term = self.channel_add_conv(context)
+ out = out + channel_add_term
+ else:
+ # [N, C, 1, 1]
+ channel_concat_term = self.channel_concat_conv(context)
+
+ # use concat
+ _, C1, _, _ = channel_concat_term.shape
+ N, C2, H, W = out.shape
+
+ out = paddle.concat(
+ [out, channel_concat_term.expand([-1, -1, H, W])], axis=1)
+ out = self.cat_conv(out)
+ out = F.layer_norm(out, [self.inplanes, H, W])
+ out = F.relu(out)
+
+ return out
diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py
index ede5b7a35af65fac351277cefccd89b251f5cdb7..34dd9d10ea36758059448d96674d4d2c249d3ad0 100644
--- a/ppocr/modeling/backbones/vqa_layoutlm.py
+++ b/ppocr/modeling/backbones/vqa_layoutlm.py
@@ -43,9 +43,11 @@ class NLPBaseModel(nn.Layer):
super(NLPBaseModel, self).__init__()
if checkpoints is not None:
self.model = model_class.from_pretrained(checkpoints)
+ elif isinstance(pretrained, (str, )) and os.path.exists(pretrained):
+ self.model = model_class.from_pretrained(pretrained)
else:
pretrained_model_name = pretrained_model_dict[base_model_class]
- if pretrained:
+ if pretrained is True:
base_model = base_model_class.from_pretrained(
pretrained_model_name)
else:
@@ -74,9 +76,9 @@ class LayoutLMForSer(NLPBaseModel):
def forward(self, x):
x = self.model(
input_ids=x[0],
- bbox=x[2],
- attention_mask=x[4],
- token_type_ids=x[5],
+ bbox=x[1],
+ attention_mask=x[2],
+ token_type_ids=x[3],
position_ids=None,
output_hidden_states=False)
return x
@@ -96,13 +98,15 @@ class LayoutLMv2ForSer(NLPBaseModel):
def forward(self, x):
x = self.model(
input_ids=x[0],
- bbox=x[2],
- image=x[3],
- attention_mask=x[4],
- token_type_ids=x[5],
+ bbox=x[1],
+ attention_mask=x[2],
+ token_type_ids=x[3],
+ image=x[4],
position_ids=None,
head_mask=None,
labels=None)
+ if not self.training:
+ return x
return x[0]
@@ -120,13 +124,15 @@ class LayoutXLMForSer(NLPBaseModel):
def forward(self, x):
x = self.model(
input_ids=x[0],
- bbox=x[2],
- image=x[3],
- attention_mask=x[4],
- token_type_ids=x[5],
+ bbox=x[1],
+ attention_mask=x[2],
+ token_type_ids=x[3],
+ image=x[4],
position_ids=None,
head_mask=None,
labels=None)
+ if not self.training:
+ return x
return x[0]
@@ -140,12 +146,12 @@ class LayoutLMv2ForRe(NLPBaseModel):
x = self.model(
input_ids=x[0],
bbox=x[1],
- labels=None,
- image=x[2],
- attention_mask=x[3],
- token_type_ids=x[4],
+ attention_mask=x[2],
+ token_type_ids=x[3],
+ image=x[4],
position_ids=None,
head_mask=None,
+ labels=None,
entities=x[5],
relations=x[6])
return x
@@ -161,12 +167,12 @@ class LayoutXLMForRe(NLPBaseModel):
x = self.model(
input_ids=x[0],
bbox=x[1],
- labels=None,
- image=x[2],
- attention_mask=x[3],
- token_type_ids=x[4],
+ attention_mask=x[2],
+ token_type_ids=x[3],
+ image=x[4],
position_ids=None,
head_mask=None,
+ labels=None,
entities=x[5],
relations=x[6])
return x
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index fd2d89315b3c8ef6f9b5edc418b80249bc8d20a0..99cb59e6725d01af4483ec21ff43fb3d3d7b5ae7 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -34,6 +34,7 @@ def build_head(config):
from .rec_pren_head import PRENHead
from .rec_multi_head import MultiHead
from .rec_robustscanner_head import RobustScannerHead
+ from .rec_abinet_head import ABINetHead
# cls head
from .cls_head import ClsHead
@@ -42,12 +43,13 @@ def build_head(config):
from .kie_sdmgr_head import SDMGRHead
from .table_att_head import TableAttentionHead
+ from .table_master_head import TableMasterHead
support_dict = [
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
- 'MultiHead', 'RobustScannerHead'
+ 'MultiHead', 'ABINetHead', 'TableMasterHead', 'RobustScannerHead'
]
#table head
diff --git a/ppocr/modeling/heads/multiheadAttention.py b/ppocr/modeling/heads/multiheadAttention.py
deleted file mode 100755
index 900865ba1a8d80a108b3247ce1aff91c242860f2..0000000000000000000000000000000000000000
--- a/ppocr/modeling/heads/multiheadAttention.py
+++ /dev/null
@@ -1,163 +0,0 @@
-# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import paddle
-from paddle import nn
-import paddle.nn.functional as F
-from paddle.nn import Linear
-from paddle.nn.initializer import XavierUniform as xavier_uniform_
-from paddle.nn.initializer import Constant as constant_
-from paddle.nn.initializer import XavierNormal as xavier_normal_
-
-zeros_ = constant_(value=0.)
-ones_ = constant_(value=1.)
-
-
-class MultiheadAttention(nn.Layer):
- """Allows the model to jointly attend to information
- from different representation subspaces.
- See reference: Attention Is All You Need
-
- .. math::
- \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
- \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
-
- Args:
- embed_dim: total dimension of the model
- num_heads: parallel attention layers, or heads
-
- """
-
- def __init__(self,
- embed_dim,
- num_heads,
- dropout=0.,
- bias=True,
- add_bias_kv=False,
- add_zero_attn=False):
- super(MultiheadAttention, self).__init__()
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.dropout = dropout
- self.head_dim = embed_dim // num_heads
- assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
- self.scaling = self.head_dim**-0.5
- self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
- self._reset_parameters()
- self.conv1 = paddle.nn.Conv2D(
- in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
- self.conv2 = paddle.nn.Conv2D(
- in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
- self.conv3 = paddle.nn.Conv2D(
- in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
-
- def _reset_parameters(self):
- xavier_uniform_(self.out_proj.weight)
-
- def forward(self,
- query,
- key,
- value,
- key_padding_mask=None,
- incremental_state=None,
- attn_mask=None):
- """
- Inputs of forward function
- query: [target length, batch size, embed dim]
- key: [sequence length, batch size, embed dim]
- value: [sequence length, batch size, embed dim]
- key_padding_mask: if True, mask padding based on batch size
- incremental_state: if provided, previous time steps are cashed
- need_weights: output attn_output_weights
- static_kv: key and value are static
-
- Outputs of forward function
- attn_output: [target length, batch size, embed dim]
- attn_output_weights: [batch size, target length, sequence length]
- """
- q_shape = paddle.shape(query)
- src_shape = paddle.shape(key)
- q = self._in_proj_q(query)
- k = self._in_proj_k(key)
- v = self._in_proj_v(value)
- q *= self.scaling
- q = paddle.transpose(
- paddle.reshape(
- q, [q_shape[0], q_shape[1], self.num_heads, self.head_dim]),
- [1, 2, 0, 3])
- k = paddle.transpose(
- paddle.reshape(
- k, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
- [1, 2, 0, 3])
- v = paddle.transpose(
- paddle.reshape(
- v, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
- [1, 2, 0, 3])
- if key_padding_mask is not None:
- assert key_padding_mask.shape[0] == q_shape[1]
- assert key_padding_mask.shape[1] == src_shape[0]
- attn_output_weights = paddle.matmul(q,
- paddle.transpose(k, [0, 1, 3, 2]))
- if attn_mask is not None:
- attn_mask = paddle.unsqueeze(paddle.unsqueeze(attn_mask, 0), 0)
- attn_output_weights += attn_mask
- if key_padding_mask is not None:
- attn_output_weights = paddle.reshape(
- attn_output_weights,
- [q_shape[1], self.num_heads, q_shape[0], src_shape[0]])
- key = paddle.unsqueeze(paddle.unsqueeze(key_padding_mask, 1), 2)
- key = paddle.cast(key, 'float32')
- y = paddle.full(
- shape=paddle.shape(key), dtype='float32', fill_value='-inf')
- y = paddle.where(key == 0., key, y)
- attn_output_weights += y
- attn_output_weights = F.softmax(
- attn_output_weights.astype('float32'),
- axis=-1,
- dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16
- else attn_output_weights.dtype)
- attn_output_weights = F.dropout(
- attn_output_weights, p=self.dropout, training=self.training)
-
- attn_output = paddle.matmul(attn_output_weights, v)
- attn_output = paddle.reshape(
- paddle.transpose(attn_output, [2, 0, 1, 3]),
- [q_shape[0], q_shape[1], self.embed_dim])
- attn_output = self.out_proj(attn_output)
-
- return attn_output
-
- def _in_proj_q(self, query):
- query = paddle.transpose(query, [1, 2, 0])
- query = paddle.unsqueeze(query, axis=2)
- res = self.conv1(query)
- res = paddle.squeeze(res, axis=2)
- res = paddle.transpose(res, [2, 0, 1])
- return res
-
- def _in_proj_k(self, key):
- key = paddle.transpose(key, [1, 2, 0])
- key = paddle.unsqueeze(key, axis=2)
- res = self.conv2(key)
- res = paddle.squeeze(res, axis=2)
- res = paddle.transpose(res, [2, 0, 1])
- return res
-
- def _in_proj_v(self, value):
- value = paddle.transpose(value, [1, 2, 0]) #(1, 2, 0)
- value = paddle.unsqueeze(value, axis=2)
- res = self.conv3(value)
- res = paddle.squeeze(res, axis=2)
- res = paddle.transpose(res, [2, 0, 1])
- return res
diff --git a/ppocr/modeling/heads/rec_abinet_head.py b/ppocr/modeling/heads/rec_abinet_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0f60f1be1727e85380eedb7d311ce9445f88b8e
--- /dev/null
+++ b/ppocr/modeling/heads/rec_abinet_head.py
@@ -0,0 +1,296 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/FangShancheng/ABINet/tree/main/modules
+"""
+
+import math
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle.nn import LayerList
+from ppocr.modeling.heads.rec_nrtr_head import TransformerBlock, PositionalEncoding
+
+
+class BCNLanguage(nn.Layer):
+ def __init__(self,
+ d_model=512,
+ nhead=8,
+ num_layers=4,
+ dim_feedforward=2048,
+ dropout=0.,
+ max_length=25,
+ detach=True,
+ num_classes=37):
+ super().__init__()
+
+ self.d_model = d_model
+ self.detach = detach
+ self.max_length = max_length + 1 # additional stop token
+ self.proj = nn.Linear(num_classes, d_model, bias_attr=False)
+ self.token_encoder = PositionalEncoding(
+ dropout=0.1, dim=d_model, max_len=self.max_length)
+ self.pos_encoder = PositionalEncoding(
+ dropout=0, dim=d_model, max_len=self.max_length)
+
+ self.decoder = nn.LayerList([
+ TransformerBlock(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ attention_dropout_rate=dropout,
+ residual_dropout_rate=dropout,
+ with_self_attn=False,
+ with_cross_attn=True) for i in range(num_layers)
+ ])
+
+ self.cls = nn.Linear(d_model, num_classes)
+
+ def forward(self, tokens, lengths):
+ """
+ Args:
+ tokens: (B, N, C) where N is length, B is batch size and C is classes number
+ lengths: (B,)
+ """
+ if self.detach: tokens = tokens.detach()
+ embed = self.proj(tokens) # (B, N, C)
+ embed = self.token_encoder(embed) # (B, N, C)
+ padding_mask = _get_mask(lengths, self.max_length)
+ zeros = paddle.zeros_like(embed) # (B, N, C)
+ qeury = self.pos_encoder(zeros)
+ for decoder_layer in self.decoder:
+ qeury = decoder_layer(qeury, embed, cross_mask=padding_mask)
+ output = qeury # (B, N, C)
+
+ logits = self.cls(output) # (B, N, C)
+
+ return output, logits
+
+
+def encoder_layer(in_c, out_c, k=3, s=2, p=1):
+ return nn.Sequential(
+ nn.Conv2D(in_c, out_c, k, s, p), nn.BatchNorm2D(out_c), nn.ReLU())
+
+
+def decoder_layer(in_c,
+ out_c,
+ k=3,
+ s=1,
+ p=1,
+ mode='nearest',
+ scale_factor=None,
+ size=None):
+ align_corners = False if mode == 'nearest' else True
+ return nn.Sequential(
+ nn.Upsample(
+ size=size,
+ scale_factor=scale_factor,
+ mode=mode,
+ align_corners=align_corners),
+ nn.Conv2D(in_c, out_c, k, s, p),
+ nn.BatchNorm2D(out_c),
+ nn.ReLU())
+
+
+class PositionAttention(nn.Layer):
+ def __init__(self,
+ max_length,
+ in_channels=512,
+ num_channels=64,
+ h=8,
+ w=32,
+ mode='nearest',
+ **kwargs):
+ super().__init__()
+ self.max_length = max_length
+ self.k_encoder = nn.Sequential(
+ encoder_layer(
+ in_channels, num_channels, s=(1, 2)),
+ encoder_layer(
+ num_channels, num_channels, s=(2, 2)),
+ encoder_layer(
+ num_channels, num_channels, s=(2, 2)),
+ encoder_layer(
+ num_channels, num_channels, s=(2, 2)))
+ self.k_decoder = nn.Sequential(
+ decoder_layer(
+ num_channels, num_channels, scale_factor=2, mode=mode),
+ decoder_layer(
+ num_channels, num_channels, scale_factor=2, mode=mode),
+ decoder_layer(
+ num_channels, num_channels, scale_factor=2, mode=mode),
+ decoder_layer(
+ num_channels, in_channels, size=(h, w), mode=mode))
+
+ self.pos_encoder = PositionalEncoding(
+ dropout=0, dim=in_channels, max_len=max_length)
+ self.project = nn.Linear(in_channels, in_channels)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ k, v = x, x
+
+ # calculate key vector
+ features = []
+ for i in range(0, len(self.k_encoder)):
+ k = self.k_encoder[i](k)
+ features.append(k)
+ for i in range(0, len(self.k_decoder) - 1):
+ k = self.k_decoder[i](k)
+ # print(k.shape, features[len(self.k_decoder) - 2 - i].shape)
+ k = k + features[len(self.k_decoder) - 2 - i]
+ k = self.k_decoder[-1](k)
+
+ # calculate query vector
+ # TODO q=f(q,k)
+ zeros = paddle.zeros(
+ (B, self.max_length, C), dtype=x.dtype) # (T, N, C)
+ q = self.pos_encoder(zeros) # (B, N, C)
+ q = self.project(q) # (B, N, C)
+
+ # calculate attention
+ attn_scores = q @k.flatten(2) # (B, N, (H*W))
+ attn_scores = attn_scores / (C**0.5)
+ attn_scores = F.softmax(attn_scores, axis=-1)
+
+ v = v.flatten(2).transpose([0, 2, 1]) # (B, (H*W), C)
+ attn_vecs = attn_scores @v # (B, N, C)
+
+ return attn_vecs, attn_scores.reshape([0, self.max_length, H, W])
+
+
+class ABINetHead(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ d_model=512,
+ nhead=8,
+ num_layers=3,
+ dim_feedforward=2048,
+ dropout=0.1,
+ max_length=25,
+ use_lang=False,
+ iter_size=1):
+ super().__init__()
+ self.max_length = max_length + 1
+ self.pos_encoder = PositionalEncoding(
+ dropout=0.1, dim=d_model, max_len=8 * 32)
+ self.encoder = nn.LayerList([
+ TransformerBlock(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ attention_dropout_rate=dropout,
+ residual_dropout_rate=dropout,
+ with_self_attn=True,
+ with_cross_attn=False) for i in range(num_layers)
+ ])
+ self.decoder = PositionAttention(
+ max_length=max_length + 1, # additional stop token
+ mode='nearest', )
+ self.out_channels = out_channels
+ self.cls = nn.Linear(d_model, self.out_channels)
+ self.use_lang = use_lang
+ if use_lang:
+ self.iter_size = iter_size
+ self.language = BCNLanguage(
+ d_model=d_model,
+ nhead=nhead,
+ num_layers=4,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ max_length=max_length,
+ num_classes=self.out_channels)
+ # alignment
+ self.w_att_align = nn.Linear(2 * d_model, d_model)
+ self.cls_align = nn.Linear(d_model, self.out_channels)
+
+ def forward(self, x, targets=None):
+ x = x.transpose([0, 2, 3, 1])
+ _, H, W, C = x.shape
+ feature = x.flatten(1, 2)
+ feature = self.pos_encoder(feature)
+ for encoder_layer in self.encoder:
+ feature = encoder_layer(feature)
+ feature = feature.reshape([0, H, W, C]).transpose([0, 3, 1, 2])
+ v_feature, attn_scores = self.decoder(
+ feature) # (B, N, C), (B, C, H, W)
+ vis_logits = self.cls(v_feature) # (B, N, C)
+ logits = vis_logits
+ vis_lengths = _get_length(vis_logits)
+ if self.use_lang:
+ align_logits = vis_logits
+ align_lengths = vis_lengths
+ all_l_res, all_a_res = [], []
+ for i in range(self.iter_size):
+ tokens = F.softmax(align_logits, axis=-1)
+ lengths = align_lengths
+ lengths = paddle.clip(
+ lengths, 2, self.max_length) # TODO:move to langauge model
+ l_feature, l_logits = self.language(tokens, lengths)
+
+ # alignment
+ all_l_res.append(l_logits)
+ fuse = paddle.concat((l_feature, v_feature), -1)
+ f_att = F.sigmoid(self.w_att_align(fuse))
+ output = f_att * v_feature + (1 - f_att) * l_feature
+ align_logits = self.cls_align(output) # (B, N, C)
+
+ align_lengths = _get_length(align_logits)
+ all_a_res.append(align_logits)
+ if self.training:
+ return {
+ 'align': all_a_res,
+ 'lang': all_l_res,
+ 'vision': vis_logits
+ }
+ else:
+ logits = align_logits
+ if self.training:
+ return logits
+ else:
+ return F.softmax(logits, -1)
+
+
+def _get_length(logit):
+ """ Greed decoder to obtain length from logit"""
+ out = (logit.argmax(-1) == 0)
+ abn = out.any(-1)
+ out_int = out.cast('int32')
+ out = (out_int.cumsum(-1) == 1) & out
+ out = out.cast('int32')
+ out = out.argmax(-1)
+ out = out + 1
+ out = paddle.where(abn, out, paddle.to_tensor(logit.shape[1]))
+ return out
+
+
+def _get_mask(length, max_length):
+ """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
+ Unmasked positions are filled with float(0.0).
+ """
+ length = length.unsqueeze(-1)
+ B = paddle.shape(length)[0]
+ grid = paddle.arange(0, max_length).unsqueeze(0).tile([B, 1])
+ zero_mask = paddle.zeros([B, max_length], dtype='float32')
+ inf_mask = paddle.full([B, max_length], '-inf', dtype='float32')
+ diag_mask = paddle.diag(
+ paddle.full(
+ [max_length], '-inf', dtype=paddle.float32),
+ offset=0,
+ name=None)
+ mask = paddle.where(grid >= length, inf_mask, zero_mask)
+ mask = mask.unsqueeze(1) + diag_mask
+ return mask.unsqueeze(1)
diff --git a/ppocr/modeling/heads/rec_nrtr_head.py b/ppocr/modeling/heads/rec_nrtr_head.py
index 38ba0c917840ea7d1e2a3c2bf0da32c2c35f2b40..bf9ef56145e6edfb15bd30235b4a62588396ba96 100644
--- a/ppocr/modeling/heads/rec_nrtr_head.py
+++ b/ppocr/modeling/heads/rec_nrtr_head.py
@@ -14,20 +14,15 @@
import math
import paddle
-import copy
from paddle import nn
import paddle.nn.functional as F
from paddle.nn import LayerList
-from paddle.nn.initializer import XavierNormal as xavier_uniform_
-from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
+# from paddle.nn.initializer import XavierNormal as xavier_uniform_
+from paddle.nn import Dropout, Linear, LayerNorm
import numpy as np
-from ppocr.modeling.heads.multiheadAttention import MultiheadAttention
-from paddle.nn.initializer import Constant as constant_
+from ppocr.modeling.backbones.rec_svtrnet import Mlp, zeros_, ones_
from paddle.nn.initializer import XavierNormal as xavier_normal_
-zeros_ = constant_(value=0.)
-ones_ = constant_(value=1.)
-
class Transformer(nn.Layer):
"""A transformer model. User is able to modify the attributes as needed. The architechture
@@ -45,7 +40,6 @@ class Transformer(nn.Layer):
dropout: the dropout value (default=0.1).
custom_encoder: custom encoder (default=None).
custom_decoder: custom decoder (default=None).
-
"""
def __init__(self,
@@ -54,45 +48,49 @@ class Transformer(nn.Layer):
num_encoder_layers=6,
beam_size=0,
num_decoder_layers=6,
+ max_len=25,
dim_feedforward=1024,
attention_dropout_rate=0.0,
residual_dropout_rate=0.1,
- custom_encoder=None,
- custom_decoder=None,
in_channels=0,
out_channels=0,
scale_embedding=True):
super(Transformer, self).__init__()
self.out_channels = out_channels + 1
+ self.max_len = max_len
self.embedding = Embeddings(
d_model=d_model,
vocab=self.out_channels,
padding_idx=0,
scale_embedding=scale_embedding)
self.positional_encoding = PositionalEncoding(
- dropout=residual_dropout_rate,
- dim=d_model, )
- if custom_encoder is not None:
- self.encoder = custom_encoder
- else:
- if num_encoder_layers > 0:
- encoder_layer = TransformerEncoderLayer(
- d_model, nhead, dim_feedforward, attention_dropout_rate,
- residual_dropout_rate)
- self.encoder = TransformerEncoder(encoder_layer,
- num_encoder_layers)
- else:
- self.encoder = None
-
- if custom_decoder is not None:
- self.decoder = custom_decoder
+ dropout=residual_dropout_rate, dim=d_model)
+
+ if num_encoder_layers > 0:
+ self.encoder = nn.LayerList([
+ TransformerBlock(
+ d_model,
+ nhead,
+ dim_feedforward,
+ attention_dropout_rate,
+ residual_dropout_rate,
+ with_self_attn=True,
+ with_cross_attn=False) for i in range(num_encoder_layers)
+ ])
else:
- decoder_layer = TransformerDecoderLayer(
- d_model, nhead, dim_feedforward, attention_dropout_rate,
- residual_dropout_rate)
- self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
+ self.encoder = None
+
+ self.decoder = nn.LayerList([
+ TransformerBlock(
+ d_model,
+ nhead,
+ dim_feedforward,
+ attention_dropout_rate,
+ residual_dropout_rate,
+ with_self_attn=True,
+ with_cross_attn=True) for i in range(num_decoder_layers)
+ ])
- self._reset_parameters()
self.beam_size = beam_size
self.d_model = d_model
self.nhead = nhead
@@ -105,7 +103,7 @@ class Transformer(nn.Layer):
def _init_weights(self, m):
- if isinstance(m, nn.Conv2D):
+ if isinstance(m, nn.Linear):
xavier_normal_(m.weight)
if m.bias is not None:
zeros_(m.bias)
@@ -113,24 +111,20 @@ class Transformer(nn.Layer):
def forward_train(self, src, tgt):
tgt = tgt[:, :-1]
- tgt_key_padding_mask = self.generate_padding_mask(tgt)
- tgt = self.embedding(tgt).transpose([1, 0, 2])
+ tgt = self.embedding(tgt)
tgt = self.positional_encoding(tgt)
- tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0])
+ tgt_mask = self.generate_square_subsequent_mask(tgt.shape[1])
if self.encoder is not None:
- src = self.positional_encoding(src.transpose([1, 0, 2]))
- memory = self.encoder(src)
+ src = self.positional_encoding(src)
+ for encoder_layer in self.encoder:
+ src = encoder_layer(src)
+ memory = src # B N C
else:
- memory = src.squeeze(2).transpose([2, 0, 1])
- output = self.decoder(
- tgt,
- memory,
- tgt_mask=tgt_mask,
- memory_mask=None,
- tgt_key_padding_mask=tgt_key_padding_mask,
- memory_key_padding_mask=None)
- output = output.transpose([1, 0, 2])
+ memory = src # B N C
+ for decoder_layer in self.decoder:
+ tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
+ output = tgt
logit = self.tgt_word_prj(output)
return logit
@@ -140,8 +134,8 @@ class Transformer(nn.Layer):
src: the sequence to the encoder (required).
tgt: the sequence to the decoder (required).
Shape:
- - src: :math:`(S, N, E)`.
- - tgt: :math:`(T, N, E)`.
+ - src: :math:`(B, sN, C)`.
+ - tgt: :math:`(B, tN, C)`.
Examples:
>>> output = transformer_model(src, tgt)
"""
@@ -157,36 +151,35 @@ class Transformer(nn.Layer):
return self.forward_test(src)
def forward_test(self, src):
+
bs = paddle.shape(src)[0]
if self.encoder is not None:
- src = self.positional_encoding(paddle.transpose(src, [1, 0, 2]))
- memory = self.encoder(src)
+ src = self.positional_encoding(src)
+ for encoder_layer in self.encoder:
+ src = encoder_layer(src)
+ memory = src # B N C
else:
- memory = paddle.transpose(paddle.squeeze(src, 2), [2, 0, 1])
+ memory = src
dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32)
- for len_dec_seq in range(1, 25):
- dec_seq_embed = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
+ for len_dec_seq in range(1, self.max_len):
+ dec_seq_embed = self.embedding(dec_seq)
dec_seq_embed = self.positional_encoding(dec_seq_embed)
tgt_mask = self.generate_square_subsequent_mask(
- paddle.shape(dec_seq_embed)[0])
- output = self.decoder(
- dec_seq_embed,
- memory,
- tgt_mask=tgt_mask,
- memory_mask=None,
- tgt_key_padding_mask=None,
- memory_key_padding_mask=None)
- dec_output = paddle.transpose(output, [1, 0, 2])
+ paddle.shape(dec_seq_embed)[1])
+ tgt = dec_seq_embed
+ for decoder_layer in self.decoder:
+ tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
+ dec_output = tgt
dec_output = dec_output[:, -1, :]
- word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
- preds_idx = paddle.argmax(word_prob, axis=1)
+ word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=-1)
+ preds_idx = paddle.argmax(word_prob, axis=-1)
if paddle.equal_all(
preds_idx,
paddle.full(
paddle.shape(preds_idx), 3, dtype='int64')):
break
- preds_prob = paddle.max(word_prob, axis=1)
+ preds_prob = paddle.max(word_prob, axis=-1)
dec_seq = paddle.concat(
[dec_seq, paddle.reshape(preds_idx, [-1, 1])], axis=1)
dec_prob = paddle.concat(
@@ -194,10 +187,10 @@ class Transformer(nn.Layer):
return [dec_seq, dec_prob]
def forward_beam(self, images):
- ''' Translation work in one batch '''
+ """ Translation work in one batch """
def get_inst_idx_to_tensor_position_map(inst_idx_list):
- ''' Indicate the position of an instance in a tensor. '''
+ """ Indicate the position of an instance in a tensor. """
return {
inst_idx: tensor_position
for tensor_position, inst_idx in enumerate(inst_idx_list)
@@ -205,7 +198,7 @@ class Transformer(nn.Layer):
def collect_active_part(beamed_tensor, curr_active_inst_idx,
n_prev_active_inst, n_bm):
- ''' Collect tensor parts associated to active instances. '''
+ """ Collect tensor parts associated to active instances. """
beamed_tensor_shape = paddle.shape(beamed_tensor)
n_curr_active_inst = len(curr_active_inst_idx)
@@ -237,9 +230,8 @@ class Transformer(nn.Layer):
return active_src_enc, active_inst_idx_to_position_map
def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
- inst_idx_to_position_map, n_bm,
- memory_key_padding_mask):
- ''' Decode and update beam status, and then return active beam idx '''
+ inst_idx_to_position_map, n_bm):
+ """ Decode and update beam status, and then return active beam idx """
def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
dec_partial_seq = [
@@ -249,19 +241,15 @@ class Transformer(nn.Layer):
dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
return dec_partial_seq
- def predict_word(dec_seq, enc_output, n_active_inst, n_bm,
- memory_key_padding_mask):
- dec_seq = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
+ def predict_word(dec_seq, enc_output, n_active_inst, n_bm):
+ dec_seq = self.embedding(dec_seq)
dec_seq = self.positional_encoding(dec_seq)
tgt_mask = self.generate_square_subsequent_mask(
- paddle.shape(dec_seq)[0])
- dec_output = self.decoder(
- dec_seq,
- enc_output,
- tgt_mask=tgt_mask,
- tgt_key_padding_mask=None,
- memory_key_padding_mask=memory_key_padding_mask, )
- dec_output = paddle.transpose(dec_output, [1, 0, 2])
+ paddle.shape(dec_seq)[1])
+ tgt = dec_seq
+ for decoder_layer in self.decoder:
+ tgt = decoder_layer(tgt, enc_output, self_mask=tgt_mask)
+ dec_output = tgt
dec_output = dec_output[:,
-1, :] # Pick the last step: (bh * bm) * d_h
word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
@@ -281,8 +269,7 @@ class Transformer(nn.Layer):
n_active_inst = len(inst_idx_to_position_map)
dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
- word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm,
- None)
+ word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm)
# Update the beam with predicted word prob information and collect incomplete instances
active_inst_idx_list = collect_active_inst_idx_list(
inst_dec_beams, word_prob, inst_idx_to_position_map)
@@ -303,10 +290,10 @@ class Transformer(nn.Layer):
with paddle.no_grad():
#-- Encode
if self.encoder is not None:
- src = self.positional_encoding(images.transpose([1, 0, 2]))
+ src = self.positional_encoding(images)
src_enc = self.encoder(src)
else:
- src_enc = images.squeeze(2).transpose([0, 2, 1])
+ src_enc = images
n_bm = self.beam_size
src_shape = paddle.shape(src_enc)
@@ -317,11 +304,11 @@ class Transformer(nn.Layer):
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
active_inst_idx_list)
# Decode
- for len_dec_seq in range(1, 25):
+ for len_dec_seq in range(1, self.max_len):
src_enc_copy = src_enc.clone()
active_inst_idx_list = beam_decode_step(
inst_dec_beams, len_dec_seq, src_enc_copy,
- inst_idx_to_position_map, n_bm, None)
+ inst_idx_to_position_map, n_bm)
if not active_inst_idx_list:
break # all instances have finished their path to
src_enc, inst_idx_to_position_map = collate_active_info(
@@ -354,261 +341,124 @@ class Transformer(nn.Layer):
shape=[sz, sz], dtype='float32', fill_value='-inf'),
diagonal=1)
mask = mask + mask_inf
- return mask
-
- def generate_padding_mask(self, x):
- padding_mask = paddle.equal(x, paddle.to_tensor(0, dtype=x.dtype))
- return padding_mask
+ return mask.unsqueeze([0, 1])
- def _reset_parameters(self):
- """Initiate parameters in the transformer model."""
- for p in self.parameters():
- if p.dim() > 1:
- xavier_uniform_(p)
+class MultiheadAttention(nn.Layer):
+ """Allows the model to jointly attend to information
+ from different representation subspaces.
+ See reference: Attention Is All You Need
+ .. math::
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
-class TransformerEncoder(nn.Layer):
- """TransformerEncoder is a stack of N encoder layers
Args:
- encoder_layer: an instance of the TransformerEncoderLayer() class (required).
- num_layers: the number of sub-encoder-layers in the encoder (required).
- norm: the layer normalization component (optional).
- """
+ embed_dim: total dimension of the model
+ num_heads: parallel attention layers, or heads
- def __init__(self, encoder_layer, num_layers):
- super(TransformerEncoder, self).__init__()
- self.layers = _get_clones(encoder_layer, num_layers)
- self.num_layers = num_layers
+ """
- def forward(self, src):
- """Pass the input through the endocder layers in turn.
- Args:
- src: the sequnce to the encoder (required).
- mask: the mask for the src sequence (optional).
- src_key_padding_mask: the mask for the src keys per batch (optional).
- """
- output = src
+ def __init__(self, embed_dim, num_heads, dropout=0., self_attn=False):
+ super(MultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ # self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ self.scale = self.head_dim**-0.5
+ self.self_attn = self_attn
+ if self_attn:
+ self.qkv = nn.Linear(embed_dim, embed_dim * 3)
+ else:
+ self.q = nn.Linear(embed_dim, embed_dim)
+ self.kv = nn.Linear(embed_dim, embed_dim * 2)
+ self.attn_drop = nn.Dropout(dropout)
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
- for i in range(self.num_layers):
- output = self.layers[i](output,
- src_mask=None,
- src_key_padding_mask=None)
+ def forward(self, query, key=None, attn_mask=None):
- return output
+ qN = query.shape[1]
+ if self.self_attn:
+ qkv = self.qkv(query).reshape(
+ (0, qN, 3, self.num_heads, self.head_dim)).transpose(
+ (2, 0, 3, 1, 4))
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ else:
+ kN = key.shape[1]
+ q = self.q(query).reshape(
+ [0, qN, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3])
+ kv = self.kv(key).reshape(
+ (0, kN, 2, self.num_heads, self.head_dim)).transpose(
+ (2, 0, 3, 1, 4))
+ k, v = kv[0], kv[1]
-class TransformerDecoder(nn.Layer):
- """TransformerDecoder is a stack of N decoder layers
+ attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
- Args:
- decoder_layer: an instance of the TransformerDecoderLayer() class (required).
- num_layers: the number of sub-decoder-layers in the decoder (required).
- norm: the layer normalization component (optional).
+ if attn_mask is not None:
+ attn += attn_mask
- """
+ attn = F.softmax(attn, axis=-1)
+ attn = self.attn_drop(attn)
- def __init__(self, decoder_layer, num_layers):
- super(TransformerDecoder, self).__init__()
- self.layers = _get_clones(decoder_layer, num_layers)
- self.num_layers = num_layers
+ x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape(
+ (0, qN, self.embed_dim))
+ x = self.out_proj(x)
- def forward(self,
- tgt,
- memory,
- tgt_mask=None,
- memory_mask=None,
- tgt_key_padding_mask=None,
- memory_key_padding_mask=None):
- """Pass the inputs (and mask) through the decoder layer in turn.
+ return x
- Args:
- tgt: the sequence to the decoder (required).
- memory: the sequnce from the last layer of the encoder (required).
- tgt_mask: the mask for the tgt sequence (optional).
- memory_mask: the mask for the memory sequence (optional).
- tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
- memory_key_padding_mask: the mask for the memory keys per batch (optional).
- """
- output = tgt
- for i in range(self.num_layers):
- output = self.layers[i](
- output,
- memory,
- tgt_mask=tgt_mask,
- memory_mask=memory_mask,
- tgt_key_padding_mask=tgt_key_padding_mask,
- memory_key_padding_mask=memory_key_padding_mask)
-
- return output
-
-
-class TransformerEncoderLayer(nn.Layer):
- """TransformerEncoderLayer is made up of self-attn and feedforward network.
- This standard encoder layer is based on the paper "Attention Is All You Need".
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
- Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
- Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
- in a different way during application.
-
- Args:
- d_model: the number of expected features in the input (required).
- nhead: the number of heads in the multiheadattention models (required).
- dim_feedforward: the dimension of the feedforward network model (default=2048).
- dropout: the dropout value (default=0.1).
-
- """
+class TransformerBlock(nn.Layer):
def __init__(self,
d_model,
nhead,
dim_feedforward=2048,
attention_dropout_rate=0.0,
- residual_dropout_rate=0.1):
- super(TransformerEncoderLayer, self).__init__()
- self.self_attn = MultiheadAttention(
- d_model, nhead, dropout=attention_dropout_rate)
-
- self.conv1 = Conv2D(
- in_channels=d_model,
- out_channels=dim_feedforward,
- kernel_size=(1, 1))
- self.conv2 = Conv2D(
- in_channels=dim_feedforward,
- out_channels=d_model,
- kernel_size=(1, 1))
-
- self.norm1 = LayerNorm(d_model)
- self.norm2 = LayerNorm(d_model)
- self.dropout1 = Dropout(residual_dropout_rate)
- self.dropout2 = Dropout(residual_dropout_rate)
-
- def forward(self, src, src_mask=None, src_key_padding_mask=None):
- """Pass the input through the endocder layer.
- Args:
- src: the sequnce to the encoder layer (required).
- src_mask: the mask for the src sequence (optional).
- src_key_padding_mask: the mask for the src keys per batch (optional).
- """
- src2 = self.self_attn(
- src,
- src,
- src,
- attn_mask=src_mask,
- key_padding_mask=src_key_padding_mask)
- src = src + self.dropout1(src2)
- src = self.norm1(src)
-
- src = paddle.transpose(src, [1, 2, 0])
- src = paddle.unsqueeze(src, 2)
- src2 = self.conv2(F.relu(self.conv1(src)))
- src2 = paddle.squeeze(src2, 2)
- src2 = paddle.transpose(src2, [2, 0, 1])
- src = paddle.squeeze(src, 2)
- src = paddle.transpose(src, [2, 0, 1])
-
- src = src + self.dropout2(src2)
- src = self.norm2(src)
- return src
-
-
-class TransformerDecoderLayer(nn.Layer):
- """TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
- This standard decoder layer is based on the paper "Attention Is All You Need".
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
- Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
- Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
- in a different way during application.
-
- Args:
- d_model: the number of expected features in the input (required).
- nhead: the number of heads in the multiheadattention models (required).
- dim_feedforward: the dimension of the feedforward network model (default=2048).
- dropout: the dropout value (default=0.1).
-
- """
+ residual_dropout_rate=0.1,
+ with_self_attn=True,
+ with_cross_attn=False,
+ epsilon=1e-5):
+ super(TransformerBlock, self).__init__()
+ self.with_self_attn = with_self_attn
+ if with_self_attn:
+ self.self_attn = MultiheadAttention(
+ d_model,
+ nhead,
+ dropout=attention_dropout_rate,
+ self_attn=with_self_attn)
+ self.norm1 = LayerNorm(d_model, epsilon=epsilon)
+ self.dropout1 = Dropout(residual_dropout_rate)
+ self.with_cross_attn = with_cross_attn
+ if with_cross_attn:
+ self.cross_attn = MultiheadAttention( #for self_attn of encoder or cross_attn of decoder
+ d_model,
+ nhead,
+ dropout=attention_dropout_rate)
+ self.norm2 = LayerNorm(d_model, epsilon=epsilon)
+ self.dropout2 = Dropout(residual_dropout_rate)
+
+ self.mlp = Mlp(in_features=d_model,
+ hidden_features=dim_feedforward,
+ act_layer=nn.ReLU,
+ drop=residual_dropout_rate)
+
+ self.norm3 = LayerNorm(d_model, epsilon=epsilon)
- def __init__(self,
- d_model,
- nhead,
- dim_feedforward=2048,
- attention_dropout_rate=0.0,
- residual_dropout_rate=0.1):
- super(TransformerDecoderLayer, self).__init__()
- self.self_attn = MultiheadAttention(
- d_model, nhead, dropout=attention_dropout_rate)
- self.multihead_attn = MultiheadAttention(
- d_model, nhead, dropout=attention_dropout_rate)
-
- self.conv1 = Conv2D(
- in_channels=d_model,
- out_channels=dim_feedforward,
- kernel_size=(1, 1))
- self.conv2 = Conv2D(
- in_channels=dim_feedforward,
- out_channels=d_model,
- kernel_size=(1, 1))
-
- self.norm1 = LayerNorm(d_model)
- self.norm2 = LayerNorm(d_model)
- self.norm3 = LayerNorm(d_model)
- self.dropout1 = Dropout(residual_dropout_rate)
- self.dropout2 = Dropout(residual_dropout_rate)
self.dropout3 = Dropout(residual_dropout_rate)
- def forward(self,
- tgt,
- memory,
- tgt_mask=None,
- memory_mask=None,
- tgt_key_padding_mask=None,
- memory_key_padding_mask=None):
- """Pass the inputs (and mask) through the decoder layer.
+ def forward(self, tgt, memory=None, self_mask=None, cross_mask=None):
+ if self.with_self_attn:
+ tgt1 = self.self_attn(tgt, attn_mask=self_mask)
+ tgt = self.norm1(tgt + self.dropout1(tgt1))
- Args:
- tgt: the sequence to the decoder layer (required).
- memory: the sequnce from the last layer of the encoder (required).
- tgt_mask: the mask for the tgt sequence (optional).
- memory_mask: the mask for the memory sequence (optional).
- tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
- memory_key_padding_mask: the mask for the memory keys per batch (optional).
-
- """
- tgt2 = self.self_attn(
- tgt,
- tgt,
- tgt,
- attn_mask=tgt_mask,
- key_padding_mask=tgt_key_padding_mask)
- tgt = tgt + self.dropout1(tgt2)
- tgt = self.norm1(tgt)
- tgt2 = self.multihead_attn(
- tgt,
- memory,
- memory,
- attn_mask=memory_mask,
- key_padding_mask=memory_key_padding_mask)
- tgt = tgt + self.dropout2(tgt2)
- tgt = self.norm2(tgt)
-
- # default
- tgt = paddle.transpose(tgt, [1, 2, 0])
- tgt = paddle.unsqueeze(tgt, 2)
- tgt2 = self.conv2(F.relu(self.conv1(tgt)))
- tgt2 = paddle.squeeze(tgt2, 2)
- tgt2 = paddle.transpose(tgt2, [2, 0, 1])
- tgt = paddle.squeeze(tgt, 2)
- tgt = paddle.transpose(tgt, [2, 0, 1])
-
- tgt = tgt + self.dropout3(tgt2)
- tgt = self.norm3(tgt)
+ if self.with_cross_attn:
+ tgt2 = self.cross_attn(tgt, key=memory, attn_mask=cross_mask)
+ tgt = self.norm2(tgt + self.dropout2(tgt2))
+ tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt)))
return tgt
-def _get_clones(module, N):
- return LayerList([copy.deepcopy(module) for i in range(N)])
-
-
class PositionalEncoding(nn.Layer):
"""Inject some information about the relative or absolute position of the tokens
in the sequence. The positional encodings have the same dimension as
@@ -651,8 +501,9 @@ class PositionalEncoding(nn.Layer):
Examples:
>>> output = pos_encoder(x)
"""
+ x = x.transpose([1, 0, 2])
x = x + self.pe[:paddle.shape(x)[0], :]
- return self.dropout(x)
+ return self.dropout(x).transpose([1, 0, 2])
class PositionalEncoding_2d(nn.Layer):
@@ -725,7 +576,7 @@ class PositionalEncoding_2d(nn.Layer):
class Embeddings(nn.Layer):
- def __init__(self, d_model, vocab, padding_idx, scale_embedding):
+ def __init__(self, d_model, vocab, padding_idx=None, scale_embedding=True):
super(Embeddings, self).__init__()
self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
w0 = np.random.normal(0.0, d_model**-0.5,
@@ -742,7 +593,7 @@ class Embeddings(nn.Layer):
class Beam():
- ''' Beam search '''
+ """ Beam search """
def __init__(self, size, device=False):
diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py
index e354f40d6518c1f7ca22e93694b1c6668fc003d2..4f39d6253d8d596fecdc4736666a6d3106601a82 100644
--- a/ppocr/modeling/heads/table_att_head.py
+++ b/ppocr/modeling/heads/table_att_head.py
@@ -21,6 +21,8 @@ import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
+from .rec_att_head import AttentionGRUCell
+
class TableAttentionHead(nn.Layer):
def __init__(self,
@@ -28,21 +30,19 @@ class TableAttentionHead(nn.Layer):
hidden_size,
loc_type,
in_max_len=488,
- max_text_length=100,
- max_elem_length=800,
- max_cell_num=500,
+ max_text_length=800,
+ out_channels=30,
+ point_num=2,
**kwargs):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
self.hidden_size = hidden_size
- self.elem_num = 30
+ self.out_channels = out_channels
self.max_text_length = max_text_length
- self.max_elem_length = max_elem_length
- self.max_cell_num = max_cell_num
self.structure_attention_cell = AttentionGRUCell(
- self.input_size, hidden_size, self.elem_num, use_gru=False)
- self.structure_generator = nn.Linear(hidden_size, self.elem_num)
+ self.input_size, hidden_size, self.out_channels, use_gru=False)
+ self.structure_generator = nn.Linear(hidden_size, self.out_channels)
self.loc_type = loc_type
self.in_max_len = in_max_len
@@ -50,12 +50,13 @@ class TableAttentionHead(nn.Layer):
self.loc_generator = nn.Linear(hidden_size, 4)
else:
if self.in_max_len == 640:
- self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1)
+ self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1)
elif self.in_max_len == 800:
- self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1)
+ self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
else:
- self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1)
- self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
+ self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
+ self.loc_generator = nn.Linear(self.input_size + hidden_size,
+ point_num * 2)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
@@ -77,9 +78,9 @@ class TableAttentionHead(nn.Layer):
output_hiddens = []
if self.training and targets is not None:
structure = targets[0]
- for i in range(self.max_elem_length + 1):
+ for i in range(self.max_text_length + 1):
elem_onehots = self._char_to_onehot(
- structure[:, i], onehot_dim=self.elem_num)
+ structure[:, i], onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
@@ -102,11 +103,11 @@ class TableAttentionHead(nn.Layer):
elem_onehots = None
outputs = None
alpha = None
- max_elem_length = paddle.to_tensor(self.max_elem_length)
+ max_text_length = paddle.to_tensor(self.max_text_length)
i = 0
- while i < max_elem_length + 1:
+ while i < max_text_length + 1:
elem_onehots = self._char_to_onehot(
- temp_elem, onehot_dim=self.elem_num)
+ temp_elem, onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
@@ -128,119 +129,3 @@ class TableAttentionHead(nn.Layer):
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
-
-
-class AttentionGRUCell(nn.Layer):
- def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
- super(AttentionGRUCell, self).__init__()
- self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
- self.h2h = nn.Linear(hidden_size, hidden_size)
- self.score = nn.Linear(hidden_size, 1, bias_attr=False)
- self.rnn = nn.GRUCell(
- input_size=input_size + num_embeddings, hidden_size=hidden_size)
- self.hidden_size = hidden_size
-
- def forward(self, prev_hidden, batch_H, char_onehots):
- batch_H_proj = self.i2h(batch_H)
- prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
- res = paddle.add(batch_H_proj, prev_hidden_proj)
- res = paddle.tanh(res)
- e = self.score(res)
- alpha = F.softmax(e, axis=1)
- alpha = paddle.transpose(alpha, [0, 2, 1])
- context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
- concat_context = paddle.concat([context, char_onehots], 1)
- cur_hidden = self.rnn(concat_context, prev_hidden)
- return cur_hidden, alpha
-
-
-class AttentionLSTM(nn.Layer):
- def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
- super(AttentionLSTM, self).__init__()
- self.input_size = in_channels
- self.hidden_size = hidden_size
- self.num_classes = out_channels
-
- self.attention_cell = AttentionLSTMCell(
- in_channels, hidden_size, out_channels, use_gru=False)
- self.generator = nn.Linear(hidden_size, out_channels)
-
- def _char_to_onehot(self, input_char, onehot_dim):
- input_ont_hot = F.one_hot(input_char, onehot_dim)
- return input_ont_hot
-
- def forward(self, inputs, targets=None, batch_max_length=25):
- batch_size = inputs.shape[0]
- num_steps = batch_max_length
-
- hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
- (batch_size, self.hidden_size)))
- output_hiddens = []
-
- if targets is not None:
- for i in range(num_steps):
- # one-hot vectors for a i-th char
- char_onehots = self._char_to_onehot(
- targets[:, i], onehot_dim=self.num_classes)
- hidden, alpha = self.attention_cell(hidden, inputs,
- char_onehots)
-
- hidden = (hidden[1][0], hidden[1][1])
- output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
- output = paddle.concat(output_hiddens, axis=1)
- probs = self.generator(output)
-
- else:
- targets = paddle.zeros(shape=[batch_size], dtype="int32")
- probs = None
-
- for i in range(num_steps):
- char_onehots = self._char_to_onehot(
- targets, onehot_dim=self.num_classes)
- hidden, alpha = self.attention_cell(hidden, inputs,
- char_onehots)
- probs_step = self.generator(hidden[0])
- hidden = (hidden[1][0], hidden[1][1])
- if probs is None:
- probs = paddle.unsqueeze(probs_step, axis=1)
- else:
- probs = paddle.concat(
- [probs, paddle.unsqueeze(
- probs_step, axis=1)], axis=1)
-
- next_input = probs_step.argmax(axis=1)
-
- targets = next_input
-
- return probs
-
-
-class AttentionLSTMCell(nn.Layer):
- def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
- super(AttentionLSTMCell, self).__init__()
- self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
- self.h2h = nn.Linear(hidden_size, hidden_size)
- self.score = nn.Linear(hidden_size, 1, bias_attr=False)
- if not use_gru:
- self.rnn = nn.LSTMCell(
- input_size=input_size + num_embeddings, hidden_size=hidden_size)
- else:
- self.rnn = nn.GRUCell(
- input_size=input_size + num_embeddings, hidden_size=hidden_size)
-
- self.hidden_size = hidden_size
-
- def forward(self, prev_hidden, batch_H, char_onehots):
- batch_H_proj = self.i2h(batch_H)
- prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
- res = paddle.add(batch_H_proj, prev_hidden_proj)
- res = paddle.tanh(res)
- e = self.score(res)
-
- alpha = F.softmax(e, axis=1)
- alpha = paddle.transpose(alpha, [0, 2, 1])
- context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
- concat_context = paddle.concat([context, char_onehots], 1)
- cur_hidden = self.rnn(concat_context, prev_hidden)
-
- return cur_hidden, alpha
diff --git a/ppocr/modeling/heads/table_master_head.py b/ppocr/modeling/heads/table_master_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..fddbcc63fcd6d5380f9fdd96f9ca85756d666442
--- /dev/null
+++ b/ppocr/modeling/heads/table_master_head.py
@@ -0,0 +1,281 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/mmocr/models/textrecog/decoders/master_decoder.py
+"""
+
+import copy
+import math
+import paddle
+from paddle import nn
+from paddle.nn import functional as F
+
+
+class TableMasterHead(nn.Layer):
+ """
+ Split to two transformer header at the last layer.
+ Cls_layer is used to structure token classification.
+ Bbox_layer is used to regress bbox coord.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels=30,
+ headers=8,
+ d_ff=2048,
+ dropout=0,
+ max_text_length=500,
+ point_num=2,
+ **kwargs):
+ super(TableMasterHead, self).__init__()
+ hidden_size = in_channels[-1]
+ self.layers = clones(
+ DecoderLayer(headers, hidden_size, dropout, d_ff), 2)
+ self.cls_layer = clones(
+ DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
+ self.bbox_layer = clones(
+ DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
+ self.cls_fc = nn.Linear(hidden_size, out_channels)
+ self.bbox_fc = nn.Sequential(
+ # nn.Linear(hidden_size, hidden_size),
+ nn.Linear(hidden_size, point_num * 2),
+ nn.Sigmoid())
+ self.norm = nn.LayerNorm(hidden_size)
+ self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels)
+ self.positional_encoding = PositionalEncoding(d_model=hidden_size)
+
+ self.SOS = out_channels - 3
+ self.PAD = out_channels - 1
+ self.out_channels = out_channels
+ self.point_num = point_num
+ self.max_text_length = max_text_length
+
+ def make_mask(self, tgt):
+ """
+ Make mask for self attention.
+ :param src: [b, c, h, l_src]
+ :param tgt: [b, l_tgt]
+ :return:
+ """
+ trg_pad_mask = (tgt != self.PAD).unsqueeze(1).unsqueeze(3)
+
+ tgt_len = paddle.shape(tgt)[1]
+ trg_sub_mask = paddle.tril(
+ paddle.ones(
+ ([tgt_len, tgt_len]), dtype=paddle.float32))
+
+ tgt_mask = paddle.logical_and(
+ trg_pad_mask.astype(paddle.float32), trg_sub_mask)
+ return tgt_mask.astype(paddle.float32)
+
+ def decode(self, input, feature, src_mask, tgt_mask):
+ # main process of transformer decoder.
+ x = self.embedding(input) # x: 1*x*512, feature: 1*3600,512
+ x = self.positional_encoding(x)
+
+ # origin transformer layers
+ for i, layer in enumerate(self.layers):
+ x = layer(x, feature, src_mask, tgt_mask)
+
+ # cls head
+ for layer in self.cls_layer:
+ cls_x = layer(x, feature, src_mask, tgt_mask)
+ cls_x = self.norm(cls_x)
+
+ # bbox head
+ for layer in self.bbox_layer:
+ bbox_x = layer(x, feature, src_mask, tgt_mask)
+ bbox_x = self.norm(bbox_x)
+ return self.cls_fc(cls_x), self.bbox_fc(bbox_x)
+
+ def greedy_forward(self, SOS, feature):
+ input = SOS
+ output = paddle.zeros(
+ [input.shape[0], self.max_text_length + 1, self.out_channels])
+ bbox_output = paddle.zeros(
+ [input.shape[0], self.max_text_length + 1, self.point_num * 2])
+ max_text_length = paddle.to_tensor(self.max_text_length)
+ for i in range(max_text_length + 1):
+ target_mask = self.make_mask(input)
+ out_step, bbox_output_step = self.decode(input, feature, None,
+ target_mask)
+ prob = F.softmax(out_step, axis=-1)
+ next_word = prob.argmax(axis=2, dtype="int64")
+ input = paddle.concat(
+ [input, next_word[:, -1].unsqueeze(-1)], axis=1)
+ if i == self.max_text_length:
+ output = out_step
+ bbox_output = bbox_output_step
+ return output, bbox_output
+
+ def forward_train(self, out_enc, targets):
+ # x is token of label
+ # feat is feature after backbone before pe.
+ # out_enc is feature after pe.
+ padded_targets = targets[0]
+ src_mask = None
+ tgt_mask = self.make_mask(padded_targets[:, :-1])
+ output, bbox_output = self.decode(padded_targets[:, :-1], out_enc,
+ src_mask, tgt_mask)
+ return {'structure_probs': output, 'loc_preds': bbox_output}
+
+ def forward_test(self, out_enc):
+ batch_size = out_enc.shape[0]
+ SOS = paddle.zeros([batch_size, 1], dtype='int64') + self.SOS
+ output, bbox_output = self.greedy_forward(SOS, out_enc)
+ output = F.softmax(output)
+ return {'structure_probs': output, 'loc_preds': bbox_output}
+
+ def forward(self, feat, targets=None):
+ feat = feat[-1]
+ b, c, h, w = feat.shape
+ feat = feat.reshape([b, c, h * w]) # flatten 2D feature map
+ feat = feat.transpose((0, 2, 1))
+ out_enc = self.positional_encoding(feat)
+ if self.training:
+ return self.forward_train(out_enc, targets)
+
+ return self.forward_test(out_enc)
+
+
+class DecoderLayer(nn.Layer):
+ """
+ Decoder is made of self attention, srouce attention and feed forward.
+ """
+
+ def __init__(self, headers, d_model, dropout, d_ff):
+ super(DecoderLayer, self).__init__()
+ self.self_attn = MultiHeadAttention(headers, d_model, dropout)
+ self.src_attn = MultiHeadAttention(headers, d_model, dropout)
+ self.feed_forward = FeedForward(d_model, d_ff, dropout)
+ self.sublayer = clones(SubLayerConnection(d_model, dropout), 3)
+
+ def forward(self, x, feature, src_mask, tgt_mask):
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
+ x = self.sublayer[1](
+ x, lambda x: self.src_attn(x, feature, feature, src_mask))
+ return self.sublayer[2](x, self.feed_forward)
+
+
+class MultiHeadAttention(nn.Layer):
+ def __init__(self, headers, d_model, dropout):
+ super(MultiHeadAttention, self).__init__()
+
+ assert d_model % headers == 0
+ self.d_k = int(d_model / headers)
+ self.headers = headers
+ self.linears = clones(nn.Linear(d_model, d_model), 4)
+ self.attn = None
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, query, key, value, mask=None):
+ B = query.shape[0]
+
+ # 1) Do all the linear projections in batch from d_model => h x d_k
+ query, key, value = \
+ [l(x).reshape([B, 0, self.headers, self.d_k]).transpose([0, 2, 1, 3])
+ for l, x in zip(self.linears, (query, key, value))]
+ # 2) Apply attention on all the projected vectors in batch
+ x, self.attn = self_attention(
+ query, key, value, mask=mask, dropout=self.dropout)
+ x = x.transpose([0, 2, 1, 3]).reshape([B, 0, self.headers * self.d_k])
+ return self.linears[-1](x)
+
+
+class FeedForward(nn.Layer):
+ def __init__(self, d_model, d_ff, dropout):
+ super(FeedForward, self).__init__()
+ self.w_1 = nn.Linear(d_model, d_ff)
+ self.w_2 = nn.Linear(d_ff, d_model)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ return self.w_2(self.dropout(F.relu(self.w_1(x))))
+
+
+class SubLayerConnection(nn.Layer):
+ """
+ A residual connection followed by a layer norm.
+ Note for code simplicity the norm is first as opposed to last.
+ """
+
+ def __init__(self, size, dropout):
+ super(SubLayerConnection, self).__init__()
+ self.norm = nn.LayerNorm(size)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x, sublayer):
+ return x + self.dropout(sublayer(self.norm(x)))
+
+
+def masked_fill(x, mask, value):
+ mask = mask.astype(x.dtype)
+ return x * paddle.logical_not(mask).astype(x.dtype) + mask * value
+
+
+def self_attention(query, key, value, mask=None, dropout=None):
+ """
+ Compute 'Scale Dot Product Attention'
+ """
+ d_k = value.shape[-1]
+
+ score = paddle.matmul(query, key.transpose([0, 1, 3, 2]) / math.sqrt(d_k))
+ if mask is not None:
+ # score = score.masked_fill(mask == 0, -1e9) # b, h, L, L
+ score = masked_fill(score, mask == 0, -6.55e4) # for fp16
+
+ p_attn = F.softmax(score, axis=-1)
+
+ if dropout is not None:
+ p_attn = dropout(p_attn)
+ return paddle.matmul(p_attn, value), p_attn
+
+
+def clones(module, N):
+ """ Produce N identical layers """
+ return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
+
+
+class Embeddings(nn.Layer):
+ def __init__(self, d_model, vocab):
+ super(Embeddings, self).__init__()
+ self.lut = nn.Embedding(vocab, d_model)
+ self.d_model = d_model
+
+ def forward(self, *input):
+ x = input[0]
+ return self.lut(x) * math.sqrt(self.d_model)
+
+
+class PositionalEncoding(nn.Layer):
+ """ Implement the PE function. """
+
+ def __init__(self, d_model, dropout=0., max_len=5000):
+ super(PositionalEncoding, self).__init__()
+ self.dropout = nn.Dropout(p=dropout)
+
+ # Compute the positional encodings once in log space.
+ pe = paddle.zeros([max_len, d_model])
+ position = paddle.arange(0, max_len).unsqueeze(1).astype('float32')
+ div_term = paddle.exp(
+ paddle.arange(0, d_model, 2) * -math.log(10000.0) / d_model)
+ pe[:, 0::2] = paddle.sin(position * div_term)
+ pe[:, 1::2] = paddle.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.register_buffer('pe', pe)
+
+ def forward(self, feat, **kwargs):
+ feat = feat + self.pe[:, :paddle.shape(feat)[1]] # pe 1*5000*512
+ return self.dropout(feat)
diff --git a/ppocr/modeling/necks/db_fpn.py b/ppocr/modeling/necks/db_fpn.py
index 93ed2dbfd1fac9bf2d163c54d23a20e16b537981..8c3f52a331db5daafab2a38c0a441edd44eb141d 100644
--- a/ppocr/modeling/necks/db_fpn.py
+++ b/ppocr/modeling/necks/db_fpn.py
@@ -105,9 +105,10 @@ class DSConv(nn.Layer):
class DBFPN(nn.Layer):
- def __init__(self, in_channels, out_channels, **kwargs):
+ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs):
super(DBFPN, self).__init__()
self.out_channels = out_channels
+ self.use_asf = use_asf
weight_attr = paddle.nn.initializer.KaimingUniform()
self.in2_conv = nn.Conv2D(
@@ -163,6 +164,9 @@ class DBFPN(nn.Layer):
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
+ if self.use_asf is True:
+ self.asf = ASFBlock(self.out_channels, self.out_channels // 4)
+
def forward(self, x):
c2, c3, c4, c5 = x
@@ -187,6 +191,10 @@ class DBFPN(nn.Layer):
p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
fuse = paddle.concat([p5, p4, p3, p2], axis=1)
+
+ if self.use_asf is True:
+ fuse = self.asf(fuse, [p5, p4, p3, p2])
+
return fuse
@@ -356,3 +364,64 @@ class LKPAN(nn.Layer):
fuse = paddle.concat([p5, p4, p3, p2], axis=1)
return fuse
+
+
+class ASFBlock(nn.Layer):
+ """
+ This code is refered from:
+ https://github.com/MhLiao/DB/blob/master/decoders/feature_attention.py
+ """
+
+ def __init__(self, in_channels, inter_channels, out_features_num=4):
+ """
+ Adaptive Scale Fusion (ASF) block of DBNet++
+ Args:
+ in_channels: the number of channels in the input data
+ inter_channels: the number of middle channels
+ out_features_num: the number of fused stages
+ """
+ super(ASFBlock, self).__init__()
+ weight_attr = paddle.nn.initializer.KaimingUniform()
+ self.in_channels = in_channels
+ self.inter_channels = inter_channels
+ self.out_features_num = out_features_num
+ self.conv = nn.Conv2D(in_channels, inter_channels, 3, padding=1)
+
+ self.spatial_scale = nn.Sequential(
+ #Nx1xHxW
+ nn.Conv2D(
+ in_channels=1,
+ out_channels=1,
+ kernel_size=3,
+ bias_attr=False,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr)),
+ nn.ReLU(),
+ nn.Conv2D(
+ in_channels=1,
+ out_channels=1,
+ kernel_size=1,
+ bias_attr=False,
+ weight_attr=ParamAttr(initializer=weight_attr)),
+ nn.Sigmoid())
+
+ self.channel_scale = nn.Sequential(
+ nn.Conv2D(
+ in_channels=inter_channels,
+ out_channels=out_features_num,
+ kernel_size=1,
+ bias_attr=False,
+ weight_attr=ParamAttr(initializer=weight_attr)),
+ nn.Sigmoid())
+
+ def forward(self, fuse_features, features_list):
+ fuse_features = self.conv(fuse_features)
+ spatial_x = paddle.mean(fuse_features, axis=1, keepdim=True)
+ attention_scores = self.spatial_scale(spatial_x) + fuse_features
+ attention_scores = self.channel_scale(attention_scores)
+ assert len(features_list) == self.out_features_num
+
+ out_list = []
+ for i in range(self.out_features_num):
+ out_list.append(attention_scores[:, i:i + 1] * features_list[i])
+ return paddle.concat(out_list, axis=1)
diff --git a/ppocr/optimizer/learning_rate.py b/ppocr/optimizer/learning_rate.py
index fe251f36e736bb1eac8a71a8115c941cbd7443e6..7d45109b4857871f52764c64d6d32e5322fc7c57 100644
--- a/ppocr/optimizer/learning_rate.py
+++ b/ppocr/optimizer/learning_rate.py
@@ -308,3 +308,81 @@ class Const(object):
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
+
+
+class DecayLearningRate(object):
+ """
+ DecayLearningRate learning rate decay
+ new_lr = (lr - end_lr) * (1 - epoch/decay_steps)**power + end_lr
+ Args:
+ learning_rate(float): initial learning rate
+ step_each_epoch(int): steps each epoch
+ epochs(int): total training epochs
+ factor(float): Power of polynomial, should greater than 0.0 to get learning rate decay. Default: 0.9
+ end_lr(float): The minimum final learning rate. Default: 0.0.
+ """
+
+ def __init__(self,
+ learning_rate,
+ step_each_epoch,
+ epochs,
+ factor=0.9,
+ end_lr=0,
+ **kwargs):
+ super(DecayLearningRate, self).__init__()
+ self.learning_rate = learning_rate
+ self.epochs = epochs + 1
+ self.factor = factor
+ self.end_lr = 0
+ self.decay_steps = step_each_epoch * epochs
+
+ def __call__(self):
+ learning_rate = lr.PolynomialDecay(
+ learning_rate=self.learning_rate,
+ decay_steps=self.decay_steps,
+ power=self.factor,
+ end_lr=self.end_lr)
+ return learning_rate
+
+
+class MultiStepDecay(object):
+ """
+ Piecewise learning rate decay
+ Args:
+ step_each_epoch(int): steps each epoch
+ learning_rate (float): The initial learning rate. It is a python float number.
+ step_size (int): the interval to update.
+ gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
+ It should be less than 1.0. Default: 0.1.
+ last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
+ """
+
+ def __init__(self,
+ learning_rate,
+ milestones,
+ step_each_epoch,
+ gamma,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs):
+ super(MultiStepDecay, self).__init__()
+ self.milestones = [step_each_epoch * e for e in milestones]
+ self.learning_rate = learning_rate
+ self.gamma = gamma
+ self.last_epoch = last_epoch
+ self.warmup_epoch = round(warmup_epoch * step_each_epoch)
+
+ def __call__(self):
+ learning_rate = lr.MultiStepDecay(
+ learning_rate=self.learning_rate,
+ milestones=self.milestones,
+ gamma=self.gamma,
+ last_epoch=self.last_epoch)
+ if self.warmup_epoch > 0:
+ learning_rate = lr.LinearWarmup(
+ learning_rate=learning_rate,
+ warmup_steps=self.warmup_epoch,
+ start_lr=0.0,
+ end_lr=self.learning_rate,
+ last_epoch=self.last_epoch)
+ return learning_rate
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index f50b5f1c5f8e617066bb47636c8f4d2b171b6ecb..1d414eb2e8562925f461b0c6f6ce15774b81bb8f 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -26,12 +26,13 @@ from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
- DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
- SEEDLabelDecode, PRENLabelDecode
+ DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
+ SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
+from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
def build_post_process(config, global_config=None):
@@ -42,7 +43,8 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
- 'DistillationSARLabelDecode'
+ 'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
+ 'TableMasterLabelDecode'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py
index 27b428ef2e73c9abf81d3881b23979343c8595b2..5e2553c3a09f8359d1641d2d49b1bfb84df695ac 100755
--- a/ppocr/postprocess/db_postprocess.py
+++ b/ppocr/postprocess/db_postprocess.py
@@ -38,6 +38,7 @@ class DBPostProcess(object):
unclip_ratio=2.0,
use_dilation=False,
score_mode="fast",
+ use_polygon=False,
**kwargs):
self.thresh = thresh
self.box_thresh = box_thresh
@@ -45,6 +46,7 @@ class DBPostProcess(object):
self.unclip_ratio = unclip_ratio
self.min_size = 3
self.score_mode = score_mode
+ self.use_polygon = use_polygon
assert score_mode in [
"slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
@@ -52,6 +54,53 @@ class DBPostProcess(object):
self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]])
+ def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
+ '''
+ _bitmap: single map with shape (1, H, W),
+ whose values are binarized as {0, 1}
+ '''
+
+ bitmap = _bitmap
+ height, width = bitmap.shape
+
+ boxes = []
+ scores = []
+
+ contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
+ cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
+
+ for contour in contours[:self.max_candidates]:
+ epsilon = 0.002 * cv2.arcLength(contour, True)
+ approx = cv2.approxPolyDP(contour, epsilon, True)
+ points = approx.reshape((-1, 2))
+ if points.shape[0] < 4:
+ continue
+
+ score = self.box_score_fast(pred, points.reshape(-1, 2))
+ if self.box_thresh > score:
+ continue
+
+ if points.shape[0] > 2:
+ box = self.unclip(points, self.unclip_ratio)
+ if len(box) > 1:
+ continue
+ else:
+ continue
+ box = box.reshape(-1, 2)
+
+ _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
+ if sside < self.min_size + 2:
+ continue
+
+ box = np.array(box)
+ box[:, 0] = np.clip(
+ np.round(box[:, 0] / width * dest_width), 0, dest_width)
+ box[:, 1] = np.clip(
+ np.round(box[:, 1] / height * dest_height), 0, dest_height)
+ boxes.append(box.tolist())
+ scores.append(score)
+ return boxes, scores
+
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
@@ -85,7 +134,7 @@ class DBPostProcess(object):
if self.box_thresh > score:
continue
- box = self.unclip(points).reshape(-1, 1, 2)
+ box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
@@ -99,8 +148,7 @@ class DBPostProcess(object):
scores.append(score)
return np.array(boxes, dtype=np.int16), scores
- def unclip(self, box):
- unclip_ratio = self.unclip_ratio
+ def unclip(self, box, unclip_ratio):
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
@@ -185,8 +233,12 @@ class DBPostProcess(object):
self.dilation_kernel)
else:
mask = segmentation[batch_index]
- boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
- src_w, src_h)
+ if self.use_polygon is True:
+ boxes, scores = self.polygons_from_bitmap(pred[batch_index],
+ mask, src_w, src_h)
+ else:
+ boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
+ src_w, src_h)
boxes_batch.append({'points': boxes})
return boxes_batch
@@ -202,6 +254,7 @@ class DistillationDBPostProcess(object):
unclip_ratio=1.5,
use_dilation=False,
score_mode="fast",
+ use_polygon=False,
**kwargs):
self.model_name = model_name
self.key = key
@@ -211,7 +264,8 @@ class DistillationDBPostProcess(object):
max_candidates=max_candidates,
unclip_ratio=unclip_ratio,
use_dilation=use_dilation,
- score_mode=score_mode)
+ score_mode=score_mode,
+ use_polygon=use_polygon)
def __call__(self, predicts, shape_list):
results = {}
diff --git a/ppocr/postprocess/pse_postprocess/pse_postprocess.py b/ppocr/postprocess/pse_postprocess/pse_postprocess.py
index 34f1b8c9b5397a5513462468a9ee3d8530389607..962f3efe922c4a2656e0f44f478e1baf301a5542 100755
--- a/ppocr/postprocess/pse_postprocess/pse_postprocess.py
+++ b/ppocr/postprocess/pse_postprocess/pse_postprocess.py
@@ -58,6 +58,8 @@ class PSEPostProcess(object):
kernels = (pred > self.thresh).astype('float32')
text_mask = kernels[:, 0, :, :]
+ text_mask = paddle.unsqueeze(text_mask, axis=1)
+
kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
score = score.numpy()
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index bf0fd890bf25949361665d212bf8e1a657054e5b..cc7c2cb379cc476943152507569f0b0066189c46 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -140,70 +140,6 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
return output
-class NRTRLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
-
- def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
- super(NRTRLabelDecode, self).__init__(character_dict_path,
- use_space_char)
-
- def __call__(self, preds, label=None, *args, **kwargs):
-
- if len(preds) == 2:
- preds_id = preds[0]
- preds_prob = preds[1]
- if isinstance(preds_id, paddle.Tensor):
- preds_id = preds_id.numpy()
- if isinstance(preds_prob, paddle.Tensor):
- preds_prob = preds_prob.numpy()
- if preds_id[0][0] == 2:
- preds_idx = preds_id[:, 1:]
- preds_prob = preds_prob[:, 1:]
- else:
- preds_idx = preds_id
- text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
- if label is None:
- return text
- label = self.decode(label[:, 1:])
- else:
- if isinstance(preds, paddle.Tensor):
- preds = preds.numpy()
- preds_idx = preds.argmax(axis=2)
- preds_prob = preds.max(axis=2)
- text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
- if label is None:
- return text
- label = self.decode(label[:, 1:])
- return text, label
-
- def add_special_char(self, dict_character):
- dict_character = ['blank', '', '', ''] + dict_character
- return dict_character
-
- def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
- result_list = []
- batch_size = len(text_index)
- for batch_idx in range(batch_size):
- char_list = []
- conf_list = []
- for idx in range(len(text_index[batch_idx])):
- if text_index[batch_idx][idx] == 3: # end
- break
- try:
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
- except:
- continue
- if text_prob is not None:
- conf_list.append(text_prob[batch_idx][idx])
- else:
- conf_list.append(1)
- text = ''.join(char_list)
- result_list.append((text.lower(), np.mean(conf_list).tolist()))
- return result_list
-
-
class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
@@ -444,146 +380,6 @@ class SRNLabelDecode(BaseRecLabelDecode):
return idx
-class TableLabelDecode(object):
- """ """
-
- def __init__(self, character_dict_path, **kwargs):
- list_character, list_elem = self.load_char_elem_dict(
- character_dict_path)
- list_character = self.add_special_char(list_character)
- list_elem = self.add_special_char(list_elem)
- self.dict_character = {}
- self.dict_idx_character = {}
- for i, char in enumerate(list_character):
- self.dict_idx_character[i] = char
- self.dict_character[char] = i
- self.dict_elem = {}
- self.dict_idx_elem = {}
- for i, elem in enumerate(list_elem):
- self.dict_idx_elem[i] = elem
- self.dict_elem[elem] = i
-
- def load_char_elem_dict(self, character_dict_path):
- list_character = []
- list_elem = []
- with open(character_dict_path, "rb") as fin:
- lines = fin.readlines()
- substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
- "\t")
- character_num = int(substr[0])
- elem_num = int(substr[1])
- for cno in range(1, 1 + character_num):
- character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
- list_character.append(character)
- for eno in range(1 + character_num, 1 + character_num + elem_num):
- elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
- list_elem.append(elem)
- return list_character, list_elem
-
- def add_special_char(self, list_character):
- self.beg_str = "sos"
- self.end_str = "eos"
- list_character = [self.beg_str] + list_character + [self.end_str]
- return list_character
-
- def __call__(self, preds):
- structure_probs = preds['structure_probs']
- loc_preds = preds['loc_preds']
- if isinstance(structure_probs, paddle.Tensor):
- structure_probs = structure_probs.numpy()
- if isinstance(loc_preds, paddle.Tensor):
- loc_preds = loc_preds.numpy()
- structure_idx = structure_probs.argmax(axis=2)
- structure_probs = structure_probs.max(axis=2)
- structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
- structure_idx, structure_probs, 'elem')
- res_html_code_list = []
- res_loc_list = []
- batch_num = len(structure_str)
- for bno in range(batch_num):
- res_loc = []
- for sno in range(len(structure_str[bno])):
- text = structure_str[bno][sno]
- if text in ['', ' | 0 and tmp_elem_idx == end_idx:
- break
- if tmp_elem_idx in ignored_tokens:
- continue
-
- char_list.append(current_dict[tmp_elem_idx])
- elem_pos_list.append(idx)
- score_list.append(structure_probs[batch_idx, idx])
- elem_idx_list.append(tmp_elem_idx)
- result_list.append(char_list)
- result_pos_list.append(elem_pos_list)
- result_score_list.append(score_list)
- result_elem_idx_list.append(elem_idx_list)
- return result_list, result_pos_list, result_score_list, result_elem_idx_list
-
- def get_ignored_tokens(self, char_or_elem):
- beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
- end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
- return [beg_idx, end_idx]
-
- def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
- if char_or_elem == "char":
- if beg_or_end == "beg":
- idx = self.dict_character[self.beg_str]
- elif beg_or_end == "end":
- idx = self.dict_character[self.end_str]
- else:
- assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
- % beg_or_end
- elif char_or_elem == "elem":
- if beg_or_end == "beg":
- idx = self.dict_elem[self.beg_str]
- elif beg_or_end == "end":
- idx = self.dict_elem[self.end_str]
- else:
- assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
- % beg_or_end
- else:
- assert False, "Unsupport type %s in char_or_elem" \
- % char_or_elem
- return idx
-
-
class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
@@ -752,3 +548,122 @@ class PRENLabelDecode(BaseRecLabelDecode):
return text
label = self.decode(label)
return text, label
+
+
+class NRTRLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
+ super(NRTRLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+
+ if len(preds) == 2:
+ preds_id = preds[0]
+ preds_prob = preds[1]
+ if isinstance(preds_id, paddle.Tensor):
+ preds_id = preds_id.numpy()
+ if isinstance(preds_prob, paddle.Tensor):
+ preds_prob = preds_prob.numpy()
+ if preds_id[0][0] == 2:
+ preds_idx = preds_id[:, 1:]
+ preds_prob = preds_prob[:, 1:]
+ else:
+ preds_idx = preds_id
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ if label is None:
+ return text
+ label = self.decode(label[:, 1:])
+ else:
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ if label is None:
+ return text
+ label = self.decode(label[:, 1:])
+ return text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = ['blank', '', '', ''] + dict_character
+ return dict_character
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """ convert text-index into text-label. """
+ result_list = []
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ try:
+ char_idx = self.character[int(text_index[batch_idx][idx])]
+ except:
+ continue
+ if char_idx == '': # end
+ break
+ char_list.append(char_idx)
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+ text = ''.join(char_list)
+ result_list.append((text.lower(), np.mean(conf_list).tolist()))
+ return result_list
+
+
+class ViTSTRLabelDecode(NRTRLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(ViTSTRLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ if isinstance(preds, paddle.Tensor):
+ preds = preds[:, 1:].numpy()
+ else:
+ preds = preds[:, 1:]
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ if label is None:
+ return text
+ label = self.decode(label[:, 1:])
+ return text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = ['', ''] + dict_character
+ return dict_character
+
+
+class ABINetLabelDecode(NRTRLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(ABINetLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ if isinstance(preds, dict):
+ preds = preds['align'][-1].numpy()
+ elif isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+ else:
+ preds = preds
+
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ if label is None:
+ return text
+ label = self.decode(label)
+ return text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = [''] + dict_character
+ return dict_character
diff --git a/ppocr/postprocess/table_postprocess.py b/ppocr/postprocess/table_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..4396ec4f701478e7bdcdd8c7752738c5c8ef148d
--- /dev/null
+++ b/ppocr/postprocess/table_postprocess.py
@@ -0,0 +1,160 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import paddle
+
+from .rec_postprocess import AttnLabelDecode
+
+
+class TableLabelDecode(AttnLabelDecode):
+ """ """
+
+ def __init__(self, character_dict_path, **kwargs):
+ super(TableLabelDecode, self).__init__(character_dict_path)
+ self.td_token = ['', ' | ', ' | | ']
+
+ def __call__(self, preds, batch=None):
+ structure_probs = preds['structure_probs']
+ bbox_preds = preds['loc_preds']
+ if isinstance(structure_probs, paddle.Tensor):
+ structure_probs = structure_probs.numpy()
+ if isinstance(bbox_preds, paddle.Tensor):
+ bbox_preds = bbox_preds.numpy()
+ shape_list = batch[-1]
+ result = self.decode(structure_probs, bbox_preds, shape_list)
+ if len(batch) == 1: # only contains shape
+ return result
+
+ label_decode_result = self.decode_label(batch)
+ return result, label_decode_result
+
+ def decode(self, structure_probs, bbox_preds, shape_list):
+ """convert text-label into text-index.
+ """
+ ignored_tokens = self.get_ignored_tokens()
+ end_idx = self.dict[self.end_str]
+
+ structure_idx = structure_probs.argmax(axis=2)
+ structure_probs = structure_probs.max(axis=2)
+
+ structure_batch_list = []
+ bbox_batch_list = []
+ batch_size = len(structure_idx)
+ for batch_idx in range(batch_size):
+ structure_list = []
+ bbox_list = []
+ score_list = []
+ for idx in range(len(structure_idx[batch_idx])):
+ char_idx = int(structure_idx[batch_idx][idx])
+ if idx > 0 and char_idx == end_idx:
+ break
+ if char_idx in ignored_tokens:
+ continue
+ text = self.character[char_idx]
+ if text in self.td_token:
+ bbox = bbox_preds[batch_idx, idx]
+ bbox = self._bbox_decode(bbox, shape_list[batch_idx])
+ bbox_list.append(bbox)
+ structure_list.append(text)
+ score_list.append(structure_probs[batch_idx, idx])
+ structure_batch_list.append([structure_list, np.mean(score_list)])
+ bbox_batch_list.append(np.array(bbox_list))
+ result = {
+ 'bbox_batch_list': bbox_batch_list,
+ 'structure_batch_list': structure_batch_list,
+ }
+ return result
+
+ def decode_label(self, batch):
+ """convert text-label into text-index.
+ """
+ structure_idx = batch[1]
+ gt_bbox_list = batch[2]
+ shape_list = batch[-1]
+ ignored_tokens = self.get_ignored_tokens()
+ end_idx = self.dict[self.end_str]
+
+ structure_batch_list = []
+ bbox_batch_list = []
+ batch_size = len(structure_idx)
+ for batch_idx in range(batch_size):
+ structure_list = []
+ bbox_list = []
+ for idx in range(len(structure_idx[batch_idx])):
+ char_idx = int(structure_idx[batch_idx][idx])
+ if idx > 0 and char_idx == end_idx:
+ break
+ if char_idx in ignored_tokens:
+ continue
+ structure_list.append(self.character[char_idx])
+
+ bbox = gt_bbox_list[batch_idx][idx]
+ if bbox.sum() != 0:
+ bbox = self._bbox_decode(bbox, shape_list[batch_idx])
+ bbox_list.append(bbox)
+ structure_batch_list.append(structure_list)
+ bbox_batch_list.append(bbox_list)
+ result = {
+ 'bbox_batch_list': bbox_batch_list,
+ 'structure_batch_list': structure_batch_list,
+ }
+ return result
+
+ def _bbox_decode(self, bbox, shape):
+ h, w, ratio_h, ratio_w, pad_h, pad_w = shape
+ src_h = h / ratio_h
+ src_w = w / ratio_w
+ bbox[0::2] *= src_w
+ bbox[1::2] *= src_h
+ return bbox
+
+
+class TableMasterLabelDecode(TableLabelDecode):
+ """ """
+
+ def __init__(self, character_dict_path, box_shape='ori', **kwargs):
+ super(TableMasterLabelDecode, self).__init__(character_dict_path)
+ self.box_shape = box_shape
+ assert box_shape in [
+ 'ori', 'pad'
+ ], 'The shape used for box normalization must be ori or pad'
+
+ def add_special_char(self, dict_character):
+ self.beg_str = ''
+ self.end_str = ''
+ self.unknown_str = ''
+ self.pad_str = ''
+ dict_character = dict_character
+ dict_character = dict_character + [
+ self.unknown_str, self.beg_str, self.end_str, self.pad_str
+ ]
+ return dict_character
+
+ def get_ignored_tokens(self):
+ pad_idx = self.dict[self.pad_str]
+ start_idx = self.dict[self.beg_str]
+ end_idx = self.dict[self.end_str]
+ unknown_idx = self.dict[self.unknown_str]
+ return [start_idx, end_idx, pad_idx, unknown_idx]
+
+ def _bbox_decode(self, bbox, shape):
+ h, w, ratio_h, ratio_w, pad_h, pad_w = shape
+ if self.box_shape == 'pad':
+ h, w = pad_h, pad_w
+ bbox[0::2] *= w
+ bbox[1::2] *= h
+ bbox[0::2] /= ratio_w
+ bbox[1::2] /= ratio_h
+ return bbox
diff --git a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
index 782cdea6c58c69e0d728787e0e21e200c9e13790..8a6669f71f5ae6a7a16931e565b43355de5928d9 100644
--- a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
+++ b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
@@ -41,11 +41,13 @@ class VQASerTokenLayoutLMPostProcess(object):
self.id2label_map_for_show[val] = key
def __call__(self, preds, batch=None, *args, **kwargs):
+ if isinstance(preds, tuple):
+ preds = preds[0]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
if batch is not None:
- return self._metric(preds, batch[1])
+ return self._metric(preds, batch[5])
else:
return self._infer(preds, **kwargs)
@@ -63,11 +65,11 @@ class VQASerTokenLayoutLMPostProcess(object):
j]])
return decode_out_list, label_decode_out_list
- def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos):
+ def _infer(self, preds, segment_offset_ids, ocr_infos):
results = []
- for pred, attention_mask, segment_offset_id, ocr_info in zip(
- preds, attention_masks, segment_offset_ids, ocr_infos):
+ for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids,
+ ocr_infos):
pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred]
diff --git a/ppocr/utils/dict/table_master_structure_dict.txt b/ppocr/utils/dict/table_master_structure_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..95ab2539a70aca4f695c53a38cdc1c3e164fcfb3
--- /dev/null
+++ b/ppocr/utils/dict/table_master_structure_dict.txt
@@ -0,0 +1,39 @@
+
+
+ |
+
+
+
+
+
+
+ |
+ colspan="2"
+ colspan="3"
+
+
+ rowspan="2"
+ colspan="4"
+ colspan="6"
+ rowspan="3"
+ colspan="9"
+ colspan="10"
+ colspan="7"
+ rowspan="4"
+ rowspan="5"
+ rowspan="9"
+ colspan="8"
+ rowspan="8"
+ rowspan="6"
+ rowspan="7"
+ rowspan="10"
+
+
+
+
+
+
+
+
diff --git a/ppocr/utils/dict/table_structure_dict.txt b/ppocr/utils/dict/table_structure_dict.txt
index 9c4531e5f3b8c498e70d3c2ea0471e5e746a2c30..8edb10b8817ad596af6c63b6b8fc5eb2349b7464 100644
--- a/ppocr/utils/dict/table_structure_dict.txt
+++ b/ppocr/utils/dict/table_structure_dict.txt
@@ -1,281 +1,3 @@
-277 28 1267 1186
-
-V
-a
-r
-i
-b
-l
-e
-
-H
-z
-d
-
-t
-o
-9
-5
-%
-C
-I
-
-p
-
-v
-u
-*
-A
-g
-(
-m
-n
-)
-0
-.
-7
-1
-6
-≤
->
-8
-3
-–
-2
-G
-4
-M
-F
-T
-y
-f
-s
-L
-w
-c
-U
-h
-D
-S
-Q
-R
-x
-P
--
-E
-O
-/
-k
-,
-+
-N
-K
-q
-′
-[
-]
-<
-≥
-
-−
-
-μ
-±
-J
-j
-W
-_
-Δ
-B
-“
-:
-Y
-α
-λ
-;
-
-
-?
-∼
-=
-°
-#
-̊
-̈
-̂
-’
-Z
-X
-∗
-—
-β
-'
-†
-~
-@
-"
-γ
-↓
-↑
-&
-‡
-χ
-”
-σ
-§
-|
-¶
-‐
-×
-$
-→
-√
-✓
-‘
-\
-∞
-π
-•
-®
-^
-∆
-≧
-
-
-́
-♀
-♂
-‒
-⁎
-▲
-·
-£
-φ
-Ψ
-ß
-△
-☆
-▪
-η
-€
-∧
-̃
-Φ
-ρ
-̄
-δ
-‰
-̧
-Ω
-♦
-{
-}
-̀
-∑
-∫
-ø
-κ
-ε
-¥
-※
-`
-ω
-Σ
-➔
-‖
-Β
-̸
-
-─
-●
-⩾
-Χ
-Α
-⋅
-◆
-★
-■
-ψ
-ǂ
-□
-ζ
-!
-Γ
-↔
-θ
-⁄
-〈
-〉
-―
-υ
-τ
-⋆
-Ø
-©
-∥
-С
-˂
-➢
-ɛ
-
-✗
-←
-○
-¢
-⩽
-∖
-˃
-
-≈
-Π
-̌
-≦
-∅
-ᅟ
-
-
-∣
-¤
-♯
-̆
-ξ
-÷
-▼
-
-ι
-ν
-║
-
-
-◦
-
-◊
-∙
-«
-»
-ł
-ı
-Θ
-∈
-„
-∘
-✔
-̇
-æ
-ʹ
-ˆ
-♣
-⇓
-∩
-⊕
-⇒
-⇑
-̨
-Ι
-Λ
-⋯
-А
-⋮
@@ -303,2457 +25,4 @@ $
rowspan="8"
rowspan="6"
rowspan="7"
- rowspan="10"
-0 2924682
-1 3405345
-2 2363468
-3 2709165
-4 4078680
-5 3250792
-6 1923159
-7 1617890
-8 1450532
-9 1717624
-10 1477550
-11 1489223
-12 915528
-13 819193
-14 593660
-15 518924
-16 682065
-17 494584
-18 400591
-19 396421
-20 340994
-21 280688
-22 250328
-23 226786
-24 199927
-25 182707
-26 164629
-27 141613
-28 127554
-29 116286
-30 107682
-31 96367
-32 88002
-33 79234
-34 72186
-35 65921
-36 60374
-37 55976
-38 52166
-39 47414
-40 44932
-41 41279
-42 38232
-43 35463
-44 33703
-45 30557
-46 29639
-47 27000
-48 25447
-49 23186
-50 22093
-51 20412
-52 19844
-53 18261
-54 17561
-55 16499
-56 15597
-57 14558
-58 14372
-59 13445
-60 13514
-61 12058
-62 11145
-63 10767
-64 10370
-65 9630
-66 9337
-67 8881
-68 8727
-69 8060
-70 7994
-71 7740
-72 7189
-73 6729
-74 6749
-75 6548
-76 6321
-77 5957
-78 5740
-79 5407
-80 5370
-81 5035
-82 4921
-83 4656
-84 4600
-85 4519
-86 4277
-87 4023
-88 3939
-89 3910
-90 3861
-91 3560
-92 3483
-93 3406
-94 3346
-95 3229
-96 3122
-97 3086
-98 3001
-99 2884
-100 2822
-101 2677
-102 2670
-103 2610
-104 2452
-105 2446
-106 2400
-107 2300
-108 2316
-109 2196
-110 2089
-111 2083
-112 2041
-113 1881
-114 1838
-115 1896
-116 1795
-117 1786
-118 1743
-119 1765
-120 1750
-121 1683
-122 1563
-123 1499
-124 1513
-125 1462
-126 1388
-127 1441
-128 1417
-129 1392
-130 1306
-131 1321
-132 1274
-133 1294
-134 1240
-135 1126
-136 1157
-137 1130
-138 1084
-139 1130
-140 1083
-141 1040
-142 980
-143 1031
-144 974
-145 980
-146 932
-147 898
-148 960
-149 907
-150 852
-151 912
-152 859
-153 847
-154 876
-155 792
-156 791
-157 765
-158 788
-159 787
-160 744
-161 673
-162 683
-163 697
-164 666
-165 680
-166 632
-167 677
-168 657
-169 618
-170 587
-171 585
-172 567
-173 549
-174 562
-175 548
-176 542
-177 539
-178 542
-179 549
-180 547
-181 526
-182 525
-183 514
-184 512
-185 505
-186 515
-187 467
-188 475
-189 458
-190 435
-191 443
-192 427
-193 424
-194 404
-195 389
-196 429
-197 404
-198 386
-199 351
-200 388
-201 408
-202 361
-203 346
-204 324
-205 361
-206 363
-207 364
-208 323
-209 336
-210 342
-211 315
-212 325
-213 328
-214 314
-215 327
-216 320
-217 300
-218 295
-219 315
-220 310
-221 295
-222 275
-223 248
-224 274
-225 232
-226 293
-227 259
-228 286
-229 263
-230 242
-231 214
-232 261
-233 231
-234 211
-235 250
-236 233
-237 206
-238 224
-239 210
-240 233
-241 223
-242 216
-243 222
-244 207
-245 212
-246 196
-247 205
-248 201
-249 202
-250 211
-251 201
-252 215
-253 179
-254 163
-255 179
-256 191
-257 188
-258 196
-259 150
-260 154
-261 176
-262 211
-263 166
-264 171
-265 165
-266 149
-267 182
-268 159
-269 161
-270 164
-271 161
-272 141
-273 151
-274 127
-275 129
-276 142
-277 158
-278 148
-279 135
-280 127
-281 134
-282 138
-283 131
-284 126
-285 125
-286 130
-287 126
-288 135
-289 125
-290 135
-291 131
-292 95
-293 135
-294 106
-295 117
-296 136
-297 128
-298 128
-299 118
-300 109
-301 112
-302 117
-303 108
-304 120
-305 100
-306 95
-307 108
-308 112
-309 77
-310 120
-311 104
-312 109
-313 89
-314 98
-315 82
-316 98
-317 93
-318 77
-319 93
-320 77
-321 98
-322 93
-323 86
-324 89
-325 73
-326 70
-327 71
-328 77
-329 87
-330 77
-331 93
-332 100
-333 83
-334 72
-335 74
-336 69
-337 77
-338 68
-339 78
-340 90
-341 98
-342 75
-343 80
-344 63
-345 71
-346 83
-347 66
-348 71
-349 70
-350 62
-351 62
-352 59
-353 63
-354 62
-355 52
-356 64
-357 64
-358 56
-359 49
-360 57
-361 63
-362 60
-363 68
-364 62
-365 55
-366 54
-367 40
-368 75
-369 70
-370 53
-371 58
-372 57
-373 55
-374 69
-375 57
-376 53
-377 43
-378 45
-379 47
-380 56
-381 51
-382 59
-383 51
-384 43
-385 34
-386 57
-387 49
-388 39
-389 46
-390 48
-391 43
-392 40
-393 54
-394 50
-395 41
-396 43
-397 33
-398 27
-399 49
-400 44
-401 44
-402 38
-403 30
-404 32
-405 37
-406 39
-407 42
-408 53
-409 39
-410 34
-411 31
-412 32
-413 52
-414 27
-415 41
-416 34
-417 36
-418 50
-419 35
-420 32
-421 33
-422 45
-423 35
-424 40
-425 29
-426 41
-427 40
-428 39
-429 32
-430 31
-431 34
-432 29
-433 27
-434 26
-435 22
-436 34
-437 28
-438 30
-439 38
-440 35
-441 36
-442 36
-443 27
-444 24
-445 33
-446 31
-447 25
-448 33
-449 27
-450 32
-451 46
-452 31
-453 35
-454 35
-455 34
-456 26
-457 21
-458 25
-459 26
-460 24
-461 27
-462 33
-463 30
-464 35
-465 21
-466 32
-467 19
-468 27
-469 16
-470 28
-471 26
-472 27
-473 26
-474 25
-475 25
-476 27
-477 20
-478 28
-479 22
-480 23
-481 16
-482 25
-483 27
-484 19
-485 23
-486 19
-487 15
-488 15
-489 23
-490 24
-491 19
-492 20
-493 18
-494 17
-495 30
-496 28
-497 20
-498 29
-499 17
-500 19
-501 21
-502 15
-503 24
-504 15
-505 19
-506 25
-507 16
-508 23
-509 26
-510 21
-511 15
-512 12
-513 16
-514 18
-515 24
-516 26
-517 18
-518 8
-519 25
-520 14
-521 8
-522 24
-523 20
-524 18
-525 15
-526 13
-527 17
-528 18
-529 22
-530 21
-531 9
-532 16
-533 17
-534 13
-535 17
-536 15
-537 13
-538 20
-539 13
-540 19
-541 29
-542 10
-543 8
-544 18
-545 13
-546 9
-547 18
-548 10
-549 18
-550 18
-551 9
-552 9
-553 15
-554 13
-555 15
-556 14
-557 14
-558 18
-559 8
-560 13
-561 9
-562 7
-563 12
-564 6
-565 9
-566 9
-567 18
-568 9
-569 10
-570 13
-571 14
-572 13
-573 21
-574 8
-575 16
-576 12
-577 9
-578 16
-579 17
-580 22
-581 6
-582 14
-583 13
-584 15
-585 11
-586 13
-587 5
-588 12
-589 13
-590 15
-591 13
-592 15
-593 12
-594 7
-595 18
-596 12
-597 13
-598 13
-599 13
-600 12
-601 12
-602 10
-603 11
-604 6
-605 6
-606 2
-607 9
-608 8
-609 12
-610 9
-611 12
-612 13
-613 12
-614 14
-615 9
-616 8
-617 9
-618 14
-619 13
-620 12
-621 6
-622 8
-623 8
-624 8
-625 12
-626 8
-627 7
-628 5
-629 8
-630 12
-631 6
-632 10
-633 10
-634 7
-635 8
-636 9
-637 6
-638 9
-639 4
-640 12
-641 4
-642 3
-643 11
-644 10
-645 6
-646 12
-647 12
-648 4
-649 4
-650 9
-651 8
-652 6
-653 5
-654 14
-655 10
-656 11
-657 8
-658 5
-659 5
-660 9
-661 13
-662 4
-663 5
-664 9
-665 11
-666 12
-667 7
-668 13
-669 2
-670 1
-671 7
-672 7
-673 7
-674 10
-675 9
-676 6
-677 5
-678 7
-679 6
-680 3
-681 3
-682 4
-683 9
-684 8
-685 5
-686 3
-687 11
-688 9
-689 2
-690 6
-691 5
-692 9
-693 5
-694 6
-695 5
-696 9
-697 8
-698 3
-699 7
-700 5
-701 9
-702 8
-703 7
-704 2
-705 3
-706 7
-707 6
-708 6
-709 10
-710 2
-711 10
-712 6
-713 7
-714 5
-715 6
-716 4
-717 6
-718 8
-719 4
-720 6
-721 7
-722 5
-723 7
-724 3
-725 10
-726 10
-727 3
-728 7
-729 7
-730 5
-731 2
-732 1
-733 5
-734 1
-735 5
-736 6
-737 2
-738 2
-739 3
-740 7
-741 2
-742 7
-743 4
-744 5
-745 4
-746 5
-747 3
-748 1
-749 4
-750 4
-751 2
-752 4
-753 6
-754 6
-755 6
-756 3
-757 2
-758 5
-759 5
-760 3
-761 4
-762 2
-763 1
-764 8
-765 3
-766 4
-767 3
-768 1
-769 5
-770 3
-771 3
-772 4
-773 4
-774 1
-775 3
-776 2
-777 2
-778 3
-779 3
-780 1
-781 4
-782 3
-783 4
-784 6
-785 3
-786 5
-787 4
-788 2
-789 4
-790 5
-791 4
-792 6
-794 4
-795 1
-796 1
-797 4
-798 2
-799 3
-800 3
-801 1
-802 5
-803 5
-804 3
-805 3
-806 3
-807 4
-808 4
-809 2
-811 5
-812 4
-813 6
-814 3
-815 2
-816 2
-817 3
-818 5
-819 3
-820 1
-821 1
-822 4
-823 3
-824 4
-825 8
-826 3
-827 5
-828 5
-829 3
-830 6
-831 3
-832 4
-833 8
-834 5
-835 3
-836 3
-837 2
-838 4
-839 2
-840 1
-841 3
-842 2
-843 1
-844 3
-846 4
-847 4
-848 3
-849 3
-850 2
-851 3
-853 1
-854 4
-855 4
-856 2
-857 4
-858 1
-859 2
-860 5
-861 1
-862 1
-863 4
-864 2
-865 2
-867 5
-868 1
-869 4
-870 1
-871 1
-872 1
-873 2
-875 5
-876 3
-877 1
-878 3
-879 3
-880 3
-881 2
-882 1
-883 6
-884 2
-885 2
-886 1
-887 1
-888 3
-889 2
-890 2
-891 3
-892 1
-893 3
-894 1
-895 5
-896 1
-897 3
-899 2
-900 2
-902 1
-903 2
-904 4
-905 4
-906 3
-907 1
-908 1
-909 2
-910 5
-911 2
-912 3
-914 1
-915 1
-916 2
-918 2
-919 2
-920 4
-921 4
-922 1
-923 1
-924 4
-925 5
-926 1
-928 2
-929 1
-930 1
-931 1
-932 1
-933 1
-934 2
-935 1
-936 1
-937 1
-938 2
-939 1
-941 1
-942 4
-944 2
-945 2
-946 2
-947 1
-948 1
-950 1
-951 2
-953 1
-954 2
-955 1
-956 1
-957 2
-958 1
-960 3
-962 4
-963 1
-964 1
-965 3
-966 2
-967 2
-968 1
-969 3
-970 3
-972 1
-974 4
-975 3
-976 3
-977 2
-979 2
-980 1
-981 1
-983 5
-984 1
-985 3
-986 1
-987 2
-988 4
-989 2
-991 2
-992 2
-993 1
-994 1
-996 2
-997 2
-998 1
-999 3
-1000 2
-1001 1
-1002 3
-1003 3
-1004 2
-1005 3
-1006 1
-1007 2
-1009 1
-1011 1
-1013 3
-1014 1
-1016 2
-1017 1
-1018 1
-1019 1
-1020 4
-1021 1
-1022 2
-1025 1
-1026 1
-1027 2
-1028 1
-1030 1
-1031 2
-1032 4
-1034 3
-1035 2
-1036 1
-1038 1
-1039 1
-1040 1
-1041 1
-1042 2
-1043 1
-1044 2
-1045 4
-1048 1
-1050 1
-1051 1
-1052 2
-1054 1
-1055 3
-1056 2
-1057 1
-1059 1
-1061 2
-1063 1
-1064 1
-1065 1
-1066 1
-1067 1
-1068 1
-1069 2
-1074 1
-1075 1
-1077 1
-1078 1
-1079 1
-1082 1
-1085 1
-1088 1
-1090 1
-1091 1
-1092 2
-1094 2
-1097 2
-1098 1
-1099 2
-1101 2
-1102 1
-1104 1
-1105 1
-1107 1
-1109 1
-1111 2
-1112 1
-1114 2
-1115 2
-1116 2
-1117 1
-1118 1
-1119 1
-1120 1
-1122 1
-1123 1
-1127 1
-1128 3
-1132 2
-1138 3
-1142 1
-1145 4
-1150 1
-1153 2
-1154 1
-1158 1
-1159 1
-1163 1
-1165 1
-1169 2
-1174 1
-1176 1
-1177 1
-1178 2
-1179 1
-1180 2
-1181 1
-1182 1
-1183 2
-1185 1
-1187 1
-1191 2
-1193 1
-1195 3
-1196 1
-1201 3
-1203 1
-1206 1
-1210 1
-1213 1
-1214 1
-1215 2
-1218 1
-1220 1
-1221 1
-1225 1
-1226 1
-1233 2
-1241 1
-1243 1
-1249 1
-1250 2
-1251 1
-1254 1
-1255 2
-1260 1
-1268 1
-1270 1
-1273 1
-1274 1
-1277 1
-1284 1
-1287 1
-1291 1
-1292 2
-1294 1
-1295 2
-1297 1
-1298 1
-1301 1
-1307 1
-1308 3
-1311 2
-1313 1
-1316 1
-1321 1
-1324 1
-1325 1
-1330 1
-1333 1
-1334 1
-1338 2
-1340 1
-1341 1
-1342 1
-1343 1
-1345 1
-1355 1
-1357 1
-1360 2
-1375 1
-1376 1
-1380 1
-1383 1
-1387 1
-1389 1
-1393 1
-1394 1
-1396 1
-1398 1
-1410 1
-1414 1
-1419 1
-1425 1
-1434 1
-1435 1
-1438 1
-1439 1
-1447 1
-1455 2
-1460 1
-1461 1
-1463 1
-1466 1
-1470 1
-1473 1
-1478 1
-1480 1
-1483 1
-1484 1
-1485 2
-1492 2
-1499 1
-1509 1
-1512 1
-1513 1
-1523 1
-1524 1
-1525 2
-1529 1
-1539 1
-1544 1
-1568 1
-1584 1
-1591 1
-1598 1
-1600 1
-1604 1
-1614 1
-1617 1
-1621 1
-1622 1
-1626 1
-1638 1
-1648 1
-1658 1
-1661 1
-1679 1
-1682 1
-1693 1
-1700 1
-1705 1
-1707 1
-1722 1
-1728 1
-1758 1
-1762 1
-1763 1
-1775 1
-1776 1
-1801 1
-1810 1
-1812 1
-1827 1
-1834 1
-1846 1
-1847 1
-1848 1
-1851 1
-1862 1
-1866 1
-1877 2
-1884 1
-1888 1
-1903 1
-1912 1
-1925 1
-1938 1
-1955 1
-1998 1
-2054 1
-2058 1
-2065 1
-2069 1
-2076 1
-2089 1
-2104 1
-2111 1
-2133 1
-2138 1
-2156 1
-2204 1
-2212 1
-2237 1
-2246 2
-2298 1
-2304 1
-2360 1
-2400 1
-2481 1
-2544 1
-2586 1
-2622 1
-2666 1
-2682 1
-2725 1
-2920 1
-3997 1
-4019 1
-5211 1
-12 19
-14 1
-16 401
-18 2
-20 421
-22 557
-24 625
-26 50
-28 4481
-30 52
-32 550
-34 5840
-36 4644
-38 87
-40 5794
-41 33
-42 571
-44 11805
-46 4711
-47 7
-48 597
-49 12
-50 678
-51 2
-52 14715
-53 3
-54 7322
-55 3
-56 508
-57 39
-58 3486
-59 11
-60 8974
-61 45
-62 1276
-63 4
-64 15693
-65 15
-66 657
-67 13
-68 6409
-69 10
-70 3188
-71 25
-72 1889
-73 27
-74 10370
-75 9
-76 12432
-77 23
-78 520
-79 15
-80 1534
-81 29
-82 2944
-83 23
-84 12071
-85 36
-86 1502
-87 10
-88 10978
-89 11
-90 889
-91 16
-92 4571
-93 17
-94 7855
-95 21
-96 2271
-97 33
-98 1423
-99 15
-100 11096
-101 21
-102 4082
-103 13
-104 5442
-105 25
-106 2113
-107 26
-108 3779
-109 43
-110 1294
-111 29
-112 7860
-113 29
-114 4965
-115 22
-116 7898
-117 25
-118 1772
-119 28
-120 1149
-121 38
-122 1483
-123 32
-124 10572
-125 25
-126 1147
-127 31
-128 1699
-129 22
-130 5533
-131 22
-132 4669
-133 34
-134 3777
-135 10
-136 5412
-137 21
-138 855
-139 26
-140 2485
-141 46
-142 1970
-143 27
-144 6565
-145 40
-146 933
-147 15
-148 7923
-149 16
-150 735
-151 23
-152 1111
-153 33
-154 3714
-155 27
-156 2445
-157 30
-158 3367
-159 10
-160 4646
-161 27
-162 990
-163 23
-164 5679
-165 25
-166 2186
-167 17
-168 899
-169 32
-170 1034
-171 22
-172 6185
-173 32
-174 2685
-175 17
-176 1354
-177 38
-178 1460
-179 15
-180 3478
-181 20
-182 958
-183 20
-184 6055
-185 23
-186 2180
-187 15
-188 1416
-189 30
-190 1284
-191 22
-192 1341
-193 21
-194 2413
-195 18
-196 4984
-197 13
-198 830
-199 22
-200 1834
-201 19
-202 2238
-203 9
-204 3050
-205 22
-206 616
-207 17
-208 2892
-209 22
-210 711
-211 30
-212 2631
-213 19
-214 3341
-215 21
-216 987
-217 26
-218 823
-219 9
-220 3588
-221 20
-222 692
-223 7
-224 2925
-225 31
-226 1075
-227 16
-228 2909
-229 18
-230 673
-231 20
-232 2215
-233 14
-234 1584
-235 21
-236 1292
-237 29
-238 1647
-239 25
-240 1014
-241 30
-242 1648
-243 19
-244 4465
-245 10
-246 787
-247 11
-248 480
-249 25
-250 842
-251 15
-252 1219
-253 23
-254 1508
-255 8
-256 3525
-257 16
-258 490
-259 12
-260 1678
-261 14
-262 822
-263 16
-264 1729
-265 28
-266 604
-267 11
-268 2572
-269 7
-270 1242
-271 15
-272 725
-273 18
-274 1983
-275 13
-276 1662
-277 19
-278 491
-279 12
-280 1586
-281 14
-282 563
-283 10
-284 2363
-285 10
-286 656
-287 14
-288 725
-289 28
-290 871
-291 9
-292 2606
-293 12
-294 961
-295 9
-296 478
-297 13
-298 1252
-299 10
-300 736
-301 19
-302 466
-303 13
-304 2254
-305 12
-306 486
-307 14
-308 1145
-309 13
-310 955
-311 13
-312 1235
-313 13
-314 931
-315 14
-316 1768
-317 11
-318 330
-319 10
-320 539
-321 23
-322 570
-323 12
-324 1789
-325 13
-326 884
-327 5
-328 1422
-329 14
-330 317
-331 11
-332 509
-333 13
-334 1062
-335 12
-336 577
-337 27
-338 378
-339 10
-340 2313
-341 9
-342 391
-343 13
-344 894
-345 17
-346 664
-347 9
-348 453
-349 6
-350 363
-351 15
-352 1115
-353 13
-354 1054
-355 8
-356 1108
-357 12
-358 354
-359 7
-360 363
-361 16
-362 344
-363 11
-364 1734
-365 12
-366 265
-367 10
-368 969
-369 16
-370 316
-371 12
-372 757
-373 7
-374 563
-375 15
-376 857
-377 9
-378 469
-379 9
-380 385
-381 12
-382 921
-383 15
-384 764
-385 14
-386 246
-387 6
-388 1108
-389 14
-390 230
-391 8
-392 266
-393 11
-394 641
-395 8
-396 719
-397 9
-398 243
-399 4
-400 1108
-401 7
-402 229
-403 7
-404 903
-405 7
-406 257
-407 12
-408 244
-409 3
-410 541
-411 6
-412 744
-413 8
-414 419
-415 8
-416 388
-417 19
-418 470
-419 14
-420 612
-421 6
-422 342
-423 3
-424 1179
-425 3
-426 116
-427 14
-428 207
-429 6
-430 255
-431 4
-432 288
-433 12
-434 343
-435 6
-436 1015
-437 3
-438 538
-439 10
-440 194
-441 6
-442 188
-443 15
-444 524
-445 7
-446 214
-447 7
-448 574
-449 6
-450 214
-451 5
-452 635
-453 9
-454 464
-455 5
-456 205
-457 9
-458 163
-459 2
-460 558
-461 4
-462 171
-463 14
-464 444
-465 11
-466 543
-467 5
-468 388
-469 6
-470 141
-471 4
-472 647
-473 3
-474 210
-475 4
-476 193
-477 7
-478 195
-479 7
-480 443
-481 10
-482 198
-483 3
-484 816
-485 6
-486 128
-487 9
-488 215
-489 9
-490 328
-491 7
-492 158
-493 11
-494 335
-495 8
-496 435
-497 6
-498 174
-499 1
-500 373
-501 5
-502 140
-503 7
-504 330
-505 9
-506 149
-507 5
-508 642
-509 3
-510 179
-511 3
-512 159
-513 8
-514 204
-515 7
-516 306
-517 4
-518 110
-519 5
-520 326
-521 6
-522 305
-523 6
-524 294
-525 7
-526 268
-527 5
-528 149
-529 4
-530 133
-531 2
-532 513
-533 10
-534 116
-535 5
-536 258
-537 4
-538 113
-539 4
-540 138
-541 6
-542 116
-544 485
-545 4
-546 93
-547 9
-548 299
-549 3
-550 256
-551 6
-552 92
-553 3
-554 175
-555 6
-556 253
-557 7
-558 95
-559 2
-560 128
-561 4
-562 206
-563 2
-564 465
-565 3
-566 69
-567 3
-568 157
-569 7
-570 97
-571 8
-572 118
-573 5
-574 130
-575 4
-576 301
-577 6
-578 177
-579 2
-580 397
-581 3
-582 80
-583 1
-584 128
-585 5
-586 52
-587 2
-588 72
-589 1
-590 84
-591 6
-592 323
-593 11
-594 77
-595 5
-596 205
-597 1
-598 244
-599 4
-600 69
-601 3
-602 89
-603 5
-604 254
-605 6
-606 147
-607 3
-608 83
-609 3
-610 77
-611 3
-612 194
-613 1
-614 98
-615 3
-616 243
-617 3
-618 50
-619 8
-620 188
-621 4
-622 67
-623 4
-624 123
-625 2
-626 50
-627 1
-628 239
-629 2
-630 51
-631 4
-632 65
-633 5
-634 188
-636 81
-637 3
-638 46
-639 3
-640 103
-641 1
-642 136
-643 3
-644 188
-645 3
-646 58
-648 122
-649 4
-650 47
-651 2
-652 155
-653 4
-654 71
-655 1
-656 71
-657 3
-658 50
-659 2
-660 177
-661 5
-662 66
-663 2
-664 183
-665 3
-666 50
-667 2
-668 53
-669 2
-670 115
-672 66
-673 2
-674 47
-675 1
-676 197
-677 2
-678 46
-679 3
-680 95
-681 3
-682 46
-683 3
-684 107
-685 1
-686 86
-687 2
-688 158
-689 4
-690 51
-691 1
-692 80
-694 56
-695 4
-696 40
-698 43
-699 3
-700 95
-701 2
-702 51
-703 2
-704 133
-705 1
-706 100
-707 2
-708 121
-709 2
-710 15
-711 3
-712 35
-713 2
-714 20
-715 3
-716 37
-717 2
-718 78
-720 55
-721 1
-722 42
-723 2
-724 218
-725 3
-726 23
-727 2
-728 26
-729 1
-730 64
-731 2
-732 65
-734 24
-735 2
-736 53
-737 1
-738 32
-739 1
-740 60
-742 81
-743 1
-744 77
-745 1
-746 47
-747 1
-748 62
-749 1
-750 19
-751 1
-752 86
-753 3
-754 40
-756 55
-757 2
-758 38
-759 1
-760 101
-761 1
-762 22
-764 67
-765 2
-766 35
-767 1
-768 38
-769 1
-770 22
-771 1
-772 82
-773 1
-774 73
-776 29
-777 1
-778 55
-780 23
-781 1
-782 16
-784 84
-785 3
-786 28
-788 59
-789 1
-790 33
-791 3
-792 24
-794 13
-795 1
-796 110
-797 2
-798 15
-800 22
-801 3
-802 29
-803 1
-804 87
-806 21
-808 29
-810 48
-812 28
-813 1
-814 58
-815 1
-816 48
-817 1
-818 31
-819 1
-820 66
-822 17
-823 2
-824 58
-826 10
-827 2
-828 25
-829 1
-830 29
-831 1
-832 63
-833 1
-834 26
-835 3
-836 52
-837 1
-838 18
-840 27
-841 2
-842 12
-843 1
-844 83
-845 1
-846 7
-847 1
-848 10
-850 26
-852 25
-853 1
-854 15
-856 27
-858 32
-859 1
-860 15
-862 43
-864 32
-865 1
-866 6
-868 39
-870 11
-872 25
-873 1
-874 10
-875 1
-876 20
-877 2
-878 19
-879 1
-880 30
-882 11
-884 53
-886 25
-887 1
-888 28
-890 6
-892 36
-894 10
-896 13
-898 14
-900 31
-902 14
-903 2
-904 43
-906 25
-908 9
-910 11
-911 1
-912 16
-913 1
-914 24
-916 27
-918 6
-920 15
-922 27
-923 1
-924 23
-926 13
-928 42
-929 1
-930 3
-932 27
-934 17
-936 8
-937 1
-938 11
-940 33
-942 4
-943 1
-944 18
-946 15
-948 13
-950 18
-952 12
-954 11
-956 21
-958 10
-960 13
-962 5
-964 32
-966 13
-968 8
-970 8
-971 1
-972 23
-973 2
-974 12
-975 1
-976 22
-978 7
-979 1
-980 14
-982 8
-984 22
-985 1
-986 6
-988 17
-989 1
-990 6
-992 13
-994 19
-996 11
-998 4
-1000 9
-1002 2
-1004 14
-1006 5
-1008 3
-1010 9
-1012 29
-1014 6
-1016 22
-1017 1
-1018 8
-1019 1
-1020 7
-1022 6
-1023 1
-1024 10
-1026 2
-1028 8
-1030 11
-1031 2
-1032 8
-1034 9
-1036 13
-1038 12
-1040 12
-1042 3
-1044 12
-1046 3
-1048 11
-1050 2
-1051 1
-1052 2
-1054 11
-1056 6
-1058 8
-1059 1
-1060 23
-1062 6
-1063 1
-1064 8
-1066 3
-1068 6
-1070 8
-1071 1
-1072 5
-1074 3
-1076 5
-1078 3
-1080 11
-1081 1
-1082 7
-1084 18
-1086 4
-1087 1
-1088 3
-1090 3
-1092 7
-1094 3
-1096 12
-1098 6
-1099 1
-1100 2
-1102 6
-1104 14
-1106 3
-1108 6
-1110 5
-1112 2
-1114 8
-1116 3
-1118 3
-1120 7
-1122 10
-1124 6
-1126 8
-1128 1
-1130 4
-1132 3
-1134 2
-1136 5
-1138 5
-1140 8
-1142 3
-1144 7
-1146 3
-1148 11
-1150 1
-1152 5
-1154 1
-1156 5
-1158 1
-1160 5
-1162 3
-1164 6
-1165 1
-1166 1
-1168 4
-1169 1
-1170 3
-1171 1
-1172 2
-1174 5
-1176 3
-1177 1
-1180 8
-1182 2
-1184 4
-1186 2
-1188 3
-1190 2
-1192 5
-1194 6
-1196 1
-1198 2
-1200 2
-1204 10
-1206 2
-1208 9
-1210 1
-1214 6
-1216 3
-1218 4
-1220 9
-1221 2
-1222 1
-1224 5
-1226 4
-1228 8
-1230 1
-1232 1
-1234 3
-1236 5
-1240 3
-1242 1
-1244 3
-1245 1
-1246 4
-1248 6
-1250 2
-1252 7
-1256 3
-1258 2
-1260 2
-1262 3
-1264 4
-1265 1
-1266 1
-1270 1
-1271 1
-1272 2
-1274 3
-1276 3
-1278 1
-1280 3
-1284 1
-1286 1
-1290 1
-1292 3
-1294 1
-1296 7
-1300 2
-1302 4
-1304 3
-1306 2
-1308 2
-1312 1
-1314 1
-1316 3
-1318 2
-1320 1
-1324 8
-1326 1
-1330 1
-1331 1
-1336 2
-1338 1
-1340 3
-1341 1
-1344 1
-1346 2
-1347 1
-1348 3
-1352 1
-1354 2
-1356 1
-1358 1
-1360 3
-1362 1
-1364 4
-1366 1
-1370 1
-1372 3
-1380 2
-1384 2
-1388 2
-1390 2
-1392 2
-1394 1
-1396 1
-1398 1
-1400 2
-1402 1
-1404 1
-1406 1
-1410 1
-1412 5
-1418 1
-1420 1
-1424 1
-1432 2
-1434 2
-1442 3
-1444 5
-1448 1
-1454 1
-1456 1
-1460 3
-1462 4
-1468 1
-1474 1
-1476 1
-1478 2
-1480 1
-1486 2
-1488 1
-1492 1
-1496 1
-1500 3
-1503 1
-1506 1
-1512 2
-1516 1
-1522 1
-1524 2
-1534 4
-1536 1
-1538 1
-1540 2
-1544 2
-1548 1
-1556 1
-1560 1
-1562 1
-1564 2
-1566 1
-1568 1
-1570 1
-1572 1
-1576 1
-1590 1
-1594 1
-1604 1
-1608 1
-1614 1
-1622 1
-1624 2
-1628 1
-1629 1
-1636 1
-1642 1
-1654 2
-1660 1
-1664 1
-1670 1
-1684 4
-1698 1
-1732 3
-1742 1
-1752 1
-1760 1
-1764 1
-1772 2
-1798 1
-1808 1
-1820 1
-1852 1
-1856 1
-1874 1
-1902 1
-1908 1
-1952 1
-2004 1
-2018 1
-2020 1
-2028 1
-2174 1
-2233 1
-2244 1
-2280 1
-2290 1
-2352 1
-2604 1
-4190 1
+ rowspan="10"
\ No newline at end of file
diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py
index b09f1db6e938e8eb99148d69efce016f1cbe8628..3647111fddaa848a75873ab689559c63dd6d4814 100644
--- a/ppocr/utils/save_load.py
+++ b/ppocr/utils/save_load.py
@@ -177,9 +177,9 @@ def save_model(model,
model.backbone.model.save_pretrained(model_prefix)
metric_prefix = os.path.join(model_prefix, 'metric')
# save metric and config
+ with open(metric_prefix + '.states', 'wb') as f:
+ pickle.dump(kwargs, f, protocol=2)
if is_best:
- with open(metric_prefix + '.states', 'wb') as f:
- pickle.dump(kwargs, f, protocol=2)
logger.info('save best model is to {}'.format(model_prefix))
else:
logger.info("save model in {}".format(model_prefix))
diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py
index 4a25ff8b2fa182faaf4f4ce8909c9ec2e9b55ccc..b881fcab20bc5ca076a0002bd72349768c7d881a 100755
--- a/ppocr/utils/utility.py
+++ b/ppocr/utils/utility.py
@@ -91,18 +91,19 @@ def check_and_read_gif(img_path):
def load_vqa_bio_label_maps(label_map_path):
with open(label_map_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
- lines = [line.strip() for line in lines]
- if "O" not in lines:
- lines.insert(0, "O")
- labels = []
- for line in lines:
- if line == "O":
- labels.append("O")
- else:
- labels.append("B-" + line)
- labels.append("I-" + line)
- label2id_map = {label: idx for idx, label in enumerate(labels)}
- id2label_map = {idx: label for idx, label in enumerate(labels)}
+ old_lines = [line.strip() for line in lines]
+ lines = ["O"]
+ for line in old_lines:
+ # "O" has already been in lines
+ if line.upper() in ["OTHER", "OTHERS", "IGNORE"]:
+ continue
+ lines.append(line)
+ labels = ["O"]
+ for line in lines[1:]:
+ labels.append("B-" + line)
+ labels.append("I-" + line)
+ label2id_map = {label.upper(): idx for idx, label in enumerate(labels)}
+ id2label_map = {idx: label.upper() for idx, label in enumerate(labels)}
return label2id_map, id2label_map
diff --git a/ppocr/utils/visual.py b/ppocr/utils/visual.py
index 7a8c1674a74f89299de59f7cd120b4577a7499d8..e0fbf06abb471c294cb268520fb99bca1a6b1d61 100644
--- a/ppocr/utils/visual.py
+++ b/ppocr/utils/visual.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import cv2
import os
import numpy as np
from PIL import Image, ImageDraw, ImageFont
@@ -19,7 +20,7 @@ from PIL import Image, ImageDraw, ImageFont
def draw_ser_results(image,
ocr_results,
font_path="doc/fonts/simfang.ttf",
- font_size=18):
+ font_size=14):
np.random.seed(2021)
color = (np.random.permutation(range(255)),
np.random.permutation(range(255)),
@@ -40,9 +41,15 @@ def draw_ser_results(image,
if ocr_info["pred_id"] not in color_map:
continue
color = color_map[ocr_info["pred_id"]]
- text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
+ text = "{}: {}".format(ocr_info["pred"], ocr_info["transcription"])
- draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
+ if "bbox" in ocr_info:
+ # draw with ocr engine
+ bbox = ocr_info["bbox"]
+ else:
+ # draw with ocr groundtruth
+ bbox = trans_poly_to_bbox(ocr_info["points"])
+ draw_box_txt(bbox, text, draw, font, font_size, color)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
@@ -62,6 +69,14 @@ def draw_box_txt(bbox, text, draw, font, font_size, color):
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
+def trans_poly_to_bbox(poly):
+ x1 = np.min([p[0] for p in poly])
+ x2 = np.max([p[0] for p in poly])
+ y1 = np.min([p[1] for p in poly])
+ y2 = np.max([p[1] for p in poly])
+ return [x1, y1, x2, y2]
+
+
def draw_re_results(image,
result,
font_path="doc/fonts/simfang.ttf",
@@ -80,10 +95,10 @@ def draw_re_results(image,
color_line = (0, 255, 0)
for ocr_info_head, ocr_info_tail in result:
- draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
- font_size, color_head)
- draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
- font_size, color_tail)
+ draw_box_txt(ocr_info_head["bbox"], ocr_info_head["transcription"],
+ draw, font, font_size, color_head)
+ draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["transcription"],
+ draw, font, font_size, color_tail)
center_head = (
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
@@ -96,3 +111,16 @@ def draw_re_results(image,
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
+
+
+def draw_rectangle(img_path, boxes, use_xywh=False):
+ img = cv2.imread(img_path)
+ img_show = img.copy()
+ for box in boxes.astype(int):
+ if use_xywh:
+ x, y, w, h = box
+ x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
+ else:
+ x1, y1, x2, y2 = box
+ cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
+ return img_show
\ No newline at end of file
diff --git a/ppstructure/docs/kie.md b/ppstructure/docs/kie.md
index 35498b33478d1010fd2548dfcb8586b4710723a1..315dd9f7bafa6b6160489eab330e8d278b2d119d 100644
--- a/ppstructure/docs/kie.md
+++ b/ppstructure/docs/kie.md
@@ -16,7 +16,7 @@ SDMGR是一个关键信息提取算法,将每个检测到的文本区域分类
训练和测试的数据采用wildreceipt数据集,通过如下指令下载数据集:
```
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar
+wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar
```
执行预测:
diff --git a/ppstructure/docs/kie_en.md b/ppstructure/docs/kie_en.md
index 1fe38b0b399e9290526dafa5409673dc87026db7..7b3752223dd765e780d56d146c90bd0f892aac7b 100644
--- a/ppstructure/docs/kie_en.md
+++ b/ppstructure/docs/kie_en.md
@@ -15,7 +15,7 @@ This section provides a tutorial example on how to quickly use, train, and evalu
[Wildreceipt dataset](https://paperswithcode.com/dataset/wildreceipt) is used for this tutorial. It contains 1765 photos, with 25 classes, and 50000 text boxes, which can be downloaded by wget:
```shell
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar
+wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar
```
Download the pretrained model and predict the result:
diff --git a/ppstructure/docs/models_list.md b/ppstructure/docs/models_list.md
index c7dab999ff6e370c56c5495e22e91f117b3d1275..42d44009dad1ba1b07bb410c199993c6f79f3d5d 100644
--- a/ppstructure/docs/models_list.md
+++ b/ppstructure/docs/models_list.md
@@ -1,11 +1,11 @@
# PP-Structure 系列模型列表
-- [1. 版面分析模型](#1)
-- [2. OCR和表格识别模型](#2)
- - [2.1 OCR](#21)
- - [2.2 表格识别模型](#22)
-- [3. VQA模型](#3)
-- [4. KIE模型](#4)
+- [1. 版面分析模型](#1-版面分析模型)
+- [2. OCR和表格识别模型](#2-ocr和表格识别模型)
+ - [2.1 OCR](#21-ocr)
+ - [2.2 表格识别模型](#22-表格识别模型)
+- [3. VQA模型](#3-vqa模型)
+- [4. KIE模型](#4-kie模型)
@@ -35,18 +35,18 @@
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
-|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
+|en_ppocr_mobile_v2.0_table_structure|PubTabNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
## 3. VQA模型
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
-|ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
-|re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
-|ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
+|ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
+|re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
+|ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
-|ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
+|ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
## 4. KIE模型
diff --git a/ppstructure/docs/models_list_en.md b/ppstructure/docs/models_list_en.md
index b92c10c241df72c85649b64f915b4266cd3fe410..e133a0bb2a9b017207b5e92ea444aba4633a7457 100644
--- a/ppstructure/docs/models_list_en.md
+++ b/ppstructure/docs/models_list_en.md
@@ -1,11 +1,11 @@
# PP-Structure Model list
-- [1. Layout Analysis](#1)
-- [2. OCR and Table Recognition](#2)
- - [2.1 OCR](#21)
- - [2.2 Table Recognition](#22)
-- [3. VQA](#3)
-- [4. KIE](#4)
+- [1. Layout Analysis](#1-layout-analysis)
+- [2. OCR and Table Recognition](#2-ocr-and-table-recognition)
+ - [2.1 OCR](#21-ocr)
+ - [2.2 Table Recognition](#22-table-recognition)
+- [3. VQA](#3-vqa)
+- [4. KIE](#4-kie)
@@ -42,11 +42,11 @@ If you need to use other OCR models, you can download the model in [PP-OCR model
|model| description |inference model size|download|
| --- |----------------------------------------------------------------| --- | --- |
-|ser_LayoutXLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
-|re_LayoutXLM_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
-|ser_LayoutLMv2_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLMv2 |778M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
+|ser_LayoutXLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
+|re_LayoutXLM_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
+|ser_LayoutLMv2_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLMv2 |778M|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLMv2 |765M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
-|ser_LayoutLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutLM |430M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
+|ser_LayoutLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutLM |430M|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
## 4. KIE
diff --git a/ppstructure/table/README.md b/ppstructure/table/README.md
index d21ef4aa3813b4ff49dc0580be35c5e2e0483c8f..b6804c6f09b4ee3d17cd2b81e6cc6642c1c1be9a 100644
--- a/ppstructure/table/README.md
+++ b/ppstructure/table/README.md
@@ -18,7 +18,7 @@ The table recognition mainly contains three models
The table recognition flow chart is as follows
-![tableocr_pipeline](../../doc/table/tableocr_pipeline_en.jpg)
+![tableocr_pipeline](../docs/table/tableocr_pipeline_en.jpg)
1. The coordinates of single-line text is detected by DB model, and then sends it to the recognition model to get the recognition result.
2. The table structure and cell coordinates is predicted by RARE model.
diff --git a/ppstructure/table/predict_structure.py b/ppstructure/table/predict_structure.py
index 0179c614ae4864677576f6073f291282fb772988..7a7d3169d567493b4707b63c75cec07485cf5acb 100755
--- a/ppstructure/table/predict_structure.py
+++ b/ppstructure/table/predict_structure.py
@@ -23,43 +23,63 @@ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import numpy as np
import time
+import json
import tools.infer.utility as utility
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+from ppocr.utils.visual import draw_rectangle
from ppstructure.utility import parse_args
logger = get_logger()
+def build_pre_process_list(args):
+ resize_op = {'ResizeTableImage': {'max_len': args.table_max_len, }}
+ pad_op = {
+ 'PaddingTableImage': {
+ 'size': [args.table_max_len, args.table_max_len]
+ }
+ }
+ normalize_op = {
+ 'NormalizeImage': {
+ 'std': [0.229, 0.224, 0.225] if
+ args.table_algorithm not in ['TableMaster'] else [0.5, 0.5, 0.5],
+ 'mean': [0.485, 0.456, 0.406] if
+ args.table_algorithm not in ['TableMaster'] else [0.5, 0.5, 0.5],
+ 'scale': '1./255.',
+ 'order': 'hwc'
+ }
+ }
+ to_chw_op = {'ToCHWImage': None}
+ keep_keys_op = {'KeepKeys': {'keep_keys': ['image', 'shape']}}
+ if args.table_algorithm not in ['TableMaster']:
+ pre_process_list = [
+ resize_op, normalize_op, pad_op, to_chw_op, keep_keys_op
+ ]
+ else:
+ pre_process_list = [
+ resize_op, pad_op, normalize_op, to_chw_op, keep_keys_op
+ ]
+ return pre_process_list
+
+
class TableStructurer(object):
def __init__(self, args):
- pre_process_list = [{
- 'ResizeTableImage': {
- 'max_len': args.table_max_len
- }
- }, {
- 'NormalizeImage': {
- 'std': [0.229, 0.224, 0.225],
- 'mean': [0.485, 0.456, 0.406],
- 'scale': '1./255.',
- 'order': 'hwc'
+ pre_process_list = build_pre_process_list(args)
+ if args.table_algorithm not in ['TableMaster']:
+ postprocess_params = {
+ 'name': 'TableLabelDecode',
+ "character_dict_path": args.table_char_dict_path,
}
- }, {
- 'PaddingTableImage': None
- }, {
- 'ToCHWImage': None
- }, {
- 'KeepKeys': {
- 'keep_keys': ['image']
+ else:
+ postprocess_params = {
+ 'name': 'TableMasterLabelDecode',
+ "character_dict_path": args.table_char_dict_path,
+ 'box_shape': 'pad'
}
- }]
- postprocess_params = {
- 'name': 'TableLabelDecode',
- "character_dict_path": args.table_char_dict_path,
- }
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
@@ -88,27 +108,17 @@ class TableStructurer(object):
preds['structure_probs'] = outputs[1]
preds['loc_preds'] = outputs[0]
- post_result = self.postprocess_op(preds)
-
- structure_str_list = post_result['structure_str_list']
- res_loc = post_result['res_loc']
- imgh, imgw = ori_im.shape[0:2]
- res_loc_final = []
- for rno in range(len(res_loc[0])):
- x0, y0, x1, y1 = res_loc[0][rno]
- left = max(int(imgw * x0), 0)
- top = max(int(imgh * y0), 0)
- right = min(int(imgw * x1), imgw - 1)
- bottom = min(int(imgh * y1), imgh - 1)
- res_loc_final.append([left, top, right, bottom])
-
- structure_str_list = structure_str_list[0][:-1]
+ shape_list = np.expand_dims(data[-1], axis=0)
+ post_result = self.postprocess_op(preds, [shape_list])
+
+ structure_str_list = post_result['structure_batch_list'][0]
+ bbox_list = post_result['bbox_batch_list'][0]
+ structure_str_list = structure_str_list[0]
structure_str_list = [
'', '', ''
] + structure_str_list + [' ', '', '']
-
elapse = time.time() - starttime
- return (structure_str_list, res_loc_final), elapse
+ return (structure_str_list, bbox_list), elapse
def main(args):
@@ -116,21 +126,35 @@ def main(args):
table_structurer = TableStructurer(args)
count = 0
total_time = 0
- for image_file in image_file_list:
- img, flag = check_and_read_gif(image_file)
- if not flag:
- img = cv2.imread(image_file)
- if img is None:
- logger.info("error in loading image:{}".format(image_file))
- continue
- structure_res, elapse = table_structurer(img)
-
- logger.info("result: {}".format(structure_res))
-
- if count > 0:
- total_time += elapse
- count += 1
- logger.info("Predict time of {}: {}".format(image_file, elapse))
+ use_xywh = args.table_algorithm in ['TableMaster']
+ os.makedirs(args.output, exist_ok=True)
+ with open(
+ os.path.join(args.output, 'infer.txt'), mode='w',
+ encoding='utf-8') as f_w:
+ for image_file in image_file_list:
+ img, flag = check_and_read_gif(image_file)
+ if not flag:
+ img = cv2.imread(image_file)
+ if img is None:
+ logger.info("error in loading image:{}".format(image_file))
+ continue
+ structure_res, elapse = table_structurer(img)
+ structure_str_list, bbox_list = structure_res
+ bbox_list_str = json.dumps(bbox_list.tolist())
+ logger.info("result: {}, {}".format(structure_str_list,
+ bbox_list_str))
+ f_w.write("result: {}, {}\n".format(structure_str_list,
+ bbox_list_str))
+
+ img = draw_rectangle(image_file, bbox_list, use_xywh)
+ img_save_path = os.path.join(args.output,
+ os.path.basename(image_file))
+ cv2.imwrite(img_save_path, img)
+ logger.info("save vis result to {}".format(img_save_path))
+ if count > 0:
+ total_time += elapse
+ count += 1
+ logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__":
diff --git a/ppstructure/utility.py b/ppstructure/utility.py
index 1ad902e7e6be95a6901e3774420fad337f594861..af0616239b167ff9ca5f6e1222015d51338d6bab 100644
--- a/ppstructure/utility.py
+++ b/ppstructure/utility.py
@@ -25,6 +25,7 @@ def init_args():
parser.add_argument("--output", type=str, default='./output')
# params for table structure
parser.add_argument("--table_max_len", type=int, default=488)
+ parser.add_argument("--table_algorithm", type=str, default='TableAttn')
parser.add_argument("--table_model_dir", type=str)
parser.add_argument(
"--table_char_dict_path",
@@ -40,6 +41,13 @@ def init_args():
type=ast.literal_eval,
default=None,
help='label map according to ppstructure/layout/README_ch.md')
+ # params for vqa
+ parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
+ parser.add_argument("--ser_model_dir", type=str)
+ parser.add_argument(
+ "--ser_dict_path",
+ type=str,
+ default="../train_data/XFUND/class_list_xfun.txt")
# params for inference
parser.add_argument(
"--mode",
@@ -65,7 +73,7 @@ def init_args():
"--recovery",
type=bool,
default=False,
- help='Whether to enable layout of recovery')
+ help='Whether to enable layout of recovery')
return parser
diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md
index e3a10671ddb6494eb15073e7ac007aa1e8e6a32a..05635265b5e5eff18429e2d595fc4195381299f5 100644
--- a/ppstructure/vqa/README.md
+++ b/ppstructure/vqa/README.md
@@ -1,19 +1,15 @@
English | [简体中文](README_ch.md)
-- [Document Visual Question Answering (Doc-VQA)](#Document-Visual-Question-Answering)
- - [1. Introduction](#1-Introduction)
- - [2. Performance](#2-performance)
- - [3. Effect demo](#3-Effect-demo)
- - [3.1 SER](#31-ser)
- - [3.2 RE](#32-re)
- - [4. Install](#4-Install)
- - [4.1 Installation dependencies](#41-Install-dependencies)
- - [4.2 Install PaddleOCR](#42-Install-PaddleOCR)
- - [5. Usage](#5-Usage)
- - [5.1 Data and Model Preparation](#51-Data-and-Model-Preparation)
- - [5.2 SER](#52-ser)
- - [5.3 RE](#53-re)
- - [6. Reference](#6-Reference-Links)
+- [1 Introduction](#1-introduction)
+- [2. Performance](#2-performance)
+- [3. Effect demo](#3-effect-demo)
+ - [3.1 SER](#31-ser)
+ - [3.2 RE](#32-re)
+- [4. Install](#4-install)
+ - [4.1 Install dependencies](#41-install-dependencies)
+ - [5.3 RE](#53-re)
+- [6. Reference Links](#6-reference-links)
+- [License](#license)
# Document Visual Question Answering
@@ -125,13 +121,13 @@ If you want to experience the prediction process directly, you can download the
* Download the processed dataset
-The download address of the processed XFUND Chinese dataset: [https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar).
+The download address of the processed XFUND Chinese dataset: [link](https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar).
Download and unzip the dataset, and place the dataset in the current directory after unzipping.
```shell
-wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
+wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar
````
* Convert the dataset
@@ -187,17 +183,17 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o
````
Finally, `precision`, `recall`, `hmean` and other indicators will be printed
-* Use `OCR engine + SER` tandem prediction
+* `OCR + SER` tandem prediction based on training engine
-Use the following command to complete the series prediction of `OCR engine + SER`, taking the pretrained SER model as an example:
+Use the following command to complete the series prediction of `OCR engine + SER`, taking the SER model based on LayoutXLM as an example::
```shell
-CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_42.jpg
+python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
````
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
-* End-to-end evaluation of `OCR engine + SER` prediction system
+* End-to-end evaluation of `OCR + SER` prediction system
First use the `tools/infer_vqa_token_ser.py` script to complete the prediction of the dataset, then use the following command to evaluate.
@@ -205,6 +201,24 @@ First use the `tools/infer_vqa_token_ser.py` script to complete the prediction o
export CUDA_VISIBLE_DEVICES=0
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
````
+* export model
+
+Use the following command to complete the model export of the SER model, taking the SER model based on LayoutXLM as an example:
+
+```shell
+python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
+```
+The converted model will be stored in the directory specified by the `Global.save_inference_dir` field.
+
+* `OCR + SER` tandem prediction based on prediction engine
+
+Use the following command to complete the tandem prediction of `OCR + SER` based on the prediction engine, taking the SER model based on LayoutXLM as an example:
+
+```shell
+cd ppstructure
+CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
+```
+After the prediction is successful, the visualization images and results will be saved in the directory specified by the `output` field
### 5.3 RE
@@ -247,11 +261,19 @@ Finally, `precision`, `recall`, `hmean` and other indicators will be printed
Use the following command to complete the series prediction of `OCR engine + SER + RE`, taking the pretrained SER and RE models as an example:
```shell
export CUDA_VISIBLE_DEVICES=0
-python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm. yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
+python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm. yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
````
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
+* export model
+
+cooming soon
+
+* `OCR + SER + RE` tandem prediction based on prediction engine
+
+cooming soon
+
## 6. Reference Links
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
diff --git a/ppstructure/vqa/README_ch.md b/ppstructure/vqa/README_ch.md
index b677dc07bce6c1a752d753b6a1c538b4d3f99271..b421a82d3a1cbe39f5c740bea486ec26593ab20f 100644
--- a/ppstructure/vqa/README_ch.md
+++ b/ppstructure/vqa/README_ch.md
@@ -1,19 +1,19 @@
[English](README.md) | 简体中文
-- [文档视觉问答(DOC-VQA)](#文档视觉问答doc-vqa)
- - [1. 简介](#1-简介)
- - [2. 性能](#2-性能)
- - [3. 效果演示](#3-效果演示)
- - [3.1 SER](#31-ser)
- - [3.2 RE](#32-re)
- - [4. 安装](#4-安装)
- - [4.1 安装依赖](#41-安装依赖)
- - [4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)](#42-安装paddleocr包含-pp-ocr-和-vqa)
- - [5. 使用](#5-使用)
- - [5.1 数据和预训练模型准备](#51-数据和预训练模型准备)
- - [5.2 SER](#52-ser)
- - [5.3 RE](#53-re)
- - [6. 参考链接](#6-参考链接)
+- [1. 简介](#1-简介)
+- [2. 性能](#2-性能)
+- [3. 效果演示](#3-效果演示)
+ - [3.1 SER](#31-ser)
+ - [3.2 RE](#32-re)
+- [4. 安装](#4-安装)
+ - [4.1 安装依赖](#41-安装依赖)
+ - [4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)](#42-安装paddleocr包含-pp-ocr-和-vqa)
+- [5. 使用](#5-使用)
+ - [5.1 数据和预训练模型准备](#51-数据和预训练模型准备)
+ - [5.2 SER](#52-ser)
+ - [5.3 RE](#53-re)
+- [6. 参考链接](#6-参考链接)
+- [License](#license)
# 文档视觉问答(DOC-VQA)
@@ -122,13 +122,13 @@ python3 -m pip install -r ppstructure/vqa/requirements.txt
* 下载处理好的数据集
-处理好的XFUND中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)。
+处理好的XFUND中文数据集下载地址:[链接](https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar)。
下载并解压该数据集,解压后将数据集放置在当前目录下。
```shell
-wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
+wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar
```
* 转换数据集
@@ -183,16 +183,16 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o
```
最终会打印出`precision`, `recall`, `hmean`等指标
-* 使用`OCR引擎 + SER`串联预测
+* 基于训练引擎的`OCR + SER`串联预测
-使用如下命令即可完成`OCR引擎 + SER`的串联预测, 以SER预训练模型为例:
+使用如下命令即可完成基于训练引擎的`OCR + SER`的串联预测, 以基于LayoutXLM的SER模型为例:
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_42.jpg
```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`。
-* 对`OCR引擎 + SER`预测系统进行端到端评估
+* 对`OCR + SER`预测系统进行端到端评估
首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。
@@ -200,6 +200,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/l
export CUDA_VISIBLE_DEVICES=0
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
```
+* 模型导出
+
+使用如下命令即可完成SER模型的模型导出, 以基于LayoutXLM的SER模型为例:
+
+```shell
+python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
+```
+转换后的模型会存放在`Global.save_inference_dir`字段指定的目录下。
+
+* 基于预测引擎的`OCR + SER`串联预测
+
+使用如下命令即可完成基于预测引擎的`OCR + SER`的串联预测, 以基于LayoutXLM的SER模型为例:
+
+```shell
+cd ppstructure
+CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
+```
+预测成功后,可视化图片和结果会保存在`output`字段指定的目录下
### 5.3 RE
@@ -236,16 +254,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o
```
最终会打印出`precision`, `recall`, `hmean`等指标
-* 使用`OCR引擎 + SER + RE`串联预测
+* 基于训练引擎的`OCR + SER + RE`串联预测
-使用如下命令即可完成`OCR引擎 + SER + RE`的串联预测, 以预训练SER和RE模型为例:
+使用如下命令即可完成基于训练引擎的`OCR + SER + RE`串联预测, 以基于LayoutXLMSER和RE模型为例:
```shell
export CUDA_VISIBLE_DEVICES=0
-python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
+python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`。
+* 模型导出
+
+cooming soon
+
+* 基于预测引擎的`OCR + SER + RE`串联预测
+
+cooming soon
+
## 6. 参考链接
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
diff --git a/ppstructure/vqa/labels/labels_ser.txt b/ppstructure/vqa/labels/labels_ser.txt
deleted file mode 100644
index 508e48112412f62538baf0c78bcf99ec8945196e..0000000000000000000000000000000000000000
--- a/ppstructure/vqa/labels/labels_ser.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-QUESTION
-ANSWER
-HEADER
diff --git a/ppstructure/vqa/predict_vqa_token_ser.py b/ppstructure/vqa/predict_vqa_token_ser.py
new file mode 100644
index 0000000000000000000000000000000000000000..de0bbfe72d80d9a16de8b09657a98dc5285bb348
--- /dev/null
+++ b/ppstructure/vqa/predict_vqa_token_ser.py
@@ -0,0 +1,169 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import json
+import numpy as np
+import time
+
+import tools.infer.utility as utility
+from ppocr.data import create_operators, transform
+from ppocr.postprocess import build_post_process
+from ppocr.utils.logging import get_logger
+from ppocr.utils.visual import draw_ser_results
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+from ppstructure.utility import parse_args
+
+from paddleocr import PaddleOCR
+
+logger = get_logger()
+
+
+class SerPredictor(object):
+ def __init__(self, args):
+ self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False)
+
+ pre_process_list = [{
+ 'VQATokenLabelEncode': {
+ 'algorithm': args.vqa_algorithm,
+ 'class_path': args.ser_dict_path,
+ 'contains_re': False,
+ 'ocr_engine': self.ocr_engine
+ }
+ }, {
+ 'VQATokenPad': {
+ 'max_seq_len': 512,
+ 'return_attention_mask': True
+ }
+ }, {
+ 'VQASerTokenChunk': {
+ 'max_seq_len': 512,
+ 'return_attention_mask': True
+ }
+ }, {
+ 'Resize': {
+ 'size': [224, 224]
+ }
+ }, {
+ 'NormalizeImage': {
+ 'std': [58.395, 57.12, 57.375],
+ 'mean': [123.675, 116.28, 103.53],
+ 'scale': '1',
+ 'order': 'hwc'
+ }
+ }, {
+ 'ToCHWImage': None
+ }, {
+ 'KeepKeys': {
+ 'keep_keys': [
+ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
+ 'image', 'labels', 'segment_offset_id', 'ocr_info',
+ 'entities'
+ ]
+ }
+ }]
+ postprocess_params = {
+ 'name': 'VQASerTokenLayoutLMPostProcess',
+ "class_path": args.ser_dict_path,
+ }
+
+ self.preprocess_op = create_operators(pre_process_list,
+ {'infer_mode': True})
+ self.postprocess_op = build_post_process(postprocess_params)
+ self.predictor, self.input_tensor, self.output_tensors, self.config = \
+ utility.create_predictor(args, 'ser', logger)
+
+ def __call__(self, img):
+ ori_im = img.copy()
+ data = {'image': img}
+ data = transform(data, self.preprocess_op)
+ img = data[0]
+ if img is None:
+ return None, 0
+ img = np.expand_dims(img, axis=0)
+ img = img.copy()
+ starttime = time.time()
+
+ for idx in range(len(self.input_tensor)):
+ expand_input = np.expand_dims(data[idx], axis=0)
+ self.input_tensor[idx].copy_from_cpu(expand_input)
+
+ self.predictor.run()
+
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ preds = outputs[0]
+
+ post_result = self.postprocess_op(
+ preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]])
+ elapse = time.time() - starttime
+ return post_result, elapse
+
+
+def main(args):
+ image_file_list = get_image_file_list(args.image_dir)
+ ser_predictor = SerPredictor(args)
+ count = 0
+ total_time = 0
+
+ os.makedirs(args.output, exist_ok=True)
+ with open(
+ os.path.join(args.output, 'infer.txt'), mode='w',
+ encoding='utf-8') as f_w:
+ for image_file in image_file_list:
+ img, flag = check_and_read_gif(image_file)
+ if not flag:
+ img = cv2.imread(image_file)
+ img = img[:, :, ::-1]
+ if img is None:
+ logger.info("error in loading image:{}".format(image_file))
+ continue
+ ser_res, elapse = ser_predictor(img)
+ ser_res = ser_res[0]
+
+ res_str = '{}\t{}\n'.format(
+ image_file,
+ json.dumps(
+ {
+ "ocr_info": ser_res,
+ }, ensure_ascii=False))
+ f_w.write(res_str)
+
+ img_res = draw_ser_results(
+ image_file,
+ ser_res,
+ font_path="../doc/fonts/simfang.ttf", )
+
+ img_save_path = os.path.join(args.output,
+ os.path.basename(image_file))
+ cv2.imwrite(img_save_path, img_res)
+ logger.info("save vis result to {}".format(img_save_path))
+ if count > 0:
+ total_time += elapse
+ count += 1
+ logger.info("Predict time of {}: {}".format(image_file, elapse))
+
+
+if __name__ == "__main__":
+ main(parse_args())
diff --git a/ppstructure/vqa/requirements.txt b/ppstructure/vqa/requirements.txt
index 0042ec0baedcc3e7bbecb922d10b93c95219219d..fcd882274c4402ba2a1d34f20ee6e2befa157121 100644
--- a/ppstructure/vqa/requirements.txt
+++ b/ppstructure/vqa/requirements.txt
@@ -1,4 +1,7 @@
sentencepiece
yacs
seqeval
-paddlenlp>=2.2.1
\ No newline at end of file
+paddlenlp>=2.2.1
+pypandoc
+attrdict
+python_docx
\ No newline at end of file
diff --git a/ppstructure/vqa/tools/trans_funsd_label.py b/ppstructure/vqa/tools/trans_funsd_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef7d1db010a925b37d285befe77aa202db2141d9
--- /dev/null
+++ b/ppstructure/vqa/tools/trans_funsd_label.py
@@ -0,0 +1,151 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import sys
+import cv2
+import numpy as np
+from copy import deepcopy
+
+
+def trans_poly_to_bbox(poly):
+ x1 = np.min([p[0] for p in poly])
+ x2 = np.max([p[0] for p in poly])
+ y1 = np.min([p[1] for p in poly])
+ y2 = np.max([p[1] for p in poly])
+ return [x1, y1, x2, y2]
+
+
+def get_outer_poly(bbox_list):
+ x1 = min([bbox[0] for bbox in bbox_list])
+ y1 = min([bbox[1] for bbox in bbox_list])
+ x2 = max([bbox[2] for bbox in bbox_list])
+ y2 = max([bbox[3] for bbox in bbox_list])
+ return [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
+
+
+def load_funsd_label(image_dir, anno_dir):
+ imgs = os.listdir(image_dir)
+ annos = os.listdir(anno_dir)
+
+ imgs = [img.replace(".png", "") for img in imgs]
+ annos = [anno.replace(".json", "") for anno in annos]
+
+ fn_info_map = dict()
+ for anno_fn in annos:
+ res = []
+ with open(os.path.join(anno_dir, anno_fn + ".json"), "r") as fin:
+ infos = json.load(fin)
+ infos = infos["form"]
+ old_id2new_id_map = dict()
+ global_new_id = 0
+ for info in infos:
+ if info["text"] is None:
+ continue
+ words = info["words"]
+ if len(words) <= 0:
+ continue
+ word_idx = 1
+ curr_bboxes = [words[0]["box"]]
+ curr_texts = [words[0]["text"]]
+ while word_idx < len(words):
+ # switch to a new link
+ if words[word_idx]["box"][0] + 10 <= words[word_idx - 1][
+ "box"][2]:
+ if len("".join(curr_texts[0])) > 0:
+ res.append({
+ "transcription": " ".join(curr_texts),
+ "label": info["label"],
+ "points": get_outer_poly(curr_bboxes),
+ "linking": info["linking"],
+ "id": global_new_id,
+ })
+ if info["id"] not in old_id2new_id_map:
+ old_id2new_id_map[info["id"]] = []
+ old_id2new_id_map[info["id"]].append(global_new_id)
+ global_new_id += 1
+ curr_bboxes = [words[word_idx]["box"]]
+ curr_texts = [words[word_idx]["text"]]
+ else:
+ curr_bboxes.append(words[word_idx]["box"])
+ curr_texts.append(words[word_idx]["text"])
+ word_idx += 1
+ if len("".join(curr_texts[0])) > 0:
+ res.append({
+ "transcription": " ".join(curr_texts),
+ "label": info["label"],
+ "points": get_outer_poly(curr_bboxes),
+ "linking": info["linking"],
+ "id": global_new_id,
+ })
+ if info["id"] not in old_id2new_id_map:
+ old_id2new_id_map[info["id"]] = []
+ old_id2new_id_map[info["id"]].append(global_new_id)
+ global_new_id += 1
+ res = sorted(
+ res, key=lambda r: (r["points"][0][1], r["points"][0][0]))
+ for i in range(len(res) - 1):
+ for j in range(i, 0, -1):
+ if abs(res[j + 1]["points"][0][1] - res[j]["points"][0][1]) < 20 and \
+ (res[j + 1]["points"][0][0] < res[j]["points"][0][0]):
+ tmp = deepcopy(res[j])
+ res[j] = deepcopy(res[j + 1])
+ res[j + 1] = deepcopy(tmp)
+ else:
+ break
+ # re-generate unique ids
+ for idx, r in enumerate(res):
+ new_links = []
+ for link in r["linking"]:
+ # illegal links will be removed
+ if link[0] not in old_id2new_id_map or link[
+ 1] not in old_id2new_id_map:
+ continue
+ for src in old_id2new_id_map[link[0]]:
+ for dst in old_id2new_id_map[link[1]]:
+ new_links.append([src, dst])
+ res[idx]["linking"] = deepcopy(new_links)
+
+ fn_info_map[anno_fn] = res
+
+ return fn_info_map
+
+
+def main():
+ test_image_dir = "train_data/FUNSD/testing_data/images/"
+ test_anno_dir = "train_data/FUNSD/testing_data/annotations/"
+ test_output_dir = "train_data/FUNSD/test.json"
+
+ fn_info_map = load_funsd_label(test_image_dir, test_anno_dir)
+ with open(test_output_dir, "w") as fout:
+ for fn in fn_info_map:
+ fout.write(fn + ".png" + "\t" + json.dumps(
+ fn_info_map[fn], ensure_ascii=False) + "\n")
+
+ train_image_dir = "train_data/FUNSD/training_data/images/"
+ train_anno_dir = "train_data/FUNSD/training_data/annotations/"
+ train_output_dir = "train_data/FUNSD/train.json"
+
+ fn_info_map = load_funsd_label(train_image_dir, train_anno_dir)
+ with open(train_output_dir, "w") as fout:
+ for fn in fn_info_map:
+ fout.write(fn + ".png" + "\t" + json.dumps(
+ fn_info_map[fn], ensure_ascii=False) + "\n")
+ print("====ok====")
+ return
+
+
+if __name__ == "__main__":
+ main()
diff --git a/ppstructure/vqa/tools/trans_xfun_data.py b/ppstructure/vqa/tools/trans_xfun_data.py
index 93ec98163c6cec96ec93399c1d41524200ddc499..11d221bea40367f091b3e09dde42e87f2217a617 100644
--- a/ppstructure/vqa/tools/trans_xfun_data.py
+++ b/ppstructure/vqa/tools/trans_xfun_data.py
@@ -21,26 +21,22 @@ def transfer_xfun_data(json_path=None, output_file=None):
json_info = json.loads(lines[0])
documents = json_info["documents"]
- label_info = {}
with open(output_file, "w", encoding='utf-8') as fout:
for idx, document in enumerate(documents):
+ label_info = []
img_info = document["img"]
document = document["document"]
image_path = img_info["fname"]
- label_info["height"] = img_info["height"]
- label_info["width"] = img_info["width"]
-
- label_info["ocr_info"] = []
-
for doc in document:
- label_info["ocr_info"].append({
- "text": doc["text"],
+ x1, y1, x2, y2 = doc["box"]
+ points = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
+ label_info.append({
+ "transcription": doc["text"],
"label": doc["label"],
- "bbox": doc["box"],
+ "points": points,
"id": doc["id"],
- "linking": doc["linking"],
- "words": doc["words"]
+ "linking": doc["linking"]
})
fout.write(image_path + "\t" + json.dumps(
diff --git a/test_tipc/configs/ch_PP-OCRv2_det_PACT/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_PP-OCRv2_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
similarity index 76%
rename from test_tipc/configs/ch_PP-OCRv2_det_PACT/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_PP-OCRv2_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
index 3afc0acb799153b44bfddf713c1057f06ce525dc..91a6288eb0d4f3d2a8c968a65916295d25024c32 100644
--- a/test_tipc/configs/ch_PP-OCRv2_det_PACT/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_PP-OCRv2_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -1,9 +1,9 @@
===========================train_params===========================
-model_name:ch_PP-OCRv2_det_PACT
+model_name:ch_PP-OCRv2_det
python:python3.7
-gpu_list:0|0,1
-Global.use_gpu:True|True
-Global.auto_cast:amp
+gpu_list:192.168.0.1,192.168.0.2;0,1
+Global.use_gpu:True
+Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=50
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
@@ -12,9 +12,9 @@ train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
-trainer:pact_train
-norm_train:null
-pact_train:deploy/slim/quantization/quant.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
+trainer:norm_train
+norm_train:tools/train.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
+pact_train:null
fpgm_train:null
distill_train:null
null:null
@@ -27,8 +27,8 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:null
-quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
+norm_export:tools/export_model.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
+quant_export:null
fpgm_export:
distill_export:null
export1:null
@@ -38,7 +38,7 @@ infer_model:./inference/ch_PP-OCRv2_det_infer/
infer_export:null
infer_quant:False
inference:tools/infer/predict_det.py
---use_gpu:True|False
+--use_gpu:False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
diff --git a/test_tipc/configs/ch_PP-OCRv2_det_PACT/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_det/train_pact_infer_python.txt
similarity index 100%
rename from test_tipc/configs/ch_PP-OCRv2_det_PACT/train_infer_python.txt
rename to test_tipc/configs/ch_PP-OCRv2_det/train_pact_infer_python.txt
diff --git a/test_tipc/configs/ch_PP-OCRv2_det_KL/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_PP-OCRv2_det/train_ptq_infer_python.txt
similarity index 100%
rename from test_tipc/configs/ch_PP-OCRv2_det_KL/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_PP-OCRv2_det/train_ptq_infer_python.txt
diff --git a/test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_PP-OCRv2_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
similarity index 60%
rename from test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_PP-OCRv2_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
index e0f86d9020fb77aec5fe051dffdb4bc1018f907e..5795bc27e686164578fc246e1fa467efdc52f71f 100644
--- a/test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_PP-OCRv2_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -1,20 +1,20 @@
===========================train_params===========================
-model_name:ch_PP-OCRv2_rec_PACT
+model_name:ch_PP-OCRv2_rec
python:python3.7
-gpu_list:0|0,1
-Global.use_gpu:True|True
-Global.auto_cast:amp
-Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=50
+gpu_list:192.168.0.1,192.168.0.2;0,1
+Global.use_gpu:True
+Global.auto_cast:fp32
+Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=50
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
-Global.pretrained_model:pretrain_models/ch_PP-OCRv2_rec_train/best_accuracy
+Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
null:null
##
-trainer:pact_train
-norm_train:null
-pact_train:deploy/slim/quantization/quant.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
+pact_train:null
fpgm_train:null
distill_train:null
null:null
@@ -27,18 +27,18 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:null
-quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
-fpgm_export: null
+norm_export:tools/export_model.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
+quant_export:
+fpgm_export:
distill_export:null
export1:null
export2:null
inference_dir:Student
-infer_model:./inference/ch_PP-OCRv2_rec_slim_quant_infer
+infer_model:./inference/ch_PP-OCRv2_rec_infer
infer_export:null
-infer_quant:True
+infer_quant:False
inference:tools/infer/predict_rec.py
---use_gpu:True|False
+--use_gpu:False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1|6
diff --git a/test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_rec/train_pact_infer_python.txt
similarity index 100%
rename from test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_infer_python.txt
rename to test_tipc/configs/ch_PP-OCRv2_rec/train_pact_infer_python.txt
diff --git a/test_tipc/configs/ch_PP-OCRv2_rec_KL/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_PP-OCRv2_rec/train_ptq_infer_python.txt
similarity index 100%
rename from test_tipc/configs/ch_PP-OCRv2_rec_KL/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_PP-OCRv2_rec/train_ptq_infer_python.txt
diff --git a/test_tipc/configs/ch_PP-OCRv3_det_PACT/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_PP-OCRv3_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
similarity index 76%
rename from test_tipc/configs/ch_PP-OCRv3_det_PACT/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_PP-OCRv3_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
index 252378e4dc022be24f581f319801f4fbf2a5d0eb..7e987125a6681629a592d43f05c2ecfe51dac3f1 100644
--- a/test_tipc/configs/ch_PP-OCRv3_det_PACT/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_PP-OCRv3_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -1,9 +1,9 @@
===========================train_params===========================
-model_name:ch_PP-OCRv3_det_PACT
+model_name:ch_PP-OCRv3_det
python:python3.7
-gpu_list:0|0,1
-Global.use_gpu:True|True
-Global.auto_cast:amp
+gpu_list:192.168.0.1,192.168.0.2;0,1
+Global.use_gpu:True
+Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=50
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
@@ -12,9 +12,9 @@ train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
-trainer:pact_train
-norm_train:null
-pact_train:deploy/slim/quantization/quant.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o
+trainer:norm_train
+norm_train:tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o
+pact_train:null
fpgm_train:null
distill_train:null
null:null
@@ -27,8 +27,8 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:null
-quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o
+norm_export:tools/export_model.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o
+quant_export:null
fpgm_export:
distill_export:null
export1:null
@@ -38,7 +38,7 @@ infer_model:./inference/ch_PP-OCRv3_det_infer/
infer_export:null
infer_quant:False
inference:tools/infer/predict_det.py
---use_gpu:True|False
+--use_gpu:False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
diff --git a/test_tipc/configs/ch_PP-OCRv3_det_PACT/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv3_det/train_pact_infer_python.txt
similarity index 100%
rename from test_tipc/configs/ch_PP-OCRv3_det_PACT/train_infer_python.txt
rename to test_tipc/configs/ch_PP-OCRv3_det/train_pact_infer_python.txt
diff --git a/test_tipc/configs/ch_PP-OCRv3_det_KL/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_PP-OCRv3_det/train_ptq_infer_python.txt
similarity index 100%
rename from test_tipc/configs/ch_PP-OCRv3_det_KL/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_PP-OCRv3_det/train_ptq_infer_python.txt
diff --git a/test_tipc/configs/ch_PP-OCRv3_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_PP-OCRv3_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
index 08e1fe9ba0aba4e3ab358be188aeed0212ad08ff..7fcc8b4418c65b0f98624d92bd3896518f2ed465 100644
--- a/test_tipc/configs/ch_PP-OCRv3_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_PP-OCRv3_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -38,7 +38,7 @@ infer_model:./inference/ch_PP-OCRv3_rec_infer
infer_export:null
infer_quant:False
inference:tools/infer/predict_rec.py --rec_image_shape="3,48,320"
---use_gpu:True|False
+--use_gpu:False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1|6
diff --git a/test_tipc/configs/ch_PP-OCRv3_rec_PACT/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv3_rec/train_pact_infer_python.txt
similarity index 100%
rename from test_tipc/configs/ch_PP-OCRv3_rec_PACT/train_infer_python.txt
rename to test_tipc/configs/ch_PP-OCRv3_rec/train_pact_infer_python.txt
diff --git a/test_tipc/configs/ch_PP-OCRv3_rec_KL/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_PP-OCRv3_rec/train_ptq_infer_python.txt
similarity index 100%
rename from test_tipc/configs/ch_PP-OCRv3_rec_KL/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_PP-OCRv3_rec/train_ptq_infer_python.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
similarity index 62%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
index 3a5d8faf3fd60b4d94030e488538a0ba12345ee1..5271f78bb778f9e419da7f9bbbb6b4a6fafb305b 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -1,10 +1,10 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_PACT
+model_name:ch_ppocr_mobile_v2.0_det
python:python3.7
-gpu_list:0|0,1
-Global.use_gpu:True|True
-Global.auto_cast:amp
-Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=50
+gpu_list:192.168.0.1,192.168.0.2;0,1
+Global.use_gpu:True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=100|whole_train_whole_infer=50
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null
@@ -12,9 +12,9 @@ train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
-trainer:pact_train
-norm_train:null
-pact_train:deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o
+trainer:norm_train
+norm_train:tools/train.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
+pact_train:null
fpgm_train:null
distill_train:null
null:null
@@ -27,18 +27,18 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:null
-quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o
+norm_export:tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o
+quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:null
-train_model:./inference/ch_ppocr_mobile_v2.0_det_prune_infer/
-infer_export:null
+train_model:./inference/ch_ppocr_mobile_v2.0_det_train/best_accuracy
+infer_export:tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
---use_gpu:True|False
+--use_gpu:False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
@@ -50,4 +50,4 @@ null:null
--benchmark:True
null:null
===========================infer_benchmark_params==========================
-random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
+random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
\ No newline at end of file
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_pact_infer_python.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_pact_infer_python.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_ptq_infer_python.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_ptq_infer_python.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..631118c0a9ab98c10129f12ec1c1cf2bbac46115
--- /dev/null
+++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:ch_ppocr_mobile_v2.0_rec
+python:python3.7
+gpu_list:192.168.0.1,192.168.0.2;0,1
+Global.use_gpu:True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=50
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./inference/rec_inference
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c configs/rec/rec_icdar15_train.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c configs/rec/rec_icdar15_train.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c configs/rec/rec_icdar15_train.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/ch_ppocr_mobile_v2.0_rec_train/best_accuracy
+infer_export:tools/export_model.py -c configs/rec/rec_icdar15_train.yml -o
+infer_quant:False
+inference:tools/infer/predict_rec.py
+--use_gpu:False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1|6
+--use_tensorrt:False
+--precision:fp32
+--rec_model_dir:
+--image_dir:./inference/rec_inference
+--save_log_path:./test/output/
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,32,100]}]
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_pact_infer_python.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_pact_infer_python.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_ptq_infer_python.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_ptq_infer_python.txt
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..12388d967755c54a46efdb915ef047896dddaef7
--- /dev/null
+++ b/test_tipc/configs/ch_ppocr_server_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:ch_ppocr_server_v2.0_det
+python:python3.7
+gpu_list:192.168.0.1,192.168.0.2;0,1
+Global.use_gpu:True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=50
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_lite_infer=4
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+quant_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/ch_ppocr_server_v2.0_det_train/best_accuracy
+infer_export:tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml -o
+infer_quant:False
+inference:tools/infer/predict_det.py
+--use_gpu:False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--det_model_dir:
+--image_dir:./inference/ch_det_data_50/all-sum-510/
+--save_log_path:null
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
\ No newline at end of file
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9884ab247b80de4ca700bf084cea4faa89c86396
--- /dev/null
+++ b/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:ch_ppocr_server_v2.0_rec
+python:python3.7
+gpu_list:192.168.0.1,192.168.0.2;0,1
+Global.use_gpu:True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=5|whole_train_whole_infer=100
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./inference/rec_inference
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/ch_ppocr_server_v2.0_rec_train/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+infer_quant:False
+inference:tools/infer/predict_rec.py
+--use_gpu:False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1|6
+--use_tensorrt:False
+--precision:fp32
+--rec_model_dir:
+--image_dir:./inference/rec_inference
+--save_log_path:./test/output/
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,32,100]}]
diff --git a/test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt b/test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt
index ab3aa59b601db58b48cf18de79f77710611e2596..2c8aa953449c4b97790842bb90256280b8b20d9a 100644
--- a/test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt
+++ b/test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt
@@ -54,6 +54,6 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params==========================
batch_size:8|16
fp_items:fp32|fp16
-epoch:2
+epoch:15
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
diff --git a/test_tipc/configs/det_r18_vd_db_v2_0/train_infer_python.txt b/test_tipc/configs/det_r18_vd_db_v2_0/train_infer_python.txt
index 33e4dbf2337f3799328516119a213bc0f14af9fe..df88c0e5434511fb48deac699e8f67fc535765d3 100644
--- a/test_tipc/configs/det_r18_vd_db_v2_0/train_infer_python.txt
+++ b/test_tipc/configs/det_r18_vd_db_v2_0/train_infer_python.txt
@@ -54,5 +54,5 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params==========================
batch_size:8|16
fp_items:fp32|fp16
-epoch:2
+epoch:15
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
diff --git a/test_tipc/configs/det_r50_db_plusplus/train_infer_python.txt b/test_tipc/configs/det_r50_db_plusplus/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..04a3e845859167f78d5b3dd799236f8b8a051e81
--- /dev/null
+++ b/test_tipc/configs/det_r50_db_plusplus/train_infer_python.txt
@@ -0,0 +1,59 @@
+===========================train_params===========================
+model_name:det_r50_db_plusplus
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=300
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c configs/det/det_r50_db++_ic15.yml -o Global.pretrained_model=./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c configs/det/det_r50_db++_ic15.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+inference_dir:null
+train_model:./inference/det_r50_db++_train/best_accuracy
+infer_export:tools/export_model.py -c configs/det/det_r50_db++_ic15.yml -o
+infer_quant:False
+inference:tools/infer/predict_det.py --det_algorithm="DB++"
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--det_model_dir:
+--image_dir:./inference/ch_det_data_50/all-sum-510/
+null:null
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
+===========================train_benchmark_params==========================
+batch_size:8|16
+fp_items:fp32|fp16
+epoch:2
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
diff --git a/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt
index 8477a4fa74f7a0617104aa83617fc6f61b8234b3..24e4d760c37828c213741b9ff127d55df2f9335a 100644
--- a/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_east_v2_0/train_infer_python.txt
@@ -1,13 +1,13 @@
===========================train_params===========================
model_name:det_r50_vd_east_v2_0
python:python3.7
-gpu_list:0
+gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
-Global.pretrained_model:null
+Global.pretrained_model:./pretrain_models/det_r50_vd_east_v2.0_train/best_accuracy
train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
diff --git a/test_tipc/configs/det_r50_vd_pse_v2_0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_pse_v2_0/train_infer_python.txt
index 62da89fe1c8e3a7c2b7586eae6b2589f94237a2e..53511e6ae21003cb9df6a92d3931577fbbef5b18 100644
--- a/test_tipc/configs/det_r50_vd_pse_v2_0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_pse_v2_0/train_infer_python.txt
@@ -54,5 +54,5 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params==========================
batch_size:8
fp_items:fp32|fp16
-epoch:2
+epoch:10
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
diff --git a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
index 11444a3ac1b99c54dae31d28b83ffe14269599d9..b70ef46b4afb3a39f3bbd3d6274f0135a0646a37 100644
--- a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
@@ -4,16 +4,16 @@ python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
-Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=5000
+Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
-Global.pretrained_model:null
+Global.pretrained_model:./pretrain_models/det_r50_vd_sast_icdar15_v2.0_train/best_accuracy
train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml -o Global.pretrained_model=./pretrain_models/ResNet50_vd_ssld_pretrained
+norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -45,7 +45,7 @@ inference:tools/infer/predict_det.py
--use_tensorrt:False
--precision:fp32
--det_model_dir:
---image_dir:./inference/ch_det_data_50/all-sum-510/
+--image_dir:./inference/ch_det_data_50/all-sum-510/00008790.jpg
null:null
--benchmark:True
--det_algorithm:SAST
diff --git a/test_tipc/configs/en_table_structure/table_mv3.yml b/test_tipc/configs/en_table_structure/table_mv3.yml
new file mode 100755
index 0000000000000000000000000000000000000000..281038b968a5bf829483882117d779ec7de1976d
--- /dev/null
+++ b/test_tipc/configs/en_table_structure/table_mv3.yml
@@ -0,0 +1,124 @@
+Global:
+ use_gpu: true
+ epoch_num: 10
+ log_smooth_window: 20
+ print_batch_step: 5
+ save_model_dir: ./output/table_mv3/
+ save_epoch_step: 3
+ # evaluation is run every 400 iterations after the 0th iteration
+ eval_batch_step: [0, 400]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: ppstructure/docs/table/table.jpg
+ save_res_path: output/table_mv3
+ # for data or label process
+ character_dict_path: ppocr/utils/dict/table_structure_dict.txt
+ character_type: en
+ max_text_length: 800
+ infer_mode: False
+ process_total_num: 0
+ process_cut_num: 0
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ clip_norm: 5.0
+ lr:
+ learning_rate: 0.001
+ regularizer:
+ name: 'L2'
+ factor: 0.00000
+
+Architecture:
+ model_type: table
+ algorithm: TableAttn
+ Backbone:
+ name: MobileNetV3
+ scale: 1.0
+ model_name: large
+ Head:
+ name: TableAttentionHead
+ hidden_size: 256
+ loc_type: 2
+ max_text_length: 800
+
+Loss:
+ name: TableAttentionLoss
+ structure_weight: 100.0
+ loc_weight: 10000.0
+
+PostProcess:
+ name: TableLabelDecode
+
+Metric:
+ name: TableMetric
+ main_indicator: acc
+ compute_bbox_metric: false # cost many time, set False for training
+
+Train:
+ dataset:
+ name: PubTabDataSet
+ data_dir: ./train_data/pubtabnet/train
+ label_file_list: [./train_data/pubtabnet/train.jsonl]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - TableLabelEncode:
+ learn_empty_box: False
+ merge_no_span_structure: False
+ replace_empty_cell_token: False
+ - TableBoxEncode:
+ - ResizeTableImage:
+ max_len: 488
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: 'hwc'
+ - PaddingTableImage:
+ size: [488, 488]
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
+ loader:
+ shuffle: True
+ batch_size_per_card: 32
+ drop_last: True
+ num_workers: 1
+
+Eval:
+ dataset:
+ name: PubTabDataSet
+ data_dir: ./train_data/pubtabnet/test/
+ label_file_list: [./train_data/pubtabnet/test.jsonl]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - TableLabelEncode:
+ learn_empty_box: False
+ merge_no_span_structure: False
+ replace_empty_cell_token: False
+ - TableBoxEncode:
+ - ResizeTableImage:
+ max_len: 488
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: 'hwc'
+ - PaddingTableImage:
+ size: [488, 488]
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 16
+ num_workers: 1
diff --git a/test_tipc/configs/en_table_structure/train_infer_python.txt b/test_tipc/configs/en_table_structure/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d9f3b30e16c75281a929130d877b947a23c16190
--- /dev/null
+++ b/test_tipc/configs/en_table_structure/train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:en_table_structure
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:fp32
+Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=50
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
+Global.pretrained_model:./pretrain_models/en_ppocr_mobile_v2.0_table_structure_train/best_accuracy
+train_model_name:latest
+train_infer_img_dir:./ppstructure/docs/table/table.jpg
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o
+quant_export:
+fpgm_export:
+distill_export:null
+export1:null
+export2:null
+##
+infer_model:./inference/en_ppocr_mobile_v2.0_table_structure_infer
+infer_export:null
+infer_quant:False
+inference:ppstructure/table/predict_table.py --det_model_dir=./inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=./inference/en_ppocr_mobile_v2.0_table_rec_infer --rec_char_dict_path=./ppocr/utils/dict/table_dict.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --det_limit_side_len=736 --det_limit_type=min --output ./output/table
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--table_model_dir:
+--image_dir:./ppstructure/docs/table/table.jpg
+null:null
+--benchmark:False
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,488,488]}]
diff --git a/test_tipc/configs/en_table_structure/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/en_table_structure/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..41d236c3765fbf6a711c6739d8dee4f41a147039
--- /dev/null
+++ b/test_tipc/configs/en_table_structure/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:en_table_structure
+python:python3.7
+gpu_list:192.168.0.1,192.168.0.2;0,1
+Global.use_gpu:True
+Global.auto_cast:fp32
+Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=50
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
+Global.pretrained_model:./pretrain_models/en_ppocr_mobile_v2.0_table_structure_train/best_accuracy
+train_model_name:latest
+train_infer_img_dir:./ppstructure/docs/table/table.jpg
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o
+quant_export:
+fpgm_export:
+distill_export:null
+export1:null
+export2:null
+##
+infer_model:./inference/en_ppocr_mobile_v2.0_table_structure_infer
+infer_export:null
+infer_quant:False
+inference:ppstructure/table/predict_table.py --det_model_dir=./inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=./inference/en_ppocr_mobile_v2.0_table_rec_infer --rec_char_dict_path=./ppocr/utils/dict/table_dict.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --det_limit_side_len=736 --det_limit_type=min --output ./output/table
+--use_gpu:False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--table_model_dir:
+--image_dir:./ppstructure/docs/table/table.jpg
+null:null
+--benchmark:False
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,488,488]}]
diff --git a/test_tipc/configs/en_table_structure/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/en_table_structure/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..31ac1ed53f2adc9810bc4fd2cf4f874d89d49606
--- /dev/null
+++ b/test_tipc/configs/en_table_structure/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:en_table_structure
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:amp
+Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=50
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
+Global.pretrained_model:./pretrain_models/en_ppocr_mobile_v2.0_table_structure_train/best_accuracy
+train_model_name:latest
+train_infer_img_dir:./ppstructure/docs/table/table.jpg
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o
+quant_export:
+fpgm_export:
+distill_export:null
+export1:null
+export2:null
+##
+infer_model:./inference/en_ppocr_mobile_v2.0_table_structure_infer
+infer_export:null
+infer_quant:False
+inference:ppstructure/table/predict_table.py --det_model_dir=./inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=./inference/en_ppocr_mobile_v2.0_table_rec_infer --rec_char_dict_path=./ppocr/utils/dict/table_dict.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --det_limit_side_len=736 --det_limit_type=min --output ./output/table
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--table_model_dir:
+--image_dir:./ppstructure/docs/table/table.jpg
+null:null
+--benchmark:False
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,488,488]}]
diff --git a/test_tipc/configs/en_table_structure/train_pact_infer_python.txt b/test_tipc/configs/en_table_structure/train_pact_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f62e8b68bc6c1af06a65a8dfb438d5d63576e123
--- /dev/null
+++ b/test_tipc/configs/en_table_structure/train_pact_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:en_table_structure_PACT
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:fp32
+Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=50
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
+Global.pretrained_model:./pretrain_models/en_ppocr_mobile_v2.0_table_structure_train/best_accuracy
+train_model_name:latest
+train_infer_img_dir:./ppstructure/docs/table/table.jpg
+null:null
+##
+trainer:pact_train
+norm_train:null
+pact_train:deploy/slim/quantization/quant.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:null
+quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o
+fpgm_export:
+distill_export:null
+export1:null
+export2:null
+##
+infer_model:./inference/en_ppocr_mobile_v2.0_table_structure_infer
+infer_export:null
+infer_quant:True
+inference:ppstructure/table/predict_table.py --det_model_dir=./inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=./inference/en_ppocr_mobile_v2.0_table_rec_infer --rec_char_dict_path=./ppocr/utils/dict/table_dict.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --det_limit_side_len=736 --det_limit_type=min --output ./output/table
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--table_model_dir:
+--image_dir:./ppstructure/docs/table/table.jpg
+null:null
+--benchmark:False
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,488,488]}]
diff --git a/test_tipc/configs/en_table_structure/train_ptq_infer_python.txt b/test_tipc/configs/en_table_structure/train_ptq_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e8f7bbaa50417b97f79596634677fff0a95cb47f
--- /dev/null
+++ b/test_tipc/configs/en_table_structure/train_ptq_infer_python.txt
@@ -0,0 +1,21 @@
+===========================train_params===========================
+model_name:en_table_structure_KL
+python:python3.7
+Global.pretrained_model:
+Global.save_inference_dir:null
+infer_model:./inference/en_ppocr_mobile_v2.0_table_structure_infer/
+infer_export:deploy/slim/quantization/quant_kl.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o
+infer_quant:True
+inference:ppstructure/table/predict_table.py --det_model_dir=./inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=./inference/en_ppocr_mobile_v2.0_table_rec_infer --rec_char_dict_path=./ppocr/utils/dict/table_dict.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --det_limit_side_len=736 --det_limit_type=min --output ./output/table
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:int8
+--table_model_dir:
+--image_dir:./ppstructure/docs/table/table.jpg
+null:null
+--benchmark:False
+null:null
+null:null
diff --git a/test_tipc/configs/rec_mtb_nrtr/rec_mtb_nrtr.yml b/test_tipc/configs/rec_mtb_nrtr/rec_mtb_nrtr.yml
index 15119bb2a9de02c19684d21ad5a1859db94895ce..8118d587248b7e4797e3a75c897e7b0a3d71b364 100644
--- a/test_tipc/configs/rec_mtb_nrtr/rec_mtb_nrtr.yml
+++ b/test_tipc/configs/rec_mtb_nrtr/rec_mtb_nrtr.yml
@@ -49,7 +49,7 @@ Architecture:
Loss:
- name: NRTRLoss
+ name: CELoss
smoothing: True
PostProcess:
@@ -69,7 +69,7 @@ Train:
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
- - NRTRRecResizeImg:
+ - GrayRecResizeImg:
image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV
- KeepKeys:
@@ -90,7 +90,7 @@ Eval:
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
- - NRTRRecResizeImg:
+ - GrayRecResizeImg:
image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV
- KeepKeys:
@@ -99,5 +99,5 @@ Eval:
shuffle: False
drop_last: False
batch_size_per_card: 256
- num_workers: 1
+ num_workers: 4
use_shared_memory: False
diff --git a/test_tipc/configs/rec_r45_abinet/rec_r45_abinet.yml b/test_tipc/configs/rec_r45_abinet/rec_r45_abinet.yml
new file mode 100644
index 0000000000000000000000000000000000000000..5b5890e7728b9a1cb629744bd5d56488657c73f3
--- /dev/null
+++ b/test_tipc/configs/rec_r45_abinet/rec_r45_abinet.yml
@@ -0,0 +1,106 @@
+Global:
+ use_gpu: True
+ epoch_num: 10
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/rec/r45_abinet/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 2000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words_en/word_10.png
+ # for data or label process
+ character_dict_path:
+ character_type: en
+ max_text_length: 25
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_abinet.txt
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.99
+ clip_norm: 20.0
+ lr:
+ name: Piecewise
+ decay_epochs: [6]
+ values: [0.0001, 0.00001]
+ regularizer:
+ name: 'L2'
+ factor: 0.
+
+Architecture:
+ model_type: rec
+ algorithm: ABINet
+ in_channels: 3
+ Transform:
+ Backbone:
+ name: ResNet45
+
+ Head:
+ name: ABINetHead
+ use_lang: True
+ iter_size: 3
+
+
+Loss:
+ name: CELoss
+ ignore_index: &ignore_index 100 # Must be greater than the number of character classes
+
+PostProcess:
+ name: ABINetLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data/
+ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: RGB
+ channel_first: False
+ - ABINetRecAug:
+ - ABINetLabelEncode: # Class handling label
+ ignore_index: *ignore_index
+ - ABINetRecResizeImg:
+ image_shape: [3, 32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 96
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data
+ label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: RGB
+ channel_first: False
+ - ABINetLabelEncode: # Class handling label
+ ignore_index: *ignore_index
+ - ABINetRecResizeImg:
+ image_shape: [3, 32, 128]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 4
+ use_shared_memory: False
diff --git a/test_tipc/configs/rec_r45_abinet/train_infer_python.txt b/test_tipc/configs/rec_r45_abinet/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ecab1bcbbde11fc6d14357b6715033704c2c3316
--- /dev/null
+++ b/test_tipc/configs/rec_r45_abinet/train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:rec_abinet
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./inference/rec_inference
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/rec_r45_abinet/rec_r45_abinet.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c test_tipc/configs/rec_r45_abinet/rec_r45_abinet.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/rec_r45_abinet/rec_r45_abinet.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/rec_r45_abinet_train/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/rec_r45_abinet/rec_r45_abinet.yml -o
+infer_quant:False
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,128" --rec_algorithm="ABINet"
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1|6
+--use_tensorrt:False
+--precision:fp32
+--rec_model_dir:
+--image_dir:./inference/rec_inference
+--save_log_path:./test/output/
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,32,128]}]
diff --git a/test_tipc/configs/rec_svtrnet/rec_svtrnet.yml b/test_tipc/configs/rec_svtrnet/rec_svtrnet.yml
new file mode 100644
index 0000000000000000000000000000000000000000..140b17e0e79f9895167e9c51d86ced173e44a541
--- /dev/null
+++ b/test_tipc/configs/rec_svtrnet/rec_svtrnet.yml
@@ -0,0 +1,117 @@
+Global:
+ use_gpu: True
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/rec/svtr/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations after the 0th iteration
+ eval_batch_step: [0, 2000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words_en/word_10.png
+ # for data or label process
+ character_dict_path:
+ character_type: en
+ max_text_length: 25
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_svtr_tiny.txt
+
+
+Optimizer:
+ name: AdamW
+ beta1: 0.9
+ beta2: 0.99
+ epsilon: 8.e-8
+ weight_decay: 0.05
+ no_weight_decay_name: norm pos_embed
+ one_dim_param_no_weight_decay: true
+ lr:
+ name: Cosine
+ learning_rate: 0.0005
+ warmup_epoch: 2
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ name: STN_ON
+ tps_inputsize: [32, 64]
+ tps_outputsize: [32, 100]
+ num_control_points: 20
+ tps_margins: [0.05,0.05]
+ stn_activation: none
+ Backbone:
+ name: SVTRNet
+ img_size: [32, 100]
+ out_char_num: 25
+ out_channels: 192
+ patch_merging: 'Conv'
+ embed_dim: [64, 128, 256]
+ depth: [3, 6, 3]
+ num_heads: [2, 4, 8]
+ mixer: ['Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[7, 11], [7, 11], [7, 11]]
+ last_stage: True
+ prenorm: false
+ Neck:
+ name: SequenceEncoder
+ encoder_type: reshape
+ Head:
+ name: CTCHead
+
+Loss:
+ name: CTCLoss
+
+PostProcess:
+ name: CTCLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data/
+ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - CTCLabelEncode: # Class handling label
+ - SVTRRecResizeImg:
+ image_shape: [3, 64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 512
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data
+ label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - CTCLabelEncode: # Class handling label
+ - SVTRRecResizeImg:
+ image_shape: [3, 64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/rec_svtrnet/train_infer_python.txt
similarity index 51%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/rec_svtrnet/train_infer_python.txt
index a1a2a8e63dae5122d540d96a802fdebd75f554ea..a7e4a24063b2e248f2ab92d5efd257a2837c0a34 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/rec_svtrnet/train_infer_python.txt
@@ -1,43 +1,43 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_PACT
+model_name:rec_svtrnet
python:python3.7
-gpu_list:0
+gpu_list:0|0,1
Global.use_gpu:True|True
-Global.auto_cast:amp
-Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=50
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
-Global.checkpoints:null
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
+Global.pretrained_model:null
train_model_name:latest
-train_infer_img_dir:./train_data/ic15_data/test/word_1.png
+train_infer_img_dir:./inference/rec_inference
null:null
##
-trainer:pact_train
-norm_train:null
-pact_train:deploy/slim/quantization/quant.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
+pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
-===========================eval_params===========================
-eval:null
+===========================eval_params===========================
+eval:tools/eval.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:null
-quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
+quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
-inference_dir:null
-infer_model:./inference/ch_ppocr_mobile_v2.0_rec_slim_infer/
-infer_export:null
+##
+train_model:./inference/rec_svtrnet_train/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
infer_quant:False
-inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ppocr_keys_v1.txt --rec_image_shape="3,32,100"
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,64,256" --rec_algorithm="SVTR"
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
@@ -50,4 +50,4 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ppocr_ke
--benchmark:True
null:null
===========================infer_benchmark_params==========================
-random_infer_input:[{float32,[3,32,320]}]
+random_infer_input:[{float32,[3,64,256]}]
diff --git a/test_tipc/configs/rec_vitstr_none_ce/rec_vitstr_none_ce.yml b/test_tipc/configs/rec_vitstr_none_ce/rec_vitstr_none_ce.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a0aed488755f7cb6fed18a5747e9b7f62f57da86
--- /dev/null
+++ b/test_tipc/configs/rec_vitstr_none_ce/rec_vitstr_none_ce.yml
@@ -0,0 +1,104 @@
+Global:
+ use_gpu: True
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/rec/vitstr_none_ce/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations after the 0th iteration#
+ eval_batch_step: [0, 2000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words_en/word_10.png
+ # for data or label process
+ character_dict_path: ppocr/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_vitstr.txt
+
+
+Optimizer:
+ name: Adadelta
+ epsilon: 1.e-8
+ rho: 0.95
+ clip_norm: 5.0
+ lr:
+ learning_rate: 1.0
+
+Architecture:
+ model_type: rec
+ algorithm: ViTSTR
+ in_channels: 1
+ Transform:
+ Backbone:
+ name: ViTSTR
+ Neck:
+ name: SequenceEncoder
+ encoder_type: reshape
+ Head:
+ name: CTCHead
+
+Loss:
+ name: CELoss
+ smoothing: False
+ with_all: True
+ ignore_index: &ignore_index 0 # Must be zero or greater than the number of character classes
+
+PostProcess:
+ name: ViTSTRLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data/
+ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ViTSTRLabelEncode: # Class handling label
+ ignore_index: *ignore_index
+ - GrayRecResizeImg:
+ image_shape: [224, 224] # W H
+ resize_type: PIL # PIL or OpenCV
+ inter_type: 'Image.BICUBIC'
+ scale: false
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 48
+ drop_last: True
+ num_workers: 8
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data
+ label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ViTSTRLabelEncode: # Class handling label
+ ignore_index: *ignore_index
+ - GrayRecResizeImg:
+ image_shape: [224, 224] # W H
+ resize_type: PIL # PIL or OpenCV
+ inter_type: 'Image.BICUBIC'
+ scale: false
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/test_tipc/configs/rec_vitstr_none_ce/train_infer_python.txt b/test_tipc/configs/rec_vitstr_none_ce/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..04c5742ea2ddaf01e782d8b39c21bcbcfa0a7ce7
--- /dev/null
+++ b/test_tipc/configs/rec_vitstr_none_ce/train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:rec_vitstr
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./inference/rec_inference
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/rec_vitstr_none_ce/rec_vitstr_none_ce.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c test_tipc/configs/rec_vitstr_none_ce/rec_vitstr_none_ce.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/rec_vitstr_none_ce/rec_vitstr_none_ce.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/rec_vitstr_none_ce_train/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/rec_vitstr_none_ce/rec_vitstr_none_ce.yml -o
+infer_quant:False
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/EN_symbol_dict.txt --rec_image_shape="1,224,224" --rec_algorithm="ViTSTR"
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1|6
+--use_tensorrt:False
+--precision:fp32
+--rec_model_dir:
+--image_dir:./inference/rec_inference
+--save_log_path:./test/output/
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[1,224,224]}]
diff --git a/test_tipc/configs/table_master/table_master.yml b/test_tipc/configs/table_master/table_master.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c519b5b8f464d8843888155387b74a8416821f2f
--- /dev/null
+++ b/test_tipc/configs/table_master/table_master.yml
@@ -0,0 +1,136 @@
+Global:
+ use_gpu: true
+ epoch_num: 17
+ log_smooth_window: 20
+ print_batch_step: 100
+ save_model_dir: ./output/table_master/
+ save_epoch_step: 17
+ eval_batch_step: [0, 6259]
+ cal_metric_during_train: true
+ pretrained_model: null
+ checkpoints:
+ save_inference_dir: output/table_master/infer
+ use_visualdl: false
+ infer_img: ppstructure/docs/table/table.jpg
+ save_res_path: ./output/table_master
+ character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
+ infer_mode: false
+ max_text_length: 500
+ process_total_num: 0
+ process_cut_num: 0
+
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ lr:
+ name: MultiStepDecay
+ learning_rate: 0.001
+ milestones: [12, 15]
+ gamma: 0.1
+ warmup_epoch: 0.02
+ regularizer:
+ name: L2
+ factor: 0.0
+
+Architecture:
+ model_type: table
+ algorithm: TableMaster
+ Backbone:
+ name: TableResNetExtra
+ gcb_config:
+ ratio: 0.0625
+ headers: 1
+ att_scale: False
+ fusion_type: channel_add
+ layers: [False, True, True, True]
+ layers: [1,2,5,3]
+ Head:
+ name: TableMasterHead
+ hidden_size: 512
+ headers: 8
+ dropout: 0
+ d_ff: 2024
+ max_text_length: 500
+
+Loss:
+ name: TableMasterLoss
+ ignore_index: 42 # set to len of dict + 3
+
+PostProcess:
+ name: TableMasterLabelDecode
+ box_shape: pad
+
+Metric:
+ name: TableMetric
+ main_indicator: acc
+ compute_bbox_metric: False
+
+Train:
+ dataset:
+ name: PubTabDataSet
+ data_dir: ./train_data/pubtabnet/train
+ label_file_list: [./train_data/pubtabnet/train.jsonl]
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: False
+ - TableMasterLabelEncode:
+ learn_empty_box: False
+ merge_no_span_structure: True
+ replace_empty_cell_token: True
+ - ResizeTableImage:
+ max_len: 480
+ resize_bboxes: True
+ - PaddingTableImage:
+ size: [480, 480]
+ - TableBoxEncode:
+ use_xywh: True
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.5, 0.5, 0.5]
+ std: [0.5, 0.5, 0.5]
+ order: hwc
+ - ToCHWImage: null
+ - KeepKeys:
+ keep_keys: [image, structure, bboxes, bbox_masks, shape]
+ loader:
+ shuffle: True
+ batch_size_per_card: 10
+ drop_last: True
+ num_workers: 8
+
+Eval:
+ dataset:
+ name: PubTabDataSet
+ data_dir: ./train_data/pubtabnet/test/
+ label_file_list: [./train_data/pubtabnet/test.jsonl]
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: False
+ - TableMasterLabelEncode:
+ learn_empty_box: False
+ merge_no_span_structure: True
+ replace_empty_cell_token: True
+ - ResizeTableImage:
+ max_len: 480
+ resize_bboxes: True
+ - PaddingTableImage:
+ size: [480, 480]
+ - TableBoxEncode:
+ use_xywh: True
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.5, 0.5, 0.5]
+ std: [0.5, 0.5, 0.5]
+ order: hwc
+ - ToCHWImage: null
+ - KeepKeys:
+ keep_keys: [image, structure, bboxes, bbox_masks, shape]
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 10
+ num_workers: 8
\ No newline at end of file
diff --git a/test_tipc/configs/table_master/train_infer_python.txt b/test_tipc/configs/table_master/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..56b8e636026939ae8cd700308690010e1300d8f6
--- /dev/null
+++ b/test_tipc/configs/table_master/train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:table_master
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:fp32
+Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
+Global.pretrained_model:./pretrain_models/table_structure_tablemaster_train/best_accuracy
+train_model_name:latest
+train_infer_img_dir:./ppstructure/docs/table/table.jpg
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/table_master/table_master.yml -o Global.print_batch_step=10
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/table_master/table_master.yml -o
+quant_export:
+fpgm_export:
+distill_export:null
+export1:null
+export2:null
+##
+infer_model:null
+infer_export:null
+infer_quant:False
+inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_master_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --output ./output/table --table_algorithm=TableMaster --table_max_len=480
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--table_model_dir:
+--image_dir:./ppstructure/docs/table/table.jpg
+null:null
+--benchmark:False
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,480,480]}]
diff --git a/test_tipc/docs/test_train_fleet_inference_python.md b/test_tipc/docs/test_train_fleet_inference_python.md
index 4479a47da83a951eeed9d7d0e8f9077fc0a9fed4..9fddb5d1634b452f1906a83bca4157dbaec47c81 100644
--- a/test_tipc/docs/test_train_fleet_inference_python.md
+++ b/test_tipc/docs/test_train_fleet_inference_python.md
@@ -15,7 +15,7 @@ Linux GPU/CPU 多机多卡训练推理测试的主程序为`test_train_inference
| 算法名称 | 模型名称 | device_CPU | device_GPU | batchsize |
| :----: | :----: | :----: | :----: | :----: |
-| PP-OCRv3 | ch_PP-OCRv3_rec | 支持 | 支持 | 1 |
+| PP-OCRv3 | ch_PP-OCRv3_rec | 支持 | - | 1/6 |
## 2. 测试流程
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index 65ac2bbebbff9f2132a46d28c2cccccb2864f183..8e1758abb8adb3b120704d590e77e05476fb9d4e 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -22,13 +22,19 @@ trainer_list=$(func_parser_value "${lines[14]}")
if [ ${MODE} = "benchmark_train" ];then
pip install -r requirements.txt
- if [[ ${model_name} =~ "det_mv3_db_v2_0" || ${model_name} =~ "det_r50_vd_east_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" || ${model_name} =~ "det_r18_db_v2_0" ]];then
+ if [[ ${model_name} =~ "det_mv3_db_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" || ${model_name} =~ "det_r18_db_v2_0" ]];then
rm -rf ./train_data/icdar2015
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015.tar && cd ../
fi
- if [[ ${model_name} =~ "det_r50_vd_east_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then
+ if [[ ${model_name} =~ "det_r50_vd_east_v2_0" ]]; then
+ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
+ cd ./pretrain_models/ && tar xf det_r50_vd_east_v2.0_train.tar && cd ../
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate
+ cd ./train_data/ && tar xf icdar2015.tar && cd ../
+ fi
+ if [[ ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015.tar && cd ../
@@ -55,6 +61,16 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
if [ ${model_name} == "en_table_structure" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf en_ppocr_mobile_v2.0_table_structure_train.tar && cd ../
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate
+ cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../
+ fi
+ if [[ ${model_name} =~ "det_r50_db_plusplus" ]];then
+ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams --no-check-certificate
+ fi
+ if [ ${model_name} == "table_master" ];then
+ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar --no-check-certificate
+ cd ./pretrain_models/ && tar xf table_structure_tablemaster_train.tar && cd ../
fi
cd ./pretrain_models/ && tar xf det_mv3_db_v2.0_train.tar && cd ../
rm -rf ./train_data/icdar2015
@@ -95,8 +111,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
fi
if [ ${model_name} == "det_r50_vd_sast_icdar15_v2.0" ] || [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
+ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
cd ./train_data && tar xf total_text_lite.tar && ln -s total_text_lite total_text && cd ../
+ cd ./pretrain_models && tar xf det_r50_vd_sast_icdar15_v2.0_train.tar && cd ../
fi
if [ ${model_name} == "det_mv3_db_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
@@ -115,6 +133,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_mv3_east_v2.0_train.tar && cd ../
fi
+ if [ ${model_name} == "det_r50_vd_east_v2_0" ]; then
+ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
+ cd ./pretrain_models/ && tar xf det_r50_vd_east_v2.0_train.tar && cd ../
+ fi
elif [ ${MODE} = "whole_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
@@ -147,9 +169,12 @@ elif [ ${MODE} = "whole_train_whole_infer" ];then
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
cd ./train_data && tar xf total_text.tar && ln -s total_text_lite total_text && cd ../
fi
- if [ ${model_name} == "en_table_structure" ];then
+ if [[ ${model_name} =~ "en_table_structure" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf en_ppocr_mobile_v2.0_table_structure_train.tar && cd ../
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate
+ cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../
fi
elif [ ${MODE} = "lite_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
@@ -172,9 +197,12 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_PP-OCRv3_det_distill_train.tar && cd ../
fi
- if [ ${model_name} == "en_table_structure" ];then
+ if [[ ${model_name} =~ "en_table_structure" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf en_ppocr_mobile_v2.0_table_structure_train.tar && cd ../
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate
+ cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../
fi
elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
@@ -335,13 +363,15 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_east_v2.0_train.tar & cd ../
fi
- if [ ${model_name} == "en_table_structure" ];then
+ if [[ ${model_name} =~ "en_table_structure" ]];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate
- cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && cd ../
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate
+ cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../
fi
fi
-if [ ${MODE} = "klquant_whole_infer" ]; then
+if [[ ${model_name} =~ "KL" ]]; then
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015_lite.tar && rm -rf ./icdar2015 && ln -s ./icdar2015_lite ./icdar2015 && cd ../
if [ ${model_name} = "ch_ppocr_mobile_v2.0_det_KL" ]; then
@@ -382,7 +412,15 @@ if [ ${MODE} = "klquant_whole_infer" ]; then
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar --no-check-certificate
cd ./train_data/ && tar xf ic15_data.tar && cd ../
cd ./inference && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf rec_inference.tar && cd ../
- fi
+ fi
+ if [ ${model_name} = "en_table_structure_KL" ];then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/pubtabnet.tar --no-check-certificate
+ cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../
+ cd ./train_data/ && tar xf pubtabnet.tar && cd ../
+ fi
fi
if [ ${MODE} = "cpp_infer" ];then
diff --git a/test_tipc/readme.md b/test_tipc/readme.md
index effb2f168b6cc91012bef3de120de9e98a21dbda..1c637d76f99fffdfdc5a053fa0c5b9336fe4b731 100644
--- a/test_tipc/readme.md
+++ b/test_tipc/readme.md
@@ -54,6 +54,7 @@
| NRTR |rec_mtb_nrtr | 识别 | 支持 | 多机多卡 混合精度 | - | - |
| SAR |rec_r31_sar | 识别 | 支持 | 多机多卡 混合精度 | - | - |
| PGNet |rec_r34_vd_none_none_ctc_v2.0 | 端到端| 支持 | 多机多卡 混合精度 | - | - |
+| TableMaster |table_structure_tablemaster_train | 表格识别| 支持 | 多机多卡 混合精度 | - | - |
diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh
index 4ad2035f48ea43227caa16900bd186bfd62f11e6..356bc98041fffa8f0437c6419fc72c06d5e719f7 100644
--- a/test_tipc/test_paddle2onnx.sh
+++ b/test_tipc/test_paddle2onnx.sh
@@ -62,7 +62,8 @@ function func_paddle2onnx(){
set_save_model=$(func_set_params "--save_file" "${det_save_file_value}")
set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}")
set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}")
- trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker}"
+ trans_det_log="${LOG_PATH}/trans_model_det.log"
+ trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
@@ -73,7 +74,8 @@ function func_paddle2onnx(){
set_save_model=$(func_set_params "--save_file" "${rec_save_file_value}")
set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}")
set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}")
- trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker}"
+ trans_rec_log="${LOG_PATH}/trans_model_rec.log"
+ trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
@@ -85,7 +87,8 @@ function func_paddle2onnx(){
set_save_model=$(func_set_params "--save_file" "${det_save_file_value}")
set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}")
set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}")
- trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker}"
+ trans_det_log="${LOG_PATH}/trans_model_det.log"
+ trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
@@ -97,7 +100,8 @@ function func_paddle2onnx(){
set_save_model=$(func_set_params "--save_file" "${rec_save_file_value}")
set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}")
set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}")
- trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker}"
+ trans_rec_log="${LOG_PATH}/trans_model_rec.log"
+ trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
diff --git a/test_tipc/test_ptq_inference_python.sh b/test_tipc/test_ptq_inference_python.sh
new file mode 100644
index 0000000000000000000000000000000000000000..288e6098966be4aaf2953d627e7890963100cb6e
--- /dev/null
+++ b/test_tipc/test_ptq_inference_python.sh
@@ -0,0 +1,158 @@
+#!/bin/bash
+source test_tipc/common_func.sh
+
+FILENAME=$1
+# MODE be one of [''whole_infer']
+MODE=$2
+
+IFS=$'\n'
+# parser klquant_infer params
+
+dataline=$(awk 'NR==1, NR==17{print}' $FILENAME)
+lines=(${dataline})
+model_name=$(func_parser_value "${lines[1]}")
+python=$(func_parser_value "${lines[2]}")
+export_weight=$(func_parser_key "${lines[3]}")
+save_infer_key=$(func_parser_key "${lines[4]}")
+# parser inference model
+infer_model_dir_list=$(func_parser_value "${lines[5]}")
+infer_export_list=$(func_parser_value "${lines[6]}")
+infer_is_quant=$(func_parser_value "${lines[7]}")
+# parser inference
+inference_py=$(func_parser_value "${lines[8]}")
+use_gpu_key=$(func_parser_key "${lines[9]}")
+use_gpu_list=$(func_parser_value "${lines[9]}")
+use_mkldnn_key=$(func_parser_key "${lines[10]}")
+use_mkldnn_list=$(func_parser_value "${lines[10]}")
+cpu_threads_key=$(func_parser_key "${lines[11]}")
+cpu_threads_list=$(func_parser_value "${lines[11]}")
+batch_size_key=$(func_parser_key "${lines[12]}")
+batch_size_list=$(func_parser_value "${lines[12]}")
+use_trt_key=$(func_parser_key "${lines[13]}")
+use_trt_list=$(func_parser_value "${lines[13]}")
+precision_key=$(func_parser_key "${lines[14]}")
+precision_list=$(func_parser_value "${lines[14]}")
+infer_model_key=$(func_parser_key "${lines[15]}")
+image_dir_key=$(func_parser_key "${lines[16]}")
+infer_img_dir=$(func_parser_value "${lines[16]}")
+save_log_key=$(func_parser_key "${lines[17]}")
+save_log_value=$(func_parser_value "${lines[17]}")
+benchmark_key=$(func_parser_key "${lines[18]}")
+benchmark_value=$(func_parser_value "${lines[18]}")
+infer_key1=$(func_parser_key "${lines[19]}")
+infer_value1=$(func_parser_value "${lines[19]}")
+
+
+LOG_PATH="./test_tipc/output/${model_name}/${MODE}"
+mkdir -p ${LOG_PATH}
+status_log="${LOG_PATH}/results_python.log"
+
+
+function func_inference(){
+ IFS='|'
+ _python=$1
+ _script=$2
+ _model_dir=$3
+ _log_path=$4
+ _img_dir=$5
+ _flag_quant=$6
+ # inference
+ for use_gpu in ${use_gpu_list[*]}; do
+ if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then
+ for use_mkldnn in ${use_mkldnn_list[*]}; do
+ for threads in ${cpu_threads_list[*]}; do
+ for batch_size in ${batch_size_list[*]}; do
+ for precision in ${precision_list[*]}; do
+ if [ ${use_mkldnn} = "False" ] && [ ${precision} = "fp16" ]; then
+ continue
+ fi # skip when enable fp16 but disable mkldnn
+ if [ ${_flag_quant} = "True" ] && [ ${precision} != "int8" ]; then
+ continue
+ fi # skip when quant model inference but precision is not int8
+ set_precision=$(func_set_params "${precision_key}" "${precision}")
+
+ _save_log_path="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log"
+ set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
+ set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
+ set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
+ set_mkldnn=$(func_set_params "${use_mkldnn_key}" "${use_mkldnn}")
+ set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
+ set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
+ set_infer_params0=$(func_set_params "${save_log_key}" "${save_log_value}")
+ set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
+ command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_params0} ${set_infer_data} ${set_benchmark} ${set_precision} ${set_infer_params1} > ${_save_log_path} 2>&1 "
+ eval $command
+ last_status=${PIPESTATUS[0]}
+ eval "cat ${_save_log_path}"
+ status_check $last_status "${command}" "${status_log}" "${model_name}"
+ done
+ done
+ done
+ done
+ elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then
+ for use_trt in ${use_trt_list[*]}; do
+ for precision in ${precision_list[*]}; do
+ if [ ${_flag_quant} = "True" ] && [ ${precision} != "int8" ]; then
+ continue
+ fi # skip when quant model inference but precision is not int8
+ for batch_size in ${batch_size_list[*]}; do
+ _save_log_path="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
+ set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
+ set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
+ set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
+ set_tensorrt=$(func_set_params "${use_trt_key}" "${use_trt}")
+ set_precision=$(func_set_params "${precision_key}" "${precision}")
+ set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
+ set_infer_params0=$(func_set_params "${save_log_key}" "${save_log_value}")
+ set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
+ command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} ${set_infer_params0} > ${_save_log_path} 2>&1 "
+ eval $command
+ last_status=${PIPESTATUS[0]}
+ eval "cat ${_save_log_path}"
+ status_check $last_status "${command}" "${status_log}" "${model_name}"
+
+ done
+ done
+ done
+ else
+ echo "Does not support hardware other than CPU and GPU Currently!"
+ fi
+ done
+}
+
+if [ ${MODE} = "whole_infer" ]; then
+ GPUID=$3
+ if [ ${#GPUID} -le 0 ];then
+ env=" "
+ else
+ env="export CUDA_VISIBLE_DEVICES=${GPUID}"
+ fi
+ # set CUDA_VISIBLE_DEVICES
+ eval $env
+ export Count=0
+ IFS="|"
+ infer_run_exports=(${infer_export_list})
+ infer_quant_flag=(${infer_is_quant})
+ for infer_model in ${infer_model_dir_list[*]}; do
+ # run export
+ if [ ${infer_run_exports[Count]} != "null" ];then
+ save_infer_dir="${infer_model}_klquant"
+ set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
+ set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
+ export_log_path="${LOG_PATH}_export_${Count}.log"
+ export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 "
+ echo ${infer_run_exports[Count]}
+ echo $export_cmd
+ eval $export_cmd
+ status_export=$?
+ status_check $status_export "${export_cmd}" "${status_log}" "${model_name}"
+ else
+ save_infer_dir=${infer_model}
+ fi
+ #run inference
+ is_quant="True"
+ func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant}
+ Count=$(($Count + 1))
+ done
+fi
+
diff --git a/test_tipc/test_serving_infer_cpp.sh b/test_tipc/test_serving_infer_cpp.sh
index f9f7ac1aa554312052ca22876558e58629342549..0be6a45adf3105f088a96336dddfbe9ac612f19b 100644
--- a/test_tipc/test_serving_infer_cpp.sh
+++ b/test_tipc/test_serving_infer_cpp.sh
@@ -70,7 +70,8 @@ function func_serving(){
set_serving_server=$(func_set_params "--serving_server" "${det_serving_server_value}")
set_serving_client=$(func_set_params "--serving_client" "${det_serving_client_value}")
python_list=(${python_list})
- trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
+ trans_det_log="${LOG_PATH}/cpp_trans_model_det.log"
+ trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client} > ${trans_det_log} 2>&1 "
eval $trans_model_cmd
cp "deploy/pdserving/serving_client_conf.prototxt" ${det_serving_client_value}
# trans rec
@@ -78,32 +79,37 @@ function func_serving(){
set_serving_server=$(func_set_params "--serving_server" "${rec_serving_server_value}")
set_serving_client=$(func_set_params "--serving_client" "${rec_serving_client_value}")
python_list=(${python_list})
- trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
+ trans_rec_log="${LOG_PATH}/cpp_trans_model_rec.log"
+ trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client} > ${trans_rec_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
set_image_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
python_list=(${python_list})
cd ${serving_dir_value}
+
# cpp serving
for gpu_id in ${gpu_value[*]}; do
if [ ${gpu_id} = "null" ]; then
- web_service_cpp_cmd="${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} > serving_log_cpu.log &"
+ server_log_path="${LOG_PATH}/cpp_server_cpu.log"
+ web_service_cpp_cmd="nohup ${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} > ${server_log_path} 2>&1 &"
eval $web_service_cpp_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}"
sleep 5s
- _save_log_path="${LOG_PATH}/server_infer_cpp_cpu.log"
+ _save_log_path="${LOG_PATH}/cpp_client_cpu.log"
cpp_client_cmd="${python_list[0]} ${cpp_client_py} ${det_client_value} ${rec_client_value} > ${_save_log_path} 2>&1"
eval $cpp_client_cmd
last_status=${PIPESTATUS[0]}
+ eval "cat ${_save_log_path}"
status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}"
ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9
else
- web_service_cpp_cmd="${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} ${gpu_key} ${gpu_id} > serving_log_gpu.log &"
+ server_log_path="${LOG_PATH}/cpp_server_gpu.log"
+ web_service_cpp_cmd="nohup ${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} ${gpu_key} ${gpu_id} > ${server_log_path} 2>&1 &"
eval $web_service_cpp_cmd
sleep 5s
- _save_log_path="${LOG_PATH}/server_infer_cpp_gpu.log"
+ _save_log_path="${LOG_PATH}/cpp_client_gpu.log"
cpp_client_cmd="${python_list[0]} ${cpp_client_py} ${det_client_value} ${rec_client_value} > ${_save_log_path} 2>&1"
eval $cpp_client_cmd
last_status=${PIPESTATUS[0]}
diff --git a/test_tipc/test_serving_infer_python.sh b/test_tipc/test_serving_infer_python.sh
index c76d6f5d19e00c729953dc0df95cbfc20b6494a8..4ccccc06e23ce086e7dac1f3446aae9130605444 100644
--- a/test_tipc/test_serving_infer_python.sh
+++ b/test_tipc/test_serving_infer_python.sh
@@ -77,14 +77,16 @@ function func_serving(){
set_serving_server=$(func_set_params "--serving_server" "${det_serving_server_value}")
set_serving_client=$(func_set_params "--serving_client" "${det_serving_client_value}")
python_list=(${python_list})
- trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
+ trans_det_log="${LOG_PATH}/python_trans_model_det.log"
+ trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client} > ${trans_det_log} 2>&1 "
eval $trans_model_cmd
# trans rec
set_dirname=$(func_set_params "--dirname" "${rec_infer_model_dir_value}")
set_serving_server=$(func_set_params "--serving_server" "${rec_serving_server_value}")
set_serving_client=$(func_set_params "--serving_client" "${rec_serving_client_value}")
python_list=(${python_list})
- trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
+ trans_rec_log="${LOG_PATH}/python_trans_model_rec.log"
+ trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client} > ${trans_rec_log} 2>&1 "
eval $trans_model_cmd
elif [[ ${model_name} =~ "det" ]]; then
# trans det
@@ -92,7 +94,8 @@ function func_serving(){
set_serving_server=$(func_set_params "--serving_server" "${det_serving_server_value}")
set_serving_client=$(func_set_params "--serving_client" "${det_serving_client_value}")
python_list=(${python_list})
- trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
+ trans_det_log="${LOG_PATH}/python_trans_model_det.log"
+ trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client} > ${trans_det_log} 2>&1 "
eval $trans_model_cmd
elif [[ ${model_name} =~ "rec" ]]; then
# trans rec
@@ -100,7 +103,8 @@ function func_serving(){
set_serving_server=$(func_set_params "--serving_server" "${rec_serving_server_value}")
set_serving_client=$(func_set_params "--serving_client" "${rec_serving_client_value}")
python_list=(${python_list})
- trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
+ trans_rec_log="${LOG_PATH}/python_trans_model_rec.log"
+ trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client} > ${trans_rec_log} 2>&1 "
eval $trans_model_cmd
fi
set_image_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
@@ -108,36 +112,37 @@ function func_serving(){
cd ${serving_dir_value}
python=${python_list[0]}
-
+
# python serving
for use_gpu in ${web_use_gpu_list[*]}; do
if [ ${use_gpu} = "null" ]; then
for use_mkldnn in ${web_use_mkldnn_list[*]}; do
for threads in ${web_cpu_threads_list[*]}; do
set_cpu_threads=$(func_set_params "${web_cpu_threads_key}" "${threads}")
+ server_log_path="${LOG_PATH}/python_server_cpu_usemkldnn_${use_mkldnn}_threads_${threads}.log"
if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
- web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} ${set_rec_model_config} &"
+ web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
elif [[ ${model_name} =~ "det" ]]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
- web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} &"
+ web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
elif [[ ${model_name} =~ "rec" ]]; then
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
- web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_rec_model_config} &"
+ web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
fi
sleep 2s
for pipeline in ${pipeline_py[*]}; do
- _save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_1.log"
+ _save_log_path="${LOG_PATH}/python_client_cpu_${pipeline%_client*}_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_1.log"
pipeline_cmd="${python} ${pipeline} ${set_image_dir} > ${_save_log_path} 2>&1 "
eval $pipeline_cmd
last_status=${PIPESTATUS[0]}
@@ -151,6 +156,7 @@ function func_serving(){
elif [ ${use_gpu} = "gpu" ]; then
for use_trt in ${web_use_trt_list[*]}; do
for precision in ${web_precision_list[*]}; do
+ server_log_path="${LOG_PATH}/python_server_gpu_usetrt_${use_trt}_precision_${precision}.log"
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
continue
fi
@@ -168,26 +174,26 @@ function func_serving(){
if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
- web_service_cmd="${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} ${set_rec_model_config} &"
+ web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
elif [[ ${model_name} =~ "det" ]]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
- web_service_cmd="${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} &"
+ web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
elif [[ ${model_name} =~ "rec" ]]; then
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
- web_service_cmd="${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_rec_model_config} &"
+ web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
fi
sleep 2s
for pipeline in ${pipeline_py[*]}; do
- _save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_usetrt_${use_trt}_precision_${precision}_batchsize_1.log"
+ _save_log_path="${LOG_PATH}/python_client_gpu_${pipeline%_client*}_usetrt_${use_trt}_precision_${precision}_batchsize_1.log"
pipeline_cmd="${python} ${pipeline} ${set_image_dir}> ${_save_log_path} 2>&1"
eval $pipeline_cmd
last_status=${PIPESTATUS[0]}
diff --git a/test_tipc/test_train_inference_python.sh b/test_tipc/test_train_inference_python.sh
index 62a56a32ceb15e387f568f1b9857bced95166be3..907efcec9008f89740971bb6d4253bafb44938c4 100644
--- a/test_tipc/test_train_inference_python.sh
+++ b/test_tipc/test_train_inference_python.sh
@@ -2,7 +2,7 @@
source test_tipc/common_func.sh
FILENAME=$1
-# MODE be one of ['lite_train_lite_infer' 'lite_train_whole_infer' 'whole_train_whole_infer', 'whole_infer', 'klquant_whole_infer']
+# MODE be one of ['lite_train_lite_infer' 'lite_train_whole_infer' 'whole_train_whole_infer', 'whole_infer']
MODE=$2
dataline=$(awk 'NR==1, NR==51{print}' $FILENAME)
@@ -88,43 +88,6 @@ benchmark_value=$(func_parser_value "${lines[49]}")
infer_key1=$(func_parser_key "${lines[50]}")
infer_value1=$(func_parser_value "${lines[50]}")
-# parser klquant_infer
-if [ ${MODE} = "klquant_whole_infer" ]; then
- dataline=$(awk 'NR==1, NR==17{print}' $FILENAME)
- lines=(${dataline})
- model_name=$(func_parser_value "${lines[1]}")
- python=$(func_parser_value "${lines[2]}")
- export_weight=$(func_parser_key "${lines[3]}")
- save_infer_key=$(func_parser_key "${lines[4]}")
- # parser inference model
- infer_model_dir_list=$(func_parser_value "${lines[5]}")
- infer_export_list=$(func_parser_value "${lines[6]}")
- infer_is_quant=$(func_parser_value "${lines[7]}")
- # parser inference
- inference_py=$(func_parser_value "${lines[8]}")
- use_gpu_key=$(func_parser_key "${lines[9]}")
- use_gpu_list=$(func_parser_value "${lines[9]}")
- use_mkldnn_key=$(func_parser_key "${lines[10]}")
- use_mkldnn_list=$(func_parser_value "${lines[10]}")
- cpu_threads_key=$(func_parser_key "${lines[11]}")
- cpu_threads_list=$(func_parser_value "${lines[11]}")
- batch_size_key=$(func_parser_key "${lines[12]}")
- batch_size_list=$(func_parser_value "${lines[12]}")
- use_trt_key=$(func_parser_key "${lines[13]}")
- use_trt_list=$(func_parser_value "${lines[13]}")
- precision_key=$(func_parser_key "${lines[14]}")
- precision_list=$(func_parser_value "${lines[14]}")
- infer_model_key=$(func_parser_key "${lines[15]}")
- image_dir_key=$(func_parser_key "${lines[16]}")
- infer_img_dir=$(func_parser_value "${lines[16]}")
- save_log_key=$(func_parser_key "${lines[17]}")
- save_log_value=$(func_parser_value "${lines[17]}")
- benchmark_key=$(func_parser_key "${lines[18]}")
- benchmark_value=$(func_parser_value "${lines[18]}")
- infer_key1=$(func_parser_key "${lines[19]}")
- infer_value1=$(func_parser_value "${lines[19]}")
-fi
-
LOG_PATH="./test_tipc/output/${model_name}/${MODE}"
mkdir -p ${LOG_PATH}
status_log="${LOG_PATH}/results_python.log"
@@ -211,7 +174,7 @@ function func_inference(){
done
}
-if [ ${MODE} = "whole_infer" ] || [ ${MODE} = "klquant_whole_infer" ]; then
+if [ ${MODE} = "whole_infer" ]; then
GPUID=$3
if [ ${#GPUID} -le 0 ];then
env=" "
@@ -226,16 +189,12 @@ if [ ${MODE} = "whole_infer" ] || [ ${MODE} = "klquant_whole_infer" ]; then
infer_quant_flag=(${infer_is_quant})
for infer_model in ${infer_model_dir_list[*]}; do
# run export
- if [ ${infer_run_exports[Count]} != "null" ];then
- if [ ${MODE} = "klquant_whole_infer" ]; then
- save_infer_dir="${infer_model}_klquant"
- fi
- if [ ${MODE} = "whole_infer" ]; then
- save_infer_dir="${infer_model}"
- fi
+ if [ ${infer_run_exports[Count]} != "null" ];then
+ save_infer_dir="${infer_model}"
set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
- export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key}"
+ export_log_path="${LOG_PATH}_export_${Count}.log"
+ export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 "
echo ${infer_run_exports[Count]}
echo $export_cmd
eval $export_cmd
@@ -246,9 +205,6 @@ if [ ${MODE} = "whole_infer" ] || [ ${MODE} = "klquant_whole_infer" ]; then
fi
#run inference
is_quant=${infer_quant_flag[Count]}
- if [ ${MODE} = "klquant_whole_infer" ]; then
- is_quant="True"
- fi
func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant}
Count=$(($Count + 1))
done
@@ -339,6 +295,7 @@ else
fi
# run train
eval $cmd
+ eval "cat ${save_log}/train.log >> ${save_log}.log"
status_check $? "${cmd}" "${status_log}" "${model_name}"
set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}")
@@ -347,7 +304,8 @@ else
if [ ${eval_py} != "null" ]; then
eval ${env}
set_eval_params1=$(func_set_params "${eval_key1}" "${eval_value1}")
- eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1}"
+ eval_log_path="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}_eval.log"
+ eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1} > ${eval_log_path} 2>&1 "
eval $eval_cmd
status_check $? "${eval_cmd}" "${status_log}" "${model_name}"
fi
@@ -355,9 +313,10 @@ else
if [ ${run_export} != "null" ]; then
# run export model
save_infer_path="${save_log}"
+ export_log_path="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}_export.log"
set_export_weight=$(func_set_params "${export_weight}" "${save_log}/${train_model_name}")
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_path}")
- export_cmd="${python} ${run_export} ${set_export_weight} ${set_save_infer_key}"
+ export_cmd="${python} ${run_export} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 "
eval $export_cmd
status_check $? "${export_cmd}" "${status_log}" "${model_name}"
diff --git a/tools/export_model.py b/tools/export_model.py
index 07a7f3e2bc52612533054f0b56f11d7bfdea1967..11794f74201b3d98dc6cdf90095bc6e3c1a30449 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -31,7 +31,12 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser
-def export_single_model(model, arch_config, save_path, logger, quanter=None):
+def export_single_model(model,
+ arch_config,
+ save_path,
+ logger,
+ input_shape=None,
+ quanter=None):
if arch_config["algorithm"] == "SRN":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
@@ -64,7 +69,7 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
else:
other_shape = [
paddle.static.InputSpec(
- shape=[None, 3, 64, 256], dtype="float32"),
+ shape=[None] + input_shape, dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "PREN":
@@ -89,6 +94,41 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
]
]
model = to_static(model, input_spec=other_shape)
+ elif arch_config["algorithm"] == "ViTSTR":
+ other_shape = [
+ paddle.static.InputSpec(
+ shape=[None, 1, 224, 224], dtype="float32"),
+ ]
+ model = to_static(model, input_spec=other_shape)
+ elif arch_config["algorithm"] == "ABINet":
+ other_shape = [
+ paddle.static.InputSpec(
+ shape=[None, 3, 32, 128], dtype="float32"),
+ ]
+ # print([None, 3, 32, 128])
+ model = to_static(model, input_spec=other_shape)
+ elif arch_config["algorithm"] == "NRTR":
+ other_shape = [
+ paddle.static.InputSpec(
+ shape=[None, 1, 32, 100], dtype="float32"),
+ ]
+ model = to_static(model, input_spec=other_shape)
+ elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
+ input_spec = [
+ paddle.static.InputSpec(
+ shape=[None, 512], dtype="int64"), # input_ids
+ paddle.static.InputSpec(
+ shape=[None, 512, 4], dtype="int64"), # bbox
+ paddle.static.InputSpec(
+ shape=[None, 512], dtype="int64"), # attention_mask
+ paddle.static.InputSpec(
+ shape=[None, 512], dtype="int64"), # token_type_ids
+ paddle.static.InputSpec(
+ shape=[None, 3, 224, 224], dtype="int64"), # image
+ ]
+ if arch_config["algorithm"] == "LayoutLM":
+ input_spec.pop(4)
+ model = to_static(model, input_spec=[input_spec])
else:
infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec":
@@ -100,10 +140,10 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape[-1] = 100
- if arch_config["algorithm"] == "NRTR":
- infer_shape = [1, 32, 100]
elif arch_config["model_type"] == "table":
infer_shape = [3, 488, 488]
+ if arch_config["algorithm"] == "TableMaster":
+ infer_shape = [3, 480, 480]
model = to_static(
model,
input_spec=[
@@ -166,13 +206,20 @@ def main():
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
- load_model(config, model)
+ load_model(config, model, model_type=config['Architecture']["model_type"])
model.eval()
save_path = config["Global"]["save_inference_dir"]
arch_config = config["Architecture"]
+ if arch_config["algorithm"] == "SVTR" and arch_config["Head"][
+ "name"] != 'MultiHead':
+ input_shape = config["Eval"]["dataset"]["transforms"][-2][
+ 'SVTRRecResizeImg']['image_shape']
+ else:
+ input_shape = None
+
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list):
@@ -181,7 +228,8 @@ def main():
sub_model_save_path, logger)
else:
save_path = os.path.join(save_path, "inference")
- export_single_model(model, arch_config, save_path, logger)
+ export_single_model(
+ model, arch_config, save_path, logger, input_shape=input_shape)
if __name__ == "__main__":
diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py
index 5f2675d667c2aab8186886a60d8d447f43419954..394a48948b1f284bd405532769b76eeb298668bd 100755
--- a/tools/infer/predict_det.py
+++ b/tools/infer/predict_det.py
@@ -67,6 +67,23 @@ class TextDetector(object):
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
+ elif self.det_algorithm == "DB++":
+ postprocess_params['name'] = 'DBPostProcess'
+ postprocess_params["thresh"] = args.det_db_thresh
+ postprocess_params["box_thresh"] = args.det_db_box_thresh
+ postprocess_params["max_candidates"] = 1000
+ postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
+ postprocess_params["use_dilation"] = args.use_dilation
+ postprocess_params["score_mode"] = args.det_db_score_mode
+ pre_process_list[1] = {
+ 'NormalizeImage': {
+ 'std': [1.0, 1.0, 1.0],
+ 'mean':
+ [0.48109378172549, 0.45752457890196, 0.40787054090196],
+ 'scale': '1./255.',
+ 'order': 'hwc'
+ }
+ }
elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh
@@ -154,9 +171,10 @@ class TextDetector(object):
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
- diff = np.diff(pts, axis=1)
- rect[1] = pts[np.argmin(diff)]
- rect[3] = pts[np.argmax(diff)]
+ tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
+ diff = np.diff(np.array(tmp), axis=1)
+ rect[1] = tmp[np.argmin(diff)]
+ rect[3] = tmp[np.argmax(diff)]
return rect
def clip_det_res(self, points, img_height, img_width):
@@ -230,7 +248,7 @@ class TextDetector(object):
preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3]
- elif self.det_algorithm in ['DB', 'PSE']:
+ elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
preds['maps'] = outputs[0]
elif self.det_algorithm == 'FCE':
for i, output in enumerate(outputs):
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 6647c9ab524bba50f50c47c14ac509f8073b7923..5131903dc4ce5f907bd2a3ad3f0afbc93b1350ef 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -75,7 +75,19 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char,
"rm_symbol": True
+ }
+ elif self.rec_algorithm == 'ViTSTR':
+ postprocess_params = {
+ 'name': 'ViTSTRLabelDecode',
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
+ }
+ elif self.rec_algorithm == 'ABINet':
+ postprocess_params = {
+ 'name': 'ABINetLabelDecode',
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
@@ -104,15 +116,22 @@ class TextRecognizer(object):
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
- if self.rec_algorithm == 'NRTR':
+ if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR':
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im
image_pil = Image.fromarray(np.uint8(img))
- img = image_pil.resize([100, 32], Image.ANTIALIAS)
+ if self.rec_algorithm == 'ViTSTR':
+ img = image_pil.resize([imgW, imgH], Image.BICUBIC)
+ else:
+ img = image_pil.resize([imgW, imgH], Image.ANTIALIAS)
img = np.array(img)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
- return norm_img.astype(np.float32) / 128. - 1.
+ if self.rec_algorithm == 'ViTSTR':
+ norm_img = norm_img.astype(np.float32) / 255.
+ else:
+ norm_img = norm_img.astype(np.float32) / 128. - 1.
+ return norm_img
assert imgC == img.shape[2]
imgW = int((imgH * max_wh_ratio))
@@ -140,17 +159,6 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
- def resize_norm_img_svtr(self, img, image_shape):
-
- imgC, imgH, imgW = image_shape
- resized_image = cv2.resize(
- img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
- resized_image = resized_image.astype('float32')
- resized_image = resized_image.transpose((2, 0, 1)) / 255
- resized_image -= 0.5
- resized_image /= 0.5
- return resized_image
-
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
@@ -258,6 +266,35 @@ class TextRecognizer(object):
return padding_im, resize_shape, pad_shape, valid_ratio
+ def resize_norm_img_svtr(self, img, image_shape):
+
+ imgC, imgH, imgW = image_shape
+ resized_image = cv2.resize(
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ resized_image = resized_image.astype('float32')
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ return resized_image
+
+ def resize_norm_img_abinet(self, img, image_shape):
+
+ imgC, imgH, imgW = image_shape
+
+ resized_image = cv2.resize(
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ resized_image = resized_image.astype('float32')
+ resized_image = resized_image / 255.
+
+ mean = np.array([0.485, 0.456, 0.406])
+ std = np.array([0.229, 0.224, 0.225])
+ resized_image = (
+ resized_image - mean[None, None, ...]) / std[None, None, ...]
+ resized_image = resized_image.transpose((2, 0, 1))
+ resized_image = resized_image.astype('float32')
+
+ return resized_image
+
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
@@ -309,6 +346,11 @@ class TextRecognizer(object):
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
+ elif self.rec_algorithm == "ABINet":
+ norm_img = self.resize_norm_img_abinet(
+ img_list[indices[ino]], self.rec_image_shape)
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
elif self.rec_algorithm == "RobustScanner":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
img_list[indices[ino]], self.rec_image_shape, width_downsample_ratio=0.25)
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 366212f228eec33f11c825bfaf1e360258af9b2e..7eb77dec74bf283936e1143edcb5b5dfc28365bd 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -153,6 +153,8 @@ def create_predictor(args, mode, logger):
model_dir = args.rec_model_dir
elif mode == 'table':
model_dir = args.table_model_dir
+ elif mode == 'ser':
+ model_dir = args.ser_model_dir
else:
model_dir = args.e2e_model_dir
@@ -316,8 +318,13 @@ def create_predictor(args, mode, logger):
# create predictor
predictor = inference.create_predictor(config)
input_names = predictor.get_input_names()
- for name in input_names:
- input_tensor = predictor.get_input_handle(name)
+ if mode in ['ser', 're']:
+ input_tensor = []
+ for name in input_names:
+ input_tensor.append(predictor.get_input_handle(name))
+ else:
+ for name in input_names:
+ input_tensor = predictor.get_input_handle(name)
output_tensors = get_output_tensors(args, mode, predictor)
return predictor, input_tensor, output_tensors, config
diff --git a/tools/infer_det.py b/tools/infer_det.py
index 1acecedf3e42fe67a93644a7f06c07c8b6bea2e3..df346523896c9c3f82d254600986e0eb221e3c9f 100755
--- a/tools/infer_det.py
+++ b/tools/infer_det.py
@@ -44,7 +44,7 @@ def draw_det_res(dt_boxes, config, img, img_name, save_path):
import cv2
src_im = img
for box in dt_boxes:
- box = box.astype(np.int32).reshape((-1, 1, 2))
+ box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
if not os.path.exists(save_path):
os.makedirs(save_path)
@@ -106,7 +106,7 @@ def main():
dt_boxes_list = []
for box in boxes:
tmp_json = {"transcription": ""}
- tmp_json['points'] = box.tolist()
+ tmp_json['points'] = list(box)
dt_boxes_list.append(tmp_json)
det_box_json[k] = dt_boxes_list
save_det_path = os.path.dirname(config['Global'][
@@ -118,7 +118,7 @@ def main():
# write result
for box in boxes:
tmp_json = {"transcription": ""}
- tmp_json['points'] = box.tolist()
+ tmp_json['points'] = list(box)
dt_boxes_json.append(tmp_json)
save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/det_results/"
diff --git a/tools/infer_kie.py b/tools/infer_kie.py
index 0cb0b8702cbd7ea74a7b7fcff69122731578a1bd..346e2e0aeeee695ab49577b6b13dcc058150df1a 100755
--- a/tools/infer_kie.py
+++ b/tools/infer_kie.py
@@ -39,13 +39,12 @@ import time
def read_class_list(filepath):
- dict = {}
+ ret = {}
with open(filepath, "r") as f:
lines = f.readlines()
- for line in lines:
- key, value = line.split(" ")
- dict[key] = value.rstrip()
- return dict
+ for idx, line in enumerate(lines):
+ ret[idx] = line.strip("\n")
+ return ret
def draw_kie_result(batch, node, idx_to_cls, count):
@@ -71,7 +70,7 @@ def draw_kie_result(batch, node, idx_to_cls, count):
x_min = int(min([point[0] for point in new_box]))
y_min = int(min([point[1] for point in new_box]))
- pred_label = str(node_pred_label[i])
+ pred_label = node_pred_label[i]
if pred_label in idx_to_cls:
pred_label = idx_to_cls[pred_label]
pred_score = '{:.2f}'.format(node_pred_score[i])
@@ -109,8 +108,7 @@ def main():
save_res_path = config['Global']['save_res_path']
class_path = config['Global']['class_path']
idx_to_cls = read_class_list(class_path)
- if not os.path.exists(os.path.dirname(save_res_path)):
- os.makedirs(os.path.dirname(save_res_path))
+ os.makedirs(os.path.dirname(save_res_path), exist_ok=True)
model.eval()
diff --git a/tools/infer_table.py b/tools/infer_table.py
index 66c2da4421a313c634d27eb7a1013638a7c005ed..6c02dd8640c9345c267e56d6e5a0c14bde121b7e 100644
--- a/tools/infer_table.py
+++ b/tools/infer_table.py
@@ -36,10 +36,12 @@ from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
+from ppocr.utils.visual import draw_rectangle
import tools.program as program
import cv2
+@paddle.no_grad()
def main(config, device, logger, vdl_writer):
global_config = config['Global']
@@ -53,53 +55,61 @@ def main(config, device, logger, vdl_writer):
getattr(post_process_class, 'character'))
model = build_model(config['Architecture'])
+ algorithm = config['Architecture']['algorithm']
+ use_xywh = algorithm in ['TableMaster']
load_model(config, model)
# create data ops
transforms = []
- use_padding = False
for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0]
- if 'Label' in op_name:
+ if 'Encode' in op_name:
continue
if op_name == 'KeepKeys':
- op[op_name]['keep_keys'] = ['image']
- if op_name == "ResizeTableImage":
- use_padding = True
- padding_max_len = op['ResizeTableImage']['max_len']
+ op[op_name]['keep_keys'] = ['image', 'shape']
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
+ save_res_path = config['Global']['save_res_path']
+ os.makedirs(save_res_path, exist_ok=True)
+
model.eval()
- for file in get_image_file_list(config['Global']['infer_img']):
- logger.info("infer_img: {}".format(file))
- with open(file, 'rb') as f:
- img = f.read()
- data = {'image': img}
- batch = transform(data, ops)
- images = np.expand_dims(batch[0], axis=0)
- images = paddle.to_tensor(images)
- preds = model(images)
- post_result = post_process_class(preds)
- res_html_code = post_result['res_html_code']
- res_loc = post_result['res_loc']
- img = cv2.imread(file)
- imgh, imgw = img.shape[0:2]
- res_loc_final = []
- for rno in range(len(res_loc[0])):
- x0, y0, x1, y1 = res_loc[0][rno]
- left = max(int(imgw * x0), 0)
- top = max(int(imgh * y0), 0)
- right = min(int(imgw * x1), imgw - 1)
- bottom = min(int(imgh * y1), imgh - 1)
- cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2)
- res_loc_final.append([left, top, right, bottom])
- res_loc_str = json.dumps(res_loc_final)
- logger.info("result: {}, {}".format(res_html_code, res_loc_final))
- logger.info("success!")
+ with open(
+ os.path.join(save_res_path, 'infer.txt'), mode='w',
+ encoding='utf-8') as f_w:
+ for file in get_image_file_list(config['Global']['infer_img']):
+ logger.info("infer_img: {}".format(file))
+ with open(file, 'rb') as f:
+ img = f.read()
+ data = {'image': img}
+ batch = transform(data, ops)
+ images = np.expand_dims(batch[0], axis=0)
+ shape_list = np.expand_dims(batch[1], axis=0)
+
+ images = paddle.to_tensor(images)
+ preds = model(images)
+ post_result = post_process_class(preds, [shape_list])
+
+ structure_str_list = post_result['structure_batch_list'][0]
+ bbox_list = post_result['bbox_batch_list'][0]
+ structure_str_list = structure_str_list[0]
+ structure_str_list = [
+ '', '', ''
+ ] + structure_str_list + [' ', '', '']
+ bbox_list_str = json.dumps(bbox_list.tolist())
+
+ logger.info("result: {}, {}".format(structure_str_list,
+ bbox_list_str))
+ f_w.write("result: {}, {}\n".format(structure_str_list,
+ bbox_list_str))
+
+ img = draw_rectangle(file, bbox_list, use_xywh)
+ cv2.imwrite(
+ os.path.join(save_res_path, os.path.basename(file)), img)
+ logger.info("success!")
if __name__ == '__main__':
diff --git a/tools/infer_vqa_token_ser.py b/tools/infer_vqa_token_ser.py
index 83ed72b392e627c161903c3945f57be0abfabc2b..0173a554cace31e20ab47dbe36d132a4dbb2127b 100755
--- a/tools/infer_vqa_token_ser.py
+++ b/tools/infer_vqa_token_ser.py
@@ -44,6 +44,7 @@ def to_tensor(data):
from collections import defaultdict
data_dict = defaultdict(list)
to_tensor_idxs = []
+
for idx, v in enumerate(data):
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if idx not in to_tensor_idxs:
@@ -57,6 +58,7 @@ def to_tensor(data):
class SerPredictor(object):
def __init__(self, config):
global_config = config['Global']
+ self.algorithm = config['Architecture']["algorithm"]
# build post process
self.post_process_class = build_post_process(config['PostProcess'],
@@ -70,7 +72,10 @@ class SerPredictor(object):
from paddleocr import PaddleOCR
- self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False)
+ self.ocr_engine = PaddleOCR(
+ use_angle_cls=False,
+ show_log=False,
+ use_gpu=global_config['use_gpu'])
# create data ops
transforms = []
@@ -80,29 +85,30 @@ class SerPredictor(object):
op[op_name]['ocr_engine'] = self.ocr_engine
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = [
- 'input_ids', 'labels', 'bbox', 'image', 'attention_mask',
- 'token_type_ids', 'segment_offset_id', 'ocr_info',
+ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
+ 'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities'
]
transforms.append(op)
- global_config['infer_mode'] = True
+ if config["Global"].get("infer_mode", None) is None:
+ global_config['infer_mode'] = True
self.ops = create_operators(config['Eval']['dataset']['transforms'],
global_config)
self.model.eval()
- def __call__(self, img_path):
- with open(img_path, 'rb') as f:
+ def __call__(self, data):
+ with open(data["img_path"], 'rb') as f:
img = f.read()
- data = {'image': img}
+ data["image"] = img
batch = transform(data, self.ops)
batch = to_tensor(batch)
preds = self.model(batch)
+ if self.algorithm in ['LayoutLMv2', 'LayoutXLM']:
+ preds = preds[0]
+
post_result = self.post_process_class(
- preds,
- attention_masks=batch[4],
- segment_offset_ids=batch[6],
- ocr_infos=batch[7])
+ preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
return post_result, batch
@@ -112,20 +118,33 @@ if __name__ == '__main__':
ser_engine = SerPredictor(config)
- infer_imgs = get_image_file_list(config['Global']['infer_img'])
+ if config["Global"].get("infer_mode", None) is False:
+ data_dir = config['Eval']['dataset']['data_dir']
+ with open(config['Global']['infer_img'], "rb") as f:
+ infer_imgs = f.readlines()
+ else:
+ infer_imgs = get_image_file_list(config['Global']['infer_img'])
+
with open(
os.path.join(config['Global']['save_res_path'],
"infer_results.txt"),
"w",
encoding='utf-8') as fout:
- for idx, img_path in enumerate(infer_imgs):
+ for idx, info in enumerate(infer_imgs):
+ if config["Global"].get("infer_mode", None) is False:
+ data_line = info.decode('utf-8')
+ substr = data_line.strip("\n").split("\t")
+ img_path = os.path.join(data_dir, substr[0])
+ data = {'img_path': img_path, 'label': substr[1]}
+ else:
+ img_path = info
+ data = {'img_path': img_path}
+
save_img_path = os.path.join(
config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
- logger.info("process: [{}/{}], save result to {}".format(
- idx, len(infer_imgs), save_img_path))
- result, _ = ser_engine(img_path)
+ result, _ = ser_engine(data)
result = result[0]
fout.write(img_path + "\t" + json.dumps(
{
@@ -133,3 +152,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img_path, result)
cv2.imwrite(save_img_path, img_res)
+
+ logger.info("process: [{}/{}], save result to {}".format(
+ idx, len(infer_imgs), save_img_path))
diff --git a/tools/infer_vqa_token_ser_re.py b/tools/infer_vqa_token_ser_re.py
index 6210f7f3c24227c9d366b08ce93ccfe4df849ce1..20ab1fe176c3be75f7a7b01a8d77df6419c58c75 100755
--- a/tools/infer_vqa_token_ser_re.py
+++ b/tools/infer_vqa_token_ser_re.py
@@ -38,7 +38,7 @@ from ppocr.utils.save_load import load_model
from ppocr.utils.visual import draw_re_results
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
-from tools.program import ArgsParser, load_config, merge_config, check_gpu
+from tools.program import ArgsParser, load_config, merge_config
from tools.infer_vqa_token_ser import SerPredictor
@@ -107,7 +107,7 @@ def make_input(ser_inputs, ser_results):
# remove ocr_info segment_offset_id and label in ser input
ser_inputs.pop(7)
ser_inputs.pop(6)
- ser_inputs.pop(1)
+ ser_inputs.pop(5)
return ser_inputs, entity_idx_dict_batch
@@ -131,9 +131,7 @@ class SerRePredictor(object):
self.model.eval()
def __call__(self, img_path):
- ser_results, ser_inputs = self.ser_engine(img_path)
- paddle.save(ser_inputs, 'ser_inputs.npy')
- paddle.save(ser_results, 'ser_results.npy')
+ ser_results, ser_inputs = self.ser_engine({'img_path': img_path})
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
preds = self.model(re_input)
post_result = self.post_process_class(
@@ -155,7 +153,6 @@ def preprocess():
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
- check_gpu(use_gpu)
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
@@ -185,9 +182,7 @@ if __name__ == '__main__':
for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join(
config['Global']['save_res_path'],
- os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
- logger.info("process: [{}/{}], save result to {}".format(
- idx, len(infer_imgs), save_img_path))
+ os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
result = ser_re_engine(img_path)
result = result[0]
@@ -197,3 +192,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n")
img_res = draw_re_results(img_path, result)
cv2.imwrite(save_img_path, img_res)
+
+ logger.info("process: [{}/{}], save result to {}".format(
+ idx, len(infer_imgs), save_img_path))
diff --git a/tools/program.py b/tools/program.py
index 0f9d09d8e17f4c2604693a0e69964e5811f5f23c..335ceb08a83fea468df278633b35ac3bc57ee2ed 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -255,6 +255,8 @@ def train(config,
with paddle.amp.auto_cast():
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
+ elif model_type in ["kie", 'vqa']:
+ preds = model(batch)
else:
preds = model(images)
else:
@@ -279,8 +281,11 @@ def train(config,
if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
batch = [item.numpy() for item in batch]
- if model_type in ['table', 'kie']:
+ if model_type in ['kie']:
eval_class(preds, batch)
+ elif model_type in ['table']:
+ post_result = post_process_class(preds, batch)
+ eval_class(post_result, batch)
else:
if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
]: # for multi head loss
@@ -307,7 +312,8 @@ def train(config,
train_stats.update(stats)
if log_writer is not None and dist.get_rank() == 0:
- log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
+ log_writer.log_metrics(
+ metrics=train_stats.get(), prefix="TRAIN", step=global_step)
if dist.get_rank() == 0 and (
(global_step > 0 and global_step % print_batch_step == 0) or
@@ -354,7 +360,8 @@ def train(config,
# logger metric
if log_writer is not None:
- log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step)
+ log_writer.log_metrics(
+ metrics=cur_metric, prefix="EVAL", step=global_step)
if cur_metric[main_indicator] >= best_model_dict[
main_indicator]:
@@ -377,11 +384,18 @@ def train(config,
logger.info(best_str)
# logger best metric
if log_writer is not None:
- log_writer.log_metrics(metrics={
- "best_{}".format(main_indicator): best_model_dict[main_indicator]
- }, prefix="EVAL", step=global_step)
-
- log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict)
+ log_writer.log_metrics(
+ metrics={
+ "best_{}".format(main_indicator):
+ best_model_dict[main_indicator]
+ },
+ prefix="EVAL",
+ step=global_step)
+
+ log_writer.log_model(
+ is_best=True,
+ prefix="best_accuracy",
+ metadata=best_model_dict)
reader_start = time.time()
if dist.get_rank() == 0:
@@ -413,7 +427,8 @@ def train(config,
epoch=epoch,
global_step=global_step)
if log_writer is not None:
- log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch))
+ log_writer.log_model(
+ is_best=False, prefix='iter_epoch_{}'.format(epoch))
best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
@@ -451,7 +466,6 @@ def eval(model,
preds = model(batch)
else:
preds = model(images)
-
batch_numpy = []
for item in batch:
if isinstance(item, paddle.Tensor):
@@ -461,9 +475,9 @@ def eval(model,
# Obtain usable results from post-processing methods
total_time += time.time() - start
# Evaluate the results of the current batch
- if model_type in ['table', 'kie']:
+ if model_type in ['kie']:
eval_class(preds, batch_numpy)
- elif model_type in ['vqa']:
+ elif model_type in ['table', 'vqa']:
post_result = post_process_class(preds, batch_numpy)
eval_class(post_result, batch_numpy)
else:
@@ -564,8 +578,8 @@ def preprocess(is_train=False):
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
- 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR',
- 'RobustScanner'
+ 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
+ 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'RobustScanner'
]
if use_xpu:
@@ -586,7 +600,8 @@ def preprocess(is_train=False):
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
log_writer = VDLLogger(save_model_dir)
loggers.append(log_writer)
- if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config:
+ if ('use_wandb' in config['Global'] and
+ config['Global']['use_wandb']) or 'wandb' in config:
save_dir = config['Global']['save_model_dir']
wandb_writer_path = "{}/wandb".format(save_dir)
if "wandb" in config:
| | |