未验证 提交 b76f1591 编写于 作者: X xiemoyuan 提交者: GitHub

Upgrade plato2 using paddle2.0 (#5002)

* The first version of plato2. Not finished the network.

* Update decode stratage.

* Update decode stratage.

* Completed the encoder and decoder. But it will oom.

* Completed the encoder and decoder.

* backend the code.

* Only completed the network of plato2 and nsp.

* Completed the development. But the effect has not be verified.

* Add readme and remove the code about deal PY2 and PY3.

* Modify comment.

* Modify readme and add images.

* Delete the data folder.
上级 3366cf65
# PLATO-2
## 模型简介
构建高质量的开放领域(Open-Domain)的对话机器人,使得它能用自然语言与人自由地交流,这一直是自然语言处理领域终极目标之一。
为了能够简易地构建一个高质量的开放域聊天机器人,本项目在Paddle2.0上实现了PLATO-2的预测模型,并基于终端实现了简单的人机交互。用户可以通过下载预训练模型快速构建一个开放域聊天机器人。
PLATO-2的网络结构及评估结果见下图:
![image](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/dialogue/plato-2/imgs/network.png)
![image](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/dialogue/plato-2/imgs/eval_en.png)
![image](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/dialogue/plato-2/imgs/eval_cn.png)
PLATO-2的训练过程及其他细节详见 [Knover](https://github.com/PaddlePaddle/Knover)
## 快速开始
### 安装说明
* PaddlePaddle 安装
本项目依赖于 PaddlePaddle 2.0 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
* PaddleNLP 安装
```shell
pip install paddlenlp>=2.0.0b
```
* 环境依赖
Python的版本要求 3.6+
本项目依赖sentencepiece和termcolor,请在运行本项目之前进行安装
```shell
pip install sentencepiece termcolor
```
### 代码结构说明
以下是本项目主要代码结构及说明:
```text
.
├── interaction.py # 交互主程序入口
├── model.py # 模型组网
├── readers
│   ├── dialog_reader.py # 模型输入数据生成
│   ├── nsp_reader.py # 模型输入数据生成
│   └── plato_reader.py # 模型输入数据生成
├── utils
│   ├── __init__.py # 基础函数
│   ├── args.py # 运行参数配置
│   ├── masking.py # mask相关函数
│   └── tokenization.py # 分词相关函数
├── imgs # 示例图存储文件夹
└── README.md # 说明文档
```
### 数据准备
您可以从以下位置下载预训练模型文件:
* PLATO-2, 24-layers, 16-heads, 1024-hidden, EN: [预训练模型](https://paddlenlp.bj.bcebos.com/models/transformers/plato2/24L.pdparams)
* PLATO-2, 32-layers, 32-heads, 2048-hidden, EN: [预训练模型](https://paddlenlp.bj.bcebos.com/models/transformers/plato2/32L.pdparams)
以24层预训练模型为例:
```shell
wget https://paddlenlp.bj.bcebos.com/models/transformers/plato2/24L.pdparams
```
**NOTE:** PLATO-2网络参数量较大,24层网络至少需要显存16G,32层网络至少需要显存22G,用户可选择合适的网络层数及预训练模型。
sentencepiece分词预训练模型和词表文件下载:
```shell
wget https://paddlenlp.bj.bcebos.com/models/transformers/plato2/data.tar.gz
tar -zxf data.tar.gz
```
### 人机交互
运行如下命令即可开始与聊天机器人用英语进行简单的对话
```shell
export CUDA_VISIBLE_DEVICES=0
python interaction.py --vocab_path ./data/vocab.txt --spm_model_file ./data/spm.model --num_layers 24 --init_from_ckpt ./24L.pdparams
```
以上参数表示:
* vocab_path:词表文件路径。
* spm_model_file:sentencepiece分词预训练模型路径。
* num_layers:PLATO-2组网层数。
* init_from_ckpt:PLATO-2预训练模型路径。
32层PLATO-2网络交互示例:
![image](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/dialogue/plato-2/imgs/case.jpg)
import json
import argparse
from collections import namedtuple
from termcolor import colored, cprint
import paddle
from utils.args import parse_args, str2bool
from utils import gen_inputs
from readers.nsp_reader import NSPReader
from readers.plato_reader import PlatoReader
from model import Plato2InferModel
def setup_args():
"""Setup arguments."""
parser = argparse.ArgumentParser()
group = parser.add_argument_group("Model")
group.add_argument("--init_from_ckpt", type=str, default="")
group.add_argument("--vocab_size", type=int, default=8001)
group.add_argument("--latent_type_size", type=int, default=20)
group.add_argument("--num_layers", type=int, default=24)
group = parser.add_argument_group("Task")
group.add_argument("--is_cn", type=str2bool, default=False)
args, _ = parser.parse_known_args()
NSPReader.add_cmdline_args(parser)
args = parse_args(parser)
args.batch_size *= args.latent_type_size
#print(json.dumps(args, indent=2))
return args
def load_params(model, init_from_ckpt):
state_dict = paddle.load(init_from_ckpt)
model.set_state_dict(state_dict)
def interact(args):
"""Inference main function."""
plato_reader = PlatoReader(args)
nsp_reader = NSPReader(args)
if args.num_layers == 24:
n_head = 16
hidden_size = 1024
elif args.num_layers == 32:
n_head = 32
hidden_size = 2048
else:
raise ValueError('The pre-trained model only support 24 or 32 layers, '
'but received num_layers=%d.' % args.num_layers)
model = Plato2InferModel(nsp_reader, args.num_layers, n_head, hidden_size)
load_params(model, args.init_from_ckpt)
model.eval()
Example = namedtuple("Example", ["src", "data_id"])
context = []
start_info = "Enter [EXIT] to quit the interaction, [NEXT] to start a new conversation."
cprint(start_info, "yellow", attrs=["bold"])
while True:
user_utt = input(colored("[Human]: ", "red", attrs=["bold"])).strip()
if user_utt == "[EXIT]":
break
elif user_utt == "[NEXT]":
context = []
cprint(start_info, "yellow", attrs=["bold"])
else:
context.append(user_utt)
example = Example(src=" [SEP] ".join(context), data_id=0)
record = plato_reader._convert_example_to_record(
example, is_infer=True)
data = plato_reader._pad_batch_records([record], is_infer=True)
inputs = gen_inputs(data, args.latent_type_size)
pred = model(inputs)[0]
bot_response = pred["response"]
print(
colored(
"[Bot]:", "blue", attrs=["bold"]),
colored(
bot_response, attrs=["bold"]))
context.append(bot_response)
return
if __name__ == "__main__":
args = setup_args()
interact(args)
from collections import namedtuple
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def post_process_context(token_ids, reader, merge=True):
"""Post-process the context sequence."""
context = []
utt = []
for tok_id in token_ids[1:]:
if tok_id == reader.eos_id:
utt = reader.tokenizer.convert_ids_to_tokens(utt)
if merge:
utt = reader.tokenizer.merge_subword(utt)
context.append(utt)
utt = []
else:
utt.append(tok_id)
return context
def post_process_response(token_ids, reader, merge=True):
"""
Post-process the decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(token_ids)
for i, tok_id in enumerate(token_ids):
if tok_id == reader.eos_id:
eos_pos = i
break
token_ids = token_ids[1:eos_pos]
response = reader.tokenizer.convert_ids_to_tokens(token_ids)
if merge:
response = reader.tokenizer.merge_subword(response)
return token_ids, response
def get_cross_turn_repetition(context, pred_tokens, eos_idx, is_cn=False):
"""Get cross-turn repetition."""
if len(pred_tokens) == 0:
return 1.0
if is_cn:
context = ["".join(utt) for utt in context]
pred_tokens = "".join(pred_tokens)
pred_tri_grams = set()
for i in range(len(pred_tokens) - 2):
tri_gram = tuple(pred_tokens[i:i + 3])
pred_tri_grams.add(tri_gram)
for utt in context:
for i in range(len(utt) - 2):
tri_gram = tuple(utt[i:i + 3])
if tri_gram in pred_tri_grams:
return 1.0
return 0.0
def get_in_turn_repetition(pred, is_cn=False):
"""Get in-turn repetition."""
if len(pred) == 0:
return 1.0
if isinstance(pred[0], str):
pred = [tok.lower() for tok in pred]
if is_cn:
pred = "".join(pred)
tri_grams = set()
for i in range(len(pred) - 2):
tri_gram = tuple(pred[i:i + 3])
if tri_gram in tri_grams:
return 1.0
tri_grams.add(tri_gram)
return 0.0
class Plato2EncoderLayer(nn.Layer):
def __init__(self, n_head, hidden_size, attn_dropout, act_dropout):
super(Plato2EncoderLayer, self).__init__()
self.self_attn = nn.MultiHeadAttention(hidden_size, n_head,
attn_dropout)
self.pre_norm_layer = nn.LayerNorm(hidden_size)
self.post_norm_layer = nn.LayerNorm(hidden_size)
self.fc1 = nn.Linear(hidden_size, hidden_size * 4)
self.fc2 = nn.Linear(hidden_size * 4, hidden_size)
self.dropout_layer = nn.Dropout(act_dropout)
self.gelu_layer = nn.GELU()
def forward(self, x, attn_mask, cache):
query = self.pre_norm_layer(x)
attn_output, new_cache = self.self_attn(query, None, None, attn_mask,
cache)
attn_output = self.dropout_layer(attn_output)
attn_output = attn_output + x
ffd_input = self.post_norm_layer(attn_output)
ffd_output = self.fc1(ffd_input)
ffd_output = self.gelu_layer(ffd_output)
ffd_output = self.dropout_layer(ffd_output)
ffd_output = self.fc2(ffd_output)
ffd_output = self.dropout_layer(ffd_output)
out = ffd_output + attn_output
return out, new_cache
def gen_cache(self, key):
return self.self_attn.gen_cache(key)
class Plato2Encoder(nn.Layer):
def __init__(self, vocab_size, type_size, max_position_seq_len, num_layers,
n_head, hidden_size, attn_dropout, act_dropout):
super(Plato2Encoder, self).__init__()
self.n_head = n_head
self.word_embedding_layer = nn.Embedding(vocab_size, hidden_size)
self.sent_embedding_layer = nn.Embedding(type_size, hidden_size)
self.pos_embedding_layer = nn.Embedding(max_position_seq_len,
hidden_size)
self.encoder_layers = []
for i in range(num_layers):
encoder_layer = Plato2EncoderLayer(n_head, hidden_size,
attn_dropout, act_dropout)
self.encoder_layers.append(encoder_layer)
self.add_sublayer('layers.' + str(i), encoder_layer)
self.post_encoder_layer_norm = nn.LayerNorm(hidden_size)
self.dropout_layer = nn.Dropout(act_dropout)
def forward(self,
caches,
token_ids,
type_ids,
pos_ids,
generation_mask,
aux_emb=None):
out, self_attn_mask = self.gen_input(token_ids, type_ids, pos_ids,
generation_mask, aux_emb)
new_caches = []
for i, encoder_layer in enumerate(self.encoder_layers):
out, new_cache = encoder_layer(out, self_attn_mask, caches[i])
new_caches.append(new_cache)
enc_output = self.post_encoder_layer_norm(out)
return enc_output, new_caches
def gen_input(self, token_ids, type_ids, pos_ids, input_mask, aux_emb=None):
token_emb_out = self.word_embedding_layer(token_ids)
type_emb_out = self.sent_embedding_layer(type_ids)
pos_emb_out = self.pos_embedding_layer(pos_ids)
emb_out = token_emb_out + type_emb_out + pos_emb_out
# auxiliary memory embeddings
if aux_emb is not None:
emb_out = paddle.concat([aux_emb, emb_out], axis=1)
emb_out = self.dropout_layer(emb_out)
# generate n-head self-attention mask
self_attn_mask = input_mask
self_attn_mask = paddle.scale(
x=self_attn_mask, scale=1e4, bias=-1.0, bias_after_scale=False)
n_head_self_attn_mask = paddle.stack(
x=[self_attn_mask] * self.n_head, axis=1)
n_head_self_attn_mask.stop_gradient = True
return emb_out, n_head_self_attn_mask
def gen_caches(self, key):
caches = [
encoder_layer.gen_cache(key)
for encoder_layer in self.encoder_layers
]
return caches
class NSP(nn.Layer):
def __init__(self, vocab_size, type_size, max_position_seq_len, num_layers,
n_head, hidden_size, attn_dropout, act_dropout):
super(NSP, self).__init__()
self.n_head = n_head
self.hidden_size = hidden_size
self.word_embedding_layer = nn.Embedding(vocab_size, hidden_size)
self.sent_embedding_layer = nn.Embedding(type_size, hidden_size)
self.pos_embedding_layer = nn.Embedding(max_position_seq_len,
hidden_size)
encoder_layer = nn.TransformerEncoderLayer(
hidden_size, n_head, hidden_size * 4, act_dropout, 'gelu',
attn_dropout, act_dropout, 'True')
encoder_norm = nn.LayerNorm(hidden_size)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers,
encoder_norm)
self.fc1 = nn.Linear(hidden_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, 2)
self.dropout_layer = nn.Dropout(act_dropout)
self.tanh_layer = nn.Tanh()
self.softmax = nn.Softmax()
def forward(self, inputs):
token_ids = inputs['token_ids']
type_ids = inputs['type_ids']
pos_ids = inputs['pos_ids']
attention_mask = inputs['attention_mask']
label_pos = inputs["label_pos"]
out, self_attn_mask = self.gen_input(token_ids, type_ids, pos_ids,
attention_mask)
# [-1, seq_len, hidden_size]
enc_out = self.encoder(out, self_attn_mask)
enc_out = paddle.reshape(enc_out, [-1, self.hidden_size])
label_pos = paddle.cast(label_pos, 'int64')
out = paddle.gather(enc_out, label_pos)
pooled_out = self.fc1(out)
pooled_out = self.tanh_layer(pooled_out)
# [-1, 2]
logits = self.fc2(pooled_out)
probs = self.softmax(logits)
return probs
def gen_input(self, token_ids, type_ids, pos_ids, input_mask, aux_emb=None):
token_emb_out = self.word_embedding_layer(token_ids)
type_emb_out = self.sent_embedding_layer(type_ids)
pos_emb_out = self.pos_embedding_layer(pos_ids)
emb_out = token_emb_out + type_emb_out + pos_emb_out
# auxiliary memory embeddings
if aux_emb is not None:
emb_out = paddle.concat([aux_emb, emb_out], axis=1)
emb_out = self.dropout_layer(emb_out)
# generate n-head self-attention mask
self_attn_mask = input_mask
self_attn_mask = paddle.scale(
x=self_attn_mask, scale=1e4, bias=-1.0, bias_after_scale=False)
n_head_self_attn_mask = paddle.stack(
x=[self_attn_mask] * self.n_head, axis=1)
n_head_self_attn_mask.stop_gradient = True
return emb_out, n_head_self_attn_mask
class Plato2InferModel(nn.Layer):
def __init__(self,
nsp_reader,
num_layers,
n_head,
hidden_size,
vocab_size=8001,
type_size=2,
latent_type_size=20,
max_position_seq_len=256,
act_dropout=0.1,
attn_dropout=0.1,
max_dec_len=64,
min_dec_len=1,
topk=10):
super(Plato2InferModel, self).__init__()
self.nsp_reader = nsp_reader
self.num_layers = num_layers
self.latent_type_size = latent_type_size
self.max_dec_len = max_dec_len
self.min_dec_len = min_dec_len
self.topk = topk
self.unk_id = 0
self.bos_id = 1
self.eos_id = 2
self.mask_id = 8000
self.after_eos = paddle.ones([vocab_size]) * -1e9
self.after_eos[self.eos_id] = 0
self.is_cn = False
self.batch_size = 1
self.latent_weight = paddle.create_parameter(
[hidden_size, latent_type_size], 'float32')
self.plato2_encoder = Plato2Encoder(
vocab_size, type_size, max_position_seq_len, num_layers, n_head,
hidden_size, attn_dropout, act_dropout)
self.logits_fc_layer = nn.Linear(hidden_size, hidden_size)
self.logits_layer_norm = nn.LayerNorm(hidden_size)
self.logits_bias = paddle.create_parameter(
[vocab_size], 'float32', is_bias=True)
self.nsp_predictor = NSP(vocab_size, type_size, max_position_seq_len,
num_layers, n_head, hidden_size, attn_dropout,
act_dropout)
self.gelu_layer = nn.GELU()
self.softmax = nn.Softmax()
@paddle.no_grad()
def forward(self, inputs):
token_ids = inputs['token_ids']
type_ids = inputs['type_ids']
pos_ids = inputs['pos_ids']
generation_mask = inputs['generation_mask']
latent_id = inputs['latent_id']
data_id = inputs['data_id']
# [-1, 1, latent_type_size]
latent_id = F.one_hot(latent_id, self.latent_type_size)
# [-1, 1, hidden_size]
latent_emb = paddle.matmul(
latent_id, self.latent_weight, transpose_y=True)
caches = self.plato2_encoder.gen_caches(token_ids)
# [-1, seq_len + 1, hidden_size]
enc_out, new_caches = self.plato2_encoder(
caches, token_ids, type_ids, pos_ids, generation_mask, latent_emb)
pred_ids = self.decoder(inputs, new_caches)
nsp_inputs = self.gen_nsp_input(token_ids, pred_ids)
# [-1, 2]
probs = self.nsp_predictor(nsp_inputs)
return self.get_results(data_id, token_ids, pred_ids, probs)
def decoder(self, inputs, caches):
tgt_ids = inputs['tgt_ids']
tgt_pos = inputs['tgt_pos']
tgt_generation_mask = inputs['tgt_generation_mask']
predictions = tgt_ids
# TODO
step = 0
while step < self.max_dec_len:
# [-1, 1]
append_mask = paddle.cast(
tgt_ids != self.eos_id, dtype=tgt_generation_mask.dtype)
tgt_generation_mask = paddle.concat(
[tgt_generation_mask, paddle.unsqueeze(append_mask, 1)],
axis=-1)
tgt_sent = paddle.ones(
[tgt_generation_mask.shape[0], 1], dtype=tgt_ids.dtype)
# [-1, 1, hidden_size]
out, caches = self.plato2_encoder(caches, tgt_ids, tgt_sent,
tgt_pos, tgt_generation_mask)
out = paddle.squeeze(out, axis=1)
# [-1, hidden_size]
trans = self.logits_fc_layer(out)
trans = self.gelu_layer(trans)
trans = self.logits_layer_norm(trans)
# [-1, vocab_size]
logits = paddle.matmul(
trans,
self.plato2_encoder.word_embedding_layer.weight,
transpose_y=True) + self.logits_bias
logits[:, self.unk_id] = -1e9
logits[:, self.bos_id] = -1e9
logits[:, self.mask_id] = -1e9
if step < self.min_dec_len:
logits[:, self.eos_id] = -1e9
logits = logits * append_mask + (1 - append_mask) * self.after_eos
probs = self.softmax(logits)
# [-1, topk]
topk_probs, _ = paddle.topk(probs, k=self.topk)
mask = paddle.cast(probs >= topk_probs[:, -1:], 'float32')
sums = paddle.sum(topk_probs, axis=-1, keepdim=True)
new_probs = probs * mask / sums
# [-1, 1]
sampling_ids = paddle.multinomial(new_probs)
step = step + 1
tgt_ids = sampling_ids
tgt_pos = tgt_pos + 1
predictions = paddle.concat([predictions, tgt_ids], axis=1)
return predictions
def gen_nsp_input(self, token_ids, pred_ids):
token_ids = token_ids.numpy()
pred_ids = pred_ids.numpy()
def __reader__():
headers = ["src", "tgt", "data_id"]
Example = namedtuple("Example", headers)
for i, (raw, pred) in enumerate(zip(token_ids, pred_ids)):
context = post_process_context(
raw, self.nsp_reader, merge=False)
_, response = post_process_response(
pred, self.nsp_reader, merge=False)
context_tokenized_input = " [SEP] ".join(" ".join(utt)
for utt in context)
response_tokenized_input = " ".join(response)
example = Example(
src=context_tokenized_input,
tgt=response_tokenized_input,
data_id=i)
data = self.nsp_reader._convert_example_to_record(
example, is_infer=True)
yield data
return
generator = self.nsp_reader.data_generator(
reader=__reader__,
is_infer=True,
phase="test", )
inputs = next(generator())
#print('\nnsp_inputs:')
for key in inputs:
inputs[key] = paddle.to_tensor(inputs[key])
if key in ['token_ids', 'type_ids', 'pos_ids']:
inputs[key] = paddle.squeeze(inputs[key], axis=-1)
#print(key, inputs[key].shape)
#print(inputs[key])
return inputs
def get_results(self, data_id, token_ids, pred_ids, probs):
data_id = data_id.numpy()
token_ids = token_ids.numpy()
pred_ids = pred_ids.numpy()
probs = probs.numpy()
infos = []
for raw, pred, prob in zip(token_ids, pred_ids, probs):
tokens = post_process_context(raw, self.nsp_reader)
pred_token_ids, pred_tokens = post_process_response(pred,
self.nsp_reader)
info = {}
info['response'] = ' '.join(pred_tokens)
cross_turn_repetition = get_cross_turn_repetition(
tokens, pred_tokens, self.nsp_reader.eos_id, self.is_cn)
in_turn_repetition = max(
get_in_turn_repetition(pred_tokens, self.is_cn),
get_in_turn_repetition(pred_token_ids))
info['score'] = float(prob[1])
if len(pred_token_ids) >= self.max_dec_len:
info['score'] -= 1e3
elif cross_turn_repetition > 0:
info['score'] -= 1e3
elif in_turn_repetition > 0:
info['score'] -= 1e3
infos.append(info)
results = []
pre_idx = 0
sample = []
for idx, info in zip(data_id, infos):
if idx != pre_idx:
sample = sorted(sample, key=lambda info: -info["score"])
result = sample[0]
result['data_id'] = pre_idx
results.apeend(result)
sample = []
pre_idx = idx
sample.append(info)
if sample:
sample = sorted(sample, key=lambda info: -info["score"])
result = sample[0]
result['data_id'] = pre_idx
results.append(result)
return results
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dialogue Reader."""
import csv
from collections import namedtuple
from contextlib import contextmanager
import gzip
import numpy as np
from utils import pad_batch_data
from utils.args import str2bool
from utils.masking import mask
import utils.tokenization as tokenization
class DialogReader(object):
"""The implement of DialogReader."""
@classmethod
def add_cmdline_args(cls, parser):
"""Add cmdline argurments."""
group = parser.add_argument_group("Reader")
group.add_argument("--max_src_len", type=int, default=128)
group.add_argument("--max_tgt_len", type=int, default=128)
group.add_argument(
"--truncate_first_turn", type=str2bool, default=False)
group.add_argument(
"--file_format",
type=str,
default="file",
choices=["file", "filelist"])
group.add_argument(
"--data_format",
type=str,
default="raw",
choices=["raw", "tokenized", "numerical"])
group.add_argument("--in_tokens", type=str2bool, default=False)
group.add_argument("--batch_size", type=int, default=16)
group.add_argument("--continuous_position", type=str2bool, default=True)
group.add_argument("--random_seed", type=int, default=11)
group.add_argument("--sort_pool_size", type=int, default=2**16)
group = parser.add_argument_group("Tokenizer")
group.add_argument(
"--tokenizer", type=str, default="SentencePieceTokenizer")
args, _ = parser.parse_known_args()
tokenizer_cls = getattr(tokenization, args.tokenizer)
tokenizer_cls.add_cmdline_args(parser)
return group
def __init__(self, args):
tokenizer_cls = getattr(tokenization, args.tokenizer)
self.tokenizer = tokenizer_cls(args)
self.vocab = self.tokenizer.vocab
self.pad_id = args.pad_id = self.vocab["[PAD]"]
self.bos_id = args.bos_id = self.vocab["[CLS]"]
self.eos_id = args.eos_id = self.vocab["[SEP]"]
self.unk_id = args.unk_id = self.vocab["[UNK]"]
self.mask_id = args.mask_id = self.vocab["[MASK]"]
self.vocab_size = args.get("vocab_size", 0)
self.max_src_len = args.max_src_len
self.max_tgt_len = args.max_tgt_len
self.truncate_first_turn = args.truncate_first_turn
self.file_format = args.file_format
self.data_format = args.data_format
self.in_tokens = args.in_tokens
self.batch_size = args.batch_size
self.continuous_position = args.continuous_position
self.sort_pool_size = args.sort_pool_size
# random_seed must be set for data slicing when using multi-gpu
self.global_rng = np.random.RandomState(args.random_seed)
# training progress
self.current_example = 0
self.current_epoch = 0
self.num_examples = 0
# model related
self.fields = ["token_ids", "type_ids", "pos_ids"]
self.num_numerical_fields = len(self.fields)
self.fields += ["tgt_start_idx", "data_id"]
self.sort_key = lambda record: [len(record.token_ids)]
self.Record = namedtuple(
"Record", self.fields, defaults=(None, ) * len(self.fields))
self.features = {}
return
def get_train_progress(self):
"""Gets progress for training phase."""
return self.current_epoch, self.current_file_index, self.total_file
def _convert_example_to_record(self, example, is_infer):
# process src
src_token_ids = []
src_pos_ids = []
# tokenize src
s_token_ids_list = []
for s in example.src.split("[SEP]"):
s = tokenization.convert_to_unicode(s).strip()
if self.data_format == "tokenized":
s_tokens = s.split(" ")
else:
s_tokens = self.tokenizer.tokenize(s)
s_token_ids = self.tokenizer.convert_tokens_to_ids(
s_tokens) + [self.eos_id]
s_token_ids_list.append(s_token_ids)
# trim src
idx = len(s_token_ids_list) - 1
total_token_num = 1
while idx >= 0:
total_token_num += len(s_token_ids_list[idx])
if total_token_num > self.max_src_len:
if self.truncate_first_turn and idx == 0:
truncated_ids = s_token_ids_list[idx][:self.max_src_len -
total_token_num]
if len(truncated_ids) > 1:
s_token_ids_list[
idx] = truncated_ids[:-1] + [self.eos_id]
idx -= 1
break
idx -= 1
for i, s_token_ids in enumerate(s_token_ids_list[idx + 1:], idx + 1):
src_token_ids += s_token_ids
src_pos_ids += list(range(1, len(s_token_ids) + 1))
src_token_ids = [self.bos_id] + src_token_ids
src_type_ids = [0] * len(src_token_ids)
src_pos_ids = [0] + src_pos_ids
assert len(src_token_ids) == len(src_type_ids) == len(src_pos_ids), \
"not len(src_token_ids) == len(src_type_ids) == len(src_pos_ids)"
token_ids = src_token_ids
type_ids = src_type_ids
pos_ids = src_pos_ids
tgt_start_idx = len(token_ids)
if not is_infer:
# process tgt
# tokenize tgt
tgt = tokenization.convert_to_unicode(example.tgt).strip()
if self.data_format == "tokenized":
tgt_tokens = tgt.split(" ")
else:
tgt_tokens = self.tokenizer.tokenize(tgt)
tgt_token_ids = self.tokenizer.convert_tokens_to_ids(tgt_tokens)
tgt_token_ids.append(self.eos_id)
# trim tgt
if len(tgt_token_ids) > self.max_tgt_len - 1:
tgt_token_ids = tgt_token_ids[:self.max_tgt_len - 1]
tgt_token_ids = [self.bos_id] + tgt_token_ids
tgt_type_ids = [1] * len(tgt_token_ids)
tgt_pos_ids = list(range(1, len(tgt_token_ids) + 1))
assert len(tgt_token_ids) == len(tgt_type_ids) == len(tgt_pos_ids), \
"not len(tgt_token_ids) == len(tgt_type_ids) == len(tgt_pos_ids)"
token_ids += tgt_token_ids
type_ids += tgt_type_ids
pos_ids += tgt_pos_ids
assert len(token_ids) == len(type_ids) == len(pos_ids), \
"not len(token_ids) == len(type_ids) == len(pos_ids)"
if self.continuous_position:
src_pos_ids = list(range(len(src_token_ids)))
if not is_infer:
tgt_pos_ids = list(range(len(tgt_token_ids)))
pos_ids = list(range(len(token_ids)))
field_values = {
"token_ids": src_token_ids,
"type_ids": src_type_ids,
"pos_ids": src_pos_ids
}
field_values["tgt_start_idx"] = tgt_start_idx
field_values["data_id"] = example.data_id
record = self.Record(**field_values)
return record
def _read_tsv(self, fp, phase, is_infer, delimiter="\t", quotechar=None):
"""Reads a tab separated value file."""
csv.field_size_limit(2**20)
reader = csv.reader(fp, delimiter=delimiter, quotechar=quotechar)
headers = next(reader)
headers.append("data_id")
Example = namedtuple("Example", headers)
for i, line in enumerate(reader):
example = Example(*line, data_id=i)
if is_infer or phase.endswith("test"):
self.features[phase][i] = example
record = self._convert_example_to_record(example, is_infer)
yield record
def _read_numerical_file(self, fp, delimiter=";"):
for i, line in enumerate(fp):
cols = tokenization.convert_to_unicode(line).strip().split(
delimiter)
cols = list(map(lambda x: list(map(int, x.split(" "))), cols))
if len(cols) > self.num_numerical_fields:
cols = cols[:self.num_numerical_fields]
tgt_start_idx = cols[0].index(self.bos_id, 1)
record = self.Record(*cols, tgt_start_idx=tgt_start_idx, data_id=i)
yield record
def _read_file(self, input_file, phase, is_infer):
def __wrapper__():
with open_file(input_file) as fp:
if self.data_format == "numerical":
records = self._read_numerical_file(fp)
else:
records = self._read_tsv(fp, phase, is_infer)
for record in records:
yield record
return __wrapper__
def _read_files(self, filelist, phase, is_infer, shuffle_files):
input_files = open(filelist).readlines()
def __wrapper__():
if shuffle_files:
self.global_rng.shuffle(input_files)
if phase == "train":
self.total_file = len(input_files)
for file_index, input_file in enumerate(input_files, 1):
if phase == "train":
self.current_file_index = file_index
self.current_file = input_file
file_reader = self._read_file(input_file.strip(), phase,
is_infer)
for record in file_reader():
yield record
return __wrapper__
def _batch_reader(self,
reader,
phase=None,
is_infer=False,
sort_pool_size=2**16):
"""Construct a batch reader."""
def update_max_lens(max_lens, record):
"""Update max_lens."""
if max_lens is None:
return self.sort_key(record)
else:
return [
max(max_len, l) for max_len, l in zip(max_lens,
self.sort_key(record))
]
def get_batch(reader):
"""Generate batches from reader."""
batch, max_lens = [], None
for record in reader():
if record is None:
yield batch
batch, max_lens = [], None
continue
self.current_example += 1
max_lens = update_max_lens(max_lens, record)
if self.in_tokens:
to_append = (len(batch) + 1
) * sum(max_lens) <= self.batch_size
else:
to_append = len(batch) < self.batch_size
if to_append:
batch.append(record)
else:
yield batch
batch, max_lens = [record], self.sort_key(record)
if len(batch) > 0:
yield batch
def get_sorted_batch(pool):
"""Generate sorted batches from pool."""
pool = sorted(pool, key=self.sort_key)
batches = []
batch, max_lens = [], None
for record in pool:
self.current_example += 1
max_lens = update_max_lens(max_lens, record)
if self.in_tokens:
to_append = (len(batch) + 1
) * sum(max_lens) <= self.batch_size
else:
to_append = len(batch) < self.batch_size
if to_append:
batch.append(record)
else:
batches.append(batch)
batch, max_lens = [record], self.sort_key(record)
if len(batch) > 0:
batches.append(batch)
self.global_rng.shuffle(batches)
for batch in batches:
yield batch
def __wrapper__():
if sort_pool_size > 0:
pool = []
for record in reader():
pool.append(record)
if len(pool) == sort_pool_size:
for batch in get_sorted_batch(pool):
yield batch
pool = []
if len(pool) > 0:
for batch in get_sorted_batch(pool):
yield batch
else:
for batch in get_batch(reader):
yield batch
return __wrapper__
def _distributed_batch_reader(self,
batch_reader,
num_part,
part_id,
is_test=False):
def __wrapper__():
batches = []
for batch in batch_reader():
batches.append(batch)
if len(batches) == num_part:
yield batches[part_id]
batches = []
if is_test and 0 <= part_id < len(batches):
yield batches[part_id]
return
return __wrapper__
def data_generator(self,
input_file=None,
reader=None,
num_epochs=1,
num_part=1,
part_id=0,
phase=None,
is_infer=False):
"""Data generator."""
def __wrapper__():
if is_infer or phase.endswith("test"):
self.features[phase] = {}
nonlocal reader
if reader is None:
if self.file_format == "filelist":
reader = self._read_files(input_file, phase, is_infer,
not phase.endswith("test"))
else:
if phase == "train":
self.total_file = 1
self.current_file_index = 1
self.current_file = input_file
reader = self._read_file(input_file, phase, is_infer)
batch_reader = self._batch_reader(
reader,
phase,
is_infer,
sort_pool_size=self.sort_pool_size if not is_infer else 0)
if phase == "train":
batch_reader = self._distributed_batch_reader(batch_reader,
num_part, part_id)
elif phase.startswith("distributed"):
batch_reader = self._distributed_batch_reader(
batch_reader, num_part, part_id, is_test=True)
for epoch_index in range(num_epochs):
if phase == "train":
self.current_example = 0
self.current_epoch = epoch_index + 1
for batch in batch_reader():
yield self._pad_batch_records(batch, is_infer)
return __wrapper__
def _gen_self_attn_mask(self,
batch_token_ids,
batch_tgt_start_idx=None,
is_unidirectional=True,
shift_len=0):
max_len = max(map(len, batch_token_ids))
input_mask_data = np.zeros(
(len(batch_token_ids), max_len + shift_len, max_len + shift_len))
if is_unidirectional:
for index, mask_data in enumerate(input_mask_data):
start = 0 if batch_tgt_start_idx is None else batch_tgt_start_idx[
index]
end = len(batch_token_ids[index])
mask_data[:end + shift_len, :start + shift_len] = 1.0
# Generate the lower triangular matrix using the slice of matrix
b = np.tril(np.ones([end - start, end - start]), 0)
mask_data[start + shift_len:end + shift_len, start + shift_len:
end + shift_len] = b
else:
for index, token_ids in enumerate(batch_token_ids):
input_mask_data[index, :len(token_ids) + shift_len, :len(
token_ids) + shift_len] = 1.0
return input_mask_data.astype("float32")
def _pad_batch_records(self, batch_records, is_infer):
"""
Padding batch records and construct model's inputs.
"""
batch_size = len(batch_records)
batch = {}
batch_token_ids = [record.token_ids for record in batch_records]
batch_type_ids = [record.type_ids for record in batch_records]
batch_pos_ids = [record.pos_ids for record in batch_records]
batch["token_ids"] = pad_batch_data(batch_token_ids, pad_id=self.pad_id)
batch["type_ids"] = pad_batch_data(batch_type_ids, pad_id=self.pad_id)
batch["pos_ids"] = pad_batch_data(batch_pos_ids, pad_id=self.pad_id)
batch_tgt_start_idx = [record.tgt_start_idx for record in batch_records]
batch["generation_mask"] = self._gen_self_attn_mask(
batch_token_ids, batch_tgt_start_idx=batch_tgt_start_idx)
if is_infer:
tgt_ids = np.array(
[[[self.bos_id]]] * len(batch_token_ids), dtype="int64")
if self.continuous_position:
tgt_pos = np.array(batch_tgt_start_idx, dtype="int64")
else:
tgt_pos = np.zeros_like(batch_tgt_start_idx, dtype="int64")
tgt_pos = tgt_pos.reshape(-1, 1, 1)
batch["init_score"] = np.zeros_like(
tgt_ids, dtype="float32").reshape(-1, 1).tolist()
batch["tgt_ids"] = tgt_ids.tolist()
batch["tgt_pos"] = tgt_pos.tolist()
batch["tgt_generation_mask"] = batch[
"generation_mask"][:, 0:1, :].astype("float32")
else:
batch["tgt_label"], batch["tgt_pos"] = mask(
batch_tokens=batch_token_ids,
vocab_size=self.vocab_size,
sent_b_starts=batch_tgt_start_idx,
is_unidirectional=True)
batch_data_id = [record.data_id for record in batch_records]
batch["data_id"] = np.array(batch_data_id).astype("int64").reshape(
[-1, 1])
return batch
@contextmanager
def open_file(filename):
"""Open file."""
if filename.endswith(".gz"):
fp = gzip.open(filename, "rt")
else:
fp = open(filename)
yield fp
fp.close()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NSP Reader."""
from collections import namedtuple
import numpy as np
from readers.dialog_reader import DialogReader
from utils import pad_batch_data
from utils.args import str2bool
from utils.masking import mask
class NSPReader(DialogReader):
"""NSP Reader."""
@classmethod
def add_cmdline_args(cls, parser):
"""Add cmdline argurments."""
group = DialogReader.add_cmdline_args(parser)
group.add_argument(
"--attention_style",
type=str,
default="bidirectional",
choices=["bidirectional", "unidirectional"])
group.add_argument(
"--mix_negative_sample", type=str2bool, default=False)
return group
def __init__(self, args):
super(NSPReader, self).__init__(args)
self.fields.append("label")
self.Record = namedtuple(
"Record", self.fields, defaults=(None, ) * len(self.fields))
self.attention_style = args.attention_style
self.mix_negative_sample = args.mix_negative_sample
return
def _convert_example_to_record(self, example, is_infer):
record = super(NSPReader, self)._convert_example_to_record(example,
False)
if "label" in example._fields:
record = record._replace(label=int(example.label))
return record
def _mix_negative_sample(self, reader, neg_pool_size=2**16):
def gen_from_pool(pool):
num_samples = len(pool)
if num_samples == 1:
# only one sample: it is impossible to generate negative sample
yield pool[0]._replace(label=1)
return
self.global_rng.shuffle(pool)
for i in range(num_samples):
pool[i] = pool[i]._replace(label=1)
j = (i + 1) % num_samples
idx_i = pool[i].tgt_start_idx
idx_j = pool[j].tgt_start_idx
field_values = {}
field_values["token_ids"] = pool[i].token_ids[:idx_i] + pool[
j].token_ids[idx_j:]
field_values["type_ids"] = pool[i].type_ids[:idx_i] + pool[
j].type_ids[idx_j:]
field_values["pos_ids"] = list(
range(len(field_values["token_ids"])))
neg_record = self.Record(
**field_values, tgt_start_idx=idx_i, data_id=-1, label=0)
pool.append(neg_record)
assert len(neg_record.token_ids) <= self.max_seq_len
self.global_rng.shuffle(pool)
for record in pool:
yield record
def __wrapper__():
pool = []
for record in reader():
pool.append(record)
if len(pool) == neg_pool_size:
for record in gen_from_pool(pool):
yield record
pool = []
if len(pool) > 0:
for record in gen_from_pool(pool):
yield record
return __wrapper__
def _batch_reader(self,
reader,
phase=None,
is_infer=False,
sort_pool_size=2**16):
if self.mix_negative_sample:
reader = self._mix_negative_sample(reader)
return super(NSPReader, self)._batch_reader(
reader,
phase=phase,
is_infer=is_infer,
sort_pool_size=sort_pool_size)
def _pad_batch_records(self, batch_records, is_infer):
"""
Padding batch records and construct model's inputs.
"""
batch = {}
batch_token_ids = [record.token_ids for record in batch_records]
batch_type_ids = [record.type_ids for record in batch_records]
batch_pos_ids = [record.pos_ids for record in batch_records]
batch_tgt_start_idx = [record.tgt_start_idx for record in batch_records]
batch_label = [record.label for record in batch_records]
if self.attention_style == "unidirectional":
batch["token_ids"] = pad_batch_data(
batch_token_ids, pad_id=self.pad_id)
batch["type_ids"] = pad_batch_data(
batch_type_ids, pad_id=self.pad_id)
batch["pos_ids"] = pad_batch_data(batch_pos_ids, pad_id=self.pad_id)
tgt_label, tgt_pos, label_pos = mask(
batch_tokens=batch_token_ids,
vocab_size=self.vocab_size,
bos_id=self.bos_id,
sent_b_starts=batch_tgt_start_idx,
labels=batch_label,
is_unidirectional=True)
attention_mask = self._gen_self_attn_mask(batch_token_ids,
batch_tgt_start_idx)
else:
batch_mask_token_ids, tgt_label, tgt_pos, label_pos = mask(
batch_tokens=batch_token_ids,
vocab_size=self.vocab_size,
bos_id=self.bos_id,
eos_id=self.eos_id,
mask_id=self.mask_id,
sent_b_starts=batch_tgt_start_idx,
labels=batch_label,
is_unidirectional=False)
if not is_infer:
batch_token_ids = batch_mask_token_ids
batch["token_ids"] = pad_batch_data(
batch_token_ids, pad_id=self.pad_id)
batch["type_ids"] = pad_batch_data(
batch_type_ids, pad_id=self.pad_id)
batch["pos_ids"] = pad_batch_data(batch_pos_ids, pad_id=self.pad_id)
attention_mask = self._gen_self_attn_mask(
batch_token_ids, is_unidirectional=False)
batch["attention_mask"] = attention_mask
batch["label_pos"] = label_pos
if not is_infer:
batch_label = np.array(batch_label).astype("int64").reshape([-1, 1])
batch["label"] = batch_label
batch["tgt_label"] = tgt_label
batch["tgt_pos"] = tgt_pos
batch_data_id = [record.data_id for record in batch_records]
batch["data_id"] = np.array(batch_data_id).astype("int64").reshape(
[-1, 1])
return batch
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Plato Reader."""
import numpy as np
from readers.dialog_reader import DialogReader
from utils import pad_batch_data
from utils.masking import mask
class PlatoReader(DialogReader):
"""The implement of PlatoReader"""
def __init__(self, args):
super(PlatoReader, self).__init__(args)
self.latent_type_size = args.latent_type_size
self.use_bow = args.use_bow
def _pad_batch_records(self, batch_records, is_infer):
"""
Padding batch records and construct model's inputs.
"""
batch = {}
batch_token_ids = [record.token_ids for record in batch_records]
batch_type_ids = [record.type_ids for record in batch_records]
batch_pos_ids = [record.pos_ids for record in batch_records]
batch_tgt_start_idx = [record.tgt_start_idx for record in batch_records]
batch_size = len(batch_token_ids)
# padding
batch["token_ids"] = pad_batch_data(batch_token_ids, pad_id=self.pad_id)
batch["type_ids"] = pad_batch_data(batch_type_ids, pad_id=self.pad_id)
batch["pos_ids"] = pad_batch_data(batch_pos_ids, pad_id=self.pad_id)
batch["generation_mask"] = self._gen_self_attn_mask(
batch_token_ids,
batch_tgt_start_idx=batch_tgt_start_idx,
is_unidirectional=True,
shift_len=1)
if not is_infer:
batch["recognition_mask"] = self._gen_self_attn_mask(
batch_token_ids, is_unidirectional=False, shift_len=1)
if is_infer:
tgt_ids = np.array([[[self.bos_id]]] * batch_size, dtype="int64")
if self.continuous_position:
tgt_pos = np.array(batch_tgt_start_idx, dtype="int64")
else:
tgt_pos = np.zeros_like(batch_tgt_start_idx, dtype="int64")
tgt_pos = tgt_pos.reshape(-1, 1, 1)
batch["init_score"] = np.zeros_like(
tgt_ids, dtype="float32").reshape(-1, 1).tolist()
batch["tgt_ids"] = tgt_ids.tolist()
batch["tgt_pos"] = tgt_pos.tolist()
batch["parent_idx"] = np.array(range(batch_size), dtype="int32")
batch["tgt_generation_mask"] = batch[
"generation_mask"][:, 0:1, :].astype("float32")
else:
mask_return_list = mask(
batch_tokens=batch_token_ids,
vocab_size=self.vocab_size,
sent_b_starts=batch_tgt_start_idx,
is_unidirectional=True,
use_latent=True,
use_bow=self.use_bow)
batch["tgt_label"] = mask_return_list[0]
batch["tgt_pos"] = mask_return_list[1]
if self.use_bow:
batch["bow_label"] = mask_return_list[2]
batch["bow_pos"] = mask_return_list[3]
batch_data_id = [record.data_id for record in batch_records]
batch["data_id"] = np.array(batch_data_id).astype("int64").reshape(
[-1, 1])
return batch
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils."""
from itertools import chain
import numpy as np
import paddle
def repeat_array(array, times):
"""Repeate numpy array."""
if isinstance(array, list):
return list(chain(*([array] * times)))
else:
return np.concatenate([array] * times, axis=0)
def gen_inputs(inputs, latent_type_size):
batch_size = len(inputs["data_id"])
new_bsz = batch_size * latent_type_size
inputs = {
name: repeat_array(array, latent_type_size)
for name, array in inputs.items()
}
# Add latent_id
inputs["latent_id"] = np.array(
[i for i in range(latent_type_size) for _ in range(batch_size)],
dtype="int64").reshape([-1, 1])
#print('\nplato_inputs:')
for key in inputs:
inputs[key] = paddle.to_tensor(inputs[key])
if key in [
'token_ids', 'type_ids', 'pos_ids', 'tgt_ids', 'tgt_pos',
'data_id'
]:
inputs[key] = paddle.squeeze(inputs[key], axis=-1)
#print(key, inputs[key].shape, inputs[key].dtype)
return inputs
def pad_batch_data(insts, pad_id=0):
"""Pad the instances to the max sequence length in batch. """
max_len = max(map(len, insts))
inst_data = np.array(
[list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts])
return inst_data.astype("int64").reshape([-1, max_len, 1])
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parse argument."""
import argparse
import json
def str2bool(v):
""" Support bool type for argparse. """
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Unsupported value encountered.")
class Args(dict):
""" Arguments class
Store arguments in training / infer / ... scripts.
"""
def __getattr__(self, name):
if name in self.keys():
return self[name]
for v in self.values():
if isinstance(v, Args):
if name in v:
return v[name]
return None
def get(self, key, default_value=None):
"""Get the value of corresponding key."""
if key in self.keys():
return self[key]
for v in self.values():
if isinstance(v, Args):
if key in v:
return v[key]
return default_value
def __setattr__(self, name, value):
self[name] = value
def save(self, filename):
with open(filename, "w") as fp:
json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False)
def load(self, filename, group_name=None):
if group_name is not None:
if group_name not in self:
self[group_name] = Args()
self[group_name].load(filename)
return
with open(filename, "r") as fp:
params_dict = json.load(fp)
for k, v in params_dict.items():
if isinstance(v, dict):
self[k].update(Args(v))
else:
self[k] = v
def parse_args(parser: argparse.ArgumentParser, allow_unknown=False) -> Args:
""" Parse hyper-parameters from cmdline. """
if allow_unknown:
parsed, _ = parser.parse_known_args()
else:
parsed = parser.parse_args()
args = Args()
optional_args = parser._action_groups[1]
for action in optional_args._group_actions[1:]:
arg_name = action.dest
args[arg_name] = getattr(parsed, arg_name)
for group in parser._action_groups[2:]:
group_args = Args()
for action in group._group_actions:
arg_name = action.dest
group_args[arg_name] = getattr(parsed, arg_name)
if len(group_args) > 0:
if group.title in args:
args[group.title].update(group_args)
else:
args[group.title] = group_args
return args
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Reader utils."""
import numpy as np
def mask(batch_tokens,
vocab_size,
bos_id=1,
eos_id=2,
mask_id=3,
sent_b_starts=None,
labels=None,
is_unidirectional=False,
use_latent=False,
use_bow=False):
"""
Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded;
"""
batch_tokens = np.copy(batch_tokens)
max_len = max(map(len, batch_tokens))
mask_label = []
mask_pos = []
if labels is not None:
label_pos = []
if is_unidirectional:
# unidirectional language model
if use_latent:
max_len += 1
shift_len = 1
else:
shift_len = 0
for sent_index, sent in enumerate(batch_tokens):
sent_b_index = sent_b_starts[
sent_index] if sent_b_starts is not None else 0
need_cal = True
if labels is not None:
label_pos.append(sent_index * max_len + len(sent) - 1 +
shift_len)
if labels[sent_index] == 0:
need_cal = False
mask_label.extend(sent[sent_b_index + 1:])
mask_pos.extend([
sent_index * max_len + i + shift_len
for i in range(sent_b_index, len(sent) - 1)
])
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
return_list = [mask_label, mask_pos]
# latent related (bow label and pos)
if use_latent and use_bow:
bow_label = []
bow_pos = []
for sent_index, sent in enumerate(batch_tokens):
sent_b_index = sent_b_starts[
sent_index] if sent_b_starts is not None else 0
def __filter__(tok_id):
# TODO: exclude [EOS] from bow loss
return True
bow_pos.extend([
sent_index for i in range(sent_b_index + 1, len(sent))
if __filter__(sent[i])
])
bow_label.extend([
sent[i] for i in range(sent_b_index + 1, len(sent))
if __filter__(sent[i])
])
bow_label = np.array(bow_label).astype("int64").reshape([-1, 1])
bow_pos = np.array(bow_pos).astype("int64").reshape([-1, 1])
return_list += [bow_label, bow_pos]
else:
# bidirectional mask language model
total_token_num = sum(map(len, batch_tokens))
prob_mask = np.random.rand(total_token_num)
# TODO: fix replace_ids, include [UNK]
replace_ids = np.random.randint(
3, high=vocab_size, size=total_token_num)
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
# add pair label position
if labels is not None:
label_pos.append(sent_index * max_len)
# add mask label and position
for token_index, token in enumerate(sent):
if token == eos_id or token == bos_id:
continue
prob = prob_mask[prob_index + token_index]
if prob > 0.15:
continue
elif 0.03 < prob <= 0.15:
# mask
mask_label.append(sent[token_index])
sent[token_index] = mask_id
mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03:
# random replace
mask_label.append(sent[token_index])
sent[token_index] = replace_ids[prob_index + token_index]
mask_pos.append(sent_index * max_len + token_index)
else:
# keep the original token
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
prob_index += len(sent)
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
return_list = [batch_tokens, mask_label, mask_pos]
if labels is not None:
label_pos = np.array(label_pos).astype("int64").reshape([-1, 1])
assert len(labels) == len(label_pos)
return_list.append(label_pos)
return return_list
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
import collections
import sentencepiece as spm
import unicodedata
from utils.args import str2bool
def clean_text(text):
"""Performs invalid character removal and whitespace cleanup on text."""
text = text.replace(u"“", u'"')\
.replace(u'”', u'"')\
.replace(u'‘', "'")\
.replace(u'’', u"'")\
.replace(u'—', u'-')
output = []
for char in text:
if _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
def preprocess_text(inputs, remove_space=True, lower=False):
"""preprocess data by removing extra space and normalize data."""
outputs = inputs
if remove_space:
outputs = " ".join(inputs.strip().split())
outputs = unicodedata.normalize("NFKD", outputs)
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
if lower:
outputs = outputs.lower()
return outputs
def encode_pieces(spm_model, text, return_unicode=True, sample=False):
"""turn sentences into word pieces."""
# liujiaxiang: add for ernie-albert, mainly consider for “/”/‘/’/— causing too many unk
text = clean_text(text)
if not sample:
pieces = spm_model.EncodeAsPieces(text)
else:
pieces = spm_model.SampleEncodeAsPieces(text, 64, 0.1)
return pieces
def encode_ids(spm_model, text, sample=False):
"""turn sentences into word pieces."""
pieces = encode_pieces(spm_model, text, return_unicode=False, sample=sample)
ids = [spm_model.PieceToId(piece) for piece in pieces]
return ids
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
fin = open(vocab_file)
for num, line in enumerate(fin):
items = convert_to_unicode(line.rstrip()).split("\t")
if len(items) > 2:
break
token = items[0]
index = items[1] if len(items) == 2 else num
token = token.strip()
vocab[token] = int(index)
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
class SentencePieceTokenizer(object):
"""Runs end-to-end tokenziation."""
@classmethod
def add_cmdline_args(cls, parser):
"""Add cmdline argurments."""
group = parser.add_argument_group("Tokenizer")
group.add_argument("--vocab_path", type=str, required=True)
group.add_argument("--do_lower_case", type=str2bool, default=False)
group.add_argument("--spm_model_file", type=str, required=True)
return group
def __init__(self, args):
self.spm_model = spm.SentencePieceProcessor()
self.spm_model.Load(args.spm_model_file)
self.vocab = load_vocab(args.vocab_path)
self.do_lower_case = args.do_lower_case
self.inv_vocab = {v: k for k, v in self.vocab.items()}
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = preprocess_text(text, lower=self.do_lower_case)
return encode_pieces(self.spm_model, text, return_unicode=True)
def convert_tokens_to_ids(self, tokens):
"""Convert tokens to ids."""
ret = []
unk_id = self.vocab["<unk>"]
for token in tokens:
if token in self.vocab:
ret.append(self.vocab[token])
else:
ret.append(unk_id)
return ret
def convert_ids_to_tokens(self, ids):
"""Convert ids to tokens."""
return convert_by_vocab(self.inv_vocab, ids)
def merge_subword(self, tokens):
"""Merge subword."""
ret = []
for token in tokens:
if token.startswith(u"▁"):
ret.append(token[1:])
else:
if len(ret):
ret[-1] += token
else:
ret.append(token)
ret = [token for token in ret if token]
return ret
def convert_ids_to_str(self, ids):
"""Convert ids to string."""
tokens = self.convert_ids_to_tokens(ids)
tokens = self.merge_subword(tokens)
res = " ".join(tokens).replace("<s>", "")
res = res.replace("</s>", "\n").replace("\n ", "\n").strip()
return res
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册