未验证 提交 26bc357c 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

Update CPM-LM module and add a new text generation module GPT2_Base_CN (#1316)

* Update GPT2 Text Generation Module
上级 9299ce78
from .model import GPT2Model
from .tokenization import GPT2Tokenizer
import math
import paddle
import paddle.nn as nn
class MLP(nn.Layer):
def __init__(self, embedding_size):
super(MLP, self).__init__()
self.dense_h_to_4h = nn.Linear(embedding_size, embedding_size * 4)
self.dense_4h_to_h = nn.Linear(embedding_size * 4, embedding_size)
self.act = nn.functional.gelu
def forward(self, x):
h = self.act(self.dense_h_to_4h(x))
h2 = self.dense_4h_to_h(h)
return h2
class Attention(nn.Layer):
def __init__(self, embedding_size, num_attention_heads, attention_dropout, residual_dropout):
super(Attention, self).__init__()
self.num_attention_heads = num_attention_heads
self.size_per_head = embedding_size // num_attention_heads
self.embedding_size = embedding_size
self.query_key_value = nn.Linear(embedding_size, embedding_size * 3)
self.attn_drop = nn.Dropout(attention_dropout)
self.resid_drop = nn.Dropout(residual_dropout)
self.dense = nn.Linear(embedding_size, embedding_size)
def split_heads(self, x):
x = x.reshape([-1, self.seq_len, self.num_attention_heads, self.size_per_head])
return x.transpose((0, 2, 1, 3))
def forward(self, x, kv_cache=None):
self.seq_len = x.shape[1]
x = self.query_key_value(x)
q, k, v = x.split(num_or_sections=3, axis=2)
q = self.split_heads(q)
k = self.split_heads(k)
v = self.split_heads(v)
if kv_cache is not None:
pk, pv = paddle.unstack(kv_cache, axis=1)
k = paddle.concat([pk, k], axis=-2)
v = paddle.concat([pv, v], axis=-2)
cached_kv = paddle.stack([k, v], axis=1)
attn = paddle.matmul(q, k, transpose_y=True) # [B, N, L, S]
attn = attn / math.sqrt(self.size_per_head)
# [L, S]
attention_mask = paddle.tril(paddle.ones([self.seq_len, self.seq_len], 'float32'))
attention_mask = attention_mask.reshape([1, 1, self.seq_len, self.seq_len])
# adding to softmax -> its like removing them entirely
attn = attn * attention_mask - 10000.0 * (1.0 - attention_mask)
attn = nn.Softmax(axis=-1)(attn)
attn = self.attn_drop(attn)
y = paddle.matmul(attn, v)
# [B, N, L, S] -> [B, L, N, S]
y = y.transpose((0, 2, 1, 3))
y = paddle.reshape(y, [-1, self.seq_len, self.embedding_size])
y = self.resid_drop(self.dense(y))
return y, cached_kv
class Block(nn.Layer):
def __init__(self, embedding_size, num_attention_heads, attention_dropout, residual_dropout):
super(Block, self).__init__()
self.input_layernorm = nn.LayerNorm(embedding_size, epsilon=1e-5)
self.attention = Attention(embedding_size, num_attention_heads, attention_dropout, residual_dropout)
self.post_attention_layernorm = nn.LayerNorm(embedding_size, epsilon=1e-5)
self.mlp = MLP(embedding_size)
def forward(self, x, kv_cache=None):
attn, cached_kv = self.attention(self.input_layernorm(x), kv_cache=kv_cache)
x = x + attn
z = self.post_attention_layernorm(x)
z = self.mlp(z)
x = x + z
return x, cached_kv
class Transformer(nn.Layer):
def __init__(self, layer_size, embedding_size, num_attention_heads, attention_dropout, residual_dropout):
super(Transformer, self).__init__()
self.layers = nn.LayerList([
Block(embedding_size, num_attention_heads, attention_dropout, residual_dropout) for _ in range(layer_size)
])
self.final_layernorm = nn.LayerNorm(embedding_size, epsilon=1e-5)
def forward(self, x, kv_cache=None):
cached_kvs = []
for i, layer in enumerate(self.layers):
x, cached_kv = layer(x, kv_cache=kv_cache[i] if kv_cache is not None else None)
cached_kvs.append(cached_kv)
x = self.final_layernorm(x)
return x, paddle.stack(cached_kvs)
class GPT2Model(nn.Layer):
def __init__(self, vocab_size, layer_size, block_size, embedding_dropout, embedding_size, num_attention_heads,
attention_dropout, residual_dropout):
super(GPT2Model, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, embedding_size)
self.position_embeddings = nn.Embedding(block_size, embedding_size)
self.emb_drop = nn.Dropout(embedding_dropout)
self.transformer = Transformer(layer_size, embedding_size, num_attention_heads, attention_dropout,
residual_dropout)
def forward(self, x, kv_cache=None, use_cache=False):
if kv_cache is None:
past_length = 0
else:
past_length = kv_cache[0][0].shape[-2]
position_ids = paddle.arange(past_length, x.shape[-1] + past_length, dtype='int64')
position_ids = position_ids.unsqueeze(0).expand_as(x)
x = self.word_embeddings(x)
x = self.emb_drop(x + self.position_embeddings(position_ids))
x, cached_kvs = self.transformer(x, kv_cache)
x = paddle.matmul(x, self.word_embeddings.weight, transpose_y=True)
if use_cache:
return x, cached_kvs
else:
return x
if __name__ == '__main__':
gpt = GPT2Model(
vocab_size=30000,
layer_size=2,
block_size=1024,
embedding_dropout=0.0,
embedding_size=2560,
num_attention_heads=32,
attention_dropout=0.0,
residual_dropout=0.0)
gpt.eval()
out, cached_kvs = gpt(paddle.ones([1, 1], 'int64'), paddle.randn([32, 1, 2, 32, 9, 80], 'float32'), use_cache=True)
print(out.shape, cached_kvs.shape)
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function, unicode_literals)
import json
from io import open
import sentencepiece as spm
import jieba
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
class GPT2Tokenizer(object):
def __init__(self, vocab_file, model_file, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v: k for k, v in self.encoder.items()}
self.sp = spm.SentencePieceProcessor(model_file=model_file)
self.translator = str.maketrans(" \n", "\u2582\u2583")
self.eod_id = self.encoder['<eod>']
@property
def vocab_size(self):
return len(self.encoder)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
@property
def eod(self):
return self.eod_id
def tokenize(self, text):
""" Tokenize a string. """
seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)]
new_seg = " ".join(seg_list)
return self.sp.encode(new_seg)
def encode(self, text):
res = self.tokenize(text)
return res
def decode(self, tokens):
text = self.sp.decode(tokens)
text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')
return text
## 概述
GPT2_Base_CN 是一个预训练生成模型,是 PaddleNLP 中的内置模型。
## API
```python
def greedy_search(
text,
max_len=32,
end_word=None):
```
文本生成 API ,根据输入的文字进行文本生成,使用 Greedy Search 进行解码,生成的文本单一且确定,适合于问答类的文本生成。
**参数**
* text (str) : 输入文本
* max_len (int) : 生成文本的最大长度
* end_word (str or None) : 终止生成的标志词
**返回**
* results (str): 生成的文本
```python
def nucleus_sample(
text,
max_len=32,
end_word=None,
repitition_penalty=1.0,
temperature=1.0,
top_k=0,
top_p=1.0):
```
文本生成 API ,根据输入的文字进行文本生成,使用采样的方式进行解码,生成的文本比较多样,适合于文章类的文本生成。
**参数**
* text (str) : 输入文本
* max_len (int) : 生成文本的最大长度
* end_word (str or None) : 终止生成的标志词
* repitition_penalty (float) : 重复词抑制率,大于1抑制,小于1提高
* temperature (float) :较低的temperature可以让模型对最佳选择越来越有信息,大于1,则会降低,0则相当于 argmax/max ,inf则相当于均匀采样
* top_k (int) : 抑制小于 Top K 的输出,大于0时有效
* top_p (float) : 抑制小于 Top P 的输出,小于1.0时有效
**返回**
* results (str): 生成的文本
**代码示例**
* 加载模型:
```python
import paddlehub as hub
model = hub.Module(name='GPT2_Base_CN')
```
* 使用 Greedy Search 生成文本:
```python
inputs = '''默写古诗:
日照香炉生紫烟,遥看瀑布挂前川。
飞流直下三千尺,'''
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
默写古诗:
日照香炉生紫烟,遥看瀑布挂前川。
飞流直下三千尺,疑是银河落九天。
```python
inputs = '''问题:西游记是谁写的?
答案:'''
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
问题:西游记是谁写的?
答案:吴承恩。
```python
inputs = '''小明决定去吃饭,小红继续写作业
问题:去吃饭的人是谁?
答案:'''
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
小明决定去吃饭,小红继续写作业
问题:去吃饭的人是谁?
答案:小明
```python
inputs = '''默写英文:
狗:dog
猫:'''
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
默写英文:
狗:dog
猫:cat
* 使用采样方式生成文本:
```python
inputs = '''在此处输入文本的开头'''
outputs = model.nucleus_sample(
inputs,
max_len=32,
end_word='。',
repitition_penalty=1.0,
temperature=1.0,
top_k=5,
top_p=1.0
)
print(outputs)
```
在此处输入文本的开头字母。
```python
inputs = '''《乡土中国》是费孝通先生在社区研究的基础上从宏观角度探讨中国社会结构的著作,'''
outputs = model.nucleus_sample(
inputs,
max_len=32,
end_word='。',
repitition_penalty=1.0,
temperature=1.0,
top_k=3000,
top_p=1.0
)
print(outputs)
```
《乡土中国》是费孝通先生在社区研究的基础上从宏观角度探讨中国社会结构的著作,肯定了集体所有制在提升中国中低山地区农村生活水平方面所起的积极作用。
## 服务部署
PaddleHub Serving 可以部署一个在线文本生成服务。
## 第一步:启动PaddleHub Serving
运行启动命令:
```shell
$ hub serving start --modules GPT2_Base_CN
```
这样就完成了一个文本生成的在线服务API的部署,默认端口号为8866。
**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
## 第二步:发送预测请求
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```python
import requests
import json
text = "今天是个好日子"
data = {
"text": text,
"mode": "sample", # 'search' or 'sample'
# 可以更加需要设置上述 API 中提到的其他参数
}
url = "http://127.0.0.1:8866/predict/GPT2_Base_CN"
headers = {"Content-Type": "application/json"}
r = requests.post(url=url, headers=headers, data=json.dumps(data))
```
## 查看代码
https://github.com/PaddlePaddle/PaddleNLP
## 依赖
paddlepaddle >= 2.0.0
paddlehub >= 2.0.0
sentencepiece==0.1.92
\ No newline at end of file
import paddle
import numpy as np
import paddle.nn as nn
from paddlehub.module.module import moduleinfo, serving
from paddlenlp.transformers import GPT2ForPretraining, GPT2ChineseTokenizer
@moduleinfo(
name="GPT2_Base_CN", # 模型名称
type="NLP/NLG", # 模型类型
author="jm12138", # 作者名称
author_email="jm12138@qq.com", # 作者邮箱
summary="GPT2_Base_CN", # 模型介绍
version="1.0.0" # 版本号
)
class GPT2_Base_CN(nn.Layer):
def __init__(self):
super(GPT2_Base_CN, self).__init__()
# 加载 PaddleNLP 自带的预训练中文 GPT2 模型
self.model = GPT2ForPretraining.from_pretrained('gpt2-base-cn')
# 设置模型为评估状态
self.model.eval()
# 加载编解码器
self.tokenizer = GPT2ChineseTokenizer.from_pretrained('gpt2-base-cn')
# 初始化编码器
_ = self.tokenizer.encode('_')
# Greedy Search
def greedy_search(self, text, max_len=32, end_word=None):
with paddle.no_grad():
# # 终止标志
if end_word is not None:
stop_id = self.tokenizer.encode(end_word)
length = len(stop_id)
# 初始预测
ids = self.tokenizer.encode(text)
input_id = paddle.to_tensor(
np.array(ids).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(input_id, use_cache=True)
next_token = int(np.argmax(output[0, -1].numpy()))
ids.append(next_token)
# 使用缓存进行继续预测
for i in range(max_len-1):
input_id = paddle.to_tensor(
np.array([next_token]).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(
input_id, use_cache=True, cache=cached_kvs)
next_token = int(np.argmax(output[0, -1].numpy()))
ids.append(next_token)
# 根据终止标志停止预测
if (end_word is not None) and (ids[-length:] == stop_id):
break
return self.tokenizer.decode(ids)
@staticmethod
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
top_k = min(top_k, logits.shape[-1]) # Safety check
logits_np = logits.numpy()
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits_np < np.sort(logits_np)[-top_k]
logits_np[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits = paddle.sort(logits, descending=True)
sorted_indices = paddle.argsort(logits, descending=True).numpy()
cumulative_probs = paddle.cumsum(paddle.nn.functional.softmax(
sorted_logits, axis=-1), axis=-1).numpy()
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[...,
1:] = sorted_indices_to_remove[..., :-1]
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits_np[indices_to_remove] = filter_value
return paddle.to_tensor(logits_np)
def nucleus_sample(self, text, max_len=32, end_word=None, repitition_penalty=1.0, temperature=1.0, top_k=0, top_p=1.0):
with paddle.no_grad():
# 终止标志
if end_word is not None:
stop_id = self.tokenizer.encode(end_word)
length = len(stop_id)
# 初始预测
ids = self.tokenizer.encode(text)
input_id = paddle.to_tensor(
np.array(ids).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(input_id, use_cache=True)
next_token_logits = output[0, -1, :]
for id in set(ids):
next_token_logits[id] /= repitition_penalty
next_token_logits = next_token_logits / temperature
filtered_logits = self.top_k_top_p_filtering(
next_token_logits, top_k=top_k, top_p=top_p)
next_token = paddle.multinomial(paddle.nn.functional.softmax(
filtered_logits, axis=-1), num_samples=1).numpy()
ids += [int(next_token)]
# 使用缓存进行继续预测
for i in range(max_len-1):
input_id = paddle.to_tensor(
np.array([next_token]).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(
input_id, use_cache=True, cache=cached_kvs)
next_token_logits = output[0, -1, :]
for id in set(ids):
next_token_logits[id] /= repitition_penalty
next_token_logits = next_token_logits / temperature
filtered_logits = self.top_k_top_p_filtering(
next_token_logits, top_k=top_k, top_p=top_p)
next_token = paddle.multinomial(paddle.nn.functional.softmax(
filtered_logits, axis=-1), num_samples=1).numpy()
ids += [int(next_token)]
# 根据终止标志停止预测
if (end_word is not None) and (ids[-length:] == stop_id):
break
return self.tokenizer.decode(ids)
# Hub Serving
@serving
def serving_method(self, text, mode='search', **kwargs):
if mode == 'search':
results = self.greedy_search(text, **kwargs)
else:
results = self.nucleus_sample(text, **kwargs)
return results
## 概述
CPM-LM 是一个基于 GPT-2 的预训练生成模型,参数规模达 26 亿,预训练中文数据规模 100 GB,使用了 64 块 V100 GPU,训练时间约为 3 周。能够在多种自然语言处理任务上进行零次学习或少次学习,并达到较好的效果。基于给定上文,模型可以续写出一致性高、可读性强的文本,达到现有中文生成模型的领先效果。
模型参数转换至官方开源项目,由于模型较大,推荐在GPU环境下运行,并且请确保运行环境的内存大于20G且显卡显存大于12G,否则可能无法正常运行
更多详情参考[清源CPM官网](https://cpm.baai.ac.cn)及其[Github项目主页](https://github.com/TsinghuaAI/CPM-Generate)
## API
```python
def greedy_search(
text,
max_len=32,
end_word=None):
```
文本生成 API ,根据输入的文字进行文本生成,使用 Greedy Search 进行解码,生成的文本单一且确定,适合于问答类的文本生成。
**参数**
* text (str) : 输入文本
* max_len (int) : 生成文本的最大长度
* end_word (str or None) : 终止生成的标志词
**返回**
* results (str): 生成的文本
```python
def sample(
text,
max_len=32,
end_word=None,
repitition_penalty=1.0,
temperature=1.0,
top_k=0,
top_p=1.0):
```
文本生成 API ,根据输入的文字进行文本生成,使用采样的方式进行解码,生成的文本比较多样,适合于文章类的文本生成。
**参数**
* text (str) : 输入文本
* max_len (int) : 生成文本的最大长度
* end_word (str or None) : 终止生成的标志词
* repitition_penalty (float) : 重复词抑制率,大于1抑制,小于1提高
* temperature (float) :较低的temperature可以让模型对最佳选择越来越有信息,大于1,则会降低,0则相当于 argmax/max ,inf则相当于均匀采样
* top_k (int) : 抑制小于 Top K 的输出,大于0时有效
* top_p (float) : 抑制小于 Top P 的输出,小于1.0时有效
**返回**
* results (str): 生成的文本
**代码示例**
* 加载模型:
```python
import paddlehub as hub
model = hub.Module(name='CPM_LM')
```
* 使用 Greedy Search 生成文本:
```python
inputs = '''默写古诗:
日照香炉生紫烟,遥看瀑布挂前川。
飞流直下三千尺,'''
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
> 默写古诗:
日照香炉生紫烟,遥看瀑布挂前川。
飞流直下三千尺,疑是银河落九天。
```python
inputs = '''问题:西游记是谁写的?
答案:'''
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
> 问题:西游记是谁写的?
答案:吴承恩。
```python
inputs = '''小明决定去吃饭,小红继续写作业
问题:去吃饭的人是谁?
答案:'''
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
> 小明决定去吃饭,小红继续写作业
问题:去吃饭的人是谁?
答案:小明
```python
inputs = '''默写英文:
狗:dog
猫:'''
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
> 默写英文:
狗:dog
猫:cat
* 使用采样方式生成文本:
```python
inputs = '''在此处输入文本的开头'''
outputs = model.sample(
inputs,
max_len=32,
end_word='。',
repitition_penalty=1.0,
temperature=1.0,
top_k=5,
top_p=1.0
)
print(outputs)
```
> 此处输入文本的开头,然后再输入下一个字符 ,就可以得到一个"HelloWorld!"的文本。
```python
inputs = '''方平带众人骑马出了城,残雪点缀原本泛黄的大地。他一身黑衣在一群铁甲士兵中尤其显眼。'''
outputs = model.sample(
inputs,
max_len=128,
end_word=None,
repitition_penalty=1.0,
temperature=1.0,
top_k=3000,
top_p=1.0
)
print(outputs)
```
> 方平带众人骑马出了城,残雪点缀原本泛黄的大地。他一身黑衣在一群铁甲士兵中尤其显眼。他负手彷徨,曾经在铜宫带领大军,横扫天下的军师如今只是位数趾高气扬的小卒,如今自己,连他身边的一个随从的支使就算不会武功也是位高权重。横刀立马,换来的是什么?他不知道,今天他走路都有些飘摇。
蓦然回眼看向熟悉而熟悉的军队,心中一阵轻松,走向来时军队时。忽然看见那位太史
## 查看代码
https://github.com/jm12138/CPM-Generate-Paddle
## 依赖
paddlepaddle >= 2.0.0rc0
paddlehub >= 2.0.0b1
## 概述
CPM-LM 是一个基于 GPT-2 的预训练生成模型,参数规模达 26 亿,预训练中文数据规模 100 GB,使用了 64 块 V100 GPU,训练时间约为 3 周。能够在多种自然语言处理任务上进行零次学习或少次学习,并达到较好的效果。基于给定上文,模型可以续写出一致性高、可读性强的文本,达到现有中文生成模型的领先效果。
模型参数转换至官方开源项目,由于模型较大,推荐在GPU环境下运行,并且请确保运行环境的内存大于20G且显卡显存大于12G,否则可能无法正常运行
更多详情参考[清源CPM官网](https://cpm.baai.ac.cn)及其[Github项目主页](https://github.com/TsinghuaAI/CPM-Generate)
## API
```python
def greedy_search(
text,
max_len=32,
end_word=None):
```
文本生成 API ,根据输入的文字进行文本生成,使用 Greedy Search 进行解码,生成的文本单一且确定,适合于问答类的文本生成。
**参数**
* text (str) : 输入文本
* max_len (int) : 生成文本的最大长度
* end_word (str or None) : 终止生成的标志词
**返回**
* results (str): 生成的文本
```python
def nucleus_sample(
text,
max_len=32,
end_word=None,
repitition_penalty=1.0,
temperature=1.0,
top_k=0,
top_p=1.0):
```
文本生成 API ,根据输入的文字进行文本生成,使用采样的方式进行解码,生成的文本比较多样,适合于文章类的文本生成。
**参数**
* text (str) : 输入文本
* max_len (int) : 生成文本的最大长度
* end_word (str or None) : 终止生成的标志词
* repitition_penalty (float) : 重复词抑制率,大于1抑制,小于1提高
* temperature (float) :较低的temperature可以让模型对最佳选择越来越有信息,大于1,则会降低,0则相当于 argmax/max ,inf则相当于均匀采样
* top_k (int) : 抑制小于 Top K 的输出,大于0时有效
* top_p (float) : 抑制小于 Top P 的输出,小于1.0时有效
**返回**
* results (str): 生成的文本
**代码示例**
* 加载模型:
```python
import paddlehub as hub
model = hub.Module(name='GPT2_CPM_LM')
```
* 使用 Greedy Search 生成文本:
```python
inputs = '''默写古诗:
日照香炉生紫烟,遥看瀑布挂前川。
飞流直下三千尺,'''
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
默写古诗:
日照香炉生紫烟,遥看瀑布挂前川。
飞流直下三千尺,疑是银河落九天。
```python
inputs = '''问题:西游记是谁写的?
答案:'''
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
问题:西游记是谁写的?
答案:吴承恩。
```python
inputs = '''小明决定去吃饭,小红继续写作业
问题:去吃饭的人是谁?
答案:'''
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
小明决定去吃饭,小红继续写作业
问题:去吃饭的人是谁?
答案:小明
```python
inputs = '''默写英文:
狗:dog
猫:'''
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
默写英文:
狗:dog
猫:cat
* 使用采样方式生成文本:
```python
inputs = '''在此处输入文本的开头'''
outputs = model.nucleus_sample(
inputs,
max_len=32,
end_word='。',
repitition_penalty=1.0,
temperature=1.0,
top_k=5,
top_p=1.0
)
print(outputs)
```
在此处输入文本的开头、结尾或中间部分,然后按下回车键。
```python
inputs = '''《乡土中国》是费孝通先生在社区研究的基础上从宏观角度探讨中国社会结构的著作,'''
outputs = model.nucleus_sample(
inputs,
max_len=32,
end_word='。',
repitition_penalty=1.0,
temperature=1.0,
top_k=3000,
top_p=1.0
)
print(outputs)
```
《乡土中国》是费孝通先生在社区研究的基础上从宏观角度探讨中国社会结构的著作,在书中,他试图向读者说明,中国社会的结构是如何形成、演变的,家庭伦理是如何形成、
## 服务部署
PaddleHub Serving 可以部署一个在线文本生成服务。
## 第一步:启动PaddleHub Serving
运行启动命令:
```shell
$ hub serving start --modules GPT2_CPM_LM
```
这样就完成了一个文本生成的在线服务API的部署,默认端口号为8866。
**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
## 第二步:发送预测请求
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```python
import requests
import json
text = "今天是个好日子"
data = {
"text": text,
"mode": "sample", # 'search' or 'sample'
# 可以更加需要设置上述 API 中提到的其他参数
}
url = "http://127.0.0.1:8866/predict/GPT2_CPM_LM"
headers = {"Content-Type": "application/json"}
r = requests.post(url=url, headers=headers, data=json.dumps(data))
```
## 查看代码
https://github.com/jm12138/CPM-Generate-Paddle
## 依赖
paddlepaddle >= 2.0.0
paddlehub >= 2.0.0
sentencepiece==0.1.92
\ No newline at end of file
import os
import paddle
import numpy as np
from paddle.nn import Layer
from paddlehub.module.module import moduleinfo
from CPM_LM.GPT2 import GPT2Model, GPT2Tokenizer
@moduleinfo(
name="CPM_LM", # 模型名称
type="NLP/NLG", # 模型类型
author="jm12138", # 作者名称
author_email="jm12138@qq.com", # 作者邮箱
summary="CPM_LM", # 模型介绍
version="1.0.1" # 版本号
)
class CPM_LM(Layer):
def __init__(self, max_len=512):
super(CPM_LM, self).__init__()
# 初始化模型
self.model = GPT2Model(
vocab_size=30000,
layer_size=32,
block_size=1024,
embedding_dropout=0.0,
embedding_size=2560,
num_attention_heads=32,
attention_dropout=0.0,
residual_dropout=0.0)
# 读取CPM-LM模型参数(FP16)
state_dict = paddle.load(os.path.join(self.directory, 'CPM-LM.pdparams'))
# FP16 -> FP32
for param in state_dict:
state_dict[param] = state_dict[param].astype('float32')
# 加载CPM-LM模型参数
self.model.set_dict(state_dict)
# 将模型设置为评估状态
self.model.eval()
# 加载编码器
self.tokenizer = GPT2Tokenizer(
vocab_file=os.path.join(self.directory, 'GPT2/bpe/vocab.json'),
model_file=os.path.join(self.directory, 'GPT2/bpe/chinese_vocab.model'),
max_len=max_len)
# 初始化编码器
_ = self.tokenizer.encode('_')
# greedy_search
def greedy_search(self, text, max_len=32, end_word=None):
# 终止标志
end_id = self.tokenizer.eod_id
if end_word is not None:
stop_id = self.tokenizer.encode(end_word)
length = len(stop_id)
# 初始预测
ids = self.tokenizer.encode(text)
input_id = paddle.to_tensor(np.array(ids).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(input_id, use_cache=True)
next_token = int(np.argmax(output[0, -1].numpy()))
ids.append(next_token)
# 使用缓存进行继续预测
for i in range(max_len - 1):
input_id = paddle.to_tensor(np.array([next_token]).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(input_id, cached_kvs, use_cache=True)
next_token = int(np.argmax(output[0, -1].numpy()))
ids.append(next_token)
if next_token == end_id:
break
# 根据终止标志停止预测
if (end_word is not None) and (ids[-length:] == stop_id):
break
return self.tokenizer.decode(ids)
@staticmethod
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
top_k = min(top_k, logits.shape[-1]) # Safety check
logits_np = logits.numpy()
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits_np < np.sort(logits_np)[-top_k]
logits_np[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits = paddle.sort(logits, descending=True)
sorted_indices = paddle.argsort(logits, descending=True).numpy()
cumulative_probs = paddle.cumsum(paddle.nn.functional.softmax(sorted_logits, axis=-1), axis=-1).numpy()
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1]
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits_np[indices_to_remove] = filter_value
return paddle.to_tensor(logits_np)
# sample
def sample(self, text, max_len=32, end_word=None, repitition_penalty=1.0, temperature=1.0, top_k=0, top_p=1.0):
# 终止标志
end_id = self.tokenizer.eod_id
if end_word is not None:
stop_id = self.tokenizer.encode(end_word)
length = len(stop_id)
# 初始预测
ids = self.tokenizer.encode(text)
input_id = paddle.to_tensor(np.array(ids).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(input_id, use_cache=True)
next_token_logits = output[0, -1, :]
for id in set(ids):
next_token_logits[id] /= repitition_penalty
next_token_logits = next_token_logits / temperature
next_token_logits[self.tokenizer.encoder['<unk>']] = -float('Inf')
filtered_logits = self.top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
next_token = paddle.multinomial(paddle.nn.functional.softmax(filtered_logits, axis=-1), num_samples=1).numpy()
ids += [int(next_token)]
# 使用缓存进行继续预测
for i in range(max_len - 1):
input_id = paddle.to_tensor(np.array([next_token]).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(input_id, cached_kvs, use_cache=True)
next_token_logits = output[0, -1, :]
for id in set(ids):
next_token_logits[id] /= repitition_penalty
next_token_logits = next_token_logits / temperature
next_token_logits[self.tokenizer.encoder['<unk>']] = -float('Inf')
filtered_logits = self.top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
next_token = paddle.multinomial(
paddle.nn.functional.softmax(filtered_logits, axis=-1), num_samples=1).numpy()
ids += [int(next_token)]
if next_token == end_id:
break
# 根据终止标志停止预测
if (end_word is not None) and (ids[-length:] == stop_id):
break
return self.tokenizer.decode(ids)
import os
import paddle
import numpy as np
import paddle.nn as nn
from paddlehub.module.module import moduleinfo, serving
from paddlenlp.transformers import GPT2ForPretraining, GPT2ChineseTokenizer, GPT2Model
@moduleinfo(
name="GPT2_CPM_LM", # 模型名称
type="NLP/NLG", # 模型类型
author="jm12138", # 作者名称
author_email="jm12138@qq.com", # 作者邮箱
summary="GPT2_CPM_LM", # 模型介绍
version="1.0.0" # 版本号
)
class GPT2_CPM_LM(nn.Layer):
def __init__(self):
super(GPT2_CPM_LM, self).__init__()
# 实例化模型
gpt2 = GPT2Model(
vocab_size=30000,
hidden_size=2560,
num_hidden_layers=32,
num_attention_heads=32,
intermediate_size=10240,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0
)
self.model = GPT2ForPretraining(gpt2)
# 读取CPM-LM模型参数(FP16)
state_dict = paddle.load(os.path.join(
self.directory, 'CPM-LM.pdparams'))
# FP16 -> FP32
for param in state_dict:
state_dict[param] = state_dict[param].astype('float32')
# 设置模型参数
self.model.set_dict(state_dict)
# 将模型设置为评估状态
self.model.eval()
# 加载编解码器
self.tokenizer = GPT2ChineseTokenizer(
vocab_file=os.path.join(self.directory, 'vocab.json'),
model_file=os.path.join(self.directory, 'chinese_vocab.model')
)
# 初始化编码器
_ = self.tokenizer.encode('_')
# Greedy Search
def greedy_search(self, text, max_len=32, end_word=None):
with paddle.no_grad():
# # 终止标志
if end_word is not None:
stop_id = self.tokenizer.encode(end_word)
length = len(stop_id)
# 初始预测
ids = self.tokenizer.encode(text)
input_id = paddle.to_tensor(
np.array(ids).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(input_id, use_cache=True)
next_token = int(np.argmax(output[0, -1].numpy()))
ids.append(next_token)
# 使用缓存进行继续预测
for i in range(max_len-1):
input_id = paddle.to_tensor(
np.array([next_token]).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(
input_id, use_cache=True, cache=cached_kvs)
next_token = int(np.argmax(output[0, -1].numpy()))
ids.append(next_token)
# 根据终止标志停止预测
if (end_word is not None) and (ids[-length:] == stop_id):
break
return self.tokenizer.decode(ids)
@staticmethod
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
top_k = min(top_k, logits.shape[-1]) # Safety check
logits_np = logits.numpy()
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits_np < np.sort(logits_np)[-top_k]
logits_np[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits = paddle.sort(logits, descending=True)
sorted_indices = paddle.argsort(logits, descending=True).numpy()
cumulative_probs = paddle.cumsum(paddle.nn.functional.softmax(
sorted_logits, axis=-1), axis=-1).numpy()
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[...,
1:] = sorted_indices_to_remove[..., :-1]
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits_np[indices_to_remove] = filter_value
return paddle.to_tensor(logits_np)
def nucleus_sample(self, text, max_len=32, end_word=None, repitition_penalty=1.0, temperature=1.0, top_k=0, top_p=1.0):
with paddle.no_grad():
# 终止标志
if end_word is not None:
stop_id = self.tokenizer.encode(end_word)
length = len(stop_id)
# 初始预测
ids = self.tokenizer.encode(text)
input_id = paddle.to_tensor(
np.array(ids).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(input_id, use_cache=True)
next_token_logits = output[0, -1, :]
for id in set(ids):
next_token_logits[id] /= repitition_penalty
next_token_logits = next_token_logits / temperature
next_token_logits[self.tokenizer.encoder['<unk>']] = -float('Inf')
filtered_logits = self.top_k_top_p_filtering(
next_token_logits, top_k=top_k, top_p=top_p)
next_token = paddle.multinomial(paddle.nn.functional.softmax(
filtered_logits, axis=-1), num_samples=1).numpy()
ids += [int(next_token)]
# 使用缓存进行继续预测
for i in range(max_len-1):
input_id = paddle.to_tensor(
np.array([next_token]).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(
input_id, use_cache=True, cache=cached_kvs)
next_token_logits = output[0, -1, :]
for id in set(ids):
next_token_logits[id] /= repitition_penalty
next_token_logits = next_token_logits / temperature
next_token_logits[self.tokenizer.encoder['<unk>']] = -float('Inf')
filtered_logits = self.top_k_top_p_filtering(
next_token_logits, top_k=top_k, top_p=top_p)
next_token = paddle.multinomial(paddle.nn.functional.softmax(
filtered_logits, axis=-1), num_samples=1).numpy()
ids += [int(next_token)]
# 根据终止标志停止预测
if (end_word is not None) and (ids[-length:] == stop_id):
break
return self.tokenizer.decode(ids)
# Hub Serving
@serving
def serving_method(self, text, mode='search', **kwargs):
if mode == 'search':
results = self.greedy_search(text, **kwargs)
else:
results = self.nucleus_sample(text, **kwargs)
return results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册