diff --git a/ernie-gen/README.md b/ernie-gen/README.md new file mode 100644 index 0000000000000000000000000000000000000000..12751f3b7a0cfe5042f95abe821fbc0f339b08ba --- /dev/null +++ b/ernie-gen/README.md @@ -0,0 +1,228 @@ +English | [简体中文](./README.zh.md) + +## _ERNIE-GEN_: An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation + +- [Proposed Generation Framework](#proposed-generation-framework) +- [Pre-trained Models](#pre-trained-models) +- [Fine-tuning on Downstream Tasks](#fine-tuning-on-downstream-tasks) + * [Abstractive Summarization](#abstractive-summarization) + * [Question Generation](#question-generation) + * [Generative Dialogue Response](#generative-dialogue-response) + * [Generative Question Answering](#generative-question-answering) +- [Usage](#usage) + * [Install PaddlePaddle](#install-paddlepaddle) + * [Fine-tuning](#fine-tuning) + * [Employ Dynamic Computation Graph](#employ-dynamic-computation-graph) + * [The ERNIE 1.0 is avaliable](#the-ernie-10-is-avaliable-for-chinese-generation-tasks) +- [Citation](#citation) + +For technical description of the algorithm, please see our paper: +>[_**ERNIE-GEN:An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation**_](https://arxiv.org/abs/2001.11314.pdf) + +>Dongling Xiao\*, Han Zhang\*, Yukun Li, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang (\* : equal contribution) + +>Preprint January 2020 + +>Accepted by **IJCAI-2020** + + +![ERNIE-GEN](https://img.shields.io/badge/Pretraining-Generation-green) ![Gigaword](https://img.shields.io/badge/Abstractive Summarization-Gigaword-yellow) ![Gigaword](https://img.shields.io/badge/Abstractive Summarization-CNN/Daily Mail-blue) ![SQuAD](https://img.shields.io/badge/Question Generation-SQuAD-green) ![Personal-Chat](https://img.shields.io/badge/Dialogue Response-Personal Chat-yellowgreen) ![CoQA](https://img.shields.io/badge/Generative Question Answering-CoQA-orange) +--- +**[ERNIE-GEN](https://arxiv.org/abs/2001.11314.pdf) is a multi-flow language generation framework for both pre-training and fine-tuning.** We propose a novel **span-by-span generation** pre-training task to enable the model to **generate a semantically-complete span** at each step rather than a word, in light of the fact that entities, phrases in human writing are organized in a coherent manner. An **infilling generation mechanism** and a **noise-aware generation method** are incorporated into both pre-training and fine-tuning to alleviate **the problem of exposure bias**. In the pre-training phase, ERNIE-GEN adopts a **multi-granularity target fragments sampling** strategy to force decoder to rely more on the encoder representations other than the previous generated words to enhancing the correlation between encoder and decoder. + +## Proposed Generation Framework + +We construct three novel methods to enhance the language generation ability: + +- **Span-by-span Generation Pre-training Task**: to enable model to generate a semantically-complete span at each step rather than a word. +- **Infilling Genration and Noise-aware Generation**: to alleviate the problem of exposure bias. +- **Multi-Granularity Target Fragments**: to enhance the correlation between encoder and decoder during pre-training. + +Specifically, the span-by-span generation task and word-by-word generation task based on infilling generation mechanism are impemented by a carefully designed **Multi-Flow Attention** architecture as shown below. + +![multi-flow-attention](.meta/multi-flow-attention.png) + +## Pre-trained Models + +We release the checkpoints for **ERNIE-GEN _base_** model and **ERNIE-GEN _large_** model which are both pre-trained on English Wikipedia and [BookCorpus](https://arxiv.org/abs/1506.06724) (totally 16GB). Besides, **ERNIE-GEN _large_** pre-trained on the 160GB corpus (used by [RoBERTa](https://arxiv.org/abs/1907.11692) and [BART](https://arxiv.org/abs/1910.13461)) is available as well. + +- [**ERNIE-GEN _base_**](https://ernie.bj.bcebos.com/ernie_gen_base.tgz) (_lowercased | 12-layer, 768-hidden, 12-heads, 110M parameters_) +- [**ERNIE-GEN _large_**](https://ernie.bj.bcebos.com/ernie_gen_large.tgz) (_lowercased | 24-layer, 1024-hidden, 16-heads, 340M parameters_) +- [**ERNIE-GEN _large with 160G_**](https://ernie.bj.bcebos.com/ernie_gen_large_160g.tgz) (_lowercased | 24-layer, 1024-hidden, 16-heads, 340M parameters_) + + +## Fine-tuning on Downstream Tasks + +We compare the performance of [ERNIE-GEN](https://arxiv.org/pdf/2001.11314.pdf) with the existing SOTA pre-training models for natural language generation ([UniLM](https://arxiv.org/abs/1905.03197), [MASS](https://arxiv.org/abs/1905.02450), [PEGASUS](https://arxiv.org/abs/1912.08777), [BART](https://arxiv.org/abs/1910.13461) and [T5](https://arxiv.org/abs/1910.10683)) on 5 genration tasks, including abstractive summarization (**_Gigaword_** and **_CNN/DailyMail_**), question generation (**_SQuAD_**), dialogue generation (**_Persona-Chat_**) and generative question answering (**_CoQA_**). + +### Abstractive Summarization + +- _**Gigaword**_ + +The results on Gigaword-10k (10K examples of Gigaword) are presented as follows: + +| Model | Data / Params | Rouge-1 | Rouge-2 | Rouge-L | +| :-------------------------------------------------------- | :----------------------------: | :----------------------: | :----------------------: | :----------------------: | +| UniLM | 16G / 340M | 34.21 | 15.28 | 31.54 | +| **ENRIE-GEN** _base_ | 16G / 110M | 33.75 | 15.23 | 31.35 | +| **ERNIE-GEN** _large_ | 16G / 340M | 35.05 | 16.10 | 32.50 | +| **ERNIE-GEN** _large_ (160G) | 160G / 340M | **35.51** | **16.79** | **33.23** | + +The results on Gigaword are presented as follows: + +| Model | Data / Params | Rouge-1 | Rouge-2 | Rouge-L | +| :-------------------------------------------------------- | :----------------------------: | :----------------------: | :----------------------: | :----------------------: | +| MASS | 18G / 160M | 38.73 | 19.71 | 35.96 | +| BERTSHARE | 16G / 110M | 38.13 | 19.81 | 35.62 | +| UniLM | 16G / 340M | 38.45 | 19.45 | 35.75 | +| PEGASUS (_C4_) | 750G / 568M | 38.75 | 19.96 | 36.14 | +| PEGASUS (_HugeNews_) | 3.8T / 568M | 39.12 | 19.86 | 36.24 | +| **ENRIE-GEN** _base_ | 16G / 110M | 38.83 | 20.04 | 36.20 | +| **ERNIE-GEN** _large_ | 16G / 340M | 39.25 | 20.25 | 36.53 | +| **ERNIE-GEN** _large_ (160G) | 160G / 340M | **39.46** | **20.34** | **36.74** | + +We preprocess the raw Gigaword dataset following UniLM, the preprocessed data is avalilable at this [Gigaword](https://ernie.bj.bcebos.com/gigaword.tgz). + +- _**CNN/Daily Mail**_ + +The results on CNN/Daily Mail are presented as follows: + +| Model | Data / Params | Rouge-1 | Rouge-2 | Rouge-L | +| :-------------------------------------------------------- | :-----------: | :----------------------: | :----------------------: | :----------------------: | +| MASS | 18G / 160M | 42.12 | 19.50 | 39.01 | +| UniLM | 16G / 340M | 43.33 | 20.21 | 40.51 | +| T5 _large_ | 750G / 340M | 42.50 | 20.68 | 39.75 | +| T5 _xlarge_ | 750G / 11B | 43.52 | **21.55** | 40.69 | +| BART | 160G / 400M | 44.16 | 21.28 | 40.90 | +| PEGASUS (_C4_) | 750G / 568M | 43.90 | 21.20 | 40.76 | +| PEGASUS (_HugeNews_) | 3.8T / 568M | 44.17 | 21.47 | 41.11 | +| **ENRIE-GEN** _base_ | 16G / 110M | 42.30 | 19.92 | 39.68 | +| **ENRIE-GEN** _large_ | 16G / 340M | 44.02 | 21.17 | 41.26 | +| **ENRIE-GEN** _large_ (160G) | 160G / 340M | **44.31** | 21.35 | **41.60** | + +We preprocess the raw CNN/Daily Mail dataset following UniLM, the preprocessed data is avalilable at this [CNN/Daily Mail](https://ernie.bj.bcebos.com/cnndm.tgz). + +### Question Generation + +- _**SQuAD**_ + +The results on the [SQuAD 1.1](https://arxiv.org/abs/1806.03822) dataset following the data split in [[Du et al., 2017]](https://arxiv.org/pdf/1705.00106.pdf) are presented as follows: + +| Model | BLEU-4 | METEOR | Rouge-L | +| :----------------------------------------------------------- | :----------------------: | :----------------------: | :----------------------: | +| [SemQG](https://arxiv.org/abs/1909.06356) | 18.37 | 22.65 | 46.68 | +| UniLM _large_ (beam size=1) | 22.12 | 25.06 | 51.07 | +| **ENRIE-GEN** _base_ (beam size=1) | 22.28 | 25.13 | 50.38 | +| **ERNIE-GEN** _large_ (beam size=1) | 24.03 | 26.31 | 52.36 | +| **ERNIE-GEN** _large_ (beam size=5) | 25.40 | **26.92** | 52.84 | +| **ERNIE-GEN** _large_ (beam size=5) + (160G) | **25.41** | 26.77 | **52.91** | + +The results following the reversed dev-test data split in [[Zhao et al., 2018]](https://www.aclweb.org/anthology/D18-1424/) are presented as follows: + +| Model | BLEU-4 | METEOR | Rouge-L | +| :----------------------------------------------------------- | :----------------------: | :----------------------: | :----------------------: | +| SemQG | 20.76 | 24.20 | 48.91 | +| UniLM _large_ (beam size=1) | 23.75 | 25.61 | 52.04 | +| **ENRIE-GEN** _base_ (beam size=1) | 23.52 | 25.61 | 51.45 | +| **ERNIE-GEN** _large_ (beam size=1) | 25.57 | 26.89 | 53.31 | +| **ERNIE-GEN** _large_ (beam size=5) | 26.95 | **27.57** | 53.77 | +| **ERNIE-GEN** _large_ (beam size=5) + (160G) | **27.05** | 27.43 | **53.83** | + +*_Note that we also report the results with higher beam size to 5._ + +The preprocessed data for question generation task can be downloaded from [SQuAD](https://ernie.bj.bcebos.com/squad_qg.tgz). + +### Generative Dialogue Response + +- _**Personal-Chat**_ + + Comparison with current state-of-the-art results on the multi-turn conversations task ([Persona-Chat](https://arxiv.org/abs/1801.07243)) is presented as follows: + +| Model | BLEU-1 | BLEU-2 | Distinct-1 | Distinct-2 | +| :-------------------------------------------------------- | :---------------------: | :---------------------: | :-------------------------: | :---------------------------: | +| [LIC](https://arxiv.org/abs/1910.07931) | 40.5 | 32.0 | 0.019 | 0.113 | +| [PLATO](https://arxiv.org/abs/1910.07931) | 45.8 | 35.7 | 0.012 | 0.064 | +| PLATO _w/o latent_ | 40.6 | 31.5 | 0.021 | 0.121 | +| **ERNIE-GEN** _large_ | **46.8** | **36.4** | **0.023** | **0.168** | + +The training data can be downloaded from [Personal-Chat](https://ernie.bj.bcebos.com/persona_chat.tgz). + +### Generative Question Answering + +- _**CoQA**_ + +Results of development set on CoQA task is presented as follows: + +| Model | F1-score | +| :-------------------------------------------------------- | :------: | +| [Seq2Seq](https://arxiv.org/abs/1910.07931) | 27.5 | +| [PGNet](https://arxiv.org/abs/1910.07931) | 45.4 | +| UniLM _large_ | 82.5 | +| **ERNIE-GEN** _large_ | **84.5** | + +We preprocess the raw [CoQA](https://arxiv.org/abs/1808.07042) dataset, the preprocessed data is avalilable at this [CoQA-preprocessed](https://ernie.bj.bcebos.com/coqa.tgz). + +Finally, we also compared with a concurrent work [ProphetNet](https://arxiv.org/abs/2001.04063), the fine-tuning results on Gigaword, CNN/Daily Mail and SQuAD are reported as follows: + +- _**Abstractive Summarization**_ + +| Model / Task | Data / Params | Gigaword |CNN/Daily Mail| +| :-------------------------------------------------------- | :----------------------------: | :----------------------: | :----------------------: | +| Metric | - | Rouge-1 / Rouge-2 / Rouge-L |Rouge-1 / Rouge-2 / Rouge-L| +| **ProphetNet** _large_ (160G) | 160G / 340M | **39.51** / **20.42** / 36.69 |44.20 / 21.17 / 41.30| +| **ERNIE-GEN** _large_ (160G) | 160G / 340M | 39.46 / 20.34 / **36.74** |**44.31** / **21.35** / **41.60**| + +- _**Question Generation**_ + +| Model | Data / Params | BLEU-4 / METEOR / Rouge-L |BLEU-4 / METEOR / Rouge-L| +| :-------------------------------------------------------- | :----------------------------: | :----------------------: |:----------------------: | +| Data split | - | Original |Reversed dev-test| +| **ProphetNet** _large_ (16G) | 16G / 340M | 25.01 / 26.83 / 52.57 |26.72 / **27.64** / **53.79** | +| **ERNIE-GEN** _large_ (16G) | 16G / 340M | **25.40** / **26.92** / **52.84** |**26.95** / 27.57 / **53.77**| + +## Usage + +### Install PaddlePaddle + +This code base has been tested with Paddle Fluid 1.7 with Python 2.7. Other dependency of ERNIE-GEN is listed in `requirements.txt`, you can install it by +```script +pip install -r requirements.txt +``` + +### Fine-tuning +Please update LD_LIBRARY_PATH about CUDA, cuDNN, NCCL2 before running ERNIE-GEN. We have put the parameter configurations of the above downstream tasks in `config/`. You can easily run finetuning through these configuration files. For example, you can finetune ERNIE-GEN base model on Gigaword by +```script +MODEL="base" # base or large or large_160g +TASK="gigaword" # cnndm, coqa, gigaword, squad_qg or persona-chat +sh run_seq2seq.sh ./configs/${MODEL}/${TASK}_conf +``` +The log of training and the evaluation results are in `log/job.log.0`. To finetune on your own task data, you can refer to the data format we provide for processing your data. + +Our fine-tuning experiments are carried on 8 NVIDIA V100 (32GB) GPUs. If your GPU memory is not enough, you can reduce the batch size in the corresponding configuration file. + +**NOTICE: ** The actual total batch size is equal to `configured batch size * number of used gpus`. + +### Employ Dynamic Computation Graph + +The ERNIE-GEN code using dynamic graph is more concise and flexible, please refer to [ERNIE-GEN Dygraph](https://github.com/PaddlePaddle/ERNIE/tree/develop/experimental/seq2seq) for specific use. + +### The ERNIE 1.0 is avaliable for Chinese Generation Tasks + +The ERNIE-GEN code is compatible with [ERNIE 1.0](https://ernie.bj.bcebos.com/ERNIE_1.0_max-len-512.tar.gz) model. After specifying the parameters related to the model and data in the configuration file, you can use ERNIE 1.0 to fine-tune chinese generation tasks. + +## Citation + +You can cite the paper as below: + +``` +@article{xiao2020ernie-gen, + title={ERNIE-GEN: An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation}, + author={Xiao, Dongling and Zhang, Han and Li, Yukun and Sun, Yu and Tian, Hao and Wu, Hua and Wang, Haifeng}, + journal={arXiv preprint arXiv:2001.11314}, + year={2020} +} +``` + + + + diff --git a/ernie-gen/README.zh.md b/ernie-gen/README.zh.md new file mode 100644 index 0000000000000000000000000000000000000000..2ee9cd4ad2e4e16472c8dc2b5a73dc7056f31236 --- /dev/null +++ b/ernie-gen/README.zh.md @@ -0,0 +1,226 @@ +[English](./README.md) | 简体中文 + +## _ERNIE-GEN_: An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation + +- [模型框架](#模型框架) +- [预训练模型](#预训练模型) +- [微调任务](#微调任务) + * [生成式摘要](#生成式摘要) + * [问题生成](#问题生成) + * [多轮对话](#多轮对话) + * [生成式多轮问答](#生成式多轮问答) +- [使用说明](#使用说明) + * [安装飞桨](#安装飞桨) + * [运行微调](#运行微调) + * [使用动态图](#使用动态图) + * [中文生成任务使用 ERNIE 1.0](#中文生成任务使用-ernie-10) +- [引用](#引用) + +关于算法的详细描述,请参见我们的论文: +>[_**ERNIE-GEN:An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation**_](https://arxiv.org/abs/2001.11314) + +>Dongling Xiao\*, Han Zhang\*, Yukun Li, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang (\* : equal contribution) + +>Preprint January 2020 + +>Accepted by **IJCAI-2020** + + +![ERNIE-GEN](https://img.shields.io/badge/预训练-生成-green) ![Gigaword](https://img.shields.io/badge/生成式摘要-Gigaword-yellow) ![Gigaword](https://img.shields.io/badge/生成式摘要-CNN/Daily Mail-blue) ![SQuAD](https://img.shields.io/badge/问题生成-SQuAD-green) ![Personal-Chat](https://img.shields.io/badge/多轮对话-Personal Chat-yellowgreen) ![CoQA](https://img.shields.io/badge/多轮问答-CoQA-orange) +--- +**ERNIE-GEN 是面向生成任务的预训练-微调框架**,首次在预训练阶段加入**span-by-span 生成**任务,让模型每次能够生成一个语义完整的片段。在预训练和微调中通过**填充式生成机制**和**噪声感知机制**来缓解曝光偏差问题。此外, ERNIE-GEN 采样**多片段-多粒度目标文本采样**策略, 增强源文本和目标文本的关联性,加强了编码器和解码器的交互。 + +## 模型框架 + +我们提出了三种方法来提高语言生成能力: + +- **Span-by-span 生成任务**: 让模型能够每次生成一个语义完整的片段。 +- **填充式生成**和**噪声感知生成**: 缓解曝光偏差问题。 +- **多片段-多粒度目标文本采样**: 预训练阶段增强编码器和解码器的交互。 + +我们基于 Transformer 模型设计了 **Mulit-Flow Attention** 框架,用于实现 span-by-span 的填充式生成。 + +![multi-flow-attention](.meta/multi-flow-attention.png) + +## 预训练模型 + +我们发布了 **ERNIE-GEN _base_** 模型和 **ERNIE-GEN _large_** 模型。 预训练数据使用英文维基百科和 BookCorpus,总共16GB。此外,我们还发布了基于 160GB 语料预训练的**ERNIE-GEN _large_** 模型,此份语料也被用于 [RoBERTa](https://arxiv.org/abs/1907.11692) 和 [BART](https://arxiv.org/abs/1910.13461) 的预训练。 + +- [**ERNIE-GEN _base_**](https://ernie.bj.bcebos.com/ernie_gen_base.tgz) (_lowercased | 12-layer, 768-hidden, 12-heads, 110M parameters_) +- [**ERNIE-GEN _large_**](https://ernie.bj.bcebos.com/ernie_gen_large.tgz) (_lowercased | 24-layer, 1024-hidden, 16-heads, 340M parameters_) +- [**ERNIE-GEN _large with 160G_**](https://ernie.bj.bcebos.com/ernie_gen_large_160g.tgz) (_lowercased | 24-layer, 1024-hidden, 16-heads, 340M parameters_) + + +## 微调任务 + +我们在五个典型生成任务上与当前效果最优的生成预训练模型([UniLM](https://arxiv.org/abs/1905.03197)、[MASS](https://arxiv.org/abs/1905.02450)、[PEGASUS](https://arxiv.org/abs/1912.08777)、[BART](https://arxiv.org/abs/1910.13461)、[T5](https://arxiv.org/abs/1910.10683)等)进行对比, 包括生成式摘要 (Gigaword 和 CNN/DailyMail), 问题生成(SQuAD), 多轮对话(Persona-Chat) 和生成式多轮问答(CoQA)。 + +### 生成式摘要 + +- _**Gigaword**_ + +在 Gigaword-10k (Gigaword 的子集) 上的效果: + +| 模型 | 数据量 / 参数量 | Rouge-1 | Rouge-2 | Rouge-L | +| :-------------------------------------------------------- | :------------------------------: | :----------------------: | :----------------------: | :----------------------: | +| UniLM | 16G / 340M | 34.21 | 15.28 | 31.54 | +| **ENRIE-GEN** _base_ | 16G / 110M | 33.75 | 15.23 | 31.35 | +| **ERNIE-GEN** _large_ | 16G / 340M | 35.05 | 16.10 | 32.50 | +| **ERNIE-GEN** _large_ (160G) | 160G / 340M | **35.51** | **16.79** | **33.23** | + +在 Gigaword 上的效果: + +| 模型 | 数量 / 参数量 | Rouge-1 | Rouge-2 | Rouge-L | +| :-------------------------------------------------------- | :----------------------------: | :----------------------: | :----------------------: | :----------------------: | +| MASS | 18G / 160M | 38.73 | 19.71 | 35.96 | +| [BERTSHARE](https://arxiv.org/abs/1907.12461) | 16G / 110M | 38.13 | 19.81 | 35.62 | +| UniLM | 16G / 340M | 38.45 | 19.45 | 35.75 | +| PEGASUS (_C4_) | 750G / 568M | 38.75 | 19.96 | 36.14 | +| PEGASUS (_HugeNews_) | 3.8T / 568M | 39.12 | 19.86 | 36.24 | +| **ENRIE-GEN** _base_ | 16G / 110M | 38.83 | 20.04 | 36.20 | +| **ERNIE-GEN** _large_ | 16G / 340M | 39.25 | 20.25 | 36.53 | +| **ERNIE-GEN** _large_ (160G) | 160G / 340M | **39.46** | **20.34** | **36.74** | + +我们按照 UniLM 的方式处理了数据,下载链接 [Gigaword](https://ernie.bj.bcebos.com/gigaword.tgz)。 + +- _**CNN/Daily Mail**_ + +在 CNN/Daily Mail 上的效果: + +| 模型 | 数据量 /参数量| Rouge-1 | Rouge-2 | Rouge-L | +| :-------------------------------------------------------- | :-----------: | :----------------------: | :----------------------: | :----------------------: | +| MASS | 18G / 160M | 42.12 | 19.50 | 39.01 | +| UniLM | 16G / 340M | 43.33 | 20.21 | 40.51 | +| T5 _large_ | 750G / 340M | 42.50 | 20.68 | 39.75 | +| T5 _xlarge_ | 750G / 11B | 43.52 | **21.55** | 40.69 | +| BART | 160G / 400M | 44.16 | 21.28 | 40.90 | +| PEGASUS (_C4_) | 750G / 568M | 43.90 | 21.20 | 40.76 | +| PEGASUS (_HugeNews_) | 3.8T / 568M | 44.17 | 21.47 | 41.11 | +| **ENRIE-GEN** _base_ | 16G / 110M | 42.30 | 19.92 | 39.68 | +| **ENRIE-GEN** _large_ | 16G / 340M | 44.02 | 21.17 | 41.26 | +| **ENRIE-GEN** _large_ (160G) | 160G / 340M | **44.31** | 21.35 | **41.60** | + +我们按照 UniLM 的方式处理了数据,下载链接 [CNN/Daily Mail](https://ernie.bj.bcebos.com/cnndm.tgz)。 + +### 问题生成 + +- _**SQuAD**_ + +在 SQuAD 1.1 数据集上的效果(测试集划分按照 [[Du et al., 2017]](https://arxiv.org/abs/1705.00106)) : + +| 模型 | BLEU-4 | METEOR | Rouge-L | +| :----------------------------------------------------------- | :----------------------: | :----------------------: | :----------------------: | +| [SemQG](https://arxiv.org/abs/1909.06356) | 18.37 | 22.65 | 46.68 | +| UniLM _large_ (beam size=1) | 22.12 | 25.06 | 51.07 | +| **ENRIE-GEN** _base_ (beam size=1) | 22.28 | 25.13 | 50.38 | +| **ERNIE-GEN** _large_ (beam size=1) | 24.03 | 26.31 | 52.36 | +| **ERNIE-GEN** _large_ (beam size=5) | 25.40 | **26.92** | 52.84 | +| **ERNIE-GEN** _large_ (beam size=5) + (160G) | **25.41** | 26.77 | **52.91** | + +按照 [[Zhao et al., 2018]](https://www.aclweb.org/anthology/D18-1424/) 反向使用验证集和测试集,效果如下: + +| Model | BLEU-4 | METEOR | Rouge-L | +| :----------------------------------------------------------- | :----------------------: | :----------------------: | :----------------------: | +| [SemQG](https://arxiv.org/abs/1909.06356) | 20.76 | 24.20 | 48.91 | +| UniLM _large_ (beam size=1) | 23.75 | 25.61 | 52.04 | +| **ENRIE-GEN** _base_ (beam size=1) | 23.52 | 25.61 | 51.45 | +| **ERNIE-GEN** _large_ (beam size=1) | 25.57 | 26.89 | 53.31 | +| **ERNIE-GEN** _large_ (beam size=5) | 26.95 | **27.57** | 53.77 | +| **ERNIE-GEN** _large_ (beam size=5) + (160G) | **27.05** | 27.43 | **53.83** | + +*_我们增加了将 beam size 扩大到 5 的结果。_ + +我们按照 UniLM 的方式处理了数据,下载链接 [SQuAD](https://ernie.bj.bcebos.com/squad_qg.tgz)。 + +### 多轮对话 + +- _**Personal-Chat**_ + +| Model | BLEU-1 | BLEU-2 | Distinct-1 | Distinct-2 | +| :-------------------------------------------------------- | :---------------------: | :---------------------: | :-------------------------: | :---------------------------: | +| [LIC](https://arxiv.org/abs/1910.07931) | 40.5 | 32.0 | 0.019 | 0.113 | +| [PLATO](https://arxiv.org/abs/1910.07931) | 45.8 | 35.7 | 0.012 | 0.064 | +| [PLATO](https://arxiv.org/abs/1910.07931) _w/o latent_ | 40.6 | 31.5 | 0.021 | 0.121 | +| **ERNIE-GEN** _large_ | **46.8** | **36.4** | **0.023** | **0.168** | + +我们处理的数据下载链接 [Personal-Chat](https://ernie.bj.bcebos.com/persona_chat.tgz)。 + +### 生成式多轮问答 + +- _**CoQA**_ + +在 CoQA 验证集上的效果: + +| 模型 | F1-score | +| :-------------------------------------------------------- | :------: | +| [Seq2Seq](https://arxiv.org/abs/1910.07931) | 27.5 | +| [PGNet](https://arxiv.org/abs/1910.07931) | 45.4 | +| UniLM _large_ | 82.5 | +| **ERNIE-GEN** _large_ | **84.5** | + +我们对原始的 CoQA 数据集进行了处理,下载链接 [CoQA](https://ernie.bj.bcebos.com/coqa.tgz)。 + +此外,我们与同期的工作 [ProphetNet](https://arxiv.org/abs/2001.04063) 在 Gigaword,CNN/Daily Mail 和 SQuAD 三个数据集上进行了对比: + +- _**生成式摘要**_ + +| 模型 / 任务 | 数据量 / 参数量 | Gigaword |CNN/Daily Mail| +| :-------------------------------------------------------- | :------------------------------: | :----------------------: | :----------------------: | +| Metric | - | Rouge-1 / Rouge-2 / Rouge-L |Rouge-1 / Rouge-2 / Rouge-L| +| ProphetNet _large_ (160G) | 160G / 340M | **39.51** / **20.42** / 36.69 |44.20 / 21.17 / 41.30| +| **ERNIE-GEN** _large_ (160G) | 160G / 340M | 39.46 / 20.34 / **36.74** |**44.31** / **21.35** / **41.60**| + +- _**问题生成**_ + +| 模型 | 数据量 / 参数量 | BLEU-4 / METEOR / Rouge-L |BLEU-4 / METEOR / Rouge-L| +| :-------------------------------------------------------- | :------------------------------: | :----------------------: |:----------------------: | +| Data split | - | Original |Reversed dev-test| +| ProphetNet** _large_ (16G) | 16G / 340M | 25.01 / 26.83 / 52.57 |26.72 / **27.64** / **53.79** | +| **ERNIE-GEN** _large_ (16G) | 16G / 340M | **25.40** / **26.92** / **52.84** |**26.95** / 27.57 / **53.77**| + +## 使用说明 + +### 安装飞桨 + +我们的代码基于 Paddle Fluid 1.7 和 Python 2.7。 ERNIE-GEN 依赖的其他模块也列举在 `requirements.txt`,可以通过下面的指令安装: +```script +pip install -r requirements.txt +``` + +### 运行微调 +在运行 ERNIE-GEN 前,需要将 CUDA 、cuDNN 、NCCL2 的动态库路径添加到 LD_LIBRARY_PATH 。 我们把下游任务的参数配置文件放到了 `config/` ,可以简单地通过配置文件运行。 例如,您可以通过下面的指令在 Gigaword 数据集上微调 ERNIE-GEN base 模型: +```script +MODEL="base" # base or large or large_160g +TASK="gigaword" # cnndm, coqa, gigaword, squad_qg or persona-chat +sh run_seq2seq.sh ./configs/${MODEL}/${TASK}_conf +``` +训练和评估的日志在 `log/job.log.0`。 如果要在您自己的数据集上微调,可以参考我们提供的数据格式处理自己的数据。 + +我们的微调实验在 8 张 32GB 显存的英伟达 V100 GPU 上运行,如果您的 GPU 显存不够,可以减小配置文件中的 batch_size 。 + +**注意**: 训练时实际的 batch size 等于 `配置的 batch size * GPU 卡数`。 + +### 使用动态图 + +动态图版本的 ERNIE-GEN 代码更加简洁灵活,使用请参考 [ERNIE-GEN Dygraph](https://github.com/PaddlePaddle/ERNIE/tree/develop/experimental/seq2seq)。 + +### 中文生成任务使用 ERNIE 1.0 + +ERNIE-GEN 的代码兼容 [ERNIE 1.0 模型](https://ernie.bj.bcebos.com/ERNIE_1.0_max-len-512.tar.gz),修改配置文件中模型和数据相关的设置,就可以用 ERNIE 1.0 在中文生成任务上微调。 + +## 引用 + +可以按下面的格式引用我们的论文: + +``` +@article{xiao2020ernie-gen, + title={ERNIE-GEN: An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation}, + author={Xiao, Dongling and Zhang, Han and Li, Yukun and Sun, Yu and Tian, Hao and Wu, Hua and Wang, Haifeng}, + journal={arXiv preprint arXiv:2001.11314}, + year={2020} +} +``` + + + + diff --git a/ernie-gen/configs/base/cnndm_conf b/ernie-gen/configs/base/cnndm_conf new file mode 100644 index 0000000000000000000000000000000000000000..1e5708756e451a89ae60d02f5356d3247ce94ae0 --- /dev/null +++ b/ernie-gen/configs/base/cnndm_conf @@ -0,0 +1,48 @@ +#load model +vocab_path="ernie_gen_base/vocab.txt" +config_path="ernie_gen_base/ernie_config.json" +init_model="ernie_gen_base/params" + +#input +max_src_len=640 +max_tgt_len=192 +tokenized_input="true" +continuous_position="true" +batch_size=8 +in_tokens="false" +tgt_type_id=3 + +#decode +do_decode="true" +max_dec_len=128 +beam_size=5 +length_penalty=1.0 +use_multi_gpu_test="true" + +#train +epoch=30 +weight_decay=0.01 +label_smooth=0.1 +hidden_dropout_prob=0.1 +save_and_valid_by_epoch="true" +#lr +warmup_proportion=0.1 +lr_scheduler="linear_warmup_decay" +learning_rate=5e-5 +#noise +random_noise="true" +noise_prob=0.7 + +#dataset +data_path="./datasets/cnndm/" +train_set="train.tsv" +dev_set="dev.2k.tsv" +pred_set="test.tsv" +do_train="true" +do_val="true" +do_test="false" +do_pred="true" + +#evaluate +eval_script="sh ./eval/tasks/cnndm/eval.sh" +eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/base/gigaword-10k_conf b/ernie-gen/configs/base/gigaword-10k_conf new file mode 100644 index 0000000000000000000000000000000000000000..58b9e9735b278344c1457ce8edf585ecebdd90c7 --- /dev/null +++ b/ernie-gen/configs/base/gigaword-10k_conf @@ -0,0 +1,48 @@ +#load model +vocab_path="ernie_gen_base/vocab.txt" +config_path="ernie_gen_base/ernie_config.json" +init_model="ernie_gen_base/params" + +#input +max_src_len=192 +max_tgt_len=64 +tokenized_input="true" +continuous_position="true" +batch_size=16 +in_tokens="false" +tgt_type_id=3 + +#decode +do_decode="true" +max_dec_len=32 +beam_size=5 +length_penalty=0.6 +use_multi_gpu_test="true" + +#train +epoch=30 +weight_decay=0.01 +label_smooth=0.1 +save_and_valid_by_epoch="true" +hidden_dropout_prob=0.1 +#lr +warmup_proportion=0.1 +lr_scheduler="linear_warmup_decay" +learning_rate=1.25e-5 +#noise +random_noise="true" +noise_prob=0.5 + +#dataset +data_path="./datasets/gigaword/" +train_set="train.10k.tsv" +dev_set="dev.20k.tsv" +test_set="test.tsv" +do_train="true" +do_val="true" +do_test="true" +do_pred="false" + +#evaluate +eval_script="sh ./eval/tasks/gigaword/eval.sh" +eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/base/gigaword_conf b/ernie-gen/configs/base/gigaword_conf new file mode 100644 index 0000000000000000000000000000000000000000..acbb5e698b1a7e7ddf7cc3eb6c7104134d6c4ba6 --- /dev/null +++ b/ernie-gen/configs/base/gigaword_conf @@ -0,0 +1,48 @@ +#load model +vocab_path="ernie_gen_base/vocab.txt" +config_path="ernie_gen_base/ernie_config.json" +init_model="ernie_gen_base/params" + +#input +max_src_len=192 +max_tgt_len=64 +tokenized_input="true" +continuous_position="true" +batch_size=16 +in_tokens="false" +tgt_type_id=3 + +#decode +do_decode="true" +max_dec_len=32 +beam_size=5 +length_penalty=0.6 +use_multi_gpu_test="true" + +#train +epoch=10 +weight_decay=0.01 +label_smooth=0.1 +hidden_dropout_prob=0.1 +save_and_valid_by_epoch="true" +#lr +warmup_proportion=0.1 +lr_scheduler="linear_warmup_decay" +learning_rate=3e-5 +#noise +random_noise="true" +noise_prob=0.5 + +#dataset +data_path="./datasets/gigaword/" +train_set="train.tsv" +dev_set="dev.20k.tsv" +test_set="test.tsv" +do_train="true" +do_val="true" +do_test="true" +do_pred="false" + +#evaluate +eval_script="sh ./eval/tasks/gigaword/eval.sh" +eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/base/squad-qg_conf b/ernie-gen/configs/base/squad-qg_conf new file mode 100644 index 0000000000000000000000000000000000000000..0dbe30024538d1590d74f9321031deabb215dea5 --- /dev/null +++ b/ernie-gen/configs/base/squad-qg_conf @@ -0,0 +1,47 @@ +#load model +vocab_path="ernie_gen_base/vocab.txt" +config_path="ernie_gen_base/ernie_config.json" +init_model="ernie_gen_base/params" + +#input +max_src_len=512 +max_tgt_len=96 +tokenized_input="true" +continuous_position="true" +batch_size=4 +in_tokens="false" +tgt_type_id=3 + +#decode +do_decode="true" +max_dec_len=48 +beam_size=5 +length_penalty=1.0 +use_multi_gpu_test="true" + +#train +epoch=10 +weight_decay=0.01 +label_smooth=0.1 +hidden_dropout_prob=0.1 +save_and_valid_by_epoch="true" +#lr +warmup_proportion=0.1 +lr_scheduler="linear_warmup_decay" +learning_rate=2.5e-5 +#noise +random_noise="true" +noise_prob=0.7 + +#dataset +data_path="./datasets/squad_qg/" +train_set="train.tsv" +dev_set="dev.tsv" +test_set="test.tsv" +do_train="true" +do_val="true" +do_test="true" + +#evaluate +eval_script="sh ./eval/tasks/squad_qg/eval.sh" +eval_mertrics="Bleu_4,METEOR,ROUGE_L" diff --git a/ernie-gen/configs/large/cnndm_conf b/ernie-gen/configs/large/cnndm_conf new file mode 100644 index 0000000000000000000000000000000000000000..a6b136a7b7a5a8bf53cdb690cbe7cea8440ec8a5 --- /dev/null +++ b/ernie-gen/configs/large/cnndm_conf @@ -0,0 +1,48 @@ +#load model +vocab_path="ernie_gen_large/vocab.txt" +config_path="ernie_gen_large/ernie_config.json" +init_model="ernie_gen_large/params" + +#input +max_src_len=640 +max_tgt_len=192 +tokenized_input="true" +continuous_position="true" +batch_size=4 +in_tokens="false" +tgt_type_id=3 + +#decode +do_decode="true" +max_dec_len=128 +beam_size=5 +length_penalty=1.0 +use_multi_gpu_test="true" + +#train +epoch=20 +weight_decay=0.01 +label_smooth=0.1 +hidden_dropout_prob=0.1 +save_and_valid_by_epoch="true" +#lr +warmup_proportion=0.1 +lr_scheduler="linear_warmup_decay" +learning_rate=4e-5 +#noise +random_noise="true" +noise_prob=0.7 + +#dataset +data_path="./datasets/cnndm/" +train_set="train.tsv" +dev_set="dev.2k.tsv" +pred_set="test.tsv" +do_train="true" +do_val="true" +do_test="false" +do_pred="true" + +#evaluate +eval_script="sh ./eval/tasks/cnndm/eval.sh" +eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/large/coqa_conf b/ernie-gen/configs/large/coqa_conf new file mode 100644 index 0000000000000000000000000000000000000000..fb8178733e62c13acf355002d021df43727a3e95 --- /dev/null +++ b/ernie-gen/configs/large/coqa_conf @@ -0,0 +1,52 @@ +#load model +vocab_path="ernie_gen_large/vocab.txt" +config_path="ernie_gen_large/ernie_config.json" +init_model="ernie_gen_large/params" + +#for multi-turn dialog/qa +task_type="dialog" +role_type_size=3 +turn_type_size=16 + +#input +max_src_len=480 +max_tgt_len=32 +tokenized_input="true" +continuous_position="true" +batch_size=4 +in_tokens="false" +#tgt_type_id=1 + +#decode +do_decode="true" +max_dec_len=30 +beam_size=3 +length_penalty=0.0 +use_multi_gpu_test="true" + +#train +epoch=10 +weight_decay=0.01 +label_smooth=0.1 +hidden_dropout_prob=0.1 +save_and_valid_by_epoch="true" +#lr +warmup_proportion=0.1 +lr_scheduler="linear_warmup_decay" +learning_rate=1e-5 +#noise +random_noise="false" +noise_prob=0.5 + +#dataset +data_path="./datasets/coqa/" +train_set="train.tsv" +dev_set="dev.tsv" +do_train="true" +do_val="true" +do_test="false" +do_pred="false" + +#evaluate +eval_script="sh ./eval/tasks/coqa/eval.sh" +eval_mertrics="f1" diff --git a/ernie-gen/configs/large/gigaword-10k_conf b/ernie-gen/configs/large/gigaword-10k_conf new file mode 100644 index 0000000000000000000000000000000000000000..8ba11c13c423a8e2c6692ff903171ba4adbb2378 --- /dev/null +++ b/ernie-gen/configs/large/gigaword-10k_conf @@ -0,0 +1,48 @@ +#load model +vocab_path="ernie_gen_large/vocab.txt" +config_path="ernie_gen_large/ernie_config.json" +init_model="ernie_gen_large/params" + +#input +max_src_len=192 +max_tgt_len=64 +tokenized_input="true" +continuous_position="true" +batch_size=16 +in_tokens="false" +tgt_type_id=3 + +#decode +do_decode="true" +max_dec_len=32 +beam_size=5 +length_penalty=0.6 +use_multi_gpu_test="true" + +#train +epoch=30 +weight_decay=0.01 +label_smooth=0.1 +save_and_valid_by_epoch="true" +hidden_dropout_prob=0.1 +#lr +warmup_proportion=0.1 +lr_scheduler="linear_warmup_decay" +learning_rate=1e-5 +#noise +random_noise="true" +noise_prob=0.7 + +#dataset +data_path="./datasets/gigaword/" +train_set="train.10k.tsv" +dev_set="dev.20k.tsv" +test_set="test.tsv" +do_train="true" +do_val="true" +do_test="true" +do_pred="false" + +#evaluate +eval_script="sh ./eval/tasks/gigaword/eval.sh" +eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/large/gigaword_conf b/ernie-gen/configs/large/gigaword_conf new file mode 100644 index 0000000000000000000000000000000000000000..69fd8e9482c2df434584d331ab91f16d404a107c --- /dev/null +++ b/ernie-gen/configs/large/gigaword_conf @@ -0,0 +1,48 @@ +#load model +vocab_path="ernie_gen_large/vocab.txt" +config_path="ernie_gen_large/ernie_config.json" +init_model="ernie_gen_large/params" + +#input +max_src_len=192 +max_tgt_len=64 +tokenized_input="true" +continuous_position="true" +batch_size=16 +in_tokens="false" +tgt_type_id=3 + +#decode +do_decode="true" +max_dec_len=32 +beam_size=5 +length_penalty=0.6 +use_multi_gpu_test="true" + +#train +epoch=5 +weight_decay=0.01 +label_smooth=0.1 +hidden_dropout_prob=0.2 +save_and_valid_by_epoch="true" +#lr +warmup_proportion=0.1 +lr_scheduler="linear_warmup_decay" +learning_rate=3e-5 +#noise +random_noise="true" +noise_prob=0.6 + +#dataset +data_path="./datasets/gigaword/" +train_set="train.tsv" +dev_set="dev.20k.tsv" +test_set="test.tsv" +do_train="true" +do_val="true" +do_test="true" +do_pred="false" + +#evaluate +eval_script="sh ./eval/tasks/gigaword/eval.sh" +eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/large/persona-chat_conf b/ernie-gen/configs/large/persona-chat_conf new file mode 100644 index 0000000000000000000000000000000000000000..66e2379ff34de4d4781b814809f7f70b6fbbbe34 --- /dev/null +++ b/ernie-gen/configs/large/persona-chat_conf @@ -0,0 +1,53 @@ +#load model +vocab_path="ernie_gen_large/vocab.txt" +config_path="ernie_gen_large/ernie_config.json" +init_model="ernie_gen_large/params" + +#for multi-turn dialog/qa +task_type="dialog" +role_type_size=3 +turn_type_size=16 + +#input +max_src_len=472 +max_tgt_len=40 +tokenized_input="true" +continuous_position="true" +batch_size=8 +in_tokens="false" + +#decode +do_decode="true" +max_dec_len=32 +beam_size=10 +length_penalty=1.3 +use_multi_gpu_test="true" + +#train +epoch=30 +weight_decay=0.01 +label_smooth=0.0 +hidden_dropout_prob=0.1 +save_and_valid_by_epoch="true" +#lr +warmup_proportion=0.1 +lr_scheduler="linear_warmup_decay" +learning_rate=1e-4 +#noise +random_noise="false" +noise_prob=0.0 + +#dataset +data_path="./datasets/persona_chat/" +train_set="train.tsv" +dev_set="dev.2k.tsv" +pred_set="test.tsv" +do_train="true" +do_val="true" +do_test="false" +do_pred="true" +do_decode="true" + +#evaluate +eval_script="sh ./eval/tasks/persona_chat/eval.sh" +eval_mertrics="bleu_1,bleu_2,distinct_1,distinct_2" diff --git a/ernie-gen/configs/large/squad-qg_conf b/ernie-gen/configs/large/squad-qg_conf new file mode 100644 index 0000000000000000000000000000000000000000..c4921dbb651ad33b17b2ff35ff242045bc669ccc --- /dev/null +++ b/ernie-gen/configs/large/squad-qg_conf @@ -0,0 +1,47 @@ +#load model +vocab_path="ernie_gen_large/vocab.txt" +config_path="ernie_gen_large/ernie_config.json" +init_model="ernie_gen_large/params" + +#input +max_src_len=512 +max_tgt_len=96 +tokenized_input="true" +continuous_position="true" +batch_size=4 +in_tokens="false" +tgt_type_id=3 + +#decode +do_decode="true" +max_dec_len=48 +beam_size=5 +length_penalty=1.0 +use_multi_gpu_test="true" + +#train +epoch=10 +weight_decay=0.01 +label_smooth=0.1 +hidden_dropout_prob=0.2 +save_and_valid_by_epoch="true" +#lr +warmup_proportion=0.1 +lr_scheduler="linear_warmup_decay" +learning_rate=1e-5 +#noise +random_noise="true" +noise_prob=0.7 + +#dataset +data_path="./datasets/squad_qg/" +train_set="train.tsv" +dev_set="dev.tsv" +test_set="test.tsv" +do_train="true" +do_val="true" +do_test="true" + +#evaluate +eval_script="sh ./eval/tasks/squad_qg/eval.sh" +eval_mertrics="Bleu_4,METEOR,ROUGE_L" diff --git a/ernie-gen/configs/large_160g/cnndm_conf b/ernie-gen/configs/large_160g/cnndm_conf new file mode 100644 index 0000000000000000000000000000000000000000..4b0bd7dbe0d3fa97a7df2fcad879266af3f532e0 --- /dev/null +++ b/ernie-gen/configs/large_160g/cnndm_conf @@ -0,0 +1,48 @@ +#load model +vocab_path="ernie_gen_large_160g/vocab.txt" +config_path="ernie_gen_large_160g/ernie_config.json" +init_model="ernie_gen_large_160g/params" + +#input +max_src_len=640 +max_tgt_len=192 +tokenized_input="true" +continuous_position="true" +batch_size=4 +in_tokens="false" +tgt_type_id=3 + +#decode +do_decode="true" +max_dec_len=128 +beam_size=5 +length_penalty=1.2 +use_multi_gpu_test="true" + +#train +epoch=17 +weight_decay=0.01 +label_smooth=0.1 +hidden_dropout_prob=0.1 +save_and_valid_by_epoch="true" +#lr +warmup_proportion=0.02 +lr_scheduler="linear_warmup_decay" +learning_rate=4e-5 +#noise +random_noise="true" +noise_prob=0.7 + +#dataset +data_path="./datasets/cnndm/" +train_set="train.tsv" +dev_set="dev.2k.tsv" +pred_set="test.tsv" +do_train="true" +do_val="true" +do_test="false" +do_pred="true" + +#evaluate +eval_script="sh ./eval/tasks/cnndm/eval.sh" +eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/large_160g/coqa_conf b/ernie-gen/configs/large_160g/coqa_conf new file mode 100644 index 0000000000000000000000000000000000000000..fb8178733e62c13acf355002d021df43727a3e95 --- /dev/null +++ b/ernie-gen/configs/large_160g/coqa_conf @@ -0,0 +1,52 @@ +#load model +vocab_path="ernie_gen_large/vocab.txt" +config_path="ernie_gen_large/ernie_config.json" +init_model="ernie_gen_large/params" + +#for multi-turn dialog/qa +task_type="dialog" +role_type_size=3 +turn_type_size=16 + +#input +max_src_len=480 +max_tgt_len=32 +tokenized_input="true" +continuous_position="true" +batch_size=4 +in_tokens="false" +#tgt_type_id=1 + +#decode +do_decode="true" +max_dec_len=30 +beam_size=3 +length_penalty=0.0 +use_multi_gpu_test="true" + +#train +epoch=10 +weight_decay=0.01 +label_smooth=0.1 +hidden_dropout_prob=0.1 +save_and_valid_by_epoch="true" +#lr +warmup_proportion=0.1 +lr_scheduler="linear_warmup_decay" +learning_rate=1e-5 +#noise +random_noise="false" +noise_prob=0.5 + +#dataset +data_path="./datasets/coqa/" +train_set="train.tsv" +dev_set="dev.tsv" +do_train="true" +do_val="true" +do_test="false" +do_pred="false" + +#evaluate +eval_script="sh ./eval/tasks/coqa/eval.sh" +eval_mertrics="f1" diff --git a/ernie-gen/configs/large_160g/gigaword-10k_conf b/ernie-gen/configs/large_160g/gigaword-10k_conf new file mode 100644 index 0000000000000000000000000000000000000000..89a4f90350817be87b33c02b6fef55f93b58d3a9 --- /dev/null +++ b/ernie-gen/configs/large_160g/gigaword-10k_conf @@ -0,0 +1,48 @@ +#load model +vocab_path="ernie_gen_large_160g/vocab.txt" +config_path="ernie_gen_large_160g/ernie_config.json" +init_model="ernie_gen_large_160g/params" + +#input +max_src_len=192 +max_tgt_len=64 +tokenized_input="true" +continuous_position="true" +batch_size=16 +in_tokens="false" +tgt_type_id=3 + +#decode +do_decode="true" +max_dec_len=32 +beam_size=5 +length_penalty=0.6 +use_multi_gpu_test="true" + +#train +epoch=30 +weight_decay=0.01 +label_smooth=0.1 +save_and_valid_by_epoch="true" +hidden_dropout_prob=0.1 +#lr +warmup_proportion=0.15 +lr_scheduler="linear_warmup_decay" +learning_rate=7.5e-6 +#noise +random_noise="true" +noise_prob=0.65 + +#dataset +data_path="./datasets/gigaword/" +train_set="train.10k.tsv" +dev_set="dev.20k.tsv" +test_set="test.tsv" +do_train="true" +do_val="true" +do_test="true" +do_pred="false" + +#evaluate +eval_script="sh ./eval/tasks/gigaword/eval.sh" +eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/large_160g/gigaword_conf b/ernie-gen/configs/large_160g/gigaword_conf new file mode 100644 index 0000000000000000000000000000000000000000..4d31e9a19b6b73c4eaeba2e448de9d9f96c374ea --- /dev/null +++ b/ernie-gen/configs/large_160g/gigaword_conf @@ -0,0 +1,48 @@ +#load model +vocab_path="ernie_gen_large_160g/vocab.txt" +config_path="ernie_gen_large_160g/ernie_config.json" +init_model="ernie_gen_large_160g/params" + +#input +max_src_len=192 +max_tgt_len=64 +tokenized_input="true" +continuous_position="true" +batch_size=16 +in_tokens="false" +tgt_type_id=3 + +#decode +do_decode="true" +max_dec_len=32 +beam_size=6 +length_penalty=0.7 +use_multi_gpu_test="true" + +#train +epoch=5 +weight_decay=0.01 +label_smooth=0.1 +hidden_dropout_prob=0.2 +save_and_valid_by_epoch="true" +#lr +warmup_proportion=0.1 +lr_scheduler="linear_warmup_decay" +learning_rate=3e-5 +#noise +random_noise="true" +noise_prob=0.6 + +#dataset +data_path="./datasets/gigaword/" +train_set="train.tsv" +dev_set="dev.20k.tsv" +test_set="test.tsv" +do_train="true" +do_val="true" +do_test="true" +do_pred="false" + +#evaluate +eval_script="sh ./eval/tasks/gigaword/eval.sh" +eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/large_160g/persona-chat_conf b/ernie-gen/configs/large_160g/persona-chat_conf new file mode 100644 index 0000000000000000000000000000000000000000..66e2379ff34de4d4781b814809f7f70b6fbbbe34 --- /dev/null +++ b/ernie-gen/configs/large_160g/persona-chat_conf @@ -0,0 +1,53 @@ +#load model +vocab_path="ernie_gen_large/vocab.txt" +config_path="ernie_gen_large/ernie_config.json" +init_model="ernie_gen_large/params" + +#for multi-turn dialog/qa +task_type="dialog" +role_type_size=3 +turn_type_size=16 + +#input +max_src_len=472 +max_tgt_len=40 +tokenized_input="true" +continuous_position="true" +batch_size=8 +in_tokens="false" + +#decode +do_decode="true" +max_dec_len=32 +beam_size=10 +length_penalty=1.3 +use_multi_gpu_test="true" + +#train +epoch=30 +weight_decay=0.01 +label_smooth=0.0 +hidden_dropout_prob=0.1 +save_and_valid_by_epoch="true" +#lr +warmup_proportion=0.1 +lr_scheduler="linear_warmup_decay" +learning_rate=1e-4 +#noise +random_noise="false" +noise_prob=0.0 + +#dataset +data_path="./datasets/persona_chat/" +train_set="train.tsv" +dev_set="dev.2k.tsv" +pred_set="test.tsv" +do_train="true" +do_val="true" +do_test="false" +do_pred="true" +do_decode="true" + +#evaluate +eval_script="sh ./eval/tasks/persona_chat/eval.sh" +eval_mertrics="bleu_1,bleu_2,distinct_1,distinct_2" diff --git a/ernie-gen/configs/large_160g/squad-qg_conf b/ernie-gen/configs/large_160g/squad-qg_conf new file mode 100644 index 0000000000000000000000000000000000000000..85953f8a705deb85ccd50070175c65777d271b01 --- /dev/null +++ b/ernie-gen/configs/large_160g/squad-qg_conf @@ -0,0 +1,47 @@ +#load model +vocab_path="ernie_gen_large_160g/vocab.txt" +config_path="ernie_gen_large_160g/ernie_config.json" +init_model="ernie_gen_large_160g/params" + +#input +max_src_len=512 +max_tgt_len=96 +tokenized_input="true" +continuous_position="true" +batch_size=4 +in_tokens="false" +tgt_type_id=3 + +#decode +do_decode="true" +max_dec_len=48 +beam_size=5 +length_penalty=1.0 +use_multi_gpu_test="true" + +#train +epoch=10 +weight_decay=0.01 +label_smooth=0.1 +hidden_dropout_prob=0.2 +save_and_valid_by_epoch="true" +#lr +warmup_proportion=0.25 +lr_scheduler="linear_warmup_decay" +learning_rate=1.25e-5 +#noise +random_noise="true" +noise_prob=0.7 + +#dataset +data_path="./datasets/squad_qg/" +train_set="train.tsv" +dev_set="dev.tsv" +test_set="test.tsv" +do_train="true" +do_val="true" +do_test="true" + +#evaluate +eval_script="sh ./eval/tasks/squad_qg/eval.sh" +eval_mertrics="Bleu_4,METEOR,ROUGE_L" diff --git a/ernie-gen/env.sh b/ernie-gen/env.sh new file mode 100644 index 0000000000000000000000000000000000000000..66aae09ab447b622583ab6adec5c59a303ce80f8 --- /dev/null +++ b/ernie-gen/env.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash + +set -xu + +function check_iplist() { + if [[ ${iplist:-""} == "" ]] ;then + iplist=`hostname -i` + fi + + export PADDLE_PSERVER_PORT=9184 + export PADDLE_TRAINER_IPS=${iplist} + export PADDLE_CURRENT_IP=`hostname -i` + + iparray=(${iplist//,/ }) + for i in "${!iparray[@]}"; do + if [ ${iparray[$i]} == ${PADDLE_CURRENT_IP} ]; then + export PADDLE_TRAINER_ID=$i + fi + done + + export TRAINING_ROLE=TRAINER + export PADDLE_INIT_TRAINER_COUNT=${#iparray[@]} + export PADDLE_PORT=${PADDLE_PSERVER_PORT} + export PADDLE_TRAINERS=${PADDLE_TRAINER_IPS} + export POD_IP=${PADDLE_CURRENT_IP} + export PADDLE_TRAINERS_NUM=${PADDLE_INIT_TRAINER_COUNT} + export PADDLE_IS_LOCAL=0 + + #paddle debug envs + export GLOG_v=0 + export GLOG_logtostderr=1 + + #nccl debug envs + export NCCL_DEBUG=INFO + export NCCL_IB_GID_INDEX=3 +} + diff --git a/ernie-gen/eval/__init__.py b/ernie-gen/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ernie-gen/eval/gen_eval.py b/ernie-gen/eval/gen_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea0edf0e70e454c475c59d26d62b42c51560083 --- /dev/null +++ b/ernie-gen/eval/gen_eval.py @@ -0,0 +1,132 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ultis help and eval functions for gen .""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import json +import math +import subprocess + +from reader.tokenization import BasicTokenizer + +class GenerationEval(object): + """GenerationEval""" + + def __init__(self, args, merge_subword=None): + self.basic_tokenizer = BasicTokenizer(do_lower_case=True) + self.merge_subword = merge_subword + self.eval_script = args.eval_script.split(" ") + self.eval_mertrics = args.eval_mertrics.split(",") if args.eval_mertrics else [] + self.tokenized_input = args.tokenized_input + + def eval(self, output_file, phase="", features=None): + """run eval""" + eval_res = {} + if self.eval_script: + eval_res = subprocess.check_output(self.eval_script + [output_file, phase]) + eval_res = json.loads(eval_res) + else: + preds = [] + for line in open(output_file): + preds.append(self.basic_tokenizer.tokenize(line.strip())) + + refs = [] + for id in sorted(features.keys()): + if self.tokenized_input: + ref = features[id].tgt.decode("utf8").split(" ") + refs.append([self.merge_subword(ref)]) + else: + refs.append([self.basic_tokenizer.tokenize(features[id].tgt)]) + + for mertric in self.eval_mertrics: + eval_func = getattr(self, mertric, None) + if eval_func: + eval_res[mertric] = eval_func(refs, preds) + + ret = [] + for mertric in self.eval_mertrics: + mertric_res = eval_res.get(mertric, None) + if mertric_res is None: + raise Exception("Eval mertric: %s is not supported" % mertric) + ret.append("%s: %f" % (mertric, mertric_res)) + + return ", ".join(ret) + + def bleu(self, refs, preds): + """bleu mertric""" + return _compute_bleu(refs, preds, max_order=4)[0] + + +def _get_ngrams(segment, max_order): + ngram_counts = collections.Counter() + for order in range(1, max_order + 1): + for i in range(0, len(segment) - order + 1): + ngram = tuple(segment[i: i + order]) + ngram_counts[ngram] += 1 + return ngram_counts + + +def _compute_bleu(reference_corpus, translation_corpus, max_order=4, smooth=False): + matches_by_order = [0] * max_order + possible_matches_by_order = [0] * max_order + reference_length = 0 + translation_length = 0 + for (references, translation) in zip(reference_corpus, translation_corpus): + reference_length += min(len(r) for r in references) + translation_length += len(translation) + + merged_ref_ngram_counts = collections.Counter() + for reference in references: + merged_ref_ngram_counts |= _get_ngrams(reference, max_order) + translation_ngram_counts = _get_ngrams(translation, max_order) + overlap = translation_ngram_counts & merged_ref_ngram_counts + for ngram in overlap: + matches_by_order[len(ngram) - 1] += overlap[ngram] + for order in range(1, max_order + 1): + possible_matches = len(translation) - order + 1 + if possible_matches > 0: + possible_matches_by_order[order - 1] += possible_matches + + precisions = [0] * max_order + for i in range(0, max_order): + if smooth: + precisions[i] = ((matches_by_order[i] + 1.) / + (possible_matches_by_order[i] + 1.)) + else: + if possible_matches_by_order[i] > 0: + precisions[i] = (float(matches_by_order[i]) / + possible_matches_by_order[i]) + else: + precisions[i] = 0.0 + + if min(precisions) > 0: + p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) + geo_mean = math.exp(p_log_sum) + else: + geo_mean = 0 + + ratio = float(translation_length) / reference_length + + if ratio > 1.0: + bp = 1. + else: + bp = math.exp(1 - 1. / (ratio + 1e-4)) + + bleu = geo_mean * bp + ret = [bleu, precisions, bp, ratio, translation_length, reference_length] + return ret diff --git a/ernie-gen/eval/tasks/cnndm/cnndm/__init__.py b/ernie-gen/eval/tasks/cnndm/cnndm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ernie-gen/eval/tasks/cnndm/cnndm/bs_pyrouge.py b/ernie-gen/eval/tasks/cnndm/cnndm/bs_pyrouge.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4d350cd467311b2c4dfb835f3927b6d07611ba --- /dev/null +++ b/ernie-gen/eval/tasks/cnndm/cnndm/bs_pyrouge.py @@ -0,0 +1,644 @@ +from __future__ import print_function, unicode_literals, division + +import os +import re +import codecs +import platform + +from subprocess import check_output +from tempfile import mkdtemp +from functools import partial + +try: + from configparser import ConfigParser +except ImportError: + from ConfigParser import ConfigParser + +from pyrouge.utils import log +from pyrouge.utils.file_utils import verify_dir + + +REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}", + "-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'} + + +def clean(x): + return re.sub( + r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''", + lambda m: REMAP.get(m.group()), x) + + +class DirectoryProcessor: + + @staticmethod + def process(input_dir, output_dir, function): + """ + Apply function to all files in input_dir and save the resulting ouput + files in output_dir. + + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + logger = log.get_global_console_logger() + logger.info("Processing files in {}.".format(input_dir)) + input_file_names = os.listdir(input_dir) + for input_file_name in input_file_names: + input_file = os.path.join(input_dir, input_file_name) + with codecs.open(input_file, "r", encoding="UTF-8") as f: + input_string = f.read() + output_string = function(input_string) + output_file = os.path.join(output_dir, input_file_name) + with codecs.open(output_file, "w", encoding="UTF-8") as f: + f.write(clean(output_string.lower())) + logger.info("Saved processed files to {}.".format(output_dir)) + + +class Rouge155(object): + """ + This is a wrapper for the ROUGE 1.5.5 summary evaluation package. + This class is designed to simplify the evaluation process by: + + 1) Converting summaries into a format ROUGE understands. + 2) Generating the ROUGE configuration file automatically based + on filename patterns. + + This class can be used within Python like this: + + rouge = Rouge155() + rouge.system_dir = 'test/systems' + rouge.model_dir = 'test/models' + + # The system filename pattern should contain one group that + # matches the document ID. + rouge.system_filename_pattern = 'SL.P.10.R.11.SL062003-(\d+).html' + + # The model filename pattern has '#ID#' as a placeholder for the + # document ID. If there are multiple model summaries, pyrouge + # will use the provided regex to automatically match them with + # the corresponding system summary. Here, [A-Z] matches + # multiple model summaries for a given #ID#. + rouge.model_filename_pattern = 'SL.P.10.R.[A-Z].SL062003-#ID#.html' + + rouge_output = rouge.evaluate() + print(rouge_output) + output_dict = rouge.output_to_dict(rouge_ouput) + print(output_dict) + -> {'rouge_1_f_score': 0.95652, + 'rouge_1_f_score_cb': 0.95652, + 'rouge_1_f_score_ce': 0.95652, + 'rouge_1_precision': 0.95652, + [...] + + + To evaluate multiple systems: + + rouge = Rouge155() + rouge.system_dir = '/PATH/TO/systems' + rouge.model_dir = 'PATH/TO/models' + for system_id in ['id1', 'id2', 'id3']: + rouge.system_filename_pattern = \ + 'SL.P/.10.R.{}.SL062003-(\d+).html'.format(system_id) + rouge.model_filename_pattern = \ + 'SL.P.10.R.[A-Z].SL062003-#ID#.html' + rouge_output = rouge.evaluate(system_id) + print(rouge_output) + + """ + + def __init__(self, rouge_dir=None, rouge_args=None, temp_dir=None): + """ + Create a Rouge155 object. + + rouge_dir: Directory containing Rouge-1.5.5.pl + rouge_args: Arguments to pass through to ROUGE if you + don't want to use the default pyrouge + arguments. + + """ + self.temp_dir = temp_dir + self.log = log.get_global_console_logger() + self.__set_dir_properties() + self._config_file = None + self._settings_file = self.__get_config_path() + self.__set_rouge_dir(rouge_dir) + self.args = self.__clean_rouge_args(rouge_args) + self._system_filename_pattern = None + self._model_filename_pattern = None + + def save_home_dir(self): + config = ConfigParser() + section = 'pyrouge settings' + config.add_section(section) + config.set(section, 'home_dir', self._home_dir) + with open(self._settings_file, 'w') as f: + config.write(f) + self.log.info("Set ROUGE home directory to {}.".format(self._home_dir)) + + @property + def settings_file(self): + """ + Path of the setttings file, which stores the ROUGE home dir. + + """ + return self._settings_file + + @property + def bin_path(self): + """ + The full path of the ROUGE binary (although it's technically + a script), i.e. rouge_home_dir/ROUGE-1.5.5.pl + + """ + if self._bin_path is None: + raise Exception( + "ROUGE path not set. Please set the ROUGE home directory " + "and ensure that ROUGE-1.5.5.pl exists in it.") + return self._bin_path + + @property + def system_filename_pattern(self): + """ + The regular expression pattern for matching system summary + filenames. The regex string. + + E.g. "SL.P.10.R.11.SL062003-(\d+).html" will match the system + filenames in the SPL2003/system folder of the ROUGE SPL example + in the "sample-test" folder. + + Currently, there is no support for multiple systems. + + """ + return self._system_filename_pattern + + @system_filename_pattern.setter + def system_filename_pattern(self, pattern): + self._system_filename_pattern = pattern + + @property + def model_filename_pattern(self): + """ + The regular expression pattern for matching model summary + filenames. The pattern needs to contain the string "#ID#", + which is a placeholder for the document ID. + + E.g. "SL.P.10.R.[A-Z].SL062003-#ID#.html" will match the model + filenames in the SPL2003/system folder of the ROUGE SPL + example in the "sample-test" folder. + + "#ID#" is a placeholder for the document ID which has been + matched by the "(\d+)" part of the system filename pattern. + The different model summaries for a given document ID are + matched by the "[A-Z]" part. + + """ + return self._model_filename_pattern + + @model_filename_pattern.setter + def model_filename_pattern(self, pattern): + self._model_filename_pattern = pattern + + @property + def config_file(self): + return self._config_file + + @config_file.setter + def config_file(self, path): + config_dir, _ = os.path.split(path) + verify_dir(config_dir, "configuration file") + self._config_file = path + + def split_sentences(self): + """ + ROUGE requires texts split into sentences. In case the texts + are not already split, this method can be used. + + """ + from pyrouge.utils.sentence_splitter import PunktSentenceSplitter + self.log.info("Splitting sentences.") + ss = PunktSentenceSplitter() + + def sent_split_to_string(s): return "\n".join(ss.split(s)) + process_func = partial( + DirectoryProcessor.process, function=sent_split_to_string) + self.__process_summaries(process_func) + + @staticmethod + def convert_summaries_to_rouge_format(input_dir, output_dir): + """ + Convert all files in input_dir into a format ROUGE understands + and saves the files to output_dir. The input files are assumed + to be plain text with one sentence per line. + + input_dir: Path of directory containing the input files. + output_dir: Path of directory in which the converted files + will be saved. + + """ + DirectoryProcessor.process( + input_dir, output_dir, Rouge155.convert_text_to_rouge_format) + + @staticmethod + def convert_text_to_rouge_format(text, title="dummy title"): + """ + Convert a text to a format ROUGE understands. The text is + assumed to contain one sentence per line. + + text: The text to convert, containg one sentence per line. + title: Optional title for the text. The title will appear + in the converted file, but doesn't seem to have + any other relevance. + + Returns: The converted text as string. + + """ + sentences = text.split("\n") + sent_elems = [ + "[{i}] " + "{text}".format(i=i, text=sent) + for i, sent in enumerate(sentences, start=1)] + html = """ +
+{name}
".format( + id=system_id, name=system_filename) + + model_elems = ["systemX
+#
systemX
+#
systemX
+#
systemX
+#
systemX
+#