提交 334bbabf 编写于 作者: misite_J's avatar misite_J

initial commit

上级
.idea
__pycache__
/ckpts/*
/logs/*
/output/*
tmp.py
\ No newline at end of file
# 目录
[简介](#简介)
[安装](#安装)
[目录结构](#目录结构)
[模型介绍](#模型介绍)
[运行方式](#运行方式)
[参考](#参考)
# 简介
# 安装
```sh
tqdm==4.62.2
torch==1.8.2+cu102
transformers==4.11.3
torchcrf==1.1.0
```
# 目录结构
```python
Word2Vec
├── Data # 数据集
├── en.txt
├── zh.txt
├── log # 训练日志
├── model # 保存模型
├── dataloader.py
├── model.py
├── trainer.py
├── utils.py
```
# 模型介绍
## BERT_BiLSTM_CRF
```python
"""
1. BERT:
outputs = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=input_mask)
sequence_output = outputs[0]
Inputs:
input_ids: torch.Size([batch_size,seq_len]), 代表输入实例的tensor张量
token_type_ids: torch.Size([batch_size,seq_len]), 一个实例可以含有两个句子,相当于标记
attention_mask: torch.Size([batch_size,seq_len]), 指定对哪些词进行self-Attention操作
Out:
sequence_output: torch.Size([batch_size,seq_len,hidden_size]), 输出序列
pooled_output: torch.Size([batch_size,hidden_size]), 对输出序列进行pool操作的结果
(hidden_states): tuple, 13*torch.Size([batch_size,seq_len,hidden_size]), 隐藏层状态,取决于config的output_hidden_states
(attentions): tuple, 12*torch.Size([batch_size, 12, seq_len, seq_len]), 注意力层,取决于config中的output_attentions
2. BiLSTM:
self.birnn = nn.LSTM(input_size=config.hidden_size, hidden_size=rnn_dim, num_layers=1, bidirectional=True, batch_first=True)
Args:
input_size: 输入数据的特征维数
hidden_size: LSTM中隐层的维度
num_layers: 循环神经网络的层数
bias: 用不用偏置,default=True
batch_first: 通常我们输入的数据shape=(batch_size,seq_length,input_size),而batch_first默认是False,需要将batch_size与seq_length调换
dropout: 默认是0,代表不用dropout
bidirectional: 默认是false,代表不用双向LSTM
sequence_output, _ = self.birnn(sequence_output)
Inputs:
input:shape=(seq_length,batch_size,input_size)的张量
(h_0,c_0): h_0.shape=(num_directions*num_layers, batch, hidden_size),它包含了在当前这个batch_size中每个句子的初始隐藏状态;num_layers就是LSTM的层数,如果bidirectional=True,num_directions=2,否则就是1,表示只有一个方向,c_0和h_0的形状相同,它包含的是在当前这个batch_size中的每个句子的初始状态,h_0、c_0如果不提供,那么默认是0
OutPuts:
output:shape=(seq_length,batch_size,num_directions*hidden_size), 它包含LSTM的最后一层的输出特征(h_t)
(h_n,c_n): h_n.shape=(num_directions*num_layers, batch, hidden_size), c_n与h_n形状相同, h_n包含的是句子的最后一个单词的隐藏状态;c_n包含的是句子的最后一个单词的细胞状态,所以它们都与句子的长度seq_length无关;output[-1]与h_n是相等的,因为output[-1]包含的正是batch_size个句子中每一个句子的最后一个单词的隐藏状态,注意LSTM中的隐藏状态其实就是输出,cell state细胞状态才是LSTM中一直隐藏的,记录着信息
3. 全连接层:
self.hidden2tag = nn.Linear(in_features=out_dim, out_features=config.num_labels)
Args:
in_features: 输入的二维张量的大小,即输入的[batch_size, size]中的size
out_features: 输出的二维张量的大小,即输出的二维张量的形状为[batch_size,output_size],当然,它也代表了该全连接层的神经元个数
释义:
从输入输出的张量的shape角度来理解,相当于一个输入为[batch_size, in_features]的张量变换成了[batch_size, out_features]的输出张量
4. CRF:
self.crf = CRF(num_tags=config.num_labels, batch_first=True)
Args:
num_tags:Number of tags.
batch_first: Whether the first dimension corresponds to the size of a minibatch.
loss = -1 * self.crf(emissions, tags, mask=input_mask.byte())
Inputs:
emissions (`~torch.Tensor`): Emission score tensor of size``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,``(batch_size, seq_length, num_tags)`` otherwise.
tags (`~torch.LongTensor`): Sequence of tags tensor of size``(seq_length, batch_size)`` if ``batch_first`` is ``False``,``(batch_size, seq_length)`` otherwise.
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
reduction: Specifies the reduction to apply to the output:``none|sum|mean|token_mean``. ``none``: no reduction will be applied; ``sum``: the output will be summed over batches; ``mean``: the output will be averaged over batches; ``token_mean``: the output will be averaged over tokens.
Returns:
`~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if reduction is ``none``, ``()`` otherwise.
"""
```
# 运行方式
`run.py`文件内设定以下参数后,运行该py文件即可。
```python
language = 'zh'
neg_sample = True # 是否负采样
embed_dim = 300
C = 3 # 窗口大小
K = 15 # 负采样大小
num_epochs = 100
batch_size = 32
learning_rate = 0.025
```
# 参考
1. [bert_bilstm_crf_ner_pytorch](https://gitee.com/chenzhouwy/bert_bilstm_crf_ner_pytorch/tree/master)
2.
import datetime
import os
import threading
class Config(object):
_instance_lock = threading.Lock()
_init_flag = False
def __init__(self):
if not Config._init_flag:
Config._init_flag = True
self.base_path = os.path.abspath('./')
self._init_train_config()
def _init_train_config(self):
self.label_list = []
self.use_gpu = True
self.device = "cuda"
self.checkpoints = True
self.model = 'bert_bilstm_crf' # 可选['bert_bilstm_crf','hmm','bilstm_crf]
# 输入数据集、日志、输出目录
self.train_file = os.path.join(self.base_path, 'data/train.txt')
self.test_file = os.path.join(self.base_path, 'data/test.txt')
self.log_path = os.path.join(self.base_path, 'logs')
# self.output_path = os.path.join(self.base_path, 'output', datetime.datetime.now().strftime('%Y%m%d%H%M%S'))
self.output_path = os.path.join(self.base_path, 'output', self.model)
self.trained_model_path = os.path.join(self.base_path, 'ckpts', self.model)
self.model_name_or_path = os.path.join(self.base_path, 'ckpts', 'bert-base-chinese') if not self.checkpoints \
else self.trained_model_path
# 以下是模型训练参数
self.do_train = True
self.do_eval = False
self.need_birnn = True
self.do_lower_case = True
self.rnn_dim = 128
self.max_seq_length = 128
self.batch_size = 16
self.num_train_epochs = 5
self.ckpts_epoch = 5
self.gradient_accumulation_steps = 1
self.learning_rate = 3e-5
self.adam_epsilon = 1e-8
self.warmup_steps = 0
self.logging_steps = 50
文件已添加
此差异已折叠。
此差异已折叠。
因为 它太大了无法显示 source diff 。你可以改为 查看blob
此差异已折叠。
此差异已折叠。
此差异已折叠。
import os
import logging
import torch
from torch.utils.data import Dataset, TensorDataset
# from config import Config
from utils import load_pkl, save_pkl, load_file
class InputData(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text, label=None):
self.guid = guid
self.text = text
self.label = label
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, token_type_ids, attention_mask, label_id):
"""
:param input_ids: 单词在词典中的编码
:param attention_mask: 指定 对哪些词 进行self-Attention操作
:param token_type_ids: 区分两个句子的编码(上句全为0,下句全为1)
:param label_id: 标签的id
"""
self.input_ids = input_ids
self.token_type_ids = token_type_ids
self.attention_mask = attention_mask
self.label_id = label_id
class NERDataset(Dataset):
def __init__(self, config, tokenizer, mode="train"):
# text: a list of words, all text from the training dataset
super(NERDataset, self).__init__()
self.config = config
self.tokenizer = tokenizer
if mode == "train":
self.file_path = config.train_file
elif mode == "test":
self.file_path = config.test_file
elif mode == "eval":
self.file_path = config.dev_file
else:
raise ValueError("mode must be one of train, or test")
self.tdt_data = self.get_data()
self.len = len(self.tdt_data)
def __len__(self):
return self.len
def __getitem__(self, idx):
"""
对指定数据集进行预处理,进一步封装数据,包括:
tdt_data:[InputData(guid=index, text=text, label=label)]
feature:BatchEncoding( input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
label_id=label_ids)
data_f: 处理完成的数据集, TensorDataset(all_input_ids, all_token_type_ids, all_attention_mask, all_label_ids)
"""
label_map = {label: i for i, label in enumerate(self.config.label_list)}
max_seq_length = self.config.max_seq_length
data = self.tdt_data[idx]
data_text_list = data.text.split(" ")
data_label_list = data.label.split(" ")
assert len(data_text_list) == len(data_label_list)
features = self.tokenizer(''.join(data_text_list), padding='max_length', max_length=max_seq_length, truncation=True)
label_ids = [label_map[label] for label in data_label_list]
label_ids = [label_map["<START>"]] + label_ids + [label_map["<END>"]]
while len(label_ids) < max_seq_length:
label_ids.append(-1)
features.data['label_ids'] = label_ids
return features
def read_file(self):
with open(self.config.test_file, "r", encoding="utf-8") as f:
lines, words, labels = [], [], []
for line in f.readlines():
contends = line.strip()
tokens = line.strip().split()
if len(tokens) == 2:
words.append(tokens[0])
labels.append(tokens[1])
else:
if len(contends) == 0 and len(words) > 0:
label, word = [], []
for l, w in zip(labels, words):
if len(l) > 0 and len(w) > 0:
label.append(l)
word.append(w)
lines.append([' '.join(label), ' '.join(word)])
words, labels = [], []
return lines
def get_data(self):
'''数据预处理并返回相关数据'''
lines = self.read_file()
tdt_data = []
for i, line in enumerate(lines):
guid = str(i)
text = line[1]
word_piece = self.word_piece_bool(text)
if word_piece:
continue
label = line[0]
tdt_data.append(InputData(guid=guid, text=text, label=label))
return tdt_data
def word_piece_bool(self, text):
word_piece = False
data_text_list = text.split(' ')
for i, word in enumerate(data_text_list):
# 防止wordPiece情况出现,不过貌似不会
token = self.tokenizer.tokenize(word)
# 单个字符表示不会出现wordPiece
if len(token) != 1:
word_piece = True
return word_piece
@staticmethod
def convert_data_to_features(self, tdt_data):
"""
对输入数据进行特征转换
例如:
guid: 0
tokens: [CLS] 王 辉 生 前 驾 驶 机 械 洒 药 消 毒 9 0 后 王 辉 , 2 0 1 0 年 1 2 月 参 军 , 2 0 1 5 年 1 2 月 退 伍 后 , 先 是 应 聘 当 辅 警 , 后 来 在 父 亲 成 立 的 扶 风 恒 盛 科 [SEP]
input_ids: 101 4374 6778 4495 1184 7730 7724 3322 3462 3818 5790 3867 3681 130 121 1400 4374 6778 8024 123 121 122 121 2399 122 123 3299 1346 1092 8024 123 121 122 126 2399 122 123 3299 6842 824 1400 8024 1044 3221 2418 5470 2496 6774 6356 8024 1400 3341 1762 4266 779 2768 4989 4638 2820 7599 2608 4670 4906 102
token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
label_ids: 2 5 3 2 2 2 2 2 2 2 2 2 2 4 11 11 5 3 2 4 11 11 11 11 11 11 11 2 2 2 4 11 11 11 11 11 11 11 2 2 2 2 2 2 2 2 2 0 14 2 2 2 2 2 2 2 2 2 12 7 7 7 7 2
"""
label_map = {label: i for i, label in enumerate(self.config.label_list)}
max_seq_length = self.config.max_seq_length
features = []
for data in tdt_data:
data_text_list = data.text.split(" ")
data_label_list = data.label.split(" ")
assert len(data_text_list) == len(data_label_list)
tokens, labels, ori_tokens = [], [], []
word_piece = False
for i, word in enumerate(data_text_list):
# 防止wordPiece情况出现,不过貌似不会
token = self.tokenizer.tokenize(word)
tokens.extend(token)
label = data_label_list[i]
ori_tokens.append(word)
# 单个字符不会出现wordPiece
if len(token) == 1:
labels.append(label)
else:
word_piece = True
if word_piece:
logging.info("Error tokens!!! skip this lines, the content is: %s" % " ".join(data_text_list))
continue
assert len(tokens) == len(ori_tokens)
# feature = self.tokenizer(''.join(tokens), padding='max_length', max_length=max_seq_length, truncation=True)
# label_ids = [label_map[label] for label in labels]
# label_ids = [label_map["<START>"]] + label_ids + [label_map["<END>"]]
# while len(label_ids) < max_seq_length:
# label_ids.append(-1)
# feature.data['label_ids'] = label_ids
# features.append(feature)
if len(tokens) >= max_seq_length - 1:
# -2的原因是因为序列需要加一个句首和句尾标志
tokens = tokens[0:(max_seq_length - 2)]
labels = labels[0:(max_seq_length - 2)]
label_ids = [label_map[label] for label in labels]
new_tokens = ["[CLS]"] + tokens + ["[SEP]"]
input_ids = self.tokenizer.convert_tokens_to_ids(new_tokens)
token_type_ids = [0] * len(input_ids)
attention_mask = [1] * len(input_ids)
label_ids = [label_map["<START>"]] + label_ids + [label_map["<END>"]]
while len(input_ids) < max_seq_length:
input_ids.append(0)
attention_mask.append(0)
token_type_ids.append(0)
label_ids.append(0)
features.append(InputFeatures(input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
label_id=label_ids))
return features
import logging
from collections import Counter
class Metrics(object):
"""用于评价模型,计算每个标签的精确率,召回率,F1分数"""
def __init__(self, golden_tags, predict_tags, remove_O=False):
# 所有句子tags的拼接[[t1, t2], [t3, t4]...] --> [t1, t2, t3, t4...]
self.golden_tags = self.flatten_lists(golden_tags)
self.predict_tags = self.flatten_lists(predict_tags)
if remove_O: # 将O标记移除,只关心实体标记
self._remove_Otags()
# 辅助计算的变量
self.tagset = set(self.golden_tags)
self.correct_tags_number = self.count_correct_tags()
self.predict_tags_counter = Counter(self.predict_tags)
self.golden_tags_counter = Counter(self.golden_tags)
self.precision_scores = self.cal_precision()
self.recall_scores = self.cal_recall()
self.f1_scores = self.cal_f1()
def flatten_lists(self, lists):
flatten_list = []
for l in lists:
if type(l) == list:
flatten_list += l
else:
flatten_list.append(l)
return flatten_list
def cal_precision(self):
precision_scores = {}
for tag in self.tagset:
precision_scores[tag] = self.correct_tags_number.get(tag, 0) / \
self.predict_tags_counter[tag]
return precision_scores
def cal_recall(self):
recall_scores = {}
for tag in self.tagset:
recall_scores[tag] = self.correct_tags_number.get(tag, 0) / \
self.golden_tags_counter[tag]
return recall_scores
def cal_f1(self):
f1_scores = {}
for tag in self.tagset:
p, r = self.precision_scores[tag], self.recall_scores[tag]
f1_scores[tag] = 2*p*r / (p+r+1e-10) # 加上一个特别小的数,防止分母为0
return f1_scores
def report_scores(self):
"""将结果用表格的形式打印出来,像这个样子:
precision recall f1-score support
B-LOC 0.775 0.757 0.766 1084
I-LOC 0.601 0.631 0.616 325
B-MISC 0.698 0.499 0.582 339
I-MISC 0.644 0.567 0.603 557
B-ORG 0.795 0.801 0.798 1400
I-ORG 0.831 0.773 0.801 1104
B-PER 0.812 0.876 0.843 735
I-PER 0.873 0.931 0.901 634
avg/total 0.779 0.764 0.770 6178
"""
# 打印表头
header_format = '{:>9s} {:>9} {:>9} {:>9} {:>9}'
header = ['precision', 'recall', 'f1-score', 'support']
logging.info(header_format.format('', *header))
# 打印每个标签的 精确率、召回率、f1分数
row_format = '{:>9s} {:>9.4f} {:>9.4f} {:>9.4f} {:>9}'
for tag in self.tagset:
logging.info(row_format.format(
tag,
self.precision_scores[tag],
self.recall_scores[tag],
self.f1_scores[tag],
self.golden_tags_counter[tag]
))
# 计算并打印平均值
avg_metrics = self.cal_avg_metrics()
logging.info(row_format.format(
'avg/total',
avg_metrics['precision'],
avg_metrics['recall'],
avg_metrics['f1_score'],
len(self.golden_tags)
))
def count_correct_tags(self):
"""计算每种标签预测正确的个数(对应精确率、召回率计算公式上的tp),用于后面精确率以及召回率的计算"""
correct_dict = {}
for gold_tag, predict_tag in zip(self.golden_tags, self.predict_tags):
if gold_tag == predict_tag:
if gold_tag not in correct_dict:
correct_dict[gold_tag] = 1
else:
correct_dict[gold_tag] += 1
return correct_dict
def cal_avg_metrics(self):
avg_metrics = {}
total = len(self.golden_tags)
avg_metrics['precision'] = 0.
avg_metrics['recall'] = 0.
avg_metrics['f1_score'] = 0.
for tag in self.tagset:
size = self.golden_tags_counter[tag]
avg_metrics['precision'] += self.precision_scores[tag] * size
avg_metrics['recall'] += self.recall_scores[tag] * size
avg_metrics['f1_score'] += self.f1_scores[tag] * size
for metric in avg_metrics.keys():
avg_metrics[metric] /= total
return avg_metrics
def _remove_Otags(self):
length = len(self.golden_tags)
O_tag_indices = [i for i in range(length) if self.golden_tags[i] == 'O']
self.golden_tags = [tag for i, tag in enumerate(self.golden_tags) if i not in O_tag_indices]
self.predict_tags = [tag for i, tag in enumerate(self.predict_tags) if i not in O_tag_indices]
logging.info("原总标记数为{},移除了{}个O标记,占比{:.2f}%".format(
length,
len(O_tag_indices),
len(O_tag_indices) / length * 100
))
def report_confusion_matrix(self):
"""计算混淆矩阵"""
logging.info("Confusion Matrix:")
tag_list = list(self.tagset)
# 初始化混淆矩阵 matrix[i][j]表示第i个tag被模型预测成第j个tag的次数
tags_size = len(tag_list)
matrix = []
for i in range(tags_size):
matrix.append([0] * tags_size)
for golden_tag, predict_tag in zip(self.golden_tags, self.predict_tags):
try:
row = tag_list.index(golden_tag)
col = tag_list.index(predict_tag)
matrix[row][col] += 1
except ValueError: # 有极少数标记没有出现在golden_tags,但出现在predict_tags,跳过这些标记
continue
row_format_ = '{:>7} ' * (tags_size+1)
logging.info(row_format_.format("", *tag_list))
for i, row in enumerate(matrix):
logging.info(row_format_.format(tag_list[i], *row))
import torch
from torch.utils.tensorboard import SummaryWriter
from config import Config
from utils import *
from trainer import Bert_Bilstm_Crf
def main():
config = Config()
set_logger(config)
writer = SummaryWriter(log_dir=os.path.join(config.output_path, "visual"), comment="ner")
if config.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(config.gradient_accumulation_steps))
use_gpu = torch.cuda.is_available() and config.use_gpu
device = torch.device('cuda' if use_gpu else 'cpu')
config.device = device
n_gpu = torch.cuda.device_count()
logging.info(f"available device: {device},count_gpu: {n_gpu}")
config.label_list = get_labels(config)
label2id = {label: i for i, label in enumerate(config.label_list)}
id2label = {i: label for label, i in label2id.items()}
logging.info("loading label2id and id2label dictionary successful!")
# Bert_Bilstm_Crf模型的训练与测试
trainer_bbc = Bert_Bilstm_Crf(config, device, use_gpu, n_gpu, writer, id2label)
# trainer_bbc.train()
trainer_bbc.test()
if __name__ == '__main__':
main()
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel
from torchcrf import CRF
class BERT_BiLSTM_CRF(BertPreTrainedModel):
def __init__(self, config, need_birnn=False, rnn_dim=128):
super(BERT_BiLSTM_CRF, self).__init__(config)
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
out_dim = config.hidden_size # 768
if need_birnn:
self.need_birnn = need_birnn
self.birnn = nn.LSTM(input_size=config.hidden_size, hidden_size=rnn_dim, num_layers=1, bidirectional=True,
batch_first=True)
out_dim = rnn_dim * 2
self.hidden2tag = nn.Linear(in_features=out_dim, out_features=config.num_labels)
self.crf = CRF(num_tags=config.num_labels, batch_first=True)
def forward(self, input_ids, tags, token_type_ids=None, attention_mask=None):
"""
:param input_ids: torch.Size([batch_size,seq_len]), 代表输入实例的tensor张量
:param token_type_ids: torch.Size([batch_size,seq_len]), 一个实例可以含有两个句子,相当于标记
:param attention_mask: torch.Size([batch_size,seq_len]), 指定对哪些词进行self-Attention操作
:param tags:
:return:
"""
outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
sequence_output = outputs[0] # torch.Size([batch_size,seq_len,hidden_size])
if self.need_birnn:
sequence_output, _ = self.birnn(sequence_output) # (seq_length,batch_size,num_directions*hidden_size)
sequence_output = self.dropout(sequence_output)
emissions = self.hidden2tag(sequence_output) # [seq_length, batch_size, num_labels]
loss = -1 * self.crf(emissions, tags, mask=attention_mask.byte())
return loss
def predict(self, input_ids, token_type_ids=None, attention_mask=None):
outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
sequence_output = outputs[0]
if self.need_birnn:
sequence_output, _ = self.birnn(sequence_output)
sequence_output = self.dropout(sequence_output)
emissions = self.hidden2tag(sequence_output)
return self.crf.decode(emissions, attention_mask.byte())
tqdm==4.62.2
torch==1.8.2+cu102
transformers==4.11.3
torchcrf==1.1.0
import torch
from tqdm import tqdm, trange
from torch.utils.data import DataLoader, SequentialSampler
from transformers import AdamW, get_linear_schedule_with_warmup, BertTokenizer, BertConfig
from utils import *
from dataloader import NERDataset
from models import BERT_BiLSTM_CRF
from evaluator import Metrics
class Bert_Bilstm_Crf():
def __init__(self, config, device, use_gpu, n_gpu, writer, id2label):
self.config = config
self.device = device
self.use_gpu = use_gpu
self.n_gpu = n_gpu
self.writer = writer
self.id2label = id2label
self.tokenizer = BertTokenizer.from_pretrained(config.model_name_or_path,
do_lower_case=config.do_lower_case)
bert_config = BertConfig.from_pretrained(config.model_name_or_path, num_labels=len(config.label_list))
self.model = BERT_BiLSTM_CRF.from_pretrained(config.model_name_or_path, config=bert_config,
need_birnn=config.need_birnn, rnn_dim=config.rnn_dim)
self.model.to(device)
logging.info("loading tokenizer、bert_config and bert_bilstm_crf model successful!")
def train(self):
if self.use_gpu and self.n_gpu > 1:
self.model = torch.nn.DataParallel(self.model)
logging.info("starting load train data and data_loader...")
dataset = NERDataset(self.config, self.tokenizer, mode='train')
dataloader = DataLoader(dataset, self.config.batch_size, shuffle=True)
logging.info("loading train data_set and data_loader successful!")
# 初始化模型参数优化器
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': 0.01},
{'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.config.learning_rate, eps=self.config.adam_epsilon)
# 初始化学习率优化器
t_total = len(dataloader) // self.config.gradient_accumulation_steps * self.config.num_train_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.config.warmup_steps,
num_training_steps=t_total)
logging.info("loading AdamW optimizer、Warmup LinearSchedule and calculate optimizer parameter successful!")
logging.info("====================== Running training ======================")
logging.info(
f"Num Examples: {len(dataset)}, Num Batch Step: {len(dataloader)}, "
f"Num Epochs: {self.config.num_train_epochs}, Num scheduler steps:{t_total}")
# 启用 BatchNormalization 和 Dropout
self.model.train()
global_step, tr_loss, logging_loss, best_f1 = 0, 0.0, 0.0, 0.0
for epoch in range(int(self.config.num_train_epochs)):
# model.train()
for batch, batch_data in enumerate(tqdm(dataloader, desc="Train_DataLoader")):
# input_ids = torch.tensor(batch_data['input_ids'], dtype=torch.long)
# token_type_ids = torch.tensor(batch_data['token_type_ids'], dtype=torch.long)
# attention_mask = torch.tensor(batch_data['attention_mask'], dtype=torch.long)
# label_ids = torch.tensor(batch_data['label_ids'], dtype=torch.long)
batch_data = tuple(torch.stack(batch_data[k]).T.to(self.device) for k in batch_data.keys())
input_ids, token_type_ids, attention_mask, label_ids = batch_data
outputs = self.model(input_ids, label_ids, token_type_ids, attention_mask)
loss = outputs
if self.use_gpu and self.n_gpu > 1:
loss = loss.mean()
if self.config.gradient_accumulation_steps > 1:
loss = loss / self.config.gradient_accumulation_steps
logging.info(f"Epoch: {epoch}/{int(self.config.num_train_epochs)}\tBatch: {batch}/{len(dataloader)}\tLoss:{loss}")
# 反向传播
loss.backward()
tr_loss += loss.item()
# 优化器_模型参数的总更新次数,和上面的t_total对应
if (batch + 1) % self.config.gradient_accumulation_steps == 0:
# 更新参数
optimizer.step()
scheduler.step()
# 梯度清零
self.model.zero_grad()
global_step += 1
if self.config.logging_steps > 0 and global_step % self.config.logging_steps == 0:
tr_loss_avg = (tr_loss - logging_loss) / self.config.logging_steps
self.writer.add_scalar("Train/loss", tr_loss_avg, global_step)
logging_loss = tr_loss
if self.config.do_eval:
logging.info("====================== Running Eval ======================")
eval_data = NERDataset(self.config, self.tokenizer, mode="eval")
avg_metrics, cal_indicators, eval_sens = self.evaluate(
self.config, self.tokenizer, eval_data, self.model, self.id2label, self.device, tqdm_desc="Eval_DataLoader")
f1_score = avg_metrics['f1_score']
self.writer.add_scalar("Eval/precision", avg_metrics['precision'], epoch)
self.writer.add_scalar("Eval/recall", avg_metrics['recall'], epoch)
self.writer.add_scalar("Eval/f1_score", avg_metrics['f1_score'], epoch)
# save the best performs model
if f1_score > best_f1:
logging.info(f"******** the best f1 is {f1_score}, save model !!! ********")
best_f1 = f1_score
# Take care of distributed/parallel training
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
model_to_save.save_pretrained(self.config.trained_model_path)
self.tokenizer.save_pretrained(self.config.trained_model_path)
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
model_to_save.save_pretrained(os.path.join(self.config.trained_model_path, 'checkpoints'))
self.tokenizer.save_pretrained(os.path.join(self.config.trained_model_path, 'checkpoints'))
# # (如果config.do_eval=False,注释以下模型断点保存步骤)
# # 数据集过大,需要分阶段、分时训练时每隔一段时间保存checkpoints
# if (epoch + 1) % self.config.ckpts_epoch == 0:
# model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
# model_to_save.save_pretrained(os.path.join(self.config.trained_model_path, 'checkpoints'))
# self.tokenizer.save_pretrained(os.path.join(self.config.trained_model_path, 'checkpoints'))
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
model_to_save.save_pretrained(os.path.join(self.config.trained_model_path, 'checkpoints'))
self.tokenizer.save_pretrained(os.path.join(self.config.trained_model_path, 'checkpoints'))
# torch.save(self.config, os.path.join(self.config.trained_model_path, 'training_config.bin'))
# torch.save(self.model, os.path.join(self.config.trained_model_path, 'ner_model.ckpt'))
# logging.info("training_args.bin and ner_model.ckpt save successful!")
self.writer.close()
logging.info("NER model training successful!!!")
@staticmethod
def evaluate(config, tokenizer, dataset, model, id2label, device, tqdm_desc):
sampler = SequentialSampler(dataset)
data_loader = DataLoader(dataset, sampler=sampler, batch_size=config.batch_size)
if isinstance(model, torch.nn.DataParallel):
model = model.module
model.eval()
id2label[-1] = 'NULL' # 解码临时添加
ori_tokens = [tokenizer.decode(tdt['input_ids']).split(" ") for tdt in dataset]
ori_labels = [[id2label[idx] for idx in tdt['label_ids']] for tdt in dataset]
pred_labels = []
for b_i, batch_data in enumerate(tqdm(data_loader, desc=tqdm_desc)):
batch_data = tuple(torch.stack(batch_data[k]).T.to(device) for k in batch_data.keys())
input_ids, token_type_ids, attention_mask, label_ids = batch_data
with torch.no_grad():
logits = model.predict(input_ids, token_type_ids, attention_mask)
for logit in logits:
pred_labels.append([id2label[idx] for idx in logit])
assert len(pred_labels) == len(ori_tokens) == len(ori_labels)
eval_sens = []
for ori_token, ori_label, pred_label in zip(ori_tokens, ori_labels, pred_labels):
sen_tll = []
for ot, ol, pl in zip(ori_token, ori_label, pred_label):
if ot in ["[CLS]", "[SEP]", "[PAD]"]:
continue
sen_tll.append((ot, ol, pl))
eval_sens.append(sen_tll)
golden_tags = [[ttl[1] for ttl in sen] for sen in eval_sens]
predict_tags = [[ttl[2] for ttl in sen] for sen in eval_sens]
cal_indicators = Metrics(golden_tags, predict_tags)
avg_metrics = cal_indicators.cal_avg_metrics() # avg_metrics['precision'], avg_metrics['recall'], avg_metrics['f1_score']
return avg_metrics, cal_indicators, eval_sens
def test(self):
logging.info("====================== Running test ======================")
dataset = NERDataset(self.config, self.tokenizer, mode='test')
avg_metrics, cal_indicators, eval_sens = self.evaluate(
self.config, self.tokenizer, dataset, self.model, self.id2label, self.device, tqdm_desc="Test_DataLoader")
cal_indicators.report_scores() # avg_metrics['precision'], avg_metrics['recall'], avg_metrics['f1_score']
cal_indicators.report_confusion_matrix()
# 将测试结果写入本地
with open(os.path.join(self.config.output_path, "token_labels_test.txt"), "w", encoding="utf-8") as f:
for sen in eval_sens:
for ttl in sen:
f.write(f"{ttl[0]}\t{ttl[1]}\t{ttl[2]}\n")
f.write("\n")
# sampler = SequentialSampler(dataset)
# data_loader = DataLoader(dataset, sampler=sampler, batch_size=self.config.batch_size)
# self.model.eval()
#
# id2label = self.id2label
# id2label[-1] = 'NULL' # 解码临时添加
# ori_tokens = [self.tokenizer.decode(tdt['input_ids']).split(" ") for tdt in dataset]
# ori_labels = [[id2label[idx] for idx in tdt['label_ids']] for tdt in dataset]
# pred_labels = []
#
# for b_i, batch_data in enumerate(tqdm(data_loader, desc="Test_DataLoader")):
# batch_data = tuple(torch.stack(batch_data[k]).T.to(self.device) for k in batch_data.keys())
# input_ids, token_type_ids, attention_mask, label_ids = batch_data
#
# with torch.no_grad():
# logits = self.model.predict(input_ids, token_type_ids, attention_mask)
#
# for logit in logits:
# pred_label = []
# for idx in logit:
# pred_label.append(id2label[idx])
# pred_labels.append(pred_label)
#
# assert len(pred_labels) == len(ori_tokens) == len(ori_labels)
# eval_sens = []
# for ori_token, ori_label, pred_label in zip(ori_tokens, ori_labels, pred_labels):
# sen_tll = []
# for ot, ol, pl in zip(ori_token, ori_label, pred_label):
# if ot in ["[CLS]", "[SEP]", "[PAD]"]:
# continue
# sen_tll.append((ot, ol, pl))
# eval_sens.append(sen_tll)
#
# golden_tags = [[ttl[1] for ttl in sen] for sen in eval_sens]
# predict_tags = [[ttl[2] for ttl in sen] for sen in eval_sens]
# cal_indicators = Metrics(golden_tags, predict_tags)
# avg_metrics = cal_indicators.cal_avg_metrics()
import os
import pickle
import logging
def set_logger(config):
if not os.path.exists(config.log_path):
os.mkdir(config.log_path)
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S',
filename=os.path.join(config.log_path, '{}.log'.format(config.model)),
filemode='a'
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
def load_pkl(fp):
"""加载pkl文件"""
with open(fp, 'rb') as f:
data = pickle.load(f)
return data
def save_pkl(data, fp):
"""保存pkl文件,数据序列化"""
with open(fp, 'wb') as f:
pickle.dump(data, f)
def load_file(fp: str, sep: str = None):
"""
读取文件;
若sep为None,按行读取,返回文件内容列表,格式为:[xxx,xxx,xxx,...]
若不为None,按行读取分隔,返回文件内容列表,格式为: [[xxx,xxx],[xxx,xxx],...]
"""
with open(fp, "r", encoding="utf-8") as f:
lines = f.readlines()
if sep:
return [line.strip().split(sep) for line in lines]
else:
return lines
def get_labels(config):
"""读取训练数据获取标签"""
label_pkl_path = os.path.join(config.base_path, "data/label_list.pkl")
if os.path.exists(label_pkl_path):
logging.info(f"loading labels info from {os.path.join(config.base_path, 'data')}")
labels = load_pkl(label_pkl_path)
else:
logging.info(f"loading labels info from train file and dump in {os.path.join(config.base_path, 'data')}")
tokens_list = load_file(config.train_file, sep=' ')
labels = list(set([tokens[1] for tokens in tokens_list if len(tokens) == 2]))
# 增加开始和结束的标志
labels.extend(['<START>', '<END>'])
save_pkl(labels, label_pkl_path)
return labels
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册