diff --git a/README.md b/README.md index dc48329272ea0ce766dc9d1557dbeb4365356bef..e0aa4032a176911ecb0917550dd4f2d7eca79927 100644 --- a/README.md +++ b/README.md @@ -17,10 +17,10 @@ ERNIE是百度开创性提出的基于知识增强的持续学习语义理解框 - 易于部署。 - 通过Aistudio 教程快速入门NLP。 - 向后兼容老版 checkpoint。 - - `ERNIE-GEN` 模型正式开源! ([点击进入](ernie-gen)) + - `ERNIE-GEN` 模型正式开源! ([点击进入](https://github.com/PaddlePaddle/ERNIE/tree/repro/ernie-gen)) - 最强文本生成预训练模型正式开源,相关工作已被 `IJCAI-2020` 收录。 - 首次把 ERNIE 预训练技术能力扩展至文本生成领域,在多个典型任务上取得最佳。 - - 您现在即可下载论文报告的所有模型(包含 [`base/large/large-160G`](ernie-gen/README.zh.md#预训练模型))。 + - 您现在即可下载论文报告的所有模型(包含 [`base/large/large-160G`](https://github.com/PaddlePaddle/ERNIE/tree/repro/ernie-gen/README.zh.md#预训练模型))。 - 首次在预训练阶段加入span-by-span 生成任务,让模型每次能够生成一个语义完整的片段。 - 提出填充式生成机制和噪声感知机制来缓解曝光偏差问题。 - 精巧的 Mulit-Flow Attention 实现框架。 diff --git a/ernie-gen/.meta/ernie-gen-paper.png b/ernie-gen/.meta/ernie-gen-paper.png deleted file mode 100644 index ff8fcd48fc7ac28557634a42ea020ecc37edfb4c..0000000000000000000000000000000000000000 Binary files a/ernie-gen/.meta/ernie-gen-paper.png and /dev/null differ diff --git a/ernie-gen/.meta/multi-flow-attention.png b/ernie-gen/.meta/multi-flow-attention.png deleted file mode 100644 index 9802adc0fa85621c4e52dd61cc62bcbf204fa529..0000000000000000000000000000000000000000 Binary files a/ernie-gen/.meta/multi-flow-attention.png and /dev/null differ diff --git a/ernie-gen/README.md b/ernie-gen/README.md deleted file mode 100644 index 0976d283de96b55ba483f1d3af8c623d9a03d2a1..0000000000000000000000000000000000000000 --- a/ernie-gen/README.md +++ /dev/null @@ -1,227 +0,0 @@ -English | [简体中文](./README.zh.md) - -## _ERNIE-GEN_: An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation - -- [Proposed Generation Framework](#proposed-generation-framework) -- [Pre-trained Models](#pre-trained-models) -- [Fine-tuning on Downstream Tasks](#fine-tuning-on-downstream-tasks) - * [Abstractive Summarization](#abstractive-summarization) - * [Question Generation](#question-generation) - * [Generative Dialogue Response](#generative-dialogue-response) - * [Generative Question Answering](#generative-question-answering) -- [Usage](#usage) - * [Install PaddlePaddle](#install-paddlepaddle) - * [Fine-tuning](#fine-tuning) - * [Employ Dynamic Computation Graph](#employ-dynamic-computation-graph) - * [The ERNIE 1.0 is avaliable](#the-ernie-10-is-avaliable-for-chinese-generation-tasks) -- [Citation](#citation) - -For technical description of the algorithm, please see our paper: ->[_**ERNIE-GEN:An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation**_](https://arxiv.org/abs/2001.11314.pdf) -> ->Dongling Xiao\*, Han Zhang\*, Yukun Li, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang (\* : equal contribution) -> ->Preprint January 2020 -> ->Accepted by **IJCAI-2020** - -![ERNIE-GEN](https://img.shields.io/badge/Pretraining-Generation-green) ![Gigaword](https://img.shields.io/badge/Abstractive%20Summarization-Gigaword-yellow) ![Gigaword](https://img.shields.io/badge/Abstractive%20Summarization-CNN/Daily%20Mail-blue) ![SQuAD](https://img.shields.io/badge/Question%20Generation-SQuAD-green) ![Personal-Chat](https://img.shields.io/badge/Dialogue%20Response-Personal%20Chat-yellowgreen) ![CoQA](https://img.shields.io/badge/Generative%20Question%20Answering-CoQA-orange) ---- -**[ERNIE-GEN](https://arxiv.org/abs/2001.11314.pdf) is a multi-flow language generation framework for both pre-training and fine-tuning.** We propose a novel **span-by-span generation** pre-training task to enable the model to **generate a semantically-complete span** at each step rather than a word, in light of the fact that entities, phrases in human writing are organized in a coherent manner. An **infilling generation mechanism** and a **noise-aware generation method** are incorporated into both pre-training and fine-tuning to alleviate **the problem of exposure bias**. In the pre-training phase, ERNIE-GEN adopts a **multi-granularity target fragments sampling** strategy to force decoder to rely more on the encoder representations other than the previous generated words to enhancing the correlation between encoder and decoder. - -## Proposed Generation Framework - -We construct three novel methods to enhance the language generation ability: - -- **Span-by-span Generation Pre-training Task**: to enable model to generate a semantically-complete span at each step rather than a word. -- **Infilling Genration and Noise-aware Generation**: to alleviate the problem of exposure bias. -- **Multi-Granularity Target Fragments**: to enhance the correlation between encoder and decoder during pre-training. - -Specifically, the span-by-span generation task and word-by-word generation task based on infilling generation mechanism are impemented by a carefully designed **Multi-Flow Attention** architecture as shown below. - -![multi-flow-attention](.meta/multi-flow-attention.png) - -## Pre-trained Models - -We release the checkpoints for **ERNIE-GEN _base_** model and **ERNIE-GEN _large_** model which are both pre-trained on English Wikipedia and [BookCorpus](https://arxiv.org/abs/1506.06724) (totally 16GB). Besides, **ERNIE-GEN _large_** pre-trained on the 160GB corpus (used by [RoBERTa](https://arxiv.org/abs/1907.11692) and [BART](https://arxiv.org/abs/1910.13461)) is available as well. - -- [**ERNIE-GEN _base_**](https://ernie.bj.bcebos.com/ernie_gen_base.tgz) (_lowercased | 12-layer, 768-hidden, 12-heads, 110M parameters_) -- [**ERNIE-GEN _large_**](https://ernie.bj.bcebos.com/ernie_gen_large.tgz) (_lowercased | 24-layer, 1024-hidden, 16-heads, 340M parameters_) -- [**ERNIE-GEN _large with 160G_**](https://ernie.bj.bcebos.com/ernie_gen_large_160g.tgz) (_lowercased | 24-layer, 1024-hidden, 16-heads, 340M parameters_) - - -## Fine-tuning on Downstream Tasks - -We compare the performance of [ERNIE-GEN](https://arxiv.org/pdf/2001.11314.pdf) with the existing SOTA pre-training models for natural language generation ([UniLM](https://arxiv.org/abs/1905.03197), [MASS](https://arxiv.org/abs/1905.02450), [PEGASUS](https://arxiv.org/abs/1912.08777), [BART](https://arxiv.org/abs/1910.13461) and [T5](https://arxiv.org/abs/1910.10683)) on 5 genration tasks, including abstractive summarization (**_Gigaword_** and **_CNN/DailyMail_**), question generation (**_SQuAD_**), dialogue generation (**_Persona-Chat_**) and generative question answering (**_CoQA_**). - -### Abstractive Summarization - -- _**Gigaword**_ - -The results on Gigaword-10k (10K examples of Gigaword) are presented as follows: - -| Model | Data / Params | Rouge-1 | Rouge-2 | Rouge-L | -| :-------------------------------------------------------- | :----------------------------: | :----------------------: | :----------------------: | :----------------------: | -| UniLM | 16G / 340M | 34.21 | 15.28 | 31.54 | -| **ENRIE-GEN** _base_ | 16G / 110M | 33.75 | 15.23 | 31.35 | -| **ERNIE-GEN** _large_ | 16G / 340M | 35.05 | 16.10 | 32.50 | -| **ERNIE-GEN** _large_ (160G) | 160G / 340M | **35.51** | **16.79** | **33.23** | - -The results on Gigaword are presented as follows: - -| Model | Data / Params | Rouge-1 | Rouge-2 | Rouge-L | -| :-------------------------------------------------------- | :----------------------------: | :----------------------: | :----------------------: | :----------------------: | -| MASS | 18G / 160M | 38.73 | 19.71 | 35.96 | -| BERTSHARE | 16G / 110M | 38.13 | 19.81 | 35.62 | -| UniLM | 16G / 340M | 38.45 | 19.45 | 35.75 | -| PEGASUS (_C4_) | 750G / 568M | 38.75 | 19.96 | 36.14 | -| PEGASUS (_HugeNews_) | 3.8T / 568M | 39.12 | 19.86 | 36.24 | -| **ENRIE-GEN** _base_ | 16G / 110M | 38.83 | 20.04 | 36.20 | -| **ERNIE-GEN** _large_ | 16G / 340M | 39.25 | 20.25 | 36.53 | -| **ERNIE-GEN** _large_ (160G) | 160G / 340M | **39.46** | **20.34** | **36.74** | - -We preprocess the raw Gigaword dataset following UniLM, the preprocessed data is avalilable at this [Gigaword](https://ernie.bj.bcebos.com/gigaword.tgz). - -- _**CNN/Daily Mail**_ - -The results on CNN/Daily Mail are presented as follows: - -| Model | Data / Params | Rouge-1 | Rouge-2 | Rouge-L | -| :-------------------------------------------------------- | :-----------: | :----------------------: | :----------------------: | :----------------------: | -| MASS | 18G / 160M | 42.12 | 19.50 | 39.01 | -| UniLM | 16G / 340M | 43.33 | 20.21 | 40.51 | -| T5 _large_ | 750G / 340M | 42.50 | 20.68 | 39.75 | -| T5 _xlarge_ | 750G / 11B | 43.52 | **21.55** | 40.69 | -| BART | 160G / 400M | 44.16 | 21.28 | 40.90 | -| PEGASUS (_C4_) | 750G / 568M | 43.90 | 21.20 | 40.76 | -| PEGASUS (_HugeNews_) | 3.8T / 568M | 44.17 | 21.47 | 41.11 | -| **ENRIE-GEN** _base_ | 16G / 110M | 42.30 | 19.92 | 39.68 | -| **ENRIE-GEN** _large_ | 16G / 340M | 44.02 | 21.17 | 41.26 | -| **ENRIE-GEN** _large_ (160G) | 160G / 340M | **44.31** | 21.35 | **41.60** | - -We preprocess the raw CNN/Daily Mail dataset following UniLM, the preprocessed data is avalilable at this [CNN/Daily Mail](https://ernie.bj.bcebos.com/cnndm.tgz). - -### Question Generation - -- _**SQuAD**_ - -The results on the [SQuAD 1.1](https://arxiv.org/abs/1806.03822) dataset following the data split in [[Du et al., 2017]](https://arxiv.org/pdf/1705.00106.pdf) are presented as follows: - -| Model | BLEU-4 | METEOR | Rouge-L | -| :----------------------------------------------------------- | :----------------------: | :----------------------: | :----------------------: | -| [SemQG](https://arxiv.org/abs/1909.06356) | 18.37 | 22.65 | 46.68 | -| UniLM _large_ (beam size=1) | 22.12 | 25.06 | 51.07 | -| **ENRIE-GEN** _base_ (beam size=1) | 22.28 | 25.13 | 50.38 | -| **ERNIE-GEN** _large_ (beam size=1) | 24.03 | 26.31 | 52.36 | -| **ERNIE-GEN** _large_ (beam size=5) | 25.40 | **26.92** | 52.84 | -| **ERNIE-GEN** _large_ (beam size=5) + (160G) | **25.41** | 26.77 | **52.91** | - -The results following the reversed dev-test data split in [[Zhao et al., 2018]](https://www.aclweb.org/anthology/D18-1424/) are presented as follows: - -| Model | BLEU-4 | METEOR | Rouge-L | -| :----------------------------------------------------------- | :----------------------: | :----------------------: | :----------------------: | -| SemQG | 20.76 | 24.20 | 48.91 | -| UniLM _large_ (beam size=1) | 23.75 | 25.61 | 52.04 | -| **ENRIE-GEN** _base_ (beam size=1) | 23.52 | 25.61 | 51.45 | -| **ERNIE-GEN** _large_ (beam size=1) | 25.57 | 26.89 | 53.31 | -| **ERNIE-GEN** _large_ (beam size=5) | 26.95 | **27.57** | 53.77 | -| **ERNIE-GEN** _large_ (beam size=5) + (160G) | **27.05** | 27.43 | **53.83** | - -*_Note that we also report the results with higher beam size to 5._ - -The preprocessed data for question generation task can be downloaded from [SQuAD](https://ernie.bj.bcebos.com/squad_qg.tgz). - -### Generative Dialogue Response - -- _**Personal-Chat**_ - - Comparison with current state-of-the-art results on the multi-turn conversations task ([Persona-Chat](https://arxiv.org/abs/1801.07243)) is presented as follows: - -| Model | BLEU-1 | BLEU-2 | Distinct-1 | Distinct-2 | -| :-------------------------------------------------------- | :---------------------: | :---------------------: | :-------------------------: | :---------------------------: | -| [LIC](https://arxiv.org/abs/1910.07931) | 40.5 | 32.0 | 0.019 | 0.113 | -| [PLATO](https://arxiv.org/abs/1910.07931) | 45.8 | 35.7 | 0.012 | 0.064 | -| PLATO _w/o latent_ | 40.6 | 31.5 | 0.021 | 0.121 | -| **ERNIE-GEN** _large_ | **46.8** | **36.4** | **0.023** | **0.168** | - -The training data can be downloaded from [Personal-Chat](https://ernie.bj.bcebos.com/persona_chat.tgz). - -### Generative Question Answering - -- _**CoQA**_ - -Results of development set on CoQA task is presented as follows: - -| Model | F1-score | -| :-------------------------------------------------------- | :------: | -| [Seq2Seq](https://arxiv.org/abs/1910.07931) | 27.5 | -| [PGNet](https://arxiv.org/abs/1910.07931) | 45.4 | -| UniLM _large_ | 82.5 | -| **ERNIE-GEN** _large_ | **84.5** | - -We preprocess the raw [CoQA](https://arxiv.org/abs/1808.07042) dataset, the preprocessed data is avalilable at this [CoQA-preprocessed](https://ernie.bj.bcebos.com/coqa.tgz). - -Finally, we also compared with a concurrent work [ProphetNet](https://arxiv.org/abs/2001.04063), the fine-tuning results on Gigaword, CNN/Daily Mail and SQuAD are reported as follows: - -- _**Abstractive Summarization**_ - -| Model / Task | Data / Params | Gigaword |CNN/Daily Mail| -| :-------------------------------------------------------- | :----------------------------: | :----------------------: | :----------------------: | -| Metric | - | Rouge-1 / Rouge-2 / Rouge-L |Rouge-1 / Rouge-2 / Rouge-L| -| **ProphetNet** _large_ (160G) | 160G / 340M | **39.51** / **20.42** / 36.69 |44.20 / 21.17 / 41.30| -| **ERNIE-GEN** _large_ (160G) | 160G / 340M | 39.46 / 20.34 / **36.74** |**44.31** / **21.35** / **41.60**| - -- _**Question Generation**_ - -| Model | Data / Params | BLEU-4 / METEOR / Rouge-L |BLEU-4 / METEOR / Rouge-L| -| :-------------------------------------------------------- | :----------------------------: | :----------------------: |:----------------------: | -| Data split | - | Original |Reversed dev-test| -| **ProphetNet** _large_ (16G) | 16G / 340M | 25.01 / 26.83 / 52.57 |26.72 / **27.64** / **53.79** | -| **ERNIE-GEN** _large_ (16G) | 16G / 340M | **25.40** / **26.92** / **52.84** |**26.95** / 27.57 / **53.77**| - -## Usage - -### Install PaddlePaddle - -This code base has been tested with Paddle Fluid 1.7 with Python 2.7. Other dependency of ERNIE-GEN is listed in `requirements.txt`, you can install it by -```script -pip install -r requirements.txt -``` - -### Fine-tuning -Please update LD_LIBRARY_PATH about CUDA, cuDNN, NCCL2 before running ERNIE-GEN. We have put the parameter configurations of the above downstream tasks in `config/`. You can easily run finetuning through these configuration files. For example, you can finetune ERNIE-GEN base model on Gigaword by -```script -MODEL="base" # base or large or large_160g -TASK="gigaword" # cnndm, coqa, gigaword, squad_qg or persona-chat -sh run_seq2seq.sh ./configs/${MODEL}/${TASK}_conf -``` -The log of training and the evaluation results are in `log/job.log.0`. To finetune on your own task data, you can refer to the data format we provide for processing your data. - -Our fine-tuning experiments are carried on 8 NVIDIA V100 (32GB) GPUs. If your GPU memory is not enough, you can reduce the batch size in the corresponding configuration file. - -**NOTICE: ** The actual total batch size is equal to `configured batch size * number of used gpus`. - -### Employ Dynamic Computation Graph - -The ERNIE-GEN code using dynamic graph is more concise and flexible, please refer to [ERNIE-GEN Dygraph](https://github.com/PaddlePaddle/ERNIE/tree/develop/experimental/seq2seq) for specific use. - -### The ERNIE 1.0 is avaliable for Chinese Generation Tasks - -The ERNIE-GEN code is compatible with [ERNIE 1.0](https://ernie.bj.bcebos.com/ERNIE_1.0_max-len-512.tar.gz) model. After specifying the parameters related to the model and data in the configuration file, you can use ERNIE 1.0 to fine-tune chinese generation tasks. - -## Citation - -You can cite the paper as below: - -``` -@article{xiao2020ernie-gen, - title={ERNIE-GEN: An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation}, - author={Xiao, Dongling and Zhang, Han and Li, Yukun and Sun, Yu and Tian, Hao and Wu, Hua and Wang, Haifeng}, - journal={arXiv preprint arXiv:2001.11314}, - year={2020} -} -``` - - - - diff --git a/ernie-gen/README.zh.md b/ernie-gen/README.zh.md deleted file mode 100644 index f748ed3303366102a3cbf5cef21790bb419d651a..0000000000000000000000000000000000000000 --- a/ernie-gen/README.zh.md +++ /dev/null @@ -1,225 +0,0 @@ -[English](./README.md) | 简体中文 - -## _ERNIE-GEN_: An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation - -- [模型框架](#模型框架) -- [预训练模型](#预训练模型) -- [微调任务](#微调任务) - * [生成式摘要](#生成式摘要) - * [问题生成](#问题生成) - * [多轮对话](#多轮对话) - * [生成式多轮问答](#生成式多轮问答) -- [使用说明](#使用说明) - * [安装飞桨](#安装飞桨) - * [运行微调](#运行微调) - * [使用动态图](#使用动态图) - * [中文生成任务使用 ERNIE 1.0](#中文生成任务使用-ernie-10) -- [引用](#引用) - -关于算法的详细描述,请参见我们的论文: ->[_**ERNIE-GEN:An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation**_](https://arxiv.org/abs/2001.11314) -> ->Dongling Xiao\*, Han Zhang\*, Yukun Li, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang (\* : equal contribution) -> ->Preprint January 2020 -> ->Accepted by **IJCAI-2020** - -![ERNIE-GEN](https://img.shields.io/badge/预训练-语言生成-green) ![Gigaword](https://img.shields.io/badge/生成式摘要-Gigaword-yellow) ![Gigaword](https://img.shields.io/badge/生成式摘要-CNN/Daily%20Mail-blue) ![SQuAD](https://img.shields.io/badge/问题生成-SQuAD-green) ![Personal-Chat](https://img.shields.io/badge/多轮对话-Personal%20Chat-yellowgreen) ![CoQA](https://img.shields.io/badge/多轮问答-CoQA-orange) ---- -**ERNIE-GEN 是面向生成任务的预训练-微调框架**,首次在预训练阶段加入**span-by-span 生成**任务,让模型每次能够生成一个语义完整的片段。在预训练和微调中通过**填充式生成机制**和**噪声感知机制**来缓解曝光偏差问题。此外, ERNIE-GEN 采样**多片段-多粒度目标文本采样**策略, 增强源文本和目标文本的关联性,加强了编码器和解码器的交互。 - -## 模型框架 - -我们提出了三种方法来提高语言生成能力: - -- **Span-by-span 生成任务**: 让模型能够每次生成一个语义完整的片段。 -- **填充式生成**和**噪声感知生成**: 缓解曝光偏差问题。 -- **多片段-多粒度目标文本采样**: 预训练阶段增强编码器和解码器的交互。 - -我们基于 Transformer 模型设计了 **Mulit-Flow Attention** 框架,用于实现 span-by-span 的填充式生成。 - -![multi-flow-attention](.meta/multi-flow-attention.png) - -## 预训练模型 - -我们发布了 **ERNIE-GEN _base_** 模型和 **ERNIE-GEN _large_** 模型。 预训练数据使用英文维基百科和 BookCorpus,总共16GB。此外,我们还发布了基于 160GB 语料预训练的**ERNIE-GEN _large_** 模型,此份语料也被用于 [RoBERTa](https://arxiv.org/abs/1907.11692) 和 [BART](https://arxiv.org/abs/1910.13461) 的预训练。 - -- [**ERNIE-GEN _base_**](https://ernie.bj.bcebos.com/ernie_gen_base.tgz) (_lowercased | 12-layer, 768-hidden, 12-heads, 110M parameters_) -- [**ERNIE-GEN _large_**](https://ernie.bj.bcebos.com/ernie_gen_large.tgz) (_lowercased | 24-layer, 1024-hidden, 16-heads, 340M parameters_) -- [**ERNIE-GEN _large with 160G_**](https://ernie.bj.bcebos.com/ernie_gen_large_160g.tgz) (_lowercased | 24-layer, 1024-hidden, 16-heads, 340M parameters_) - - -## 微调任务 - -我们在五个典型生成任务上与当前效果最优的生成预训练模型([UniLM](https://arxiv.org/abs/1905.03197)、[MASS](https://arxiv.org/abs/1905.02450)、[PEGASUS](https://arxiv.org/abs/1912.08777)、[BART](https://arxiv.org/abs/1910.13461)、[T5](https://arxiv.org/abs/1910.10683)等)进行对比, 包括生成式摘要 (Gigaword 和 CNN/DailyMail), 问题生成(SQuAD), 多轮对话(Persona-Chat) 和生成式多轮问答(CoQA)。 - -### 生成式摘要 - -- _**Gigaword**_ - -在 Gigaword-10k (Gigaword 的子集) 上的效果: - -| 模型 | 数据量 / 参数量 | Rouge-1 | Rouge-2 | Rouge-L | -| :-------------------------------------------------------- | :------------------------------: | :----------------------: | :----------------------: | :----------------------: | -| UniLM | 16G / 340M | 34.21 | 15.28 | 31.54 | -| **ENRIE-GEN** _base_ | 16G / 110M | 33.75 | 15.23 | 31.35 | -| **ERNIE-GEN** _large_ | 16G / 340M | 35.05 | 16.10 | 32.50 | -| **ERNIE-GEN** _large_ (160G) | 160G / 340M | **35.51** | **16.79** | **33.23** | - -在 Gigaword 上的效果: - -| 模型 | 数量 / 参数量 | Rouge-1 | Rouge-2 | Rouge-L | -| :-------------------------------------------------------- | :----------------------------: | :----------------------: | :----------------------: | :----------------------: | -| MASS | 18G / 160M | 38.73 | 19.71 | 35.96 | -| [BERTSHARE](https://arxiv.org/abs/1907.12461) | 16G / 110M | 38.13 | 19.81 | 35.62 | -| UniLM | 16G / 340M | 38.45 | 19.45 | 35.75 | -| PEGASUS (_C4_) | 750G / 568M | 38.75 | 19.96 | 36.14 | -| PEGASUS (_HugeNews_) | 3.8T / 568M | 39.12 | 19.86 | 36.24 | -| **ENRIE-GEN** _base_ | 16G / 110M | 38.83 | 20.04 | 36.20 | -| **ERNIE-GEN** _large_ | 16G / 340M | 39.25 | 20.25 | 36.53 | -| **ERNIE-GEN** _large_ (160G) | 160G / 340M | **39.46** | **20.34** | **36.74** | - -我们按照 UniLM 的方式处理了数据,下载链接 [Gigaword](https://ernie.bj.bcebos.com/gigaword.tgz)。 - -- _**CNN/Daily Mail**_ - -在 CNN/Daily Mail 上的效果: - -| 模型 | 数据量 /参数量| Rouge-1 | Rouge-2 | Rouge-L | -| :-------------------------------------------------------- | :-----------: | :----------------------: | :----------------------: | :----------------------: | -| MASS | 18G / 160M | 42.12 | 19.50 | 39.01 | -| UniLM | 16G / 340M | 43.33 | 20.21 | 40.51 | -| T5 _large_ | 750G / 340M | 42.50 | 20.68 | 39.75 | -| T5 _xlarge_ | 750G / 11B | 43.52 | **21.55** | 40.69 | -| BART | 160G / 400M | 44.16 | 21.28 | 40.90 | -| PEGASUS (_C4_) | 750G / 568M | 43.90 | 21.20 | 40.76 | -| PEGASUS (_HugeNews_) | 3.8T / 568M | 44.17 | 21.47 | 41.11 | -| **ENRIE-GEN** _base_ | 16G / 110M | 42.30 | 19.92 | 39.68 | -| **ENRIE-GEN** _large_ | 16G / 340M | 44.02 | 21.17 | 41.26 | -| **ENRIE-GEN** _large_ (160G) | 160G / 340M | **44.31** | 21.35 | **41.60** | - -我们按照 UniLM 的方式处理了数据,下载链接 [CNN/Daily Mail](https://ernie.bj.bcebos.com/cnndm.tgz)。 - -### 问题生成 - -- _**SQuAD**_ - -在 SQuAD 1.1 数据集上的效果(测试集划分按照 [[Du et al., 2017]](https://arxiv.org/abs/1705.00106)) : - -| 模型 | BLEU-4 | METEOR | Rouge-L | -| :----------------------------------------------------------- | :----------------------: | :----------------------: | :----------------------: | -| [SemQG](https://arxiv.org/abs/1909.06356) | 18.37 | 22.65 | 46.68 | -| UniLM _large_ (beam size=1) | 22.12 | 25.06 | 51.07 | -| **ENRIE-GEN** _base_ (beam size=1) | 22.28 | 25.13 | 50.38 | -| **ERNIE-GEN** _large_ (beam size=1) | 24.03 | 26.31 | 52.36 | -| **ERNIE-GEN** _large_ (beam size=5) | 25.40 | **26.92** | 52.84 | -| **ERNIE-GEN** _large_ (beam size=5) + (160G) | **25.41** | 26.77 | **52.91** | - -按照 [[Zhao et al., 2018]](https://www.aclweb.org/anthology/D18-1424/) 反向使用验证集和测试集,效果如下: - -| Model | BLEU-4 | METEOR | Rouge-L | -| :----------------------------------------------------------- | :----------------------: | :----------------------: | :----------------------: | -| [SemQG](https://arxiv.org/abs/1909.06356) | 20.76 | 24.20 | 48.91 | -| UniLM _large_ (beam size=1) | 23.75 | 25.61 | 52.04 | -| **ENRIE-GEN** _base_ (beam size=1) | 23.52 | 25.61 | 51.45 | -| **ERNIE-GEN** _large_ (beam size=1) | 25.57 | 26.89 | 53.31 | -| **ERNIE-GEN** _large_ (beam size=5) | 26.95 | **27.57** | 53.77 | -| **ERNIE-GEN** _large_ (beam size=5) + (160G) | **27.05** | 27.43 | **53.83** | - -*_我们增加了将 beam size 扩大到 5 的结果。_ - -我们按照 UniLM 的方式处理了数据,下载链接 [SQuAD](https://ernie.bj.bcebos.com/squad_qg.tgz)。 - -### 多轮对话 - -- _**Personal-Chat**_ - -| Model | BLEU-1 | BLEU-2 | Distinct-1 | Distinct-2 | -| :-------------------------------------------------------- | :---------------------: | :---------------------: | :-------------------------: | :---------------------------: | -| [LIC](https://arxiv.org/abs/1910.07931) | 40.5 | 32.0 | 0.019 | 0.113 | -| [PLATO](https://arxiv.org/abs/1910.07931) | 45.8 | 35.7 | 0.012 | 0.064 | -| [PLATO](https://arxiv.org/abs/1910.07931) _w/o latent_ | 40.6 | 31.5 | 0.021 | 0.121 | -| **ERNIE-GEN** _large_ | **46.8** | **36.4** | **0.023** | **0.168** | - -我们处理的数据下载链接 [Personal-Chat](https://ernie.bj.bcebos.com/persona_chat.tgz)。 - -### 生成式多轮问答 - -- _**CoQA**_ - -在 CoQA 验证集上的效果: - -| 模型 | F1-score | -| :-------------------------------------------------------- | :------: | -| [Seq2Seq](https://arxiv.org/abs/1910.07931) | 27.5 | -| [PGNet](https://arxiv.org/abs/1910.07931) | 45.4 | -| UniLM _large_ | 82.5 | -| **ERNIE-GEN** _large_ | **84.5** | - -我们对原始的 CoQA 数据集进行了处理,下载链接 [CoQA](https://ernie.bj.bcebos.com/coqa.tgz)。 - -此外,我们与同期的工作 [ProphetNet](https://arxiv.org/abs/2001.04063) 在 Gigaword,CNN/Daily Mail 和 SQuAD 三个数据集上进行了对比: - -- _**生成式摘要**_ - -| 模型 / 任务 | 数据量 / 参数量 | Gigaword |CNN/Daily Mail| -| :-------------------------------------------------------- | :------------------------------: | :----------------------: | :----------------------: | -| Metric | - | Rouge-1 / Rouge-2 / Rouge-L |Rouge-1 / Rouge-2 / Rouge-L| -| ProphetNet _large_ (160G) | 160G / 340M | **39.51** / **20.42** / 36.69 |44.20 / 21.17 / 41.30| -| **ERNIE-GEN** _large_ (160G) | 160G / 340M | 39.46 / 20.34 / **36.74** |**44.31** / **21.35** / **41.60**| - -- _**问题生成**_ - -| 模型 | 数据量 / 参数量 | BLEU-4 / METEOR / Rouge-L |BLEU-4 / METEOR / Rouge-L| -| :-------------------------------------------------------- | :------------------------------: | :----------------------: |:----------------------: | -| Data split | - | Original |Reversed dev-test| -| ProphetNet** _large_ (16G) | 16G / 340M | 25.01 / 26.83 / 52.57 |26.72 / **27.64** / **53.79** | -| **ERNIE-GEN** _large_ (16G) | 16G / 340M | **25.40** / **26.92** / **52.84** |**26.95** / 27.57 / **53.77**| - -## 使用说明 - -### 安装飞桨 - -我们的代码基于 Paddle Fluid 1.7 和 Python 2.7。 ERNIE-GEN 依赖的其他模块也列举在 `requirements.txt`,可以通过下面的指令安装: -```script -pip install -r requirements.txt -``` - -### 运行微调 -在运行 ERNIE-GEN 前,需要将 CUDA 、cuDNN 、NCCL2 的动态库路径添加到 LD_LIBRARY_PATH 。 我们把下游任务的参数配置文件放到了 `config/` ,可以简单地通过配置文件运行。 例如,您可以通过下面的指令在 Gigaword 数据集上微调 ERNIE-GEN base 模型: -```script -MODEL="base" # base or large or large_160g -TASK="gigaword" # cnndm, coqa, gigaword, squad_qg or persona-chat -sh run_seq2seq.sh ./configs/${MODEL}/${TASK}_conf -``` -训练和评估的日志在 `log/job.log.0`。 如果要在您自己的数据集上微调,可以参考我们提供的数据格式处理自己的数据。 - -我们的微调实验在 8 张 32GB 显存的英伟达 V100 GPU 上运行,如果您的 GPU 显存不够,可以减小配置文件中的 batch_size 。 - -**注意**: 训练时实际的 batch size 等于 `配置的 batch size * GPU 卡数`。 - -### 使用动态图 - -动态图版本的 ERNIE-GEN 代码更加简洁灵活,使用请参考 [ERNIE-GEN Dygraph](https://github.com/PaddlePaddle/ERNIE/tree/develop/experimental/seq2seq)。 - -### 中文生成任务使用 ERNIE 1.0 - -ERNIE-GEN 的代码兼容 [ERNIE 1.0 模型](https://ernie.bj.bcebos.com/ERNIE_1.0_max-len-512.tar.gz),修改配置文件中模型和数据相关的设置,就可以用 ERNIE 1.0 在中文生成任务上微调。 - -## 引用 - -可以按下面的格式引用我们的论文: - -``` -@article{xiao2020ernie-gen, - title={ERNIE-GEN: An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation}, - author={Xiao, Dongling and Zhang, Han and Li, Yukun and Sun, Yu and Tian, Hao and Wu, Hua and Wang, Haifeng}, - journal={arXiv preprint arXiv:2001.11314}, - year={2020} -} -``` - - - - diff --git a/ernie-gen/configs/base/cnndm_conf b/ernie-gen/configs/base/cnndm_conf deleted file mode 100644 index 1e5708756e451a89ae60d02f5356d3247ce94ae0..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/base/cnndm_conf +++ /dev/null @@ -1,48 +0,0 @@ -#load model -vocab_path="ernie_gen_base/vocab.txt" -config_path="ernie_gen_base/ernie_config.json" -init_model="ernie_gen_base/params" - -#input -max_src_len=640 -max_tgt_len=192 -tokenized_input="true" -continuous_position="true" -batch_size=8 -in_tokens="false" -tgt_type_id=3 - -#decode -do_decode="true" -max_dec_len=128 -beam_size=5 -length_penalty=1.0 -use_multi_gpu_test="true" - -#train -epoch=30 -weight_decay=0.01 -label_smooth=0.1 -hidden_dropout_prob=0.1 -save_and_valid_by_epoch="true" -#lr -warmup_proportion=0.1 -lr_scheduler="linear_warmup_decay" -learning_rate=5e-5 -#noise -random_noise="true" -noise_prob=0.7 - -#dataset -data_path="./datasets/cnndm/" -train_set="train.tsv" -dev_set="dev.2k.tsv" -pred_set="test.tsv" -do_train="true" -do_val="true" -do_test="false" -do_pred="true" - -#evaluate -eval_script="sh ./eval/tasks/cnndm/eval.sh" -eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/base/gigaword-10k_conf b/ernie-gen/configs/base/gigaword-10k_conf deleted file mode 100644 index 58b9e9735b278344c1457ce8edf585ecebdd90c7..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/base/gigaword-10k_conf +++ /dev/null @@ -1,48 +0,0 @@ -#load model -vocab_path="ernie_gen_base/vocab.txt" -config_path="ernie_gen_base/ernie_config.json" -init_model="ernie_gen_base/params" - -#input -max_src_len=192 -max_tgt_len=64 -tokenized_input="true" -continuous_position="true" -batch_size=16 -in_tokens="false" -tgt_type_id=3 - -#decode -do_decode="true" -max_dec_len=32 -beam_size=5 -length_penalty=0.6 -use_multi_gpu_test="true" - -#train -epoch=30 -weight_decay=0.01 -label_smooth=0.1 -save_and_valid_by_epoch="true" -hidden_dropout_prob=0.1 -#lr -warmup_proportion=0.1 -lr_scheduler="linear_warmup_decay" -learning_rate=1.25e-5 -#noise -random_noise="true" -noise_prob=0.5 - -#dataset -data_path="./datasets/gigaword/" -train_set="train.10k.tsv" -dev_set="dev.20k.tsv" -test_set="test.tsv" -do_train="true" -do_val="true" -do_test="true" -do_pred="false" - -#evaluate -eval_script="sh ./eval/tasks/gigaword/eval.sh" -eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/base/gigaword_conf b/ernie-gen/configs/base/gigaword_conf deleted file mode 100644 index acbb5e698b1a7e7ddf7cc3eb6c7104134d6c4ba6..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/base/gigaword_conf +++ /dev/null @@ -1,48 +0,0 @@ -#load model -vocab_path="ernie_gen_base/vocab.txt" -config_path="ernie_gen_base/ernie_config.json" -init_model="ernie_gen_base/params" - -#input -max_src_len=192 -max_tgt_len=64 -tokenized_input="true" -continuous_position="true" -batch_size=16 -in_tokens="false" -tgt_type_id=3 - -#decode -do_decode="true" -max_dec_len=32 -beam_size=5 -length_penalty=0.6 -use_multi_gpu_test="true" - -#train -epoch=10 -weight_decay=0.01 -label_smooth=0.1 -hidden_dropout_prob=0.1 -save_and_valid_by_epoch="true" -#lr -warmup_proportion=0.1 -lr_scheduler="linear_warmup_decay" -learning_rate=3e-5 -#noise -random_noise="true" -noise_prob=0.5 - -#dataset -data_path="./datasets/gigaword/" -train_set="train.tsv" -dev_set="dev.20k.tsv" -test_set="test.tsv" -do_train="true" -do_val="true" -do_test="true" -do_pred="false" - -#evaluate -eval_script="sh ./eval/tasks/gigaword/eval.sh" -eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/base/squad-qg_conf b/ernie-gen/configs/base/squad-qg_conf deleted file mode 100644 index 0dbe30024538d1590d74f9321031deabb215dea5..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/base/squad-qg_conf +++ /dev/null @@ -1,47 +0,0 @@ -#load model -vocab_path="ernie_gen_base/vocab.txt" -config_path="ernie_gen_base/ernie_config.json" -init_model="ernie_gen_base/params" - -#input -max_src_len=512 -max_tgt_len=96 -tokenized_input="true" -continuous_position="true" -batch_size=4 -in_tokens="false" -tgt_type_id=3 - -#decode -do_decode="true" -max_dec_len=48 -beam_size=5 -length_penalty=1.0 -use_multi_gpu_test="true" - -#train -epoch=10 -weight_decay=0.01 -label_smooth=0.1 -hidden_dropout_prob=0.1 -save_and_valid_by_epoch="true" -#lr -warmup_proportion=0.1 -lr_scheduler="linear_warmup_decay" -learning_rate=2.5e-5 -#noise -random_noise="true" -noise_prob=0.7 - -#dataset -data_path="./datasets/squad_qg/" -train_set="train.tsv" -dev_set="dev.tsv" -test_set="test.tsv" -do_train="true" -do_val="true" -do_test="true" - -#evaluate -eval_script="sh ./eval/tasks/squad_qg/eval.sh" -eval_mertrics="Bleu_4,METEOR,ROUGE_L" diff --git a/ernie-gen/configs/large/cnndm_conf b/ernie-gen/configs/large/cnndm_conf deleted file mode 100644 index a6b136a7b7a5a8bf53cdb690cbe7cea8440ec8a5..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/large/cnndm_conf +++ /dev/null @@ -1,48 +0,0 @@ -#load model -vocab_path="ernie_gen_large/vocab.txt" -config_path="ernie_gen_large/ernie_config.json" -init_model="ernie_gen_large/params" - -#input -max_src_len=640 -max_tgt_len=192 -tokenized_input="true" -continuous_position="true" -batch_size=4 -in_tokens="false" -tgt_type_id=3 - -#decode -do_decode="true" -max_dec_len=128 -beam_size=5 -length_penalty=1.0 -use_multi_gpu_test="true" - -#train -epoch=20 -weight_decay=0.01 -label_smooth=0.1 -hidden_dropout_prob=0.1 -save_and_valid_by_epoch="true" -#lr -warmup_proportion=0.1 -lr_scheduler="linear_warmup_decay" -learning_rate=4e-5 -#noise -random_noise="true" -noise_prob=0.7 - -#dataset -data_path="./datasets/cnndm/" -train_set="train.tsv" -dev_set="dev.2k.tsv" -pred_set="test.tsv" -do_train="true" -do_val="true" -do_test="false" -do_pred="true" - -#evaluate -eval_script="sh ./eval/tasks/cnndm/eval.sh" -eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/large/coqa_conf b/ernie-gen/configs/large/coqa_conf deleted file mode 100644 index fb8178733e62c13acf355002d021df43727a3e95..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/large/coqa_conf +++ /dev/null @@ -1,52 +0,0 @@ -#load model -vocab_path="ernie_gen_large/vocab.txt" -config_path="ernie_gen_large/ernie_config.json" -init_model="ernie_gen_large/params" - -#for multi-turn dialog/qa -task_type="dialog" -role_type_size=3 -turn_type_size=16 - -#input -max_src_len=480 -max_tgt_len=32 -tokenized_input="true" -continuous_position="true" -batch_size=4 -in_tokens="false" -#tgt_type_id=1 - -#decode -do_decode="true" -max_dec_len=30 -beam_size=3 -length_penalty=0.0 -use_multi_gpu_test="true" - -#train -epoch=10 -weight_decay=0.01 -label_smooth=0.1 -hidden_dropout_prob=0.1 -save_and_valid_by_epoch="true" -#lr -warmup_proportion=0.1 -lr_scheduler="linear_warmup_decay" -learning_rate=1e-5 -#noise -random_noise="false" -noise_prob=0.5 - -#dataset -data_path="./datasets/coqa/" -train_set="train.tsv" -dev_set="dev.tsv" -do_train="true" -do_val="true" -do_test="false" -do_pred="false" - -#evaluate -eval_script="sh ./eval/tasks/coqa/eval.sh" -eval_mertrics="f1" diff --git a/ernie-gen/configs/large/gigaword-10k_conf b/ernie-gen/configs/large/gigaword-10k_conf deleted file mode 100644 index 8ba11c13c423a8e2c6692ff903171ba4adbb2378..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/large/gigaword-10k_conf +++ /dev/null @@ -1,48 +0,0 @@ -#load model -vocab_path="ernie_gen_large/vocab.txt" -config_path="ernie_gen_large/ernie_config.json" -init_model="ernie_gen_large/params" - -#input -max_src_len=192 -max_tgt_len=64 -tokenized_input="true" -continuous_position="true" -batch_size=16 -in_tokens="false" -tgt_type_id=3 - -#decode -do_decode="true" -max_dec_len=32 -beam_size=5 -length_penalty=0.6 -use_multi_gpu_test="true" - -#train -epoch=30 -weight_decay=0.01 -label_smooth=0.1 -save_and_valid_by_epoch="true" -hidden_dropout_prob=0.1 -#lr -warmup_proportion=0.1 -lr_scheduler="linear_warmup_decay" -learning_rate=1e-5 -#noise -random_noise="true" -noise_prob=0.7 - -#dataset -data_path="./datasets/gigaword/" -train_set="train.10k.tsv" -dev_set="dev.20k.tsv" -test_set="test.tsv" -do_train="true" -do_val="true" -do_test="true" -do_pred="false" - -#evaluate -eval_script="sh ./eval/tasks/gigaword/eval.sh" -eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/large/gigaword_conf b/ernie-gen/configs/large/gigaword_conf deleted file mode 100644 index 69fd8e9482c2df434584d331ab91f16d404a107c..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/large/gigaword_conf +++ /dev/null @@ -1,48 +0,0 @@ -#load model -vocab_path="ernie_gen_large/vocab.txt" -config_path="ernie_gen_large/ernie_config.json" -init_model="ernie_gen_large/params" - -#input -max_src_len=192 -max_tgt_len=64 -tokenized_input="true" -continuous_position="true" -batch_size=16 -in_tokens="false" -tgt_type_id=3 - -#decode -do_decode="true" -max_dec_len=32 -beam_size=5 -length_penalty=0.6 -use_multi_gpu_test="true" - -#train -epoch=5 -weight_decay=0.01 -label_smooth=0.1 -hidden_dropout_prob=0.2 -save_and_valid_by_epoch="true" -#lr -warmup_proportion=0.1 -lr_scheduler="linear_warmup_decay" -learning_rate=3e-5 -#noise -random_noise="true" -noise_prob=0.6 - -#dataset -data_path="./datasets/gigaword/" -train_set="train.tsv" -dev_set="dev.20k.tsv" -test_set="test.tsv" -do_train="true" -do_val="true" -do_test="true" -do_pred="false" - -#evaluate -eval_script="sh ./eval/tasks/gigaword/eval.sh" -eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/large/persona-chat_conf b/ernie-gen/configs/large/persona-chat_conf deleted file mode 100644 index 66e2379ff34de4d4781b814809f7f70b6fbbbe34..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/large/persona-chat_conf +++ /dev/null @@ -1,53 +0,0 @@ -#load model -vocab_path="ernie_gen_large/vocab.txt" -config_path="ernie_gen_large/ernie_config.json" -init_model="ernie_gen_large/params" - -#for multi-turn dialog/qa -task_type="dialog" -role_type_size=3 -turn_type_size=16 - -#input -max_src_len=472 -max_tgt_len=40 -tokenized_input="true" -continuous_position="true" -batch_size=8 -in_tokens="false" - -#decode -do_decode="true" -max_dec_len=32 -beam_size=10 -length_penalty=1.3 -use_multi_gpu_test="true" - -#train -epoch=30 -weight_decay=0.01 -label_smooth=0.0 -hidden_dropout_prob=0.1 -save_and_valid_by_epoch="true" -#lr -warmup_proportion=0.1 -lr_scheduler="linear_warmup_decay" -learning_rate=1e-4 -#noise -random_noise="false" -noise_prob=0.0 - -#dataset -data_path="./datasets/persona_chat/" -train_set="train.tsv" -dev_set="dev.2k.tsv" -pred_set="test.tsv" -do_train="true" -do_val="true" -do_test="false" -do_pred="true" -do_decode="true" - -#evaluate -eval_script="sh ./eval/tasks/persona_chat/eval.sh" -eval_mertrics="bleu_1,bleu_2,distinct_1,distinct_2" diff --git a/ernie-gen/configs/large/squad-qg_conf b/ernie-gen/configs/large/squad-qg_conf deleted file mode 100644 index c4921dbb651ad33b17b2ff35ff242045bc669ccc..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/large/squad-qg_conf +++ /dev/null @@ -1,47 +0,0 @@ -#load model -vocab_path="ernie_gen_large/vocab.txt" -config_path="ernie_gen_large/ernie_config.json" -init_model="ernie_gen_large/params" - -#input -max_src_len=512 -max_tgt_len=96 -tokenized_input="true" -continuous_position="true" -batch_size=4 -in_tokens="false" -tgt_type_id=3 - -#decode -do_decode="true" -max_dec_len=48 -beam_size=5 -length_penalty=1.0 -use_multi_gpu_test="true" - -#train -epoch=10 -weight_decay=0.01 -label_smooth=0.1 -hidden_dropout_prob=0.2 -save_and_valid_by_epoch="true" -#lr -warmup_proportion=0.1 -lr_scheduler="linear_warmup_decay" -learning_rate=1e-5 -#noise -random_noise="true" -noise_prob=0.7 - -#dataset -data_path="./datasets/squad_qg/" -train_set="train.tsv" -dev_set="dev.tsv" -test_set="test.tsv" -do_train="true" -do_val="true" -do_test="true" - -#evaluate -eval_script="sh ./eval/tasks/squad_qg/eval.sh" -eval_mertrics="Bleu_4,METEOR,ROUGE_L" diff --git a/ernie-gen/configs/large_160g/cnndm_conf b/ernie-gen/configs/large_160g/cnndm_conf deleted file mode 100644 index 4b0bd7dbe0d3fa97a7df2fcad879266af3f532e0..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/large_160g/cnndm_conf +++ /dev/null @@ -1,48 +0,0 @@ -#load model -vocab_path="ernie_gen_large_160g/vocab.txt" -config_path="ernie_gen_large_160g/ernie_config.json" -init_model="ernie_gen_large_160g/params" - -#input -max_src_len=640 -max_tgt_len=192 -tokenized_input="true" -continuous_position="true" -batch_size=4 -in_tokens="false" -tgt_type_id=3 - -#decode -do_decode="true" -max_dec_len=128 -beam_size=5 -length_penalty=1.2 -use_multi_gpu_test="true" - -#train -epoch=17 -weight_decay=0.01 -label_smooth=0.1 -hidden_dropout_prob=0.1 -save_and_valid_by_epoch="true" -#lr -warmup_proportion=0.02 -lr_scheduler="linear_warmup_decay" -learning_rate=4e-5 -#noise -random_noise="true" -noise_prob=0.7 - -#dataset -data_path="./datasets/cnndm/" -train_set="train.tsv" -dev_set="dev.2k.tsv" -pred_set="test.tsv" -do_train="true" -do_val="true" -do_test="false" -do_pred="true" - -#evaluate -eval_script="sh ./eval/tasks/cnndm/eval.sh" -eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/large_160g/coqa_conf b/ernie-gen/configs/large_160g/coqa_conf deleted file mode 100644 index fb8178733e62c13acf355002d021df43727a3e95..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/large_160g/coqa_conf +++ /dev/null @@ -1,52 +0,0 @@ -#load model -vocab_path="ernie_gen_large/vocab.txt" -config_path="ernie_gen_large/ernie_config.json" -init_model="ernie_gen_large/params" - -#for multi-turn dialog/qa -task_type="dialog" -role_type_size=3 -turn_type_size=16 - -#input -max_src_len=480 -max_tgt_len=32 -tokenized_input="true" -continuous_position="true" -batch_size=4 -in_tokens="false" -#tgt_type_id=1 - -#decode -do_decode="true" -max_dec_len=30 -beam_size=3 -length_penalty=0.0 -use_multi_gpu_test="true" - -#train -epoch=10 -weight_decay=0.01 -label_smooth=0.1 -hidden_dropout_prob=0.1 -save_and_valid_by_epoch="true" -#lr -warmup_proportion=0.1 -lr_scheduler="linear_warmup_decay" -learning_rate=1e-5 -#noise -random_noise="false" -noise_prob=0.5 - -#dataset -data_path="./datasets/coqa/" -train_set="train.tsv" -dev_set="dev.tsv" -do_train="true" -do_val="true" -do_test="false" -do_pred="false" - -#evaluate -eval_script="sh ./eval/tasks/coqa/eval.sh" -eval_mertrics="f1" diff --git a/ernie-gen/configs/large_160g/gigaword-10k_conf b/ernie-gen/configs/large_160g/gigaword-10k_conf deleted file mode 100644 index 89a4f90350817be87b33c02b6fef55f93b58d3a9..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/large_160g/gigaword-10k_conf +++ /dev/null @@ -1,48 +0,0 @@ -#load model -vocab_path="ernie_gen_large_160g/vocab.txt" -config_path="ernie_gen_large_160g/ernie_config.json" -init_model="ernie_gen_large_160g/params" - -#input -max_src_len=192 -max_tgt_len=64 -tokenized_input="true" -continuous_position="true" -batch_size=16 -in_tokens="false" -tgt_type_id=3 - -#decode -do_decode="true" -max_dec_len=32 -beam_size=5 -length_penalty=0.6 -use_multi_gpu_test="true" - -#train -epoch=30 -weight_decay=0.01 -label_smooth=0.1 -save_and_valid_by_epoch="true" -hidden_dropout_prob=0.1 -#lr -warmup_proportion=0.15 -lr_scheduler="linear_warmup_decay" -learning_rate=7.5e-6 -#noise -random_noise="true" -noise_prob=0.65 - -#dataset -data_path="./datasets/gigaword/" -train_set="train.10k.tsv" -dev_set="dev.20k.tsv" -test_set="test.tsv" -do_train="true" -do_val="true" -do_test="true" -do_pred="false" - -#evaluate -eval_script="sh ./eval/tasks/gigaword/eval.sh" -eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/large_160g/gigaword_conf b/ernie-gen/configs/large_160g/gigaword_conf deleted file mode 100644 index 4d31e9a19b6b73c4eaeba2e448de9d9f96c374ea..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/large_160g/gigaword_conf +++ /dev/null @@ -1,48 +0,0 @@ -#load model -vocab_path="ernie_gen_large_160g/vocab.txt" -config_path="ernie_gen_large_160g/ernie_config.json" -init_model="ernie_gen_large_160g/params" - -#input -max_src_len=192 -max_tgt_len=64 -tokenized_input="true" -continuous_position="true" -batch_size=16 -in_tokens="false" -tgt_type_id=3 - -#decode -do_decode="true" -max_dec_len=32 -beam_size=6 -length_penalty=0.7 -use_multi_gpu_test="true" - -#train -epoch=5 -weight_decay=0.01 -label_smooth=0.1 -hidden_dropout_prob=0.2 -save_and_valid_by_epoch="true" -#lr -warmup_proportion=0.1 -lr_scheduler="linear_warmup_decay" -learning_rate=3e-5 -#noise -random_noise="true" -noise_prob=0.6 - -#dataset -data_path="./datasets/gigaword/" -train_set="train.tsv" -dev_set="dev.20k.tsv" -test_set="test.tsv" -do_train="true" -do_val="true" -do_test="true" -do_pred="false" - -#evaluate -eval_script="sh ./eval/tasks/gigaword/eval.sh" -eval_mertrics="rouge-1,rouge-2,rouge-l" diff --git a/ernie-gen/configs/large_160g/persona-chat_conf b/ernie-gen/configs/large_160g/persona-chat_conf deleted file mode 100644 index 66e2379ff34de4d4781b814809f7f70b6fbbbe34..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/large_160g/persona-chat_conf +++ /dev/null @@ -1,53 +0,0 @@ -#load model -vocab_path="ernie_gen_large/vocab.txt" -config_path="ernie_gen_large/ernie_config.json" -init_model="ernie_gen_large/params" - -#for multi-turn dialog/qa -task_type="dialog" -role_type_size=3 -turn_type_size=16 - -#input -max_src_len=472 -max_tgt_len=40 -tokenized_input="true" -continuous_position="true" -batch_size=8 -in_tokens="false" - -#decode -do_decode="true" -max_dec_len=32 -beam_size=10 -length_penalty=1.3 -use_multi_gpu_test="true" - -#train -epoch=30 -weight_decay=0.01 -label_smooth=0.0 -hidden_dropout_prob=0.1 -save_and_valid_by_epoch="true" -#lr -warmup_proportion=0.1 -lr_scheduler="linear_warmup_decay" -learning_rate=1e-4 -#noise -random_noise="false" -noise_prob=0.0 - -#dataset -data_path="./datasets/persona_chat/" -train_set="train.tsv" -dev_set="dev.2k.tsv" -pred_set="test.tsv" -do_train="true" -do_val="true" -do_test="false" -do_pred="true" -do_decode="true" - -#evaluate -eval_script="sh ./eval/tasks/persona_chat/eval.sh" -eval_mertrics="bleu_1,bleu_2,distinct_1,distinct_2" diff --git a/ernie-gen/configs/large_160g/squad-qg_conf b/ernie-gen/configs/large_160g/squad-qg_conf deleted file mode 100644 index 85953f8a705deb85ccd50070175c65777d271b01..0000000000000000000000000000000000000000 --- a/ernie-gen/configs/large_160g/squad-qg_conf +++ /dev/null @@ -1,47 +0,0 @@ -#load model -vocab_path="ernie_gen_large_160g/vocab.txt" -config_path="ernie_gen_large_160g/ernie_config.json" -init_model="ernie_gen_large_160g/params" - -#input -max_src_len=512 -max_tgt_len=96 -tokenized_input="true" -continuous_position="true" -batch_size=4 -in_tokens="false" -tgt_type_id=3 - -#decode -do_decode="true" -max_dec_len=48 -beam_size=5 -length_penalty=1.0 -use_multi_gpu_test="true" - -#train -epoch=10 -weight_decay=0.01 -label_smooth=0.1 -hidden_dropout_prob=0.2 -save_and_valid_by_epoch="true" -#lr -warmup_proportion=0.25 -lr_scheduler="linear_warmup_decay" -learning_rate=1.25e-5 -#noise -random_noise="true" -noise_prob=0.7 - -#dataset -data_path="./datasets/squad_qg/" -train_set="train.tsv" -dev_set="dev.tsv" -test_set="test.tsv" -do_train="true" -do_val="true" -do_test="true" - -#evaluate -eval_script="sh ./eval/tasks/squad_qg/eval.sh" -eval_mertrics="Bleu_4,METEOR,ROUGE_L" diff --git a/ernie-gen/env.sh b/ernie-gen/env.sh deleted file mode 100644 index 66aae09ab447b622583ab6adec5c59a303ce80f8..0000000000000000000000000000000000000000 --- a/ernie-gen/env.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env bash - -set -xu - -function check_iplist() { - if [[ ${iplist:-""} == "" ]] ;then - iplist=`hostname -i` - fi - - export PADDLE_PSERVER_PORT=9184 - export PADDLE_TRAINER_IPS=${iplist} - export PADDLE_CURRENT_IP=`hostname -i` - - iparray=(${iplist//,/ }) - for i in "${!iparray[@]}"; do - if [ ${iparray[$i]} == ${PADDLE_CURRENT_IP} ]; then - export PADDLE_TRAINER_ID=$i - fi - done - - export TRAINING_ROLE=TRAINER - export PADDLE_INIT_TRAINER_COUNT=${#iparray[@]} - export PADDLE_PORT=${PADDLE_PSERVER_PORT} - export PADDLE_TRAINERS=${PADDLE_TRAINER_IPS} - export POD_IP=${PADDLE_CURRENT_IP} - export PADDLE_TRAINERS_NUM=${PADDLE_INIT_TRAINER_COUNT} - export PADDLE_IS_LOCAL=0 - - #paddle debug envs - export GLOG_v=0 - export GLOG_logtostderr=1 - - #nccl debug envs - export NCCL_DEBUG=INFO - export NCCL_IB_GID_INDEX=3 -} - diff --git a/ernie-gen/eval/__init__.py b/ernie-gen/eval/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/ernie-gen/eval/gen_eval.py b/ernie-gen/eval/gen_eval.py deleted file mode 100644 index 4ea0edf0e70e454c475c59d26d62b42c51560083..0000000000000000000000000000000000000000 --- a/ernie-gen/eval/gen_eval.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""ultis help and eval functions for gen .""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import json -import math -import subprocess - -from reader.tokenization import BasicTokenizer - -class GenerationEval(object): - """GenerationEval""" - - def __init__(self, args, merge_subword=None): - self.basic_tokenizer = BasicTokenizer(do_lower_case=True) - self.merge_subword = merge_subword - self.eval_script = args.eval_script.split(" ") - self.eval_mertrics = args.eval_mertrics.split(",") if args.eval_mertrics else [] - self.tokenized_input = args.tokenized_input - - def eval(self, output_file, phase="", features=None): - """run eval""" - eval_res = {} - if self.eval_script: - eval_res = subprocess.check_output(self.eval_script + [output_file, phase]) - eval_res = json.loads(eval_res) - else: - preds = [] - for line in open(output_file): - preds.append(self.basic_tokenizer.tokenize(line.strip())) - - refs = [] - for id in sorted(features.keys()): - if self.tokenized_input: - ref = features[id].tgt.decode("utf8").split(" ") - refs.append([self.merge_subword(ref)]) - else: - refs.append([self.basic_tokenizer.tokenize(features[id].tgt)]) - - for mertric in self.eval_mertrics: - eval_func = getattr(self, mertric, None) - if eval_func: - eval_res[mertric] = eval_func(refs, preds) - - ret = [] - for mertric in self.eval_mertrics: - mertric_res = eval_res.get(mertric, None) - if mertric_res is None: - raise Exception("Eval mertric: %s is not supported" % mertric) - ret.append("%s: %f" % (mertric, mertric_res)) - - return ", ".join(ret) - - def bleu(self, refs, preds): - """bleu mertric""" - return _compute_bleu(refs, preds, max_order=4)[0] - - -def _get_ngrams(segment, max_order): - ngram_counts = collections.Counter() - for order in range(1, max_order + 1): - for i in range(0, len(segment) - order + 1): - ngram = tuple(segment[i: i + order]) - ngram_counts[ngram] += 1 - return ngram_counts - - -def _compute_bleu(reference_corpus, translation_corpus, max_order=4, smooth=False): - matches_by_order = [0] * max_order - possible_matches_by_order = [0] * max_order - reference_length = 0 - translation_length = 0 - for (references, translation) in zip(reference_corpus, translation_corpus): - reference_length += min(len(r) for r in references) - translation_length += len(translation) - - merged_ref_ngram_counts = collections.Counter() - for reference in references: - merged_ref_ngram_counts |= _get_ngrams(reference, max_order) - translation_ngram_counts = _get_ngrams(translation, max_order) - overlap = translation_ngram_counts & merged_ref_ngram_counts - for ngram in overlap: - matches_by_order[len(ngram) - 1] += overlap[ngram] - for order in range(1, max_order + 1): - possible_matches = len(translation) - order + 1 - if possible_matches > 0: - possible_matches_by_order[order - 1] += possible_matches - - precisions = [0] * max_order - for i in range(0, max_order): - if smooth: - precisions[i] = ((matches_by_order[i] + 1.) / - (possible_matches_by_order[i] + 1.)) - else: - if possible_matches_by_order[i] > 0: - precisions[i] = (float(matches_by_order[i]) / - possible_matches_by_order[i]) - else: - precisions[i] = 0.0 - - if min(precisions) > 0: - p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) - geo_mean = math.exp(p_log_sum) - else: - geo_mean = 0 - - ratio = float(translation_length) / reference_length - - if ratio > 1.0: - bp = 1. - else: - bp = math.exp(1 - 1. / (ratio + 1e-4)) - - bleu = geo_mean * bp - ret = [bleu, precisions, bp, ratio, translation_length, reference_length] - return ret diff --git a/ernie-gen/eval/tasks/cnndm/cnndm/__init__.py b/ernie-gen/eval/tasks/cnndm/cnndm/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/ernie-gen/eval/tasks/cnndm/cnndm/bs_pyrouge.py b/ernie-gen/eval/tasks/cnndm/cnndm/bs_pyrouge.py deleted file mode 100644 index 9c4d350cd467311b2c4dfb835f3927b6d07611ba..0000000000000000000000000000000000000000 --- a/ernie-gen/eval/tasks/cnndm/cnndm/bs_pyrouge.py +++ /dev/null @@ -1,644 +0,0 @@ -from __future__ import print_function, unicode_literals, division - -import os -import re -import codecs -import platform - -from subprocess import check_output -from tempfile import mkdtemp -from functools import partial - -try: - from configparser import ConfigParser -except ImportError: - from ConfigParser import ConfigParser - -from pyrouge.utils import log -from pyrouge.utils.file_utils import verify_dir - - -REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}", - "-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'} - - -def clean(x): - return re.sub( - r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''", - lambda m: REMAP.get(m.group()), x) - - -class DirectoryProcessor: - - @staticmethod - def process(input_dir, output_dir, function): - """ - Apply function to all files in input_dir and save the resulting ouput - files in output_dir. - - """ - if not os.path.exists(output_dir): - os.makedirs(output_dir) - logger = log.get_global_console_logger() - logger.info("Processing files in {}.".format(input_dir)) - input_file_names = os.listdir(input_dir) - for input_file_name in input_file_names: - input_file = os.path.join(input_dir, input_file_name) - with codecs.open(input_file, "r", encoding="UTF-8") as f: - input_string = f.read() - output_string = function(input_string) - output_file = os.path.join(output_dir, input_file_name) - with codecs.open(output_file, "w", encoding="UTF-8") as f: - f.write(clean(output_string.lower())) - logger.info("Saved processed files to {}.".format(output_dir)) - - -class Rouge155(object): - """ - This is a wrapper for the ROUGE 1.5.5 summary evaluation package. - This class is designed to simplify the evaluation process by: - - 1) Converting summaries into a format ROUGE understands. - 2) Generating the ROUGE configuration file automatically based - on filename patterns. - - This class can be used within Python like this: - - rouge = Rouge155() - rouge.system_dir = 'test/systems' - rouge.model_dir = 'test/models' - - # The system filename pattern should contain one group that - # matches the document ID. - rouge.system_filename_pattern = 'SL.P.10.R.11.SL062003-(\d+).html' - - # The model filename pattern has '#ID#' as a placeholder for the - # document ID. If there are multiple model summaries, pyrouge - # will use the provided regex to automatically match them with - # the corresponding system summary. Here, [A-Z] matches - # multiple model summaries for a given #ID#. - rouge.model_filename_pattern = 'SL.P.10.R.[A-Z].SL062003-#ID#.html' - - rouge_output = rouge.evaluate() - print(rouge_output) - output_dict = rouge.output_to_dict(rouge_ouput) - print(output_dict) - -> {'rouge_1_f_score': 0.95652, - 'rouge_1_f_score_cb': 0.95652, - 'rouge_1_f_score_ce': 0.95652, - 'rouge_1_precision': 0.95652, - [...] - - - To evaluate multiple systems: - - rouge = Rouge155() - rouge.system_dir = '/PATH/TO/systems' - rouge.model_dir = 'PATH/TO/models' - for system_id in ['id1', 'id2', 'id3']: - rouge.system_filename_pattern = \ - 'SL.P/.10.R.{}.SL062003-(\d+).html'.format(system_id) - rouge.model_filename_pattern = \ - 'SL.P.10.R.[A-Z].SL062003-#ID#.html' - rouge_output = rouge.evaluate(system_id) - print(rouge_output) - - """ - - def __init__(self, rouge_dir=None, rouge_args=None, temp_dir=None): - """ - Create a Rouge155 object. - - rouge_dir: Directory containing Rouge-1.5.5.pl - rouge_args: Arguments to pass through to ROUGE if you - don't want to use the default pyrouge - arguments. - - """ - self.temp_dir = temp_dir - self.log = log.get_global_console_logger() - self.__set_dir_properties() - self._config_file = None - self._settings_file = self.__get_config_path() - self.__set_rouge_dir(rouge_dir) - self.args = self.__clean_rouge_args(rouge_args) - self._system_filename_pattern = None - self._model_filename_pattern = None - - def save_home_dir(self): - config = ConfigParser() - section = 'pyrouge settings' - config.add_section(section) - config.set(section, 'home_dir', self._home_dir) - with open(self._settings_file, 'w') as f: - config.write(f) - self.log.info("Set ROUGE home directory to {}.".format(self._home_dir)) - - @property - def settings_file(self): - """ - Path of the setttings file, which stores the ROUGE home dir. - - """ - return self._settings_file - - @property - def bin_path(self): - """ - The full path of the ROUGE binary (although it's technically - a script), i.e. rouge_home_dir/ROUGE-1.5.5.pl - - """ - if self._bin_path is None: - raise Exception( - "ROUGE path not set. Please set the ROUGE home directory " - "and ensure that ROUGE-1.5.5.pl exists in it.") - return self._bin_path - - @property - def system_filename_pattern(self): - """ - The regular expression pattern for matching system summary - filenames. The regex string. - - E.g. "SL.P.10.R.11.SL062003-(\d+).html" will match the system - filenames in the SPL2003/system folder of the ROUGE SPL example - in the "sample-test" folder. - - Currently, there is no support for multiple systems. - - """ - return self._system_filename_pattern - - @system_filename_pattern.setter - def system_filename_pattern(self, pattern): - self._system_filename_pattern = pattern - - @property - def model_filename_pattern(self): - """ - The regular expression pattern for matching model summary - filenames. The pattern needs to contain the string "#ID#", - which is a placeholder for the document ID. - - E.g. "SL.P.10.R.[A-Z].SL062003-#ID#.html" will match the model - filenames in the SPL2003/system folder of the ROUGE SPL - example in the "sample-test" folder. - - "#ID#" is a placeholder for the document ID which has been - matched by the "(\d+)" part of the system filename pattern. - The different model summaries for a given document ID are - matched by the "[A-Z]" part. - - """ - return self._model_filename_pattern - - @model_filename_pattern.setter - def model_filename_pattern(self, pattern): - self._model_filename_pattern = pattern - - @property - def config_file(self): - return self._config_file - - @config_file.setter - def config_file(self, path): - config_dir, _ = os.path.split(path) - verify_dir(config_dir, "configuration file") - self._config_file = path - - def split_sentences(self): - """ - ROUGE requires texts split into sentences. In case the texts - are not already split, this method can be used. - - """ - from pyrouge.utils.sentence_splitter import PunktSentenceSplitter - self.log.info("Splitting sentences.") - ss = PunktSentenceSplitter() - - def sent_split_to_string(s): return "\n".join(ss.split(s)) - process_func = partial( - DirectoryProcessor.process, function=sent_split_to_string) - self.__process_summaries(process_func) - - @staticmethod - def convert_summaries_to_rouge_format(input_dir, output_dir): - """ - Convert all files in input_dir into a format ROUGE understands - and saves the files to output_dir. The input files are assumed - to be plain text with one sentence per line. - - input_dir: Path of directory containing the input files. - output_dir: Path of directory in which the converted files - will be saved. - - """ - DirectoryProcessor.process( - input_dir, output_dir, Rouge155.convert_text_to_rouge_format) - - @staticmethod - def convert_text_to_rouge_format(text, title="dummy title"): - """ - Convert a text to a format ROUGE understands. The text is - assumed to contain one sentence per line. - - text: The text to convert, containg one sentence per line. - title: Optional title for the text. The title will appear - in the converted file, but doesn't seem to have - any other relevance. - - Returns: The converted text as string. - - """ - sentences = text.split("\n") - sent_elems = [ - "[{i}] " - "{text}".format(i=i, text=sent) - for i, sent in enumerate(sentences, start=1)] - html = """ -
-{name}
".format( - id=system_id, name=system_filename) - - model_elems = ["systemX
-#
systemX
-#
systemX
-#
systemX
-#
systemX
-#