未验证 提交 6f8e0ab5 编写于 作者: K kinghuin 提交者: GitHub

Fix ernie_gen evaluation bug

上级 cf5bb9f1
...@@ -170,9 +170,9 @@ https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/ ...@@ -170,9 +170,9 @@ https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/
### 依赖 ### 依赖
paddlepaddle >= 1.8.2 paddlepaddle >= 2.0.0
paddlehub >= 1.7.0 paddlehub >= 2.0.0
paddlenlp >= 2.0.0 paddlenlp >= 2.0.0
......
...@@ -42,7 +42,7 @@ from .model import StackModel ...@@ -42,7 +42,7 @@ from .model import StackModel
author_email="", author_email="",
type="nlp/text_generation", type="nlp/text_generation",
) )
class ErnieGen(hub.Module): class ErnieGen():
def __init__(self): def __init__(self):
""" """
initialize with the necessary elements initialize with the necessary elements
...@@ -119,8 +119,7 @@ class ErnieGen(hub.Module): ...@@ -119,8 +119,7 @@ class ErnieGen(hub.Module):
train_dataset = self._load_dataset(train_path) train_dataset = self._load_dataset(train_path)
attn_id = self.tokenizer.vocab['[MASK]'] attn_id = self.tokenizer.vocab['[MASK]']
trans_func = convert_example( trans_func = convert_example(tokenizer=self.tokenizer,
tokenizer=self.tokenizer,
attn_id=attn_id, attn_id=attn_id,
tgt_type_id=1, tgt_type_id=1,
max_encode_len=max_encode_len, max_encode_len=max_encode_len,
...@@ -139,8 +138,7 @@ class ErnieGen(hub.Module): ...@@ -139,8 +138,7 @@ class ErnieGen(hub.Module):
Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # attn_ids Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # attn_ids
Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # tgt_labels Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # tgt_labels
): after_padding(fn(samples)) ): after_padding(fn(samples))
train_data_loader = DataLoader( train_data_loader = DataLoader(dataset=train_dataset,
dataset=train_dataset,
batch_sampler=train_batch_sampler, batch_sampler=train_batch_sampler,
collate_fn=batchify_fn, collate_fn=batchify_fn,
num_workers=0, num_workers=0,
...@@ -149,8 +147,11 @@ class ErnieGen(hub.Module): ...@@ -149,8 +147,11 @@ class ErnieGen(hub.Module):
if dev_path: if dev_path:
dev_dataset = self._load_dataset(dev_path) dev_dataset = self._load_dataset(dev_path)
dev_dataset = dev_dataset.map(trans_func) dev_dataset = dev_dataset.map(trans_func)
dev_data_loader = DataLoader( dev_data_loader = DataLoader(dataset=dev_dataset,
dataset=dev_dataset, batch_size=batch_size, collate_fn=batchify_fn, num_workers=0, return_list=True) batch_size=batch_size,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
label_num = self.model.word_emb.weight.shape[0] label_num = self.model.word_emb.weight.shape[0]
train_model = StackModel(self.model) train_model = StackModel(self.model)
...@@ -158,8 +159,7 @@ class ErnieGen(hub.Module): ...@@ -158,8 +159,7 @@ class ErnieGen(hub.Module):
# Generate parameter names needed to perform weight decay. # Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded. # All bias and LayerNorm parameters are excluded.
decay_params = [p.name for n, p in self.model.named_parameters() if not any(nd in n for nd in ["bias", "norm"])] decay_params = [p.name for n, p in self.model.named_parameters() if not any(nd in n for nd in ["bias", "norm"])]
optimizer = paddle.optimizer.AdamW( optimizer = paddle.optimizer.AdamW(learning_rate=lr_scheduler,
learning_rate=lr_scheduler,
parameters=self.model.parameters(), parameters=self.model.parameters(),
weight_decay=weight_decay, weight_decay=weight_decay,
grad_clip=nn.ClipGradByGlobalNorm(1.0), grad_clip=nn.ClipGradByGlobalNorm(1.0),
...@@ -175,8 +175,8 @@ class ErnieGen(hub.Module): ...@@ -175,8 +175,8 @@ class ErnieGen(hub.Module):
(src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids, mask_src_2_src, mask_tgt_2_srctgt, (src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids, mask_src_2_src, mask_tgt_2_srctgt,
mask_attn_2_srctgtattn, tgt_labels, _) = batch mask_attn_2_srctgtattn, tgt_labels, _) = batch
if label_smooth > 0.: if label_smooth > 0.:
tgt_labels = nn.functional.label_smooth( tgt_labels = nn.functional.label_smooth(nn.functional.one_hot(tgt_labels, label_num),
nn.functional.one_hot(tgt_labels, label_num), epsilon=label_smooth) epsilon=label_smooth)
tgt_pos = paddle.nonzero(attn_ids == attn_id) tgt_pos = paddle.nonzero(attn_ids == attn_id)
loss = train_model(src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids, mask_src_2_src, loss = train_model(src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids, mask_src_2_src,
...@@ -190,8 +190,8 @@ class ErnieGen(hub.Module): ...@@ -190,8 +190,8 @@ class ErnieGen(hub.Module):
if global_step % log_interval == 0 and paddle.distributed.get_rank() == 0: if global_step % log_interval == 0 and paddle.distributed.get_rank() == 0:
loss_np = loss.numpy() loss_np = loss.numpy()
ppl = np.exp(loss_np) ppl = np.exp(loss_np)
logger.info('[step %d / %d]train loss %.5f, ppl %.5f, elr %.3e' % (global_step, max_steps, loss_np, logger.info('[step %d / %d]train loss %.5f, ppl %.5f, elr %.3e' %
ppl, lr_scheduler.get_lr())) (global_step, max_steps, loss_np, ppl, lr_scheduler.get_lr()))
if save_dir and global_step % save_interval == 0 and global_step > 0: if save_dir and global_step % save_interval == 0 and global_step > 0:
loss_np = loss.numpy() loss_np = loss.numpy()
ppl = np.exp(loss_np) ppl = np.exp(loss_np)
...@@ -214,8 +214,8 @@ class ErnieGen(hub.Module): ...@@ -214,8 +214,8 @@ class ErnieGen(hub.Module):
if global_step % save_interval != 0: if global_step % save_interval != 0:
loss_np = loss.numpy() loss_np = loss.numpy()
ppl = np.exp(loss_np) ppl = np.exp(loss_np)
logger.info('[final step %d]train loss %.5f, ppl %.5f, elr %.3e' % (global_step, loss_np, ppl, logger.info('[final step %d]train loss %.5f, ppl %.5f, elr %.3e' %
lr_scheduler.get_lr())) (global_step, loss_np, ppl, lr_scheduler.get_lr()))
if save_dir: if save_dir:
save_name = "step_%s_ppl_%.5f.pdparams" % (global_step, ppl) save_name = "step_%s_ppl_%.5f.pdparams" % (global_step, ppl)
save_path = os.path.join(save_dir, save_name) save_path = os.path.join(save_dir, save_name)
...@@ -291,6 +291,7 @@ class ErnieGen(hub.Module): ...@@ -291,6 +291,7 @@ class ErnieGen(hub.Module):
def _evaluate(self, model, data_loader, tokenizer, rouge1, rouge2, attn_id, max_decode_len, max_encode_len, def _evaluate(self, model, data_loader, tokenizer, rouge1, rouge2, attn_id, max_decode_len, max_encode_len,
beam_width, length_penalty): beam_width, length_penalty):
paddle.disable_static()
model.eval() model.eval()
vocab = tokenizer.vocab vocab = tokenizer.vocab
...@@ -305,8 +306,7 @@ class ErnieGen(hub.Module): ...@@ -305,8 +306,7 @@ class ErnieGen(hub.Module):
for data in data_loader: for data in data_loader:
(src_ids, src_tids, src_pids, _, _, _, _, _, _, _, _, raw_tgt_labels) = data # never use target when infer (src_ids, src_tids, src_pids, _, _, _, _, _, _, _, _, raw_tgt_labels) = data # never use target when infer
# Use greedy_search_infilling or beam_search_infilling to get predictions # Use greedy_search_infilling or beam_search_infilling to get predictions
output_ids = beam_search_infilling( output_ids = beam_search_infilling(model,
model,
src_ids, src_ids,
src_tids, src_tids,
eos_id=eos_id, eos_id=eos_id,
...@@ -361,8 +361,7 @@ class ErnieGen(hub.Module): ...@@ -361,8 +361,7 @@ class ErnieGen(hub.Module):
if __name__ == "__main__": if __name__ == "__main__":
module = ErnieGen() module = ErnieGen()
result = module.finetune( result = module.finetune(train_path='test_data/train.txt',
train_path='test_data/train.txt',
dev_path='test_data/dev.txt', dev_path='test_data/dev.txt',
max_steps=30, max_steps=30,
batch_size=2, batch_size=2,
......
...@@ -99,9 +99,9 @@ https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/ ...@@ -99,9 +99,9 @@ https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/
### 依赖 ### 依赖
paddlepaddle >= 1.8.2 paddlepaddle >= 2.0.0
paddlehub >= 1.7.0 paddlehub >= 2.0.0
paddlenlp >= 2.0.0 paddlenlp >= 2.0.0
......
...@@ -87,9 +87,9 @@ https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/ ...@@ -87,9 +87,9 @@ https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/
### 依赖 ### 依赖
paddlepaddle >= 1.8.2 paddlepaddle >= 2.0.0
paddlehub >= 1.7.0 paddlehub >= 2.0.0
paddlenlp >= 2.0.0 paddlenlp >= 2.0.0
......
...@@ -87,9 +87,9 @@ https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/ ...@@ -87,9 +87,9 @@ https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/
### 依赖 ### 依赖
paddlepaddle >= 1.8.2 paddlepaddle >= 2.0.0
paddlehub >= 1.7.0 paddlehub >= 2.0.0
paddlenlp >= 2.0.0 paddlenlp >= 2.0.0
......
...@@ -87,11 +87,11 @@ https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/ ...@@ -87,11 +87,11 @@ https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/
### 依赖 ### 依赖
paddlepaddle >= 1.8.2 paddlepaddle >= 2.0.0
paddlehub >= 1.7.0 paddlehub >= 2.0.0
PaddleNLP >= 2.0.0 paddlenlp >= 2.0.0
## 更新历史 ## 更新历史
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册