diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index d353391c9af2b85bd01ba659f541fa1791461f68..594197a6cd862664b17ed8d84c2d7cd908332386 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -44,16 +44,16 @@ class BaseRecLabelDecode(object): self.character_str = string.printable[:-6] dict_character = list(self.character_str) elif character_type in support_character_type: - self.character_str = "" + self.character_str = [] assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format( character_type) with open(character_dict_path, "rb") as fin: lines = fin.readlines() for line in lines: line = line.decode('utf-8').strip("\n").strip("\r\n") - self.character_str += line + self.character_str.append(line) if use_space_char: - self.character_str += " " + self.character_str.append(" ") dict_character = list(self.character_str) else: @@ -288,3 +288,172 @@ class SRNLabelDecode(BaseRecLabelDecode): assert False, "unsupport type %s in get_beg_end_flag_idx" \ % beg_or_end return idx + + +class TableLabelDecode(object): + """ """ + + def __init__(self, + max_text_length, + max_elem_length, + max_cell_num, + character_dict_path, + **kwargs): + self.max_text_length = max_text_length + self.max_elem_length = max_elem_length + self.max_cell_num = max_cell_num + list_character, list_elem = self.load_char_elem_dict(character_dict_path) + list_character = self.add_special_char(list_character) + list_elem = self.add_special_char(list_elem) + self.dict_character = {} + self.dict_idx_character = {} + for i, char in enumerate(list_character): + self.dict_idx_character[i] = char + self.dict_character[char] = i + self.dict_elem = {} + self.dict_idx_elem = {} + for i, elem in enumerate(list_elem): + self.dict_idx_elem[i] = elem + self.dict_elem[elem] = i + + def load_char_elem_dict(self, character_dict_path): + list_character = [] + list_elem = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + substr = lines[0].decode('utf-8').strip("\n").split("\t") + character_num = int(substr[0]) + elem_num = int(substr[1]) + for cno in range(1, 1 + character_num): + character = lines[cno].decode('utf-8').strip("\n") + list_character.append(character) + for eno in range(1 + character_num, 1 + character_num + elem_num): + elem = lines[eno].decode('utf-8').strip("\n") + list_elem.append(elem) + return list_character, list_elem + + def add_special_char(self, list_character): + self.beg_str = "sos" + self.end_str = "eos" + list_character = [self.beg_str] + list_character + [self.end_str] + return list_character + + def get_sp_tokens(self): + char_beg_idx = self.get_beg_end_flag_idx('beg', 'char') + char_end_idx = self.get_beg_end_flag_idx('end', 'char') + elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem') + elem_end_idx = self.get_beg_end_flag_idx('end', 'elem') + elem_char_idx1 = self.dict_elem[''] + elem_char_idx2 = self.dict_elem['', '' or elem == ''\ + # # or 'rowspan' in elem or 'colspan' in elem: + # if elem == '' or elem == '': + # select_td_tokens.append(self.dict_elem[elem]) + # if 'rowspan' in elem or 'colspan' in elem: + # select_span_tokens.append(self.dict_elem[elem]) + result_list = [] + result_pos_list = [] + result_score_list = [] + result_elem_idx_list = [] + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + elem_pos_list = [] + elem_idx_list = [] + score_list = [] + for idx in range(len(text_index[batch_idx])): + tmp_elem_idx = int(text_index[batch_idx][idx]) + if idx > 0 and tmp_elem_idx == end_idx: + break + if tmp_elem_idx in ignored_tokens: + continue + # if tmp_elem_idx in select_td_tokens: + # total_td_score += structure_probs[batch_idx, idx] + # total_td_num += 1 + # if tmp_elem_idx in select_span_tokens: + # total_span_score += structure_probs[batch_idx, idx] + # total_span_num += 1 + char_list.append(current_dict[tmp_elem_idx]) + elem_pos_list.append(idx) + score_list.append(structure_probs[batch_idx, idx]) + elem_idx_list.append(tmp_elem_idx) + result_list.append(char_list) + result_pos_list.append(elem_pos_list) + result_score_list.append(score_list) + result_elem_idx_list.append(elem_idx_list) + return result_list, result_pos_list, result_score_list, result_elem_idx_list + + def get_ignored_tokens(self, char_or_elem): + beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem) + end_idx = self.get_beg_end_flag_idx("end", char_or_elem) + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): + if char_or_elem == "char": + if beg_or_end == "beg": + idx = self.dict_character[self.beg_str] + elif beg_or_end == "end": + idx = self.dict_character[self.end_str] + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \ + % beg_or_end + elif char_or_elem == "elem": + if beg_or_end == "beg": + idx = self.dict_elem[self.beg_str] + elif beg_or_end == "end": + idx = self.dict_elem[self.end_str] + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ + % beg_or_end + else: + assert False, "Unsupport type %s in char_or_elem" \ + % char_or_elem + return idx diff --git a/ppocr/utils/dict/table_dict.txt b/ppocr/utils/dict/table_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..804f3e31bfe21ef13b4f4e97e3746ab627fbe705 --- /dev/null +++ b/ppocr/utils/dict/table_dict.txt @@ -0,0 +1,278 @@ +← + +☆ +─ +α + + +⋅ +$ +ω +ψ +χ +( +υ +≥ +σ +, +ρ +ε +0 +■ +4 +8 +✗ +b +< +✓ +Ψ +Ω +€ +D +3 +Π +H +║ + +L +Φ +Χ +θ +P +κ +λ +μ +T +ξ +X +β +γ +δ +\ +ζ +η +` +d + +h +f +l +Θ +p +√ +t + +x +Β +Γ +Δ +| +ǂ +ɛ +j +̧ +➢ +⁡ +̌ +′ +« +△ +▲ +# + +' +Ι ++ +¶ +/ +▼ +⇑ +□ +· +7 +▪ +; +? +➔ +∩ +C +÷ +G +⇒ +K + +O +S +С +W +Α +[ +○ +_ +● +‡ +c +z +g + +o + +〈 +〉 +s +⩽ +w +φ +ʹ +{ +» +∣ +̆ +e +ˆ +∈ +τ +◆ +ι +∅ +∆ +∙ +∘ +Ø +ß +✔ +∞ +∑ +− +× +◊ +∗ +∖ +˃ +˂ +∫ +" +i +& +π +↔ +* +∥ +æ +∧ +. +⁄ +ø +Q +∼ +6 +⁎ +: +★ +> +a +B +≈ +F +J +̄ +N +♯ +R +V + +― +Z +♣ +^ +¤ +¥ +§ + +¢ +£ +≦ +­ +≤ +‖ +Λ +© +n +↓ +→ +↑ +r +° +± +v + +♂ +k +♀ +~ +ᅟ +̇ +@ +” +♦ +ł +® +⊕ +„ +! + +% +⇓ +) +- +1 +5 +9 += +А +A +‰ +⋆ +Σ +E +◦ +I +※ +M +m +̨ +⩾ +† + +• +U +Y +
 +] +̸ +2 +‐ +– +‒ +̂ +— +̀ +́ +’ +‘ +⋮ +⋯ +̊ +“ +̈ +≧ +q +u +ı +y + +​ +̃ +} +ν diff --git a/ppocr/utils/dict/table_structure_dict.txt b/ppocr/utils/dict/table_structure_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..9c4531e5f3b8c498e70d3c2ea0471e5e746a2c30 --- /dev/null +++ b/ppocr/utils/dict/table_structure_dict.txt @@ -0,0 +1,2759 @@ +277 28 1267 1186 + +V +a +r +i +b +l +e + +H +z +d + +t +o +9 +5 +% +C +I + +p + +v +u +* +A +g +( +m +n +) +0 +. +7 +1 +6 +≤ +> +8 +3 +– +2 +G +4 +M +F +T +y +f +s +L +w +c +U +h +D +S +Q +R +x +P +- +E +O +/ +k +, ++ +N +K +q +′ +[ +] +< +≥ + +− + +μ +± +J +j +W +_ +Δ +B +“ +: +Y +α +λ +; + + +? +∼ += +° +# +̊ +̈ +̂ +’ +Z +X +∗ +— +β +' +† +~ +@ +" +γ +↓ +↑ +& +‡ +χ +” +σ +§ +| +¶ +‐ +× +$ +→ +√ +✓ +‘ +\ +∞ +π +• +® +^ +∆ +≧ + + +́ +♀ +♂ +‒ +⁎ +▲ +· +£ +φ +Ψ +ß +△ +☆ +▪ +η +€ +∧ +̃ +Φ +ρ +̄ +δ +‰ +̧ +Ω +♦ +{ +} +̀ +∑ +∫ +ø +κ +ε +¥ +※ +` +ω +Σ +➔ +‖ +Β +̸ +
 +─ +● +⩾ +Χ +Α +⋅ +◆ +★ +■ +ψ +ǂ +□ +ζ +! +Γ +↔ +θ +⁄ +〈 +〉 +― +υ +τ +⋆ +Ø +© +∥ +С +˂ +➢ +ɛ +⁡ +✗ +← +○ +¢ +⩽ +∖ +˃ +­ +≈ +Π +̌ +≦ +∅ +ᅟ + + +∣ +¤ +♯ +̆ +ξ +÷ +▼ + +ι +ν +║ + + +◦ +​ +◊ +∙ +« +» +ł +ı +Θ +∈ +„ +∘ +✔ +̇ +æ +ʹ +ˆ +♣ +⇓ +∩ +⊕ +⇒ +⇑ +̨ +Ι +Λ +⋯ +А +⋮ + + + + + + + + + + colspan="2" + colspan="3" + rowspan="2" + colspan="4" + colspan="6" + rowspan="3" + colspan="9" + colspan="10" + colspan="7" + rowspan="4" + rowspan="5" + rowspan="9" + colspan="8" + rowspan="8" + rowspan="6" + rowspan="7" + rowspan="10" +0 2924682 +1 3405345 +2 2363468 +3 2709165 +4 4078680 +5 3250792 +6 1923159 +7 1617890 +8 1450532 +9 1717624 +10 1477550 +11 1489223 +12 915528 +13 819193 +14 593660 +15 518924 +16 682065 +17 494584 +18 400591 +19 396421 +20 340994 +21 280688 +22 250328 +23 226786 +24 199927 +25 182707 +26 164629 +27 141613 +28 127554 +29 116286 +30 107682 +31 96367 +32 88002 +33 79234 +34 72186 +35 65921 +36 60374 +37 55976 +38 52166 +39 47414 +40 44932 +41 41279 +42 38232 +43 35463 +44 33703 +45 30557 +46 29639 +47 27000 +48 25447 +49 23186 +50 22093 +51 20412 +52 19844 +53 18261 +54 17561 +55 16499 +56 15597 +57 14558 +58 14372 +59 13445 +60 13514 +61 12058 +62 11145 +63 10767 +64 10370 +65 9630 +66 9337 +67 8881 +68 8727 +69 8060 +70 7994 +71 7740 +72 7189 +73 6729 +74 6749 +75 6548 +76 6321 +77 5957 +78 5740 +79 5407 +80 5370 +81 5035 +82 4921 +83 4656 +84 4600 +85 4519 +86 4277 +87 4023 +88 3939 +89 3910 +90 3861 +91 3560 +92 3483 +93 3406 +94 3346 +95 3229 +96 3122 +97 3086 +98 3001 +99 2884 +100 2822 +101 2677 +102 2670 +103 2610 +104 2452 +105 2446 +106 2400 +107 2300 +108 2316 +109 2196 +110 2089 +111 2083 +112 2041 +113 1881 +114 1838 +115 1896 +116 1795 +117 1786 +118 1743 +119 1765 +120 1750 +121 1683 +122 1563 +123 1499 +124 1513 +125 1462 +126 1388 +127 1441 +128 1417 +129 1392 +130 1306 +131 1321 +132 1274 +133 1294 +134 1240 +135 1126 +136 1157 +137 1130 +138 1084 +139 1130 +140 1083 +141 1040 +142 980 +143 1031 +144 974 +145 980 +146 932 +147 898 +148 960 +149 907 +150 852 +151 912 +152 859 +153 847 +154 876 +155 792 +156 791 +157 765 +158 788 +159 787 +160 744 +161 673 +162 683 +163 697 +164 666 +165 680 +166 632 +167 677 +168 657 +169 618 +170 587 +171 585 +172 567 +173 549 +174 562 +175 548 +176 542 +177 539 +178 542 +179 549 +180 547 +181 526 +182 525 +183 514 +184 512 +185 505 +186 515 +187 467 +188 475 +189 458 +190 435 +191 443 +192 427 +193 424 +194 404 +195 389 +196 429 +197 404 +198 386 +199 351 +200 388 +201 408 +202 361 +203 346 +204 324 +205 361 +206 363 +207 364 +208 323 +209 336 +210 342 +211 315 +212 325 +213 328 +214 314 +215 327 +216 320 +217 300 +218 295 +219 315 +220 310 +221 295 +222 275 +223 248 +224 274 +225 232 +226 293 +227 259 +228 286 +229 263 +230 242 +231 214 +232 261 +233 231 +234 211 +235 250 +236 233 +237 206 +238 224 +239 210 +240 233 +241 223 +242 216 +243 222 +244 207 +245 212 +246 196 +247 205 +248 201 +249 202 +250 211 +251 201 +252 215 +253 179 +254 163 +255 179 +256 191 +257 188 +258 196 +259 150 +260 154 +261 176 +262 211 +263 166 +264 171 +265 165 +266 149 +267 182 +268 159 +269 161 +270 164 +271 161 +272 141 +273 151 +274 127 +275 129 +276 142 +277 158 +278 148 +279 135 +280 127 +281 134 +282 138 +283 131 +284 126 +285 125 +286 130 +287 126 +288 135 +289 125 +290 135 +291 131 +292 95 +293 135 +294 106 +295 117 +296 136 +297 128 +298 128 +299 118 +300 109 +301 112 +302 117 +303 108 +304 120 +305 100 +306 95 +307 108 +308 112 +309 77 +310 120 +311 104 +312 109 +313 89 +314 98 +315 82 +316 98 +317 93 +318 77 +319 93 +320 77 +321 98 +322 93 +323 86 +324 89 +325 73 +326 70 +327 71 +328 77 +329 87 +330 77 +331 93 +332 100 +333 83 +334 72 +335 74 +336 69 +337 77 +338 68 +339 78 +340 90 +341 98 +342 75 +343 80 +344 63 +345 71 +346 83 +347 66 +348 71 +349 70 +350 62 +351 62 +352 59 +353 63 +354 62 +355 52 +356 64 +357 64 +358 56 +359 49 +360 57 +361 63 +362 60 +363 68 +364 62 +365 55 +366 54 +367 40 +368 75 +369 70 +370 53 +371 58 +372 57 +373 55 +374 69 +375 57 +376 53 +377 43 +378 45 +379 47 +380 56 +381 51 +382 59 +383 51 +384 43 +385 34 +386 57 +387 49 +388 39 +389 46 +390 48 +391 43 +392 40 +393 54 +394 50 +395 41 +396 43 +397 33 +398 27 +399 49 +400 44 +401 44 +402 38 +403 30 +404 32 +405 37 +406 39 +407 42 +408 53 +409 39 +410 34 +411 31 +412 32 +413 52 +414 27 +415 41 +416 34 +417 36 +418 50 +419 35 +420 32 +421 33 +422 45 +423 35 +424 40 +425 29 +426 41 +427 40 +428 39 +429 32 +430 31 +431 34 +432 29 +433 27 +434 26 +435 22 +436 34 +437 28 +438 30 +439 38 +440 35 +441 36 +442 36 +443 27 +444 24 +445 33 +446 31 +447 25 +448 33 +449 27 +450 32 +451 46 +452 31 +453 35 +454 35 +455 34 +456 26 +457 21 +458 25 +459 26 +460 24 +461 27 +462 33 +463 30 +464 35 +465 21 +466 32 +467 19 +468 27 +469 16 +470 28 +471 26 +472 27 +473 26 +474 25 +475 25 +476 27 +477 20 +478 28 +479 22 +480 23 +481 16 +482 25 +483 27 +484 19 +485 23 +486 19 +487 15 +488 15 +489 23 +490 24 +491 19 +492 20 +493 18 +494 17 +495 30 +496 28 +497 20 +498 29 +499 17 +500 19 +501 21 +502 15 +503 24 +504 15 +505 19 +506 25 +507 16 +508 23 +509 26 +510 21 +511 15 +512 12 +513 16 +514 18 +515 24 +516 26 +517 18 +518 8 +519 25 +520 14 +521 8 +522 24 +523 20 +524 18 +525 15 +526 13 +527 17 +528 18 +529 22 +530 21 +531 9 +532 16 +533 17 +534 13 +535 17 +536 15 +537 13 +538 20 +539 13 +540 19 +541 29 +542 10 +543 8 +544 18 +545 13 +546 9 +547 18 +548 10 +549 18 +550 18 +551 9 +552 9 +553 15 +554 13 +555 15 +556 14 +557 14 +558 18 +559 8 +560 13 +561 9 +562 7 +563 12 +564 6 +565 9 +566 9 +567 18 +568 9 +569 10 +570 13 +571 14 +572 13 +573 21 +574 8 +575 16 +576 12 +577 9 +578 16 +579 17 +580 22 +581 6 +582 14 +583 13 +584 15 +585 11 +586 13 +587 5 +588 12 +589 13 +590 15 +591 13 +592 15 +593 12 +594 7 +595 18 +596 12 +597 13 +598 13 +599 13 +600 12 +601 12 +602 10 +603 11 +604 6 +605 6 +606 2 +607 9 +608 8 +609 12 +610 9 +611 12 +612 13 +613 12 +614 14 +615 9 +616 8 +617 9 +618 14 +619 13 +620 12 +621 6 +622 8 +623 8 +624 8 +625 12 +626 8 +627 7 +628 5 +629 8 +630 12 +631 6 +632 10 +633 10 +634 7 +635 8 +636 9 +637 6 +638 9 +639 4 +640 12 +641 4 +642 3 +643 11 +644 10 +645 6 +646 12 +647 12 +648 4 +649 4 +650 9 +651 8 +652 6 +653 5 +654 14 +655 10 +656 11 +657 8 +658 5 +659 5 +660 9 +661 13 +662 4 +663 5 +664 9 +665 11 +666 12 +667 7 +668 13 +669 2 +670 1 +671 7 +672 7 +673 7 +674 10 +675 9 +676 6 +677 5 +678 7 +679 6 +680 3 +681 3 +682 4 +683 9 +684 8 +685 5 +686 3 +687 11 +688 9 +689 2 +690 6 +691 5 +692 9 +693 5 +694 6 +695 5 +696 9 +697 8 +698 3 +699 7 +700 5 +701 9 +702 8 +703 7 +704 2 +705 3 +706 7 +707 6 +708 6 +709 10 +710 2 +711 10 +712 6 +713 7 +714 5 +715 6 +716 4 +717 6 +718 8 +719 4 +720 6 +721 7 +722 5 +723 7 +724 3 +725 10 +726 10 +727 3 +728 7 +729 7 +730 5 +731 2 +732 1 +733 5 +734 1 +735 5 +736 6 +737 2 +738 2 +739 3 +740 7 +741 2 +742 7 +743 4 +744 5 +745 4 +746 5 +747 3 +748 1 +749 4 +750 4 +751 2 +752 4 +753 6 +754 6 +755 6 +756 3 +757 2 +758 5 +759 5 +760 3 +761 4 +762 2 +763 1 +764 8 +765 3 +766 4 +767 3 +768 1 +769 5 +770 3 +771 3 +772 4 +773 4 +774 1 +775 3 +776 2 +777 2 +778 3 +779 3 +780 1 +781 4 +782 3 +783 4 +784 6 +785 3 +786 5 +787 4 +788 2 +789 4 +790 5 +791 4 +792 6 +794 4 +795 1 +796 1 +797 4 +798 2 +799 3 +800 3 +801 1 +802 5 +803 5 +804 3 +805 3 +806 3 +807 4 +808 4 +809 2 +811 5 +812 4 +813 6 +814 3 +815 2 +816 2 +817 3 +818 5 +819 3 +820 1 +821 1 +822 4 +823 3 +824 4 +825 8 +826 3 +827 5 +828 5 +829 3 +830 6 +831 3 +832 4 +833 8 +834 5 +835 3 +836 3 +837 2 +838 4 +839 2 +840 1 +841 3 +842 2 +843 1 +844 3 +846 4 +847 4 +848 3 +849 3 +850 2 +851 3 +853 1 +854 4 +855 4 +856 2 +857 4 +858 1 +859 2 +860 5 +861 1 +862 1 +863 4 +864 2 +865 2 +867 5 +868 1 +869 4 +870 1 +871 1 +872 1 +873 2 +875 5 +876 3 +877 1 +878 3 +879 3 +880 3 +881 2 +882 1 +883 6 +884 2 +885 2 +886 1 +887 1 +888 3 +889 2 +890 2 +891 3 +892 1 +893 3 +894 1 +895 5 +896 1 +897 3 +899 2 +900 2 +902 1 +903 2 +904 4 +905 4 +906 3 +907 1 +908 1 +909 2 +910 5 +911 2 +912 3 +914 1 +915 1 +916 2 +918 2 +919 2 +920 4 +921 4 +922 1 +923 1 +924 4 +925 5 +926 1 +928 2 +929 1 +930 1 +931 1 +932 1 +933 1 +934 2 +935 1 +936 1 +937 1 +938 2 +939 1 +941 1 +942 4 +944 2 +945 2 +946 2 +947 1 +948 1 +950 1 +951 2 +953 1 +954 2 +955 1 +956 1 +957 2 +958 1 +960 3 +962 4 +963 1 +964 1 +965 3 +966 2 +967 2 +968 1 +969 3 +970 3 +972 1 +974 4 +975 3 +976 3 +977 2 +979 2 +980 1 +981 1 +983 5 +984 1 +985 3 +986 1 +987 2 +988 4 +989 2 +991 2 +992 2 +993 1 +994 1 +996 2 +997 2 +998 1 +999 3 +1000 2 +1001 1 +1002 3 +1003 3 +1004 2 +1005 3 +1006 1 +1007 2 +1009 1 +1011 1 +1013 3 +1014 1 +1016 2 +1017 1 +1018 1 +1019 1 +1020 4 +1021 1 +1022 2 +1025 1 +1026 1 +1027 2 +1028 1 +1030 1 +1031 2 +1032 4 +1034 3 +1035 2 +1036 1 +1038 1 +1039 1 +1040 1 +1041 1 +1042 2 +1043 1 +1044 2 +1045 4 +1048 1 +1050 1 +1051 1 +1052 2 +1054 1 +1055 3 +1056 2 +1057 1 +1059 1 +1061 2 +1063 1 +1064 1 +1065 1 +1066 1 +1067 1 +1068 1 +1069 2 +1074 1 +1075 1 +1077 1 +1078 1 +1079 1 +1082 1 +1085 1 +1088 1 +1090 1 +1091 1 +1092 2 +1094 2 +1097 2 +1098 1 +1099 2 +1101 2 +1102 1 +1104 1 +1105 1 +1107 1 +1109 1 +1111 2 +1112 1 +1114 2 +1115 2 +1116 2 +1117 1 +1118 1 +1119 1 +1120 1 +1122 1 +1123 1 +1127 1 +1128 3 +1132 2 +1138 3 +1142 1 +1145 4 +1150 1 +1153 2 +1154 1 +1158 1 +1159 1 +1163 1 +1165 1 +1169 2 +1174 1 +1176 1 +1177 1 +1178 2 +1179 1 +1180 2 +1181 1 +1182 1 +1183 2 +1185 1 +1187 1 +1191 2 +1193 1 +1195 3 +1196 1 +1201 3 +1203 1 +1206 1 +1210 1 +1213 1 +1214 1 +1215 2 +1218 1 +1220 1 +1221 1 +1225 1 +1226 1 +1233 2 +1241 1 +1243 1 +1249 1 +1250 2 +1251 1 +1254 1 +1255 2 +1260 1 +1268 1 +1270 1 +1273 1 +1274 1 +1277 1 +1284 1 +1287 1 +1291 1 +1292 2 +1294 1 +1295 2 +1297 1 +1298 1 +1301 1 +1307 1 +1308 3 +1311 2 +1313 1 +1316 1 +1321 1 +1324 1 +1325 1 +1330 1 +1333 1 +1334 1 +1338 2 +1340 1 +1341 1 +1342 1 +1343 1 +1345 1 +1355 1 +1357 1 +1360 2 +1375 1 +1376 1 +1380 1 +1383 1 +1387 1 +1389 1 +1393 1 +1394 1 +1396 1 +1398 1 +1410 1 +1414 1 +1419 1 +1425 1 +1434 1 +1435 1 +1438 1 +1439 1 +1447 1 +1455 2 +1460 1 +1461 1 +1463 1 +1466 1 +1470 1 +1473 1 +1478 1 +1480 1 +1483 1 +1484 1 +1485 2 +1492 2 +1499 1 +1509 1 +1512 1 +1513 1 +1523 1 +1524 1 +1525 2 +1529 1 +1539 1 +1544 1 +1568 1 +1584 1 +1591 1 +1598 1 +1600 1 +1604 1 +1614 1 +1617 1 +1621 1 +1622 1 +1626 1 +1638 1 +1648 1 +1658 1 +1661 1 +1679 1 +1682 1 +1693 1 +1700 1 +1705 1 +1707 1 +1722 1 +1728 1 +1758 1 +1762 1 +1763 1 +1775 1 +1776 1 +1801 1 +1810 1 +1812 1 +1827 1 +1834 1 +1846 1 +1847 1 +1848 1 +1851 1 +1862 1 +1866 1 +1877 2 +1884 1 +1888 1 +1903 1 +1912 1 +1925 1 +1938 1 +1955 1 +1998 1 +2054 1 +2058 1 +2065 1 +2069 1 +2076 1 +2089 1 +2104 1 +2111 1 +2133 1 +2138 1 +2156 1 +2204 1 +2212 1 +2237 1 +2246 2 +2298 1 +2304 1 +2360 1 +2400 1 +2481 1 +2544 1 +2586 1 +2622 1 +2666 1 +2682 1 +2725 1 +2920 1 +3997 1 +4019 1 +5211 1 +12 19 +14 1 +16 401 +18 2 +20 421 +22 557 +24 625 +26 50 +28 4481 +30 52 +32 550 +34 5840 +36 4644 +38 87 +40 5794 +41 33 +42 571 +44 11805 +46 4711 +47 7 +48 597 +49 12 +50 678 +51 2 +52 14715 +53 3 +54 7322 +55 3 +56 508 +57 39 +58 3486 +59 11 +60 8974 +61 45 +62 1276 +63 4 +64 15693 +65 15 +66 657 +67 13 +68 6409 +69 10 +70 3188 +71 25 +72 1889 +73 27 +74 10370 +75 9 +76 12432 +77 23 +78 520 +79 15 +80 1534 +81 29 +82 2944 +83 23 +84 12071 +85 36 +86 1502 +87 10 +88 10978 +89 11 +90 889 +91 16 +92 4571 +93 17 +94 7855 +95 21 +96 2271 +97 33 +98 1423 +99 15 +100 11096 +101 21 +102 4082 +103 13 +104 5442 +105 25 +106 2113 +107 26 +108 3779 +109 43 +110 1294 +111 29 +112 7860 +113 29 +114 4965 +115 22 +116 7898 +117 25 +118 1772 +119 28 +120 1149 +121 38 +122 1483 +123 32 +124 10572 +125 25 +126 1147 +127 31 +128 1699 +129 22 +130 5533 +131 22 +132 4669 +133 34 +134 3777 +135 10 +136 5412 +137 21 +138 855 +139 26 +140 2485 +141 46 +142 1970 +143 27 +144 6565 +145 40 +146 933 +147 15 +148 7923 +149 16 +150 735 +151 23 +152 1111 +153 33 +154 3714 +155 27 +156 2445 +157 30 +158 3367 +159 10 +160 4646 +161 27 +162 990 +163 23 +164 5679 +165 25 +166 2186 +167 17 +168 899 +169 32 +170 1034 +171 22 +172 6185 +173 32 +174 2685 +175 17 +176 1354 +177 38 +178 1460 +179 15 +180 3478 +181 20 +182 958 +183 20 +184 6055 +185 23 +186 2180 +187 15 +188 1416 +189 30 +190 1284 +191 22 +192 1341 +193 21 +194 2413 +195 18 +196 4984 +197 13 +198 830 +199 22 +200 1834 +201 19 +202 2238 +203 9 +204 3050 +205 22 +206 616 +207 17 +208 2892 +209 22 +210 711 +211 30 +212 2631 +213 19 +214 3341 +215 21 +216 987 +217 26 +218 823 +219 9 +220 3588 +221 20 +222 692 +223 7 +224 2925 +225 31 +226 1075 +227 16 +228 2909 +229 18 +230 673 +231 20 +232 2215 +233 14 +234 1584 +235 21 +236 1292 +237 29 +238 1647 +239 25 +240 1014 +241 30 +242 1648 +243 19 +244 4465 +245 10 +246 787 +247 11 +248 480 +249 25 +250 842 +251 15 +252 1219 +253 23 +254 1508 +255 8 +256 3525 +257 16 +258 490 +259 12 +260 1678 +261 14 +262 822 +263 16 +264 1729 +265 28 +266 604 +267 11 +268 2572 +269 7 +270 1242 +271 15 +272 725 +273 18 +274 1983 +275 13 +276 1662 +277 19 +278 491 +279 12 +280 1586 +281 14 +282 563 +283 10 +284 2363 +285 10 +286 656 +287 14 +288 725 +289 28 +290 871 +291 9 +292 2606 +293 12 +294 961 +295 9 +296 478 +297 13 +298 1252 +299 10 +300 736 +301 19 +302 466 +303 13 +304 2254 +305 12 +306 486 +307 14 +308 1145 +309 13 +310 955 +311 13 +312 1235 +313 13 +314 931 +315 14 +316 1768 +317 11 +318 330 +319 10 +320 539 +321 23 +322 570 +323 12 +324 1789 +325 13 +326 884 +327 5 +328 1422 +329 14 +330 317 +331 11 +332 509 +333 13 +334 1062 +335 12 +336 577 +337 27 +338 378 +339 10 +340 2313 +341 9 +342 391 +343 13 +344 894 +345 17 +346 664 +347 9 +348 453 +349 6 +350 363 +351 15 +352 1115 +353 13 +354 1054 +355 8 +356 1108 +357 12 +358 354 +359 7 +360 363 +361 16 +362 344 +363 11 +364 1734 +365 12 +366 265 +367 10 +368 969 +369 16 +370 316 +371 12 +372 757 +373 7 +374 563 +375 15 +376 857 +377 9 +378 469 +379 9 +380 385 +381 12 +382 921 +383 15 +384 764 +385 14 +386 246 +387 6 +388 1108 +389 14 +390 230 +391 8 +392 266 +393 11 +394 641 +395 8 +396 719 +397 9 +398 243 +399 4 +400 1108 +401 7 +402 229 +403 7 +404 903 +405 7 +406 257 +407 12 +408 244 +409 3 +410 541 +411 6 +412 744 +413 8 +414 419 +415 8 +416 388 +417 19 +418 470 +419 14 +420 612 +421 6 +422 342 +423 3 +424 1179 +425 3 +426 116 +427 14 +428 207 +429 6 +430 255 +431 4 +432 288 +433 12 +434 343 +435 6 +436 1015 +437 3 +438 538 +439 10 +440 194 +441 6 +442 188 +443 15 +444 524 +445 7 +446 214 +447 7 +448 574 +449 6 +450 214 +451 5 +452 635 +453 9 +454 464 +455 5 +456 205 +457 9 +458 163 +459 2 +460 558 +461 4 +462 171 +463 14 +464 444 +465 11 +466 543 +467 5 +468 388 +469 6 +470 141 +471 4 +472 647 +473 3 +474 210 +475 4 +476 193 +477 7 +478 195 +479 7 +480 443 +481 10 +482 198 +483 3 +484 816 +485 6 +486 128 +487 9 +488 215 +489 9 +490 328 +491 7 +492 158 +493 11 +494 335 +495 8 +496 435 +497 6 +498 174 +499 1 +500 373 +501 5 +502 140 +503 7 +504 330 +505 9 +506 149 +507 5 +508 642 +509 3 +510 179 +511 3 +512 159 +513 8 +514 204 +515 7 +516 306 +517 4 +518 110 +519 5 +520 326 +521 6 +522 305 +523 6 +524 294 +525 7 +526 268 +527 5 +528 149 +529 4 +530 133 +531 2 +532 513 +533 10 +534 116 +535 5 +536 258 +537 4 +538 113 +539 4 +540 138 +541 6 +542 116 +544 485 +545 4 +546 93 +547 9 +548 299 +549 3 +550 256 +551 6 +552 92 +553 3 +554 175 +555 6 +556 253 +557 7 +558 95 +559 2 +560 128 +561 4 +562 206 +563 2 +564 465 +565 3 +566 69 +567 3 +568 157 +569 7 +570 97 +571 8 +572 118 +573 5 +574 130 +575 4 +576 301 +577 6 +578 177 +579 2 +580 397 +581 3 +582 80 +583 1 +584 128 +585 5 +586 52 +587 2 +588 72 +589 1 +590 84 +591 6 +592 323 +593 11 +594 77 +595 5 +596 205 +597 1 +598 244 +599 4 +600 69 +601 3 +602 89 +603 5 +604 254 +605 6 +606 147 +607 3 +608 83 +609 3 +610 77 +611 3 +612 194 +613 1 +614 98 +615 3 +616 243 +617 3 +618 50 +619 8 +620 188 +621 4 +622 67 +623 4 +624 123 +625 2 +626 50 +627 1 +628 239 +629 2 +630 51 +631 4 +632 65 +633 5 +634 188 +636 81 +637 3 +638 46 +639 3 +640 103 +641 1 +642 136 +643 3 +644 188 +645 3 +646 58 +648 122 +649 4 +650 47 +651 2 +652 155 +653 4 +654 71 +655 1 +656 71 +657 3 +658 50 +659 2 +660 177 +661 5 +662 66 +663 2 +664 183 +665 3 +666 50 +667 2 +668 53 +669 2 +670 115 +672 66 +673 2 +674 47 +675 1 +676 197 +677 2 +678 46 +679 3 +680 95 +681 3 +682 46 +683 3 +684 107 +685 1 +686 86 +687 2 +688 158 +689 4 +690 51 +691 1 +692 80 +694 56 +695 4 +696 40 +698 43 +699 3 +700 95 +701 2 +702 51 +703 2 +704 133 +705 1 +706 100 +707 2 +708 121 +709 2 +710 15 +711 3 +712 35 +713 2 +714 20 +715 3 +716 37 +717 2 +718 78 +720 55 +721 1 +722 42 +723 2 +724 218 +725 3 +726 23 +727 2 +728 26 +729 1 +730 64 +731 2 +732 65 +734 24 +735 2 +736 53 +737 1 +738 32 +739 1 +740 60 +742 81 +743 1 +744 77 +745 1 +746 47 +747 1 +748 62 +749 1 +750 19 +751 1 +752 86 +753 3 +754 40 +756 55 +757 2 +758 38 +759 1 +760 101 +761 1 +762 22 +764 67 +765 2 +766 35 +767 1 +768 38 +769 1 +770 22 +771 1 +772 82 +773 1 +774 73 +776 29 +777 1 +778 55 +780 23 +781 1 +782 16 +784 84 +785 3 +786 28 +788 59 +789 1 +790 33 +791 3 +792 24 +794 13 +795 1 +796 110 +797 2 +798 15 +800 22 +801 3 +802 29 +803 1 +804 87 +806 21 +808 29 +810 48 +812 28 +813 1 +814 58 +815 1 +816 48 +817 1 +818 31 +819 1 +820 66 +822 17 +823 2 +824 58 +826 10 +827 2 +828 25 +829 1 +830 29 +831 1 +832 63 +833 1 +834 26 +835 3 +836 52 +837 1 +838 18 +840 27 +841 2 +842 12 +843 1 +844 83 +845 1 +846 7 +847 1 +848 10 +850 26 +852 25 +853 1 +854 15 +856 27 +858 32 +859 1 +860 15 +862 43 +864 32 +865 1 +866 6 +868 39 +870 11 +872 25 +873 1 +874 10 +875 1 +876 20 +877 2 +878 19 +879 1 +880 30 +882 11 +884 53 +886 25 +887 1 +888 28 +890 6 +892 36 +894 10 +896 13 +898 14 +900 31 +902 14 +903 2 +904 43 +906 25 +908 9 +910 11 +911 1 +912 16 +913 1 +914 24 +916 27 +918 6 +920 15 +922 27 +923 1 +924 23 +926 13 +928 42 +929 1 +930 3 +932 27 +934 17 +936 8 +937 1 +938 11 +940 33 +942 4 +943 1 +944 18 +946 15 +948 13 +950 18 +952 12 +954 11 +956 21 +958 10 +960 13 +962 5 +964 32 +966 13 +968 8 +970 8 +971 1 +972 23 +973 2 +974 12 +975 1 +976 22 +978 7 +979 1 +980 14 +982 8 +984 22 +985 1 +986 6 +988 17 +989 1 +990 6 +992 13 +994 19 +996 11 +998 4 +1000 9 +1002 2 +1004 14 +1006 5 +1008 3 +1010 9 +1012 29 +1014 6 +1016 22 +1017 1 +1018 8 +1019 1 +1020 7 +1022 6 +1023 1 +1024 10 +1026 2 +1028 8 +1030 11 +1031 2 +1032 8 +1034 9 +1036 13 +1038 12 +1040 12 +1042 3 +1044 12 +1046 3 +1048 11 +1050 2 +1051 1 +1052 2 +1054 11 +1056 6 +1058 8 +1059 1 +1060 23 +1062 6 +1063 1 +1064 8 +1066 3 +1068 6 +1070 8 +1071 1 +1072 5 +1074 3 +1076 5 +1078 3 +1080 11 +1081 1 +1082 7 +1084 18 +1086 4 +1087 1 +1088 3 +1090 3 +1092 7 +1094 3 +1096 12 +1098 6 +1099 1 +1100 2 +1102 6 +1104 14 +1106 3 +1108 6 +1110 5 +1112 2 +1114 8 +1116 3 +1118 3 +1120 7 +1122 10 +1124 6 +1126 8 +1128 1 +1130 4 +1132 3 +1134 2 +1136 5 +1138 5 +1140 8 +1142 3 +1144 7 +1146 3 +1148 11 +1150 1 +1152 5 +1154 1 +1156 5 +1158 1 +1160 5 +1162 3 +1164 6 +1165 1 +1166 1 +1168 4 +1169 1 +1170 3 +1171 1 +1172 2 +1174 5 +1176 3 +1177 1 +1180 8 +1182 2 +1184 4 +1186 2 +1188 3 +1190 2 +1192 5 +1194 6 +1196 1 +1198 2 +1200 2 +1204 10 +1206 2 +1208 9 +1210 1 +1214 6 +1216 3 +1218 4 +1220 9 +1221 2 +1222 1 +1224 5 +1226 4 +1228 8 +1230 1 +1232 1 +1234 3 +1236 5 +1240 3 +1242 1 +1244 3 +1245 1 +1246 4 +1248 6 +1250 2 +1252 7 +1256 3 +1258 2 +1260 2 +1262 3 +1264 4 +1265 1 +1266 1 +1270 1 +1271 1 +1272 2 +1274 3 +1276 3 +1278 1 +1280 3 +1284 1 +1286 1 +1290 1 +1292 3 +1294 1 +1296 7 +1300 2 +1302 4 +1304 3 +1306 2 +1308 2 +1312 1 +1314 1 +1316 3 +1318 2 +1320 1 +1324 8 +1326 1 +1330 1 +1331 1 +1336 2 +1338 1 +1340 3 +1341 1 +1344 1 +1346 2 +1347 1 +1348 3 +1352 1 +1354 2 +1356 1 +1358 1 +1360 3 +1362 1 +1364 4 +1366 1 +1370 1 +1372 3 +1380 2 +1384 2 +1388 2 +1390 2 +1392 2 +1394 1 +1396 1 +1398 1 +1400 2 +1402 1 +1404 1 +1406 1 +1410 1 +1412 5 +1418 1 +1420 1 +1424 1 +1432 2 +1434 2 +1442 3 +1444 5 +1448 1 +1454 1 +1456 1 +1460 3 +1462 4 +1468 1 +1474 1 +1476 1 +1478 2 +1480 1 +1486 2 +1488 1 +1492 1 +1496 1 +1500 3 +1503 1 +1506 1 +1512 2 +1516 1 +1522 1 +1524 2 +1534 4 +1536 1 +1538 1 +1540 2 +1544 2 +1548 1 +1556 1 +1560 1 +1562 1 +1564 2 +1566 1 +1568 1 +1570 1 +1572 1 +1576 1 +1590 1 +1594 1 +1604 1 +1608 1 +1614 1 +1622 1 +1624 2 +1628 1 +1629 1 +1636 1 +1642 1 +1654 2 +1660 1 +1664 1 +1670 1 +1684 4 +1698 1 +1732 3 +1742 1 +1752 1 +1760 1 +1764 1 +1772 2 +1798 1 +1808 1 +1820 1 +1852 1 +1856 1 +1874 1 +1902 1 +1908 1 +1952 1 +2004 1 +2018 1 +2020 1 +2028 1 +2174 1 +2233 1 +2244 1 +2280 1 +2290 1 +2352 1 +2604 1 +4190 1 diff --git a/ppocr/utils/table_utils/matcher.py b/ppocr/utils/table_utils/matcher.py new file mode 100755 index 0000000000000000000000000000000000000000..711806aa872f80aba2428627c93d0086ec0372ca --- /dev/null +++ b/ppocr/utils/table_utils/matcher.py @@ -0,0 +1,214 @@ +import json +def distance(box_1, box_2): + x1, y1, x2, y2 = box_1 + x3, y3, x4, y4 = box_2 + # min_x = (x1 + x2) / 2 + # min_y = (y1 + y2) / 2 + # max_x = (x3 + x4) / 2 + # max_y = (y3 + y4) / 2 + dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2) + dis_2 = abs(x3 - x1) + abs(y3 - y1) + dis_3 = abs(x4- x2) + abs(y4 - y2) + #dis = pow(min_x - max_x, 2) + pow(min_y - max_y, 2) + pow(x3 - x1, 2) + pow(y3 - y1, 2) + pow(x4- x2, 2) + pow(y4 - y2, 2) + abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2) + return dis + min(dis_2, dis_3) + +def compute_iou(rec1, rec2): + """ + computing IoU + :param rec1: (y0, x0, y1, x1), which reflects + (top, left, bottom, right) + :param rec2: (y0, x0, y1, x1) + :return: scala value of IoU + """ + # computing area of each rectangles + rec1, rec2 = rec1 * 1000, rec2 * 1000 + S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) + S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) + + # computing the sum_area + sum_area = S_rec1 + S_rec2 + + # find the each edge of intersect rectangle + left_line = max(rec1[1], rec2[1]) + right_line = min(rec1[3], rec2[3]) + top_line = max(rec1[0], rec2[0]) + bottom_line = min(rec1[2], rec2[2]) + + # judge if there is an intersect + if left_line >= right_line or top_line >= bottom_line: + return 0 + else: + intersect = (right_line - left_line) * (bottom_line - top_line) + return (intersect / (sum_area - intersect))*1.0 + + + +def matcher_merge(ocr_bboxes, pred_bboxes): # ocr_bboxes: OCR pred_bboxes:端到端 + all_dis = [] + ious = [] + matched = {} + for i, gt_box in enumerate(ocr_bboxes): + distances = [] + for j, pred_box in enumerate(pred_bboxes): + distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) #获取两两cell之间的L1距离和 1- IOU + sorted_distances = distances.copy() + # 根据距离和IOU挑选最"近"的cell + sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0])) + if distances.index(sorted_distances[0]) not in matched.keys(): + matched[distances.index(sorted_distances[0])] = [i] + else: + matched[distances.index(sorted_distances[0])].append(i) + return matched#, sum(ious) / len(ious) +def complex_num(pred_bboxes): + complex_nums = [] + for bbox in pred_bboxes: + distances = [] + temp_ious = [] + for pred_bbox in pred_bboxes: + if bbox != pred_bbox: + distances.append(distance(bbox, pred_bbox)) + temp_ious.append(compute_iou(bbox, pred_bbox)) + complex_nums.append(temp_ious[distances.index(min(distances))]) + return sum(complex_nums) / len(complex_nums) + +def get_rows(pred_bboxes): + pre_bbox = pred_bboxes[0] + res = [] + step = 0 + for i in range(len(pred_bboxes)): + bbox = pred_bboxes[i] + if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0: + break + else: + res.append(bbox) + step += 1 + for i in range(step): + pred_bboxes.pop(0) + return res, pred_bboxes +def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上 + ys_1 = [] + ys_2 = [] + for box in pred_bboxes: + ys_1.append(box[1]) + ys_2.append(box[3]) + min_y_1 = sum(ys_1) / len(ys_1) + min_y_2 = sum(ys_2) / len(ys_2) + re_boxes = [] + for box in pred_bboxes: + box[1] = min_y_1 + box[3] = min_y_2 + re_boxes.append(box) + return re_boxes + +def matcher_refine_row(gt_bboxes, pred_bboxes): + before_refine_pred_bboxes = pred_bboxes.copy() + pred_bboxes = [] + while(len(before_refine_pred_bboxes) != 0): + row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes) + print(row_bboxes) + pred_bboxes.extend(refine_rows(row_bboxes)) + all_dis = [] + ious = [] + matched = {} + for i, gt_box in enumerate(gt_bboxes): + distances = [] + #temp_ious = [] + for j, pred_box in enumerate(pred_bboxes): + distances.append(distance(gt_box, pred_box)) + #temp_ious.append(compute_iou(gt_box, pred_box)) + #all_dis.append(min(distances)) + #ious.append(temp_ious[distances.index(min(distances))]) + if distances.index(min(distances)) not in matched.keys(): + matched[distances.index(min(distances))] = [i] + else: + matched[distances.index(min(distances))].append(i) + return matched#, sum(ious) / len(ious) + + + +#先挑选出一行,再进行匹配 +def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes): + gt_box_index = 0 + delete_gt_bboxes = gt_bboxes.copy() + match_bboxes_ready = [] + matched = {} + while(len(delete_gt_bboxes) != 0): + row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes) + row_bboxes = sorted(row_bboxes, key = lambda key: key[0]) + if len(pred_bboxes_rows) > 0: + match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) + print(row_bboxes) + for i, gt_box in enumerate(row_bboxes): + #print(gt_box) + pred_distances = [] + distances = [] + for pred_bbox in pred_bboxes: + pred_distances.append(distance(gt_box, pred_bbox)) + for j, pred_box in enumerate(match_bboxes_ready): + distances.append(distance(gt_box, pred_box)) + index = pred_distances.index(min(distances)) + #print('index', index) + if index not in matched.keys(): + matched[index] = [gt_box_index] + else: + matched[index].append(gt_box_index) + gt_box_index += 1 + return matched + +def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes): + ''' + gt_bboxes: 排序后 + pred_bboxes: + ''' + pre_bbox = gt_bboxes[0] + matched = {} + match_bboxes_ready = [] + match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) + for i, gt_box in enumerate(gt_bboxes): + + pred_distances = [] + for pred_bbox in pred_bboxes: + pred_distances.append(distance(gt_box, pred_bbox)) + distances = [] + gap_pre = gt_box[1] - pre_bbox[1] + gap_pre_1 = gt_box[0] - pre_bbox[2] + #print(gap_pre, len(pred_bboxes_rows)) + if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0): + match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) + if len(pred_bboxes_rows) == 1: + match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) + if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0: + match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) + if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0: + break + #print(match_bboxes_ready) + for j, pred_box in enumerate(match_bboxes_ready): + distances.append(distance(gt_box, pred_box)) + index = pred_distances.index(min(distances)) + #print(gt_box, index) + #match_bboxes_ready.pop(distances.index(min(distances))) + print(gt_box, match_bboxes_ready[distances.index(min(distances))]) + if index not in matched.keys(): + matched[index] = [i] + else: + matched[index].append(i) + pre_bbox = gt_box + return matched + + +def main(): + detect_bboxes = json.load(open('./f_detecion_bbox.json')) + gt_bboxes = json.load(open('./f_gt_bbox.json')) + all_node = 0 + matched_right = 0 + key = 'PMC4796501_003_00.png' + print(key) + gt_bbox = gt_bboxes[key] + pred_bbox = detect_bboxes[key] + matched = matcher(gt_bbox, pred_bbox) + print(matched) + + +if __name__ == "__main__": + main() + diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..cd2ff0fbfc370e8d673b575686cd1fdea8594a7f 100644 --- a/ppstructure/predict_system.py +++ b/ppstructure/predict_system.py @@ -0,0 +1,123 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import subprocess + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' +import cv2 +import copy +import numpy as np +import time +import tools.infer.utility as utility +from tools.infer.predict_system import TextSystem +from ppstructure.table.predict_table import TableSystem, to_excel +from ppstructure.layout.predict_layout import LayoutDetector +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppocr.utils.logging import get_logger + +logger = get_logger() + + +def parse_args(): + parser = utility.init_args() + + # params for table structure + parser.add_argument("--table_max_len", type=int, default=488) + parser.add_argument("--table_max_text_length", type=int, default=100) + parser.add_argument("--table_max_elem_length", type=int, default=800) + parser.add_argument("--table_max_cell_num", type=int, default=500) + parser.add_argument("--table_model_dir", type=str) + parser.add_argument("--table_char_type", type=str, default='en') + parser.add_argument("--table_char_dict_path", type=str, default="./ppocr/utils/dict/table_structure_dict.txt") + + # params for layout detector + parser.add_argument("--layout_model_dir", type=str) + return parser.parse_args() + + +class OCRSystem(): + def __init__(self, args): + self.text_system = TextSystem(args) + self.table_system = TableSystem(args) + self.table_layout = LayoutDetector(args) + self.use_angle_cls = args.use_angle_cls + self.drop_score = args.drop_score + + def __call__(self, img): + ori_im = img.copy() + layout_res = self.table_layout(copy.deepcopy(img)) + for region in layout_res: + x1, y1, x2, y2 = region['bbox'] + roi_img = ori_im[y1:y2, x1:x2,:] + if region['label'] == 'table': + res = self.table_system(roi_img) + else: + res = self.text_system(roi_img) + region['res'] = res + return layout_res + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + image_file_list = image_file_list[args.process_id::args.total_process_num] + excel_save_folder = 'output/table' + os.makedirs(excel_save_folder, exist_ok=True) + + text_sys = OCRSystem(args) + img_num = len(image_file_list) + for i, image_file in enumerate(image_file_list): + logger.info("[{}/{}] {}".format(i, img_num, image_file)) + img, flag = check_and_read_gif(image_file) + imgname = os.path.basename(image_file).split('.')[0] + # excel_path = os.path.join(excel_save_folder, + '.xlsx') + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + starttime = time.time() + res = text_sys(img) + + for region in res: + if region['label'] == 'table': + # x1, y1, x2, y2 = region['bbox'] + excel_path = os.path.join(excel_save_folder, '{}_{}.xlsx'.format(imgname,region['bbox'])) + to_excel(region['res'],excel_path) + logger.info(res) + elapse = time.time() - starttime + logger.info("Predict time : {:.3f}s".format(elapse)) + + +if __name__ == "__main__": + args = parse_args() + if args.use_mp: + p_list = [] + total_process_num = args.total_process_num + for process_id in range(total_process_num): + cmd = [sys.executable, "-u"] + sys.argv + [ + "--process_id={}".format(process_id), + "--use_mp={}".format(False) + ] + p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout) + p_list.append(p) + for p in p_list: + p.wait() + else: + main(args) diff --git a/ppstructure/table/__init__.py b/ppstructure/table/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d11e265597c7c8e39098a228108da3bb954b892 --- /dev/null +++ b/ppstructure/table/__init__.py @@ -0,0 +1,13 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ppstructure/table/eval_table.py b/ppstructure/table/eval_table.py new file mode 100755 index 0000000000000000000000000000000000000000..baa701772410043d9e6e941b7775707a4f65df12 --- /dev/null +++ b/ppstructure/table/eval_table.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +import cv2 +import json +from tqdm import tqdm +from ppstructure.table.table_metric import TEDS +from ppstructure.table.predict_table import TableSystem, utility + + +def main(gt_path, img_root, args): + teds = TEDS(n_jobs=16) + + text_sys = TableSystem(args) + jsons_gt = json.load(open(gt_path)) # gt + pred_htmls = [] + gt_htmls = [] + for img_name in tqdm(jsons_gt): + if img_name != 'PMC1064865_002_00.png': + continue + # 读取信息 + img = cv2.imread(os.path.join(img_root,img_name)) + pred_html = text_sys(img) + pred_htmls.append(pred_html) + + gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name] + gt_html, gt = get_gt_html(gt_structures, contents_with_block) # 获取HTMLgt + gt_htmls.append(gt_html) + scores = teds.batch_evaluate_html(gt_htmls, pred_htmls) # 计算teds + print('teds:', sum(scores) / len(scores)) + + +def get_gt_html(gt_structures, contents_with_block): + end_html = [] + td_index = 0 + for tag in gt_structures: + if '' in tag: + if contents_with_block[td_index] != []: + end_html.extend(contents_with_block[td_index]) + end_html.append(tag) + td_index += 1 + else: + end_html.append(tag) + return ''.join(end_html), end_html + + +if __name__ == '__main__': + args = utility.parse_args() + gt_path = 'table/match_code/f_gt_bbox.json' + img_root = 'table/imgs' + main(gt_path,img_root, args) diff --git a/ppstructure/table/predict_structure.py b/ppstructure/table/predict_structure.py new file mode 100755 index 0000000000000000000000000000000000000000..fd00dfd13679ba6c6367e2d7076b33eb0a6db5d5 --- /dev/null +++ b/ppstructure/table/predict_structure.py @@ -0,0 +1,141 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import numpy as np +import math +import time +import traceback +import paddle + +import tools.infer.utility as utility +from ppocr.data import create_operators, transform +from ppocr.postprocess import build_post_process +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, check_and_read_gif + +logger = get_logger() + + +class TableStructurer(object): + def __init__(self, args): + pre_process_list = [{ + 'ResizeTableImage': { + 'max_len': args.table_max_len + } + }, { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225], + 'mean': [0.485, 0.456, 0.406], + 'scale': '1./255.', + 'order': 'hwc' + } + }, { + 'PaddingTableImage': None + }, { + 'ToCHWImage': None + }, { + 'KeepKeys': { + 'keep_keys': ['image'] + } + }] + postprocess_params = { + 'name': 'TableLabelDecode', + "character_type": args.table_char_type, + "character_dict_path": args.table_char_dict_path, + "max_text_length": args.table_max_text_length, + "max_elem_length": args.table_max_elem_length, + "max_cell_num": args.table_max_cell_num + } + + self.preprocess_op = create_operators(pre_process_list) + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors = \ + utility.create_predictor(args, 'table', logger) + + def __call__(self, img): + ori_im = img.copy() + data = {'image': img} + data = transform(data, self.preprocess_op) + img = data[0] + if img is None: + return None, 0 + img = np.expand_dims(img, axis=0) + img = img.copy() + starttime = time.time() + + self.input_tensor.copy_from_cpu(img) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + + preds = {} + preds['structure_probs'] = outputs[1] + preds['loc_preds'] = outputs[0] + + post_result = self.postprocess_op(preds) + + structure_str_list = post_result['structure_str_list'] + res_loc = post_result['res_loc'] + imgh, imgw = ori_im.shape[0:2] + res_loc_final = [] + for rno in range(len(res_loc[0])): + x0, y0, x1, y1 = res_loc[0][rno] + left = max(int(imgw * x0), 0) + top = max(int(imgh * y0), 0) + right = min(int(imgw * x1), imgw - 1) + bottom = min(int(imgh * y1), imgh - 1) + res_loc_final.append([left, top, right, bottom]) + + structure_str_list = structure_str_list[0][:-1] + structure_str_list = ['', '', ''] + structure_str_list + ['
', '', ''] + + elapse = time.time() - starttime + return (structure_str_list, res_loc_final), elapse + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + table_structurer = TableStructurer(args) + count = 0 + total_time = 0 + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + structure_res, elapse = table_structurer(img) + + logger.info("result: {}".format(structure_res)) + + if count > 0: + total_time += elapse + count += 1 + logger.info("Predict time of {}: {}".format(image_file, elapse)) + + +if __name__ == "__main__": + main(utility.parse_args()) diff --git a/ppstructure/table/predict_table.py b/ppstructure/table/predict_table.py new file mode 100644 index 0000000000000000000000000000000000000000..36cf4939ec272923120e8899660137206a0eaf86 --- /dev/null +++ b/ppstructure/table/predict_table.py @@ -0,0 +1,222 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import subprocess + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' +import cv2 +import copy +import numpy as np +import time +import tools.infer.utility as utility +import tools.infer.predict_rec as predict_rec +import tools.infer.predict_det as predict_det +import ppstructure.table.predict_structure as predict_strture +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppocr.utils.logging import get_logger +from ppocr.utils.table_utils.matcher import distance, compute_iou + +logger = get_logger() + + +def expand(pix, det_box, shape): + x0, y0, x1, y1 = det_box + # print(shape) + h, w, c = shape + tmp_x0 = x0 - pix + tmp_x1 = x1 + pix + tmp_y0 = y0 - pix + tmp_y1 = y1 + pix + x0_ = tmp_x0 if tmp_x0 >= 0 else 0 + x1_ = tmp_x1 if tmp_x1 <= w else w + y0_ = tmp_y0 if tmp_y0 >= 0 else 0 + y1_ = tmp_y1 if tmp_y1 <= h else h + return x0_, y0_, x1_, y1_ + + +class TableSystem(object): + def __init__(self, args): + self.text_detector = predict_det.TextDetector(args) + self.text_recognizer = predict_rec.TextRecognizer(args) + self.table_structurer = predict_strture.TableStructurer(args) + self.use_angle_cls = args.use_angle_cls + self.drop_score = args.drop_score + + def __call__(self, img): + ori_im = img.copy() + structure_res, elapse = self.table_structurer(copy.deepcopy(img)) + dt_boxes, elapse = self.text_detector(copy.deepcopy(img)) + dt_boxes = sorted_boxes(dt_boxes) + + r_boxes = [] + for box in dt_boxes: + x_min = box[:, 0].min() - 1 + x_max = box[:, 0].max() + 1 + y_min = box[:, 1].min() - 1 + y_max = box[:, 1].max() + 1 + box = [x_min, y_min, x_max, y_max] + r_boxes.append(box) + dt_boxes = np.array(r_boxes) + + # logger.info("dt_boxes num : {}, elapse : {}".format( + # len(dt_boxes), elapse)) + if dt_boxes is None: + return None, None + img_crop_list = [] + + for i in range(len(dt_boxes)): + det_box = dt_boxes[i] + x0, y0, x1, y1 = expand(2, det_box, ori_im.shape) + text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :] + img_crop_list.append(text_rect) + rec_res, elapse = self.text_recognizer(img_crop_list) + # logger.info("rec_res num : {}, elapse : {}".format( + # len(rec_res), elapse)) + + pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res) + return pred_html + + def rebuild_table(self, structure_res, dt_boxes, rec_res): + pred_structures, pred_bboxes = structure_res + matched_index = self.match_result(dt_boxes, pred_bboxes) + pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res) + return pred_html, pred + + def match_result(self, dt_boxes, pred_bboxes): + matched = {} + for i, gt_box in enumerate(dt_boxes): + # gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])] + distances = [] + for j, pred_box in enumerate(pred_bboxes): + distances.append( + (distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) # 获取两两cell之间的L1距离和 1- IOU + sorted_distances = distances.copy() + # 根据距离和IOU挑选最"近"的cell + sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0])) + if distances.index(sorted_distances[0]) not in matched.keys(): + matched[distances.index(sorted_distances[0])] = [i] + else: + matched[distances.index(sorted_distances[0])].append(i) + return matched + + def get_pred_html(self, pred_structures, matched_index, ocr_contents): + end_html = [] + td_index = 0 + for tag in pred_structures: + if '' in tag: + if td_index in matched_index.keys(): + b_with = False + if '' in ocr_contents[matched_index[td_index][0]] and len(matched_index[td_index]) > 1: + b_with = True + end_html.extend('') + for i, td_index_index in enumerate(matched_index[td_index]): + content = ocr_contents[td_index_index][0] + if len(matched_index[td_index]) > 1: + if len(content) == 0: + continue + if content[0] == ' ': + content = content[1:] + if '' in content: + content = content[3:] + if '' in content: + content = content[:-4] + if len(content) == 0: + continue + if i != len(matched_index[td_index]) - 1 and ' ' != content[-1]: + content += ' ' + end_html.extend(content) + if b_with: + end_html.extend('') + + end_html.append(tag) + td_index += 1 + else: + end_html.append(tag) + return ''.join(end_html), end_html + + +def sorted_boxes(dt_boxes): + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with shape [4, 2] + return: + sorted boxes(array) with shape [4, 2] + """ + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + + for i in range(num_boxes - 1): + if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): + tmp = _boxes[i] + _boxes[i] = _boxes[i + 1] + _boxes[i + 1] = tmp + return _boxes + +def to_excel(html_table, excel_path): + from tablepyxl import tablepyxl + tablepyxl.document_to_xl(html_table, excel_path) + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + image_file_list = image_file_list[args.process_id::args.total_process_num] + excel_save_folder = 'output/table' + os.makedirs(excel_save_folder, exist_ok=True) + + text_sys = TableSystem(args) + img_num = len(image_file_list) + for i, image_file in enumerate(image_file_list): + logger.info("[{}/{}] {}".format(i, img_num, image_file)) + img, flag = check_and_read_gif(image_file) + excel_path = os.path.join(excel_save_folder, os.path.basename(image_file).split('.')[0] + '.xlsx') + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + starttime = time.time() + pred_html = text_sys(img) + + to_excel(pred_html, excel_path) + logger.info('excel saved to {}'.format(excel_path)) + logger.info(pred_html) + elapse = time.time() - starttime + logger.info("Predict time : {:.3f}s".format(elapse)) + + +if __name__ == "__main__": + args = utility.parse_args() + if args.use_mp: + p_list = [] + total_process_num = args.total_process_num + for process_id in range(total_process_num): + cmd = [sys.executable, "-u"] + sys.argv + [ + "--process_id={}".format(process_id), + "--use_mp={}".format(False) + ] + p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout) + p_list.append(p) + for p in p_list: + p.wait() + else: + main(args) diff --git a/ppstructure/table/table_metric/__init__.py b/ppstructure/table/table_metric/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..de2d307430f68881ece1e41357d3b2f423e07ddd --- /dev/null +++ b/ppstructure/table/table_metric/__init__.py @@ -0,0 +1,16 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['TEDS'] +from .table_metric import TEDS \ No newline at end of file diff --git a/ppstructure/table/table_metric/parallel.py b/ppstructure/table/table_metric/parallel.py new file mode 100755 index 0000000000000000000000000000000000000000..f7326a1f506ca5fb7b3e97b0d077dc016e7eb7c7 --- /dev/null +++ b/ppstructure/table/table_metric/parallel.py @@ -0,0 +1,51 @@ +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor, as_completed + + +def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0): + """ + A parallel version of the map function with a progress bar. + Args: + array (array-like): An array to iterate over. + function (function): A python function to apply to the elements of array + n_jobs (int, default=16): The number of cores to use + use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of + keyword arguments to function + front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job. + Useful for catching bugs + Returns: + [function(array[0]), function(array[1]), ...] + """ + # We run the first few iterations serially to catch bugs + if front_num > 0: + front = [function(**a) if use_kwargs else function(a) + for a in array[:front_num]] + else: + front = [] + # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging. + if n_jobs == 1: + return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])] + # Assemble the workers + with ProcessPoolExecutor(max_workers=n_jobs) as pool: + # Pass the elements of array into function + if use_kwargs: + futures = [pool.submit(function, **a) for a in array[front_num:]] + else: + futures = [pool.submit(function, a) for a in array[front_num:]] + kwargs = { + 'total': len(futures), + 'unit': 'it', + 'unit_scale': True, + 'leave': True + } + # Print out the progress as tasks complete + for f in tqdm(as_completed(futures), **kwargs): + pass + out = [] + # Get the results from the futures. + for i, future in tqdm(enumerate(futures)): + try: + out.append(future.result()) + except Exception as e: + out.append(e) + return front + out diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 3f0ff2ff64bff2c2e70be37a95b5449deaa90046..956df5ca49523c8bfc53a3b271e2f802314826bd 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -125,6 +125,8 @@ def create_predictor(args, mode, logger): model_dir = args.cls_model_dir elif mode == 'rec': model_dir = args.rec_model_dir + elif mode == 'table': + model_dir = args.table_model_dir else: model_dir = args.e2e_model_dir @@ -244,7 +246,8 @@ def create_predictor(args, mode, logger): config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.switch_use_feed_fetch_ops(False) - + if mode == 'table': + config.switch_ir_optim(False) # create predictor predictor = inference.create_predictor(config) input_names = predictor.get_input_names()