cluener_evaluation.py 2.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

'''bert clue evaluation'''

import json
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
W
wanghua 已提交
22 23 24 25
from src import tokenization
from src.sample_process import label_generation, process_one_example_p
from src.CRF import postprocess
from src.finetune_eval_config import bert_net_cfg
26 27


W
wanghua 已提交
28
def process(model=None, text="", tokenizer_=None, use_crf="", label2id_file=""):
29 30 31 32 33 34 35 36
    """
    process text.
    """
    data = [text]
    features = []
    res = []
    ids = []
    for i in data:
W
wanghua 已提交
37
        feature = process_one_example_p(tokenizer_, i, max_seq_len=bert_net_cfg.seq_length)
38 39 40 41 42
        features.append(feature)
        input_ids, input_mask, token_type_id = feature
        input_ids = Tensor(np.array(input_ids), mstype.int32)
        input_mask = Tensor(np.array(input_mask), mstype.int32)
        token_type_id = Tensor(np.array(token_type_id), mstype.int32)
W
wanghua 已提交
43
        if use_crf.lower() == "true":
44 45 46 47 48 49 50 51 52 53 54
            backpointers, best_tag_id = model.predict(input_ids, input_mask, token_type_id, Tensor(1))
            best_path = postprocess(backpointers, best_tag_id)
            logits = []
            for ele in best_path:
                logits.extend(ele)
            ids = logits
        else:
            logits = model.predict(input_ids, input_mask, token_type_id, Tensor(1))
            ids = logits.asnumpy()
            ids = np.argmax(ids, axis=-1)
            ids = list(ids)
W
wanghua 已提交
55
    res = label_generation(text=text, probs=ids, label2id_file=label2id_file)
56 57
    return res

W
wanghua 已提交
58
def submit(model=None, path="", vocab_file="", use_crf="", label2id_file=""):
59 60 61
    """
    submit task
    """
W
wanghua 已提交
62
    tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file)
63 64 65 66 67
    data = []
    for line in open(path):
        if not line.strip():
            continue
        oneline = json.loads(line.strip())
W
wanghua 已提交
68 69
        res = process(model=model, text=oneline["text"], tokenizer_=tokenizer_,
                      use_crf=use_crf, label2id_file=label2id_file)
70 71 72 73
        print("text", oneline["text"])
        print("res:", res)
        data.append(json.dumps({"label": res}, ensure_ascii=False))
    open("ner_predict.json", "w").write("\n".join(data))