提交 3b028ad8 编写于 作者: M MRXLT

add lac and senta reader

上级 82e925ca
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
import os
import io
def load_kv_dict(dict_path,
reverse=False,
delimiter="\t",
key_func=None,
value_func=None):
result_dict = {}
for line in io.open(dict_path, "r", encoding="utf8"):
terms = line.strip("\n").split(delimiter)
if len(terms) != 2:
continue
if reverse:
value, key = terms
else:
key, value = terms
if key in result_dict:
raise KeyError("key duplicated with [%s]" % (key))
if key_func:
key = key_func(key)
if value_func:
value = value_func(value)
result_dict[key] = value
return result_dict
class LACReader(object):
"""data reader"""
def __init__(self, dict_folder):
# read dict
#basepath = os.path.abspath(__file__)
#folder = os.path.dirname(basepath)
word_dict_path = os.path.join(dict_folder, "word.dic")
label_dict_path = os.path.join(dict_folder, "tag.dic")
self.word2id_dict = load_kv_dict(
word_dict_path, reverse=True, value_func=int)
self.id2word_dict = load_kv_dict(word_dict_path)
self.label2id_dict = load_kv_dict(
label_dict_path, reverse=True, value_func=int)
self.id2label_dict = load_kv_dict(label_dict_path)
@property
def vocab_size(self):
"""vocabulary size"""
return max(self.word2id_dict.values()) + 1
@property
def num_labels(self):
"""num_labels"""
return max(self.label2id_dict.values()) + 1
def word_to_ids(self, words):
"""convert word to word index"""
word_ids = []
idx = 0
try:
words = unicode(words, 'utf-8')
except:
pass
for word in words:
if word not in self.word2id_dict:
word = "OOV"
word_id = self.word2id_dict[word]
word_ids.append(word_id)
return word_ids
def label_to_ids(self, labels):
"""convert label to label index"""
label_ids = []
for label in labels:
if label not in self.label2id_dict:
label = "O"
label_id = self.label2id_dict[label]
label_ids.append(label_id)
return label_ids
def process(self, sent):
words = sent.strip()
word_ids = self.word_to_ids(words)
return word_ids
def parse_result(self, words, crf_decode):
tags = [self.id2label_dict[str(x[0])] for x in crf_decode]
sent_out = []
tags_out = []
partial_word = ""
for ind, tag in enumerate(tags):
if partial_word == "":
partial_word = words[ind]
tags_out.append(tag.split('-')[0])
continue
if tag.endswith("-B") or (tag == "O" and tag[ind - 1] != "O"):
sent_out.append(partial_word)
tags_out.append(tag.split('-')[0])
partial_word = words[ind]
continue
partial_word += words[ind]
if len(sent_out) < len(tags_out):
sent_out.append(partial_word)
return sent_out
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import io
class SentaReader():
def __init__(self, vocab_path, max_seq_len=20):
self.max_seq_len = max_seq_len
self.word_dict = self.load_vocab(vocab_path)
def load_vocab(self, vocab_path):
"""
load the given vocabulary
"""
vocab = {}
with io.open(vocab_path, 'r', encoding='utf8') as f:
wid = 0
for line in f:
if line.strip() not in vocab:
vocab[line.strip()] = wid
wid += 1
vocab["<unk>"] = len(vocab)
return vocab
def process(self, cols):
unk_id = len(self.word_dict)
pad_id = 0
wids = [
self.word_dict[x] if x in self.word_dict else unk_id for x in cols
]
seq_len = len(wids)
if seq_len < self.max_seq_len:
for i in range(self.max_seq_len - seq_len):
wids.append(pad_id)
else:
wids = wids[:self.max_seq_len]
seq_len = self.max_seq_len
return wids
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册