diff --git a/PaddleNLP/examples/dialogue/plato-2/README.md b/PaddleNLP/examples/dialogue/plato-2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8ab6d446eed49e22f05951ab58e82fdc0c930e42 --- /dev/null +++ b/PaddleNLP/examples/dialogue/plato-2/README.md @@ -0,0 +1,104 @@ +# PLATO-2 + +## 模型简介 + +构建高质量的开放领域(Open-Domain)的对话机器人,使得它能用自然语言与人自由地交流,这一直是自然语言处理领域终极目标之一。 + +为了能够简易地构建一个高质量的开放域聊天机器人,本项目在Paddle2.0上实现了PLATO-2的预测模型,并基于终端实现了简单的人机交互。用户可以通过下载预训练模型快速构建一个开放域聊天机器人。 + +PLATO-2的网络结构及评估结果见下图: + +![image](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/dialogue/plato-2/imgs/network.png) + +![image](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/dialogue/plato-2/imgs/eval_en.png) + +![image](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/dialogue/plato-2/imgs/eval_cn.png) + +PLATO-2的训练过程及其他细节详见 [Knover](https://github.com/PaddlePaddle/Knover) + +## 快速开始 + +### 安装说明 + +* PaddlePaddle 安装 + + 本项目依赖于 PaddlePaddle 2.0 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装 + +* PaddleNLP 安装 + + ```shell + pip install paddlenlp>=2.0.0b + ``` + +* 环境依赖 + + Python的版本要求 3.6+ + + 本项目依赖sentencepiece和termcolor,请在运行本项目之前进行安装 + + ```shell + pip install sentencepiece termcolor + ``` + +### 代码结构说明 + +以下是本项目主要代码结构及说明: + +```text +. +├── interaction.py # 交互主程序入口 +├── model.py # 模型组网 +├── readers +│   ├── dialog_reader.py # 模型输入数据生成 +│   ├── nsp_reader.py # 模型输入数据生成 +│   └── plato_reader.py # 模型输入数据生成 +├── utils +│   ├── __init__.py # 基础函数 +│   ├── args.py # 运行参数配置 +│   ├── masking.py # mask相关函数 +│   └── tokenization.py # 分词相关函数 +├── imgs # 示例图存储文件夹 +└── README.md # 说明文档 +``` + +### 数据准备 + +您可以从以下位置下载预训练模型文件: + +* PLATO-2, 24-layers, 16-heads, 1024-hidden, EN: [预训练模型](https://paddlenlp.bj.bcebos.com/models/transformers/plato2/24L.pdparams) +* PLATO-2, 32-layers, 32-heads, 2048-hidden, EN: [预训练模型](https://paddlenlp.bj.bcebos.com/models/transformers/plato2/32L.pdparams) + +以24层预训练模型为例: + +```shell +wget https://paddlenlp.bj.bcebos.com/models/transformers/plato2/24L.pdparams +``` + +**NOTE:** PLATO-2网络参数量较大,24层网络至少需要显存16G,32层网络至少需要显存22G,用户可选择合适的网络层数及预训练模型。 + +sentencepiece分词预训练模型和词表文件下载: + +```shell +wget https://paddlenlp.bj.bcebos.com/models/transformers/plato2/data.tar.gz +tar -zxf data.tar.gz +``` + +### 人机交互 + +运行如下命令即可开始与聊天机器人用英语进行简单的对话 + +```shell +export CUDA_VISIBLE_DEVICES=0 +python interaction.py --vocab_path ./data/vocab.txt --spm_model_file ./data/spm.model --num_layers 24 --init_from_ckpt ./24L.pdparams +``` + +以上参数表示: + +* vocab_path:词表文件路径。 +* spm_model_file:sentencepiece分词预训练模型路径。 +* num_layers:PLATO-2组网层数。 +* init_from_ckpt:PLATO-2预训练模型路径。 + +32层PLATO-2网络交互示例: + +![image](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/dialogue/plato-2/imgs/case.jpg) diff --git a/PaddleNLP/examples/dialogue/plato-2/imgs/case.jpg b/PaddleNLP/examples/dialogue/plato-2/imgs/case.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1917204e228c060de36eef0f69136e656f48a94b Binary files /dev/null and b/PaddleNLP/examples/dialogue/plato-2/imgs/case.jpg differ diff --git a/PaddleNLP/examples/dialogue/plato-2/imgs/eval_cn.png b/PaddleNLP/examples/dialogue/plato-2/imgs/eval_cn.png new file mode 100644 index 0000000000000000000000000000000000000000..29dc9dd3e2cb3e2fb02fc7a3ad976a8198dea847 Binary files /dev/null and b/PaddleNLP/examples/dialogue/plato-2/imgs/eval_cn.png differ diff --git a/PaddleNLP/examples/dialogue/plato-2/imgs/eval_en.png b/PaddleNLP/examples/dialogue/plato-2/imgs/eval_en.png new file mode 100644 index 0000000000000000000000000000000000000000..41d405a7141bec6a492822288ffa6ff1c54df203 Binary files /dev/null and b/PaddleNLP/examples/dialogue/plato-2/imgs/eval_en.png differ diff --git a/PaddleNLP/examples/dialogue/plato-2/imgs/network.png b/PaddleNLP/examples/dialogue/plato-2/imgs/network.png new file mode 100644 index 0000000000000000000000000000000000000000..fb8750ae63c1d9685418b6f847e3bc96d500498b Binary files /dev/null and b/PaddleNLP/examples/dialogue/plato-2/imgs/network.png differ diff --git a/PaddleNLP/examples/dialogue/plato-2/interaction.py b/PaddleNLP/examples/dialogue/plato-2/interaction.py new file mode 100644 index 0000000000000000000000000000000000000000..26813dacd4d61bc2cc8c3490204404b71067710b --- /dev/null +++ b/PaddleNLP/examples/dialogue/plato-2/interaction.py @@ -0,0 +1,92 @@ +import json +import argparse +from collections import namedtuple +from termcolor import colored, cprint + +import paddle + +from utils.args import parse_args, str2bool +from utils import gen_inputs +from readers.nsp_reader import NSPReader +from readers.plato_reader import PlatoReader +from model import Plato2InferModel + + +def setup_args(): + """Setup arguments.""" + parser = argparse.ArgumentParser() + group = parser.add_argument_group("Model") + group.add_argument("--init_from_ckpt", type=str, default="") + group.add_argument("--vocab_size", type=int, default=8001) + group.add_argument("--latent_type_size", type=int, default=20) + group.add_argument("--num_layers", type=int, default=24) + + group = parser.add_argument_group("Task") + group.add_argument("--is_cn", type=str2bool, default=False) + + args, _ = parser.parse_known_args() + NSPReader.add_cmdline_args(parser) + + args = parse_args(parser) + args.batch_size *= args.latent_type_size + + #print(json.dumps(args, indent=2)) + return args + + +def load_params(model, init_from_ckpt): + state_dict = paddle.load(init_from_ckpt) + model.set_state_dict(state_dict) + + +def interact(args): + """Inference main function.""" + plato_reader = PlatoReader(args) + nsp_reader = NSPReader(args) + + if args.num_layers == 24: + n_head = 16 + hidden_size = 1024 + elif args.num_layers == 32: + n_head = 32 + hidden_size = 2048 + else: + raise ValueError('The pre-trained model only support 24 or 32 layers, ' + 'but received num_layers=%d.' % args.num_layers) + + model = Plato2InferModel(nsp_reader, args.num_layers, n_head, hidden_size) + load_params(model, args.init_from_ckpt) + model.eval() + + Example = namedtuple("Example", ["src", "data_id"]) + context = [] + start_info = "Enter [EXIT] to quit the interaction, [NEXT] to start a new conversation." + cprint(start_info, "yellow", attrs=["bold"]) + while True: + user_utt = input(colored("[Human]: ", "red", attrs=["bold"])).strip() + if user_utt == "[EXIT]": + break + elif user_utt == "[NEXT]": + context = [] + cprint(start_info, "yellow", attrs=["bold"]) + else: + context.append(user_utt) + example = Example(src=" [SEP] ".join(context), data_id=0) + record = plato_reader._convert_example_to_record( + example, is_infer=True) + data = plato_reader._pad_batch_records([record], is_infer=True) + inputs = gen_inputs(data, args.latent_type_size) + pred = model(inputs)[0] + bot_response = pred["response"] + print( + colored( + "[Bot]:", "blue", attrs=["bold"]), + colored( + bot_response, attrs=["bold"])) + context.append(bot_response) + return + + +if __name__ == "__main__": + args = setup_args() + interact(args) diff --git a/PaddleNLP/examples/dialogue/plato-2/model.py b/PaddleNLP/examples/dialogue/plato-2/model.py new file mode 100644 index 0000000000000000000000000000000000000000..8edf5cc5ed1bb4dfb326c4134ddcb69663ab3bb1 --- /dev/null +++ b/PaddleNLP/examples/dialogue/plato-2/model.py @@ -0,0 +1,478 @@ +from collections import namedtuple + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +def post_process_context(token_ids, reader, merge=True): + """Post-process the context sequence.""" + context = [] + utt = [] + for tok_id in token_ids[1:]: + if tok_id == reader.eos_id: + utt = reader.tokenizer.convert_ids_to_tokens(utt) + if merge: + utt = reader.tokenizer.merge_subword(utt) + context.append(utt) + utt = [] + else: + utt.append(tok_id) + return context + + +def post_process_response(token_ids, reader, merge=True): + """ + Post-process the decoded sequence. Truncate from the first + and remove the and tokens currently. + """ + eos_pos = len(token_ids) + for i, tok_id in enumerate(token_ids): + if tok_id == reader.eos_id: + eos_pos = i + break + token_ids = token_ids[1:eos_pos] + response = reader.tokenizer.convert_ids_to_tokens(token_ids) + if merge: + response = reader.tokenizer.merge_subword(response) + return token_ids, response + + +def get_cross_turn_repetition(context, pred_tokens, eos_idx, is_cn=False): + """Get cross-turn repetition.""" + if len(pred_tokens) == 0: + return 1.0 + if is_cn: + context = ["".join(utt) for utt in context] + pred_tokens = "".join(pred_tokens) + + pred_tri_grams = set() + for i in range(len(pred_tokens) - 2): + tri_gram = tuple(pred_tokens[i:i + 3]) + pred_tri_grams.add(tri_gram) + for utt in context: + for i in range(len(utt) - 2): + tri_gram = tuple(utt[i:i + 3]) + if tri_gram in pred_tri_grams: + return 1.0 + return 0.0 + + +def get_in_turn_repetition(pred, is_cn=False): + """Get in-turn repetition.""" + if len(pred) == 0: + return 1.0 + if isinstance(pred[0], str): + pred = [tok.lower() for tok in pred] + if is_cn: + pred = "".join(pred) + tri_grams = set() + for i in range(len(pred) - 2): + tri_gram = tuple(pred[i:i + 3]) + if tri_gram in tri_grams: + return 1.0 + tri_grams.add(tri_gram) + return 0.0 + + +class Plato2EncoderLayer(nn.Layer): + def __init__(self, n_head, hidden_size, attn_dropout, act_dropout): + super(Plato2EncoderLayer, self).__init__() + + self.self_attn = nn.MultiHeadAttention(hidden_size, n_head, + attn_dropout) + self.pre_norm_layer = nn.LayerNorm(hidden_size) + self.post_norm_layer = nn.LayerNorm(hidden_size) + self.fc1 = nn.Linear(hidden_size, hidden_size * 4) + self.fc2 = nn.Linear(hidden_size * 4, hidden_size) + + self.dropout_layer = nn.Dropout(act_dropout) + self.gelu_layer = nn.GELU() + + def forward(self, x, attn_mask, cache): + query = self.pre_norm_layer(x) + attn_output, new_cache = self.self_attn(query, None, None, attn_mask, + cache) + attn_output = self.dropout_layer(attn_output) + attn_output = attn_output + x + ffd_input = self.post_norm_layer(attn_output) + + ffd_output = self.fc1(ffd_input) + ffd_output = self.gelu_layer(ffd_output) + ffd_output = self.dropout_layer(ffd_output) + + ffd_output = self.fc2(ffd_output) + ffd_output = self.dropout_layer(ffd_output) + out = ffd_output + attn_output + + return out, new_cache + + def gen_cache(self, key): + return self.self_attn.gen_cache(key) + + +class Plato2Encoder(nn.Layer): + def __init__(self, vocab_size, type_size, max_position_seq_len, num_layers, + n_head, hidden_size, attn_dropout, act_dropout): + super(Plato2Encoder, self).__init__() + + self.n_head = n_head + + self.word_embedding_layer = nn.Embedding(vocab_size, hidden_size) + self.sent_embedding_layer = nn.Embedding(type_size, hidden_size) + self.pos_embedding_layer = nn.Embedding(max_position_seq_len, + hidden_size) + + self.encoder_layers = [] + for i in range(num_layers): + encoder_layer = Plato2EncoderLayer(n_head, hidden_size, + attn_dropout, act_dropout) + self.encoder_layers.append(encoder_layer) + self.add_sublayer('layers.' + str(i), encoder_layer) + self.post_encoder_layer_norm = nn.LayerNorm(hidden_size) + + self.dropout_layer = nn.Dropout(act_dropout) + + def forward(self, + caches, + token_ids, + type_ids, + pos_ids, + generation_mask, + aux_emb=None): + out, self_attn_mask = self.gen_input(token_ids, type_ids, pos_ids, + generation_mask, aux_emb) + + new_caches = [] + for i, encoder_layer in enumerate(self.encoder_layers): + out, new_cache = encoder_layer(out, self_attn_mask, caches[i]) + new_caches.append(new_cache) + + enc_output = self.post_encoder_layer_norm(out) + return enc_output, new_caches + + def gen_input(self, token_ids, type_ids, pos_ids, input_mask, aux_emb=None): + token_emb_out = self.word_embedding_layer(token_ids) + type_emb_out = self.sent_embedding_layer(type_ids) + pos_emb_out = self.pos_embedding_layer(pos_ids) + emb_out = token_emb_out + type_emb_out + pos_emb_out + + # auxiliary memory embeddings + if aux_emb is not None: + emb_out = paddle.concat([aux_emb, emb_out], axis=1) + + emb_out = self.dropout_layer(emb_out) + + # generate n-head self-attention mask + self_attn_mask = input_mask + self_attn_mask = paddle.scale( + x=self_attn_mask, scale=1e4, bias=-1.0, bias_after_scale=False) + n_head_self_attn_mask = paddle.stack( + x=[self_attn_mask] * self.n_head, axis=1) + n_head_self_attn_mask.stop_gradient = True + + return emb_out, n_head_self_attn_mask + + def gen_caches(self, key): + caches = [ + encoder_layer.gen_cache(key) + for encoder_layer in self.encoder_layers + ] + return caches + + +class NSP(nn.Layer): + def __init__(self, vocab_size, type_size, max_position_seq_len, num_layers, + n_head, hidden_size, attn_dropout, act_dropout): + super(NSP, self).__init__() + + self.n_head = n_head + self.hidden_size = hidden_size + + self.word_embedding_layer = nn.Embedding(vocab_size, hidden_size) + self.sent_embedding_layer = nn.Embedding(type_size, hidden_size) + self.pos_embedding_layer = nn.Embedding(max_position_seq_len, + hidden_size) + + encoder_layer = nn.TransformerEncoderLayer( + hidden_size, n_head, hidden_size * 4, act_dropout, 'gelu', + attn_dropout, act_dropout, 'True') + encoder_norm = nn.LayerNorm(hidden_size) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers, + encoder_norm) + self.fc1 = nn.Linear(hidden_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, 2) + + self.dropout_layer = nn.Dropout(act_dropout) + self.tanh_layer = nn.Tanh() + self.softmax = nn.Softmax() + + def forward(self, inputs): + token_ids = inputs['token_ids'] + type_ids = inputs['type_ids'] + pos_ids = inputs['pos_ids'] + attention_mask = inputs['attention_mask'] + label_pos = inputs["label_pos"] + + out, self_attn_mask = self.gen_input(token_ids, type_ids, pos_ids, + attention_mask) + # [-1, seq_len, hidden_size] + enc_out = self.encoder(out, self_attn_mask) + + enc_out = paddle.reshape(enc_out, [-1, self.hidden_size]) + label_pos = paddle.cast(label_pos, 'int64') + out = paddle.gather(enc_out, label_pos) + pooled_out = self.fc1(out) + pooled_out = self.tanh_layer(pooled_out) + + # [-1, 2] + logits = self.fc2(pooled_out) + probs = self.softmax(logits) + + return probs + + def gen_input(self, token_ids, type_ids, pos_ids, input_mask, aux_emb=None): + token_emb_out = self.word_embedding_layer(token_ids) + type_emb_out = self.sent_embedding_layer(type_ids) + pos_emb_out = self.pos_embedding_layer(pos_ids) + emb_out = token_emb_out + type_emb_out + pos_emb_out + + # auxiliary memory embeddings + if aux_emb is not None: + emb_out = paddle.concat([aux_emb, emb_out], axis=1) + + emb_out = self.dropout_layer(emb_out) + + # generate n-head self-attention mask + self_attn_mask = input_mask + self_attn_mask = paddle.scale( + x=self_attn_mask, scale=1e4, bias=-1.0, bias_after_scale=False) + n_head_self_attn_mask = paddle.stack( + x=[self_attn_mask] * self.n_head, axis=1) + n_head_self_attn_mask.stop_gradient = True + + return emb_out, n_head_self_attn_mask + + +class Plato2InferModel(nn.Layer): + def __init__(self, + nsp_reader, + num_layers, + n_head, + hidden_size, + vocab_size=8001, + type_size=2, + latent_type_size=20, + max_position_seq_len=256, + act_dropout=0.1, + attn_dropout=0.1, + max_dec_len=64, + min_dec_len=1, + topk=10): + super(Plato2InferModel, self).__init__() + + self.nsp_reader = nsp_reader + self.num_layers = num_layers + self.latent_type_size = latent_type_size + self.max_dec_len = max_dec_len + self.min_dec_len = min_dec_len + self.topk = topk + self.unk_id = 0 + self.bos_id = 1 + self.eos_id = 2 + self.mask_id = 8000 + self.after_eos = paddle.ones([vocab_size]) * -1e9 + self.after_eos[self.eos_id] = 0 + self.is_cn = False + self.batch_size = 1 + + self.latent_weight = paddle.create_parameter( + [hidden_size, latent_type_size], 'float32') + + self.plato2_encoder = Plato2Encoder( + vocab_size, type_size, max_position_seq_len, num_layers, n_head, + hidden_size, attn_dropout, act_dropout) + + self.logits_fc_layer = nn.Linear(hidden_size, hidden_size) + self.logits_layer_norm = nn.LayerNorm(hidden_size) + self.logits_bias = paddle.create_parameter( + [vocab_size], 'float32', is_bias=True) + + self.nsp_predictor = NSP(vocab_size, type_size, max_position_seq_len, + num_layers, n_head, hidden_size, attn_dropout, + act_dropout) + + self.gelu_layer = nn.GELU() + self.softmax = nn.Softmax() + + @paddle.no_grad() + def forward(self, inputs): + token_ids = inputs['token_ids'] + type_ids = inputs['type_ids'] + pos_ids = inputs['pos_ids'] + generation_mask = inputs['generation_mask'] + latent_id = inputs['latent_id'] + data_id = inputs['data_id'] + + # [-1, 1, latent_type_size] + latent_id = F.one_hot(latent_id, self.latent_type_size) + # [-1, 1, hidden_size] + latent_emb = paddle.matmul( + latent_id, self.latent_weight, transpose_y=True) + + caches = self.plato2_encoder.gen_caches(token_ids) + + # [-1, seq_len + 1, hidden_size] + enc_out, new_caches = self.plato2_encoder( + caches, token_ids, type_ids, pos_ids, generation_mask, latent_emb) + + pred_ids = self.decoder(inputs, new_caches) + + nsp_inputs = self.gen_nsp_input(token_ids, pred_ids) + # [-1, 2] + probs = self.nsp_predictor(nsp_inputs) + + return self.get_results(data_id, token_ids, pred_ids, probs) + + def decoder(self, inputs, caches): + tgt_ids = inputs['tgt_ids'] + tgt_pos = inputs['tgt_pos'] + tgt_generation_mask = inputs['tgt_generation_mask'] + predictions = tgt_ids + + # TODO + step = 0 + while step < self.max_dec_len: + # [-1, 1] + append_mask = paddle.cast( + tgt_ids != self.eos_id, dtype=tgt_generation_mask.dtype) + tgt_generation_mask = paddle.concat( + [tgt_generation_mask, paddle.unsqueeze(append_mask, 1)], + axis=-1) + tgt_sent = paddle.ones( + [tgt_generation_mask.shape[0], 1], dtype=tgt_ids.dtype) + + # [-1, 1, hidden_size] + out, caches = self.plato2_encoder(caches, tgt_ids, tgt_sent, + tgt_pos, tgt_generation_mask) + out = paddle.squeeze(out, axis=1) + + # [-1, hidden_size] + trans = self.logits_fc_layer(out) + trans = self.gelu_layer(trans) + trans = self.logits_layer_norm(trans) + + # [-1, vocab_size] + logits = paddle.matmul( + trans, + self.plato2_encoder.word_embedding_layer.weight, + transpose_y=True) + self.logits_bias + logits[:, self.unk_id] = -1e9 + logits[:, self.bos_id] = -1e9 + logits[:, self.mask_id] = -1e9 + if step < self.min_dec_len: + logits[:, self.eos_id] = -1e9 + logits = logits * append_mask + (1 - append_mask) * self.after_eos + probs = self.softmax(logits) + + # [-1, topk] + topk_probs, _ = paddle.topk(probs, k=self.topk) + mask = paddle.cast(probs >= topk_probs[:, -1:], 'float32') + sums = paddle.sum(topk_probs, axis=-1, keepdim=True) + new_probs = probs * mask / sums + # [-1, 1] + sampling_ids = paddle.multinomial(new_probs) + + step = step + 1 + tgt_ids = sampling_ids + tgt_pos = tgt_pos + 1 + predictions = paddle.concat([predictions, tgt_ids], axis=1) + return predictions + + def gen_nsp_input(self, token_ids, pred_ids): + token_ids = token_ids.numpy() + pred_ids = pred_ids.numpy() + + def __reader__(): + headers = ["src", "tgt", "data_id"] + + Example = namedtuple("Example", headers) + + for i, (raw, pred) in enumerate(zip(token_ids, pred_ids)): + context = post_process_context( + raw, self.nsp_reader, merge=False) + _, response = post_process_response( + pred, self.nsp_reader, merge=False) + context_tokenized_input = " [SEP] ".join(" ".join(utt) + for utt in context) + response_tokenized_input = " ".join(response) + example = Example( + src=context_tokenized_input, + tgt=response_tokenized_input, + data_id=i) + data = self.nsp_reader._convert_example_to_record( + example, is_infer=True) + yield data + return + + generator = self.nsp_reader.data_generator( + reader=__reader__, + is_infer=True, + phase="test", ) + inputs = next(generator()) + + #print('\nnsp_inputs:') + for key in inputs: + inputs[key] = paddle.to_tensor(inputs[key]) + if key in ['token_ids', 'type_ids', 'pos_ids']: + inputs[key] = paddle.squeeze(inputs[key], axis=-1) + #print(key, inputs[key].shape) + #print(inputs[key]) + return inputs + + def get_results(self, data_id, token_ids, pred_ids, probs): + data_id = data_id.numpy() + token_ids = token_ids.numpy() + pred_ids = pred_ids.numpy() + probs = probs.numpy() + + infos = [] + for raw, pred, prob in zip(token_ids, pred_ids, probs): + tokens = post_process_context(raw, self.nsp_reader) + pred_token_ids, pred_tokens = post_process_response(pred, + self.nsp_reader) + info = {} + info['response'] = ' '.join(pred_tokens) + cross_turn_repetition = get_cross_turn_repetition( + tokens, pred_tokens, self.nsp_reader.eos_id, self.is_cn) + in_turn_repetition = max( + get_in_turn_repetition(pred_tokens, self.is_cn), + get_in_turn_repetition(pred_token_ids)) + + info['score'] = float(prob[1]) + if len(pred_token_ids) >= self.max_dec_len: + info['score'] -= 1e3 + elif cross_turn_repetition > 0: + info['score'] -= 1e3 + elif in_turn_repetition > 0: + info['score'] -= 1e3 + infos.append(info) + + results = [] + pre_idx = 0 + sample = [] + for idx, info in zip(data_id, infos): + if idx != pre_idx: + sample = sorted(sample, key=lambda info: -info["score"]) + result = sample[0] + result['data_id'] = pre_idx + results.apeend(result) + sample = [] + pre_idx = idx + sample.append(info) + if sample: + sample = sorted(sample, key=lambda info: -info["score"]) + result = sample[0] + result['data_id'] = pre_idx + results.append(result) + return results diff --git a/PaddleNLP/examples/dialogue/plato-2/readers/dialog_reader.py b/PaddleNLP/examples/dialogue/plato-2/readers/dialog_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..ccf347acf3efcbf4932d3db8ae31f45e2fb7f005 --- /dev/null +++ b/PaddleNLP/examples/dialogue/plato-2/readers/dialog_reader.py @@ -0,0 +1,491 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dialogue Reader.""" + +import csv +from collections import namedtuple +from contextlib import contextmanager +import gzip + +import numpy as np + +from utils import pad_batch_data +from utils.args import str2bool +from utils.masking import mask +import utils.tokenization as tokenization + + +class DialogReader(object): + """The implement of DialogReader.""" + + @classmethod + def add_cmdline_args(cls, parser): + """Add cmdline argurments.""" + group = parser.add_argument_group("Reader") + group.add_argument("--max_src_len", type=int, default=128) + group.add_argument("--max_tgt_len", type=int, default=128) + group.add_argument( + "--truncate_first_turn", type=str2bool, default=False) + group.add_argument( + "--file_format", + type=str, + default="file", + choices=["file", "filelist"]) + group.add_argument( + "--data_format", + type=str, + default="raw", + choices=["raw", "tokenized", "numerical"]) + group.add_argument("--in_tokens", type=str2bool, default=False) + group.add_argument("--batch_size", type=int, default=16) + group.add_argument("--continuous_position", type=str2bool, default=True) + group.add_argument("--random_seed", type=int, default=11) + group.add_argument("--sort_pool_size", type=int, default=2**16) + + group = parser.add_argument_group("Tokenizer") + group.add_argument( + "--tokenizer", type=str, default="SentencePieceTokenizer") + args, _ = parser.parse_known_args() + tokenizer_cls = getattr(tokenization, args.tokenizer) + tokenizer_cls.add_cmdline_args(parser) + return group + + def __init__(self, args): + tokenizer_cls = getattr(tokenization, args.tokenizer) + self.tokenizer = tokenizer_cls(args) + self.vocab = self.tokenizer.vocab + self.pad_id = args.pad_id = self.vocab["[PAD]"] + self.bos_id = args.bos_id = self.vocab["[CLS]"] + self.eos_id = args.eos_id = self.vocab["[SEP]"] + self.unk_id = args.unk_id = self.vocab["[UNK]"] + self.mask_id = args.mask_id = self.vocab["[MASK]"] + self.vocab_size = args.get("vocab_size", 0) + self.max_src_len = args.max_src_len + self.max_tgt_len = args.max_tgt_len + self.truncate_first_turn = args.truncate_first_turn + self.file_format = args.file_format + self.data_format = args.data_format + self.in_tokens = args.in_tokens + self.batch_size = args.batch_size + self.continuous_position = args.continuous_position + self.sort_pool_size = args.sort_pool_size + + # random_seed must be set for data slicing when using multi-gpu + self.global_rng = np.random.RandomState(args.random_seed) + + # training progress + self.current_example = 0 + self.current_epoch = 0 + self.num_examples = 0 + + # model related + + self.fields = ["token_ids", "type_ids", "pos_ids"] + self.num_numerical_fields = len(self.fields) + self.fields += ["tgt_start_idx", "data_id"] + self.sort_key = lambda record: [len(record.token_ids)] + + self.Record = namedtuple( + "Record", self.fields, defaults=(None, ) * len(self.fields)) + + self.features = {} + return + + def get_train_progress(self): + """Gets progress for training phase.""" + return self.current_epoch, self.current_file_index, self.total_file + + def _convert_example_to_record(self, example, is_infer): + # process src + src_token_ids = [] + src_pos_ids = [] + + # tokenize src + s_token_ids_list = [] + for s in example.src.split("[SEP]"): + s = tokenization.convert_to_unicode(s).strip() + + if self.data_format == "tokenized": + s_tokens = s.split(" ") + else: + s_tokens = self.tokenizer.tokenize(s) + + s_token_ids = self.tokenizer.convert_tokens_to_ids( + s_tokens) + [self.eos_id] + s_token_ids_list.append(s_token_ids) + + # trim src + idx = len(s_token_ids_list) - 1 + total_token_num = 1 + while idx >= 0: + total_token_num += len(s_token_ids_list[idx]) + if total_token_num > self.max_src_len: + if self.truncate_first_turn and idx == 0: + truncated_ids = s_token_ids_list[idx][:self.max_src_len - + total_token_num] + if len(truncated_ids) > 1: + s_token_ids_list[ + idx] = truncated_ids[:-1] + [self.eos_id] + idx -= 1 + break + idx -= 1 + + for i, s_token_ids in enumerate(s_token_ids_list[idx + 1:], idx + 1): + src_token_ids += s_token_ids + src_pos_ids += list(range(1, len(s_token_ids) + 1)) + + src_token_ids = [self.bos_id] + src_token_ids + src_type_ids = [0] * len(src_token_ids) + src_pos_ids = [0] + src_pos_ids + assert len(src_token_ids) == len(src_type_ids) == len(src_pos_ids), \ + "not len(src_token_ids) == len(src_type_ids) == len(src_pos_ids)" + + token_ids = src_token_ids + type_ids = src_type_ids + pos_ids = src_pos_ids + tgt_start_idx = len(token_ids) + + if not is_infer: + # process tgt + # tokenize tgt + tgt = tokenization.convert_to_unicode(example.tgt).strip() + if self.data_format == "tokenized": + tgt_tokens = tgt.split(" ") + else: + tgt_tokens = self.tokenizer.tokenize(tgt) + + tgt_token_ids = self.tokenizer.convert_tokens_to_ids(tgt_tokens) + tgt_token_ids.append(self.eos_id) + + # trim tgt + if len(tgt_token_ids) > self.max_tgt_len - 1: + tgt_token_ids = tgt_token_ids[:self.max_tgt_len - 1] + + tgt_token_ids = [self.bos_id] + tgt_token_ids + tgt_type_ids = [1] * len(tgt_token_ids) + tgt_pos_ids = list(range(1, len(tgt_token_ids) + 1)) + assert len(tgt_token_ids) == len(tgt_type_ids) == len(tgt_pos_ids), \ + "not len(tgt_token_ids) == len(tgt_type_ids) == len(tgt_pos_ids)" + + token_ids += tgt_token_ids + type_ids += tgt_type_ids + pos_ids += tgt_pos_ids + + assert len(token_ids) == len(type_ids) == len(pos_ids), \ + "not len(token_ids) == len(type_ids) == len(pos_ids)" + + if self.continuous_position: + src_pos_ids = list(range(len(src_token_ids))) + if not is_infer: + tgt_pos_ids = list(range(len(tgt_token_ids))) + pos_ids = list(range(len(token_ids))) + + field_values = { + "token_ids": src_token_ids, + "type_ids": src_type_ids, + "pos_ids": src_pos_ids + } + field_values["tgt_start_idx"] = tgt_start_idx + field_values["data_id"] = example.data_id + + record = self.Record(**field_values) + return record + + def _read_tsv(self, fp, phase, is_infer, delimiter="\t", quotechar=None): + """Reads a tab separated value file.""" + csv.field_size_limit(2**20) + reader = csv.reader(fp, delimiter=delimiter, quotechar=quotechar) + headers = next(reader) + headers.append("data_id") + Example = namedtuple("Example", headers) + + for i, line in enumerate(reader): + example = Example(*line, data_id=i) + if is_infer or phase.endswith("test"): + self.features[phase][i] = example + record = self._convert_example_to_record(example, is_infer) + yield record + + def _read_numerical_file(self, fp, delimiter=";"): + for i, line in enumerate(fp): + cols = tokenization.convert_to_unicode(line).strip().split( + delimiter) + cols = list(map(lambda x: list(map(int, x.split(" "))), cols)) + if len(cols) > self.num_numerical_fields: + cols = cols[:self.num_numerical_fields] + tgt_start_idx = cols[0].index(self.bos_id, 1) + record = self.Record(*cols, tgt_start_idx=tgt_start_idx, data_id=i) + yield record + + def _read_file(self, input_file, phase, is_infer): + def __wrapper__(): + with open_file(input_file) as fp: + if self.data_format == "numerical": + records = self._read_numerical_file(fp) + else: + records = self._read_tsv(fp, phase, is_infer) + for record in records: + yield record + + return __wrapper__ + + def _read_files(self, filelist, phase, is_infer, shuffle_files): + input_files = open(filelist).readlines() + + def __wrapper__(): + if shuffle_files: + self.global_rng.shuffle(input_files) + + if phase == "train": + self.total_file = len(input_files) + for file_index, input_file in enumerate(input_files, 1): + if phase == "train": + self.current_file_index = file_index + self.current_file = input_file + file_reader = self._read_file(input_file.strip(), phase, + is_infer) + for record in file_reader(): + yield record + + return __wrapper__ + + def _batch_reader(self, + reader, + phase=None, + is_infer=False, + sort_pool_size=2**16): + """Construct a batch reader.""" + + def update_max_lens(max_lens, record): + """Update max_lens.""" + if max_lens is None: + return self.sort_key(record) + else: + return [ + max(max_len, l) for max_len, l in zip(max_lens, + self.sort_key(record)) + ] + + def get_batch(reader): + """Generate batches from reader.""" + batch, max_lens = [], None + for record in reader(): + if record is None: + yield batch + batch, max_lens = [], None + continue + + self.current_example += 1 + max_lens = update_max_lens(max_lens, record) + if self.in_tokens: + to_append = (len(batch) + 1 + ) * sum(max_lens) <= self.batch_size + else: + to_append = len(batch) < self.batch_size + if to_append: + batch.append(record) + else: + yield batch + batch, max_lens = [record], self.sort_key(record) + + if len(batch) > 0: + yield batch + + def get_sorted_batch(pool): + """Generate sorted batches from pool.""" + pool = sorted(pool, key=self.sort_key) + batches = [] + batch, max_lens = [], None + for record in pool: + self.current_example += 1 + max_lens = update_max_lens(max_lens, record) + if self.in_tokens: + to_append = (len(batch) + 1 + ) * sum(max_lens) <= self.batch_size + else: + to_append = len(batch) < self.batch_size + if to_append: + batch.append(record) + else: + batches.append(batch) + batch, max_lens = [record], self.sort_key(record) + + if len(batch) > 0: + batches.append(batch) + self.global_rng.shuffle(batches) + + for batch in batches: + yield batch + + def __wrapper__(): + if sort_pool_size > 0: + pool = [] + for record in reader(): + pool.append(record) + if len(pool) == sort_pool_size: + for batch in get_sorted_batch(pool): + yield batch + pool = [] + if len(pool) > 0: + for batch in get_sorted_batch(pool): + yield batch + else: + for batch in get_batch(reader): + yield batch + + return __wrapper__ + + def _distributed_batch_reader(self, + batch_reader, + num_part, + part_id, + is_test=False): + def __wrapper__(): + batches = [] + for batch in batch_reader(): + batches.append(batch) + if len(batches) == num_part: + yield batches[part_id] + batches = [] + if is_test and 0 <= part_id < len(batches): + yield batches[part_id] + return + + return __wrapper__ + + def data_generator(self, + input_file=None, + reader=None, + num_epochs=1, + num_part=1, + part_id=0, + phase=None, + is_infer=False): + """Data generator.""" + + def __wrapper__(): + if is_infer or phase.endswith("test"): + self.features[phase] = {} + + nonlocal reader + if reader is None: + if self.file_format == "filelist": + reader = self._read_files(input_file, phase, is_infer, + not phase.endswith("test")) + else: + if phase == "train": + self.total_file = 1 + self.current_file_index = 1 + self.current_file = input_file + reader = self._read_file(input_file, phase, is_infer) + + batch_reader = self._batch_reader( + reader, + phase, + is_infer, + sort_pool_size=self.sort_pool_size if not is_infer else 0) + if phase == "train": + batch_reader = self._distributed_batch_reader(batch_reader, + num_part, part_id) + elif phase.startswith("distributed"): + batch_reader = self._distributed_batch_reader( + batch_reader, num_part, part_id, is_test=True) + + for epoch_index in range(num_epochs): + if phase == "train": + self.current_example = 0 + self.current_epoch = epoch_index + 1 + for batch in batch_reader(): + yield self._pad_batch_records(batch, is_infer) + + return __wrapper__ + + def _gen_self_attn_mask(self, + batch_token_ids, + batch_tgt_start_idx=None, + is_unidirectional=True, + shift_len=0): + max_len = max(map(len, batch_token_ids)) + input_mask_data = np.zeros( + (len(batch_token_ids), max_len + shift_len, max_len + shift_len)) + if is_unidirectional: + for index, mask_data in enumerate(input_mask_data): + start = 0 if batch_tgt_start_idx is None else batch_tgt_start_idx[ + index] + end = len(batch_token_ids[index]) + mask_data[:end + shift_len, :start + shift_len] = 1.0 + # Generate the lower triangular matrix using the slice of matrix + b = np.tril(np.ones([end - start, end - start]), 0) + mask_data[start + shift_len:end + shift_len, start + shift_len: + end + shift_len] = b + else: + for index, token_ids in enumerate(batch_token_ids): + input_mask_data[index, :len(token_ids) + shift_len, :len( + token_ids) + shift_len] = 1.0 + return input_mask_data.astype("float32") + + def _pad_batch_records(self, batch_records, is_infer): + """ + Padding batch records and construct model's inputs. + """ + batch_size = len(batch_records) + batch = {} + batch_token_ids = [record.token_ids for record in batch_records] + batch_type_ids = [record.type_ids for record in batch_records] + batch_pos_ids = [record.pos_ids for record in batch_records] + batch["token_ids"] = pad_batch_data(batch_token_ids, pad_id=self.pad_id) + batch["type_ids"] = pad_batch_data(batch_type_ids, pad_id=self.pad_id) + batch["pos_ids"] = pad_batch_data(batch_pos_ids, pad_id=self.pad_id) + + batch_tgt_start_idx = [record.tgt_start_idx for record in batch_records] + batch["generation_mask"] = self._gen_self_attn_mask( + batch_token_ids, batch_tgt_start_idx=batch_tgt_start_idx) + + if is_infer: + tgt_ids = np.array( + [[[self.bos_id]]] * len(batch_token_ids), dtype="int64") + if self.continuous_position: + tgt_pos = np.array(batch_tgt_start_idx, dtype="int64") + else: + tgt_pos = np.zeros_like(batch_tgt_start_idx, dtype="int64") + tgt_pos = tgt_pos.reshape(-1, 1, 1) + batch["init_score"] = np.zeros_like( + tgt_ids, dtype="float32").reshape(-1, 1).tolist() + batch["tgt_ids"] = tgt_ids.tolist() + batch["tgt_pos"] = tgt_pos.tolist() + + batch["tgt_generation_mask"] = batch[ + "generation_mask"][:, 0:1, :].astype("float32") + else: + batch["tgt_label"], batch["tgt_pos"] = mask( + batch_tokens=batch_token_ids, + vocab_size=self.vocab_size, + sent_b_starts=batch_tgt_start_idx, + is_unidirectional=True) + + batch_data_id = [record.data_id for record in batch_records] + batch["data_id"] = np.array(batch_data_id).astype("int64").reshape( + [-1, 1]) + return batch + + +@contextmanager +def open_file(filename): + """Open file.""" + if filename.endswith(".gz"): + fp = gzip.open(filename, "rt") + else: + fp = open(filename) + yield fp + fp.close() diff --git a/PaddleNLP/examples/dialogue/plato-2/readers/nsp_reader.py b/PaddleNLP/examples/dialogue/plato-2/readers/nsp_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3767d5bc42f7db9a59f0f4989e0ad315033d7f --- /dev/null +++ b/PaddleNLP/examples/dialogue/plato-2/readers/nsp_reader.py @@ -0,0 +1,172 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""NSP Reader.""" + +from collections import namedtuple + +import numpy as np + +from readers.dialog_reader import DialogReader +from utils import pad_batch_data +from utils.args import str2bool +from utils.masking import mask + + +class NSPReader(DialogReader): + """NSP Reader.""" + + @classmethod + def add_cmdline_args(cls, parser): + """Add cmdline argurments.""" + group = DialogReader.add_cmdline_args(parser) + group.add_argument( + "--attention_style", + type=str, + default="bidirectional", + choices=["bidirectional", "unidirectional"]) + group.add_argument( + "--mix_negative_sample", type=str2bool, default=False) + return group + + def __init__(self, args): + super(NSPReader, self).__init__(args) + self.fields.append("label") + self.Record = namedtuple( + "Record", self.fields, defaults=(None, ) * len(self.fields)) + + self.attention_style = args.attention_style + self.mix_negative_sample = args.mix_negative_sample + return + + def _convert_example_to_record(self, example, is_infer): + record = super(NSPReader, self)._convert_example_to_record(example, + False) + if "label" in example._fields: + record = record._replace(label=int(example.label)) + return record + + def _mix_negative_sample(self, reader, neg_pool_size=2**16): + def gen_from_pool(pool): + num_samples = len(pool) + if num_samples == 1: + # only one sample: it is impossible to generate negative sample + yield pool[0]._replace(label=1) + return + self.global_rng.shuffle(pool) + for i in range(num_samples): + pool[i] = pool[i]._replace(label=1) + j = (i + 1) % num_samples + idx_i = pool[i].tgt_start_idx + idx_j = pool[j].tgt_start_idx + field_values = {} + field_values["token_ids"] = pool[i].token_ids[:idx_i] + pool[ + j].token_ids[idx_j:] + field_values["type_ids"] = pool[i].type_ids[:idx_i] + pool[ + j].type_ids[idx_j:] + field_values["pos_ids"] = list( + range(len(field_values["token_ids"]))) + neg_record = self.Record( + **field_values, tgt_start_idx=idx_i, data_id=-1, label=0) + pool.append(neg_record) + assert len(neg_record.token_ids) <= self.max_seq_len + self.global_rng.shuffle(pool) + for record in pool: + yield record + + def __wrapper__(): + pool = [] + for record in reader(): + pool.append(record) + if len(pool) == neg_pool_size: + for record in gen_from_pool(pool): + yield record + pool = [] + if len(pool) > 0: + for record in gen_from_pool(pool): + yield record + + return __wrapper__ + + def _batch_reader(self, + reader, + phase=None, + is_infer=False, + sort_pool_size=2**16): + if self.mix_negative_sample: + reader = self._mix_negative_sample(reader) + return super(NSPReader, self)._batch_reader( + reader, + phase=phase, + is_infer=is_infer, + sort_pool_size=sort_pool_size) + + def _pad_batch_records(self, batch_records, is_infer): + """ + Padding batch records and construct model's inputs. + """ + batch = {} + batch_token_ids = [record.token_ids for record in batch_records] + batch_type_ids = [record.type_ids for record in batch_records] + batch_pos_ids = [record.pos_ids for record in batch_records] + batch_tgt_start_idx = [record.tgt_start_idx for record in batch_records] + batch_label = [record.label for record in batch_records] + + if self.attention_style == "unidirectional": + batch["token_ids"] = pad_batch_data( + batch_token_ids, pad_id=self.pad_id) + batch["type_ids"] = pad_batch_data( + batch_type_ids, pad_id=self.pad_id) + batch["pos_ids"] = pad_batch_data(batch_pos_ids, pad_id=self.pad_id) + tgt_label, tgt_pos, label_pos = mask( + batch_tokens=batch_token_ids, + vocab_size=self.vocab_size, + bos_id=self.bos_id, + sent_b_starts=batch_tgt_start_idx, + labels=batch_label, + is_unidirectional=True) + attention_mask = self._gen_self_attn_mask(batch_token_ids, + batch_tgt_start_idx) + else: + batch_mask_token_ids, tgt_label, tgt_pos, label_pos = mask( + batch_tokens=batch_token_ids, + vocab_size=self.vocab_size, + bos_id=self.bos_id, + eos_id=self.eos_id, + mask_id=self.mask_id, + sent_b_starts=batch_tgt_start_idx, + labels=batch_label, + is_unidirectional=False) + if not is_infer: + batch_token_ids = batch_mask_token_ids + batch["token_ids"] = pad_batch_data( + batch_token_ids, pad_id=self.pad_id) + batch["type_ids"] = pad_batch_data( + batch_type_ids, pad_id=self.pad_id) + batch["pos_ids"] = pad_batch_data(batch_pos_ids, pad_id=self.pad_id) + attention_mask = self._gen_self_attn_mask( + batch_token_ids, is_unidirectional=False) + + batch["attention_mask"] = attention_mask + batch["label_pos"] = label_pos + + if not is_infer: + batch_label = np.array(batch_label).astype("int64").reshape([-1, 1]) + batch["label"] = batch_label + batch["tgt_label"] = tgt_label + batch["tgt_pos"] = tgt_pos + + batch_data_id = [record.data_id for record in batch_records] + batch["data_id"] = np.array(batch_data_id).astype("int64").reshape( + [-1, 1]) + return batch diff --git a/PaddleNLP/examples/dialogue/plato-2/readers/plato_reader.py b/PaddleNLP/examples/dialogue/plato-2/readers/plato_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..5e7dc4c107dc9b6b9e28402ecc5b14b662447a53 --- /dev/null +++ b/PaddleNLP/examples/dialogue/plato-2/readers/plato_reader.py @@ -0,0 +1,90 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Plato Reader.""" + +import numpy as np + +from readers.dialog_reader import DialogReader +from utils import pad_batch_data +from utils.masking import mask + + +class PlatoReader(DialogReader): + """The implement of PlatoReader""" + + def __init__(self, args): + super(PlatoReader, self).__init__(args) + self.latent_type_size = args.latent_type_size + self.use_bow = args.use_bow + + def _pad_batch_records(self, batch_records, is_infer): + """ + Padding batch records and construct model's inputs. + """ + batch = {} + batch_token_ids = [record.token_ids for record in batch_records] + batch_type_ids = [record.type_ids for record in batch_records] + batch_pos_ids = [record.pos_ids for record in batch_records] + + batch_tgt_start_idx = [record.tgt_start_idx for record in batch_records] + + batch_size = len(batch_token_ids) + + # padding + batch["token_ids"] = pad_batch_data(batch_token_ids, pad_id=self.pad_id) + batch["type_ids"] = pad_batch_data(batch_type_ids, pad_id=self.pad_id) + batch["pos_ids"] = pad_batch_data(batch_pos_ids, pad_id=self.pad_id) + + batch["generation_mask"] = self._gen_self_attn_mask( + batch_token_ids, + batch_tgt_start_idx=batch_tgt_start_idx, + is_unidirectional=True, + shift_len=1) + if not is_infer: + batch["recognition_mask"] = self._gen_self_attn_mask( + batch_token_ids, is_unidirectional=False, shift_len=1) + + if is_infer: + tgt_ids = np.array([[[self.bos_id]]] * batch_size, dtype="int64") + if self.continuous_position: + tgt_pos = np.array(batch_tgt_start_idx, dtype="int64") + else: + tgt_pos = np.zeros_like(batch_tgt_start_idx, dtype="int64") + tgt_pos = tgt_pos.reshape(-1, 1, 1) + batch["init_score"] = np.zeros_like( + tgt_ids, dtype="float32").reshape(-1, 1).tolist() + batch["tgt_ids"] = tgt_ids.tolist() + batch["tgt_pos"] = tgt_pos.tolist() + batch["parent_idx"] = np.array(range(batch_size), dtype="int32") + + batch["tgt_generation_mask"] = batch[ + "generation_mask"][:, 0:1, :].astype("float32") + else: + mask_return_list = mask( + batch_tokens=batch_token_ids, + vocab_size=self.vocab_size, + sent_b_starts=batch_tgt_start_idx, + is_unidirectional=True, + use_latent=True, + use_bow=self.use_bow) + batch["tgt_label"] = mask_return_list[0] + batch["tgt_pos"] = mask_return_list[1] + if self.use_bow: + batch["bow_label"] = mask_return_list[2] + batch["bow_pos"] = mask_return_list[3] + + batch_data_id = [record.data_id for record in batch_records] + batch["data_id"] = np.array(batch_data_id).astype("int64").reshape( + [-1, 1]) + return batch diff --git a/PaddleNLP/examples/dialogue/plato-2/utils/__init__.py b/PaddleNLP/examples/dialogue/plato-2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8865dcbbdef2375bdbb6ddf4b7a5c19e73f8030b --- /dev/null +++ b/PaddleNLP/examples/dialogue/plato-2/utils/__init__.py @@ -0,0 +1,58 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils.""" + +from itertools import chain +import numpy as np +import paddle + + +def repeat_array(array, times): + """Repeate numpy array.""" + if isinstance(array, list): + return list(chain(*([array] * times))) + else: + return np.concatenate([array] * times, axis=0) + + +def gen_inputs(inputs, latent_type_size): + batch_size = len(inputs["data_id"]) + new_bsz = batch_size * latent_type_size + inputs = { + name: repeat_array(array, latent_type_size) + for name, array in inputs.items() + } + # Add latent_id + inputs["latent_id"] = np.array( + [i for i in range(latent_type_size) for _ in range(batch_size)], + dtype="int64").reshape([-1, 1]) + + #print('\nplato_inputs:') + for key in inputs: + inputs[key] = paddle.to_tensor(inputs[key]) + if key in [ + 'token_ids', 'type_ids', 'pos_ids', 'tgt_ids', 'tgt_pos', + 'data_id' + ]: + inputs[key] = paddle.squeeze(inputs[key], axis=-1) + #print(key, inputs[key].shape, inputs[key].dtype) + return inputs + + +def pad_batch_data(insts, pad_id=0): + """Pad the instances to the max sequence length in batch. """ + max_len = max(map(len, insts)) + inst_data = np.array( + [list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts]) + return inst_data.astype("int64").reshape([-1, max_len, 1]) diff --git a/PaddleNLP/examples/dialogue/plato-2/utils/args.py b/PaddleNLP/examples/dialogue/plato-2/utils/args.py new file mode 100644 index 0000000000000000000000000000000000000000..514151ecf02843180c2845cf1000f78dbf3ff165 --- /dev/null +++ b/PaddleNLP/examples/dialogue/plato-2/utils/args.py @@ -0,0 +1,98 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Parse argument.""" + +import argparse +import json + + +def str2bool(v): + """ Support bool type for argparse. """ + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Unsupported value encountered.") + + +class Args(dict): + """ Arguments class + + Store arguments in training / infer / ... scripts. + """ + + def __getattr__(self, name): + if name in self.keys(): + return self[name] + for v in self.values(): + if isinstance(v, Args): + if name in v: + return v[name] + return None + + def get(self, key, default_value=None): + """Get the value of corresponding key.""" + if key in self.keys(): + return self[key] + for v in self.values(): + if isinstance(v, Args): + if key in v: + return v[key] + return default_value + + def __setattr__(self, name, value): + self[name] = value + + def save(self, filename): + with open(filename, "w") as fp: + json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False) + + def load(self, filename, group_name=None): + if group_name is not None: + if group_name not in self: + self[group_name] = Args() + self[group_name].load(filename) + return + with open(filename, "r") as fp: + params_dict = json.load(fp) + for k, v in params_dict.items(): + if isinstance(v, dict): + self[k].update(Args(v)) + else: + self[k] = v + + +def parse_args(parser: argparse.ArgumentParser, allow_unknown=False) -> Args: + """ Parse hyper-parameters from cmdline. """ + if allow_unknown: + parsed, _ = parser.parse_known_args() + else: + parsed = parser.parse_args() + args = Args() + optional_args = parser._action_groups[1] + for action in optional_args._group_actions[1:]: + arg_name = action.dest + args[arg_name] = getattr(parsed, arg_name) + for group in parser._action_groups[2:]: + group_args = Args() + for action in group._group_actions: + arg_name = action.dest + group_args[arg_name] = getattr(parsed, arg_name) + if len(group_args) > 0: + if group.title in args: + args[group.title].update(group_args) + else: + args[group.title] = group_args + return args diff --git a/PaddleNLP/examples/dialogue/plato-2/utils/masking.py b/PaddleNLP/examples/dialogue/plato-2/utils/masking.py new file mode 100644 index 0000000000000000000000000000000000000000..968c1c4a17e118419e0dd6f273c9c3e1c9474496 --- /dev/null +++ b/PaddleNLP/examples/dialogue/plato-2/utils/masking.py @@ -0,0 +1,133 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Reader utils.""" + +import numpy as np + + +def mask(batch_tokens, + vocab_size, + bos_id=1, + eos_id=2, + mask_id=3, + sent_b_starts=None, + labels=None, + is_unidirectional=False, + use_latent=False, + use_bow=False): + """ + Add mask for batch_tokens, return out, mask_label, mask_pos; + Note: mask_pos responding the batch_tokens after padded; + """ + batch_tokens = np.copy(batch_tokens) + max_len = max(map(len, batch_tokens)) + mask_label = [] + mask_pos = [] + if labels is not None: + label_pos = [] + + if is_unidirectional: + # unidirectional language model + if use_latent: + max_len += 1 + shift_len = 1 + else: + shift_len = 0 + for sent_index, sent in enumerate(batch_tokens): + sent_b_index = sent_b_starts[ + sent_index] if sent_b_starts is not None else 0 + need_cal = True + if labels is not None: + label_pos.append(sent_index * max_len + len(sent) - 1 + + shift_len) + if labels[sent_index] == 0: + need_cal = False + mask_label.extend(sent[sent_b_index + 1:]) + mask_pos.extend([ + sent_index * max_len + i + shift_len + for i in range(sent_b_index, len(sent) - 1) + ]) + mask_label = np.array(mask_label).astype("int64").reshape([-1, 1]) + mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1]) + return_list = [mask_label, mask_pos] + + # latent related (bow label and pos) + if use_latent and use_bow: + bow_label = [] + bow_pos = [] + for sent_index, sent in enumerate(batch_tokens): + sent_b_index = sent_b_starts[ + sent_index] if sent_b_starts is not None else 0 + + def __filter__(tok_id): + # TODO: exclude [EOS] from bow loss + return True + + bow_pos.extend([ + sent_index for i in range(sent_b_index + 1, len(sent)) + if __filter__(sent[i]) + ]) + bow_label.extend([ + sent[i] for i in range(sent_b_index + 1, len(sent)) + if __filter__(sent[i]) + ]) + bow_label = np.array(bow_label).astype("int64").reshape([-1, 1]) + bow_pos = np.array(bow_pos).astype("int64").reshape([-1, 1]) + return_list += [bow_label, bow_pos] + else: + # bidirectional mask language model + total_token_num = sum(map(len, batch_tokens)) + prob_mask = np.random.rand(total_token_num) + # TODO: fix replace_ids, include [UNK] + replace_ids = np.random.randint( + 3, high=vocab_size, size=total_token_num) + prob_index = 0 + for sent_index, sent in enumerate(batch_tokens): + # add pair label position + if labels is not None: + label_pos.append(sent_index * max_len) + + # add mask label and position + for token_index, token in enumerate(sent): + if token == eos_id or token == bos_id: + continue + prob = prob_mask[prob_index + token_index] + if prob > 0.15: + continue + elif 0.03 < prob <= 0.15: + # mask + mask_label.append(sent[token_index]) + sent[token_index] = mask_id + mask_pos.append(sent_index * max_len + token_index) + elif 0.015 < prob <= 0.03: + # random replace + mask_label.append(sent[token_index]) + sent[token_index] = replace_ids[prob_index + token_index] + mask_pos.append(sent_index * max_len + token_index) + else: + # keep the original token + mask_label.append(sent[token_index]) + mask_pos.append(sent_index * max_len + token_index) + + prob_index += len(sent) + + mask_label = np.array(mask_label).astype("int64").reshape([-1, 1]) + mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1]) + return_list = [batch_tokens, mask_label, mask_pos] + + if labels is not None: + label_pos = np.array(label_pos).astype("int64").reshape([-1, 1]) + assert len(labels) == len(label_pos) + return_list.append(label_pos) + return return_list diff --git a/PaddleNLP/examples/dialogue/plato-2/utils/tokenization.py b/PaddleNLP/examples/dialogue/plato-2/utils/tokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..780bf27a690dcfde349f8b2fc85e65f4add9edb9 --- /dev/null +++ b/PaddleNLP/examples/dialogue/plato-2/utils/tokenization.py @@ -0,0 +1,193 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" + +import collections +import sentencepiece as spm +import unicodedata + +from utils.args import str2bool + + +def clean_text(text): + """Performs invalid character removal and whitespace cleanup on text.""" + text = text.replace(u"“", u'"')\ + .replace(u'”', u'"')\ + .replace(u'‘', "'")\ + .replace(u'’', u"'")\ + .replace(u'—', u'-') + + output = [] + for char in text: + if _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +def preprocess_text(inputs, remove_space=True, lower=False): + """preprocess data by removing extra space and normalize data.""" + outputs = inputs + if remove_space: + outputs = " ".join(inputs.strip().split()) + + outputs = unicodedata.normalize("NFKD", outputs) + outputs = "".join([c for c in outputs if not unicodedata.combining(c)]) + if lower: + outputs = outputs.lower() + + return outputs + + +def encode_pieces(spm_model, text, return_unicode=True, sample=False): + """turn sentences into word pieces.""" + # liujiaxiang: add for ernie-albert, mainly consider for “/”/‘/’/— causing too many unk + text = clean_text(text) + + if not sample: + pieces = spm_model.EncodeAsPieces(text) + else: + pieces = spm_model.SampleEncodeAsPieces(text, 64, 0.1) + + return pieces + + +def encode_ids(spm_model, text, sample=False): + """turn sentences into word pieces.""" + pieces = encode_pieces(spm_model, text, return_unicode=False, sample=sample) + ids = [spm_model.PieceToId(piece) for piece in pieces] + return ids + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + fin = open(vocab_file) + for num, line in enumerate(fin): + items = convert_to_unicode(line.rstrip()).split("\t") + if len(items) > 2: + break + token = items[0] + index = items[1] if len(items) == 2 else num + token = token.strip() + vocab[token] = int(index) + return vocab + + +def convert_by_vocab(vocab, items): + """Converts a sequence of [tokens|ids] using the vocab.""" + output = [] + for item in items: + output.append(vocab[item]) + return output + + +class SentencePieceTokenizer(object): + """Runs end-to-end tokenziation.""" + + @classmethod + def add_cmdline_args(cls, parser): + """Add cmdline argurments.""" + group = parser.add_argument_group("Tokenizer") + group.add_argument("--vocab_path", type=str, required=True) + group.add_argument("--do_lower_case", type=str2bool, default=False) + group.add_argument("--spm_model_file", type=str, required=True) + return group + + def __init__(self, args): + self.spm_model = spm.SentencePieceProcessor() + self.spm_model.Load(args.spm_model_file) + self.vocab = load_vocab(args.vocab_path) + self.do_lower_case = args.do_lower_case + self.inv_vocab = {v: k for k, v in self.vocab.items()} + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = preprocess_text(text, lower=self.do_lower_case) + return encode_pieces(self.spm_model, text, return_unicode=True) + + def convert_tokens_to_ids(self, tokens): + """Convert tokens to ids.""" + ret = [] + unk_id = self.vocab[""] + for token in tokens: + if token in self.vocab: + ret.append(self.vocab[token]) + else: + ret.append(unk_id) + return ret + + def convert_ids_to_tokens(self, ids): + """Convert ids to tokens.""" + return convert_by_vocab(self.inv_vocab, ids) + + def merge_subword(self, tokens): + """Merge subword.""" + ret = [] + for token in tokens: + if token.startswith(u"▁"): + ret.append(token[1:]) + else: + if len(ret): + ret[-1] += token + else: + ret.append(token) + + ret = [token for token in ret if token] + return ret + + def convert_ids_to_str(self, ids): + """Convert ids to string.""" + tokens = self.convert_ids_to_tokens(ids) + tokens = self.merge_subword(tokens) + res = " ".join(tokens).replace("", "") + res = res.replace("", "\n").replace("\n ", "\n").strip() + return res + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False