未验证 提交 d77986a8 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

Update CPM-LM module (#1141)

* update CPM-LM module

* update CPM-LM module

* update CPM-LM module
上级 91f3dd35
......@@ -7,9 +7,12 @@ CPM-LM 是一个基于 GPT-2 的预训练生成模型,参数规模达 26 亿
## API
```python
def predict(text, max_len=32, end_word=None):
def greedy_search(
text,
max_len=32,
end_word=None):
```
预测 API ,根据输入的文字进行文本生成,使用 Greedy Search 进行解码
文本生成 API ,根据输入的文字进行文本生成,使用 Greedy Search 进行解码,生成的文本单一且确定,适合于问答类的文本生成
**参数**
* text (str) : 输入文本
......@@ -20,52 +23,43 @@ def predict(text, max_len=32, end_word=None):
* results (str): 生成的文本
```python
def tokenizer.encode(text):
def sample(
text,
max_len=32,
end_word=None,
repitition_penalty=1.0,
temperature=1.0,
top_k=0,
top_p=1.0):
```
编码 API
文本生成 API ,根据输入的文字进行文本生成,使用采样的方式进行解码,生成的文本比较多样,适合于文章类的文本生成。
**参数**
* text (str) : 输入文本
* max_len (int) : 生成文本的最大长度
* end_word (str or None) : 终止生成的标志词
* repitition_penalty (float) : 重复词抑制率,大于1抑制,小于1提高
* temperature (float) :较低的temperature可以让模型对最佳选择越来越有信息,大于1,则会降低,0则相当于 argmax/max ,inf则相当于均匀采样
* top_k (int) : 抑制小于 Top K 的输出,大于0时有效
* top_p (float) : 抑制小于 Top P 的输出,小于1.0时有效
**返回**
* results (list[int]) : 输出编码
```python
def tokenizer.decode(ids):
```
解码 API
**参数**
* ids (list[int]) : 输入编码
**返回**
* results (str) : 输出文本
```python
def model(x, kv_cache=None, use_cache=False):
```
模型前向计算 API
**参数**
* x (tensor) : 输入编码
* kv_cache (tensor) : 输入的缓存
* use_cache (bool) : 是否使用缓存
**返回**
* results (tensor) : 模型输出
* results (str): 生成的文本
**代码示例**
* 加载模型:
```python
import paddlehub as hub
model = hub.Module(name='CPM_LM')
```
* 使用 Greedy Search 生成文本:
```python
inputs = '''默写古诗:
日照香炉生紫烟,遥看瀑布挂前川。
飞流直下三千尺,'''
outputs = model.predict(inputs, max_len=10, end_word='\n')
print(inputs+outputs)
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
> 默写古诗:
日照香炉生紫烟,遥看瀑布挂前川。
......@@ -73,8 +67,8 @@ print(inputs+outputs)
```python
inputs = '''问题:西游记是谁写的?
答案:'''
outputs = model.predict(inputs, max_len=10, end_word='\n')
print(inputs+outputs)
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
> 问题:西游记是谁写的?
答案:吴承恩。
......@@ -82,8 +76,8 @@ print(inputs+outputs)
inputs = '''小明决定去吃饭,小红继续写作业
问题:去吃饭的人是谁?
答案:'''
outputs = model.predict(inputs, max_len=10, end_word='\n')
print(inputs+outputs)
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
> 小明决定去吃饭,小红继续写作业
问题:去吃饭的人是谁?
......@@ -92,13 +86,52 @@ print(inputs+outputs)
inputs = '''默写英文:
狗:dog
猫:'''
outputs = model.predict(inputs, max_len=10, end_word='\n')
print(inputs+outputs)
outputs = model.greedy_search(inputs, max_len=10, end_word='\n')
print(outputs)
```
> 默写英文:
狗:dog
猫:cat
* 使用采样方式生成文本:
```python
inputs = '''在此处输入文本的开头'''
outputs = model.sample(
inputs,
max_len=32,
end_word='。',
repitition_penalty=1.0,
temperature=1.0,
top_k=5,
top_p=1.0
)
print(outputs)
```
> 此处输入文本的开头,然后再输入下一个字符 ,就可以得到一个"HelloWorld!"的文本。
```python
inputs = '''方平带众人骑马出了城,残雪点缀原本泛黄的大地。他一身黑衣在一群铁甲士兵中尤其显眼。'''
outputs = model.sample(
inputs,
max_len=128,
end_word=None,
repitition_penalty=1.0,
temperature=1.0,
top_k=3000,
top_p=1.0
)
print(outputs)
```
> 方平带众人骑马出了城,残雪点缀原本泛黄的大地。他一身黑衣在一群铁甲士兵中尤其显眼。他负手彷徨,曾经在铜宫带领大军,横扫天下的军师如今只是位数趾高气扬的小卒,如今自己,连他身边的一个随从的支使就算不会武功也是位高权重。横刀立马,换来的是什么?他不知道,今天他走路都有些飘摇。
蓦然回眼看向熟悉而熟悉的军队,心中一阵轻松,走向来时军队时。忽然看见那位太史
## 查看代码
https://github.com/jm12138/CPM-Generate-Paddle
......
......@@ -11,7 +11,7 @@ from CPM_LM.GPT2 import GPT2Model, GPT2Tokenizer
author="jm12138", # 作者名称
author_email="jm12138@qq.com", # 作者邮箱
summary="CPM_LM", # 模型介绍
version="1.0.0" # 版本号
version="1.0.1" # 版本号
)
class CPM_LM(Layer):
def __init__(self, max_len=512):
......@@ -49,34 +49,109 @@ class CPM_LM(Layer):
# 初始化编码器
_ = self.tokenizer.encode('_')
# 基础预测函数
def predict(self, text, max_len=32, end_word=None):
# greedy_search
def greedy_search(self, text, max_len=32, end_word=None):
# 终止标志
end_id = self.tokenizer.eod_id
if end_word is not None:
end_id = self.tokenizer.encode(end_word)
length = len(end_id)
else:
end_id = self.tokenizer.eod_id
stop_id = self.tokenizer.encode(end_word)
length = len(stop_id)
# 初始预测
ids = self.tokenizer.encode(text)
input_id = paddle.to_tensor(np.array(ids).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(input_id, use_cache=True)
nid = int(np.argmax(output[0, -1].numpy()))
out = [nid]
next_token = int(np.argmax(output[0, -1].numpy()))
ids.append(next_token)
# 使用缓存进行继续预测
for i in range(max_len-1):
input_id = paddle.to_tensor(np.array([nid]).reshape(1, -1).astype('int64'))
input_id = paddle.to_tensor(np.array([next_token]).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(input_id, cached_kvs, use_cache=True)
nid = int(np.argmax(output[0, -1].numpy()))
next_token = int(np.argmax(output[0, -1].numpy()))
ids.append(next_token)
# 根据终止标志停止预测
if (end_word is not None) and (out[-length+1:]+[nid]==end_id):
if next_token==end_id:
break
elif (end_word is None) and (nid==end_id):
# 根据终止标志停止预测
if (end_word is not None) and (ids[-length:]==stop_id):
break
return self.tokenizer.decode(ids)
@staticmethod
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
top_k = min(top_k, logits.shape[-1]) # Safety check
logits_np = logits.numpy()
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits_np < np.sort(logits_np)[-top_k]
logits_np[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits = paddle.sort(logits, descending=True)
sorted_indices = paddle.argsort(logits, descending=True).numpy()
cumulative_probs = paddle.cumsum(paddle.nn.functional.softmax(sorted_logits, axis=-1), axis=-1).numpy()
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1]
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits_np[indices_to_remove] = filter_value
return paddle.to_tensor(logits_np)
# sample
def sample(self, text, max_len=32, end_word=None, repitition_penalty=1.0, temperature=1.0, top_k=0, top_p=1.0):
# 终止标志
end_id = self.tokenizer.eod_id
if end_word is not None:
stop_id = self.tokenizer.encode(end_word)
length = len(stop_id)
# 初始预测
ids = self.tokenizer.encode(text)
input_id = paddle.to_tensor(np.array(ids).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(input_id, use_cache=True)
next_token_logits = output[0, -1, :]
for id in set(ids):
next_token_logits[id] /= repitition_penalty
next_token_logits = next_token_logits / temperature
next_token_logits[self.tokenizer.encoder['<unk>']] = -float('Inf')
filtered_logits = self.top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
next_token = paddle.multinomial(paddle.nn.functional.softmax(filtered_logits, axis=-1), num_samples=1).numpy()
ids += [int(next_token)]
# 使用缓存进行继续预测
for i in range(max_len-1):
input_id = paddle.to_tensor(np.array([next_token]).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(input_id, cached_kvs, use_cache=True)
next_token_logits = output[0, -1, :]
for id in set(ids):
next_token_logits[id] /= repitition_penalty
next_token_logits = next_token_logits / temperature
next_token_logits[self.tokenizer.encoder['<unk>']] = -float('Inf')
filtered_logits = self.top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
next_token = paddle.multinomial(paddle.nn.functional.softmax(filtered_logits, axis=-1), num_samples=1).numpy()
ids += [int(next_token)]
out.append(nid)
return self.tokenizer.decode(out)
\ No newline at end of file
if next_token==end_id:
break
# 根据终止标志停止预测
if (end_word is not None) and (ids[-length:]==stop_id):
break
return self.tokenizer.decode(ids)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册