Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
4eb48457
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4eb48457
编写于
1月 28, 2021
作者:
K
kinghuin
提交者:
GitHub
1月 28, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Cherry-pick] fix crf bug in paddlenlp (#5242)
[Cherry-pick] fix crf bug in paddlenlp
上级
6f30ec2a
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
90 addition
and
82 deletion
+90
-82
PaddleNLP/examples/lexical_analysis/data.py
PaddleNLP/examples/lexical_analysis/data.py
+2
-2
PaddleNLP/examples/lexical_analysis/eval.py
PaddleNLP/examples/lexical_analysis/eval.py
+1
-1
PaddleNLP/examples/lexical_analysis/model.py
PaddleNLP/examples/lexical_analysis/model.py
+8
-4
PaddleNLP/examples/lexical_analysis/predict.py
PaddleNLP/examples/lexical_analysis/predict.py
+2
-2
PaddleNLP/examples/lexical_analysis/train.py
PaddleNLP/examples/lexical_analysis/train.py
+3
-6
PaddleNLP/examples/named_entity_recognition/express_ner/run_bigru_crf.py
...les/named_entity_recognition/express_ner/run_bigru_crf.py
+1
-1
PaddleNLP/paddlenlp/layers/crf.py
PaddleNLP/paddlenlp/layers/crf.py
+67
-60
PaddleNLP/paddlenlp/metrics/chunk.py
PaddleNLP/paddlenlp/metrics/chunk.py
+6
-6
未找到文件。
PaddleNLP/examples/lexical_analysis/data.py
浏览文件 @
4eb48457
...
...
@@ -161,11 +161,11 @@ def parse_lac_result(words, preds, lengths, word_vocab, label_vocab):
for
sent_index
in
range
(
len
(
lengths
)):
sent
=
[
id2word_dict
[
index
]
for
index
in
words
[
sent_index
][:
lengths
[
sent_index
]
-
1
]
for
index
in
words
[
sent_index
][:
lengths
[
sent_index
]]
]
tags
=
[
id2label_dict
[
index
]
for
index
in
preds
[
sent_index
][:
lengths
[
sent_index
]
-
1
]
for
index
in
preds
[
sent_index
][:
lengths
[
sent_index
]]
]
sent_out
=
[]
...
...
PaddleNLP/examples/lexical_analysis/eval.py
浏览文件 @
4eb48457
...
...
@@ -56,7 +56,7 @@ def evaluate(args):
dataset
=
test_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
drop_last
=
Tru
e
)
drop_last
=
Fals
e
)
test_loader
=
paddle
.
io
.
DataLoader
(
dataset
=
test_dataset
,
batch_sampler
=
test_sampler
,
...
...
PaddleNLP/examples/lexical_analysis/model.py
浏览文件 @
4eb48457
...
...
@@ -39,7 +39,8 @@ class BiGruCrf(nn.Layer):
vocab_size
,
num_labels
,
emb_lr
=
2.0
,
crf_lr
=
0.2
):
crf_lr
=
0.2
,
with_start_stop_tag
=
True
):
super
(
BiGruCrf
,
self
).
__init__
()
self
.
word_emb_dim
=
word_emb_dim
self
.
vocab_size
=
vocab_size
...
...
@@ -73,14 +74,17 @@ class BiGruCrf(nn.Layer):
self
.
fc
=
nn
.
Linear
(
in_features
=
self
.
hidden_size
*
2
,
out_features
=
self
.
num_labels
+
2
,
out_features
=
self
.
num_labels
+
2
\
if
with_start_stop_tag
else
self
.
num_labels
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Uniform
(
low
=-
self
.
init_bound
,
high
=
self
.
init_bound
),
regularizer
=
paddle
.
regularizer
.
L2Decay
(
coeff
=
1e-4
)))
self
.
crf
=
LinearChainCrf
(
self
.
num_labels
,
self
.
crf_lr
)
self
.
viterbi_decoder
=
ViterbiDecoder
(
self
.
crf
.
transitions
)
self
.
crf
=
LinearChainCrf
(
self
.
num_labels
,
self
.
crf_lr
,
with_start_stop_tag
)
self
.
viterbi_decoder
=
ViterbiDecoder
(
self
.
crf
.
transitions
,
with_start_stop_tag
)
def
forward
(
self
,
inputs
,
lengths
):
word_embed
=
self
.
word_embedding
(
inputs
)
...
...
PaddleNLP/examples/lexical_analysis/predict.py
浏览文件 @
4eb48457
...
...
@@ -55,7 +55,7 @@ def infer(args):
dataset
=
infer_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
drop_last
=
Tru
e
)
drop_last
=
Fals
e
)
infer_loader
=
paddle
.
io
.
DataLoader
(
dataset
=
infer_dataset
,
batch_sampler
=
infer_sampler
,
...
...
@@ -75,7 +75,7 @@ def infer(args):
test_data
=
infer_loader
,
batch_size
=
args
.
batch_size
)
# Post-processing the lexical analysis results
lengths
=
np
.
array
(
lengths
).
reshape
([
-
1
])
lengths
=
np
.
array
(
[
l
for
lens
in
lengths
for
l
in
lens
]
).
reshape
([
-
1
])
preds
=
np
.
array
(
[
pred
for
batch_pred
in
crf_decodes
for
pred
in
batch_pred
])
...
...
PaddleNLP/examples/lexical_analysis/train.py
浏览文件 @
4eb48457
...
...
@@ -77,7 +77,7 @@ def train(args):
dataset
=
test_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
drop_last
=
Tru
e
)
drop_last
=
Fals
e
)
test_loader
=
paddle
.
io
.
DataLoader
(
dataset
=
test_dataset
,
batch_sampler
=
test_sampler
,
...
...
@@ -93,7 +93,7 @@ def train(args):
# Prepare optimizer, loss and metric evaluator
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
args
.
base_lr
,
parameters
=
model
.
parameters
())
crf_loss
=
LinearChainCrfLoss
(
network
.
crf
.
transitions
)
crf_loss
=
LinearChainCrfLoss
(
network
.
crf
)
chunk_evaluator
=
ChunkEvaluator
(
label_list
=
train_dataset
.
label_vocab
.
keys
(),
suffix
=
True
)
model
.
prepare
(
optimizer
,
crf_loss
,
chunk_evaluator
)
...
...
@@ -101,7 +101,6 @@ def train(args):
model
.
load
(
args
.
init_checkpoint
)
# Start training
callback
=
paddle
.
callbacks
.
ProgBarLogger
(
log_freq
=
10
,
verbose
=
3
)
model
.
fit
(
train_data
=
train_loader
,
eval_data
=
test_loader
,
batch_size
=
args
.
batch_size
,
...
...
@@ -110,9 +109,7 @@ def train(args):
log_freq
=
10
,
save_dir
=
args
.
model_save_dir
,
save_freq
=
1
,
drop_last
=
True
,
shuffle
=
True
,
callbacks
=
callback
)
shuffle
=
True
)
if
__name__
==
"__main__"
:
...
...
PaddleNLP/examples/named_entity_recognition/express_ner/run_bigru_crf.py
浏览文件 @
4eb48457
...
...
@@ -164,7 +164,7 @@ if __name__ == '__main__':
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.001
,
parameters
=
model
.
parameters
())
crf_loss
=
LinearChainCrfLoss
(
network
.
crf
.
transitions
)
crf_loss
=
LinearChainCrfLoss
(
network
.
crf
)
chunk_evaluator
=
ChunkEvaluator
(
label_list
=
train_ds
.
label_vocab
.
keys
(),
suffix
=
True
)
model
.
prepare
(
optimizer
,
crf_loss
,
chunk_evaluator
)
...
...
PaddleNLP/paddlenlp/layers/crf.py
浏览文件 @
4eb48457
...
...
@@ -58,7 +58,8 @@ class LinearChainCrf(nn.Layer):
def
_initialize_alpha
(
self
,
batch_size
):
# alpha accumulate the path value to get the different next tag
if
self
.
_initial_alpha
is
None
:
if
self
.
_initial_alpha
is
None
or
batch_size
>
self
.
_initial_alpha
.
shape
[
0
]:
# Initialized by a small value.
initial_alpha
=
paddle
.
full
(
(
batch_size
,
self
.
num_tags
-
1
),
...
...
@@ -69,7 +70,7 @@ class LinearChainCrf(nn.Layer):
(
batch_size
,
1
),
dtype
=
'float32'
,
fill_value
=
0.
)
self
.
_initial_alpha
=
paddle
.
concat
(
[
initial_alpha
,
alpha_start
],
axis
=
1
)
return
self
.
_initial_alpha
return
self
.
_initial_alpha
[:
batch_size
,
:]
def
forward
(
self
,
inputs
,
lengths
):
"""
...
...
@@ -99,27 +100,20 @@ class LinearChainCrf(nn.Layer):
all_alpha
=
[]
if
self
.
with_start_stop_tag
:
alpha
=
self
.
_initialize_alpha
(
batch_size
).
detach
()
for
i
,
input_exp
in
enumerate
(
inputs_t_exp
):
# input_exp: batch_size, num_tags, num_tags
# alpha_exp: batch_size, num_tags, num_tags
alpha
=
self
.
_initialize_alpha
(
batch_size
)
for
i
,
input_exp
in
enumerate
(
inputs_t_exp
):
# input_exp: batch_size, num_tags, num_tags
# alpha_exp: batch_size, num_tags, num_tags
if
i
==
0
and
not
self
.
with_start_stop_tag
:
mat
=
input_exp
else
:
alpha_exp
=
alpha
.
unsqueeze
(
1
).
expand
(
[
batch_size
,
n_labels
,
n_labels
])
# F(n) = logsumexp(F(n-1) + p(y_n) + T(y_{n-1}, y_n))
mat
=
input_exp
+
trans_exp
+
alpha_exp
alpha
=
paddle
.
logsumexp
(
mat
,
2
)
all_alpha
.
append
(
alpha
)
else
:
for
i
,
input_exp
in
enumerate
(
inputs_t_exp
):
if
i
==
0
:
alpha
=
inputs
.
transpose
([
1
,
0
,
2
])[
0
]
else
:
alpha_exp
=
alpha
.
unsqueeze
(
1
).
expand
(
[
batch_size
,
n_labels
,
n_labels
])
# F(n) = logsumexp(F(n-1) + p(y_n) + T(y_{n-1}, y_n))
mat
=
input_exp
+
trans_exp
+
alpha_exp
alpha
=
paddle
.
logsumexp
(
mat
,
2
)
all_alpha
.
append
(
alpha
)
alpha
=
paddle
.
logsumexp
(
mat
,
2
)
all_alpha
.
append
(
alpha
)
# Get the valid alpha
all_alpha
=
paddle
.
stack
(
all_alpha
).
transpose
([
1
,
0
,
2
])
...
...
@@ -166,8 +160,7 @@ class LinearChainCrf(nn.Layer):
sequence_mask
(
self
.
_get_batch_seq_index
(
batch_size
,
seq_len
),
lengths
),
'float32'
)
if
self
.
with_start_stop_tag
:
mask
=
mask
[:,
:
seq_len
]
mask
=
mask
[:,
:
seq_len
]
mask_scores
=
scores
*
mask
score
=
paddle
.
sum
(
mask_scores
,
1
)
...
...
@@ -191,6 +184,10 @@ class LinearChainCrf(nn.Layer):
fill_value
=
self
.
stop_idx
)
labels_ext
=
(
1
-
mask
)
*
pad_stop
+
mask
*
labels_ext
else
:
mask
=
paddle
.
cast
(
sequence_mask
(
self
.
_get_batch_seq_index
(
batch_size
,
seq_len
),
lengths
),
'int32'
)
labels_ext
=
labels
start_tag_indices
=
labels_ext
[:,
:
-
1
]
...
...
@@ -212,7 +209,8 @@ class LinearChainCrf(nn.Layer):
return
score
def
_get_start_stop_tensor
(
self
,
batch_size
):
if
self
.
_start_tensor
is
None
or
self
.
_stop_tensor
is
None
:
if
self
.
_start_tensor
is
None
or
self
.
_stop_tensor
is
None
or
batch_size
!=
self
.
_start_tensor
.
shape
[
0
]:
self
.
_start_tensor
=
paddle
.
full
(
(
batch_size
,
1
),
dtype
=
'int64'
,
fill_value
=
self
.
start_idx
)
self
.
_stop_tensor
=
paddle
.
full
(
...
...
@@ -220,7 +218,8 @@ class LinearChainCrf(nn.Layer):
return
self
.
_start_tensor
,
self
.
_stop_tensor
def
_get_batch_index
(
self
,
batch_size
):
if
self
.
_batch_index
is
None
:
if
self
.
_batch_index
is
None
or
batch_size
!=
self
.
_batch_index
.
shape
[
0
]:
self
.
_batch_index
=
paddle
.
arange
(
end
=
batch_size
,
dtype
=
"int64"
)
return
self
.
_batch_index
...
...
@@ -231,36 +230,39 @@ class LinearChainCrf(nn.Layer):
def
_get_batch_seq_index
(
self
,
batch_size
,
length
):
if
self
.
_batch_seq_index
is
None
or
length
+
2
>
self
.
_batch_seq_index
.
shape
[
1
]:
1
]
or
batch_size
>
self
.
_batch_seq_index
.
shape
[
0
]
:
self
.
_batch_seq_index
=
paddle
.
cumsum
(
paddle
.
ones
([
batch_size
,
length
+
2
],
"int64"
),
axis
=
1
)
-
1
if
self
.
with_start_stop_tag
:
return
self
.
_batch_seq_index
[:,
:
length
+
2
]
return
self
.
_batch_seq_index
[:
batch_size
,
:
length
+
2
]
else
:
return
self
.
_batch_seq_index
[:,
:
length
]
return
self
.
_batch_seq_index
[:
batch_size
,
:
length
]
class
LinearChainCrfLoss
(
LinearChainCrf
):
class
LinearChainCrfLoss
(
nn
.
Layer
):
"""The negative log-likelihood for linear chain Conditional Random Field (CRF).
let $$ Z(x) =
\\
sum_{y'}exp(score(x,y')) $$, means the sum of all path scores,
then we have $$ loss = -logp(y|x) = -log(exp(score(x,y))/Z(x)) = -score(x,y) + logZ(x) $$
Args:
transitions (Tensor): The transition matrix
.
crf (LinearChainCrf): The LinearChainCrf network
.
"""
def
__init__
(
self
,
transitions
):
num_labels
=
transitions
.
shape
[
0
]
-
2
super
(
LinearChainCrfLoss
,
self
).
__init__
(
num_labels
)
self
.
transitions
.
set_value
(
transitions
)
def
__init__
(
self
,
crf
):
super
(
LinearChainCrfLoss
,
self
).
__init__
()
self
.
crf
=
crf
if
isinstance
(
crf
,
paddle
.
fluid
.
framework
.
ParamBase
):
raise
ValueError
(
"From paddlenlp >= 2.0.0b4, the first param of LinearChainCrfLoss shoule be a LinearChainCrf object. For input parameter 'crf.transitions', you can remove '.transitions' to 'crf'"
)
def
forward
(
self
,
inputs
,
lengths
,
predictions
,
labels
):
# Note: When closing to convergence, the loss could be a small negative number. This may caused by underflow when calculating exp in logsumexp.
# We add relu here to avoid negative loss. In theory, the crf loss must be greater than or equal to 0, relu will not impact on it.
return
nn
.
functional
.
relu
(
s
uper
(
LinearChainCrfLoss
,
self
).
forward
(
inputs
,
lengths
)
-
self
.
gold_score
(
inputs
,
labels
,
lengths
))
s
elf
.
crf
.
forward
(
inputs
,
lengths
)
-
self
.
crf
.
gold_score
(
inputs
,
labels
,
lengths
))
class
ViterbiDecoder
(
nn
.
Layer
):
...
...
@@ -278,7 +280,9 @@ class ViterbiDecoder(nn.Layer):
self
.
transitions
=
transitions
self
.
with_start_stop_tag
=
with_start_stop_tag
# If consider start and stop, -1 should be START and -2 should be STOP.
self
.
stop_idx
=
-
2
if
with_start_stop_tag
:
self
.
start_idx
=
-
1
self
.
stop_idx
=
-
2
self
.
num_tags
=
transitions
.
shape
[
0
]
self
.
_initial_alpha
=
None
...
...
@@ -287,7 +291,8 @@ class ViterbiDecoder(nn.Layer):
def
_initialize_alpha
(
self
,
batch_size
):
# alpha accumulate the path value to get the different next tag
if
self
.
_initial_alpha
is
None
:
if
self
.
_initial_alpha
is
None
or
batch_size
>
self
.
_initial_alpha
.
shape
[
0
]:
# Initialized by a small value.
initial_alpha
=
paddle
.
full
(
(
batch_size
,
self
.
num_tags
-
1
),
...
...
@@ -298,7 +303,7 @@ class ViterbiDecoder(nn.Layer):
(
batch_size
,
1
),
dtype
=
'float32'
,
fill_value
=
0.
)
self
.
_initial_alpha
=
paddle
.
concat
(
[
initial_alpha
,
alpha_start
],
axis
=
1
)
return
self
.
_initial_alpha
return
self
.
_initial_alpha
[:
batch_size
,
:]
def
forward
(
self
,
inputs
,
lengths
):
"""
...
...
@@ -313,32 +318,34 @@ class ViterbiDecoder(nn.Layer):
"""
batch_size
,
seq_len
,
n_labels
=
inputs
.
shape
inputs_t
=
inputs
.
transpose
([
1
,
0
,
2
])
tr
n
_exp
=
self
.
transitions
.
unsqueeze
(
0
).
expand
(
tr
ans
_exp
=
self
.
transitions
.
unsqueeze
(
0
).
expand
(
[
batch_size
,
n_labels
,
n_labels
])
all_alpha
=
[]
historys
=
[]
alpha
=
self
.
_initialize_alpha
(
batch_size
).
detach
(
)
if
self
.
with_start_stop_tag
else
None
# inputs_t: seq_len, batch_size, n_labels
# logit: batch_size, n_labels
if
self
.
with_start_stop_tag
:
alpha
=
self
.
_initialize_alpha
(
batch_size
)
else
:
alpha
=
paddle
.
zeros
((
batch_size
,
self
.
num_tags
),
dtype
=
'float32'
)
for
i
,
logit
in
enumerate
(
inputs_t
):
if
alpha
is
not
None
:
alpha_exp
=
alpha
.
unsqueeze
(
1
).
expand
(
[
batch_size
,
n_labels
,
n_labels
])
# alpha_trn_sum: batch_size, n_labels, n_labels
alpha_trn_sum
=
alpha_exp
+
trn_exp
# alpha_max: batch_size, n_labels
# We don't include the emission scores here because the max does not depend on them (we add them in below)
alpha_max
=
alpha_trn_sum
.
max
(
2
)
alpha_exp
=
alpha
.
unsqueeze
(
1
).
expand
(
[
batch_size
,
n_labels
,
n_labels
])
# alpha_trn_sum: batch_size, n_labels, n_labels
alpha_trn_sum
=
alpha_exp
+
trans_exp
# alpha_max: batch_size, n_labels
# We don't include the emission scores here because the max does not depend on them (we add them in below)
alpha_max
=
alpha_trn_sum
.
max
(
2
)
if
i
==
0
:
# if self.with_start_stop_tag, the first antecedent tag must be START, drop it.
# else, the first label has not antecedent tag, pass it.
pass
else
:
alpha_argmax
=
alpha_trn_sum
.
argmax
(
2
)
historys
.
append
(
alpha_argmax
)
# Now add in the emission scores
alpha
=
alpha_max
+
logit
else
:
alpha
=
logit
# Now add the emission scores
alpha
=
alpha_max
+
logit
all_alpha
.
append
(
alpha
)
# Get the valid alpha
...
...
@@ -358,6 +365,7 @@ class ViterbiDecoder(nn.Layer):
historys
=
paddle
.
stack
(
historys
).
numpy
()
lengths_np
=
lengths
.
numpy
()
batch_path
=
[]
max_len
=
0
for
batch_id
in
range
(
batch_size
):
best_last_tag
=
last_ids
[
batch_id
]
path
=
[
best_last_tag
]
...
...
@@ -365,17 +373,16 @@ class ViterbiDecoder(nn.Layer):
# hist: batch_size, n_labels
best_last_tag
=
hist
[
batch_id
][
best_last_tag
]
path
.
append
(
best_last_tag
)
if
self
.
with_start_stop_tag
:
# the first one is start
start
=
path
.
pop
()
path
.
reverse
()
max_len
=
max
(
max_len
,
len
(
path
))
# Pad to the max sequence length, so that the ChunkEvaluator can compute it
path
+=
[
0
]
*
(
seq_len
-
len
(
path
))
batch_path
.
append
(
path
)
batch_path
=
[
path
+
[
0
]
*
(
max_len
-
len
(
path
))
for
path
in
batch_path
]
batch_path
=
paddle
.
to_tensor
(
batch_path
)
return
scores
,
batch_path
def
_get_batch_index
(
self
,
batch_size
):
if
self
.
_batch_index
is
None
:
if
self
.
_batch_index
is
None
or
batch_size
!=
self
.
_batch_index
.
shape
[
0
]:
self
.
_batch_index
=
paddle
.
arange
(
end
=
batch_size
,
dtype
=
"int64"
)
return
self
.
_batch_index
PaddleNLP/paddlenlp/metrics/chunk.py
浏览文件 @
4eb48457
...
...
@@ -112,12 +112,12 @@ class ChunkEvaluator(paddle.metric.Metric):
float: mean precision, recall and f1 score.
"""
precision
=
float
(
self
.
num_correct_chunks
)
/
self
.
num_infer_chunks
if
self
.
num_infer_chunks
else
0
recall
=
float
(
self
.
num_correct_chunks
)
/
self
.
num_label_chunks
if
self
.
num_label_chunks
else
0
f1_score
=
float
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
if
self
.
num_correct_chunks
else
0
self
.
num_correct_chunks
/
self
.
num_infer_chunks
)
if
self
.
num_infer_chunks
else
0.
recall
=
float
(
self
.
num_correct_chunks
/
self
.
num_label_chunks
)
if
self
.
num_label_chunks
else
0.
f1_score
=
float
(
2
*
precision
*
recall
/
(
precision
+
recall
)
)
if
self
.
num_correct_chunks
else
0.
return
precision
,
recall
,
f1_score
def
reset
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录