未验证 提交 b725cbd0 编写于 作者: K kinghuin 提交者: GitHub

add ernie_gen for the dygraph mode

add ernie_gen for the dygraph mode
上级 30ccfc67
......@@ -18,9 +18,9 @@
- Python >= 3.6
- PaddlePaddle >= 2.0.0rc1,安装方式请参考 [快速安装](https://www.paddlepaddle.org.cn/install/quick)
- paddlepaddle >= 2.0.0rc1,安装方式请参考 [快速安装](https://www.paddlepaddle.org.cn/install/quick)
- PaddleNLP >= 2.0.0b, 安装方式:`pip install paddlenlp>=2.0.0b`
- paddlenlp >= 2.0.0b, 安装方式:`pip install paddlenlp>=2.0.0b`
### 2.2 数据准备
......
# ERNIE-Gen
# ERNIE-Gen: An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation
TBD
## 1. 简介
**ERNIE-GEN 是面向生成任务的预训练-微调框架**,首次在预训练阶段加入**span-by-span 生成**任务,让模型每次能够生成一个语义完整的片段。在预训练和微调中通过**填充式生成机制****噪声感知机制**来缓解曝光偏差问题。此外, ERNIE-GEN 采样**多片段-多粒度目标文本采样**策略, 增强源文本和目标文本的关联性,加强了编码器和解码器的交互。
![multi-flow-attention](https://github.com/PaddlePaddle/ERNIE/raw/repro/ernie-gen/.meta/multi-flow-attention.png)
## 2. 快速开始
### 2.1 环境配置
- Python >= 3.6
- paddlepaddle >= 2.0.0rc1,安装方式请参考 [快速安装](https://www.paddlepaddle.org.cn/install/quick)
- paddlenlp >= 2.0.0b, 安装方式:`pip install paddlenlp>=2.0.0b`
### 2.2 数据准备
在本例中,我们提供了古诗词数据集,示例数据如下:
```text
画\002精\002禅\002室\002冷\002,\002方\002暑\002久\002徘\002徊\002。 不\002尽\002林\002端\002雪\002,\002长\002青\002石\002上\002苔\002。\002心\002闲\002对\002岩\002岫\002,\002目\002浄\002失\002尘\002埃\002。\002坐\002久\002清\002风\002至\002,\002疑\002从\002翠\002涧\002来\002。
```
每行数据都是由两列组成,以制表符分隔。第一列是输入的诗句前文,第二列是输出的诗句后文,所有文字都以 `\002` 分隔。
完整数据集可以通过以下命令下载并解压:
```bash
wget --no-check-certificate https://paddlenlp.bj.bcebos.com/datasets/poetry.tar.gz
tar xvf poetry.tar.gz
```
### 2.3 模型微调
模型训练支持 CPU 和 GPU,使用 GPU 之前应指定使用的显卡卡号:
```bash
export CUDA_VISIBLE_DEVICES=0,1,2 # 支持多卡训练
```
训练启动方式如下:
```bash
python -u ./train.py \
--model_name_or_path ernie-1.0 \
--max_encode_len 24 \
--max_decode_len 72 \
--batch_size 48 \
--learning_rate 2e-5 \
--num_epochs 12 \
--logging_steps 1 \
--save_steps 1000 \
--output_dir ./tmp/ \
--n_gpu 3 \
# --init_checkpoint ./tmp/model_10000/model_state.pdparams
```
参数释义如下:
- `model_name_or_path` 指示了某种特定配置的模型,对应有其预训练模型和预训练时使用的 tokenizer。若模型相关内容保存在本地,这里也可以提供相应目录地址。
- `max_encode_len` 表示最大输入句子长度,超过该长度将被截断。
- `max_decode_len` 表示最大输出句子长度,超过该长度将被截断。
- `batch_size` 表示每次迭代**每张卡**上的样本数目。
- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。
- `num_epochs` 表示训练轮数。
- `logging_steps` 表示日志打印间隔。
- `save_steps` 表示模型保存及评估间隔。
- `output_dir` 表示模型保存路径。
- `n_gpu` 表示使用的 GPU 卡数。若希望使用多卡训练,将其设置为指定数目即可;若为0,则使用CPU。
- `init_checkpoint` 表示模型加载路径,通过设置此参数可以开启增量训练。
### 2.4 模型评估
通过加载训练保存的模型,可以对验证集数据进行验证,启动方式如下:
```bash
python -u ./eval.py \
--model_name_or_path ernie-1.0 \
--max_encode_len 24 \
--max_decode_len 72 \
--batch_size 48 \
--init_checkpoint ./tmp/model_10000/model_state.pdparams \
--use_gpu
```
参数释义如下:
- `model_name_or_path` 指示了某种特定配置的模型,对应有其预训练模型和预训练时使用的 tokenizer。若模型相关内容保存在本地,这里也可以提供相应目录地址。
- `max_encode_len` 表示最大输入句子长度,超过该长度将被截断。
- `max_decode_len` 表示最大输出句子长度,超过该长度将被截断。
- `batch_size` 表示每次迭代**每张卡**上的样本数目。
- `init_checkpoint` 表示模型加载路径。
- `use_gpu` 表示使用GPU。
### 2.5 模型预测
对无标签数据可以启动模型预测:
```bash
python -u ./predict.py \
--model_name_or_path ernie-1.0 \
--max_encode_len 24 \
--max_decode_len 72 \
--batch_size 48 \
--init_checkpoint ./tmp/model_10000/model_state.pdparams \
--use_gpu
```
## 引用
您可以按下面的格式引用ERNIE-Gen论文:
```
@article{xiao2020ernie-gen,
title={ERNIE-GEN: An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation},
author={Xiao, Dongling and Zhang, Han and Li, Yukun and Sun, Yu and Tian, Hao and Wu, Hua and Wang, Haifeng},
journal={arXiv preprint arXiv:2001.11314},
year={2020}
}
```
## 如何贡献代码
如果你可以修复某个 issue 或者增加一个新功能,欢迎给我们提交 PR。如果对应的 PR 被接受了,我们将根据贡献的质量和难度 进行打分(0-5 分,越高越好)。如果你累计获得了 10 分,可以联系我们获得面试机会或为你写推荐信。
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import sys
import re
import argparse
import logging
import json
import numpy as np
from collections import namedtuple
import paddle
import paddle.nn as nn
import numpy as np
from paddlenlp.utils.log import logger
def gen_bias(encoder_inputs, decoder_inputs, step):
decoder_bsz, decoder_seqlen = decoder_inputs.shape[:2]
encoder_bsz, encoder_seqlen = encoder_inputs.shape[:2]
attn_bias = paddle.reshape(
paddle.arange(
0, decoder_seqlen, 1, dtype='float32') + 1, [1, -1, 1])
decoder_bias = paddle.cast(
(paddle.matmul(
attn_bias, 1. / attn_bias, transpose_y=True) >= 1.),
'float32') #[1, decoderlen, decoderlen]
encoder_bias = paddle.unsqueeze(
paddle.cast(paddle.ones_like(encoder_inputs), 'float32'),
[1]) #[bsz, 1, encoderlen]
encoder_bias = paddle.expand(
encoder_bias, [encoder_bsz, decoder_seqlen,
encoder_seqlen]) #[bsz,decoderlen, encoderlen]
decoder_bias = paddle.expand(
decoder_bias, [decoder_bsz, decoder_seqlen,
decoder_seqlen]) #[bsz, decoderlen, decoderlen]
if step > 0:
bias = paddle.concat([
encoder_bias, paddle.ones([decoder_bsz, decoder_seqlen, step],
'float32'), decoder_bias
], -1)
else:
bias = paddle.concat([encoder_bias, decoder_bias], -1)
return bias
@paddle.no_grad()
def greedy_search_infilling(model,
q_ids,
q_sids,
sos_id,
eos_id,
attn_id,
pad_id,
unk_id,
vocab_size,
max_encode_len=640,
max_decode_len=100,
tgt_type_id=3):
_, logits, info = model(q_ids, q_sids)
d_batch, d_seqlen = q_ids.shape
seqlen = paddle.sum(paddle.cast(q_ids != 0, 'int64'), 1, keepdim=True)
has_stopped = np.zeros([d_batch], dtype=np.bool)
gen_seq_len = np.zeros([d_batch], dtype=np.int64)
output_ids = []
past_cache = info['caches']
cls_ids = paddle.ones([d_batch], dtype='int64') * sos_id
attn_ids = paddle.ones([d_batch], dtype='int64') * attn_id
ids = paddle.stack([cls_ids, attn_ids], -1)
for step in range(max_decode_len):
bias = gen_bias(q_ids, ids, step)
pos_ids = paddle.to_tensor(
np.tile(
np.array(
[[step, step + 1]], dtype=np.int64), [d_batch, 1]))
pos_ids += seqlen
_, logits, info = model(
ids,
paddle.ones_like(ids) * tgt_type_id,
pos_ids=pos_ids,
attn_bias=bias,
past_cache=past_cache)
if logits.shape[-1] > vocab_size:
logits[:, :, vocab_size:] = 0
logits[:, :, pad_id] = 0
logits[:, :, unk_id] = 0
logits[:, :, attn_id] = 0
gen_ids = paddle.argmax(logits, -1)
past_cached_k, past_cached_v = past_cache
cached_k, cached_v = info['caches']
cached_k = [
paddle.concat([pk, k[:, :1, :]], 1)
for pk, k in zip(past_cached_k, cached_k)
] # concat cached
cached_v = [
paddle.concat([pv, v[:, :1, :]], 1)
for pv, v in zip(past_cached_v, cached_v)
]
past_cache = (cached_k, cached_v)
gen_ids = gen_ids[:, 1]
ids = paddle.stack([gen_ids, attn_ids], 1)
gen_ids = gen_ids.numpy()
has_stopped |= (gen_ids == eos_id).astype(np.bool)
gen_seq_len += (1 - has_stopped.astype(np.int64))
output_ids.append(gen_ids.tolist())
if has_stopped.all():
break
output_ids = np.array(output_ids).transpose([1, 0])
return output_ids
BeamSearchState = namedtuple('BeamSearchState',
['log_probs', 'lengths', 'finished'])
BeamSearchOutput = namedtuple('BeamSearchOutput',
['scores', 'predicted_ids', 'beam_parent_ids'])
def log_softmax(x):
e_x = np.exp(x - np.max(x))
return np.log(e_x / e_x.sum())
def mask_prob(p, onehot_eos, finished):
is_finished = paddle.cast(paddle.reshape(finished, [-1, 1]) != 0, 'float32')
p = is_finished * (1. - paddle.cast(onehot_eos, 'float32')) * -9999. + (
1. - is_finished) * p
return p
def hyp_score(log_probs, length, length_penalty):
lp = paddle.pow((5. + paddle.cast(length, 'float32')) / 6., length_penalty)
return log_probs / lp
def beam_search_step(state, logits, eos_id, beam_width, is_first_step,
length_penalty):
"""logits.shape == [B*W, V]"""
_, vocab_size = logits.shape
bsz, beam_width = state.log_probs.shape
onehot_eos = paddle.cast(
nn.functional.one_hot(paddle.ones([1], 'int64') * eos_id, vocab_size),
'int64') #[1, V]
probs = paddle.log(nn.functional.softmax(logits)) #[B*W, V]
probs = mask_prob(probs, onehot_eos, state.finished) #[B*W, V]
allprobs = paddle.reshape(state.log_probs, [-1, 1]) + probs #[B*W, V]
not_finished = 1 - paddle.reshape(state.finished, [-1, 1]) #[B*W,1]
not_eos = 1 - onehot_eos
length_to_add = not_finished * not_eos #[B*W,V]
alllen = paddle.reshape(state.lengths, [-1, 1]) + length_to_add
allprobs = paddle.reshape(allprobs, [-1, beam_width * vocab_size])
alllen = paddle.reshape(alllen, [-1, beam_width * vocab_size])
allscore = hyp_score(allprobs, alllen, length_penalty)
if is_first_step:
allscore = paddle.reshape(
allscore,
[bsz, beam_width, -1])[:, 0, :] # first step only consiter beam 0
scores, idx = paddle.topk(allscore, k=beam_width) #[B, W]
next_beam_id = idx // vocab_size #[B, W]
next_word_id = idx % vocab_size
gather_idx = paddle.concat(
[paddle.nonzero(idx != -1)[:, :1], paddle.reshape(idx, [-1, 1])], 1)
next_probs = paddle.reshape(
paddle.gather_nd(allprobs, gather_idx), idx.shape)
next_len = paddle.reshape(paddle.gather_nd(alllen, gather_idx), idx.shape)
gather_idx = paddle.concat([
paddle.nonzero(next_beam_id != -1)[:, :1], paddle.reshape(next_beam_id,
[-1, 1])
], 1)
next_finished = paddle.reshape(
paddle.gather_nd(state.finished, gather_idx),
state.finished.shape) #[gather new beam state according to new beam id]
next_finished += paddle.cast(next_word_id == eos_id, 'int64')
next_finished = paddle.cast(next_finished > 0, 'int64')
next_state = BeamSearchState(
log_probs=next_probs, lengths=next_len, finished=next_finished)
output = BeamSearchOutput(
scores=scores, predicted_ids=next_word_id, beam_parent_ids=next_beam_id)
return output, next_state
@paddle.no_grad()
def beam_search_infilling(model,
q_ids,
q_sids,
sos_id,
eos_id,
attn_id,
pad_id,
unk_id,
vocab_size,
max_encode_len=640,
max_decode_len=100,
beam_width=5,
tgt_type_id=3,
length_penalty=1.0):
_, __, info = model(q_ids, q_sids)
d_batch, d_seqlen = q_ids.shape
state = BeamSearchState(
log_probs=paddle.zeros([d_batch, beam_width], 'float32'),
lengths=paddle.zeros([d_batch, beam_width], 'int64'),
finished=paddle.zeros([d_batch, beam_width], 'int64'))
outputs = []
def reorder_(t, parent_id):
"""reorder cache according to parent beam id"""
gather_idx = paddle.nonzero(
parent_id != -1)[:, 0] * beam_width + paddle.reshape(parent_id,
[-1])
t = paddle.gather(t, gather_idx)
return t
def tile_(t, times):
_shapes = list(t.shape[1:])
new_shape = [t.shape[0], times] + list(t.shape[1:])
ret = paddle.reshape(
paddle.expand(paddle.unsqueeze(t, [1]), new_shape),
[-1, ] + _shapes)
return ret
cached_k, cached_v = info['caches']
cached_k = [tile_(k, beam_width) for k in cached_k]
cached_v = [tile_(v, beam_width) for v in cached_v]
past_cache = (cached_k, cached_v)
q_ids = tile_(q_ids, beam_width)
seqlen = paddle.sum(paddle.cast(q_ids != 0, 'int64'), 1, keepdim=True)
#log.debug(q_ids.shape)
cls_ids = paddle.ones([d_batch * beam_width], dtype='int64') * sos_id
attn_ids = paddle.ones(
[d_batch * beam_width], dtype='int64') * attn_id # SOS
ids = paddle.stack([cls_ids, attn_ids], -1)
for step in range(max_decode_len):
#log.debug('decode step %d' % step)
bias = gen_bias(q_ids, ids, step)
pos_ids = paddle.to_tensor(
np.tile(
np.array(
[[step, step + 1]], dtype=np.int64),
[d_batch * beam_width, 1]))
pos_ids += seqlen
_, logits, info = model(
ids,
paddle.ones_like(ids) * tgt_type_id,
pos_ids=pos_ids,
attn_bias=bias,
past_cache=past_cache)
if logits.shape[-1] > vocab_size:
logits[:, :, vocab_size:] = 0
logits[:, :, pad_id] = 0
logits[:, :, unk_id] = 0
logits[:, :, attn_id] = 0
output, state = beam_search_step(
state,
logits[:, 1],
eos_id=eos_id,
beam_width=beam_width,
is_first_step=(step == 0),
length_penalty=length_penalty)
outputs.append(output)
past_cached_k, past_cached_v = past_cache
cached_k, cached_v = info['caches']
cached_k = [
reorder_(
paddle.concat([pk, k[:, :1, :]], 1), output.beam_parent_ids)
for pk, k in zip(past_cached_k, cached_k)
] # concat cached
cached_v = [
reorder_(
paddle.concat([pv, v[:, :1, :]], 1), output.beam_parent_ids)
for pv, v in zip(past_cached_v, cached_v)
]
past_cache = (cached_k, cached_v)
pred_ids_flatten = paddle.reshape(output.predicted_ids,
[d_batch * beam_width])
ids = paddle.stack([pred_ids_flatten, attn_ids], 1)
if state.finished.numpy().all():
break
final_ids = paddle.stack([o.predicted_ids for o in outputs], 0)
final_parent_ids = paddle.stack([o.beam_parent_ids for o in outputs], 0)
final_ids = nn.functional.gather_tree(
final_ids, final_parent_ids)[:, :, 0] #pick best beam
final_ids = paddle.transpose(
paddle.reshape(final_ids, [-1, d_batch * 1]), [1, 0])
return final_ids.numpy()
en_patten = re.compile(r'^[a-zA-Z0-9]*$')
def post_process(token):
if token.startswith('##'):
ret = token[2:]
elif token in ['[CLS]', '[SEP]', '[PAD]']:
ret = ''
else:
if en_patten.match(token):
ret = ' ' + token
else:
ret = token
return ret
from copy import deepcopy
import numpy as np
def convert_example(tokenizer,
attn_id,
tgt_type_id=3,
max_encode_len=512,
max_decode_len=128,
is_test=False,
noise_prob=0.,
use_random_noice=False):
def warpper(example):
"""convert an example into necessary features"""
encoded_src = tokenizer.encode(
example[0], max_seq_len=max_encode_len, pad_to_max_seq_len=False)
src_ids, src_sids = encoded_src["input_ids"], encoded_src["segment_ids"]
src_pids = np.arange(len(src_ids))
if not is_test:
encoded_tgt = tokenizer.encode(
example[1],
max_seq_len=max_decode_len,
pad_to_max_seq_len=False)
tgt_ids, tgt_sids = encoded_tgt["input_ids"], encoded_tgt[
"segment_ids"]
tgt_ids = np.array(tgt_ids)
tgt_sids = np.array(tgt_sids) + tgt_type_id
tgt_pids = np.arange(len(tgt_ids)) + len(src_ids)
attn_ids = np.ones_like(tgt_ids) * attn_id
if noise_prob > 0.:
tgt_labels = deepcopy(tgt_ids)
if use_random_noice:
noice_ids = np.random.randint(
1, len(tokenizer.vocab), size=tgt_ids.shape)
else:
noice_ids = np.ones_like(tgt_ids) * tokenizer.vocab['[NOISE]']
pos, = np.where(np.ones_like(tgt_ids))
np.random.shuffle(pos)
pos = pos[:int(noise_prob * len(pos))]
tgt_ids[pos, ] = noice_ids[pos, ]
else:
tgt_labels = tgt_ids
return (src_ids, src_pids, src_sids, tgt_ids, tgt_pids, tgt_sids,
attn_ids, tgt_labels)
return warpper
def gen_mask(batch_ids, mask_type='bidi', query_len=None, pad_value=0):
if query_len is None:
query_len = batch_ids.shape[1]
if mask_type != 'empty':
mask = (batch_ids != pad_value).astype(np.float32)
mask = np.tile(np.expand_dims(mask, 1), [1, query_len, 1])
if mask_type == 'causal':
assert query_len == batch_ids.shape[1]
mask = np.tril(mask)
elif mask_type == 'causal_without_diag':
assert query_len == batch_ids.shape[1]
mask = np.tril(mask, -1)
elif mask_type == 'diag':
assert query_len == batch_ids.shape[1]
# import pdb; pdb.set_trace()
mask = np.stack([np.diag(np.diag(m)) for m in mask], 0)
else:
mask_type == 'empty'
mask = np.zeros_like(batch_ids).astype(np.float32)
mask = np.tile(np.expand_dims(mask, 1), [1, query_len, 1])
return mask
def after_padding(args):
'''
attention mask:
*** src, tgt, attn
src 00, 01, 11
tgt 10, 11, 12
attn 20, 21, 22
*** s1, s2 | t1 t2 t3| attn1 attn2 attn3
s1 1, 1 | 0, 0, 0,| 0, 0, 0,
s2 1, 1 | 0, 0, 0,| 0, 0, 0,
-
t1 1, 1, | 1, 0, 0,| 0, 0, 0,
t2 1, 1, | 1, 1, 0,| 0, 0, 0,
t3 1, 1, | 1, 1, 1,| 0, 0, 0,
-
attn1 1, 1, | 0, 0, 0,| 1, 0, 0,
attn2 1, 1, | 1, 0, 0,| 0, 1, 0,
attn3 1, 1, | 1, 1, 0,| 0, 0, 1,
for details, see Fig3. https://arxiv.org/abs/2001.11314
'''
src_ids, src_pids, src_sids, tgt_ids, tgt_pids, tgt_sids, attn_ids, tgt_labels = args
src_len = src_ids.shape[1]
tgt_len = tgt_ids.shape[1]
mask_00 = gen_mask(src_ids, 'bidi', query_len=src_len)
mask_01 = gen_mask(tgt_ids, 'empty', query_len=src_len)
mask_02 = gen_mask(attn_ids, 'empty', query_len=src_len)
mask_10 = gen_mask(src_ids, 'bidi', query_len=tgt_len)
mask_11 = gen_mask(tgt_ids, 'causal', query_len=tgt_len)
mask_12 = gen_mask(attn_ids, 'empty', query_len=tgt_len)
mask_20 = gen_mask(src_ids, 'bidi', query_len=tgt_len)
mask_21 = gen_mask(tgt_ids, 'causal_without_diag', query_len=tgt_len)
mask_22 = gen_mask(attn_ids, 'diag', query_len=tgt_len)
mask_src_2_src = mask_00
mask_tgt_2_srctgt = np.concatenate([mask_10, mask_11], 2)
mask_attn_2_srctgtattn = np.concatenate([mask_20, mask_21, mask_22], 2)
raw_tgt_labels = deepcopy(tgt_labels)
tgt_labels = tgt_labels[np.where(tgt_labels != 0)]
return (src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids, attn_ids,
mask_src_2_src, mask_tgt_2_srctgt, mask_attn_2_srctgtattn,
tgt_labels, raw_tgt_labels)
import os
import ast
import time
import argparse
import logging
import paddle
import paddle.nn as nn
from tqdm import tqdm
from paddle.io import DataLoader
from paddlenlp.transformers import ErnieForGeneration
from paddlenlp.transformers import ErnieTokenizer, ErnieTinyTokenizer, BertTokenizer, ElectraTokenizer, RobertaTokenizer
from paddlenlp.datasets import Poetry
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.metrics import Rouge1, Rouge2
from paddlenlp.utils.log import logger
from encode import convert_example, after_padding
from decode import beam_search_infilling, post_process, greedy_search_infilling
# yapf: disable
parser = argparse.ArgumentParser('seq2seq model with ERNIE-GEN')
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: "+ ", ".join(list(ErnieTokenizer.pretrained_init_configuration.keys())))
parser.add_argument('--max_encode_len', type=int, default=24, help="The max encoding sentence length")
parser.add_argument('--max_decode_len', type=int, default=72, help="The max decoding sentence length")
parser.add_argument("--batch_size", default=50, type=int, help="Batch size per GPU/CPU for training.", )
parser.add_argument('--beam_width', type=int, default=1, help="Beam search width")
parser.add_argument('--length_penalty', type=float, default=1.0, help="The length penalty during decoding")
parser.add_argument('--init_checkpoint', type=str, default=None, help='Checkpoint to warm start from')
parser.add_argument('--use_gpu', action='store_true', help='If set, use gpu to excute')
# yapf: enable
args = parser.parse_args()
def evaluate():
paddle.set_device("gpu" if args.use_gpu else "cpu")
model = ErnieForGeneration.from_pretrained(args.model_name_or_path)
if "ernie-tiny" in args.model_name_or_path:
tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path)
elif "ernie" in args.model_name_or_path:
tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path:
tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
elif "electra" in args.model_name_or_path:
tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path)
else:
tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
dev_dataset = Poetry.get_datasets(['dev'])
attn_id = tokenizer.vocab[
'[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]']
tgt_type_id = model.sent_emb.weight.shape[0] - 1
trans_func = convert_example(
tokenizer=tokenizer,
attn_id=attn_id,
tgt_type_id=tgt_type_id,
max_encode_len=args.max_encode_len,
max_decode_len=args.max_decode_len)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_ids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_pids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_sids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_ids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_pids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_sids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # attn_ids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_labels
): after_padding(fn(samples))
dev_dataset = dev_dataset.apply(trans_func, lazy=True)
dev_batch_sampler = paddle.io.BatchSampler(
dev_dataset, batch_size=args.batch_size, shuffle=False)
data_loader = DataLoader(
dataset=dev_dataset,
batch_sampler=dev_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
rouge1 = Rouge1()
rouge2 = Rouge2()
if args.init_checkpoint:
model_state = paddle.load(args.init_checkpoint)
model.set_state_dict(model_state)
model.eval()
vocab = tokenizer.vocab
eos_id = vocab[tokenizer.sep_token]
sos_id = vocab[tokenizer.cls_token]
pad_id = vocab[tokenizer.pad_token]
unk_id = vocab[tokenizer.unk_token]
vocab_size = len(vocab)
evaluated_sentences_ids = []
reference_sentences_ids = []
logger.info("Evaluating...")
for data in tqdm(data_loader):
(src_ids, src_sids, src_pids, _, _, _, _, _, _, _, _,
raw_tgt_labels) = data # never use target when infer
# Use greedy_search_infilling or beam_search_infilling to get predictions
output_ids = beam_search_infilling(
model,
src_ids,
src_sids,
eos_id=eos_id,
sos_id=sos_id,
attn_id=attn_id,
pad_id=pad_id,
unk_id=unk_id,
vocab_size=vocab_size,
max_decode_len=args.max_decode_len,
max_encode_len=args.max_encode_len,
beam_width=args.beam_width,
length_penalty=args.length_penalty,
tgt_type_id=tgt_type_id)
for ids in output_ids.tolist():
if eos_id in ids:
ids = ids[:ids.index(eos_id)]
evaluated_sentences_ids.append(ids)
for ids in raw_tgt_labels.numpy().tolist():
ids = ids[:ids.index(eos_id)]
reference_sentences_ids.append(ids)
score1 = rouge1.score(evaluated_sentences_ids, reference_sentences_ids)
score2 = rouge2.score(evaluated_sentences_ids, reference_sentences_ids)
logger.info("Rouge-1: %.5f ,Rouge-2: %.5f" % (score1 * 100, score2 * 100))
if __name__ == "__main__":
evaluate()
import os
import ast
import time
import argparse
import logging
import paddle
import paddle.nn as nn
from paddle.io import DataLoader
from paddlenlp.transformers import ErnieForGeneration
from paddlenlp.transformers import ErnieTokenizer, ErnieTinyTokenizer, BertTokenizer, ElectraTokenizer, RobertaTokenizer
from paddlenlp.datasets import Poetry
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.metrics import Rouge1, Rouge2
from paddlenlp.utils.log import logger
from encode import convert_example, after_padding
from decode import beam_search_infilling, post_process, greedy_search_infilling
# yapf: disable
parser = argparse.ArgumentParser('seq2seq model with ERNIE-GEN')
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: "+ ", ".join(list(ErnieTokenizer.pretrained_init_configuration.keys())))
parser.add_argument('--max_encode_len', type=int, default=24, help="The max encoding sentence length")
parser.add_argument('--max_decode_len', type=int, default=72, help="The max decoding sentence length")
parser.add_argument("--batch_size", default=50, type=int, help="Batch size per GPU/CPU for training.", )
parser.add_argument('--beam_width', type=int, default=3, help="Beam search width")
parser.add_argument('--length_penalty', type=float, default=1.0, help="The length penalty during decoding")
parser.add_argument('--init_checkpoint', type=str, default=None, help='Checkpoint to warm start from')
parser.add_argument('--use_gpu', action='store_true', help='If set, use gpu to excute')
# yapf: enable
args = parser.parse_args()
def predict():
paddle.set_device("gpu" if args.use_gpu else "cpu")
model = ErnieForGeneration.from_pretrained(args.model_name_or_path)
if "ernie-tiny" in args.model_name_or_path:
tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path)
elif "ernie" in args.model_name_or_path:
tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path:
tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
elif "electra" in args.model_name_or_path:
tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path)
else:
tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
dev_dataset = Poetry.get_datasets(['test'])
attn_id = tokenizer.vocab[
'[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]']
tgt_type_id = model.sent_emb.weight.shape[0] - 1
trans_func = convert_example(
tokenizer=tokenizer,
attn_id=attn_id,
tgt_type_id=tgt_type_id,
max_encode_len=args.max_encode_len,
max_decode_len=args.max_decode_len)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_ids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_pids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_sids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_ids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_pids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_sids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # attn_ids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_labels
): after_padding(fn(samples))
dev_dataset = dev_dataset.apply(trans_func, lazy=True)
test_batch_sampler = paddle.io.BatchSampler(
dev_dataset, batch_size=args.batch_size, shuffle=False)
data_loader = DataLoader(
dataset=dev_dataset,
batch_sampler=test_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
if args.init_checkpoint:
model_state = paddle.load(args.init_checkpoint)
model.set_state_dict(model_state)
model.eval()
vocab = tokenizer.vocab
eos_id = vocab[tokenizer.sep_token]
sos_id = vocab[tokenizer.cls_token]
pad_id = vocab[tokenizer.pad_token]
unk_id = vocab[tokenizer.unk_token]
vocab_size = len(vocab)
evaluated_sentences = []
evaluated_sentences_ids = []
logger.info("Predicting...")
for data in data_loader:
(src_ids, src_sids, src_pids, _, _, _, _, _, _, _, _,
raw_tgt_labels) = data # never use target when infer
# Use greedy_search_infilling or beam_search_infilling to get predictions
output_ids = beam_search_infilling(
model,
src_ids,
src_sids,
eos_id=eos_id,
sos_id=sos_id,
attn_id=attn_id,
pad_id=pad_id,
unk_id=unk_id,
vocab_size=vocab_size,
max_decode_len=args.max_decode_len,
max_encode_len=args.max_encode_len,
beam_width=args.beam_width,
length_penalty=args.length_penalty,
tgt_type_id=tgt_type_id)
for source_ids, target_ids, predict_ids in zip(
src_ids.numpy().tolist(),
raw_tgt_labels.numpy().tolist(), output_ids.tolist()):
if eos_id in predict_ids:
predict_ids = predict_ids[:predict_ids.index(eos_id)]
source_sentence = ''.join(
map(post_process,
vocab.to_tokens(source_ids[1:source_ids.index(eos_id)])))
tgt_sentence = ''.join(
map(post_process,
vocab.to_tokens(target_ids[1:target_ids.index(eos_id)])))
predict_ids = ''.join(
map(post_process, vocab.to_tokens(predict_ids)))
print("source :%s\ntarget :%s\npredict:%s\n" %
(source_sentence, tgt_sentence, predict_ids))
break
if __name__ == "__main__":
predict()
import os
import ast
import time
import argparse
import logging
import paddle
from tqdm import tqdm
import paddle.nn as nn
from paddle.io import DataLoader
from paddlenlp.transformers import ErnieForGeneration
from paddlenlp.transformers import ErnieTokenizer, ErnieTinyTokenizer, BertTokenizer, ElectraTokenizer, RobertaTokenizer
from paddlenlp.datasets import Poetry
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.metrics import Rouge1, Rouge2
from paddlenlp.utils.log import logger
from encode import convert_example, after_padding
from decode import beam_search_infilling, post_process, greedy_search_infilling
# yapf: disable
parser = argparse.ArgumentParser('seq2seq model with ERNIE-GEN')
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: "+ ", ".join(list(ErnieTokenizer.pretrained_init_configuration.keys())))
parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.",)
parser.add_argument('--max_encode_len', type=int, default=5, help="The max encoding sentence length")
parser.add_argument('--max_decode_len', type=int, default=5, help="The max decoding sentence length")
parser.add_argument("--batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.", )
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.1, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--num_epochs", default=3, type=int, help="Total number of training epochs to perform.", )
parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_epochs.",)
parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Linear warmup proportion.")
parser.add_argument("--logging_steps", type=int, default=1, help="Log every X updates steps.")
parser.add_argument("--save_steps", type=int, default=100, help="Save checkpoint every X updates steps.")
parser.add_argument("--n_gpu", type=int, default=1, help="Number of gpus to use, 0 for cpu.")
parser.add_argument('--beam_width', type=int, default=1, help="Beam search width")
parser.add_argument('--noise_prob', type=float, default=0., help='Probability of token be repalced')
parser.add_argument('--use_random_noice', action='store_true', help='If set, replace target tokens with random token from vocabulary, else replace with `[NOISE]`')
parser.add_argument('--label_smooth', type=float, default=0., help="The soft label smooth rate")
parser.add_argument('--length_penalty', type=float, default=1.0, help="The length penalty during decoding")
parser.add_argument('--init_checkpoint', type=str, default=None, help='Checkpoint to warm start from')
parser.add_argument('--save_dir', type=str, default=None, help='Model output directory')
# yapf: enable
args = parser.parse_args()
def evaluate(model, data_loader, tokenizer, rouge1, rouge2, attn_id,
tgt_type_id, args):
model.eval()
vocab = tokenizer.vocab
eos_id = vocab[tokenizer.sep_token]
sos_id = vocab[tokenizer.cls_token]
pad_id = vocab[tokenizer.pad_token]
unk_id = vocab[tokenizer.unk_token]
vocab_size = len(vocab)
evaluated_sentences_ids = []
reference_sentences_ids = []
logger.info("Evaluating...")
for data in tqdm(data_loader):
(src_ids, src_sids, src_pids, _, _, _, _, _, _, _, _,
raw_tgt_labels) = data # never use target when infer
# Use greedy_search_infilling or beam_search_infilling to get predictions
output_ids = beam_search_infilling(
model,
src_ids,
src_sids,
eos_id=eos_id,
sos_id=sos_id,
attn_id=attn_id,
pad_id=pad_id,
unk_id=unk_id,
vocab_size=vocab_size,
max_decode_len=args.max_decode_len,
max_encode_len=args.max_encode_len,
beam_width=args.beam_width,
length_penalty=args.length_penalty,
tgt_type_id=tgt_type_id)
for ids in output_ids.tolist():
if eos_id in ids:
ids = ids[:ids.index(eos_id)]
evaluated_sentences_ids.append(ids)
for ids in raw_tgt_labels.numpy().tolist():
ids = ids[:ids.index(eos_id)]
reference_sentences_ids.append(ids)
score1 = rouge1.score(evaluated_sentences_ids, reference_sentences_ids)
score2 = rouge2.score(evaluated_sentences_ids, reference_sentences_ids)
logger.info("Rouge-1: %.5f ,Rouge-2: %.5f" % (score1 * 100, score2 * 100))
evaluated_sentences = []
reference_sentences = []
for ids in reference_sentences_ids[:5]:
reference_sentences.append(''.join(
map(post_process, vocab.to_tokens(ids))))
for ids in evaluated_sentences_ids[:5]:
evaluated_sentences.append(''.join(
map(post_process, vocab.to_tokens(ids))))
logger.debug(reference_sentences)
logger.debug(evaluated_sentences)
model.train()
def train():
paddle.set_device("gpu" if args.n_gpu else "cpu")
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
model = ErnieForGeneration.from_pretrained(args.model_name_or_path)
if "ernie-tiny" in args.model_name_or_path:
tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path)
elif "ernie" in args.model_name_or_path:
tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path:
tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
elif "electra" in args.model_name_or_path:
tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path)
else:
tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
if args.init_checkpoint:
model_state = paddle.load(args.init_checkpoint)
model.set_state_dict(model_state)
train_dataset, dev_dataset = Poetry.get_datasets(['train', 'dev'])
attn_id = tokenizer.vocab[
'[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]']
tgt_type_id = model.sent_emb.weight.shape[0] - 1
trans_func = convert_example(
tokenizer=tokenizer,
attn_id=attn_id,
tgt_type_id=tgt_type_id,
max_encode_len=args.max_encode_len,
max_decode_len=args.max_decode_len,
noise_prob=args.noise_prob,
use_random_noice=args.use_random_noice)
train_dataset = train_dataset.apply(trans_func, lazy=True)
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.batch_size, shuffle=True)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_ids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_pids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_sids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_ids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_pids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_sids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # attn_ids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_labels
): after_padding(fn(samples))
train_data_loader = DataLoader(
dataset=train_dataset,
batch_sampler=train_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
dev_dataset = dev_dataset.apply(trans_func, lazy=True)
dev_batch_sampler = paddle.io.BatchSampler(
dev_dataset, batch_size=args.batch_size, shuffle=False)
dev_data_loader = DataLoader(
dataset=dev_dataset,
batch_sampler=dev_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
label_num = model.word_emb.weight.shape[0]
if paddle.distributed.get_world_size() > 1:
model = paddle.DataParallel(model)
max_steps = (len(train_data_loader) * args.num_epochs)
lr_scheduler = paddle.optimizer.lr.LambdaDecay(
args.learning_rate,
lambda current_step, num_warmup_steps=max_steps*args.warmup_proportion,
num_training_steps=max_steps: float(
current_step) / float(max(1, num_warmup_steps))
if current_step < num_warmup_steps else max(
0.0,
float(num_training_steps - current_step) / float(
max(1, num_training_steps - num_warmup_steps))))
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
epsilon=args.adam_epsilon,
parameters=model.parameters(),
weight_decay=args.weight_decay,
grad_clip=nn.ClipGradByGlobalNorm(1.0),
apply_decay_param_fun=lambda x: x in [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
])
rouge1 = Rouge1()
rouge2 = Rouge2()
global_step = 1
tic_train = time.time()
for epoch in range(args.num_epochs):
for step, batch in enumerate(train_data_loader, start=1):
(src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids, attn_ids,
mask_src_2_src, mask_tgt_2_srctgt, mask_attn_2_srctgtattn,
tgt_labels, _) = batch
# import pdb; pdb.set_trace()
_, __, info = model(
src_ids,
sent_ids=src_sids,
pos_ids=src_pids,
attn_bias=mask_src_2_src,
encode_only=True)
cached_k, cached_v = info['caches']
_, __, info = model(
tgt_ids,
sent_ids=tgt_sids,
pos_ids=tgt_pids,
attn_bias=mask_tgt_2_srctgt,
past_cache=(cached_k, cached_v),
encode_only=True)
cached_k2, cached_v2 = info['caches']
past_cache_k = [
paddle.concat([k, k2], 1) for k, k2 in zip(cached_k, cached_k2)
]
past_cache_v = [
paddle.concat([v, v2], 1) for v, v2 in zip(cached_v, cached_v2)
]
if args.label_smooth > 0.:
tgt_labels = nn.functional.label_smooth(
nn.functional.one_hot(tgt_labels, label_num),
epsilon=args.label_smooth)
loss, _, __ = model(
attn_ids,
sent_ids=tgt_sids,
pos_ids=tgt_pids,
attn_bias=mask_attn_2_srctgtattn,
past_cache=(past_cache_k, past_cache_v),
tgt_labels=tgt_labels,
tgt_pos=paddle.nonzero(attn_ids == attn_id))
if global_step % args.logging_steps == 0:
if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0:
logger.info(
"global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, lr: %.3e"
% (global_step, epoch, step, loss, args.logging_steps /
(time.time() - tic_train), lr_scheduler.get_lr()))
tic_train = time.time()
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_gradients()
if global_step % args.save_steps == 0 and (
(not args.n_gpu > 1) or paddle.distributed.get_rank() == 0):
evaluate(model, dev_data_loader, tokenizer, rouge1, rouge2,
attn_id, tgt_type_id, args)
output_dir = os.path.join(args.output_dir,
"model_%d" % global_step)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model_to_save = model._layers if isinstance(
model, paddle.DataParallel) else model
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
global_step += 1
if __name__ == "__main__":
if args.n_gpu > 1:
paddle.distributed.spawn(train, nprocs=args.n_gpu)
else:
train()
......@@ -21,3 +21,5 @@ from .ptb import *
from .squad import *
from .translation import *
from .dureader import *
from .cnndm import *
from .poetry import *
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import warnings
from paddle.io import Dataset
from paddle.dataset.common import DATA_HOME, md5file
from paddle.utils.download import get_path_from_url
from .dataset import TSVDataset
__all__ = ['CnnDm']
class CnnDm(TSVDataset):
URL = "https://ernie-github.cdn.bcebos.com/data-cnndm.tar.gz"
MD5 = None
SEGMENT_INFO = collections.namedtuple(
'SEGMENT_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples'))
SEGMENTS = {
'train': SEGMENT_INFO(
os.path.join('cnndm', 'train', '1'),
'8b10ed0ae31e71e8cd9105a6978d8970', (1, 2), 0),
'dev': SEGMENT_INFO(
os.path.join('cnndm', 'dev', '1'),
'7cb22f9cac04a285790a91cebba75260', (1, 2), 0),
}
def __init__(self, segment='train', root=None, **kwargs):
default_root = os.path.join(DATA_HOME)
filename, data_hash, field_indices, num_discard_samples = self.SEGMENTS[
segment]
fullname = os.path.join(default_root,
filename) if root is None else os.path.join(
os.path.expanduser(root), filename)
if not os.path.exists(fullname) or (data_hash and
not md5file(fullname) == data_hash):
if root is not None: # not specified, and no need to warn
warnings.warn(
'md5 check failed for {}, download {} data to {}'.format(
filename, self.__class__.__name__, default_root))
path = get_path_from_url(self.URL, default_root, self.MD5)
fullname = os.path.join(default_root, filename)
super(CnnDm, self).__init__(
fullname,
field_indices=field_indices,
num_discard_samples=num_discard_samples,
**kwargs)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import warnings
from paddle.io import Dataset
from paddle.dataset.common import DATA_HOME, md5file
from paddle.utils.download import get_path_from_url
from .dataset import TSVDataset
__all__ = ['Poetry']
class Poetry(TSVDataset):
URL = "https://paddlenlp.bj.bcebos.com/datasets/poetry.tar.gz"
MD5 = '8edd7eda1b273145b70ef29c82cd622b'
SEGMENT_INFO = collections.namedtuple(
'SEGMENT_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples'))
SEGMENTS = {
'train': SEGMENT_INFO(
os.path.join('poetry', 'train.tsv'),
'176c6202b5e71656ae7e7848eec4c54f', (0, 1), 0),
'dev': SEGMENT_INFO(
os.path.join('poetry', 'dev.tsv'),
'737e4b6da5facdc0ac33fe688df19931', (0, 1), 0),
'test': SEGMENT_INFO(
os.path.join('poetry', 'test.tsv'),
'1dca907b2d712730c7c828f8acee7431', (0, 1), 0),
}
def __init__(self, segment='train', root=None, **kwargs):
default_root = os.path.join(DATA_HOME, 'poetry')
filename, data_hash, field_indices, num_discard_samples = self.SEGMENTS[
segment]
fullname = os.path.join(default_root,
filename) if root is None else os.path.join(
os.path.expanduser(root), filename)
if not os.path.exists(fullname) or (data_hash and
not md5file(fullname) == data_hash):
if root is not None: # not specified, and no need to warn
warnings.warn(
'md5 check failed for {}, download {} data to {}'.format(
filename, self.__class__.__name__, default_root))
path = get_path_from_url(self.URL, default_root, self.MD5)
fullname = os.path.join(default_root, filename)
super(Poetry, self).__init__(
fullname,
field_indices=field_indices,
num_discard_samples=num_discard_samples,
**kwargs)
......@@ -15,5 +15,5 @@
from .perplexity import Perplexity
from .chunk import ChunkEvaluator
from .bleu import BLEU, BLEUForDuReader
from .rouge import RougeL, RougeLForDuReader
from .rouge import RougeL, RougeLForDuReader, RougeN, Rouge1, Rouge2
from .glue import AccuracyAndF1, Mcc, PearsonAndSpearman
......@@ -20,6 +20,94 @@ from .utils import default_trans_func
__all__ = ['RougeL', 'RougeLForDuReader']
class RougeN():
def __init__(self, n):
self.n = n
def _get_ngrams(self, words):
"""Calculates word n-grams for multiple sentences.
"""
ngram_set = set()
max_index_ngram_start = len(words) - self.n
for i in range(max_index_ngram_start + 1):
ngram_set.add(tuple(words[i:i + self.n]))
return ngram_set
def score(self, evaluated_sentences_ids, reference_sentences_ids):
overlapping_count, reference_count = self.compute(
evaluated_sentences_ids, reference_sentences_ids)
return overlapping_count / reference_count
def compute(self, evaluated_sentences_ids, reference_sentences_ids):
"""
Args:
evaluated_sentences (list): the sentences ids predicted by the model.
reference_sentences (list): the referenced sentences ids. Its size should be same as evaluated_sentences.
Returns:
overlapping_count (int): the overlapping n-gram count.
reference_count (int): the reference sentences n-gram count.
"""
if len(evaluated_sentences_ids) <= 0 or len(
reference_sentences_ids) <= 0:
raise ValueError("Collections must contain at least 1 sentence.")
reference_count = 0
overlapping_count = 0
for evaluated_sentence_ids, reference_sentence_ids in zip(
evaluated_sentences_ids, reference_sentences_ids):
evaluated_ngrams = self._get_ngrams(evaluated_sentence_ids)
reference_ngrams = self._get_ngrams(reference_sentence_ids)
reference_count += len(reference_ngrams)
# Gets the overlapping ngrams between evaluated and reference
overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
overlapping_count += len(overlapping_ngrams)
return overlapping_count, reference_count
def accumulate(self):
"""
This function returns the mean precision, recall and f1 score for all accumulated minibatches.
Returns:
float: mean precision, recall and f1 score.
"""
rouge_score = self.overlapping_count / self.reference_count
return rouge_score
def reset(self):
"""
Reset function empties the evaluation memory for previous mini-batches.
"""
self.overlapping_count = 0
self.reference_count = 0
def name(self):
"""
Return name of metric instance.
"""
return "Rouge-%s" % self.n
def update(self, overlapping_count, reference_count):
"""
Args:
"""
self.overlapping_count += overlapping_count
self.reference_count += reference_count
class Rouge1(RougeN):
def __init__(self):
super(Rouge1, self).__init__(n=1)
class Rouge2(RougeN):
def __init__(self):
super(Rouge2, self).__init__(n=2)
class RougeL(paddle.metric.Metric):
r'''
Rouge-L is Recall-Oriented Understudy for Gisting Evaluation based on Longest Common Subsequence (LCS).
......
......@@ -24,3 +24,4 @@ from .roberta.tokenizer import *
from .electra.modeling import *
from .electra.tokenizer import *
from .transformer.modeling import *
from .ernie_gen.modeling import ErnieForGeneration
......@@ -24,8 +24,8 @@ import paddle.nn.functional as F
from .. import PretrainedModel, register_base_model
__all__ = [
'ElectraModel', 'ElectraForTotalPretraining', 'ElectraDiscriminator',
'ElectraGenerator', 'ElectraClassificationHead',
'ElectraModel', 'ElectraPretrainedModel', 'ElectraForTotalPretraining',
'ElectraDiscriminator', 'ElectraGenerator', 'ElectraClassificationHead',
'ElectraForSequenceClassification', 'ElectraForTokenClassification',
'ElectraPretrainingCriterion'
]
......
......@@ -59,6 +59,12 @@ class ErnieTokenizer(PretrainedTokenizer):
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie_v2_base/vocab.txt",
"ernie-2.0-large-en":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie_v2_large/vocab.txt",
"ernie-gen-base-en":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie-gen-base-en/vocab.txt",
"ernie-gen-large-en":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie-gen-large/vocab.txt",
"ernie-gen-large-430g-en":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie-gen-large-430g/vocab.txt",
}
}
pretrained_init_configuration = {
......@@ -71,6 +77,15 @@ class ErnieTokenizer(PretrainedTokenizer):
"ernie-2.0-large-en": {
"do_lower_case": True
},
"ernie-gen-base-en": {
"do_lower_case": True
},
"ernie-gen-large-en": {
"do_lower_case": True
},
"ernie-gen-large-430g-en": {
"do_lower_case": True
},
}
def __init__(self,
......
此差异已折叠。
{"embeddings.word_embeddings.weight": "word_emb.weight", "embeddings.position_embeddings.weight": "pos_emb.weight", "embeddings.token_type_embeddings.weight": "sent_emb.weight", "embeddings.layer_norm.weight": "ln.weight", "embeddings.layer_norm.bias": "ln.bias", "encoder.layers.0.self_attn.q_proj.weight": "encoder_stack.block.0.attn.q.weight", "encoder.layers.0.self_attn.q_proj.bias": "encoder_stack.block.0.attn.q.bias", "encoder.layers.0.self_attn.k_proj.weight": "encoder_stack.block.0.attn.k.weight", "encoder.layers.0.self_attn.k_proj.bias": "encoder_stack.block.0.attn.k.bias", "encoder.layers.0.self_attn.v_proj.weight": "encoder_stack.block.0.attn.v.weight", "encoder.layers.0.self_attn.v_proj.bias": "encoder_stack.block.0.attn.v.bias", "encoder.layers.0.self_attn.out_proj.weight": "encoder_stack.block.0.attn.o.weight", "encoder.layers.0.self_attn.out_proj.bias": "encoder_stack.block.0.attn.o.bias", "encoder.layers.1.self_attn.q_proj.weight": "encoder_stack.block.1.attn.q.weight", "encoder.layers.1.self_attn.q_proj.bias": "encoder_stack.block.1.attn.q.bias", "encoder.layers.1.self_attn.k_proj.weight": "encoder_stack.block.1.attn.k.weight", "encoder.layers.1.self_attn.k_proj.bias": "encoder_stack.block.1.attn.k.bias", "encoder.layers.1.self_attn.v_proj.weight": "encoder_stack.block.1.attn.v.weight", "encoder.layers.1.self_attn.v_proj.bias": "encoder_stack.block.1.attn.v.bias", "encoder.layers.1.self_attn.out_proj.weight": "encoder_stack.block.1.attn.o.weight", "encoder.layers.1.self_attn.out_proj.bias": "encoder_stack.block.1.attn.o.bias", "encoder.layers.2.self_attn.q_proj.weight": "encoder_stack.block.2.attn.q.weight", "encoder.layers.2.self_attn.q_proj.bias": "encoder_stack.block.2.attn.q.bias", "encoder.layers.2.self_attn.k_proj.weight": "encoder_stack.block.2.attn.k.weight", "encoder.layers.2.self_attn.k_proj.bias": "encoder_stack.block.2.attn.k.bias", "encoder.layers.2.self_attn.v_proj.weight": "encoder_stack.block.2.attn.v.weight", "encoder.layers.2.self_attn.v_proj.bias": "encoder_stack.block.2.attn.v.bias", "encoder.layers.2.self_attn.out_proj.weight": "encoder_stack.block.2.attn.o.weight", "encoder.layers.2.self_attn.out_proj.bias": "encoder_stack.block.2.attn.o.bias", "encoder.layers.3.self_attn.q_proj.weight": "encoder_stack.block.3.attn.q.weight", "encoder.layers.3.self_attn.q_proj.bias": "encoder_stack.block.3.attn.q.bias", "encoder.layers.3.self_attn.k_proj.weight": "encoder_stack.block.3.attn.k.weight", "encoder.layers.3.self_attn.k_proj.bias": "encoder_stack.block.3.attn.k.bias", "encoder.layers.3.self_attn.v_proj.weight": "encoder_stack.block.3.attn.v.weight", "encoder.layers.3.self_attn.v_proj.bias": "encoder_stack.block.3.attn.v.bias", "encoder.layers.3.self_attn.out_proj.weight": "encoder_stack.block.3.attn.o.weight", "encoder.layers.3.self_attn.out_proj.bias": "encoder_stack.block.3.attn.o.bias", "encoder.layers.4.self_attn.q_proj.weight": "encoder_stack.block.4.attn.q.weight", "encoder.layers.4.self_attn.q_proj.bias": "encoder_stack.block.4.attn.q.bias", "encoder.layers.4.self_attn.k_proj.weight": "encoder_stack.block.4.attn.k.weight", "encoder.layers.4.self_attn.k_proj.bias": "encoder_stack.block.4.attn.k.bias", "encoder.layers.4.self_attn.v_proj.weight": "encoder_stack.block.4.attn.v.weight", "encoder.layers.4.self_attn.v_proj.bias": "encoder_stack.block.4.attn.v.bias", "encoder.layers.4.self_attn.out_proj.weight": "encoder_stack.block.4.attn.o.weight", "encoder.layers.4.self_attn.out_proj.bias": "encoder_stack.block.4.attn.o.bias", "encoder.layers.5.self_attn.q_proj.weight": "encoder_stack.block.5.attn.q.weight", "encoder.layers.5.self_attn.q_proj.bias": "encoder_stack.block.5.attn.q.bias", "encoder.layers.5.self_attn.k_proj.weight": "encoder_stack.block.5.attn.k.weight", "encoder.layers.5.self_attn.k_proj.bias": "encoder_stack.block.5.attn.k.bias", "encoder.layers.5.self_attn.v_proj.weight": "encoder_stack.block.5.attn.v.weight", "encoder.layers.5.self_attn.v_proj.bias": "encoder_stack.block.5.attn.v.bias", "encoder.layers.5.self_attn.out_proj.weight": "encoder_stack.block.5.attn.o.weight", "encoder.layers.5.self_attn.out_proj.bias": "encoder_stack.block.5.attn.o.bias", "encoder.layers.6.self_attn.q_proj.weight": "encoder_stack.block.6.attn.q.weight", "encoder.layers.6.self_attn.q_proj.bias": "encoder_stack.block.6.attn.q.bias", "encoder.layers.6.self_attn.k_proj.weight": "encoder_stack.block.6.attn.k.weight", "encoder.layers.6.self_attn.k_proj.bias": "encoder_stack.block.6.attn.k.bias", "encoder.layers.6.self_attn.v_proj.weight": "encoder_stack.block.6.attn.v.weight", "encoder.layers.6.self_attn.v_proj.bias": "encoder_stack.block.6.attn.v.bias", "encoder.layers.6.self_attn.out_proj.weight": "encoder_stack.block.6.attn.o.weight", "encoder.layers.6.self_attn.out_proj.bias": "encoder_stack.block.6.attn.o.bias", "encoder.layers.7.self_attn.q_proj.weight": "encoder_stack.block.7.attn.q.weight", "encoder.layers.7.self_attn.q_proj.bias": "encoder_stack.block.7.attn.q.bias", "encoder.layers.7.self_attn.k_proj.weight": "encoder_stack.block.7.attn.k.weight", "encoder.layers.7.self_attn.k_proj.bias": "encoder_stack.block.7.attn.k.bias", "encoder.layers.7.self_attn.v_proj.weight": "encoder_stack.block.7.attn.v.weight", "encoder.layers.7.self_attn.v_proj.bias": "encoder_stack.block.7.attn.v.bias", "encoder.layers.7.self_attn.out_proj.weight": "encoder_stack.block.7.attn.o.weight", "encoder.layers.7.self_attn.out_proj.bias": "encoder_stack.block.7.attn.o.bias", "encoder.layers.8.self_attn.q_proj.weight": "encoder_stack.block.8.attn.q.weight", "encoder.layers.8.self_attn.q_proj.bias": "encoder_stack.block.8.attn.q.bias", "encoder.layers.8.self_attn.k_proj.weight": "encoder_stack.block.8.attn.k.weight", "encoder.layers.8.self_attn.k_proj.bias": "encoder_stack.block.8.attn.k.bias", "encoder.layers.8.self_attn.v_proj.weight": "encoder_stack.block.8.attn.v.weight", "encoder.layers.8.self_attn.v_proj.bias": "encoder_stack.block.8.attn.v.bias", "encoder.layers.8.self_attn.out_proj.weight": "encoder_stack.block.8.attn.o.weight", "encoder.layers.8.self_attn.out_proj.bias": "encoder_stack.block.8.attn.o.bias", "encoder.layers.9.self_attn.q_proj.weight": "encoder_stack.block.9.attn.q.weight", "encoder.layers.9.self_attn.q_proj.bias": "encoder_stack.block.9.attn.q.bias", "encoder.layers.9.self_attn.k_proj.weight": "encoder_stack.block.9.attn.k.weight", "encoder.layers.9.self_attn.k_proj.bias": "encoder_stack.block.9.attn.k.bias", "encoder.layers.9.self_attn.v_proj.weight": "encoder_stack.block.9.attn.v.weight", "encoder.layers.9.self_attn.v_proj.bias": "encoder_stack.block.9.attn.v.bias", "encoder.layers.9.self_attn.out_proj.weight": "encoder_stack.block.9.attn.o.weight", "encoder.layers.9.self_attn.out_proj.bias": "encoder_stack.block.9.attn.o.bias", "encoder.layers.10.self_attn.q_proj.weight": "encoder_stack.block.10.attn.q.weight", "encoder.layers.10.self_attn.q_proj.bias": "encoder_stack.block.10.attn.q.bias", "encoder.layers.10.self_attn.k_proj.weight": "encoder_stack.block.10.attn.k.weight", "encoder.layers.10.self_attn.k_proj.bias": "encoder_stack.block.10.attn.k.bias", "encoder.layers.10.self_attn.v_proj.weight": "encoder_stack.block.10.attn.v.weight", "encoder.layers.10.self_attn.v_proj.bias": "encoder_stack.block.10.attn.v.bias", "encoder.layers.10.self_attn.out_proj.weight": "encoder_stack.block.10.attn.o.weight", "encoder.layers.10.self_attn.out_proj.bias": "encoder_stack.block.10.attn.o.bias", "encoder.layers.11.self_attn.q_proj.weight": "encoder_stack.block.11.attn.q.weight", "encoder.layers.11.self_attn.q_proj.bias": "encoder_stack.block.11.attn.q.bias", "encoder.layers.11.self_attn.k_proj.weight": "encoder_stack.block.11.attn.k.weight", "encoder.layers.11.self_attn.k_proj.bias": "encoder_stack.block.11.attn.k.bias", "encoder.layers.11.self_attn.v_proj.weight": "encoder_stack.block.11.attn.v.weight", "encoder.layers.11.self_attn.v_proj.bias": "encoder_stack.block.11.attn.v.bias", "encoder.layers.11.self_attn.out_proj.weight": "encoder_stack.block.11.attn.o.weight", "encoder.layers.11.self_attn.out_proj.bias": "encoder_stack.block.11.attn.o.bias", "encoder.layers.0.linear1.weight": "encoder_stack.block.0.ffn.i.weight", "encoder.layers.0.linear1.bias": "encoder_stack.block.0.ffn.i.bias", "encoder.layers.0.linear2.weight": "encoder_stack.block.0.ffn.o.weight", "encoder.layers.0.linear2.bias": "encoder_stack.block.0.ffn.o.bias", "encoder.layers.1.linear1.weight": "encoder_stack.block.1.ffn.i.weight", "encoder.layers.1.linear1.bias": "encoder_stack.block.1.ffn.i.bias", "encoder.layers.1.linear2.weight": "encoder_stack.block.1.ffn.o.weight", "encoder.layers.1.linear2.bias": "encoder_stack.block.1.ffn.o.bias", "encoder.layers.2.linear1.weight": "encoder_stack.block.2.ffn.i.weight", "encoder.layers.2.linear1.bias": "encoder_stack.block.2.ffn.i.bias", "encoder.layers.2.linear2.weight": "encoder_stack.block.2.ffn.o.weight", "encoder.layers.2.linear2.bias": "encoder_stack.block.2.ffn.o.bias", "encoder.layers.3.linear1.weight": "encoder_stack.block.3.ffn.i.weight", "encoder.layers.3.linear1.bias": "encoder_stack.block.3.ffn.i.bias", "encoder.layers.3.linear2.weight": "encoder_stack.block.3.ffn.o.weight", "encoder.layers.3.linear2.bias": "encoder_stack.block.3.ffn.o.bias", "encoder.layers.4.linear1.weight": "encoder_stack.block.4.ffn.i.weight", "encoder.layers.4.linear1.bias": "encoder_stack.block.4.ffn.i.bias", "encoder.layers.4.linear2.weight": "encoder_stack.block.4.ffn.o.weight", "encoder.layers.4.linear2.bias": "encoder_stack.block.4.ffn.o.bias", "encoder.layers.5.linear1.weight": "encoder_stack.block.5.ffn.i.weight", "encoder.layers.5.linear1.bias": "encoder_stack.block.5.ffn.i.bias", "encoder.layers.5.linear2.weight": "encoder_stack.block.5.ffn.o.weight", "encoder.layers.5.linear2.bias": "encoder_stack.block.5.ffn.o.bias", "encoder.layers.6.linear1.weight": "encoder_stack.block.6.ffn.i.weight", "encoder.layers.6.linear1.bias": "encoder_stack.block.6.ffn.i.bias", "encoder.layers.6.linear2.weight": "encoder_stack.block.6.ffn.o.weight", "encoder.layers.6.linear2.bias": "encoder_stack.block.6.ffn.o.bias", "encoder.layers.7.linear1.weight": "encoder_stack.block.7.ffn.i.weight", "encoder.layers.7.linear1.bias": "encoder_stack.block.7.ffn.i.bias", "encoder.layers.7.linear2.weight": "encoder_stack.block.7.ffn.o.weight", "encoder.layers.7.linear2.bias": "encoder_stack.block.7.ffn.o.bias", "encoder.layers.8.linear1.weight": "encoder_stack.block.8.ffn.i.weight", "encoder.layers.8.linear1.bias": "encoder_stack.block.8.ffn.i.bias", "encoder.layers.8.linear2.weight": "encoder_stack.block.8.ffn.o.weight", "encoder.layers.8.linear2.bias": "encoder_stack.block.8.ffn.o.bias", "encoder.layers.9.linear1.weight": "encoder_stack.block.9.ffn.i.weight", "encoder.layers.9.linear1.bias": "encoder_stack.block.9.ffn.i.bias", "encoder.layers.9.linear2.weight": "encoder_stack.block.9.ffn.o.weight", "encoder.layers.9.linear2.bias": "encoder_stack.block.9.ffn.o.bias", "encoder.layers.10.linear1.weight": "encoder_stack.block.10.ffn.i.weight", "encoder.layers.10.linear1.bias": "encoder_stack.block.10.ffn.i.bias", "encoder.layers.10.linear2.weight": "encoder_stack.block.10.ffn.o.weight", "encoder.layers.10.linear2.bias": "encoder_stack.block.10.ffn.o.bias", "encoder.layers.11.linear1.weight": "encoder_stack.block.11.ffn.i.weight", "encoder.layers.11.linear1.bias": "encoder_stack.block.11.ffn.i.bias", "encoder.layers.11.linear2.weight": "encoder_stack.block.11.ffn.o.weight", "encoder.layers.11.linear2.bias": "encoder_stack.block.11.ffn.o.bias", "encoder.layers.0.norm1.weight": "encoder_stack.block.0.ln1.weight", "encoder.layers.0.norm1.bias": "encoder_stack.block.0.ln1.bias", "encoder.layers.1.norm1.weight": "encoder_stack.block.1.ln1.weight", "encoder.layers.1.norm1.bias": "encoder_stack.block.1.ln1.bias", "encoder.layers.2.norm1.weight": "encoder_stack.block.2.ln1.weight", "encoder.layers.2.norm1.bias": "encoder_stack.block.2.ln1.bias", "encoder.layers.3.norm1.weight": "encoder_stack.block.3.ln1.weight", "encoder.layers.3.norm1.bias": "encoder_stack.block.3.ln1.bias", "encoder.layers.4.norm1.weight": "encoder_stack.block.4.ln1.weight", "encoder.layers.4.norm1.bias": "encoder_stack.block.4.ln1.bias", "encoder.layers.5.norm1.weight": "encoder_stack.block.5.ln1.weight", "encoder.layers.5.norm1.bias": "encoder_stack.block.5.ln1.bias", "encoder.layers.6.norm1.weight": "encoder_stack.block.6.ln1.weight", "encoder.layers.6.norm1.bias": "encoder_stack.block.6.ln1.bias", "encoder.layers.7.norm1.weight": "encoder_stack.block.7.ln1.weight", "encoder.layers.7.norm1.bias": "encoder_stack.block.7.ln1.bias", "encoder.layers.8.norm1.weight": "encoder_stack.block.8.ln1.weight", "encoder.layers.8.norm1.bias": "encoder_stack.block.8.ln1.bias", "encoder.layers.9.norm1.weight": "encoder_stack.block.9.ln1.weight", "encoder.layers.9.norm1.bias": "encoder_stack.block.9.ln1.bias", "encoder.layers.10.norm1.weight": "encoder_stack.block.10.ln1.weight", "encoder.layers.10.norm1.bias": "encoder_stack.block.10.ln1.bias", "encoder.layers.11.norm1.weight": "encoder_stack.block.11.ln1.weight", "encoder.layers.11.norm1.bias": "encoder_stack.block.11.ln1.bias", "encoder.layers.0.norm2.weight": "encoder_stack.block.0.ln2.weight", "encoder.layers.0.norm2.bias": "encoder_stack.block.0.ln2.bias", "encoder.layers.1.norm2.weight": "encoder_stack.block.1.ln2.weight", "encoder.layers.1.norm2.bias": "encoder_stack.block.1.ln2.bias", "encoder.layers.2.norm2.weight": "encoder_stack.block.2.ln2.weight", "encoder.layers.2.norm2.bias": "encoder_stack.block.2.ln2.bias", "encoder.layers.3.norm2.weight": "encoder_stack.block.3.ln2.weight", "encoder.layers.3.norm2.bias": "encoder_stack.block.3.ln2.bias", "encoder.layers.4.norm2.weight": "encoder_stack.block.4.ln2.weight", "encoder.layers.4.norm2.bias": "encoder_stack.block.4.ln2.bias", "encoder.layers.5.norm2.weight": "encoder_stack.block.5.ln2.weight", "encoder.layers.5.norm2.bias": "encoder_stack.block.5.ln2.bias", "encoder.layers.6.norm2.weight": "encoder_stack.block.6.ln2.weight", "encoder.layers.6.norm2.bias": "encoder_stack.block.6.ln2.bias", "encoder.layers.7.norm2.weight": "encoder_stack.block.7.ln2.weight", "encoder.layers.7.norm2.bias": "encoder_stack.block.7.ln2.bias", "encoder.layers.8.norm2.weight": "encoder_stack.block.8.ln2.weight", "encoder.layers.8.norm2.bias": "encoder_stack.block.8.ln2.bias", "encoder.layers.9.norm2.weight": "encoder_stack.block.9.ln2.weight", "encoder.layers.9.norm2.bias": "encoder_stack.block.9.ln2.bias", "encoder.layers.10.norm2.weight": "encoder_stack.block.10.ln2.weight", "encoder.layers.10.norm2.bias": "encoder_stack.block.10.ln2.bias", "encoder.layers.11.norm2.weight": "encoder_stack.block.11.ln2.weight", "encoder.layers.11.norm2.bias": "encoder_stack.block.11.ln2.bias", "pooler.dense.weight": "pooler.weight", "pooler.dense.bias": "pooler.bias"}
\ No newline at end of file
......@@ -19,6 +19,7 @@ from .. import PretrainedModel, register_base_model
__all__ = [
'RobertaModel',
'RobertaPretrainedModel',
'RobertaForSequenceClassification',
'RobertaForTokenClassification',
'RobertaForQuestionAnswering',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册