pubtab_dataset.py 5.0 KB
Newer Older
M
MissPenguin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import random
from paddle.io import Dataset
import json

from .imaug import transform, create_operators

class PubTabDataSet(Dataset):
    def __init__(self, config, mode, logger, seed=None):
        super(PubTabDataSet, self).__init__()
        self.logger = logger

        global_config = config['Global']
        dataset_config = config[mode]['dataset']
        loader_config = config[mode]['loader']

        label_file_path = dataset_config.pop('label_file_path')

        self.data_dir = dataset_config['data_dir']
        self.do_shuffle = loader_config['shuffle']
        self.do_hard_select = False
        if 'hard_select' in loader_config:
            self.do_hard_select = loader_config['hard_select']
            self.hard_prob = loader_config['hard_prob']
        if self.do_hard_select:
            self.img_select_prob = self.load_hard_select_prob()
        self.table_select_type = None
        if 'table_select_type' in loader_config:
            self.table_select_type = loader_config['table_select_type']
            self.table_select_prob = loader_config['table_select_prob']

        self.seed = seed
        logger.info("Initialize indexs of datasets:%s" % label_file_path)
        with open(label_file_path, "rb") as f:
            self.data_lines = f.readlines()
        self.data_idx_order_list = list(range(len(self.data_lines)))
        if mode.lower() == "train":
            self.shuffle_data_random()
        self.ops = create_operators(dataset_config['transforms'], global_config)

    def shuffle_data_random(self):
        if self.do_shuffle:
            random.seed(self.seed)
            random.shuffle(self.data_lines)
        return
    
    def load_hard_select_prob(self):
        label_path = "./pretrained_model/teds_score_exp5_st2_train.txt"
        img_select_prob = {}
        with open(label_path, "rb") as fin:
            lines = fin.readlines()
            for lno in range(len(lines)):
                substr = lines[lno].decode('utf-8').strip("\n").split(" ")
                img_name = substr[0].strip(":")
                score = float(substr[1])
                if score <= 0.8:
                    img_select_prob[img_name] = self.hard_prob[0]
                elif score <= 0.98:
                    img_select_prob[img_name] = self.hard_prob[1]
                else:
                    img_select_prob[img_name] = self.hard_prob[2]
        return img_select_prob

    def __getitem__(self, idx):
        try:
            data_line = self.data_lines[idx]
            data_line = data_line.decode('utf-8').strip("\n")
            info = json.loads(data_line)
            file_name = info['filename']
            select_flag = True
            if self.do_hard_select:
                prob = self.img_select_prob[file_name]
                if prob < random.uniform(0, 1):
                    select_flag = False
            
            if self.table_select_type:
                structure = info['html']['structure']['tokens'].copy()
                structure_str = ''.join(structure)
                table_type = "simple"
                if 'colspan' in structure_str or 'rowspan' in structure_str:
                    table_type = "complex"
#                 if self.table_select_type != table_type:
#                     select_flag = False
                if table_type == "complex":
                    if self.table_select_prob < random.uniform(0, 1):
                        select_flag = False                    
            
            if select_flag:
                cells = info['html']['cells'].copy()
                structure = info['html']['structure'].copy()
                img_path = os.path.join(self.data_dir, file_name)
                data = {'img_path': img_path, 'cells': cells, 'structure':structure}
                if not os.path.exists(img_path):
                    raise Exception("{} does not exist!".format(img_path))
                with open(data['img_path'], 'rb') as f:
                    img = f.read()
                    data['image'] = img
                outs = transform(data, self.ops)
            else:
                outs = None
        except Exception as e:
            self.logger.error(
                "When parsing line {}, error happened with msg: {}".format(
                    data_line, e))
            outs = None
        if outs is None:
            return self.__getitem__(np.random.randint(self.__len__()))
        return outs

    def __len__(self):
        return len(self.data_idx_order_list)