提交 ccee4fa1 编写于 作者: misite_J's avatar misite_J

readme updated

上级 334bbabf
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
# 简介 # 简介
使用PyTorch实现中文NER模型,拟提供BERT_BiLSTM_CRF、BiLSTM_CRF、CRF和HMM四种模型,目前实现BERT_BiLSTM_CRF。
...@@ -36,14 +36,23 @@ torchcrf==1.1.0 ...@@ -36,14 +36,23 @@ torchcrf==1.1.0
# 目录结构 # 目录结构
```python ```python
Word2Vec NER_ZH
├── Data # 数据集 ├── ckpts
├── en.txt ├── bert-base-chinese # 预训练模型BERT
├── zh.txt ├── bert_bilstm_crf # 自己训练的bert_bilstm_crf模型
├── log # 训练日志 ├── ...
├── model # 保存模型 ├── data # 数据集
├── test.txt
├── train.txt
├── ...
├── logs # 训练日志
├── bert_bilstm_crf.log # bert_bilstm_crf训练测试日志
├── output # 预测结果、可视乎结果等
├── config.py
├── dataloader.py ├── dataloader.py
├── model.py ├── evaluator.py
├── main.py
├── models.py
├── trainer.py ├── trainer.py
├── utils.py ├── utils.py
``` ```
...@@ -54,8 +63,6 @@ Word2Vec ...@@ -54,8 +63,6 @@ Word2Vec
## BERT_BiLSTM_CRF ## BERT_BiLSTM_CRF
```python ```python
""" """
1. BERT: 1. BERT:
...@@ -121,27 +128,59 @@ Word2Vec ...@@ -121,27 +128,59 @@ Word2Vec
""" """
``` ```
## 模型评估
- **精确率P、召回率R以及F1值**
# 运行方式 $P=\frac{TP}{TP+FP}$
`run.py`文件内设定以下参数后,运行该py文件即可。 $R=\frac{TP}{TP+FN}$
```python $F1=\frac{2PR}{P+R}$
language = 'zh'
neg_sample = True # 是否负采样 - ```python
embed_dim = 300 '''
C = 3 # 窗口大小 通过'evaluator.py'计算每个标签的精确率、召回率和F1分数,输出如下格式:
K = 15 # 负采样大小 precision recall f1-score support
num_epochs = 100 O 0.9999 0.9999 0.9999 150935
batch_size = 32 I-ORG 0.9984 0.9991 0.9988 5640
learning_rate = 0.025 I-LOC 0.9968 0.9970 0.9969 4370
``` B-ORG 0.9955 0.9985 0.9970 1327
I-PER 1.0000 0.9995 0.9997 3845
B-LOC 0.9990 0.9962 0.9976 2871
B-PER 0.9995 1.0000 0.9997 1972
avg/total 0.9997 0.9997 0.9997 170960
'''
```
- 由于标签中 “O”占非常大的比例,因此在计算指标时,采用两种方式:一是直接计算所有的指标,二是去掉“O”这个类别后再计算所有指标。修改`config.py`中的`self.remove_O = False/True`实现不同计算方式。
- **混淆矩阵**
- ```python
'''
Confusion Matrix:
O I-ORG I-LOC B-ORG I-PER B-LOC B-PER
O 150921 6 7 1 0 0 0
I-ORG 1 5635 0 4 0 0 0
I-LOC 8 2 4357 0 0 3 0
B-ORG 1 1 0 1325 0 0 0
I-PER 1 0 0 0 3843 0 1
B-LOC 3 0 7 1 0 2860 0
B-PER 0 0 0 0 0 0 1972
'''
```
# 运行方式
根据需要自行设置`config.py`文件参数,然后运行`main.py`文件即可。
# 参考 # 参考
1. [bert_bilstm_crf_ner_pytorch](https://gitee.com/chenzhouwy/bert_bilstm_crf_ner_pytorch/tree/master) 1. [bert_bilstm_crf_ner_pytorch](https://gitee.com/chenzhouwy/bert_bilstm_crf_ner_pytorch/tree/master)
2. 2. [named_entity_recognition](https://github.com/luopeixiang/named_entity_recognition)
...@@ -18,8 +18,8 @@ class Config(object): ...@@ -18,8 +18,8 @@ class Config(object):
self.label_list = [] self.label_list = []
self.use_gpu = True self.use_gpu = True
self.device = "cuda" self.device = "cuda"
self.checkpoints = True self.checkpoints = True # 使用预训练模型时设置为False
self.model = 'bert_bilstm_crf' # 可选['bert_bilstm_crf','hmm','bilstm_crf] self.model = 'bert_bilstm_crf' # 可选['bert_bilstm_crf','bilstm_crf','bilstm','crf','hmm']
# 输入数据集、日志、输出目录 # 输入数据集、日志、输出目录
self.train_file = os.path.join(self.base_path, 'data/train.txt') self.train_file = os.path.join(self.base_path, 'data/train.txt')
...@@ -46,3 +46,4 @@ class Config(object): ...@@ -46,3 +46,4 @@ class Config(object):
self.adam_epsilon = 1e-8 self.adam_epsilon = 1e-8
self.warmup_steps = 0 self.warmup_steps = 0
self.logging_steps = 50 self.logging_steps = 50
self.remove_O = False
import os
import torch import torch
import logging
from tqdm import tqdm, trange from tqdm import tqdm, trange
from torch.utils.data import DataLoader, SequentialSampler from torch.utils.data import DataLoader, SequentialSampler
...@@ -173,7 +175,7 @@ class Bert_Bilstm_Crf(): ...@@ -173,7 +175,7 @@ class Bert_Bilstm_Crf():
golden_tags = [[ttl[1] for ttl in sen] for sen in eval_sens] golden_tags = [[ttl[1] for ttl in sen] for sen in eval_sens]
predict_tags = [[ttl[2] for ttl in sen] for sen in eval_sens] predict_tags = [[ttl[2] for ttl in sen] for sen in eval_sens]
cal_indicators = Metrics(golden_tags, predict_tags) cal_indicators = Metrics(golden_tags, predict_tags, remove_O=config.remove_O)
avg_metrics = cal_indicators.cal_avg_metrics() # avg_metrics['precision'], avg_metrics['recall'], avg_metrics['f1_score'] avg_metrics = cal_indicators.cal_avg_metrics() # avg_metrics['precision'], avg_metrics['recall'], avg_metrics['f1_score']
return avg_metrics, cal_indicators, eval_sens return avg_metrics, cal_indicators, eval_sens
...@@ -229,7 +231,7 @@ class Bert_Bilstm_Crf(): ...@@ -229,7 +231,7 @@ class Bert_Bilstm_Crf():
# #
# golden_tags = [[ttl[1] for ttl in sen] for sen in eval_sens] # golden_tags = [[ttl[1] for ttl in sen] for sen in eval_sens]
# predict_tags = [[ttl[2] for ttl in sen] for sen in eval_sens] # predict_tags = [[ttl[2] for ttl in sen] for sen in eval_sens]
# cal_indicators = Metrics(golden_tags, predict_tags) # cal_indicators = Metrics(golden_tags, predict_tags, remove_O=self.config.remove_O)
# avg_metrics = cal_indicators.cal_avg_metrics() # avg_metrics = cal_indicators.cal_avg_metrics()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册