README.md 6.2 KB
Newer Older
S
sserdoubleh 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
# PLATO
**PLATO: Pre-trained Dialogue Generation Model with Discrete Latent Variable**
[paper link](http://arxiv.org/abs/1910.07931)

**\*\*\*\*\* Update \*\*\*\*\***

Nov. 14: Support new APIs in paddlepaddle 1.6.0 (model files in the link have been updated accordingly), multi-GPU training and decoding strategy of top-k sampling. Release our baseline model `PLATO w/o latent`.

## Requirements
```
- python >= 3.6
- paddlepaddle >= 1.6.0
- numpy
- nltk
- tqdm
- visualdl >= 1.3.0 (optional)
- regex
```

## Pre-trained dialogue generation model
A novel pre-training model for dialogue generation is introduced in this work, incorporated with latent discrete variables for one-to-many relationship modeling. Our model is flexible enough to support various kinds of conversations, including chit-chat, knowledge grounded dialogues, and conversational question answering. The pre-training is carried out with Reddit and Twitter corpora. You can download the uncased pre-trained model from:
* PLATO, uncased [model](https://baidu-nlp.bj.bcebos.com/PLATO/model.tar.gz): 12-layers, 768-hidden, 12-heads, 132M parameters
* PLATO w/o latent, uncased [model](https://baidu-nlp.bj.bcebos.com/PLATO/model-baseline.tar.gz): 12-layers 768-hidden, 12-heads, 109M parameters

```bash
mv /path/to/model.tar.gz .
tar xzf model.tar.gz
```

## Fine-tuning
We also provide instructions to fine-tune PLATO on different conversation datasets (chit-chat, knowledge grounded dialogues and conversational question answering).

### Data preparation
Download data from the [link](https://baidu-nlp.bj.bcebos.com/PLATO/data.tar.gz).
The tar file contains three processed datasets: `DailyDialog`, `PersonaChat` and `DSTC7_AVSD`.
```bash
mv /path/to/data.tar.gz .
tar xzf data.tar.gz
```

### Data format
Our model supports two kinds of data formats for dialogue context: `multi` and `multi_knowledge`.
* `multi`: multi-turn dialogue context.
```txt
u_1 __eou__ u_2 __eou__ ... u_n \t r
```
* `multi_knowledge`: multi-turn dialogue context with background knowledges.
```txt
k_1 __eou__ k_2 __eou__ ... k_m \t u_1 __eou__ u_2 __eou__ ... u_n \t r
```

If you want to use this model on other datasets, you can process your data accordingly.

### Train
Fine-tuning the pre-trained model on different `${DATASET}`.
```bash
# DailyDialog / PersonaChat / DSTC7_AVSD
DATASET=DailyDialog
sh scripts/${DATASET}/train.sh
```
After training, you can find the output folder `outputs/${DATASET}` (by default). It contatins `best.model` (best results on validation dataset), `hparams.json` (hyper-parameters of training script) and `trainer.log` (training log).


Fine-tuning the pre-trained model on multiple GPUs.

Note: You need to install NCCL library and set up the environment variable `LD_LIBRARY` properly.
```bash
sh scripts/DailyDialog/multi_gpu_train.sh
```

You can fine-tune PLATO w/o latent on different `${DATASET}`. We provide an example script on DailyDialog dataset.
```bash
sh scripts/DailyDialog/baseline_train.sh
```

#### Recommended settings

For the fine-tuning of our pre-trained model, it usually requires about 10 epochs to reach convergence with learning rate = 1e-5 and about 2-3 epochs to reach convergence with learning rate = 5e-5.

GPU Memory | batch size | max len
------|------|------
16G | 6 | 256
32G | 12 | 256

### Infer
Running inference on test dataset.
```bash
# DailyDialog / PersonaChat / DSTC7_AVSD
DATASET=DailyDialog
sh scripts/${DATASET}/infer.sh

# Running inference of PLATO w/o latent
sh scripts/DailyDialog/baseline_infer.sh
```
After inference, you can find the output foler `outputs/${DATASET}.infer` (by default). It contains `infer_0.result.json` (the inference result), `hparams.json` (hyper-parameters of inference scipt) and `trainer.log` (inference log).

If you want to use top-k sampling (beam search by default), you can follow the example script:
```bash
sh scripts/DailyDialog/topk_infer.sh
```

## Result

### DailyDialog
Model | BLEU-1/2 | Distinct-1/2 | Fluency | Coherence | Informativeness | Overall
------|------|------|------|------|------|-------
Seq2Seq | 0.336/0.268 | 0.030/0.128 | 1.85 | 0.37 | 0.44 | 0.33
iVAE_MI | 0.309/0.249 | 0.029/0.250 | 1.53 | 0.34 | 0.59 | 0.30
Our w/o Latent | **0.405/0.322** | 0.046/0.246 | 1.91 | **1.58** | 1.03 | 1.44
Our Method | 0.397/0.311 | **0.053/0.291** | **1.97** | 1.57 | **1.23** | **1.48**

### PersonaChat
Model | BLEU-1/2 | Distinct-1/2 | Knowledge R/P/F1 | Fluency | Coherence | Informativeness | Overall
------|------|------|------|------|------|-------|-------
Seq2Seq | 0.448/0.353 | 0.004/0.016 | 0.004/0.016/0.006 | 1.82 | 0.37 | 0.85 | 0.34
LIC | 0.405/0.320 | 0.019/0.113 | 0.042/0.154/0.064 | 1.95 | 1.34 | 1.09 | 1.29
Our w/o Latent | **0.458/0.357** | 0.012/0.064 | 0.085/0.263/0.125 | 1.98 | 1.36 | 1.04 | 1.30
Our Method | 0.406/0.315 | **0.021/0.121** | **0.142/0.461/0.211** | **1.99** | **1.51** | **1.70** | **1.50**

### DSTC7_AVSD
Model | BELU-1 | BELU-2 | BLEU-3 | BLEU-4 | METEOR | ROUGH-L | CIDEr
------|------|------|------|------|------|-------|-------
Baseline | 0.629 | 0.485 | 0.383 | 0.309 | 0.215 | 0.487 | 0.746
CMU | 0.718 | 0.584 | 0.478 | 0.394 | 0.267 | 0.563 | 1.094
Our Method | **0.784** | **0.637** | **0.525** | **0.435** | **0.286** | **0.596** | **1.209**
Our Method Upper Bound | 0.925 | 0.843 | 0.767 | 0.689 | 0.361 | 0.731 | 1.716

Note: In the experiments on `DSTC7_AVSD`, the response selection of our method is strengthened with an extra ranking step, which ranks the candidates according to the automatic scores and selects the top one as the final answer.

## Citation
If you find PLATO useful in your work, please cite the following Arxiv paper:
```
@article{bao2019plato,
    title={PLATO: Pre-trained Dialogue Generation Model with Discrete Latent Variable},
    author={Bao, Siqi and He, Huang and Wang, Fan and Wu, Hua and Wang, Haifeng},
    journal={arXiv preprint arXiv:1910.07931},
    year={2019}
}
```

## Disclaimer
This project aims to facilitate further research progress in dialogue generation. Baidu is not responsible for the 3rd party's generation with the pre-trained system.

## Contact information
For help or issues using PLATO, please submit a GitHub issue.

For personal communication related to PLATO, please contact Siqi Bao (`baosiqi@baidu.com`), or Huang He (`hehuang@baidu.com`).