From eb7ce442a3adbd8899b2f357583847d1b237d88b Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Thu, 3 Jun 2021 16:43:29 +0800 Subject: [PATCH] add table eval and predict script --- .../table}/matcher.py | 0 .../table/table_metric/table_metric.py | 247 ++++++++++++++++++ 2 files changed, 247 insertions(+) rename {ppocr/utils/table_utils => ppstructure/table}/matcher.py (100%) create mode 100755 ppstructure/table/table_metric/table_metric.py diff --git a/ppocr/utils/table_utils/matcher.py b/ppstructure/table/matcher.py similarity index 100% rename from ppocr/utils/table_utils/matcher.py rename to ppstructure/table/matcher.py diff --git a/ppstructure/table/table_metric/table_metric.py b/ppstructure/table/table_metric/table_metric.py new file mode 100755 index 00000000..9aca98ad --- /dev/null +++ b/ppstructure/table/table_metric/table_metric.py @@ -0,0 +1,247 @@ +# Copyright 2020 IBM +# Author: peter.zhong@au1.ibm.com +# +# This is free software; you can redistribute it and/or modify +# it under the terms of the Apache 2.0 License. +# +# This software is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Apache 2.0 License for more details. + +import distance +from apted import APTED, Config +from apted.helpers import Tree +from lxml import etree, html +from collections import deque +from .parallel import parallel_process +from tqdm import tqdm + + +class TableTree(Tree): + def __init__(self, tag, colspan=None, rowspan=None, content=None, *children): + self.tag = tag + self.colspan = colspan + self.rowspan = rowspan + self.content = content + self.children = list(children) + + def bracket(self): + """Show tree using brackets notation""" + if self.tag == 'td': + result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \ + (self.tag, self.colspan, self.rowspan, self.content) + else: + result = '"tag": %s' % self.tag + for child in self.children: + result += child.bracket() + return "{{{}}}".format(result) + + +class CustomConfig(Config): + @staticmethod + def maximum(*sequences): + """Get maximum possible value + """ + return max(map(len, sequences)) + + def normalized_distance(self, *sequences): + """Get distance from 0 to 1 + """ + return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) + + def rename(self, node1, node2): + """Compares attributes of trees""" + #print(node1.tag) + if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): + return 1. + if node1.tag == 'td': + if node1.content or node2.content: + #print(node1.content, ) + return self.normalized_distance(node1.content, node2.content) + return 0. + + + +class CustomConfig_del_short(Config): + @staticmethod + def maximum(*sequences): + """Get maximum possible value + """ + return max(map(len, sequences)) + + def normalized_distance(self, *sequences): + """Get distance from 0 to 1 + """ + return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) + + def rename(self, node1, node2): + """Compares attributes of trees""" + if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): + return 1. + if node1.tag == 'td': + if node1.content or node2.content: + #print('before') + #print(node1.content, node2.content) + #print('after') + node1_content = node1.content + node2_content = node2.content + if len(node1_content) < 3: + node1_content = ['####'] + if len(node2_content) < 3: + node2_content = ['####'] + return self.normalized_distance(node1_content, node2_content) + return 0. + +class CustomConfig_del_block(Config): + @staticmethod + def maximum(*sequences): + """Get maximum possible value + """ + return max(map(len, sequences)) + + def normalized_distance(self, *sequences): + """Get distance from 0 to 1 + """ + return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) + + def rename(self, node1, node2): + """Compares attributes of trees""" + if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): + return 1. + if node1.tag == 'td': + if node1.content or node2.content: + + node1_content = node1.content + node2_content = node2.content + while ' ' in node1_content: + print(node1_content.index(' ')) + node1_content.pop(node1_content.index(' ')) + while ' ' in node2_content: + print(node2_content.index(' ')) + node2_content.pop(node2_content.index(' ')) + return self.normalized_distance(node1_content, node2_content) + return 0. + +class TEDS(object): + ''' Tree Edit Distance basead Similarity + ''' + + def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None): + assert isinstance(n_jobs, int) and ( + n_jobs >= 1), 'n_jobs must be an integer greather than 1' + self.structure_only = structure_only + self.n_jobs = n_jobs + self.ignore_nodes = ignore_nodes + self.__tokens__ = [] + + def tokenize(self, node): + ''' Tokenizes table cells + ''' + self.__tokens__.append('<%s>' % node.tag) + if node.text is not None: + self.__tokens__ += list(node.text) + for n in node.getchildren(): + self.tokenize(n) + if node.tag != 'unk': + self.__tokens__.append('' % node.tag) + if node.tag != 'td' and node.tail is not None: + self.__tokens__ += list(node.tail) + + def load_html_tree(self, node, parent=None): + ''' Converts HTML tree to the format required by apted + ''' + global __tokens__ + if node.tag == 'td': + if self.structure_only: + cell = [] + else: + self.__tokens__ = [] + self.tokenize(node) + cell = self.__tokens__[1:-1].copy() + new_node = TableTree(node.tag, + int(node.attrib.get('colspan', '1')), + int(node.attrib.get('rowspan', '1')), + cell, *deque()) + else: + new_node = TableTree(node.tag, None, None, None, *deque()) + if parent is not None: + parent.children.append(new_node) + if node.tag != 'td': + for n in node.getchildren(): + self.load_html_tree(n, new_node) + if parent is None: + return new_node + + def evaluate(self, pred, true): + ''' Computes TEDS score between the prediction and the ground truth of a + given sample + ''' + if (not pred) or (not true): + return 0.0 + parser = html.HTMLParser(remove_comments=True, encoding='utf-8') + pred = html.fromstring(pred, parser=parser) + true = html.fromstring(true, parser=parser) + if pred.xpath('body/table') and true.xpath('body/table'): + pred = pred.xpath('body/table')[0] + true = true.xpath('body/table')[0] + if self.ignore_nodes: + etree.strip_tags(pred, *self.ignore_nodes) + etree.strip_tags(true, *self.ignore_nodes) + n_nodes_pred = len(pred.xpath(".//*")) + n_nodes_true = len(true.xpath(".//*")) + n_nodes = max(n_nodes_pred, n_nodes_true) + tree_pred = self.load_html_tree(pred) + tree_true = self.load_html_tree(true) + distance = APTED(tree_pred, tree_true, + CustomConfig()).compute_edit_distance() + return 1.0 - (float(distance) / n_nodes) + else: + return 0.0 + + def batch_evaluate(self, pred_json, true_json): + ''' Computes TEDS score between the prediction and the ground truth of + a batch of samples + @params pred_json: {'FILENAME': 'HTML CODE', ...} + @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...} + @output: {'FILENAME': 'TEDS SCORE', ...} + ''' + samples = true_json.keys() + if self.n_jobs == 1: + scores = [self.evaluate(pred_json.get( + filename, ''), true_json[filename]['html']) for filename in tqdm(samples)] + else: + inputs = [{'pred': pred_json.get( + filename, ''), 'true': true_json[filename]['html']} for filename in samples] + scores = parallel_process( + inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1) + scores = dict(zip(samples, scores)) + return scores + + def batch_evaluate_html(self, pred_htmls, true_htmls): + ''' Computes TEDS score between the prediction and the ground truth of + a batch of samples + ''' + if self.n_jobs == 1: + scores = [self.evaluate(pred_html, true_html) for ( + pred_html, true_html) in zip(pred_htmls, true_htmls)] + else: + inputs = [{"pred": pred_html, "true": true_html} for( + pred_html, true_html) in zip(pred_htmls, true_htmls)] + + scores = parallel_process( + inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1) + return scores + + +if __name__ == '__main__': + import json + import pprint + with open('sample_pred.json') as fp: + pred_json = json.load(fp) + with open('sample_gt.json') as fp: + true_json = json.load(fp) + teds = TEDS(n_jobs=4) + scores = teds.batch_evaluate(pred_json, true_json) + pp = pprint.PrettyPrinter() + pp.pprint(scores) -- GitLab