diff --git a/PaddleNLP/examples/lexical_analysis/README.md b/PaddleNLP/examples/lexical_analysis/README.md index debbafc525e301d8ada04bc5f319e4f8f6a4076e..98cbfa1f867d930e1f2b6d29b8cf81b1a74a0783 100644 --- a/PaddleNLP/examples/lexical_analysis/README.md +++ b/PaddleNLP/examples/lexical_analysis/README.md @@ -18,9 +18,9 @@ - Python >= 3.6 -- PaddlePaddle >= 2.0.0rc1,安装方式请参考 [快速安装](https://www.paddlepaddle.org.cn/install/quick)。 +- paddlepaddle >= 2.0.0rc1,安装方式请参考 [快速安装](https://www.paddlepaddle.org.cn/install/quick)。 -- PaddleNLP >= 2.0.0b, 安装方式:`pip install paddlenlp>=2.0.0b` +- paddlenlp >= 2.0.0b, 安装方式:`pip install paddlenlp>=2.0.0b` ### 2.2 数据准备 diff --git a/PaddleNLP/examples/text_generation/ernie-gen/README.md b/PaddleNLP/examples/text_generation/ernie-gen/README.md index d05295e68503903d82b78e36094781d59ad769e1..af31ffa4967e68ebff91496a9344b80b370f0d7c 100644 --- a/PaddleNLP/examples/text_generation/ernie-gen/README.md +++ b/PaddleNLP/examples/text_generation/ernie-gen/README.md @@ -1,3 +1,125 @@ -# ERNIE-Gen +# ERNIE-Gen: An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation -TBD +## 1. 简介 + +**ERNIE-GEN 是面向生成任务的预训练-微调框架**,首次在预训练阶段加入**span-by-span 生成**任务,让模型每次能够生成一个语义完整的片段。在预训练和微调中通过**填充式生成机制**和**噪声感知机制**来缓解曝光偏差问题。此外, ERNIE-GEN 采样**多片段-多粒度目标文本采样**策略, 增强源文本和目标文本的关联性,加强了编码器和解码器的交互。 + +![multi-flow-attention](https://github.com/PaddlePaddle/ERNIE/raw/repro/ernie-gen/.meta/multi-flow-attention.png) + +## 2. 快速开始 + +### 2.1 环境配置 + +- Python >= 3.6 + +- paddlepaddle >= 2.0.0rc1,安装方式请参考 [快速安装](https://www.paddlepaddle.org.cn/install/quick)。 + +- paddlenlp >= 2.0.0b, 安装方式:`pip install paddlenlp>=2.0.0b` + +### 2.2 数据准备 + +在本例中,我们提供了古诗词数据集,示例数据如下: + +```text +画\002精\002禅\002室\002冷\002,\002方\002暑\002久\002徘\002徊\002。 不\002尽\002林\002端\002雪\002,\002长\002青\002石\002上\002苔\002。\002心\002闲\002对\002岩\002岫\002,\002目\002浄\002失\002尘\002埃\002。\002坐\002久\002清\002风\002至\002,\002疑\002从\002翠\002涧\002来\002。 +``` + +每行数据都是由两列组成,以制表符分隔。第一列是输入的诗句前文,第二列是输出的诗句后文,所有文字都以 `\002` 分隔。 + +完整数据集可以通过以下命令下载并解压: + +```bash +wget --no-check-certificate https://paddlenlp.bj.bcebos.com/datasets/poetry.tar.gz +tar xvf poetry.tar.gz +``` + +### 2.3 模型微调 + +模型训练支持 CPU 和 GPU,使用 GPU 之前应指定使用的显卡卡号: + +```bash +export CUDA_VISIBLE_DEVICES=0,1,2 # 支持多卡训练 +``` + +训练启动方式如下: + +```bash +python -u ./train.py \ + --model_name_or_path ernie-1.0 \ + --max_encode_len 24 \ + --max_decode_len 72 \ + --batch_size 48 \ + --learning_rate 2e-5 \ + --num_epochs 12 \ + --logging_steps 1 \ + --save_steps 1000 \ + --output_dir ./tmp/ \ + --n_gpu 3 \ + # --init_checkpoint ./tmp/model_10000/model_state.pdparams +``` + +参数释义如下: +- `model_name_or_path` 指示了某种特定配置的模型,对应有其预训练模型和预训练时使用的 tokenizer。若模型相关内容保存在本地,这里也可以提供相应目录地址。 +- `max_encode_len` 表示最大输入句子长度,超过该长度将被截断。 +- `max_decode_len` 表示最大输出句子长度,超过该长度将被截断。 +- `batch_size` 表示每次迭代**每张卡**上的样本数目。 +- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。 +- `num_epochs` 表示训练轮数。 +- `logging_steps` 表示日志打印间隔。 +- `save_steps` 表示模型保存及评估间隔。 +- `output_dir` 表示模型保存路径。 +- `n_gpu` 表示使用的 GPU 卡数。若希望使用多卡训练,将其设置为指定数目即可;若为0,则使用CPU。 +- `init_checkpoint` 表示模型加载路径,通过设置此参数可以开启增量训练。 + +### 2.4 模型评估 + +通过加载训练保存的模型,可以对验证集数据进行验证,启动方式如下: + +```bash +python -u ./eval.py \ + --model_name_or_path ernie-1.0 \ + --max_encode_len 24 \ + --max_decode_len 72 \ + --batch_size 48 \ + --init_checkpoint ./tmp/model_10000/model_state.pdparams \ + --use_gpu +``` + +参数释义如下: +- `model_name_or_path` 指示了某种特定配置的模型,对应有其预训练模型和预训练时使用的 tokenizer。若模型相关内容保存在本地,这里也可以提供相应目录地址。 +- `max_encode_len` 表示最大输入句子长度,超过该长度将被截断。 +- `max_decode_len` 表示最大输出句子长度,超过该长度将被截断。 +- `batch_size` 表示每次迭代**每张卡**上的样本数目。 +- `init_checkpoint` 表示模型加载路径。 +- `use_gpu` 表示使用GPU。 + +### 2.5 模型预测 + +对无标签数据可以启动模型预测: + +```bash +python -u ./predict.py \ + --model_name_or_path ernie-1.0 \ + --max_encode_len 24 \ + --max_decode_len 72 \ + --batch_size 48 \ + --init_checkpoint ./tmp/model_10000/model_state.pdparams \ + --use_gpu +``` + +## 引用 + +您可以按下面的格式引用ERNIE-Gen论文: + +``` +@article{xiao2020ernie-gen, + title={ERNIE-GEN: An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation}, + author={Xiao, Dongling and Zhang, Han and Li, Yukun and Sun, Yu and Tian, Hao and Wu, Hua and Wang, Haifeng}, + journal={arXiv preprint arXiv:2001.11314}, + year={2020} +} +``` + +## 如何贡献代码 + +如果你可以修复某个 issue 或者增加一个新功能,欢迎给我们提交 PR。如果对应的 PR 被接受了,我们将根据贡献的质量和难度 进行打分(0-5 分,越高越好)。如果你累计获得了 10 分,可以联系我们获得面试机会或为你写推荐信。 diff --git a/PaddleNLP/examples/text_generation/ernie-gen/decode.py b/PaddleNLP/examples/text_generation/ernie-gen/decode.py new file mode 100644 index 0000000000000000000000000000000000000000..c705de152cfc2054a4bf79a43ce037d98d9f0df7 --- /dev/null +++ b/PaddleNLP/examples/text_generation/ernie-gen/decode.py @@ -0,0 +1,340 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +from __future__ import unicode_literals + +import sys +import re +import argparse +import logging +import json +import numpy as np +from collections import namedtuple + +import paddle +import paddle.nn as nn +import numpy as np +from paddlenlp.utils.log import logger + + +def gen_bias(encoder_inputs, decoder_inputs, step): + decoder_bsz, decoder_seqlen = decoder_inputs.shape[:2] + encoder_bsz, encoder_seqlen = encoder_inputs.shape[:2] + attn_bias = paddle.reshape( + paddle.arange( + 0, decoder_seqlen, 1, dtype='float32') + 1, [1, -1, 1]) + decoder_bias = paddle.cast( + (paddle.matmul( + attn_bias, 1. / attn_bias, transpose_y=True) >= 1.), + 'float32') #[1, decoderlen, decoderlen] + encoder_bias = paddle.unsqueeze( + paddle.cast(paddle.ones_like(encoder_inputs), 'float32'), + [1]) #[bsz, 1, encoderlen] + encoder_bias = paddle.expand( + encoder_bias, [encoder_bsz, decoder_seqlen, + encoder_seqlen]) #[bsz,decoderlen, encoderlen] + decoder_bias = paddle.expand( + decoder_bias, [decoder_bsz, decoder_seqlen, + decoder_seqlen]) #[bsz, decoderlen, decoderlen] + if step > 0: + bias = paddle.concat([ + encoder_bias, paddle.ones([decoder_bsz, decoder_seqlen, step], + 'float32'), decoder_bias + ], -1) + else: + bias = paddle.concat([encoder_bias, decoder_bias], -1) + return bias + + +@paddle.no_grad() +def greedy_search_infilling(model, + q_ids, + q_sids, + sos_id, + eos_id, + attn_id, + pad_id, + unk_id, + vocab_size, + max_encode_len=640, + max_decode_len=100, + tgt_type_id=3): + _, logits, info = model(q_ids, q_sids) + d_batch, d_seqlen = q_ids.shape + seqlen = paddle.sum(paddle.cast(q_ids != 0, 'int64'), 1, keepdim=True) + has_stopped = np.zeros([d_batch], dtype=np.bool) + gen_seq_len = np.zeros([d_batch], dtype=np.int64) + output_ids = [] + + past_cache = info['caches'] + + cls_ids = paddle.ones([d_batch], dtype='int64') * sos_id + attn_ids = paddle.ones([d_batch], dtype='int64') * attn_id + ids = paddle.stack([cls_ids, attn_ids], -1) + for step in range(max_decode_len): + bias = gen_bias(q_ids, ids, step) + pos_ids = paddle.to_tensor( + np.tile( + np.array( + [[step, step + 1]], dtype=np.int64), [d_batch, 1])) + pos_ids += seqlen + _, logits, info = model( + ids, + paddle.ones_like(ids) * tgt_type_id, + pos_ids=pos_ids, + attn_bias=bias, + past_cache=past_cache) + + if logits.shape[-1] > vocab_size: + logits[:, :, vocab_size:] = 0 + logits[:, :, pad_id] = 0 + logits[:, :, unk_id] = 0 + logits[:, :, attn_id] = 0 + + gen_ids = paddle.argmax(logits, -1) + + past_cached_k, past_cached_v = past_cache + cached_k, cached_v = info['caches'] + cached_k = [ + paddle.concat([pk, k[:, :1, :]], 1) + for pk, k in zip(past_cached_k, cached_k) + ] # concat cached + cached_v = [ + paddle.concat([pv, v[:, :1, :]], 1) + for pv, v in zip(past_cached_v, cached_v) + ] + past_cache = (cached_k, cached_v) + + gen_ids = gen_ids[:, 1] + ids = paddle.stack([gen_ids, attn_ids], 1) + + gen_ids = gen_ids.numpy() + has_stopped |= (gen_ids == eos_id).astype(np.bool) + gen_seq_len += (1 - has_stopped.astype(np.int64)) + output_ids.append(gen_ids.tolist()) + if has_stopped.all(): + break + output_ids = np.array(output_ids).transpose([1, 0]) + return output_ids + + +BeamSearchState = namedtuple('BeamSearchState', + ['log_probs', 'lengths', 'finished']) +BeamSearchOutput = namedtuple('BeamSearchOutput', + ['scores', 'predicted_ids', 'beam_parent_ids']) + + +def log_softmax(x): + e_x = np.exp(x - np.max(x)) + return np.log(e_x / e_x.sum()) + + +def mask_prob(p, onehot_eos, finished): + is_finished = paddle.cast(paddle.reshape(finished, [-1, 1]) != 0, 'float32') + p = is_finished * (1. - paddle.cast(onehot_eos, 'float32')) * -9999. + ( + 1. - is_finished) * p + return p + + +def hyp_score(log_probs, length, length_penalty): + lp = paddle.pow((5. + paddle.cast(length, 'float32')) / 6., length_penalty) + return log_probs / lp + + +def beam_search_step(state, logits, eos_id, beam_width, is_first_step, + length_penalty): + """logits.shape == [B*W, V]""" + _, vocab_size = logits.shape + + bsz, beam_width = state.log_probs.shape + onehot_eos = paddle.cast( + nn.functional.one_hot(paddle.ones([1], 'int64') * eos_id, vocab_size), + 'int64') #[1, V] + + probs = paddle.log(nn.functional.softmax(logits)) #[B*W, V] + probs = mask_prob(probs, onehot_eos, state.finished) #[B*W, V] + allprobs = paddle.reshape(state.log_probs, [-1, 1]) + probs #[B*W, V] + + not_finished = 1 - paddle.reshape(state.finished, [-1, 1]) #[B*W,1] + not_eos = 1 - onehot_eos + length_to_add = not_finished * not_eos #[B*W,V] + alllen = paddle.reshape(state.lengths, [-1, 1]) + length_to_add + + allprobs = paddle.reshape(allprobs, [-1, beam_width * vocab_size]) + alllen = paddle.reshape(alllen, [-1, beam_width * vocab_size]) + allscore = hyp_score(allprobs, alllen, length_penalty) + if is_first_step: + allscore = paddle.reshape( + allscore, + [bsz, beam_width, -1])[:, 0, :] # first step only consiter beam 0 + scores, idx = paddle.topk(allscore, k=beam_width) #[B, W] + next_beam_id = idx // vocab_size #[B, W] + next_word_id = idx % vocab_size + + gather_idx = paddle.concat( + [paddle.nonzero(idx != -1)[:, :1], paddle.reshape(idx, [-1, 1])], 1) + next_probs = paddle.reshape( + paddle.gather_nd(allprobs, gather_idx), idx.shape) + next_len = paddle.reshape(paddle.gather_nd(alllen, gather_idx), idx.shape) + + gather_idx = paddle.concat([ + paddle.nonzero(next_beam_id != -1)[:, :1], paddle.reshape(next_beam_id, + [-1, 1]) + ], 1) + next_finished = paddle.reshape( + paddle.gather_nd(state.finished, gather_idx), + state.finished.shape) #[gather new beam state according to new beam id] + + next_finished += paddle.cast(next_word_id == eos_id, 'int64') + next_finished = paddle.cast(next_finished > 0, 'int64') + + next_state = BeamSearchState( + log_probs=next_probs, lengths=next_len, finished=next_finished) + output = BeamSearchOutput( + scores=scores, predicted_ids=next_word_id, beam_parent_ids=next_beam_id) + + return output, next_state + + +@paddle.no_grad() +def beam_search_infilling(model, + q_ids, + q_sids, + sos_id, + eos_id, + attn_id, + pad_id, + unk_id, + vocab_size, + max_encode_len=640, + max_decode_len=100, + beam_width=5, + tgt_type_id=3, + length_penalty=1.0): + _, __, info = model(q_ids, q_sids) + d_batch, d_seqlen = q_ids.shape + + state = BeamSearchState( + log_probs=paddle.zeros([d_batch, beam_width], 'float32'), + lengths=paddle.zeros([d_batch, beam_width], 'int64'), + finished=paddle.zeros([d_batch, beam_width], 'int64')) + outputs = [] + + def reorder_(t, parent_id): + """reorder cache according to parent beam id""" + gather_idx = paddle.nonzero( + parent_id != -1)[:, 0] * beam_width + paddle.reshape(parent_id, + [-1]) + t = paddle.gather(t, gather_idx) + return t + + def tile_(t, times): + _shapes = list(t.shape[1:]) + new_shape = [t.shape[0], times] + list(t.shape[1:]) + ret = paddle.reshape( + paddle.expand(paddle.unsqueeze(t, [1]), new_shape), + [-1, ] + _shapes) + return ret + + cached_k, cached_v = info['caches'] + cached_k = [tile_(k, beam_width) for k in cached_k] + cached_v = [tile_(v, beam_width) for v in cached_v] + past_cache = (cached_k, cached_v) + + q_ids = tile_(q_ids, beam_width) + seqlen = paddle.sum(paddle.cast(q_ids != 0, 'int64'), 1, keepdim=True) + #log.debug(q_ids.shape) + + cls_ids = paddle.ones([d_batch * beam_width], dtype='int64') * sos_id + attn_ids = paddle.ones( + [d_batch * beam_width], dtype='int64') * attn_id # SOS + ids = paddle.stack([cls_ids, attn_ids], -1) + for step in range(max_decode_len): + #log.debug('decode step %d' % step) + bias = gen_bias(q_ids, ids, step) + pos_ids = paddle.to_tensor( + np.tile( + np.array( + [[step, step + 1]], dtype=np.int64), + [d_batch * beam_width, 1])) + pos_ids += seqlen + _, logits, info = model( + ids, + paddle.ones_like(ids) * tgt_type_id, + pos_ids=pos_ids, + attn_bias=bias, + past_cache=past_cache) + if logits.shape[-1] > vocab_size: + logits[:, :, vocab_size:] = 0 + logits[:, :, pad_id] = 0 + logits[:, :, unk_id] = 0 + logits[:, :, attn_id] = 0 + + output, state = beam_search_step( + state, + logits[:, 1], + eos_id=eos_id, + beam_width=beam_width, + is_first_step=(step == 0), + length_penalty=length_penalty) + outputs.append(output) + + past_cached_k, past_cached_v = past_cache + cached_k, cached_v = info['caches'] + cached_k = [ + reorder_( + paddle.concat([pk, k[:, :1, :]], 1), output.beam_parent_ids) + for pk, k in zip(past_cached_k, cached_k) + ] # concat cached + cached_v = [ + reorder_( + paddle.concat([pv, v[:, :1, :]], 1), output.beam_parent_ids) + for pv, v in zip(past_cached_v, cached_v) + ] + past_cache = (cached_k, cached_v) + + pred_ids_flatten = paddle.reshape(output.predicted_ids, + [d_batch * beam_width]) + ids = paddle.stack([pred_ids_flatten, attn_ids], 1) + + if state.finished.numpy().all(): + break + + final_ids = paddle.stack([o.predicted_ids for o in outputs], 0) + final_parent_ids = paddle.stack([o.beam_parent_ids for o in outputs], 0) + final_ids = nn.functional.gather_tree( + final_ids, final_parent_ids)[:, :, 0] #pick best beam + final_ids = paddle.transpose( + paddle.reshape(final_ids, [-1, d_batch * 1]), [1, 0]) + + return final_ids.numpy() + + +en_patten = re.compile(r'^[a-zA-Z0-9]*$') + + +def post_process(token): + if token.startswith('##'): + ret = token[2:] + elif token in ['[CLS]', '[SEP]', '[PAD]']: + ret = '' + else: + if en_patten.match(token): + ret = ' ' + token + else: + ret = token + return ret diff --git a/PaddleNLP/examples/text_generation/ernie-gen/encode.py b/PaddleNLP/examples/text_generation/ernie-gen/encode.py new file mode 100644 index 0000000000000000000000000000000000000000..5deb0a74ffee443214299078353d467ac06e8135 --- /dev/null +++ b/PaddleNLP/examples/text_generation/ernie-gen/encode.py @@ -0,0 +1,123 @@ +from copy import deepcopy + +import numpy as np + + +def convert_example(tokenizer, + attn_id, + tgt_type_id=3, + max_encode_len=512, + max_decode_len=128, + is_test=False, + noise_prob=0., + use_random_noice=False): + def warpper(example): + """convert an example into necessary features""" + + encoded_src = tokenizer.encode( + example[0], max_seq_len=max_encode_len, pad_to_max_seq_len=False) + src_ids, src_sids = encoded_src["input_ids"], encoded_src["segment_ids"] + src_pids = np.arange(len(src_ids)) + + if not is_test: + encoded_tgt = tokenizer.encode( + example[1], + max_seq_len=max_decode_len, + pad_to_max_seq_len=False) + tgt_ids, tgt_sids = encoded_tgt["input_ids"], encoded_tgt[ + "segment_ids"] + tgt_ids = np.array(tgt_ids) + tgt_sids = np.array(tgt_sids) + tgt_type_id + tgt_pids = np.arange(len(tgt_ids)) + len(src_ids) + + attn_ids = np.ones_like(tgt_ids) * attn_id + if noise_prob > 0.: + tgt_labels = deepcopy(tgt_ids) + if use_random_noice: + noice_ids = np.random.randint( + 1, len(tokenizer.vocab), size=tgt_ids.shape) + else: + noice_ids = np.ones_like(tgt_ids) * tokenizer.vocab['[NOISE]'] + pos, = np.where(np.ones_like(tgt_ids)) + np.random.shuffle(pos) + pos = pos[:int(noise_prob * len(pos))] + tgt_ids[pos, ] = noice_ids[pos, ] + else: + tgt_labels = tgt_ids + + return (src_ids, src_pids, src_sids, tgt_ids, tgt_pids, tgt_sids, + attn_ids, tgt_labels) + + return warpper + + +def gen_mask(batch_ids, mask_type='bidi', query_len=None, pad_value=0): + if query_len is None: + query_len = batch_ids.shape[1] + if mask_type != 'empty': + mask = (batch_ids != pad_value).astype(np.float32) + mask = np.tile(np.expand_dims(mask, 1), [1, query_len, 1]) + if mask_type == 'causal': + assert query_len == batch_ids.shape[1] + mask = np.tril(mask) + elif mask_type == 'causal_without_diag': + assert query_len == batch_ids.shape[1] + mask = np.tril(mask, -1) + elif mask_type == 'diag': + assert query_len == batch_ids.shape[1] + # import pdb; pdb.set_trace() + mask = np.stack([np.diag(np.diag(m)) for m in mask], 0) + + else: + mask_type == 'empty' + mask = np.zeros_like(batch_ids).astype(np.float32) + mask = np.tile(np.expand_dims(mask, 1), [1, query_len, 1]) + return mask + + +def after_padding(args): + ''' + attention mask: + *** src, tgt, attn + src 00, 01, 11 + tgt 10, 11, 12 + attn 20, 21, 22 + + *** s1, s2 | t1 t2 t3| attn1 attn2 attn3 + s1 1, 1 | 0, 0, 0,| 0, 0, 0, + s2 1, 1 | 0, 0, 0,| 0, 0, 0, + - + t1 1, 1, | 1, 0, 0,| 0, 0, 0, + t2 1, 1, | 1, 1, 0,| 0, 0, 0, + t3 1, 1, | 1, 1, 1,| 0, 0, 0, + - + attn1 1, 1, | 0, 0, 0,| 1, 0, 0, + attn2 1, 1, | 1, 0, 0,| 0, 1, 0, + attn3 1, 1, | 1, 1, 0,| 0, 0, 1, + + for details, see Fig3. https://arxiv.org/abs/2001.11314 + ''' + src_ids, src_pids, src_sids, tgt_ids, tgt_pids, tgt_sids, attn_ids, tgt_labels = args + src_len = src_ids.shape[1] + tgt_len = tgt_ids.shape[1] + mask_00 = gen_mask(src_ids, 'bidi', query_len=src_len) + mask_01 = gen_mask(tgt_ids, 'empty', query_len=src_len) + mask_02 = gen_mask(attn_ids, 'empty', query_len=src_len) + + mask_10 = gen_mask(src_ids, 'bidi', query_len=tgt_len) + mask_11 = gen_mask(tgt_ids, 'causal', query_len=tgt_len) + mask_12 = gen_mask(attn_ids, 'empty', query_len=tgt_len) + + mask_20 = gen_mask(src_ids, 'bidi', query_len=tgt_len) + mask_21 = gen_mask(tgt_ids, 'causal_without_diag', query_len=tgt_len) + mask_22 = gen_mask(attn_ids, 'diag', query_len=tgt_len) + + mask_src_2_src = mask_00 + mask_tgt_2_srctgt = np.concatenate([mask_10, mask_11], 2) + mask_attn_2_srctgtattn = np.concatenate([mask_20, mask_21, mask_22], 2) + + raw_tgt_labels = deepcopy(tgt_labels) + tgt_labels = tgt_labels[np.where(tgt_labels != 0)] + return (src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids, attn_ids, + mask_src_2_src, mask_tgt_2_srctgt, mask_attn_2_srctgtattn, + tgt_labels, raw_tgt_labels) diff --git a/PaddleNLP/examples/text_generation/ernie-gen/eval.py b/PaddleNLP/examples/text_generation/ernie-gen/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..5773dcbb525b386cf49822228f07fb14ace2f4f8 --- /dev/null +++ b/PaddleNLP/examples/text_generation/ernie-gen/eval.py @@ -0,0 +1,137 @@ +import os +import ast +import time +import argparse +import logging + +import paddle +import paddle.nn as nn +from tqdm import tqdm +from paddle.io import DataLoader +from paddlenlp.transformers import ErnieForGeneration +from paddlenlp.transformers import ErnieTokenizer, ErnieTinyTokenizer, BertTokenizer, ElectraTokenizer, RobertaTokenizer +from paddlenlp.datasets import Poetry +from paddlenlp.data import Stack, Tuple, Pad +from paddlenlp.metrics import Rouge1, Rouge2 +from paddlenlp.utils.log import logger + +from encode import convert_example, after_padding +from decode import beam_search_infilling, post_process, greedy_search_infilling + +# yapf: disable +parser = argparse.ArgumentParser('seq2seq model with ERNIE-GEN') +parser.add_argument("--model_name_or_path", default=None, type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: "+ ", ".join(list(ErnieTokenizer.pretrained_init_configuration.keys()))) +parser.add_argument('--max_encode_len', type=int, default=24, help="The max encoding sentence length") +parser.add_argument('--max_decode_len', type=int, default=72, help="The max decoding sentence length") +parser.add_argument("--batch_size", default=50, type=int, help="Batch size per GPU/CPU for training.", ) +parser.add_argument('--beam_width', type=int, default=1, help="Beam search width") +parser.add_argument('--length_penalty', type=float, default=1.0, help="The length penalty during decoding") +parser.add_argument('--init_checkpoint', type=str, default=None, help='Checkpoint to warm start from') +parser.add_argument('--use_gpu', action='store_true', help='If set, use gpu to excute') +# yapf: enable + +args = parser.parse_args() + + +def evaluate(): + paddle.set_device("gpu" if args.use_gpu else "cpu") + + model = ErnieForGeneration.from_pretrained(args.model_name_or_path) + if "ernie-tiny" in args.model_name_or_path: + tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path) + elif "ernie" in args.model_name_or_path: + tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path) + elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path: + tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path) + elif "electra" in args.model_name_or_path: + tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path) + else: + tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path) + + dev_dataset = Poetry.get_datasets(['dev']) + attn_id = tokenizer.vocab[ + '[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]'] + tgt_type_id = model.sent_emb.weight.shape[0] - 1 + + trans_func = convert_example( + tokenizer=tokenizer, + attn_id=attn_id, + tgt_type_id=tgt_type_id, + max_encode_len=args.max_encode_len, + max_decode_len=args.max_decode_len) + + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_ids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_pids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_sids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_ids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_pids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_sids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # attn_ids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_labels + ): after_padding(fn(samples)) + + dev_dataset = dev_dataset.apply(trans_func, lazy=True) + dev_batch_sampler = paddle.io.BatchSampler( + dev_dataset, batch_size=args.batch_size, shuffle=False) + data_loader = DataLoader( + dataset=dev_dataset, + batch_sampler=dev_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + rouge1 = Rouge1() + rouge2 = Rouge2() + + if args.init_checkpoint: + model_state = paddle.load(args.init_checkpoint) + model.set_state_dict(model_state) + + model.eval() + vocab = tokenizer.vocab + eos_id = vocab[tokenizer.sep_token] + sos_id = vocab[tokenizer.cls_token] + pad_id = vocab[tokenizer.pad_token] + unk_id = vocab[tokenizer.unk_token] + vocab_size = len(vocab) + evaluated_sentences_ids = [] + reference_sentences_ids = [] + logger.info("Evaluating...") + for data in tqdm(data_loader): + (src_ids, src_sids, src_pids, _, _, _, _, _, _, _, _, + raw_tgt_labels) = data # never use target when infer + # Use greedy_search_infilling or beam_search_infilling to get predictions + output_ids = beam_search_infilling( + model, + src_ids, + src_sids, + eos_id=eos_id, + sos_id=sos_id, + attn_id=attn_id, + pad_id=pad_id, + unk_id=unk_id, + vocab_size=vocab_size, + max_decode_len=args.max_decode_len, + max_encode_len=args.max_encode_len, + beam_width=args.beam_width, + length_penalty=args.length_penalty, + tgt_type_id=tgt_type_id) + + for ids in output_ids.tolist(): + if eos_id in ids: + ids = ids[:ids.index(eos_id)] + evaluated_sentences_ids.append(ids) + + for ids in raw_tgt_labels.numpy().tolist(): + ids = ids[:ids.index(eos_id)] + reference_sentences_ids.append(ids) + + score1 = rouge1.score(evaluated_sentences_ids, reference_sentences_ids) + score2 = rouge2.score(evaluated_sentences_ids, reference_sentences_ids) + + logger.info("Rouge-1: %.5f ,Rouge-2: %.5f" % (score1 * 100, score2 * 100)) + + +if __name__ == "__main__": + evaluate() diff --git a/PaddleNLP/examples/text_generation/ernie-gen/predict.py b/PaddleNLP/examples/text_generation/ernie-gen/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..b1602becfd413354192fd88fd1160b5f7b5e5569 --- /dev/null +++ b/PaddleNLP/examples/text_generation/ernie-gen/predict.py @@ -0,0 +1,136 @@ +import os +import ast +import time +import argparse +import logging + +import paddle +import paddle.nn as nn +from paddle.io import DataLoader +from paddlenlp.transformers import ErnieForGeneration +from paddlenlp.transformers import ErnieTokenizer, ErnieTinyTokenizer, BertTokenizer, ElectraTokenizer, RobertaTokenizer +from paddlenlp.datasets import Poetry +from paddlenlp.data import Stack, Tuple, Pad +from paddlenlp.metrics import Rouge1, Rouge2 +from paddlenlp.utils.log import logger + +from encode import convert_example, after_padding +from decode import beam_search_infilling, post_process, greedy_search_infilling + +# yapf: disable +parser = argparse.ArgumentParser('seq2seq model with ERNIE-GEN') +parser.add_argument("--model_name_or_path", default=None, type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: "+ ", ".join(list(ErnieTokenizer.pretrained_init_configuration.keys()))) +parser.add_argument('--max_encode_len', type=int, default=24, help="The max encoding sentence length") +parser.add_argument('--max_decode_len', type=int, default=72, help="The max decoding sentence length") +parser.add_argument("--batch_size", default=50, type=int, help="Batch size per GPU/CPU for training.", ) +parser.add_argument('--beam_width', type=int, default=3, help="Beam search width") +parser.add_argument('--length_penalty', type=float, default=1.0, help="The length penalty during decoding") +parser.add_argument('--init_checkpoint', type=str, default=None, help='Checkpoint to warm start from') +parser.add_argument('--use_gpu', action='store_true', help='If set, use gpu to excute') +# yapf: enable + +args = parser.parse_args() + + +def predict(): + paddle.set_device("gpu" if args.use_gpu else "cpu") + + model = ErnieForGeneration.from_pretrained(args.model_name_or_path) + if "ernie-tiny" in args.model_name_or_path: + tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path) + elif "ernie" in args.model_name_or_path: + tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path) + elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path: + tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path) + elif "electra" in args.model_name_or_path: + tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path) + else: + tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path) + + dev_dataset = Poetry.get_datasets(['test']) + attn_id = tokenizer.vocab[ + '[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]'] + tgt_type_id = model.sent_emb.weight.shape[0] - 1 + + trans_func = convert_example( + tokenizer=tokenizer, + attn_id=attn_id, + tgt_type_id=tgt_type_id, + max_encode_len=args.max_encode_len, + max_decode_len=args.max_decode_len) + + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_ids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_pids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_sids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_ids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_pids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_sids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # attn_ids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_labels + ): after_padding(fn(samples)) + + dev_dataset = dev_dataset.apply(trans_func, lazy=True) + test_batch_sampler = paddle.io.BatchSampler( + dev_dataset, batch_size=args.batch_size, shuffle=False) + data_loader = DataLoader( + dataset=dev_dataset, + batch_sampler=test_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + if args.init_checkpoint: + model_state = paddle.load(args.init_checkpoint) + model.set_state_dict(model_state) + + model.eval() + vocab = tokenizer.vocab + eos_id = vocab[tokenizer.sep_token] + sos_id = vocab[tokenizer.cls_token] + pad_id = vocab[tokenizer.pad_token] + unk_id = vocab[tokenizer.unk_token] + vocab_size = len(vocab) + evaluated_sentences = [] + evaluated_sentences_ids = [] + logger.info("Predicting...") + for data in data_loader: + (src_ids, src_sids, src_pids, _, _, _, _, _, _, _, _, + raw_tgt_labels) = data # never use target when infer + # Use greedy_search_infilling or beam_search_infilling to get predictions + output_ids = beam_search_infilling( + model, + src_ids, + src_sids, + eos_id=eos_id, + sos_id=sos_id, + attn_id=attn_id, + pad_id=pad_id, + unk_id=unk_id, + vocab_size=vocab_size, + max_decode_len=args.max_decode_len, + max_encode_len=args.max_encode_len, + beam_width=args.beam_width, + length_penalty=args.length_penalty, + tgt_type_id=tgt_type_id) + + for source_ids, target_ids, predict_ids in zip( + src_ids.numpy().tolist(), + raw_tgt_labels.numpy().tolist(), output_ids.tolist()): + if eos_id in predict_ids: + predict_ids = predict_ids[:predict_ids.index(eos_id)] + source_sentence = ''.join( + map(post_process, + vocab.to_tokens(source_ids[1:source_ids.index(eos_id)]))) + tgt_sentence = ''.join( + map(post_process, + vocab.to_tokens(target_ids[1:target_ids.index(eos_id)]))) + predict_ids = ''.join( + map(post_process, vocab.to_tokens(predict_ids))) + print("source :%s\ntarget :%s\npredict:%s\n" % + (source_sentence, tgt_sentence, predict_ids)) + break + + +if __name__ == "__main__": + predict() diff --git a/PaddleNLP/examples/text_generation/ernie-gen/train.py b/PaddleNLP/examples/text_generation/ernie-gen/train.py new file mode 100644 index 0000000000000000000000000000000000000000..507ac86098fa26a9f1057332c4f3530984dda147 --- /dev/null +++ b/PaddleNLP/examples/text_generation/ernie-gen/train.py @@ -0,0 +1,275 @@ +import os +import ast +import time +import argparse +import logging + +import paddle +from tqdm import tqdm +import paddle.nn as nn +from paddle.io import DataLoader +from paddlenlp.transformers import ErnieForGeneration +from paddlenlp.transformers import ErnieTokenizer, ErnieTinyTokenizer, BertTokenizer, ElectraTokenizer, RobertaTokenizer +from paddlenlp.datasets import Poetry +from paddlenlp.data import Stack, Tuple, Pad +from paddlenlp.metrics import Rouge1, Rouge2 +from paddlenlp.utils.log import logger + +from encode import convert_example, after_padding +from decode import beam_search_infilling, post_process, greedy_search_infilling + +# yapf: disable +parser = argparse.ArgumentParser('seq2seq model with ERNIE-GEN') +parser.add_argument("--model_name_or_path", default=None, type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: "+ ", ".join(list(ErnieTokenizer.pretrained_init_configuration.keys()))) +parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.",) +parser.add_argument('--max_encode_len', type=int, default=5, help="The max encoding sentence length") +parser.add_argument('--max_decode_len', type=int, default=5, help="The max decoding sentence length") +parser.add_argument("--batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.", ) +parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") +parser.add_argument("--weight_decay", default=0.1, type=float, help="Weight decay if we apply some.") +parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") +parser.add_argument("--num_epochs", default=3, type=int, help="Total number of training epochs to perform.", ) +parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_epochs.",) +parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Linear warmup proportion.") +parser.add_argument("--logging_steps", type=int, default=1, help="Log every X updates steps.") +parser.add_argument("--save_steps", type=int, default=100, help="Save checkpoint every X updates steps.") +parser.add_argument("--n_gpu", type=int, default=1, help="Number of gpus to use, 0 for cpu.") +parser.add_argument('--beam_width', type=int, default=1, help="Beam search width") +parser.add_argument('--noise_prob', type=float, default=0., help='Probability of token be repalced') +parser.add_argument('--use_random_noice', action='store_true', help='If set, replace target tokens with random token from vocabulary, else replace with `[NOISE]`') +parser.add_argument('--label_smooth', type=float, default=0., help="The soft label smooth rate") +parser.add_argument('--length_penalty', type=float, default=1.0, help="The length penalty during decoding") +parser.add_argument('--init_checkpoint', type=str, default=None, help='Checkpoint to warm start from') +parser.add_argument('--save_dir', type=str, default=None, help='Model output directory') +# yapf: enable + +args = parser.parse_args() + + +def evaluate(model, data_loader, tokenizer, rouge1, rouge2, attn_id, + tgt_type_id, args): + model.eval() + + vocab = tokenizer.vocab + eos_id = vocab[tokenizer.sep_token] + sos_id = vocab[tokenizer.cls_token] + pad_id = vocab[tokenizer.pad_token] + unk_id = vocab[tokenizer.unk_token] + vocab_size = len(vocab) + evaluated_sentences_ids = [] + reference_sentences_ids = [] + logger.info("Evaluating...") + for data in tqdm(data_loader): + (src_ids, src_sids, src_pids, _, _, _, _, _, _, _, _, + raw_tgt_labels) = data # never use target when infer + # Use greedy_search_infilling or beam_search_infilling to get predictions + output_ids = beam_search_infilling( + model, + src_ids, + src_sids, + eos_id=eos_id, + sos_id=sos_id, + attn_id=attn_id, + pad_id=pad_id, + unk_id=unk_id, + vocab_size=vocab_size, + max_decode_len=args.max_decode_len, + max_encode_len=args.max_encode_len, + beam_width=args.beam_width, + length_penalty=args.length_penalty, + tgt_type_id=tgt_type_id) + + for ids in output_ids.tolist(): + if eos_id in ids: + ids = ids[:ids.index(eos_id)] + evaluated_sentences_ids.append(ids) + + for ids in raw_tgt_labels.numpy().tolist(): + ids = ids[:ids.index(eos_id)] + reference_sentences_ids.append(ids) + + score1 = rouge1.score(evaluated_sentences_ids, reference_sentences_ids) + score2 = rouge2.score(evaluated_sentences_ids, reference_sentences_ids) + + logger.info("Rouge-1: %.5f ,Rouge-2: %.5f" % (score1 * 100, score2 * 100)) + + evaluated_sentences = [] + reference_sentences = [] + for ids in reference_sentences_ids[:5]: + reference_sentences.append(''.join( + map(post_process, vocab.to_tokens(ids)))) + for ids in evaluated_sentences_ids[:5]: + evaluated_sentences.append(''.join( + map(post_process, vocab.to_tokens(ids)))) + logger.debug(reference_sentences) + logger.debug(evaluated_sentences) + + model.train() + + +def train(): + paddle.set_device("gpu" if args.n_gpu else "cpu") + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + model = ErnieForGeneration.from_pretrained(args.model_name_or_path) + if "ernie-tiny" in args.model_name_or_path: + tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path) + elif "ernie" in args.model_name_or_path: + tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path) + elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path: + tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path) + elif "electra" in args.model_name_or_path: + tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path) + else: + tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path) + if args.init_checkpoint: + model_state = paddle.load(args.init_checkpoint) + model.set_state_dict(model_state) + + train_dataset, dev_dataset = Poetry.get_datasets(['train', 'dev']) + attn_id = tokenizer.vocab[ + '[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]'] + tgt_type_id = model.sent_emb.weight.shape[0] - 1 + + trans_func = convert_example( + tokenizer=tokenizer, + attn_id=attn_id, + tgt_type_id=tgt_type_id, + max_encode_len=args.max_encode_len, + max_decode_len=args.max_decode_len, + noise_prob=args.noise_prob, + use_random_noice=args.use_random_noice) + + train_dataset = train_dataset.apply(trans_func, lazy=True) + train_batch_sampler = paddle.io.DistributedBatchSampler( + train_dataset, batch_size=args.batch_size, shuffle=True) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_ids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_pids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_sids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_ids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_pids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_sids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # attn_ids + Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_labels + ): after_padding(fn(samples)) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_sampler=train_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + dev_dataset = dev_dataset.apply(trans_func, lazy=True) + dev_batch_sampler = paddle.io.BatchSampler( + dev_dataset, batch_size=args.batch_size, shuffle=False) + dev_data_loader = DataLoader( + dataset=dev_dataset, + batch_sampler=dev_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + label_num = model.word_emb.weight.shape[0] + if paddle.distributed.get_world_size() > 1: + model = paddle.DataParallel(model) + + max_steps = (len(train_data_loader) * args.num_epochs) + lr_scheduler = paddle.optimizer.lr.LambdaDecay( + args.learning_rate, + lambda current_step, num_warmup_steps=max_steps*args.warmup_proportion, + num_training_steps=max_steps: float( + current_step) / float(max(1, num_warmup_steps)) + if current_step < num_warmup_steps else max( + 0.0, + float(num_training_steps - current_step) / float( + max(1, num_training_steps - num_warmup_steps)))) + + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + epsilon=args.adam_epsilon, + parameters=model.parameters(), + weight_decay=args.weight_decay, + grad_clip=nn.ClipGradByGlobalNorm(1.0), + apply_decay_param_fun=lambda x: x in [ + p.name for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ]) + + rouge1 = Rouge1() + rouge2 = Rouge2() + + global_step = 1 + tic_train = time.time() + for epoch in range(args.num_epochs): + for step, batch in enumerate(train_data_loader, start=1): + (src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids, attn_ids, + mask_src_2_src, mask_tgt_2_srctgt, mask_attn_2_srctgtattn, + tgt_labels, _) = batch + # import pdb; pdb.set_trace() + _, __, info = model( + src_ids, + sent_ids=src_sids, + pos_ids=src_pids, + attn_bias=mask_src_2_src, + encode_only=True) + cached_k, cached_v = info['caches'] + _, __, info = model( + tgt_ids, + sent_ids=tgt_sids, + pos_ids=tgt_pids, + attn_bias=mask_tgt_2_srctgt, + past_cache=(cached_k, cached_v), + encode_only=True) + cached_k2, cached_v2 = info['caches'] + past_cache_k = [ + paddle.concat([k, k2], 1) for k, k2 in zip(cached_k, cached_k2) + ] + past_cache_v = [ + paddle.concat([v, v2], 1) for v, v2 in zip(cached_v, cached_v2) + ] + if args.label_smooth > 0.: + tgt_labels = nn.functional.label_smooth( + nn.functional.one_hot(tgt_labels, label_num), + epsilon=args.label_smooth) + loss, _, __ = model( + attn_ids, + sent_ids=tgt_sids, + pos_ids=tgt_pids, + attn_bias=mask_attn_2_srctgtattn, + past_cache=(past_cache_k, past_cache_v), + tgt_labels=tgt_labels, + tgt_pos=paddle.nonzero(attn_ids == attn_id)) + if global_step % args.logging_steps == 0: + if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: + logger.info( + "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, lr: %.3e" + % (global_step, epoch, step, loss, args.logging_steps / + (time.time() - tic_train), lr_scheduler.get_lr())) + tic_train = time.time() + + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_gradients() + if global_step % args.save_steps == 0 and ( + (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0): + evaluate(model, dev_data_loader, tokenizer, rouge1, rouge2, + attn_id, tgt_type_id, args) + output_dir = os.path.join(args.output_dir, + "model_%d" % global_step) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + model_to_save = model._layers if isinstance( + model, paddle.DataParallel) else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + global_step += 1 + + +if __name__ == "__main__": + if args.n_gpu > 1: + paddle.distributed.spawn(train, nprocs=args.n_gpu) + else: + train() diff --git a/PaddleNLP/paddlenlp/datasets/__init__.py b/PaddleNLP/paddlenlp/datasets/__init__.py index 8480c1c6b756a30a3170370b0a83c68da6eaf625..a899b1d0fd1db3000e96cd9f39ab3cc7e8a26c99 100644 --- a/PaddleNLP/paddlenlp/datasets/__init__.py +++ b/PaddleNLP/paddlenlp/datasets/__init__.py @@ -21,3 +21,5 @@ from .ptb import * from .squad import * from .translation import * from .dureader import * +from .cnndm import * +from .poetry import * \ No newline at end of file diff --git a/PaddleNLP/paddlenlp/datasets/cnndm.py b/PaddleNLP/paddlenlp/datasets/cnndm.py new file mode 100644 index 0000000000000000000000000000000000000000..0f9e7f1c0964e03603adcf6c7504f75a54c151dd --- /dev/null +++ b/PaddleNLP/paddlenlp/datasets/cnndm.py @@ -0,0 +1,61 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import os +import warnings + +from paddle.io import Dataset +from paddle.dataset.common import DATA_HOME, md5file +from paddle.utils.download import get_path_from_url + +from .dataset import TSVDataset + +__all__ = ['CnnDm'] + + +class CnnDm(TSVDataset): + URL = "https://ernie-github.cdn.bcebos.com/data-cnndm.tar.gz" + MD5 = None + SEGMENT_INFO = collections.namedtuple( + 'SEGMENT_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples')) + SEGMENTS = { + 'train': SEGMENT_INFO( + os.path.join('cnndm', 'train', '1'), + '8b10ed0ae31e71e8cd9105a6978d8970', (1, 2), 0), + 'dev': SEGMENT_INFO( + os.path.join('cnndm', 'dev', '1'), + '7cb22f9cac04a285790a91cebba75260', (1, 2), 0), + } + + def __init__(self, segment='train', root=None, **kwargs): + default_root = os.path.join(DATA_HOME) + filename, data_hash, field_indices, num_discard_samples = self.SEGMENTS[ + segment] + fullname = os.path.join(default_root, + filename) if root is None else os.path.join( + os.path.expanduser(root), filename) + if not os.path.exists(fullname) or (data_hash and + not md5file(fullname) == data_hash): + if root is not None: # not specified, and no need to warn + warnings.warn( + 'md5 check failed for {}, download {} data to {}'.format( + filename, self.__class__.__name__, default_root)) + path = get_path_from_url(self.URL, default_root, self.MD5) + fullname = os.path.join(default_root, filename) + super(CnnDm, self).__init__( + fullname, + field_indices=field_indices, + num_discard_samples=num_discard_samples, + **kwargs) diff --git a/PaddleNLP/paddlenlp/datasets/poetry.py b/PaddleNLP/paddlenlp/datasets/poetry.py new file mode 100644 index 0000000000000000000000000000000000000000..6375e19d3aa73f0bb9fc9d4f6c7b0fa1e56a5df7 --- /dev/null +++ b/PaddleNLP/paddlenlp/datasets/poetry.py @@ -0,0 +1,64 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import os +import warnings + +from paddle.io import Dataset +from paddle.dataset.common import DATA_HOME, md5file +from paddle.utils.download import get_path_from_url + +from .dataset import TSVDataset + +__all__ = ['Poetry'] + + +class Poetry(TSVDataset): + URL = "https://paddlenlp.bj.bcebos.com/datasets/poetry.tar.gz" + MD5 = '8edd7eda1b273145b70ef29c82cd622b' + SEGMENT_INFO = collections.namedtuple( + 'SEGMENT_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples')) + SEGMENTS = { + 'train': SEGMENT_INFO( + os.path.join('poetry', 'train.tsv'), + '176c6202b5e71656ae7e7848eec4c54f', (0, 1), 0), + 'dev': SEGMENT_INFO( + os.path.join('poetry', 'dev.tsv'), + '737e4b6da5facdc0ac33fe688df19931', (0, 1), 0), + 'test': SEGMENT_INFO( + os.path.join('poetry', 'test.tsv'), + '1dca907b2d712730c7c828f8acee7431', (0, 1), 0), + } + + def __init__(self, segment='train', root=None, **kwargs): + default_root = os.path.join(DATA_HOME, 'poetry') + filename, data_hash, field_indices, num_discard_samples = self.SEGMENTS[ + segment] + fullname = os.path.join(default_root, + filename) if root is None else os.path.join( + os.path.expanduser(root), filename) + if not os.path.exists(fullname) or (data_hash and + not md5file(fullname) == data_hash): + if root is not None: # not specified, and no need to warn + warnings.warn( + 'md5 check failed for {}, download {} data to {}'.format( + filename, self.__class__.__name__, default_root)) + path = get_path_from_url(self.URL, default_root, self.MD5) + fullname = os.path.join(default_root, filename) + super(Poetry, self).__init__( + fullname, + field_indices=field_indices, + num_discard_samples=num_discard_samples, + **kwargs) diff --git a/PaddleNLP/paddlenlp/metrics/__init__.py b/PaddleNLP/paddlenlp/metrics/__init__.py index b2fc2fe7f389b649436b661fa31f7c59ddaa329a..a7197f355e233453e29c44db58b51a299262ec69 100644 --- a/PaddleNLP/paddlenlp/metrics/__init__.py +++ b/PaddleNLP/paddlenlp/metrics/__init__.py @@ -15,5 +15,5 @@ from .perplexity import Perplexity from .chunk import ChunkEvaluator from .bleu import BLEU, BLEUForDuReader -from .rouge import RougeL, RougeLForDuReader +from .rouge import RougeL, RougeLForDuReader, RougeN, Rouge1, Rouge2 from .glue import AccuracyAndF1, Mcc, PearsonAndSpearman diff --git a/PaddleNLP/paddlenlp/metrics/rouge.py b/PaddleNLP/paddlenlp/metrics/rouge.py index 769c5a004962bab2a79f91083610c4f36d750ded..8e7562a67608a0731589e13d1e529dfd26c9e7bd 100644 --- a/PaddleNLP/paddlenlp/metrics/rouge.py +++ b/PaddleNLP/paddlenlp/metrics/rouge.py @@ -20,6 +20,94 @@ from .utils import default_trans_func __all__ = ['RougeL', 'RougeLForDuReader'] +class RougeN(): + def __init__(self, n): + self.n = n + + def _get_ngrams(self, words): + """Calculates word n-grams for multiple sentences. + """ + ngram_set = set() + max_index_ngram_start = len(words) - self.n + for i in range(max_index_ngram_start + 1): + ngram_set.add(tuple(words[i:i + self.n])) + return ngram_set + + def score(self, evaluated_sentences_ids, reference_sentences_ids): + overlapping_count, reference_count = self.compute( + evaluated_sentences_ids, reference_sentences_ids) + return overlapping_count / reference_count + + def compute(self, evaluated_sentences_ids, reference_sentences_ids): + """ + Args: + evaluated_sentences (list): the sentences ids predicted by the model. + reference_sentences (list): the referenced sentences ids. Its size should be same as evaluated_sentences. + + Returns: + overlapping_count (int): the overlapping n-gram count. + reference_count (int): the reference sentences n-gram count. + """ + if len(evaluated_sentences_ids) <= 0 or len( + reference_sentences_ids) <= 0: + raise ValueError("Collections must contain at least 1 sentence.") + + reference_count = 0 + overlapping_count = 0 + + for evaluated_sentence_ids, reference_sentence_ids in zip( + evaluated_sentences_ids, reference_sentences_ids): + evaluated_ngrams = self._get_ngrams(evaluated_sentence_ids) + reference_ngrams = self._get_ngrams(reference_sentence_ids) + reference_count += len(reference_ngrams) + + # Gets the overlapping ngrams between evaluated and reference + overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) + overlapping_count += len(overlapping_ngrams) + + return overlapping_count, reference_count + + def accumulate(self): + """ + This function returns the mean precision, recall and f1 score for all accumulated minibatches. + + Returns: + float: mean precision, recall and f1 score. + """ + rouge_score = self.overlapping_count / self.reference_count + return rouge_score + + def reset(self): + """ + Reset function empties the evaluation memory for previous mini-batches. + """ + self.overlapping_count = 0 + self.reference_count = 0 + + def name(self): + """ + Return name of metric instance. + """ + return "Rouge-%s" % self.n + + def update(self, overlapping_count, reference_count): + """ + Args: + """ + self.overlapping_count += overlapping_count + self.reference_count += reference_count + + +class Rouge1(RougeN): + def __init__(self): + super(Rouge1, self).__init__(n=1) + + +class Rouge2(RougeN): + def __init__(self): + super(Rouge2, self).__init__(n=2) + + class RougeL(paddle.metric.Metric): r''' Rouge-L is Recall-Oriented Understudy for Gisting Evaluation based on Longest Common Subsequence (LCS). diff --git a/PaddleNLP/paddlenlp/transformers/__init__.py b/PaddleNLP/paddlenlp/transformers/__init__.py index c3d4bd614da67a4e10ab50d570477ee9224ad07d..179c3b4d2b3fe3c5beab15d84214555971a833f7 100644 --- a/PaddleNLP/paddlenlp/transformers/__init__.py +++ b/PaddleNLP/paddlenlp/transformers/__init__.py @@ -24,3 +24,4 @@ from .roberta.tokenizer import * from .electra.modeling import * from .electra.tokenizer import * from .transformer.modeling import * +from .ernie_gen.modeling import ErnieForGeneration diff --git a/PaddleNLP/paddlenlp/transformers/electra/modeling.py b/PaddleNLP/paddlenlp/transformers/electra/modeling.py index deb0b4dc821d8dc91c01aaa6b37631eb3c349c2a..196066cbc38d12fd44544c7f2ae6fa0f3434f212 100644 --- a/PaddleNLP/paddlenlp/transformers/electra/modeling.py +++ b/PaddleNLP/paddlenlp/transformers/electra/modeling.py @@ -24,8 +24,8 @@ import paddle.nn.functional as F from .. import PretrainedModel, register_base_model __all__ = [ - 'ElectraModel', 'ElectraForTotalPretraining', 'ElectraDiscriminator', - 'ElectraGenerator', 'ElectraClassificationHead', + 'ElectraModel', 'ElectraPretrainedModel', 'ElectraForTotalPretraining', + 'ElectraDiscriminator', 'ElectraGenerator', 'ElectraClassificationHead', 'ElectraForSequenceClassification', 'ElectraForTokenClassification', 'ElectraPretrainingCriterion' ] diff --git a/PaddleNLP/paddlenlp/transformers/ernie/tokenizer.py b/PaddleNLP/paddlenlp/transformers/ernie/tokenizer.py index e3c83b5db4f914080787669df9de0cf2c7d4bece..32f59b11a520864ca422731755927b06ab895f45 100644 --- a/PaddleNLP/paddlenlp/transformers/ernie/tokenizer.py +++ b/PaddleNLP/paddlenlp/transformers/ernie/tokenizer.py @@ -59,6 +59,12 @@ class ErnieTokenizer(PretrainedTokenizer): "https://paddlenlp.bj.bcebos.com/models/transformers/ernie_v2_base/vocab.txt", "ernie-2.0-large-en": "https://paddlenlp.bj.bcebos.com/models/transformers/ernie_v2_large/vocab.txt", + "ernie-gen-base-en": + "https://paddlenlp.bj.bcebos.com/models/transformers/ernie-gen-base-en/vocab.txt", + "ernie-gen-large-en": + "https://paddlenlp.bj.bcebos.com/models/transformers/ernie-gen-large/vocab.txt", + "ernie-gen-large-430g-en": + "https://paddlenlp.bj.bcebos.com/models/transformers/ernie-gen-large-430g/vocab.txt", } } pretrained_init_configuration = { @@ -71,6 +77,15 @@ class ErnieTokenizer(PretrainedTokenizer): "ernie-2.0-large-en": { "do_lower_case": True }, + "ernie-gen-base-en": { + "do_lower_case": True + }, + "ernie-gen-large-en": { + "do_lower_case": True + }, + "ernie-gen-large-430g-en": { + "do_lower_case": True + }, } def __init__(self, diff --git a/PaddleNLP/paddlenlp/transformers/ernie_gen/modeling.py b/PaddleNLP/paddlenlp/transformers/ernie_gen/modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..67272447d202924dce4ad1e5e1650d324a4a9939 --- /dev/null +++ b/PaddleNLP/paddlenlp/transformers/ernie_gen/modeling.py @@ -0,0 +1,613 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import io +import copy +import logging +import six +import json + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddlenlp.utils.env import MODEL_HOME +from paddle.utils.download import get_path_from_url +from paddlenlp.utils.log import logger +from paddlenlp.transformers import BertPretrainedModel, ElectraPretrainedModel, RobertaPretrainedModel, ErniePretrainedModel + +from ..utils import InitTrackerMeta, fn_args_to_dict + +__all__ = ["ErnieGenPretrainedModel", "ErnieForGeneration"] + + +def _build_linear(n_in, n_out, name, init): + return nn.Linear( + n_in, + n_out, + weight_attr=paddle.ParamAttr( + name='%s.w_0' % name if name is not None else None, + initializer=init), + bias_attr='%s.b_0' % name if name is not None else None, ) + + +def _build_ln(n_in, name): + return nn.LayerNorm( + normalized_shape=n_in, + weight_attr=paddle.ParamAttr( + name='%s_layer_norm_scale' % name if name is not None else None, + initializer=nn.initializer.Constant(1.)), + bias_attr=paddle.ParamAttr( + name='%s_layer_norm_bias' % name if name is not None else None, + initializer=nn.initializer.Constant(1.)), ) + + +def append_name(name, postfix): + if name is None: + ret = None + elif name == '': + ret = postfix + else: + ret = '%s_%s' % (name, postfix) + return ret + + +class AttentionLayer(nn.Layer): + def __init__(self, cfg, name=None): + super(AttentionLayer, self).__init__() + initializer = nn.initializer.TruncatedNormal( + std=cfg['initializer_range']) + d_model = cfg['hidden_size'] + n_head = cfg['num_attention_heads'] + assert d_model % n_head == 0 + d_model_q = cfg.get('query_hidden_size_per_head', + d_model // n_head) * n_head + d_model_v = cfg.get('value_hidden_size_per_head', + d_model // n_head) * n_head + self.n_head = n_head + self.d_key = d_model_q // n_head + self.q = _build_linear(d_model, d_model_q, + append_name(name, 'query_fc'), initializer) + self.k = _build_linear(d_model, d_model_q, + append_name(name, 'key_fc'), initializer) + self.v = _build_linear(d_model, d_model_v, + append_name(name, 'value_fc'), initializer) + self.o = _build_linear(d_model_v, d_model, + append_name(name, 'output_fc'), initializer) + self.dropout = nn.Dropout(p=cfg['attention_probs_dropout_prob']) + + def forward(self, queries, keys, values, attn_bias, past_cache): + assert len(queries.shape) == len(keys.shape) == len(values.shape) == 3 + #bsz, q_len, q_dim = queries.shape + #bsz, k_len, k_dim = keys.shape + #bsz, v_len, v_dim = values.shape + #assert k_len == v_len + + q = self.q(queries) + k = self.k(keys) + v = self.v(values) + + cache = (k, v) + if past_cache is not None: + cached_k, cached_v = past_cache + k = paddle.concat([cached_k, k], 1) + v = paddle.concat([cached_v, v], 1) + + q = q.reshape( + [0, 0, self.n_head, q.shape[-1] // self.n_head]).transpose( + [0, 2, 1, 3]) #[batch, head, seq, dim] + k = k.reshape( + [0, 0, self.n_head, k.shape[-1] // self.n_head]).transpose( + [0, 2, 1, 3]) #[batch, head, seq, dim] + v = v.reshape( + [0, 0, self.n_head, v.shape[-1] // self.n_head]).transpose( + [0, 2, 1, 3]) #[batch, head, seq, dim] + + q = q.scale(self.d_key**-0.5) + score = q.matmul(k, transpose_y=True) + if attn_bias is not None: + score += attn_bias + score = F.softmax(score) + score = self.dropout(score) + + out = score.matmul(v).transpose([0, 2, 1, 3]) + out = out.reshape([0, 0, out.shape[2] * out.shape[3]]) + out = self.o(out) + return out, cache + + +class PositionwiseFeedForwardLayer(nn.Layer): + def __init__(self, cfg, name=None): + super(PositionwiseFeedForwardLayer, self).__init__() + initializer = nn.initializer.TruncatedNormal( + std=cfg['initializer_range']) + d_model = cfg['hidden_size'] + d_ffn = cfg.get('intermediate_size', 4 * d_model) + self.act = getattr(paddle.nn.functional, cfg['hidden_act']) + self.i = _build_linear( + d_model, + d_ffn, + append_name(name, 'fc_0'), + initializer, ) + self.o = _build_linear(d_ffn, d_model, + append_name(name, 'fc_1'), initializer) + prob = cfg.get('intermediate_dropout_prob', 0.) + self.dropout = nn.Dropout(p=prob) + + def forward(self, inputs): + hidden = self.act(self.i(inputs)) + hidden = self.dropout(hidden) + out = self.o(hidden) + return out + + +class ErnieEncoderLayer(nn.Layer): + def __init__(self, cfg, name=None): + super(ErnieEncoderLayer, self).__init__() + d_model = cfg['hidden_size'] + self.attn = AttentionLayer( + cfg, name=append_name(name, 'multi_head_att')) + self.ln1 = _build_ln(d_model, name=append_name(name, 'post_att')) + self.ffn = PositionwiseFeedForwardLayer( + cfg, name=append_name(name, 'ffn')) + self.ln2 = _build_ln(d_model, name=append_name(name, 'post_ffn')) + prob = cfg.get('intermediate_dropout_prob', cfg['hidden_dropout_prob']) + self.dropout = nn.Dropout(p=prob) + + def forward(self, inputs, attn_bias=None, past_cache=None): + attn_out, cache = self.attn( + inputs, inputs, inputs, attn_bias, + past_cache=past_cache) #self attn + attn_out = self.dropout(attn_out) + hidden = attn_out + inputs + hidden = self.ln1(hidden) # dropout/ add/ norm + + ffn_out = self.ffn(hidden) + ffn_out = self.dropout(ffn_out) + hidden = ffn_out + hidden + hidden = self.ln2(hidden) + return hidden, cache + + +class ErnieEncoderStack(nn.Layer): + def __init__(self, cfg, name=None): + super(ErnieEncoderStack, self).__init__() + n_layers = cfg['num_hidden_layers'] + self.block = nn.LayerList([ + ErnieEncoderLayer(cfg, append_name(name, 'layer_%d' % i)) + for i in range(n_layers) + ]) + + def forward(self, inputs, attn_bias=None, past_cache=None): + if past_cache is not None: + assert isinstance( + past_cache, tuple + ), 'unknown type of `past_cache`, expect tuple or list. got %s' % repr( + type(past_cache)) + past_cache = list(zip(*past_cache)) + else: + past_cache = [None] * len(self.block) + cache_list_k, cache_list_v, hidden_list = [], [], [inputs] + + for b, p in zip(self.block, past_cache): + inputs, cache = b(inputs, attn_bias=attn_bias, past_cache=p) + cache_k, cache_v = cache + cache_list_k.append(cache_k) + cache_list_v.append(cache_v) + hidden_list.append(inputs) + + return inputs, hidden_list, (cache_list_k, cache_list_v) + + +@six.add_metaclass(InitTrackerMeta) +class ErnieGenPretrainedModel(object): + model_config_file = "model_config.json" + ernie_gen_pretrained_init_configuration = { + "ernie-gen-base-en": { + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "max_position_embeddings": 1024, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "type_vocab_size": 4, + "vocab_size": 30522, + "pad_token_id": 0, + }, + "ernie-gen-large-en": { + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "max_position_embeddings": 1024, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "type_vocab_size": 4, + "vocab_size": 30522, + "pad_token_id": 0, + }, + "ernie-gen-large-en-430g": { + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "max_position_embeddings": 1024, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "type_vocab_size": 4, + "vocab_size": 30522, + "pad_token_id": 0, + }, + } + resource_files_names = {"model_state": "model_state.pdparams"} + ernie_gen_pretrained_resource_files_map = { + "model_state": { + "ernie-gen-base-en": + "https://paddlenlp.bj.bcebos.com/models/transformers/ernie-gen-base/ernie_gen_base.pdparams", + "ernie-gen-large-en": + "https://paddlenlp.bj.bcebos.com/models/transformers/ernie-gen-large/ernie_gen_large.pdparams", + "ernie-gen-large-430g-en": + "https://paddlenlp.bj.bcebos.com/models/transformers/ernie-gen-large-430g/ernie_gen_large_430g.pdparams", + } + } + + # Support more model to warm start. + pretrained_init_configuration = { + ** ernie_gen_pretrained_init_configuration, ** + BertPretrainedModel.pretrained_init_configuration, ** + ElectraPretrainedModel.pretrained_init_configuration, ** + RobertaPretrainedModel.pretrained_init_configuration, ** + ErniePretrainedModel.pretrained_init_configuration + } + pretrained_resource_files_map = { + "model_state": { + ** ernie_gen_pretrained_resource_files_map["model_state"], ** + BertPretrainedModel.pretrained_resource_files_map["model_state"], ** + ElectraPretrainedModel.pretrained_resource_files_map["model_state"], + ** + RobertaPretrainedModel.pretrained_resource_files_map["model_state"], + ** ErniePretrainedModel.pretrained_resource_files_map["model_state"] + } + } + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): + pretrained_models = list(cls.pretrained_init_configuration.keys()) + resource_files = {} + init_configuration = {} + if pretrained_model_name_or_path in pretrained_models: + for file_id, map_list in cls.pretrained_resource_files_map.items(): + resource_files[file_id] = map_list[ + pretrained_model_name_or_path] + init_configuration = copy.deepcopy( + cls.pretrained_init_configuration[ + pretrained_model_name_or_path]) + else: + if os.path.isdir(pretrained_model_name_or_path): + for file_id, file_name in cls.resource_files_names.items(): + full_file_name = os.path.join(pretrained_model_name_or_path, + file_name) + resource_files[file_id] = full_file_name + resource_files["model_config_file"] = os.path.join( + pretrained_model_name_or_path, cls.model_config_file) + else: + raise ValueError( + "Calling {}.from_pretrained() with a model identifier or the " + "path to a directory instead. The supported model " + "identifiers are as follows: {}".format( + cls.__name__, cls.pretrained_init_configuration.keys())) + + default_root = os.path.join(MODEL_HOME, pretrained_model_name_or_path) + resolved_resource_files = {} + for file_id, file_path in resource_files.items(): + path = os.path.join(default_root, file_path.split('/')[-1]) + if file_path is None or os.path.isfile(file_path): + resolved_resource_files[file_id] = file_path + elif os.path.exists(path): + logger.info("Already cached %s" % path) + resolved_resource_files[file_id] = path + else: + logger.info("Downloading %s and saved to %s" % + (file_path, default_root)) + resolved_resource_files[file_id] = get_path_from_url( + file_path, default_root) + + # Prepare model initialization kwargs + # Did we saved some inputs and kwargs to reload ? + model_config_file = resolved_resource_files.pop("model_config_file", + None) + if model_config_file is not None: + with io.open(model_config_file, encoding="utf-8") as f: + init_kwargs = json.load(f) + else: + init_kwargs = init_configuration + + # import pdb; pdb.set_trace() + if not os.path.exists(resolved_resource_files[file_id]): + raise ValueError('pretrain dir not found: %s' % + resolved_resource_files[file_id]) + + name_prefix = kwargs.pop('name', None) + model = cls(init_kwargs, name=name_prefix) + + weight_path = list(resolved_resource_files.values())[0] + logger.info('loading pretrained model from %s' % weight_path) + + if os.path.exists(weight_path): + m = paddle.load(weight_path) + params_name = list(m.keys()) + if 'mlm.weight' not in params_name: + # ernie_gen is not implemented with paddle.transformer. + # So, when loading the params saved by paddle.transformer, we should convert the params name. + # We will update ernie_gen with paddle.transformer in the future. + name_index_begin = params_name[0].index('.') + 1 + for old_name in params_name: + new_name = old_name[name_index_begin:].replace("embeddings.word_embeddings","word_emb").replace("embeddings.position_embeddings","pos_emb")\ + .replace("embeddings.token_type_embeddings","sent_emb").replace("embeddings.layer_norm","ln").replace("encoder.layers","encoder_stack.block")\ + .replace("self_attn","attn").replace("k_proj","k").replace("q_proj","q").replace("v_proj","v").replace("out_proj","o")\ + .replace("linear1","ffn.i").replace("linear2","ffn.o").replace("norm1","ln1").replace("norm2","ln2").replace("pooler.dense","pooler") + m[new_name] = m.pop(old_name) + for k, v in model.state_dict().items(): + if k not in m: + logger.info('param:%s not set in pretrained model, skip' % + k) + m[k] = v # FIXME: no need to do this in the future + model.set_state_dict(m) + else: + raise ValueError('weight file not found in pretrain dir: %s' % + weight_path) + return model + + def save_pretrained(self, save_directory): + """ + Save model configuration and related resources (model state) to files + under `save_directory`. + Args: + save_directory (str): Directory to save files into. + """ + assert os.path.isdir( + save_directory + ), "Saving directory ({}) should be a directory".format(save_directory) + # save model config + model_config_file = os.path.join(save_directory, self.model_config_file) + model_config = self.init_config + # If init_config contains a Layer, use the layer's init_config to save + for key, value in model_config.items(): + if key == "init_args": + args = [] + for arg in value: + args.append( + arg.init_config + if isinstance(arg, ErnieGenPretrainedModel) else arg) + model_config[key] = tuple(args) + elif isinstance(value, ErnieGenPretrainedModel): + model_config[key] = value.init_config + with io.open(model_config_file, "w", encoding="utf-8") as f: + f.write(json.dumps(model_config, ensure_ascii=False)) + # save model + file_name = os.path.join(save_directory, + list(self.resource_files_names.values())[0]) + paddle.save(self.state_dict(), file_name) + + def _wrap_init(self, original_init, *args, **kwargs): + """ + It would be hooked after `__init__` to add a dict including arguments of + `__init__` as a attribute named `config` of the prtrained model instance. + """ + init_dict = fn_args_to_dict(original_init, *args, **kwargs) + self.config = init_dict + + +class ErnieModel(nn.Layer, ErnieGenPretrainedModel): + def __init__(self, cfg, name=None): + """ + Fundamental pretrained Ernie model + """ + logger.debug('init ErnieModel with config: %s' % repr(cfg)) + nn.Layer.__init__(self) + d_model = cfg['hidden_size'] + d_emb = cfg.get('emb_size', cfg['hidden_size']) + d_vocab = cfg['vocab_size'] + d_pos = cfg['max_position_embeddings'] + d_sent = cfg.get("sent_type_vocab_size") or cfg['type_vocab_size'] + self.n_head = cfg['num_attention_heads'] + self.return_additional_info = cfg.get('return_additional_info', False) + initializer = nn.initializer.TruncatedNormal( + std=cfg['initializer_range']) + + self.ln = _build_ln(d_model, name=append_name(name, 'pre_encoder')) + self.word_emb = nn.Embedding( + d_vocab, + d_emb, + weight_attr=paddle.ParamAttr( + name=append_name(name, 'word_embedding'), + initializer=initializer)) + self.pos_emb = nn.Embedding( + d_pos, + d_emb, + weight_attr=paddle.ParamAttr( + name=append_name(name, 'pos_embedding'), + initializer=initializer)) + self.sent_emb = nn.Embedding( + d_sent, + d_emb, + weight_attr=paddle.ParamAttr( + name=append_name(name, 'sent_embedding'), + initializer=initializer)) + prob = cfg['hidden_dropout_prob'] + self.dropout = nn.Dropout(p=prob) + + self.encoder_stack = ErnieEncoderStack(cfg, + append_name(name, 'encoder')) + + def forward(self, + src_ids, + sent_ids=None, + pos_ids=None, + input_mask=None, + attn_bias=None, + past_cache=None, + use_causal_mask=False): + """ + Args: + src_ids (`Variable` of shape `[batch_size, seq_len]`): + Indices of input sequence tokens in the vocabulary. + sent_ids (optional, `Variable` of shape `[batch_size, seq_len]`): + aka token_type_ids, Segment token indices to indicate first and second portions of the inputs. + if None, assume all tokens come from `segment_a` + pos_ids(optional, `Variable` of shape `[batch_size, seq_len]`): + Indices of positions of each input sequence tokens in the position embeddings. + input_mask(optional `Variable` of shape `[batch_size, seq_len]`): + Mask to avoid performing attention on the padding token indices of the encoder input. + attn_bias(optional, `Variable` of shape `[batch_size, seq_len, seq_len] or False`): + 3D version of `input_mask`, if set, overrides `input_mask`; if set not False, will not apply attention mask + past_cache(optional, tuple of two lists: cached key and cached value, + each is a list of `Variable`s of shape `[batch_size, seq_len, hidden_size]`): + cached key/value tensor that will be concated to generated key/value when performing self attention. + if set, `attn_bias` should not be None. + Returns: + pooled (`Variable` of shape `[batch_size, hidden_size]`): + output logits of pooler classifier + encoded(`Variable` of shape `[batch_size, seq_len, hidden_size]`): + output logits of transformer stack + info (Dictionary): + addtional middle level info, inclues: all hidden stats, k/v caches. + """ + assert len( + src_ids. + shape) == 2, 'expect src_ids.shape = [batch, sequecen], got %s' % ( + repr(src_ids.shape)) + assert attn_bias is not None if past_cache else True, 'if `past_cache` is specified; attn_bias should not be None' + d_seqlen = paddle.shape(src_ids)[1] + if pos_ids is None: + pos_ids = paddle.arange( + 0, d_seqlen, 1, dtype='int32').reshape([1, -1]).cast('int64') + if attn_bias is None: + if input_mask is None: + input_mask = paddle.cast(src_ids != 0, 'float32') + assert len(input_mask.shape) == 2 + input_mask = input_mask.unsqueeze(-1) + attn_bias = input_mask.matmul(input_mask, transpose_y=True) + if use_causal_mask: + sequence = paddle.reshape( + paddle.arange( + 0, d_seqlen, 1, dtype='float32') + 1., [1, 1, -1, 1]) + causal_mask = (sequence.matmul( + 1. / sequence, transpose_y=True) >= 1.).cast('float32') + attn_bias *= causal_mask + else: + assert len( + attn_bias.shape + ) == 3, 'expect attn_bias tobe rank 3, got %r' % attn_bias.shape + attn_bias = (1. - attn_bias) * -10000.0 + attn_bias = attn_bias.unsqueeze(1).tile( + [1, self.n_head, 1, 1]) # avoid broadcast =_= + + if sent_ids is None: + sent_ids = paddle.zeros_like(src_ids) + + src_embedded = self.word_emb(src_ids) + pos_embedded = self.pos_emb(pos_ids) + sent_embedded = self.sent_emb(sent_ids) + embedded = src_embedded + pos_embedded + sent_embedded + + embedded = self.dropout(self.ln(embedded)) + + encoded, hidden_list, cache_list = self.encoder_stack( + embedded, attn_bias, past_cache=past_cache) + + additional_info = { + 'hiddens': hidden_list, + 'caches': cache_list, + } + + return encoded, additional_info + + +class ErnieForGeneration(ErnieModel): + """ + Ernie Model for sequence to sequence generation. + """ + + def __init__(self, cfg, name=None): + super(ErnieForGeneration, self).__init__(cfg, name=name) + initializer = nn.initializer.TruncatedNormal( + std=cfg['initializer_range']) + d_model = cfg['hidden_size'] + d_vocab = cfg['vocab_size'] + + self.mlm = _build_linear( + d_model, + d_model, + append_name(name, 'mask_lm_trans_fc'), + initializer, ) + self.act = getattr(paddle.nn.functional, cfg['hidden_act']) + self.mlm_ln = _build_ln( + d_model, name=append_name(name, 'mask_lm_trans')) + self.mlm_bias = paddle.create_parameter( + dtype='float32', + shape=[d_vocab], + attr=paddle.ParamAttr( + name=append_name(name, 'mask_lm_out_fc.b_0'), + initializer=nn.initializer.Constant(value=0.0)), + is_bias=True, ) + + def forward(self, *args, **kwargs): + """ + Args + tgt_labels(`Variable` of shape [batch_size, seqlen] or [batch, seqlen, vocab_size]): + ground trouth target sequence id (hard label) or distribution (soft label) + tgt_pos(`Variable` of shape [n_targets, 2]): + index of tgt_labels in `src_ids`, can be obtained from `fluid.layers.where(src_ids==mask_id)` + encoder_only(Bool): + if set, will not return loss, logits_2d + Returns: + loss(`Variable` of shape []): + cross entropy loss mean over every target label. if `encode_only`, returns None. + logits(`Variable` of shape [n_targets, vocab_size]): + logits for every targets. if `encode_only`, returns None. + info(Dictionary): see `ErnieModel` + """ + tgt_labels = kwargs.pop('tgt_labels', None) + tgt_pos = kwargs.pop('tgt_pos', None) + encode_only = kwargs.pop('encode_only', False) + encoded, info = ErnieModel.forward(self, *args, **kwargs) + if encode_only: + return None, None, info + if tgt_labels is None or tgt_pos is None: + encoded = self.act(self.mlm(encoded)) + encoded = self.mlm_ln(encoded) + logits = encoded.matmul( + self.word_emb.weight, transpose_y=True) + self.mlm_bias + output_ids = logits.argmax(-1) + return output_ids, logits, info + else: + encoded_2d = encoded.gather_nd(tgt_pos) + encoded_2d = self.act(self.mlm(encoded_2d)) + encoded_2d = self.mlm_ln(encoded_2d) + logits_2d = encoded_2d.matmul( + self.word_emb.weight, transpose_y=True) + self.mlm_bias + if len(tgt_labels.shape) == 1: + tgt_labels = paddle.reshape(tgt_labels, [-1, 1]) + + loss = paddle.nn.functional.cross_entropy( + logits_2d, tgt_labels, soft_label=(tgt_labels.shape[-1] != 1)) + + return loss, logits_2d, info diff --git a/PaddleNLP/paddlenlp/transformers/ernie_gen/params_map.json b/PaddleNLP/paddlenlp/transformers/ernie_gen/params_map.json new file mode 100644 index 0000000000000000000000000000000000000000..320e940ee69bf395e05d07e3e92c7bafcd49c94b --- /dev/null +++ b/PaddleNLP/paddlenlp/transformers/ernie_gen/params_map.json @@ -0,0 +1 @@ +{"embeddings.word_embeddings.weight": "word_emb.weight", "embeddings.position_embeddings.weight": "pos_emb.weight", "embeddings.token_type_embeddings.weight": "sent_emb.weight", "embeddings.layer_norm.weight": "ln.weight", "embeddings.layer_norm.bias": "ln.bias", "encoder.layers.0.self_attn.q_proj.weight": "encoder_stack.block.0.attn.q.weight", "encoder.layers.0.self_attn.q_proj.bias": "encoder_stack.block.0.attn.q.bias", "encoder.layers.0.self_attn.k_proj.weight": "encoder_stack.block.0.attn.k.weight", "encoder.layers.0.self_attn.k_proj.bias": "encoder_stack.block.0.attn.k.bias", "encoder.layers.0.self_attn.v_proj.weight": "encoder_stack.block.0.attn.v.weight", "encoder.layers.0.self_attn.v_proj.bias": "encoder_stack.block.0.attn.v.bias", "encoder.layers.0.self_attn.out_proj.weight": "encoder_stack.block.0.attn.o.weight", "encoder.layers.0.self_attn.out_proj.bias": "encoder_stack.block.0.attn.o.bias", "encoder.layers.1.self_attn.q_proj.weight": "encoder_stack.block.1.attn.q.weight", "encoder.layers.1.self_attn.q_proj.bias": "encoder_stack.block.1.attn.q.bias", "encoder.layers.1.self_attn.k_proj.weight": "encoder_stack.block.1.attn.k.weight", "encoder.layers.1.self_attn.k_proj.bias": "encoder_stack.block.1.attn.k.bias", "encoder.layers.1.self_attn.v_proj.weight": "encoder_stack.block.1.attn.v.weight", "encoder.layers.1.self_attn.v_proj.bias": "encoder_stack.block.1.attn.v.bias", "encoder.layers.1.self_attn.out_proj.weight": "encoder_stack.block.1.attn.o.weight", "encoder.layers.1.self_attn.out_proj.bias": "encoder_stack.block.1.attn.o.bias", "encoder.layers.2.self_attn.q_proj.weight": "encoder_stack.block.2.attn.q.weight", "encoder.layers.2.self_attn.q_proj.bias": "encoder_stack.block.2.attn.q.bias", "encoder.layers.2.self_attn.k_proj.weight": "encoder_stack.block.2.attn.k.weight", "encoder.layers.2.self_attn.k_proj.bias": "encoder_stack.block.2.attn.k.bias", "encoder.layers.2.self_attn.v_proj.weight": "encoder_stack.block.2.attn.v.weight", "encoder.layers.2.self_attn.v_proj.bias": "encoder_stack.block.2.attn.v.bias", "encoder.layers.2.self_attn.out_proj.weight": "encoder_stack.block.2.attn.o.weight", "encoder.layers.2.self_attn.out_proj.bias": "encoder_stack.block.2.attn.o.bias", "encoder.layers.3.self_attn.q_proj.weight": "encoder_stack.block.3.attn.q.weight", "encoder.layers.3.self_attn.q_proj.bias": "encoder_stack.block.3.attn.q.bias", "encoder.layers.3.self_attn.k_proj.weight": "encoder_stack.block.3.attn.k.weight", "encoder.layers.3.self_attn.k_proj.bias": "encoder_stack.block.3.attn.k.bias", "encoder.layers.3.self_attn.v_proj.weight": "encoder_stack.block.3.attn.v.weight", "encoder.layers.3.self_attn.v_proj.bias": "encoder_stack.block.3.attn.v.bias", "encoder.layers.3.self_attn.out_proj.weight": "encoder_stack.block.3.attn.o.weight", "encoder.layers.3.self_attn.out_proj.bias": "encoder_stack.block.3.attn.o.bias", "encoder.layers.4.self_attn.q_proj.weight": "encoder_stack.block.4.attn.q.weight", "encoder.layers.4.self_attn.q_proj.bias": "encoder_stack.block.4.attn.q.bias", "encoder.layers.4.self_attn.k_proj.weight": "encoder_stack.block.4.attn.k.weight", "encoder.layers.4.self_attn.k_proj.bias": "encoder_stack.block.4.attn.k.bias", "encoder.layers.4.self_attn.v_proj.weight": "encoder_stack.block.4.attn.v.weight", "encoder.layers.4.self_attn.v_proj.bias": "encoder_stack.block.4.attn.v.bias", "encoder.layers.4.self_attn.out_proj.weight": "encoder_stack.block.4.attn.o.weight", "encoder.layers.4.self_attn.out_proj.bias": "encoder_stack.block.4.attn.o.bias", "encoder.layers.5.self_attn.q_proj.weight": "encoder_stack.block.5.attn.q.weight", "encoder.layers.5.self_attn.q_proj.bias": "encoder_stack.block.5.attn.q.bias", "encoder.layers.5.self_attn.k_proj.weight": "encoder_stack.block.5.attn.k.weight", "encoder.layers.5.self_attn.k_proj.bias": "encoder_stack.block.5.attn.k.bias", "encoder.layers.5.self_attn.v_proj.weight": "encoder_stack.block.5.attn.v.weight", "encoder.layers.5.self_attn.v_proj.bias": "encoder_stack.block.5.attn.v.bias", "encoder.layers.5.self_attn.out_proj.weight": "encoder_stack.block.5.attn.o.weight", "encoder.layers.5.self_attn.out_proj.bias": "encoder_stack.block.5.attn.o.bias", "encoder.layers.6.self_attn.q_proj.weight": "encoder_stack.block.6.attn.q.weight", "encoder.layers.6.self_attn.q_proj.bias": "encoder_stack.block.6.attn.q.bias", "encoder.layers.6.self_attn.k_proj.weight": "encoder_stack.block.6.attn.k.weight", "encoder.layers.6.self_attn.k_proj.bias": "encoder_stack.block.6.attn.k.bias", "encoder.layers.6.self_attn.v_proj.weight": "encoder_stack.block.6.attn.v.weight", "encoder.layers.6.self_attn.v_proj.bias": "encoder_stack.block.6.attn.v.bias", "encoder.layers.6.self_attn.out_proj.weight": "encoder_stack.block.6.attn.o.weight", "encoder.layers.6.self_attn.out_proj.bias": "encoder_stack.block.6.attn.o.bias", "encoder.layers.7.self_attn.q_proj.weight": "encoder_stack.block.7.attn.q.weight", "encoder.layers.7.self_attn.q_proj.bias": "encoder_stack.block.7.attn.q.bias", "encoder.layers.7.self_attn.k_proj.weight": "encoder_stack.block.7.attn.k.weight", "encoder.layers.7.self_attn.k_proj.bias": "encoder_stack.block.7.attn.k.bias", "encoder.layers.7.self_attn.v_proj.weight": "encoder_stack.block.7.attn.v.weight", "encoder.layers.7.self_attn.v_proj.bias": "encoder_stack.block.7.attn.v.bias", "encoder.layers.7.self_attn.out_proj.weight": "encoder_stack.block.7.attn.o.weight", "encoder.layers.7.self_attn.out_proj.bias": "encoder_stack.block.7.attn.o.bias", "encoder.layers.8.self_attn.q_proj.weight": "encoder_stack.block.8.attn.q.weight", "encoder.layers.8.self_attn.q_proj.bias": "encoder_stack.block.8.attn.q.bias", "encoder.layers.8.self_attn.k_proj.weight": "encoder_stack.block.8.attn.k.weight", "encoder.layers.8.self_attn.k_proj.bias": "encoder_stack.block.8.attn.k.bias", "encoder.layers.8.self_attn.v_proj.weight": "encoder_stack.block.8.attn.v.weight", "encoder.layers.8.self_attn.v_proj.bias": "encoder_stack.block.8.attn.v.bias", "encoder.layers.8.self_attn.out_proj.weight": "encoder_stack.block.8.attn.o.weight", "encoder.layers.8.self_attn.out_proj.bias": "encoder_stack.block.8.attn.o.bias", "encoder.layers.9.self_attn.q_proj.weight": "encoder_stack.block.9.attn.q.weight", "encoder.layers.9.self_attn.q_proj.bias": "encoder_stack.block.9.attn.q.bias", "encoder.layers.9.self_attn.k_proj.weight": "encoder_stack.block.9.attn.k.weight", "encoder.layers.9.self_attn.k_proj.bias": "encoder_stack.block.9.attn.k.bias", "encoder.layers.9.self_attn.v_proj.weight": "encoder_stack.block.9.attn.v.weight", "encoder.layers.9.self_attn.v_proj.bias": "encoder_stack.block.9.attn.v.bias", "encoder.layers.9.self_attn.out_proj.weight": "encoder_stack.block.9.attn.o.weight", "encoder.layers.9.self_attn.out_proj.bias": "encoder_stack.block.9.attn.o.bias", "encoder.layers.10.self_attn.q_proj.weight": "encoder_stack.block.10.attn.q.weight", "encoder.layers.10.self_attn.q_proj.bias": "encoder_stack.block.10.attn.q.bias", "encoder.layers.10.self_attn.k_proj.weight": "encoder_stack.block.10.attn.k.weight", "encoder.layers.10.self_attn.k_proj.bias": "encoder_stack.block.10.attn.k.bias", "encoder.layers.10.self_attn.v_proj.weight": "encoder_stack.block.10.attn.v.weight", "encoder.layers.10.self_attn.v_proj.bias": "encoder_stack.block.10.attn.v.bias", "encoder.layers.10.self_attn.out_proj.weight": "encoder_stack.block.10.attn.o.weight", "encoder.layers.10.self_attn.out_proj.bias": "encoder_stack.block.10.attn.o.bias", "encoder.layers.11.self_attn.q_proj.weight": "encoder_stack.block.11.attn.q.weight", "encoder.layers.11.self_attn.q_proj.bias": "encoder_stack.block.11.attn.q.bias", "encoder.layers.11.self_attn.k_proj.weight": "encoder_stack.block.11.attn.k.weight", "encoder.layers.11.self_attn.k_proj.bias": "encoder_stack.block.11.attn.k.bias", "encoder.layers.11.self_attn.v_proj.weight": "encoder_stack.block.11.attn.v.weight", "encoder.layers.11.self_attn.v_proj.bias": "encoder_stack.block.11.attn.v.bias", "encoder.layers.11.self_attn.out_proj.weight": "encoder_stack.block.11.attn.o.weight", "encoder.layers.11.self_attn.out_proj.bias": "encoder_stack.block.11.attn.o.bias", "encoder.layers.0.linear1.weight": "encoder_stack.block.0.ffn.i.weight", "encoder.layers.0.linear1.bias": "encoder_stack.block.0.ffn.i.bias", "encoder.layers.0.linear2.weight": "encoder_stack.block.0.ffn.o.weight", "encoder.layers.0.linear2.bias": "encoder_stack.block.0.ffn.o.bias", "encoder.layers.1.linear1.weight": "encoder_stack.block.1.ffn.i.weight", "encoder.layers.1.linear1.bias": "encoder_stack.block.1.ffn.i.bias", "encoder.layers.1.linear2.weight": "encoder_stack.block.1.ffn.o.weight", "encoder.layers.1.linear2.bias": "encoder_stack.block.1.ffn.o.bias", "encoder.layers.2.linear1.weight": "encoder_stack.block.2.ffn.i.weight", "encoder.layers.2.linear1.bias": "encoder_stack.block.2.ffn.i.bias", "encoder.layers.2.linear2.weight": "encoder_stack.block.2.ffn.o.weight", "encoder.layers.2.linear2.bias": "encoder_stack.block.2.ffn.o.bias", "encoder.layers.3.linear1.weight": "encoder_stack.block.3.ffn.i.weight", "encoder.layers.3.linear1.bias": "encoder_stack.block.3.ffn.i.bias", "encoder.layers.3.linear2.weight": "encoder_stack.block.3.ffn.o.weight", "encoder.layers.3.linear2.bias": "encoder_stack.block.3.ffn.o.bias", "encoder.layers.4.linear1.weight": "encoder_stack.block.4.ffn.i.weight", "encoder.layers.4.linear1.bias": "encoder_stack.block.4.ffn.i.bias", "encoder.layers.4.linear2.weight": "encoder_stack.block.4.ffn.o.weight", "encoder.layers.4.linear2.bias": "encoder_stack.block.4.ffn.o.bias", "encoder.layers.5.linear1.weight": "encoder_stack.block.5.ffn.i.weight", "encoder.layers.5.linear1.bias": "encoder_stack.block.5.ffn.i.bias", "encoder.layers.5.linear2.weight": "encoder_stack.block.5.ffn.o.weight", "encoder.layers.5.linear2.bias": "encoder_stack.block.5.ffn.o.bias", "encoder.layers.6.linear1.weight": "encoder_stack.block.6.ffn.i.weight", "encoder.layers.6.linear1.bias": "encoder_stack.block.6.ffn.i.bias", "encoder.layers.6.linear2.weight": "encoder_stack.block.6.ffn.o.weight", "encoder.layers.6.linear2.bias": "encoder_stack.block.6.ffn.o.bias", "encoder.layers.7.linear1.weight": "encoder_stack.block.7.ffn.i.weight", "encoder.layers.7.linear1.bias": "encoder_stack.block.7.ffn.i.bias", "encoder.layers.7.linear2.weight": "encoder_stack.block.7.ffn.o.weight", "encoder.layers.7.linear2.bias": "encoder_stack.block.7.ffn.o.bias", "encoder.layers.8.linear1.weight": "encoder_stack.block.8.ffn.i.weight", "encoder.layers.8.linear1.bias": "encoder_stack.block.8.ffn.i.bias", "encoder.layers.8.linear2.weight": "encoder_stack.block.8.ffn.o.weight", "encoder.layers.8.linear2.bias": "encoder_stack.block.8.ffn.o.bias", "encoder.layers.9.linear1.weight": "encoder_stack.block.9.ffn.i.weight", "encoder.layers.9.linear1.bias": "encoder_stack.block.9.ffn.i.bias", "encoder.layers.9.linear2.weight": "encoder_stack.block.9.ffn.o.weight", "encoder.layers.9.linear2.bias": "encoder_stack.block.9.ffn.o.bias", "encoder.layers.10.linear1.weight": "encoder_stack.block.10.ffn.i.weight", "encoder.layers.10.linear1.bias": "encoder_stack.block.10.ffn.i.bias", "encoder.layers.10.linear2.weight": "encoder_stack.block.10.ffn.o.weight", "encoder.layers.10.linear2.bias": "encoder_stack.block.10.ffn.o.bias", "encoder.layers.11.linear1.weight": "encoder_stack.block.11.ffn.i.weight", "encoder.layers.11.linear1.bias": "encoder_stack.block.11.ffn.i.bias", "encoder.layers.11.linear2.weight": "encoder_stack.block.11.ffn.o.weight", "encoder.layers.11.linear2.bias": "encoder_stack.block.11.ffn.o.bias", "encoder.layers.0.norm1.weight": "encoder_stack.block.0.ln1.weight", "encoder.layers.0.norm1.bias": "encoder_stack.block.0.ln1.bias", "encoder.layers.1.norm1.weight": "encoder_stack.block.1.ln1.weight", "encoder.layers.1.norm1.bias": "encoder_stack.block.1.ln1.bias", "encoder.layers.2.norm1.weight": "encoder_stack.block.2.ln1.weight", "encoder.layers.2.norm1.bias": "encoder_stack.block.2.ln1.bias", "encoder.layers.3.norm1.weight": "encoder_stack.block.3.ln1.weight", "encoder.layers.3.norm1.bias": "encoder_stack.block.3.ln1.bias", "encoder.layers.4.norm1.weight": "encoder_stack.block.4.ln1.weight", "encoder.layers.4.norm1.bias": "encoder_stack.block.4.ln1.bias", "encoder.layers.5.norm1.weight": "encoder_stack.block.5.ln1.weight", "encoder.layers.5.norm1.bias": "encoder_stack.block.5.ln1.bias", "encoder.layers.6.norm1.weight": "encoder_stack.block.6.ln1.weight", "encoder.layers.6.norm1.bias": "encoder_stack.block.6.ln1.bias", "encoder.layers.7.norm1.weight": "encoder_stack.block.7.ln1.weight", "encoder.layers.7.norm1.bias": "encoder_stack.block.7.ln1.bias", "encoder.layers.8.norm1.weight": "encoder_stack.block.8.ln1.weight", "encoder.layers.8.norm1.bias": "encoder_stack.block.8.ln1.bias", "encoder.layers.9.norm1.weight": "encoder_stack.block.9.ln1.weight", "encoder.layers.9.norm1.bias": "encoder_stack.block.9.ln1.bias", "encoder.layers.10.norm1.weight": "encoder_stack.block.10.ln1.weight", "encoder.layers.10.norm1.bias": "encoder_stack.block.10.ln1.bias", "encoder.layers.11.norm1.weight": "encoder_stack.block.11.ln1.weight", "encoder.layers.11.norm1.bias": "encoder_stack.block.11.ln1.bias", "encoder.layers.0.norm2.weight": "encoder_stack.block.0.ln2.weight", "encoder.layers.0.norm2.bias": "encoder_stack.block.0.ln2.bias", "encoder.layers.1.norm2.weight": "encoder_stack.block.1.ln2.weight", "encoder.layers.1.norm2.bias": "encoder_stack.block.1.ln2.bias", "encoder.layers.2.norm2.weight": "encoder_stack.block.2.ln2.weight", "encoder.layers.2.norm2.bias": "encoder_stack.block.2.ln2.bias", "encoder.layers.3.norm2.weight": "encoder_stack.block.3.ln2.weight", "encoder.layers.3.norm2.bias": "encoder_stack.block.3.ln2.bias", "encoder.layers.4.norm2.weight": "encoder_stack.block.4.ln2.weight", "encoder.layers.4.norm2.bias": "encoder_stack.block.4.ln2.bias", "encoder.layers.5.norm2.weight": "encoder_stack.block.5.ln2.weight", "encoder.layers.5.norm2.bias": "encoder_stack.block.5.ln2.bias", "encoder.layers.6.norm2.weight": "encoder_stack.block.6.ln2.weight", "encoder.layers.6.norm2.bias": "encoder_stack.block.6.ln2.bias", "encoder.layers.7.norm2.weight": "encoder_stack.block.7.ln2.weight", "encoder.layers.7.norm2.bias": "encoder_stack.block.7.ln2.bias", "encoder.layers.8.norm2.weight": "encoder_stack.block.8.ln2.weight", "encoder.layers.8.norm2.bias": "encoder_stack.block.8.ln2.bias", "encoder.layers.9.norm2.weight": "encoder_stack.block.9.ln2.weight", "encoder.layers.9.norm2.bias": "encoder_stack.block.9.ln2.bias", "encoder.layers.10.norm2.weight": "encoder_stack.block.10.ln2.weight", "encoder.layers.10.norm2.bias": "encoder_stack.block.10.ln2.bias", "encoder.layers.11.norm2.weight": "encoder_stack.block.11.ln2.weight", "encoder.layers.11.norm2.bias": "encoder_stack.block.11.ln2.bias", "pooler.dense.weight": "pooler.weight", "pooler.dense.bias": "pooler.bias"} \ No newline at end of file diff --git a/PaddleNLP/paddlenlp/transformers/roberta/modeling.py b/PaddleNLP/paddlenlp/transformers/roberta/modeling.py index 9e340ff4dbfe2e1001db4a9c2f53a078907310e0..fc797cfee5ff2ddca44257b503906c769cdf9791 100644 --- a/PaddleNLP/paddlenlp/transformers/roberta/modeling.py +++ b/PaddleNLP/paddlenlp/transformers/roberta/modeling.py @@ -19,6 +19,7 @@ from .. import PretrainedModel, register_base_model __all__ = [ 'RobertaModel', + 'RobertaPretrainedModel', 'RobertaForSequenceClassification', 'RobertaForTokenClassification', 'RobertaForQuestionAnswering',