eval_ser.py 5.7 KB
Newer Older
Z
zhoujun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))

import random
import time
import copy
import logging

import argparse
import paddle
import numpy as np
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
Z
zhoujun 已提交
32 33
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification

Z
zhoujun 已提交
34
from xfun import XFUNDataset
Z
zhoujun 已提交
35
from losses import SERLoss
36
from utils import parse_args, get_bio_label_maps, print_arguments
Z
zhoujun 已提交
37 38 39

from ppocr.utils.logging import get_logger

Z
zhoujun 已提交
40 41 42 43 44 45 46
MODELS = {
    'LayoutXLM':
    (LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
    'LayoutLM':
    (LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}

Z
zhoujun 已提交
47 48 49 50 51 52 53 54

def eval(args):
    logger = get_logger()
    print_arguments(args, logger)

    label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
    pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index

Z
zhoujun 已提交
55 56 57
    tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    model = model_class.from_pretrained(args.model_name_or_path)
Z
zhoujun 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73

    eval_dataset = XFUNDataset(
        tokenizer,
        data_dir=args.eval_data_dir,
        label_path=args.eval_label_path,
        label2id_map=label2id_map,
        img_size=(224, 224),
        pad_token_label_id=pad_token_label_id,
        contains_re=False,
        add_special_ids=False,
        return_attention_mask=True,
        load_mode='all')

    eval_dataloader = paddle.io.DataLoader(
        eval_dataset,
        batch_size=args.per_gpu_eval_batch_size,
文幕地方's avatar
文幕地方 已提交
74
        num_workers=args.num_workers,
Z
zhoujun 已提交
75 76 77
        use_shared_memory=True,
        collate_fn=None, )

Z
zhoujun 已提交
78 79 80 81 82
    loss_class = SERLoss(len(label2id_map))

    results, _ = evaluate(args, model, tokenizer, loss_class, eval_dataloader,
                          label2id_map, id2label_map, pad_token_label_id,
                          logger)
Z
zhoujun 已提交
83 84 85 86 87 88 89

    logger.info(results)


def evaluate(args,
             model,
             tokenizer,
Z
zhoujun 已提交
90
             loss_class,
Z
zhoujun 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104
             eval_dataloader,
             label2id_map,
             id2label_map,
             pad_token_label_id,
             logger,
             prefix=""):

    eval_loss = 0.0
    nb_eval_steps = 0
    preds = None
    out_label_ids = None
    model.eval()
    for idx, batch in enumerate(eval_dataloader):
        with paddle.no_grad():
Z
zhoujun 已提交
105 106 107 108
            if args.ser_model_type == 'LayoutLM':
                if 'image' in batch:
                    batch.pop('image')
            labels = batch.pop('labels')
Z
zhoujun 已提交
109
            outputs = model(**batch)
Z
zhoujun 已提交
110 111 112
            if args.ser_model_type == 'LayoutXLM':
                outputs = outputs[0]
            loss = loss_class(labels, outputs, batch['attention_mask'])
Z
zhoujun 已提交
113

Z
zhoujun 已提交
114
            loss = loss.mean()
Z
zhoujun 已提交
115 116 117

            if paddle.distributed.get_rank() == 0:
                logger.info("[Eval]process: {}/{}, loss: {:.5f}".format(
Z
zhoujun 已提交
118
                    idx, len(eval_dataloader), loss.numpy()[0]))
Z
zhoujun 已提交
119

Z
zhoujun 已提交
120
            eval_loss += loss.item()
Z
zhoujun 已提交
121 122
        nb_eval_steps += 1
        if preds is None:
Z
zhoujun 已提交
123 124
            preds = outputs.numpy()
            out_label_ids = labels.numpy()
Z
zhoujun 已提交
125
        else:
Z
zhoujun 已提交
126 127
            preds = np.append(preds, outputs.numpy(), axis=0)
            out_label_ids = np.append(out_label_ids, labels.numpy(), axis=0)
Z
zhoujun 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149

    eval_loss = eval_loss / nb_eval_steps
    preds = np.argmax(preds, axis=2)

    # label_map = {i: label.upper() for i, label in enumerate(labels)}

    out_label_list = [[] for _ in range(out_label_ids.shape[0])]
    preds_list = [[] for _ in range(out_label_ids.shape[0])]

    for i in range(out_label_ids.shape[0]):
        for j in range(out_label_ids.shape[1]):
            if out_label_ids[i, j] != pad_token_label_id:
                out_label_list[i].append(id2label_map[out_label_ids[i][j]])
                preds_list[i].append(id2label_map[preds[i][j]])

    results = {
        "loss": eval_loss,
        "precision": precision_score(out_label_list, preds_list),
        "recall": recall_score(out_label_list, preds_list),
        "f1": f1_score(out_label_list, preds_list),
    }

文幕地方's avatar
文幕地方 已提交
150 151 152
    with open(
            os.path.join(args.output_dir, "test_gt.txt"), "w",
            encoding='utf-8') as fout:
Z
zhoujun 已提交
153 154 155 156
        for lbl in out_label_list:
            for l in lbl:
                fout.write(l + "\t")
            fout.write("\n")
文幕地方's avatar
文幕地方 已提交
157 158 159
    with open(
            os.path.join(args.output_dir, "test_pred.txt"), "w",
            encoding='utf-8') as fout:
Z
zhoujun 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
        for lbl in preds_list:
            for l in lbl:
                fout.write(l + "\t")
            fout.write("\n")

    report = classification_report(out_label_list, preds_list)
    logger.info("\n" + report)

    logger.info("***** Eval results %s *****", prefix)
    for key in sorted(results.keys()):
        logger.info("  %s = %s", key, str(results[key]))
    model.train()
    return results, preds_list


if __name__ == "__main__":
    args = parse_args()
    eval(args)