未验证 提交 8fa0291f 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

Add the CPM-LM Module

上级 27525955
from .model import GPT2Model
from .tokenization import GPT2Tokenizer
此差异已折叠。
import math
import paddle
import paddle.nn as nn
class MLP(nn.Layer):
def __init__(self, embedding_size):
super(MLP, self).__init__()
self.dense_h_to_4h = nn.Linear(embedding_size, embedding_size*4)
self.dense_4h_to_h = nn.Linear(embedding_size*4, embedding_size)
self.act = nn.functional.gelu
def forward(self, x):
h = self.act(self.dense_h_to_4h(x))
h2 = self.dense_4h_to_h(h)
return h2
class Attention(nn.Layer):
def __init__(self,
embedding_size,
num_attention_heads,
attention_dropout,
residual_dropout):
super(Attention, self).__init__()
self.num_attention_heads = num_attention_heads
self.size_per_head = embedding_size // num_attention_heads
self.embedding_size = embedding_size
self.query_key_value = nn.Linear(embedding_size, embedding_size * 3)
self.attn_drop = nn.Dropout(attention_dropout)
self.resid_drop = nn.Dropout(residual_dropout)
self.dense = nn.Linear(embedding_size, embedding_size)
def split_heads(self, x):
x = x.reshape([-1, self.seq_len, self.num_attention_heads, self.size_per_head])
return x.transpose((0, 2, 1, 3))
def forward(self, x, kv_cache=None):
self.seq_len = x.shape[1]
x = self.query_key_value(x)
q, k, v = x.split(num_or_sections=3, axis=2)
q = self.split_heads(q)
k = self.split_heads(k)
v = self.split_heads(v)
if kv_cache is not None:
pk, pv = paddle.unstack(kv_cache, axis=1)
k = paddle.concat([pk, k], axis=-2)
v = paddle.concat([pv, v], axis=-2)
cached_kv = paddle.stack([k, v], axis=1)
attn = paddle.matmul(q, k, transpose_y=True) # [B, N, L, S]
attn = attn / math.sqrt(self.size_per_head)
# [L, S]
attention_mask = paddle.tril(paddle.ones([self.seq_len, self.seq_len], 'float32'))
attention_mask = attention_mask.reshape([1, 1, self.seq_len, self.seq_len])
# adding to softmax -> its like removing them entirely
attn = attn * attention_mask - 10000.0 * (1.0 - attention_mask)
attn = nn.Softmax(axis=-1)(attn)
attn = self.attn_drop(attn)
y = paddle.matmul(attn, v)
# [B, N, L, S] -> [B, L, N, S]
y = y.transpose((0, 2, 1, 3))
y = paddle.reshape(y, [-1, self.seq_len, self.embedding_size])
y = self.resid_drop(self.dense(y))
return y, cached_kv
class Block(nn.Layer):
def __init__(self,
embedding_size,
num_attention_heads,
attention_dropout,
residual_dropout):
super(Block, self).__init__()
self.input_layernorm = nn.LayerNorm(embedding_size, epsilon=1e-5)
self.attention = Attention(embedding_size, num_attention_heads, attention_dropout, residual_dropout)
self.post_attention_layernorm = nn.LayerNorm(embedding_size, epsilon=1e-5)
self.mlp = MLP(embedding_size)
def forward(self, x, kv_cache=None):
attn, cached_kv = self.attention(self.input_layernorm(x), kv_cache=kv_cache)
x = x + attn
z = self.post_attention_layernorm(x)
z = self.mlp(z)
x = x + z
return x, cached_kv
class Transformer(nn.Layer):
def __init__(self,
layer_size,
embedding_size,
num_attention_heads,
attention_dropout,
residual_dropout):
super(Transformer, self).__init__()
self.layers = nn.LayerList([Block(
embedding_size,
num_attention_heads,
attention_dropout,
residual_dropout)
for _ in range(layer_size)])
self.final_layernorm = nn.LayerNorm(embedding_size, epsilon=1e-5)
def forward(self, x, kv_cache=None):
cached_kvs = []
for i, layer in enumerate(self.layers):
x, cached_kv = layer(x, kv_cache=kv_cache[i] if kv_cache is not None else None)
cached_kvs.append(cached_kv)
x = self.final_layernorm(x)
return x, paddle.stack(cached_kvs)
class GPT2Model(nn.Layer):
def __init__(self,
vocab_size,
layer_size,
block_size,
embedding_dropout,
embedding_size,
num_attention_heads,
attention_dropout,
residual_dropout):
super(GPT2Model, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, embedding_size)
self.position_embeddings = nn.Embedding(block_size, embedding_size)
self.emb_drop = nn.Dropout(embedding_dropout)
self.transformer = Transformer(
layer_size,
embedding_size,
num_attention_heads,
attention_dropout,
residual_dropout)
def forward(self, x, kv_cache=None, use_cache=False):
if kv_cache is None:
past_length = 0
else:
past_length = kv_cache[0][0].shape[-2]
position_ids = paddle.arange(past_length, x.shape[-1] + past_length, dtype='int64')
position_ids = position_ids.unsqueeze(0).expand_as(x)
x = self.word_embeddings(x)
x = self.emb_drop(x + self.position_embeddings(position_ids))
x, cached_kvs = self.transformer(x, kv_cache)
x = paddle.matmul(x, self.word_embeddings.weight, transpose_y=True)
if use_cache:
return x, cached_kvs
else:
return x
if __name__ == '__main__':
gpt = GPT2Model(
vocab_size=30000,
layer_size=2,
block_size=1024,
embedding_dropout=0.0,
embedding_size=2560,
num_attention_heads=32,
attention_dropout=0.0,
residual_dropout=0.0)
gpt.eval()
out, cached_kvs = gpt(paddle.ones([1,1], 'int64'), paddle.randn([32, 1, 2, 32, 9, 80], 'float32'), use_cache=True)
print(out.shape, cached_kvs.shape)
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import json
from io import open
import sentencepiece as spm
import jieba
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
class GPT2Tokenizer(object):
def __init__(self, vocab_file, model_file, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
self.sp = spm.SentencePieceProcessor(model_file=model_file)
self.translator = str.maketrans(" \n", "\u2582\u2583")
self.eod_id = self.encoder['<eod>']
@property
def vocab_size(self):
return len(self.encoder)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
@property
def eod(self):
return self.eod_id
def tokenize(self, text):
""" Tokenize a string. """
seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)]
new_seg = " ".join(seg_list)
return self.sp.encode(new_seg)
def encode(self, text):
res = self.tokenize(text)
return res
def decode(self, tokens):
text = self.sp.decode(tokens)
text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')
return text
## 概述
CPM-LM 是一个基于 GPT-2 的预训练生成模型,参数规模达 26 亿,预训练中文数据规模 100 GB,使用了 64 块 V100 GPU,训练时间约为 3 周。能够在多种自然语言处理任务上进行零次学习或少次学习,并达到较好的效果。基于给定上文,模型可以续写出一致性高、可读性强的文本,达到现有中文生成模型的领先效果。
模型参数转换至官方开源项目,由于模型较大,推荐在GPU环境下运行,并且请确保运行环境的内存大于20G且显卡显存大于12G,否则可能无法正常运行
更多详情参考[清源CPM官网](https://cpm.baai.ac.cn)及其[Github项目主页](https://github.com/TsinghuaAI/CPM-Generate)
## API
```python
def predict(text, max_len=32, end_word=None):
```
预测 API ,根据输入的文字进行文本生成,使用 Greedy Search 进行解码。
**参数**
* text (str) : 输入文本
* max_len (int) : 生成文本的最大长度
* end_word (str or None) : 终止生成的标志词
**返回**
* results (str): 生成的文本
```python
def tokenizer.encode(text):
```
编码 API
**参数**
* text (str) : 输入文本
**返回**
* results (list[int]) : 输出编码
```python
def tokenizer.decode(ids):
```
解码 API
**参数**
* ids (list[int]) : 输入编码
**返回**
* results (str) : 输出文本
```python
def model(x, kv_cache=None, use_cache=False):
```
模型前向计算 API
**参数**
* x (tensor) : 输入编码
* kv_cache (tensor) : 输入的缓存
* use_cache (bool) : 是否使用缓存
**返回**
* results (tensor) : 模型输出
**代码示例**
```python
import paddlehub as hub
model = hub.Module(name='CPM_LM')
```
```python
inputs = '''默写古诗:
日照香炉生紫烟,遥看瀑布挂前川。
飞流直下三千尺,'''
outputs = model.predict(inputs, max_len=10, end_word='\n')
print(inputs+outputs)
```
> 默写古诗:
日照香炉生紫烟,遥看瀑布挂前川。
飞流直下三千尺,疑是银河落九天。
```python
inputs = '''问题:西游记是谁写的?
答案:'''
outputs = model.predict(inputs, max_len=10, end_word='\n')
print(inputs+outputs)
```
> 问题:西游记是谁写的?
答案:吴承恩。
```python
inputs = '''小明决定去吃饭,小红继续写作业
问题:去吃饭的人是谁?
答案:'''
outputs = model.predict(inputs, max_len=10, end_word='\n')
print(inputs+outputs)
```
> 小明决定去吃饭,小红继续写作业
问题:去吃饭的人是谁?
答案:小明
```python
inputs = '''默写英文:
狗:dog
猫:'''
outputs = model.predict(inputs, max_len=10, end_word='\n')
print(inputs+outputs)
```
> 默写英文:
狗:dog
猫:cat
## 查看代码
https://github.com/jm12138/CPM-Generate-Paddle
## 依赖
paddlepaddle >= 2.0.0rc0
paddlehub >= 2.0.0b1
\ No newline at end of file
import os
import paddle
import numpy as np
from paddle.nn import Layer
from paddlehub.module.module import moduleinfo
from CPM_LM.GPT2 import GPT2Model, GPT2Tokenizer
@moduleinfo(
name="CPM_LM", # 模型名称
type="NLP/NLG", # 模型类型
author="jm12138", # 作者名称
author_email="jm12138@qq.com", # 作者邮箱
summary="CPM_LM", # 模型介绍
version="1.0.0" # 版本号
)
class CPM_LM(Layer):
def __init__(self, max_len=512):
super(CPM_LM, self).__init__()
# 初始化模型
self.model = GPT2Model(
vocab_size=30000,
layer_size=32,
block_size=1024,
embedding_dropout=0.0,
embedding_size=2560,
num_attention_heads=32,
attention_dropout=0.0,
residual_dropout=0.0)
# 读取CPM-LM模型参数(FP16)
state_dict = paddle.load(os.path.join(self.directory, 'CPM-LM.pdparams'))
# FP16 -> FP32
for param in state_dict:
state_dict[param] = state_dict[param].astype('float32')
# 加载CPM-LM模型参数
self.model.set_dict(state_dict)
# 将模型设置为评估状态
self.model.eval()
# 加载编码器
self.tokenizer = GPT2Tokenizer(
vocab_file=os.path.join(self.directory, 'GPT2/bpe/vocab.json'),
model_file=os.path.join(self.directory, 'GPT2/bpe/chinese_vocab.model'),
max_len=max_len)
# 初始化编码器
_ = self.tokenizer.encode('_')
# 基础预测函数
def predict(self, text, max_len=32, end_word=None):
# 终止标志
if end_word is not None:
end_id = self.tokenizer.encode(end_word)
length = len(end_id)
else:
end_id = self.tokenizer.eod_id
# 初始预测
ids = self.tokenizer.encode(text)
input_id = paddle.to_tensor(np.array(ids).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(input_id, use_cache=True)
nid = int(np.argmax(output[0, -1].numpy()))
out = [nid]
# 使用缓存进行继续预测
for i in range(max_len-1):
input_id = paddle.to_tensor(np.array([nid]).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(input_id, cached_kvs, use_cache=True)
nid = int(np.argmax(output[0, -1].numpy()))
# 根据终止标志停止预测
if (end_word is not None) and (out[-length+1:]+[nid]==end_id):
break
elif (end_word is None) and (nid==end_id):
break
out.append(nid)
return self.tokenizer.decode(out)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册