未验证 提交 4eb48457 编写于 作者: K kinghuin 提交者: GitHub

[Cherry-pick] fix crf bug in paddlenlp (#5242)

[Cherry-pick] fix crf bug in paddlenlp 
上级 6f30ec2a
......@@ -161,11 +161,11 @@ def parse_lac_result(words, preds, lengths, word_vocab, label_vocab):
for sent_index in range(len(lengths)):
sent = [
id2word_dict[index]
for index in words[sent_index][:lengths[sent_index] - 1]
for index in words[sent_index][:lengths[sent_index]]
]
tags = [
id2label_dict[index]
for index in preds[sent_index][:lengths[sent_index] - 1]
for index in preds[sent_index][:lengths[sent_index]]
]
sent_out = []
......
......@@ -56,7 +56,7 @@ def evaluate(args):
dataset=test_dataset,
batch_size=args.batch_size,
shuffle=False,
drop_last=True)
drop_last=False)
test_loader = paddle.io.DataLoader(
dataset=test_dataset,
batch_sampler=test_sampler,
......
......@@ -39,7 +39,8 @@ class BiGruCrf(nn.Layer):
vocab_size,
num_labels,
emb_lr=2.0,
crf_lr=0.2):
crf_lr=0.2,
with_start_stop_tag=True):
super(BiGruCrf, self).__init__()
self.word_emb_dim = word_emb_dim
self.vocab_size = vocab_size
......@@ -73,14 +74,17 @@ class BiGruCrf(nn.Layer):
self.fc = nn.Linear(
in_features=self.hidden_size * 2,
out_features=self.num_labels + 2,
out_features=self.num_labels + 2 \
if with_start_stop_tag else self.num_labels,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Uniform(
low=-self.init_bound, high=self.init_bound),
regularizer=paddle.regularizer.L2Decay(coeff=1e-4)))
self.crf = LinearChainCrf(self.num_labels, self.crf_lr)
self.viterbi_decoder = ViterbiDecoder(self.crf.transitions)
self.crf = LinearChainCrf(self.num_labels, self.crf_lr,
with_start_stop_tag)
self.viterbi_decoder = ViterbiDecoder(self.crf.transitions,
with_start_stop_tag)
def forward(self, inputs, lengths):
word_embed = self.word_embedding(inputs)
......
......@@ -55,7 +55,7 @@ def infer(args):
dataset=infer_dataset,
batch_size=args.batch_size,
shuffle=False,
drop_last=True)
drop_last=False)
infer_loader = paddle.io.DataLoader(
dataset=infer_dataset,
batch_sampler=infer_sampler,
......@@ -75,7 +75,7 @@ def infer(args):
test_data=infer_loader, batch_size=args.batch_size)
# Post-processing the lexical analysis results
lengths = np.array(lengths).reshape([-1])
lengths = np.array([l for lens in lengths for l in lens]).reshape([-1])
preds = np.array(
[pred for batch_pred in crf_decodes for pred in batch_pred])
......
......@@ -77,7 +77,7 @@ def train(args):
dataset=test_dataset,
batch_size=args.batch_size,
shuffle=False,
drop_last=True)
drop_last=False)
test_loader = paddle.io.DataLoader(
dataset=test_dataset,
batch_sampler=test_sampler,
......@@ -93,7 +93,7 @@ def train(args):
# Prepare optimizer, loss and metric evaluator
optimizer = paddle.optimizer.Adam(
learning_rate=args.base_lr, parameters=model.parameters())
crf_loss = LinearChainCrfLoss(network.crf.transitions)
crf_loss = LinearChainCrfLoss(network.crf)
chunk_evaluator = ChunkEvaluator(
label_list=train_dataset.label_vocab.keys(), suffix=True)
model.prepare(optimizer, crf_loss, chunk_evaluator)
......@@ -101,7 +101,6 @@ def train(args):
model.load(args.init_checkpoint)
# Start training
callback = paddle.callbacks.ProgBarLogger(log_freq=10, verbose=3)
model.fit(train_data=train_loader,
eval_data=test_loader,
batch_size=args.batch_size,
......@@ -110,9 +109,7 @@ def train(args):
log_freq=10,
save_dir=args.model_save_dir,
save_freq=1,
drop_last=True,
shuffle=True,
callbacks=callback)
shuffle=True)
if __name__ == "__main__":
......
......@@ -164,7 +164,7 @@ if __name__ == '__main__':
optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
crf_loss = LinearChainCrfLoss(network.crf.transitions)
crf_loss = LinearChainCrfLoss(network.crf)
chunk_evaluator = ChunkEvaluator(
label_list=train_ds.label_vocab.keys(), suffix=True)
model.prepare(optimizer, crf_loss, chunk_evaluator)
......
......@@ -58,7 +58,8 @@ class LinearChainCrf(nn.Layer):
def _initialize_alpha(self, batch_size):
# alpha accumulate the path value to get the different next tag
if self._initial_alpha is None:
if self._initial_alpha is None or batch_size > self._initial_alpha.shape[
0]:
# Initialized by a small value.
initial_alpha = paddle.full(
(batch_size, self.num_tags - 1),
......@@ -69,7 +70,7 @@ class LinearChainCrf(nn.Layer):
(batch_size, 1), dtype='float32', fill_value=0.)
self._initial_alpha = paddle.concat(
[initial_alpha, alpha_start], axis=1)
return self._initial_alpha
return self._initial_alpha[:batch_size, :]
def forward(self, inputs, lengths):
"""
......@@ -99,20 +100,13 @@ class LinearChainCrf(nn.Layer):
all_alpha = []
if self.with_start_stop_tag:
alpha = self._initialize_alpha(batch_size).detach()
alpha = self._initialize_alpha(batch_size)
for i, input_exp in enumerate(inputs_t_exp):
# input_exp: batch_size, num_tags, num_tags
# alpha_exp: batch_size, num_tags, num_tags
alpha_exp = alpha.unsqueeze(1).expand(
[batch_size, n_labels, n_labels])
# F(n) = logsumexp(F(n-1) + p(y_n) + T(y_{n-1}, y_n))
mat = input_exp + trans_exp + alpha_exp
alpha = paddle.logsumexp(mat, 2)
all_alpha.append(alpha)
else:
for i, input_exp in enumerate(inputs_t_exp):
if i == 0:
alpha = inputs.transpose([1, 0, 2])[0]
if i == 0 and not self.with_start_stop_tag:
mat = input_exp
else:
alpha_exp = alpha.unsqueeze(1).expand(
[batch_size, n_labels, n_labels])
......@@ -166,7 +160,6 @@ class LinearChainCrf(nn.Layer):
sequence_mask(
self._get_batch_seq_index(batch_size, seq_len), lengths),
'float32')
if self.with_start_stop_tag:
mask = mask[:, :seq_len]
mask_scores = scores * mask
......@@ -191,6 +184,10 @@ class LinearChainCrf(nn.Layer):
fill_value=self.stop_idx)
labels_ext = (1 - mask) * pad_stop + mask * labels_ext
else:
mask = paddle.cast(
sequence_mask(
self._get_batch_seq_index(batch_size, seq_len), lengths),
'int32')
labels_ext = labels
start_tag_indices = labels_ext[:, :-1]
......@@ -212,7 +209,8 @@ class LinearChainCrf(nn.Layer):
return score
def _get_start_stop_tensor(self, batch_size):
if self._start_tensor is None or self._stop_tensor is None:
if self._start_tensor is None or self._stop_tensor is None or batch_size != self._start_tensor.shape[
0]:
self._start_tensor = paddle.full(
(batch_size, 1), dtype='int64', fill_value=self.start_idx)
self._stop_tensor = paddle.full(
......@@ -220,7 +218,8 @@ class LinearChainCrf(nn.Layer):
return self._start_tensor, self._stop_tensor
def _get_batch_index(self, batch_size):
if self._batch_index is None:
if self._batch_index is None or batch_size != self._batch_index.shape[
0]:
self._batch_index = paddle.arange(end=batch_size, dtype="int64")
return self._batch_index
......@@ -231,36 +230,39 @@ class LinearChainCrf(nn.Layer):
def _get_batch_seq_index(self, batch_size, length):
if self._batch_seq_index is None or length + 2 > self._batch_seq_index.shape[
1]:
1] or batch_size > self._batch_seq_index.shape[0]:
self._batch_seq_index = paddle.cumsum(
paddle.ones([batch_size, length + 2], "int64"), axis=1) - 1
if self.with_start_stop_tag:
return self._batch_seq_index[:, :length + 2]
return self._batch_seq_index[:batch_size, :length + 2]
else:
return self._batch_seq_index[:, :length]
return self._batch_seq_index[:batch_size, :length]
class LinearChainCrfLoss(LinearChainCrf):
class LinearChainCrfLoss(nn.Layer):
"""The negative log-likelihood for linear chain Conditional Random Field (CRF).
let $$ Z(x) = \\sum_{y'}exp(score(x,y')) $$, means the sum of all path scores,
then we have $$ loss = -logp(y|x) = -log(exp(score(x,y))/Z(x)) = -score(x,y) + logZ(x) $$
Args:
transitions (Tensor): The transition matrix.
crf (LinearChainCrf): The LinearChainCrf network.
"""
def __init__(self, transitions):
num_labels = transitions.shape[0] - 2
super(LinearChainCrfLoss, self).__init__(num_labels)
self.transitions.set_value(transitions)
def __init__(self, crf):
super(LinearChainCrfLoss, self).__init__()
self.crf = crf
if isinstance(crf, paddle.fluid.framework.ParamBase):
raise ValueError(
"From paddlenlp >= 2.0.0b4, the first param of LinearChainCrfLoss shoule be a LinearChainCrf object. For input parameter 'crf.transitions', you can remove '.transitions' to 'crf'"
)
def forward(self, inputs, lengths, predictions, labels):
# Note: When closing to convergence, the loss could be a small negative number. This may caused by underflow when calculating exp in logsumexp.
# We add relu here to avoid negative loss. In theory, the crf loss must be greater than or equal to 0, relu will not impact on it.
return nn.functional.relu(
super(LinearChainCrfLoss, self).forward(inputs, lengths) -
self.gold_score(inputs, labels, lengths))
self.crf.forward(inputs, lengths) - self.crf.gold_score(
inputs, labels, lengths))
class ViterbiDecoder(nn.Layer):
......@@ -278,6 +280,8 @@ class ViterbiDecoder(nn.Layer):
self.transitions = transitions
self.with_start_stop_tag = with_start_stop_tag
# If consider start and stop, -1 should be START and -2 should be STOP.
if with_start_stop_tag:
self.start_idx = -1
self.stop_idx = -2
self.num_tags = transitions.shape[0]
......@@ -287,7 +291,8 @@ class ViterbiDecoder(nn.Layer):
def _initialize_alpha(self, batch_size):
# alpha accumulate the path value to get the different next tag
if self._initial_alpha is None:
if self._initial_alpha is None or batch_size > self._initial_alpha.shape[
0]:
# Initialized by a small value.
initial_alpha = paddle.full(
(batch_size, self.num_tags - 1),
......@@ -298,7 +303,7 @@ class ViterbiDecoder(nn.Layer):
(batch_size, 1), dtype='float32', fill_value=0.)
self._initial_alpha = paddle.concat(
[initial_alpha, alpha_start], axis=1)
return self._initial_alpha
return self._initial_alpha[:batch_size, :]
def forward(self, inputs, lengths):
"""
......@@ -313,32 +318,34 @@ class ViterbiDecoder(nn.Layer):
"""
batch_size, seq_len, n_labels = inputs.shape
inputs_t = inputs.transpose([1, 0, 2])
trn_exp = self.transitions.unsqueeze(0).expand(
trans_exp = self.transitions.unsqueeze(0).expand(
[batch_size, n_labels, n_labels])
all_alpha = []
historys = []
alpha = self._initialize_alpha(batch_size).detach(
) if self.with_start_stop_tag else None
# inputs_t: seq_len, batch_size, n_labels
# logit: batch_size, n_labels
if self.with_start_stop_tag:
alpha = self._initialize_alpha(batch_size)
else:
alpha = paddle.zeros((batch_size, self.num_tags), dtype='float32')
for i, logit in enumerate(inputs_t):
if alpha is not None:
alpha_exp = alpha.unsqueeze(1).expand(
[batch_size, n_labels, n_labels])
# alpha_trn_sum: batch_size, n_labels, n_labels
alpha_trn_sum = alpha_exp + trn_exp
alpha_trn_sum = alpha_exp + trans_exp
# alpha_max: batch_size, n_labels
# We don't include the emission scores here because the max does not depend on them (we add them in below)
alpha_max = alpha_trn_sum.max(2)
if i == 0:
# if self.with_start_stop_tag, the first antecedent tag must be START, drop it.
# else, the first label has not antecedent tag, pass it.
pass
else:
alpha_argmax = alpha_trn_sum.argmax(2)
historys.append(alpha_argmax)
# Now add in the emission scores
# Now add the emission scores
alpha = alpha_max + logit
else:
alpha = logit
all_alpha.append(alpha)
# Get the valid alpha
......@@ -358,6 +365,7 @@ class ViterbiDecoder(nn.Layer):
historys = paddle.stack(historys).numpy()
lengths_np = lengths.numpy()
batch_path = []
max_len = 0
for batch_id in range(batch_size):
best_last_tag = last_ids[batch_id]
path = [best_last_tag]
......@@ -365,17 +373,16 @@ class ViterbiDecoder(nn.Layer):
# hist: batch_size, n_labels
best_last_tag = hist[batch_id][best_last_tag]
path.append(best_last_tag)
if self.with_start_stop_tag:
# the first one is start
start = path.pop()
path.reverse()
max_len = max(max_len, len(path))
# Pad to the max sequence length, so that the ChunkEvaluator can compute it
path += [0] * (seq_len - len(path))
batch_path.append(path)
batch_path = [path + [0] * (max_len - len(path)) for path in batch_path]
batch_path = paddle.to_tensor(batch_path)
return scores, batch_path
def _get_batch_index(self, batch_size):
if self._batch_index is None:
if self._batch_index is None or batch_size != self._batch_index.shape[
0]:
self._batch_index = paddle.arange(end=batch_size, dtype="int64")
return self._batch_index
......@@ -112,12 +112,12 @@ class ChunkEvaluator(paddle.metric.Metric):
float: mean precision, recall and f1 score.
"""
precision = float(
self.num_correct_chunks
) / self.num_infer_chunks if self.num_infer_chunks else 0
recall = float(self.num_correct_chunks
) / self.num_label_chunks if self.num_label_chunks else 0
f1_score = float(2 * precision * recall) / (
precision + recall) if self.num_correct_chunks else 0
self.num_correct_chunks /
self.num_infer_chunks) if self.num_infer_chunks else 0.
recall = float(self.num_correct_chunks /
self.num_label_chunks) if self.num_label_chunks else 0.
f1_score = float(2 * precision * recall / (
precision + recall)) if self.num_correct_chunks else 0.
return precision, recall, f1_score
def reset(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册