diff --git a/README.md b/README.md index 0d2b70c2d0b1118dbb0ee44482ae58ba3ce15b54..6c89ecde582430d89626f8b5ea7d6a7e2ec9b739 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ # 简介 - +使用PyTorch实现中文NER模型,拟提供BERT_BiLSTM_CRF、BiLSTM_CRF、CRF和HMM四种模型,目前实现BERT_BiLSTM_CRF。 @@ -36,14 +36,23 @@ torchcrf==1.1.0 # 目录结构 ```python -Word2Vec -├── Data # 数据集 -│ ├── en.txt -│ ├── zh.txt -├── log # 训练日志 -├── model # 保存模型 +NER_ZH +├── ckpts +│ ├── bert-base-chinese # 预训练模型BERT +│ ├── bert_bilstm_crf # 自己训练的bert_bilstm_crf模型 +│ ├── ... +├── data # 数据集 +│ ├── test.txt +│ ├── train.txt +│ ├── ... +├── logs # 训练日志 +│ ├── bert_bilstm_crf.log # bert_bilstm_crf训练测试日志 +├── output # 预测结果、可视乎结果等 +├── config.py ├── dataloader.py -├── model.py +├── evaluator.py +├── main.py +├── models.py ├── trainer.py ├── utils.py ``` @@ -54,8 +63,6 @@ Word2Vec ## BERT_BiLSTM_CRF - - ```python """ 1. BERT: @@ -121,27 +128,59 @@ Word2Vec """ ``` - +## 模型评估 + +- **精确率P、召回率R以及F1值** + + $P=\frac{TP}{TP+FP}$ + + $R=\frac{TP}{TP+FN}$ + + $F1=\frac{2PR}{P+R}$ + + - ```python + ''' + 通过'evaluator.py'计算每个标签的精确率、召回率和F1分数,输出如下格式: + precision recall f1-score support + O 0.9999 0.9999 0.9999 150935 + I-ORG 0.9984 0.9991 0.9988 5640 + I-LOC 0.9968 0.9970 0.9969 4370 + B-ORG 0.9955 0.9985 0.9970 1327 + I-PER 1.0000 0.9995 0.9997 3845 + B-LOC 0.9990 0.9962 0.9976 2871 + B-PER 0.9995 1.0000 0.9997 1972 + avg/total 0.9997 0.9997 0.9997 170960 + ''' + ``` + + - 由于标签中 “O”占非常大的比例,因此在计算指标时,采用两种方式:一是直接计算所有的指标,二是去掉“O”这个类别后再计算所有指标。修改`config.py`中的`self.remove_O = False/True`实现不同计算方式。 + +- **混淆矩阵** + + - ```python + ''' + Confusion Matrix: + O I-ORG I-LOC B-ORG I-PER B-LOC B-PER + O 150921 6 7 1 0 0 0 + I-ORG 1 5635 0 4 0 0 0 + I-LOC 8 2 4357 0 0 3 0 + B-ORG 1 1 0 1325 0 0 0 + I-PER 1 0 0 0 3843 0 1 + B-LOC 3 0 7 1 0 2860 0 + B-PER 0 0 0 0 0 0 1972 + ''' + ``` + + # 运行方式 -`run.py`文件内设定以下参数后,运行该py文件即可。 - -```python -language = 'zh' -neg_sample = True # 是否负采样 -embed_dim = 300 -C = 3 # 窗口大小 -K = 15 # 负采样大小 -num_epochs = 100 -batch_size = 32 -learning_rate = 0.025 -``` +根据需要自行设置`config.py`文件参数,然后运行`main.py`文件即可。 # 参考 1. [bert_bilstm_crf_ner_pytorch](https://gitee.com/chenzhouwy/bert_bilstm_crf_ner_pytorch/tree/master) -2. +2. [named_entity_recognition](https://github.com/luopeixiang/named_entity_recognition) diff --git a/config.py b/config.py index abf216f9dfa0db2d8995cb796898c4e5b4c50e65..dc78272606bc3b2416f35586644d0eb7dbdfab4c 100644 --- a/config.py +++ b/config.py @@ -18,8 +18,8 @@ class Config(object): self.label_list = [] self.use_gpu = True self.device = "cuda" - self.checkpoints = True - self.model = 'bert_bilstm_crf' # 可选['bert_bilstm_crf','hmm','bilstm_crf] + self.checkpoints = True # 使用预训练模型时设置为False + self.model = 'bert_bilstm_crf' # 可选['bert_bilstm_crf','bilstm_crf','bilstm','crf','hmm'] # 输入数据集、日志、输出目录 self.train_file = os.path.join(self.base_path, 'data/train.txt') @@ -46,3 +46,4 @@ class Config(object): self.adam_epsilon = 1e-8 self.warmup_steps = 0 self.logging_steps = 50 + self.remove_O = False diff --git a/trainer.py b/trainer.py index 63a97bb7b50dbc7686e6723e875935f7472d2a5f..2cae99a46f00c85ca7116d39a8b260ce3dfc1e89 100644 --- a/trainer.py +++ b/trainer.py @@ -1,4 +1,6 @@ +import os import torch +import logging from tqdm import tqdm, trange from torch.utils.data import DataLoader, SequentialSampler @@ -173,7 +175,7 @@ class Bert_Bilstm_Crf(): golden_tags = [[ttl[1] for ttl in sen] for sen in eval_sens] predict_tags = [[ttl[2] for ttl in sen] for sen in eval_sens] - cal_indicators = Metrics(golden_tags, predict_tags) + cal_indicators = Metrics(golden_tags, predict_tags, remove_O=config.remove_O) avg_metrics = cal_indicators.cal_avg_metrics() # avg_metrics['precision'], avg_metrics['recall'], avg_metrics['f1_score'] return avg_metrics, cal_indicators, eval_sens @@ -229,7 +231,7 @@ class Bert_Bilstm_Crf(): # # golden_tags = [[ttl[1] for ttl in sen] for sen in eval_sens] # predict_tags = [[ttl[2] for ttl in sen] for sen in eval_sens] - # cal_indicators = Metrics(golden_tags, predict_tags) + # cal_indicators = Metrics(golden_tags, predict_tags, remove_O=self.config.remove_O) # avg_metrics = cal_indicators.cal_avg_metrics()