未验证 提交 4eb48457 编写于 作者: K kinghuin 提交者: GitHub

[Cherry-pick] fix crf bug in paddlenlp (#5242)

[Cherry-pick] fix crf bug in paddlenlp 
上级 6f30ec2a
...@@ -161,11 +161,11 @@ def parse_lac_result(words, preds, lengths, word_vocab, label_vocab): ...@@ -161,11 +161,11 @@ def parse_lac_result(words, preds, lengths, word_vocab, label_vocab):
for sent_index in range(len(lengths)): for sent_index in range(len(lengths)):
sent = [ sent = [
id2word_dict[index] id2word_dict[index]
for index in words[sent_index][:lengths[sent_index] - 1] for index in words[sent_index][:lengths[sent_index]]
] ]
tags = [ tags = [
id2label_dict[index] id2label_dict[index]
for index in preds[sent_index][:lengths[sent_index] - 1] for index in preds[sent_index][:lengths[sent_index]]
] ]
sent_out = [] sent_out = []
......
...@@ -56,7 +56,7 @@ def evaluate(args): ...@@ -56,7 +56,7 @@ def evaluate(args):
dataset=test_dataset, dataset=test_dataset,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False, shuffle=False,
drop_last=True) drop_last=False)
test_loader = paddle.io.DataLoader( test_loader = paddle.io.DataLoader(
dataset=test_dataset, dataset=test_dataset,
batch_sampler=test_sampler, batch_sampler=test_sampler,
......
...@@ -39,7 +39,8 @@ class BiGruCrf(nn.Layer): ...@@ -39,7 +39,8 @@ class BiGruCrf(nn.Layer):
vocab_size, vocab_size,
num_labels, num_labels,
emb_lr=2.0, emb_lr=2.0,
crf_lr=0.2): crf_lr=0.2,
with_start_stop_tag=True):
super(BiGruCrf, self).__init__() super(BiGruCrf, self).__init__()
self.word_emb_dim = word_emb_dim self.word_emb_dim = word_emb_dim
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -73,14 +74,17 @@ class BiGruCrf(nn.Layer): ...@@ -73,14 +74,17 @@ class BiGruCrf(nn.Layer):
self.fc = nn.Linear( self.fc = nn.Linear(
in_features=self.hidden_size * 2, in_features=self.hidden_size * 2,
out_features=self.num_labels + 2, out_features=self.num_labels + 2 \
if with_start_stop_tag else self.num_labels,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Uniform( initializer=nn.initializer.Uniform(
low=-self.init_bound, high=self.init_bound), low=-self.init_bound, high=self.init_bound),
regularizer=paddle.regularizer.L2Decay(coeff=1e-4))) regularizer=paddle.regularizer.L2Decay(coeff=1e-4)))
self.crf = LinearChainCrf(self.num_labels, self.crf_lr) self.crf = LinearChainCrf(self.num_labels, self.crf_lr,
self.viterbi_decoder = ViterbiDecoder(self.crf.transitions) with_start_stop_tag)
self.viterbi_decoder = ViterbiDecoder(self.crf.transitions,
with_start_stop_tag)
def forward(self, inputs, lengths): def forward(self, inputs, lengths):
word_embed = self.word_embedding(inputs) word_embed = self.word_embedding(inputs)
......
...@@ -55,7 +55,7 @@ def infer(args): ...@@ -55,7 +55,7 @@ def infer(args):
dataset=infer_dataset, dataset=infer_dataset,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False, shuffle=False,
drop_last=True) drop_last=False)
infer_loader = paddle.io.DataLoader( infer_loader = paddle.io.DataLoader(
dataset=infer_dataset, dataset=infer_dataset,
batch_sampler=infer_sampler, batch_sampler=infer_sampler,
...@@ -75,7 +75,7 @@ def infer(args): ...@@ -75,7 +75,7 @@ def infer(args):
test_data=infer_loader, batch_size=args.batch_size) test_data=infer_loader, batch_size=args.batch_size)
# Post-processing the lexical analysis results # Post-processing the lexical analysis results
lengths = np.array(lengths).reshape([-1]) lengths = np.array([l for lens in lengths for l in lens]).reshape([-1])
preds = np.array( preds = np.array(
[pred for batch_pred in crf_decodes for pred in batch_pred]) [pred for batch_pred in crf_decodes for pred in batch_pred])
......
...@@ -77,7 +77,7 @@ def train(args): ...@@ -77,7 +77,7 @@ def train(args):
dataset=test_dataset, dataset=test_dataset,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False, shuffle=False,
drop_last=True) drop_last=False)
test_loader = paddle.io.DataLoader( test_loader = paddle.io.DataLoader(
dataset=test_dataset, dataset=test_dataset,
batch_sampler=test_sampler, batch_sampler=test_sampler,
...@@ -93,7 +93,7 @@ def train(args): ...@@ -93,7 +93,7 @@ def train(args):
# Prepare optimizer, loss and metric evaluator # Prepare optimizer, loss and metric evaluator
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.Adam(
learning_rate=args.base_lr, parameters=model.parameters()) learning_rate=args.base_lr, parameters=model.parameters())
crf_loss = LinearChainCrfLoss(network.crf.transitions) crf_loss = LinearChainCrfLoss(network.crf)
chunk_evaluator = ChunkEvaluator( chunk_evaluator = ChunkEvaluator(
label_list=train_dataset.label_vocab.keys(), suffix=True) label_list=train_dataset.label_vocab.keys(), suffix=True)
model.prepare(optimizer, crf_loss, chunk_evaluator) model.prepare(optimizer, crf_loss, chunk_evaluator)
...@@ -101,7 +101,6 @@ def train(args): ...@@ -101,7 +101,6 @@ def train(args):
model.load(args.init_checkpoint) model.load(args.init_checkpoint)
# Start training # Start training
callback = paddle.callbacks.ProgBarLogger(log_freq=10, verbose=3)
model.fit(train_data=train_loader, model.fit(train_data=train_loader,
eval_data=test_loader, eval_data=test_loader,
batch_size=args.batch_size, batch_size=args.batch_size,
...@@ -110,9 +109,7 @@ def train(args): ...@@ -110,9 +109,7 @@ def train(args):
log_freq=10, log_freq=10,
save_dir=args.model_save_dir, save_dir=args.model_save_dir,
save_freq=1, save_freq=1,
drop_last=True, shuffle=True)
shuffle=True,
callbacks=callback)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -164,7 +164,7 @@ if __name__ == '__main__': ...@@ -164,7 +164,7 @@ if __name__ == '__main__':
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters()) learning_rate=0.001, parameters=model.parameters())
crf_loss = LinearChainCrfLoss(network.crf.transitions) crf_loss = LinearChainCrfLoss(network.crf)
chunk_evaluator = ChunkEvaluator( chunk_evaluator = ChunkEvaluator(
label_list=train_ds.label_vocab.keys(), suffix=True) label_list=train_ds.label_vocab.keys(), suffix=True)
model.prepare(optimizer, crf_loss, chunk_evaluator) model.prepare(optimizer, crf_loss, chunk_evaluator)
......
...@@ -58,7 +58,8 @@ class LinearChainCrf(nn.Layer): ...@@ -58,7 +58,8 @@ class LinearChainCrf(nn.Layer):
def _initialize_alpha(self, batch_size): def _initialize_alpha(self, batch_size):
# alpha accumulate the path value to get the different next tag # alpha accumulate the path value to get the different next tag
if self._initial_alpha is None: if self._initial_alpha is None or batch_size > self._initial_alpha.shape[
0]:
# Initialized by a small value. # Initialized by a small value.
initial_alpha = paddle.full( initial_alpha = paddle.full(
(batch_size, self.num_tags - 1), (batch_size, self.num_tags - 1),
...@@ -69,7 +70,7 @@ class LinearChainCrf(nn.Layer): ...@@ -69,7 +70,7 @@ class LinearChainCrf(nn.Layer):
(batch_size, 1), dtype='float32', fill_value=0.) (batch_size, 1), dtype='float32', fill_value=0.)
self._initial_alpha = paddle.concat( self._initial_alpha = paddle.concat(
[initial_alpha, alpha_start], axis=1) [initial_alpha, alpha_start], axis=1)
return self._initial_alpha return self._initial_alpha[:batch_size, :]
def forward(self, inputs, lengths): def forward(self, inputs, lengths):
""" """
...@@ -99,27 +100,20 @@ class LinearChainCrf(nn.Layer): ...@@ -99,27 +100,20 @@ class LinearChainCrf(nn.Layer):
all_alpha = [] all_alpha = []
if self.with_start_stop_tag: if self.with_start_stop_tag:
alpha = self._initialize_alpha(batch_size).detach() alpha = self._initialize_alpha(batch_size)
for i, input_exp in enumerate(inputs_t_exp):
# input_exp: batch_size, num_tags, num_tags for i, input_exp in enumerate(inputs_t_exp):
# alpha_exp: batch_size, num_tags, num_tags # input_exp: batch_size, num_tags, num_tags
# alpha_exp: batch_size, num_tags, num_tags
if i == 0 and not self.with_start_stop_tag:
mat = input_exp
else:
alpha_exp = alpha.unsqueeze(1).expand( alpha_exp = alpha.unsqueeze(1).expand(
[batch_size, n_labels, n_labels]) [batch_size, n_labels, n_labels])
# F(n) = logsumexp(F(n-1) + p(y_n) + T(y_{n-1}, y_n)) # F(n) = logsumexp(F(n-1) + p(y_n) + T(y_{n-1}, y_n))
mat = input_exp + trans_exp + alpha_exp mat = input_exp + trans_exp + alpha_exp
alpha = paddle.logsumexp(mat, 2) alpha = paddle.logsumexp(mat, 2)
all_alpha.append(alpha) all_alpha.append(alpha)
else:
for i, input_exp in enumerate(inputs_t_exp):
if i == 0:
alpha = inputs.transpose([1, 0, 2])[0]
else:
alpha_exp = alpha.unsqueeze(1).expand(
[batch_size, n_labels, n_labels])
# F(n) = logsumexp(F(n-1) + p(y_n) + T(y_{n-1}, y_n))
mat = input_exp + trans_exp + alpha_exp
alpha = paddle.logsumexp(mat, 2)
all_alpha.append(alpha)
# Get the valid alpha # Get the valid alpha
all_alpha = paddle.stack(all_alpha).transpose([1, 0, 2]) all_alpha = paddle.stack(all_alpha).transpose([1, 0, 2])
...@@ -166,8 +160,7 @@ class LinearChainCrf(nn.Layer): ...@@ -166,8 +160,7 @@ class LinearChainCrf(nn.Layer):
sequence_mask( sequence_mask(
self._get_batch_seq_index(batch_size, seq_len), lengths), self._get_batch_seq_index(batch_size, seq_len), lengths),
'float32') 'float32')
if self.with_start_stop_tag: mask = mask[:, :seq_len]
mask = mask[:, :seq_len]
mask_scores = scores * mask mask_scores = scores * mask
score = paddle.sum(mask_scores, 1) score = paddle.sum(mask_scores, 1)
...@@ -191,6 +184,10 @@ class LinearChainCrf(nn.Layer): ...@@ -191,6 +184,10 @@ class LinearChainCrf(nn.Layer):
fill_value=self.stop_idx) fill_value=self.stop_idx)
labels_ext = (1 - mask) * pad_stop + mask * labels_ext labels_ext = (1 - mask) * pad_stop + mask * labels_ext
else: else:
mask = paddle.cast(
sequence_mask(
self._get_batch_seq_index(batch_size, seq_len), lengths),
'int32')
labels_ext = labels labels_ext = labels
start_tag_indices = labels_ext[:, :-1] start_tag_indices = labels_ext[:, :-1]
...@@ -212,7 +209,8 @@ class LinearChainCrf(nn.Layer): ...@@ -212,7 +209,8 @@ class LinearChainCrf(nn.Layer):
return score return score
def _get_start_stop_tensor(self, batch_size): def _get_start_stop_tensor(self, batch_size):
if self._start_tensor is None or self._stop_tensor is None: if self._start_tensor is None or self._stop_tensor is None or batch_size != self._start_tensor.shape[
0]:
self._start_tensor = paddle.full( self._start_tensor = paddle.full(
(batch_size, 1), dtype='int64', fill_value=self.start_idx) (batch_size, 1), dtype='int64', fill_value=self.start_idx)
self._stop_tensor = paddle.full( self._stop_tensor = paddle.full(
...@@ -220,7 +218,8 @@ class LinearChainCrf(nn.Layer): ...@@ -220,7 +218,8 @@ class LinearChainCrf(nn.Layer):
return self._start_tensor, self._stop_tensor return self._start_tensor, self._stop_tensor
def _get_batch_index(self, batch_size): def _get_batch_index(self, batch_size):
if self._batch_index is None: if self._batch_index is None or batch_size != self._batch_index.shape[
0]:
self._batch_index = paddle.arange(end=batch_size, dtype="int64") self._batch_index = paddle.arange(end=batch_size, dtype="int64")
return self._batch_index return self._batch_index
...@@ -231,36 +230,39 @@ class LinearChainCrf(nn.Layer): ...@@ -231,36 +230,39 @@ class LinearChainCrf(nn.Layer):
def _get_batch_seq_index(self, batch_size, length): def _get_batch_seq_index(self, batch_size, length):
if self._batch_seq_index is None or length + 2 > self._batch_seq_index.shape[ if self._batch_seq_index is None or length + 2 > self._batch_seq_index.shape[
1]: 1] or batch_size > self._batch_seq_index.shape[0]:
self._batch_seq_index = paddle.cumsum( self._batch_seq_index = paddle.cumsum(
paddle.ones([batch_size, length + 2], "int64"), axis=1) - 1 paddle.ones([batch_size, length + 2], "int64"), axis=1) - 1
if self.with_start_stop_tag: if self.with_start_stop_tag:
return self._batch_seq_index[:, :length + 2] return self._batch_seq_index[:batch_size, :length + 2]
else: else:
return self._batch_seq_index[:, :length] return self._batch_seq_index[:batch_size, :length]
class LinearChainCrfLoss(LinearChainCrf): class LinearChainCrfLoss(nn.Layer):
"""The negative log-likelihood for linear chain Conditional Random Field (CRF). """The negative log-likelihood for linear chain Conditional Random Field (CRF).
let $$ Z(x) = \\sum_{y'}exp(score(x,y')) $$, means the sum of all path scores, let $$ Z(x) = \\sum_{y'}exp(score(x,y')) $$, means the sum of all path scores,
then we have $$ loss = -logp(y|x) = -log(exp(score(x,y))/Z(x)) = -score(x,y) + logZ(x) $$ then we have $$ loss = -logp(y|x) = -log(exp(score(x,y))/Z(x)) = -score(x,y) + logZ(x) $$
Args: Args:
transitions (Tensor): The transition matrix. crf (LinearChainCrf): The LinearChainCrf network.
""" """
def __init__(self, transitions): def __init__(self, crf):
num_labels = transitions.shape[0] - 2 super(LinearChainCrfLoss, self).__init__()
super(LinearChainCrfLoss, self).__init__(num_labels) self.crf = crf
self.transitions.set_value(transitions) if isinstance(crf, paddle.fluid.framework.ParamBase):
raise ValueError(
"From paddlenlp >= 2.0.0b4, the first param of LinearChainCrfLoss shoule be a LinearChainCrf object. For input parameter 'crf.transitions', you can remove '.transitions' to 'crf'"
)
def forward(self, inputs, lengths, predictions, labels): def forward(self, inputs, lengths, predictions, labels):
# Note: When closing to convergence, the loss could be a small negative number. This may caused by underflow when calculating exp in logsumexp. # Note: When closing to convergence, the loss could be a small negative number. This may caused by underflow when calculating exp in logsumexp.
# We add relu here to avoid negative loss. In theory, the crf loss must be greater than or equal to 0, relu will not impact on it. # We add relu here to avoid negative loss. In theory, the crf loss must be greater than or equal to 0, relu will not impact on it.
return nn.functional.relu( return nn.functional.relu(
super(LinearChainCrfLoss, self).forward(inputs, lengths) - self.crf.forward(inputs, lengths) - self.crf.gold_score(
self.gold_score(inputs, labels, lengths)) inputs, labels, lengths))
class ViterbiDecoder(nn.Layer): class ViterbiDecoder(nn.Layer):
...@@ -278,7 +280,9 @@ class ViterbiDecoder(nn.Layer): ...@@ -278,7 +280,9 @@ class ViterbiDecoder(nn.Layer):
self.transitions = transitions self.transitions = transitions
self.with_start_stop_tag = with_start_stop_tag self.with_start_stop_tag = with_start_stop_tag
# If consider start and stop, -1 should be START and -2 should be STOP. # If consider start and stop, -1 should be START and -2 should be STOP.
self.stop_idx = -2 if with_start_stop_tag:
self.start_idx = -1
self.stop_idx = -2
self.num_tags = transitions.shape[0] self.num_tags = transitions.shape[0]
self._initial_alpha = None self._initial_alpha = None
...@@ -287,7 +291,8 @@ class ViterbiDecoder(nn.Layer): ...@@ -287,7 +291,8 @@ class ViterbiDecoder(nn.Layer):
def _initialize_alpha(self, batch_size): def _initialize_alpha(self, batch_size):
# alpha accumulate the path value to get the different next tag # alpha accumulate the path value to get the different next tag
if self._initial_alpha is None: if self._initial_alpha is None or batch_size > self._initial_alpha.shape[
0]:
# Initialized by a small value. # Initialized by a small value.
initial_alpha = paddle.full( initial_alpha = paddle.full(
(batch_size, self.num_tags - 1), (batch_size, self.num_tags - 1),
...@@ -298,7 +303,7 @@ class ViterbiDecoder(nn.Layer): ...@@ -298,7 +303,7 @@ class ViterbiDecoder(nn.Layer):
(batch_size, 1), dtype='float32', fill_value=0.) (batch_size, 1), dtype='float32', fill_value=0.)
self._initial_alpha = paddle.concat( self._initial_alpha = paddle.concat(
[initial_alpha, alpha_start], axis=1) [initial_alpha, alpha_start], axis=1)
return self._initial_alpha return self._initial_alpha[:batch_size, :]
def forward(self, inputs, lengths): def forward(self, inputs, lengths):
""" """
...@@ -313,32 +318,34 @@ class ViterbiDecoder(nn.Layer): ...@@ -313,32 +318,34 @@ class ViterbiDecoder(nn.Layer):
""" """
batch_size, seq_len, n_labels = inputs.shape batch_size, seq_len, n_labels = inputs.shape
inputs_t = inputs.transpose([1, 0, 2]) inputs_t = inputs.transpose([1, 0, 2])
trn_exp = self.transitions.unsqueeze(0).expand( trans_exp = self.transitions.unsqueeze(0).expand(
[batch_size, n_labels, n_labels]) [batch_size, n_labels, n_labels])
all_alpha = [] all_alpha = []
historys = [] historys = []
alpha = self._initialize_alpha(batch_size).detach( if self.with_start_stop_tag:
) if self.with_start_stop_tag else None alpha = self._initialize_alpha(batch_size)
# inputs_t: seq_len, batch_size, n_labels else:
# logit: batch_size, n_labels alpha = paddle.zeros((batch_size, self.num_tags), dtype='float32')
for i, logit in enumerate(inputs_t): for i, logit in enumerate(inputs_t):
if alpha is not None: alpha_exp = alpha.unsqueeze(1).expand(
alpha_exp = alpha.unsqueeze(1).expand( [batch_size, n_labels, n_labels])
[batch_size, n_labels, n_labels]) # alpha_trn_sum: batch_size, n_labels, n_labels
# alpha_trn_sum: batch_size, n_labels, n_labels alpha_trn_sum = alpha_exp + trans_exp
alpha_trn_sum = alpha_exp + trn_exp # alpha_max: batch_size, n_labels
# alpha_max: batch_size, n_labels # We don't include the emission scores here because the max does not depend on them (we add them in below)
# We don't include the emission scores here because the max does not depend on them (we add them in below) alpha_max = alpha_trn_sum.max(2)
alpha_max = alpha_trn_sum.max(2) if i == 0:
# if self.with_start_stop_tag, the first antecedent tag must be START, drop it.
# else, the first label has not antecedent tag, pass it.
pass
else:
alpha_argmax = alpha_trn_sum.argmax(2) alpha_argmax = alpha_trn_sum.argmax(2)
historys.append(alpha_argmax) historys.append(alpha_argmax)
# Now add in the emission scores # Now add the emission scores
alpha = alpha_max + logit alpha = alpha_max + logit
else:
alpha = logit
all_alpha.append(alpha) all_alpha.append(alpha)
# Get the valid alpha # Get the valid alpha
...@@ -358,6 +365,7 @@ class ViterbiDecoder(nn.Layer): ...@@ -358,6 +365,7 @@ class ViterbiDecoder(nn.Layer):
historys = paddle.stack(historys).numpy() historys = paddle.stack(historys).numpy()
lengths_np = lengths.numpy() lengths_np = lengths.numpy()
batch_path = [] batch_path = []
max_len = 0
for batch_id in range(batch_size): for batch_id in range(batch_size):
best_last_tag = last_ids[batch_id] best_last_tag = last_ids[batch_id]
path = [best_last_tag] path = [best_last_tag]
...@@ -365,17 +373,16 @@ class ViterbiDecoder(nn.Layer): ...@@ -365,17 +373,16 @@ class ViterbiDecoder(nn.Layer):
# hist: batch_size, n_labels # hist: batch_size, n_labels
best_last_tag = hist[batch_id][best_last_tag] best_last_tag = hist[batch_id][best_last_tag]
path.append(best_last_tag) path.append(best_last_tag)
if self.with_start_stop_tag:
# the first one is start
start = path.pop()
path.reverse() path.reverse()
max_len = max(max_len, len(path))
# Pad to the max sequence length, so that the ChunkEvaluator can compute it # Pad to the max sequence length, so that the ChunkEvaluator can compute it
path += [0] * (seq_len - len(path))
batch_path.append(path) batch_path.append(path)
batch_path = [path + [0] * (max_len - len(path)) for path in batch_path]
batch_path = paddle.to_tensor(batch_path) batch_path = paddle.to_tensor(batch_path)
return scores, batch_path return scores, batch_path
def _get_batch_index(self, batch_size): def _get_batch_index(self, batch_size):
if self._batch_index is None: if self._batch_index is None or batch_size != self._batch_index.shape[
0]:
self._batch_index = paddle.arange(end=batch_size, dtype="int64") self._batch_index = paddle.arange(end=batch_size, dtype="int64")
return self._batch_index return self._batch_index
...@@ -112,12 +112,12 @@ class ChunkEvaluator(paddle.metric.Metric): ...@@ -112,12 +112,12 @@ class ChunkEvaluator(paddle.metric.Metric):
float: mean precision, recall and f1 score. float: mean precision, recall and f1 score.
""" """
precision = float( precision = float(
self.num_correct_chunks self.num_correct_chunks /
) / self.num_infer_chunks if self.num_infer_chunks else 0 self.num_infer_chunks) if self.num_infer_chunks else 0.
recall = float(self.num_correct_chunks recall = float(self.num_correct_chunks /
) / self.num_label_chunks if self.num_label_chunks else 0 self.num_label_chunks) if self.num_label_chunks else 0.
f1_score = float(2 * precision * recall) / ( f1_score = float(2 * precision * recall / (
precision + recall) if self.num_correct_chunks else 0 precision + recall)) if self.num_correct_chunks else 0.
return precision, recall, f1_score return precision, recall, f1_score
def reset(self): def reset(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册