# Copyright 2020 IBM # Author: peter.zhong@au1.ibm.com # # This is free software; you can redistribute it and/or modify # it under the terms of the Apache 2.0 License. # # This software is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # Apache 2.0 License for more details. import distance from apted import APTED, Config from apted.helpers import Tree from lxml import etree, html from collections import deque from .parallel import parallel_process from tqdm import tqdm class TableTree(Tree): def __init__(self, tag, colspan=None, rowspan=None, content=None, *children): self.tag = tag self.colspan = colspan self.rowspan = rowspan self.content = content self.children = list(children) def bracket(self): """Show tree using brackets notation""" if self.tag == 'td': result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \ (self.tag, self.colspan, self.rowspan, self.content) else: result = '"tag": %s' % self.tag for child in self.children: result += child.bracket() return "{{{}}}".format(result) class CustomConfig(Config): @staticmethod def maximum(*sequences): """Get maximum possible value """ return max(map(len, sequences)) def normalized_distance(self, *sequences): """Get distance from 0 to 1 """ return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) def rename(self, node1, node2): """Compares attributes of trees""" #print(node1.tag) if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): return 1. if node1.tag == 'td': if node1.content or node2.content: #print(node1.content, ) return self.normalized_distance(node1.content, node2.content) return 0. class CustomConfig_del_short(Config): @staticmethod def maximum(*sequences): """Get maximum possible value """ return max(map(len, sequences)) def normalized_distance(self, *sequences): """Get distance from 0 to 1 """ return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) def rename(self, node1, node2): """Compares attributes of trees""" if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): return 1. if node1.tag == 'td': if node1.content or node2.content: #print('before') #print(node1.content, node2.content) #print('after') node1_content = node1.content node2_content = node2.content if len(node1_content) < 3: node1_content = ['####'] if len(node2_content) < 3: node2_content = ['####'] return self.normalized_distance(node1_content, node2_content) return 0. class CustomConfig_del_block(Config): @staticmethod def maximum(*sequences): """Get maximum possible value """ return max(map(len, sequences)) def normalized_distance(self, *sequences): """Get distance from 0 to 1 """ return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) def rename(self, node1, node2): """Compares attributes of trees""" if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): return 1. if node1.tag == 'td': if node1.content or node2.content: node1_content = node1.content node2_content = node2.content while ' ' in node1_content: print(node1_content.index(' ')) node1_content.pop(node1_content.index(' ')) while ' ' in node2_content: print(node2_content.index(' ')) node2_content.pop(node2_content.index(' ')) return self.normalized_distance(node1_content, node2_content) return 0. class TEDS(object): ''' Tree Edit Distance basead Similarity ''' def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None): assert isinstance(n_jobs, int) and ( n_jobs >= 1), 'n_jobs must be an integer greather than 1' self.structure_only = structure_only self.n_jobs = n_jobs self.ignore_nodes = ignore_nodes self.__tokens__ = [] def tokenize(self, node): ''' Tokenizes table cells ''' self.__tokens__.append('<%s>' % node.tag) if node.text is not None: self.__tokens__ += list(node.text) for n in node.getchildren(): self.tokenize(n) if node.tag != 'unk': self.__tokens__.append('' % node.tag) if node.tag != 'td' and node.tail is not None: self.__tokens__ += list(node.tail) def load_html_tree(self, node, parent=None): ''' Converts HTML tree to the format required by apted ''' global __tokens__ if node.tag == 'td': if self.structure_only: cell = [] else: self.__tokens__ = [] self.tokenize(node) cell = self.__tokens__[1:-1].copy() new_node = TableTree(node.tag, int(node.attrib.get('colspan', '1')), int(node.attrib.get('rowspan', '1')), cell, *deque()) else: new_node = TableTree(node.tag, None, None, None, *deque()) if parent is not None: parent.children.append(new_node) if node.tag != 'td': for n in node.getchildren(): self.load_html_tree(n, new_node) if parent is None: return new_node def evaluate(self, pred, true): ''' Computes TEDS score between the prediction and the ground truth of a given sample ''' if (not pred) or (not true): return 0.0 parser = html.HTMLParser(remove_comments=True, encoding='utf-8') pred = html.fromstring(pred, parser=parser) true = html.fromstring(true, parser=parser) if pred.xpath('body/table') and true.xpath('body/table'): pred = pred.xpath('body/table')[0] true = true.xpath('body/table')[0] if self.ignore_nodes: etree.strip_tags(pred, *self.ignore_nodes) etree.strip_tags(true, *self.ignore_nodes) n_nodes_pred = len(pred.xpath(".//*")) n_nodes_true = len(true.xpath(".//*")) n_nodes = max(n_nodes_pred, n_nodes_true) tree_pred = self.load_html_tree(pred) tree_true = self.load_html_tree(true) distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance() return 1.0 - (float(distance) / n_nodes) else: return 0.0 def batch_evaluate(self, pred_json, true_json): ''' Computes TEDS score between the prediction and the ground truth of a batch of samples @params pred_json: {'FILENAME': 'HTML CODE', ...} @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...} @output: {'FILENAME': 'TEDS SCORE', ...} ''' samples = true_json.keys() if self.n_jobs == 1: scores = [self.evaluate(pred_json.get( filename, ''), true_json[filename]['html']) for filename in tqdm(samples)] else: inputs = [{'pred': pred_json.get( filename, ''), 'true': true_json[filename]['html']} for filename in samples] scores = parallel_process( inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1) scores = dict(zip(samples, scores)) return scores def batch_evaluate_html(self, pred_htmls, true_htmls): ''' Computes TEDS score between the prediction and the ground truth of a batch of samples ''' if self.n_jobs == 1: scores = [self.evaluate(pred_html, true_html) for ( pred_html, true_html) in zip(pred_htmls, true_htmls)] else: inputs = [{"pred": pred_html, "true": true_html} for( pred_html, true_html) in zip(pred_htmls, true_htmls)] scores = parallel_process( inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1) return scores if __name__ == '__main__': import json import pprint with open('sample_pred.json') as fp: pred_json = json.load(fp) with open('sample_gt.json') as fp: true_json = json.load(fp) teds = TEDS(n_jobs=4) scores = teds.batch_evaluate(pred_json, true_json) pp = pprint.PrettyPrinter() pp.pprint(scores)