Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
余衫马
ner_zh
提交
ccee4fa1
N
ner_zh
项目概览
余衫马
/
ner_zh
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
N
ner_zh
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
ccee4fa1
编写于
11月 16, 2021
作者:
misite_J
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
readme updated
上级
334bbabf
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
70 addition
and
28 deletion
+70
-28
README.md
README.md
+63
-24
config.py
config.py
+3
-2
trainer.py
trainer.py
+4
-2
未找到文件。
README.md
浏览文件 @
ccee4fa1
...
...
@@ -18,7 +18,7 @@
# 简介
使用PyTorch实现中文NER模型,拟提供BERT_BiLSTM_CRF、BiLSTM_CRF、CRF和HMM四种模型,目前实现BERT_BiLSTM_CRF。
...
...
@@ -36,14 +36,23 @@ torchcrf==1.1.0
# 目录结构
```
python
Word2Vec
├──
Data
# 数据集
│
├──
en
.
txt
│
├──
zh
.
txt
├──
log
# 训练日志
├──
model
# 保存模型
NER_ZH
├──
ckpts
│
├──
bert
-
base
-
chinese
# 预训练模型BERT
│
├──
bert_bilstm_crf
# 自己训练的bert_bilstm_crf模型
│
├──
...
├──
data
# 数据集
│
├──
test
.
txt
│
├──
train
.
txt
│
├──
...
├──
logs
# 训练日志
│
├──
bert_bilstm_crf
.
log
# bert_bilstm_crf训练测试日志
├──
output
# 预测结果、可视乎结果等
├──
config
.
py
├──
dataloader
.
py
├──
model
.
py
├──
evaluator
.
py
├──
main
.
py
├──
models
.
py
├──
trainer
.
py
├──
utils
.
py
```
...
...
@@ -54,8 +63,6 @@ Word2Vec
## BERT_BiLSTM_CRF
```
python
"""
1. BERT:
...
...
@@ -121,27 +128,59 @@ Word2Vec
"""
```
## 模型评估
-
**精确率P、召回率R以及F1值**
# 运行方式
$P=
\f
rac{TP}{TP+FP}$
`run.py`
文件内设定以下参数后,运行该py文件即可。
$R=
\f
rac{TP}{TP+FN}$
```
python
language
=
'zh'
neg_sample
=
True
# 是否负采样
embed_dim
=
300
C
=
3
# 窗口大小
K
=
15
# 负采样大小
num_epochs
=
100
batch_size
=
32
learning_rate
=
0.025
```
$F1=
\f
rac{2PR}{P+R}$
-
```python
'''
通过'evaluator.py'计算每个标签的精确率、召回率和F1分数,输出如下格式:
precision recall f1-score support
O 0.9999 0.9999 0.9999 150935
I-ORG 0.9984 0.9991 0.9988 5640
I-LOC 0.9968 0.9970 0.9969 4370
B-ORG 0.9955 0.9985 0.9970 1327
I-PER 1.0000 0.9995 0.9997 3845
B-LOC 0.9990 0.9962 0.9976 2871
B-PER 0.9995 1.0000 0.9997 1972
avg/total 0.9997 0.9997 0.9997 170960
'''
```
-
由于标签中 “O”占非常大的比例,因此在计算指标时,采用两种方式:一是直接计算所有的指标,二是去掉“O”这个类别后再计算所有指标。修改
`config.py`
中的
`self.remove_O = False/True`
实现不同计算方式。
-
**混淆矩阵**
-
```python
'''
Confusion Matrix:
O I-ORG I-LOC B-ORG I-PER B-LOC B-PER
O 150921 6 7 1 0 0 0
I-ORG 1 5635 0 4 0 0 0
I-LOC 8 2 4357 0 0 3 0
B-ORG 1 1 0 1325 0 0 0
I-PER 1 0 0 0 3843 0 1
B-LOC 3 0 7 1 0 2860 0
B-PER 0 0 0 0 0 0 1972
'''
```
# 运行方式
根据需要自行设置
`config.py`
文件参数,然后运行
`main.py`
文件即可。
# 参考
1.
[
bert_bilstm_crf_ner_pytorch
](
https://gitee.com/chenzhouwy/bert_bilstm_crf_ner_pytorch/tree/master
)
2.
2.
[
named_entity_recognition
](
https://github.com/luopeixiang/named_entity_recognition
)
config.py
浏览文件 @
ccee4fa1
...
...
@@ -18,8 +18,8 @@ class Config(object):
self
.
label_list
=
[]
self
.
use_gpu
=
True
self
.
device
=
"cuda"
self
.
checkpoints
=
True
self
.
model
=
'bert_bilstm_crf'
# 可选['bert_bilstm_crf','
hmm','bilstm_crf
]
self
.
checkpoints
=
True
# 使用预训练模型时设置为False
self
.
model
=
'bert_bilstm_crf'
# 可选['bert_bilstm_crf','
bilstm_crf','bilstm','crf','hmm'
]
# 输入数据集、日志、输出目录
self
.
train_file
=
os
.
path
.
join
(
self
.
base_path
,
'data/train.txt'
)
...
...
@@ -46,3 +46,4 @@ class Config(object):
self
.
adam_epsilon
=
1e-8
self
.
warmup_steps
=
0
self
.
logging_steps
=
50
self
.
remove_O
=
False
trainer.py
浏览文件 @
ccee4fa1
import
os
import
torch
import
logging
from
tqdm
import
tqdm
,
trange
from
torch.utils.data
import
DataLoader
,
SequentialSampler
...
...
@@ -173,7 +175,7 @@ class Bert_Bilstm_Crf():
golden_tags
=
[[
ttl
[
1
]
for
ttl
in
sen
]
for
sen
in
eval_sens
]
predict_tags
=
[[
ttl
[
2
]
for
ttl
in
sen
]
for
sen
in
eval_sens
]
cal_indicators
=
Metrics
(
golden_tags
,
predict_tags
)
cal_indicators
=
Metrics
(
golden_tags
,
predict_tags
,
remove_O
=
config
.
remove_O
)
avg_metrics
=
cal_indicators
.
cal_avg_metrics
()
# avg_metrics['precision'], avg_metrics['recall'], avg_metrics['f1_score']
return
avg_metrics
,
cal_indicators
,
eval_sens
...
...
@@ -229,7 +231,7 @@ class Bert_Bilstm_Crf():
#
# golden_tags = [[ttl[1] for ttl in sen] for sen in eval_sens]
# predict_tags = [[ttl[2] for ttl in sen] for sen in eval_sens]
# cal_indicators = Metrics(golden_tags, predict_tags)
# cal_indicators = Metrics(golden_tags, predict_tags
, remove_O=self.config.remove_O
)
# avg_metrics = cal_indicators.cal_avg_metrics()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录