提交 ccee4fa1 编写于 作者: misite_J's avatar misite_J

readme updated

上级 334bbabf
......@@ -18,7 +18,7 @@
# 简介
使用PyTorch实现中文NER模型,拟提供BERT_BiLSTM_CRF、BiLSTM_CRF、CRF和HMM四种模型,目前实现BERT_BiLSTM_CRF。
......@@ -36,14 +36,23 @@ torchcrf==1.1.0
# 目录结构
```python
Word2Vec
├── Data # 数据集
├── en.txt
├── zh.txt
├── log # 训练日志
├── model # 保存模型
NER_ZH
├── ckpts
├── bert-base-chinese # 预训练模型BERT
├── bert_bilstm_crf # 自己训练的bert_bilstm_crf模型
├── ...
├── data # 数据集
├── test.txt
├── train.txt
├── ...
├── logs # 训练日志
├── bert_bilstm_crf.log # bert_bilstm_crf训练测试日志
├── output # 预测结果、可视乎结果等
├── config.py
├── dataloader.py
├── model.py
├── evaluator.py
├── main.py
├── models.py
├── trainer.py
├── utils.py
```
......@@ -54,8 +63,6 @@ Word2Vec
## BERT_BiLSTM_CRF
```python
"""
1. BERT:
......@@ -121,27 +128,59 @@ Word2Vec
"""
```
## 模型评估
- **精确率P、召回率R以及F1值**
$P=\frac{TP}{TP+FP}$
$R=\frac{TP}{TP+FN}$
$F1=\frac{2PR}{P+R}$
- ```python
'''
通过'evaluator.py'计算每个标签的精确率、召回率和F1分数,输出如下格式:
precision recall f1-score support
O 0.9999 0.9999 0.9999 150935
I-ORG 0.9984 0.9991 0.9988 5640
I-LOC 0.9968 0.9970 0.9969 4370
B-ORG 0.9955 0.9985 0.9970 1327
I-PER 1.0000 0.9995 0.9997 3845
B-LOC 0.9990 0.9962 0.9976 2871
B-PER 0.9995 1.0000 0.9997 1972
avg/total 0.9997 0.9997 0.9997 170960
'''
```
- 由于标签中 “O”占非常大的比例,因此在计算指标时,采用两种方式:一是直接计算所有的指标,二是去掉“O”这个类别后再计算所有指标。修改`config.py`中的`self.remove_O = False/True`实现不同计算方式。
- **混淆矩阵**
- ```python
'''
Confusion Matrix:
O I-ORG I-LOC B-ORG I-PER B-LOC B-PER
O 150921 6 7 1 0 0 0
I-ORG 1 5635 0 4 0 0 0
I-LOC 8 2 4357 0 0 3 0
B-ORG 1 1 0 1325 0 0 0
I-PER 1 0 0 0 3843 0 1
B-LOC 3 0 7 1 0 2860 0
B-PER 0 0 0 0 0 0 1972
'''
```
# 运行方式
`run.py`文件内设定以下参数后,运行该py文件即可。
```python
language = 'zh'
neg_sample = True # 是否负采样
embed_dim = 300
C = 3 # 窗口大小
K = 15 # 负采样大小
num_epochs = 100
batch_size = 32
learning_rate = 0.025
```
根据需要自行设置`config.py`文件参数,然后运行`main.py`文件即可。
# 参考
1. [bert_bilstm_crf_ner_pytorch](https://gitee.com/chenzhouwy/bert_bilstm_crf_ner_pytorch/tree/master)
2.
2. [named_entity_recognition](https://github.com/luopeixiang/named_entity_recognition)
......@@ -18,8 +18,8 @@ class Config(object):
self.label_list = []
self.use_gpu = True
self.device = "cuda"
self.checkpoints = True
self.model = 'bert_bilstm_crf' # 可选['bert_bilstm_crf','hmm','bilstm_crf]
self.checkpoints = True # 使用预训练模型时设置为False
self.model = 'bert_bilstm_crf' # 可选['bert_bilstm_crf','bilstm_crf','bilstm','crf','hmm']
# 输入数据集、日志、输出目录
self.train_file = os.path.join(self.base_path, 'data/train.txt')
......@@ -46,3 +46,4 @@ class Config(object):
self.adam_epsilon = 1e-8
self.warmup_steps = 0
self.logging_steps = 50
self.remove_O = False
import os
import torch
import logging
from tqdm import tqdm, trange
from torch.utils.data import DataLoader, SequentialSampler
......@@ -173,7 +175,7 @@ class Bert_Bilstm_Crf():
golden_tags = [[ttl[1] for ttl in sen] for sen in eval_sens]
predict_tags = [[ttl[2] for ttl in sen] for sen in eval_sens]
cal_indicators = Metrics(golden_tags, predict_tags)
cal_indicators = Metrics(golden_tags, predict_tags, remove_O=config.remove_O)
avg_metrics = cal_indicators.cal_avg_metrics() # avg_metrics['precision'], avg_metrics['recall'], avg_metrics['f1_score']
return avg_metrics, cal_indicators, eval_sens
......@@ -229,7 +231,7 @@ class Bert_Bilstm_Crf():
#
# golden_tags = [[ttl[1] for ttl in sen] for sen in eval_sens]
# predict_tags = [[ttl[2] for ttl in sen] for sen in eval_sens]
# cal_indicators = Metrics(golden_tags, predict_tags)
# cal_indicators = Metrics(golden_tags, predict_tags, remove_O=self.config.remove_O)
# avg_metrics = cal_indicators.cal_avg_metrics()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册