**PLATO: Pre-trained Dialogue Generation Model with Discrete Latent Variable**
[paper link](http://arxiv.org/abs/1910.07931)
## Requirements
```
- python >= 3.6
- paddlepaddle >= 1.5.2
- numpy
- nltk
- tqdm
- visualdl >= 1.3.0 (optional)
```
## Pre-trained dialogue generation model
A novel pre-training model for dialogue generation is introduced in this work, incorporated with latent discrete variables for one-to-many relationship modeling. Our model is flexible enough to support various kinds of conversations, including chit-chat, knowledge grounded dialogues, and conversational question answering. The pre-training is carried out with Reddit and Twitter corpora. You can download the uncased pre-trained model from:
We also provide instructions to fine-tune PLATO on different conversation datasets (chit-chat, knowledge grounded dialogues and conversational question answering).
### Data preparation
Download data from the [link](https://baidu-nlp.bj.bcebos.com/PLATO/data.tar.gz).
The tar file contains three processed datasets: DailyDialog, PersonaChat and DSTC7_AVSD.
```bash
mv /path/to/data.tar.gz .
tar xzf data.tar.gz
```
### Data format
Our model supports two kinds of data formats for dialogue context: "multi" and "multi_knowledge".
* multi: multi-turn dialogue context.
```txt
u_1 __eou__ u_2 __eou__ ... u_n \t r
```
* multi_knowledge: multi-turn dialogue context with background knowledge.
If you want to use this model on other datasets, you can process your data accordingly.
### Train
Fine-tuning the pre-trained model on different ${DATASET}.
```bash
# DailyDialog / PersonaChat / DSTC7_AVSD
DATASET=DailyDialog
sh scripts/${DATASET}/train.sh
```
After training, you can find the output folder `outputs/${DATASET}` (by default). It contatins `best.model` (best results on validation dataset), `hparams.json` (hyper-parameters of training script) and `trainer.log` (training log).
#### Recommended settings
For the fine-tuning of our pre-trained model, it usually requires about 10 epochs to reach convergence with learning rate = 1e-5 and about 2-3 epochs to reach convergence with learning rate = 5e-5.
GPU_MEM | batch_size | max_len
------|------|------
16G | 6 | 256
32G | 12 | 256
### Infer
Running inference on test dataset.
```bash
# DailyDialog / PersonaChat / DSTC7_AVSD
DATASET=DailyDialog
sh scripts/${DATASET}/infer.sh
```
After inference, you can find the output foler `outputs/${DATASET}.infer` (by default). It contains `infer_0.result.json` (the inference result), `hparams.json` (hyper-parameters of inference scipt) and `trainer.log` (inference log).
Note: In the experiments on DSTC_AVSD, the response selection of our method is strengthened with an extra ranking step, which ranks the candidates according to the automatic scores and selects the top one as the final answer.
## Citation
If you find PLATO useful in your work, please cite the following Arxiv paper:
```
@article{bao2019plato,
title={PLATO: Pre-trained Dialogue Generation Model with Discrete Latent Variable},
author={Bao, Siqi and He, Huang, Wang, Fan and Wu, Hua},
journal={arXiv preprint arXiv:1910.07931},
year={2019}
}
```
## Contact information
For help or issues using PLATO, please submit a GitHub issue.
For personal communication related to PLATO, please contact Siqi Bao (`baosiqi@baidu.com`), or Huang He (`hehuang@baidu.com`).