diff --git a/PaddleNLP/examples/README.md b/PaddleNLP/examples/README.md index 806b0f20e263561efd9c46f644905f289620f0b4..af8d47ee88c006cf88cca4fa0b57b554b5d0e9be 100644 --- a/PaddleNLP/examples/README.md +++ b/PaddleNLP/examples/README.md @@ -7,16 +7,16 @@ | 任务类型 | 目录 | 简介 | | ----------------------------------| ------------------------------------------------------------ | ------------------------------------------------------------ | -| 中文词法分析 | [LAC(Lexical Analysis of Chinese)](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/lexical_analysis) | 百度自主研发中文特色模型词法分析任务,集成了中文分词、词性标注和命名实体识别任务。输入是一个字符串,而输出是句子中的词边界和词性、实体类别。 | +| 中文词法分析 | [LAC(Lexical Analysis of Chinese)](./lexical_analysis) | 百度自主研发中文特色模型词法分析任务,集成了中文分词、词性标注和命名实体识别任务。输入是一个字符串,而输出是句子中的词边界和词性、实体类别。 | | 预训练词向量 | [WordEmbedding](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/word_embedding) | 提供了丰富的中文预训练词向量,通过简单配置即可使用词向量来进行热启训练,能支持较多的中文场景下的训练任务的热启训练,加快训练收敛速度。| ### 核心技术模型 | 任务类型 | 目录 | 简介 | | -------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | -| ERNIE-GEN文本生成 | [ERNIE-GEN(An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation)](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/text_generation/ernie-gen) |ERNIE-GEN是百度发布的生成式预训练模型,是一种Multi-Flow结构的预训练和微调框架。ERNIE-GEN利用更少的参数量和数据,在摘要生成、问题生成、对话和生成式问答4个任务共5个数据集上取得了SOTA效果 | -| BERT 预训练&GLUE下游任务 | [BERT(Bidirectional Encoder Representation from Transformers)](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/bert) | BERT模型作为目前最为火热语义表示预训练模型,PaddleNLP提供了简洁功效的实现方式,同时易用性方面通过简单参数切换即可实现不同的BERT模型。 | -| Electra 预训练&GLUE下游任务 | [Electra(Efficiently Learning an Encoder that Classifies Token Replacements Accurately)](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/electra) |ELECTRA 创新性地引入GAN的思想对BERT预训练过程进行了改进,在和BERT具有相同的模型参数、预训练计算量一样的情况下,ELECTRA GLUE得分明显好。同时相比GPT、ELMo,在GLUE得分略好时,ELECTRA预训练模型只需要很少的参数和计算量。| +| ERNIE-GEN文本生成 | [ERNIE-GEN(An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation)](./text_generation/ernie-gen) |ERNIE-GEN是百度发布的生成式预训练模型,是一种Multi-Flow结构的预训练和微调框架。ERNIE-GEN利用更少的参数量和数据,在摘要生成、问题生成、对话和生成式问答4个任务共5个数据集上取得了SOTA效果 | +| BERT 预训练&GLUE下游任务 | [BERT(Bidirectional Encoder Representation from Transformers)](./language_model/bert) | BERT模型作为目前最为火热语义表示预训练模型,PaddleNLP提供了简洁功效的实现方式,同时易用性方面通过简单参数切换即可实现不同的BERT模型。 | +| Electra 预训练&GLUE下游任务 | [Electra(Efficiently Learning an Encoder that Classifies Token Replacements Accurately)](./language_model/electra) |ELECTRA 创新性地引入GAN的思想对BERT预训练过程进行了改进,在和BERT具有相同的模型参数、预训练计算量一样的情况下,ELECTRA GLUE得分明显好。同时相比GPT、ELMo,在GLUE得分略好时,ELECTRA预训练模型只需要很少的参数和计算量。| ### 核心应用模型 @@ -25,20 +25,20 @@ | 模型 | 简介 | | ------------------------------------------------------------ | ------------------------------------------------------------ | -| [Seq2Seq](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/machine_translation/seq2seq) | 使用编码器-解码器(Encoder-Decoder)结构, 同时使用了Attention机制来加强Decoder和Encoder之间的信息交互,Seq2Seq 广泛应用于机器翻译,自动对话机器人,文档摘要自动生成,图片描述自动生成等任务中。| -| [Transformer](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/machine_translation/transformer) |基于PaddlePaddle框架的Transformer结构搭建的机器翻译模型,Transformer 计算并行度高,能解决学习长程依赖问题。并且模型框架集成了训练,验证,预测任务,功能完备,效果突出。| +| [Seq2Seq](./machine_translation/seq2seq) | 使用编码器-解码器(Encoder-Decoder)结构, 同时使用了Attention机制来加强Decoder和Encoder之间的信息交互,Seq2Seq 广泛应用于机器翻译,自动对话机器人,文档摘要自动生成,图片描述自动生成等任务中。| +| [Transformer](./machine_translation/transformer) |基于PaddlePaddle框架的Transformer结构搭建的机器翻译模型,Transformer 计算并行度高,能解决学习长程依赖问题。并且模型框架集成了训练,验证,预测任务,功能完备,效果突出。| #### 命名实体识别 (Named Entity Recognition) 命名实体识别(Named Entity Recognition,NER)是NLP中一项非常基础的任务。NER是信息提取、问答系统、句法分析、机器翻译等众多NLP任务的重要基础工具。命名实体识别的准确度,决定了下游任务的效果,是NLP中非常重要的一个基础问题。 在NER任务提供了两种解决方案,一类LSTM/GRU + CRF(Conditional Random Field),RNN类的模型来抽取底层文本的信息,而CRF(条件随机场)模型来学习底层Token之间的联系;另外一类是通过预训练模型,例如ERNIE,BERT模型,直接来预测Token的标签信息。 -因为该类模型较为抽象,提供了一份快递单信息抽取的训练脚本给大家使用,具体的任务是通过两类的模型来抽取快递单的核心信息,例如地址,姓名,手机号码,具体的[快递单任务链接](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/named_entity_recognition/express_ner)。 +因为该类模型较为抽象,提供了一份快递单信息抽取的训练脚本给大家使用,具体的任务是通过两类的模型来抽取快递单的核心信息,例如地址,姓名,手机号码,具体的[快递单任务链接](./named_entity_recognition/express_ner)。 下面是具体的模型信息。 | 模型 | 简介 | | ------------------------------------------------------------ | ------------------------------------------------------------ | -| [BiGRU+CRF](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/named_entity_recognition/express_ner) |传统的序列标注模型,通过双向GRU模型能抽取文本序列的信息和联系,通过CRF模型来学习文本Token之间的联系,本模型集成PaddleNLP自己开发的CRF模型,模型结构清晰易懂。 | -| [ERNIE/BERT Fine-tuning](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/named_entity_recognition) |通过预训练模型提供的强大的语义信息和ERNIE/BERT类模型的Self-Attention机制来覆盖Token之间的联系,直接通过BERT/ERNIE的序列分类模型来预测文本每个token的标签信息,模型结构简单,效果优异。| +| [BiGRU-CRF](./named_entity_recognition/express_ner) |传统的序列标注模型,通过双向GRU模型能抽取文本序列的信息和联系,通过CRF模型来学习文本Token之间的联系,本模型集成PaddleNLP自己开发的CRF模型,模型结构清晰易懂。 | +| [ERNIE/BERT Fine-tuning](./named_entity_recognition) |通过预训练模型提供的强大的语义信息和ERNIE/BERT类模型的Self-Attention机制来覆盖Token之间的联系,直接通过BERT/ERNIE的序列分类模型来预测文本每个token的标签信息,模型结构简单,效果优异。| #### 文本分类 (Text Classification) @@ -46,8 +46,8 @@ | 模型 | 简介 | | ------------------------------------------------------------ | ------------------------------------------------------------ | -| [RNN/GRU/LSTM](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/text_classification/rnn) | 面向通用场景的文本分类模型,网络结构接入常见的RNN类模型,例如LSTM,GRU,RNN。整体模型结构集成在百度的自研的Senta文本情感分类模型上,效果突出,用法简易。| -| [ERNIE/BERT Fine-tuning](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/text_classification/pretrained_models) |基于预训练后模型的文本分类的模型,多达11种的预训练模型可供使用,其中有较多中文预训练模型,预训练模型切换简单,情感分析任务上效果突出。| +| [RNN/GRU/LSTM](./text_classification/rnn) | 面向通用场景的文本分类模型,网络结构接入常见的RNN类模型,例如LSTM,GRU,RNN。整体模型结构集成在百度的自研的Senta文本情感分类模型上,效果突出,用法简易。| +| [ERNIE/BERT Fine-tuning](./text_classification/pretrained_models) |基于预训练后模型的文本分类的模型,多达11种的预训练模型可供使用,其中有较多中文预训练模型,预训练模型切换简单,情感分析任务上效果突出。| #### 文本生成 (Text Generation) @@ -55,7 +55,7 @@ | 模型 | 简介 | | ------------------------------------------------------------ | ------------------------------------------------------------ | -| [ERNIE-GEN(An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation)](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/text_generation/ernie-gen) |ERNIE-GEN是百度发布的生成式预训练模型,通过Global-Attention的方式解决训练和预测曝光偏差的问题,同时使用Multi-Flow Attention机制来分别进行Global和Context信息的交互,同时通过片段生成的方式来增加语义相关性。| +| [ERNIE-GEN(An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation)](./text_generation/ernie-gen) |ERNIE-GEN是百度发布的生成式预训练模型,通过Global-Attention的方式解决训练和预测曝光偏差的问题,同时使用Multi-Flow Attention机制来分别进行Global和Context信息的交互,同时通过片段生成的方式来增加语义相关性。| @@ -65,8 +65,8 @@ | 模型 | 简介 | | ------------------------------------------------------------ | ------------------------------------------------------------ | -| [SimNet](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/text_matching/simnet)|PaddleNLP提供的SimNet模型已经纳入了PaddleNLP的官方API中,用户直接调用API即完成一个SimNet模型的组网,在模型层面提供了Bow/CNN/LSTM/GRU常用信息抽取方式, 灵活高,使用方便。| -| [SentenceTransformer](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/text_matching/sentence_transformers)|直接调用简易的预训练模型接口接口完成对Sentence的语义表示,同时提供了较多的中文预训练模型,可以根据任务的来选择相关参数。| +| [SimNet](./text_matching/simnet)|PaddleNLP提供的SimNet模型已经纳入了PaddleNLP的官方API中,用户直接调用API即完成一个SimNet模型的组网,在模型层面提供了Bow/CNN/LSTM/GRU常用信息抽取方式, 灵活高,使用方便。| +| [SentenceTransformer](./text_matching/sentence_transformers)|直接调用简易的预训练模型接口接口完成对Sentence的语义表示,同时提供了较多的中文预训练模型,可以根据任务的来选择相关参数。| #### 语言模型 (Language Model) @@ -74,8 +74,8 @@ | 模型 | 简介 | | ------------------------------------------------------------ | ------------------------------------------------------------ | -| [RNNLM](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/language_model/rnnlm) |序列任务常用的rnn网络,实现了一个两层的LSTM网络,然后LSTM的结果去预测下一个词出现的概率。是基于RNN的常规的语言模型。| -| [ELMo](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/language_model/elmo) |ElMo是一个双向的LSTM语言模型,由一个前向和一个后向语言模型构成,目标函数就是取这两个方向语言模型的最大似然。ELMo主要是解决了传统的WordEmbedding的向量表示单一的问题,ELMo通过结合上下文来增强语义表示。| +| [RNNLM](./language_model/rnnlm) |序列任务常用的rnn网络,实现了一个两层的LSTM网络,然后LSTM的结果去预测下一个词出现的概率。是基于RNN的常规的语言模型。| +| [ELMo](./language_model/elmo) |ElMo是一个双向的LSTM语言模型,由一个前向和一个后向语言模型构成,目标函数就是取这两个方向语言模型的最大似然。ELMo主要是解决了传统的WordEmbedding的向量表示单一的问题,ELMo通过结合上下文来增强语义表示。| #### 文本图学习 (Text Graph) 在很多工业应用中,往往出现一种特殊的图:Text Graph。顾名思义,图的节点属性由文本构成,而边的构建提供了结构信息。如搜索场景下的Text Graph,节点可由搜索词、网页标题、网页正文来表达,用户反馈和超链信息则可构成边关系。百度图学习PGL((Paddle Graph Learning)团队提出ERNIESage(ERNIE SAmple aggreGatE)模型同时建模文本语义与图结构信息,有效提升Text Graph的应用效果。图学习是深度学习领域目前的研究热点,如果想对图学习有更多的了解,可以访问[PGL Github链接](https://github.com/PaddlePaddle/PGL/)。 @@ -83,14 +83,14 @@ ERNIESage模型的具体信息如下。 | 模型 | 简介 | | ------------------------------------------------------------ | ------------------------------------------------------------ | -| [ERNIESage(ERNIE SAmple aggreGatE)](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/text_graph/erniesage)|通过Graph(图)来来构建自身节点和邻居节点的连接关系,将自身节点和邻居节点的关系构建成一个关联样本输入到ERNIE中,ERNIE作为聚合函数(Aggregators)来表征自身节点和邻居节点的语义关系,最终强化图中节点的语义表示。在TextGraph的任务上ERNIESage的效果非常优秀。| +| [ERNIESage(ERNIE SAmple aggreGatE)](./text_graph/erniesage)|通过Graph(图)来来构建自身节点和邻居节点的连接关系,将自身节点和邻居节点的关系构建成一个关联样本输入到ERNIE中,ERNIE作为聚合函数(Aggregators)来表征自身节点和邻居节点的语义关系,最终强化图中节点的语义表示。在TextGraph的任务上ERNIESage的效果非常优秀。| #### 阅读理解(Machine Reading Comprehension) 机器阅读理解是近期自然语言处理领域的研究热点之一,也是人工智能在处理和理解人类语言进程中的一个长期目标。得益于深度学习技术和大规模标注数据集的发展,用端到端的神经网络来解决阅读理解任务取得了长足的进步。下面是具体的模型信息。 | 模型 | 简介 | | ------------------------------------------------------------ | ------------------------------------------------------------ | -| [BERT Fine-tuning](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/machine_reading_comprehension/) |通过ERNIE/BERT等预训练模型的强大的语义表示能力,设置在阅读理解上面的下游任务,该模块主要是提供了多个数据集来验证BERT模型在阅读理解上的效果,数据集主要是包括了SQuAD,DuReader,DuReader-robust,DuReader-yesno。同时提供了和相关阅读理解相关的Metric(指标),用户可以简易的调用这些API,快速验证模型效果。| +| [BERT Fine-tuning](./machine_reading_comprehension/) |通过ERNIE/BERT等预训练模型的强大的语义表示能力,设置在阅读理解上面的下游任务,该模块主要是提供了多个数据集来验证BERT模型在阅读理解上的效果,数据集主要是包括了SQuAD,DuReader,DuReader-robust,DuReader-yesno。同时提供了和相关阅读理解相关的Metric(指标),用户可以简易的调用这些API,快速验证模型效果。| #### 对话系统(Dialogue System) @@ -98,7 +98,8 @@ ERNIESage模型的具体信息如下。 | 模型 | 简介 | | ------------------------------------------------------------ | ------------------------------------------------------------ | -| [BERT-DGU](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/dialogue/dgu) |通过ERNIE/BERT等预训练模型的强大的语义表示能力,抽取对话中的文本语义信息,通过对文本分类等操作就可以完成对话中的诸多任务,例如意图识别,行文识别,状态跟踪等。| +| [DGU](./dialogue/dgu) |通过ERNIE/BERT等预训练模型的强大的语义表示能力,抽取对话中的文本语义信息,通过对文本分类等操作就可以完成对话中的诸多任务,例如意图识别,行文识别,状态跟踪等。| +| [PLATO-2](./dialogue/plato-2) | 百度自研领先的开放域对话预训练模型。[PLATO-2: Towards Building an Open-Domain Chatbot via Curriculum Learning](https://arxiv.org/abs/2006.16779) | #### 时间序列预测(Time Series) 时间序列是指按照时间先后顺序排列而成的序列,例如每日发电量、每小时营业额等组成的序列。通过分析时间序列中的发展过程、方向和趋势,我们可以预测下一段时间可能出现的情况。为了更好让大家了解时间序列预测任务,提供了基于19年新冠疫情预测的任务示例,有兴趣的话可以进行研究学习。 @@ -107,4 +108,4 @@ ERNIESage模型的具体信息如下。 | 模型 | 简介 | | ------------------------------------------------------------ | ------------------------------------------------------------ | -| [TCN(Temporal convolutional network)](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/time_series)|TCN模型基于卷积的时间序列模型,通过因果卷积(Causal Convolution)和空洞卷积(Dilated Convolution) 特定的组合方式解决卷积不适合时间序列任务的问题,TCN具备并行度高,内存低等诸多优点,在某些时间序列任务上效果已经超过传统的RNN模型。| +| [TCN(Temporal convolutional network)](./time_series)|TCN模型基于卷积的时间序列模型,通过因果卷积(Causal Convolution)和空洞卷积(Dilated Convolution) 特定的组合方式解决卷积不适合时间序列任务的问题,TCN具备并行度高,内存低等诸多优点,在某些时间序列任务上效果已经超过传统的RNN模型。| diff --git a/PaddleNLP/examples/electra/README.md b/PaddleNLP/examples/language_model/electra/README.md similarity index 100% rename from PaddleNLP/examples/electra/README.md rename to PaddleNLP/examples/language_model/electra/README.md diff --git a/PaddleNLP/examples/language_model/electra/run_glue.py b/PaddleNLP/examples/language_model/electra/run_glue.py new file mode 100644 index 0000000000000000000000000000000000000000..9ad07ae8a2f42d735f8a9c42db39db8291ca8574 --- /dev/null +++ b/PaddleNLP/examples/language_model/electra/run_glue.py @@ -0,0 +1,424 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import sys +import hashlib +import random +import time +import math +from functools import partial + +import numpy as np +import paddle +from paddle.io import DataLoader +from paddle.metric import Metric, Accuracy, Precision, Recall + +from paddlenlp.datasets import GlueCoLA, GlueSST2, GlueMRPC, GlueSTSB, GlueQQP, GlueMNLI, GlueQNLI, GlueRTE +from paddlenlp.data import Stack, Tuple, Pad +from paddlenlp.data.sampler import SamplerHelper +from paddlenlp.transformers import ElectraForSequenceClassification, ElectraTokenizer +from paddlenlp.utils.log import logger +from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman + +TASK_CLASSES = { + "cola": (GlueCoLA, Mcc), + "sst-2": (GlueSST2, Accuracy), + "mrpc": (GlueMRPC, AccuracyAndF1), + "sts-b": (GlueSTSB, PearsonAndSpearman), + "qqp": (GlueQQP, AccuracyAndF1), + "mnli": (GlueMNLI, Accuracy), + "qnli": (GlueQNLI, Accuracy), + "rte": (GlueRTE, Accuracy), +} + +MODEL_CLASSES = { + "electra": (ElectraForSequenceClassification, ElectraTokenizer), +} + + +def set_seed(args): + random.seed(args.seed + paddle.distributed.get_rank()) + np.random.seed(args.seed + paddle.distributed.get_rank()) + paddle.seed(args.seed + paddle.distributed.get_rank()) + + +def evaluate(model, loss_fct, metric, data_loader): + model.eval() + metric.reset() + for batch in data_loader: + input_ids, segment_ids, labels = batch + logits = model(input_ids=input_ids, token_type_ids=segment_ids) + loss = loss_fct(logits, labels) + correct = metric.compute(logits, labels) + metric.update(correct) + acc = metric.accumulate() + print("eval loss: %f, acc: %s, " % (loss.numpy(), acc), end='') + model.train() + + +def convert_example(example, + tokenizer, + label_list, + max_seq_length=128, + is_test=False): + """convert a glue example into necessary features""" + + def _truncate_seqs(seqs, max_seq_length): + if len(seqs) == 1: # single sentence + # Account for [CLS] and [SEP] with "- 2" + seqs[0] = seqs[0][0:(max_seq_length - 2)] + else: # Sentence pair + # Account for [CLS], [SEP], [SEP] with "- 3" + tokens_a, tokens_b = seqs + max_seq_length -= 3 + while True: # Truncate with longest_first strategy + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_seq_length: + break + if len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + return seqs + + def _concat_seqs(seqs, separators, seq_mask=0, separator_mask=1): + concat = sum((seq + sep for sep, seq in zip(separators, seqs)), []) + segment_ids = sum( + ([i] * (len(seq) + len(sep)) + for i, (sep, seq) in enumerate(zip(separators, seqs))), []) + if isinstance(seq_mask, int): + seq_mask = [[seq_mask] * len(seq) for seq in seqs] + if isinstance(separator_mask, int): + separator_mask = [[separator_mask] * len(sep) for sep in separators] + p_mask = sum((s_mask + mask + for sep, seq, s_mask, mask in zip( + separators, seqs, seq_mask, separator_mask)), []) + return concat, segment_ids, p_mask + + if not is_test: + # `label_list == None` is for regression task + label_dtype = "int64" if label_list else "float32" + # Get the label + label = example[-1] + example = example[:-1] + # Create label maps if classification task + if label_list: + label_map = {} + for (i, l) in enumerate(label_list): + label_map[l] = i + label = label_map[label] + label = np.array([label], dtype=label_dtype) + + # Tokenize raw text + tokens_raw = [tokenizer(l) for l in example] + # Truncate to the truncate_length, + tokens_trun = _truncate_seqs(tokens_raw, max_seq_length) + # Concate the sequences with special tokens + tokens_trun[0] = [tokenizer.cls_token] + tokens_trun[0] + tokens, segment_ids, _ = _concat_seqs(tokens_trun, [[tokenizer.sep_token]] * + len(tokens_trun)) + # Convert the token to ids + input_ids = tokenizer.convert_tokens_to_ids(tokens) + valid_length = len(input_ids) + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + # input_mask = [1] * len(input_ids) + if not is_test: + return input_ids, segment_ids, valid_length, label + else: + return input_ids, segment_ids, valid_length + + +def do_train(args): + paddle.enable_static() if not args.eager_run else None + paddle.set_device("gpu" if args.n_gpu else "cpu") + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + set_seed(args) + + args.task_name = args.task_name.lower() + dataset_class, metric_class = TASK_CLASSES[args.task_name] + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + train_dataset = dataset_class.get_datasets(["train"]) + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=train_dataset.get_labels(), + max_seq_length=args.max_seq_length) + train_dataset = train_dataset.apply(trans_func, lazy=True) + train_batch_sampler = paddle.io.DistributedBatchSampler( + train_dataset, batch_size=args.batch_size, shuffle=True) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment + Stack(), # length + Stack(dtype="int64" if train_dataset.get_labels() else "float32") # label + ): [data for i, data in enumerate(fn(samples)) if i != 2] + train_data_loader = DataLoader( + dataset=train_dataset, + batch_sampler=train_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + if args.task_name == "mnli": + dev_dataset_matched, dev_dataset_mismatched = dataset_class.get_datasets( + ["dev_matched", "dev_mismatched"]) + dev_dataset_matched = dev_dataset_matched.apply(trans_func, lazy=True) + dev_dataset_mismatched = dev_dataset_mismatched.apply( + trans_func, lazy=True) + dev_batch_sampler_matched = paddle.io.BatchSampler( + dev_dataset_matched, batch_size=args.batch_size, shuffle=False) + dev_data_loader_matched = DataLoader( + dataset=dev_dataset_matched, + batch_sampler=dev_batch_sampler_matched, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + dev_batch_sampler_mismatched = paddle.io.BatchSampler( + dev_dataset_mismatched, batch_size=args.batch_size, shuffle=False) + dev_data_loader_mismatched = DataLoader( + dataset=dev_dataset_mismatched, + batch_sampler=dev_batch_sampler_mismatched, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + else: + dev_dataset = dataset_class.get_datasets(["dev"]) + dev_dataset = dev_dataset.apply(trans_func, lazy=True) + dev_batch_sampler = paddle.io.BatchSampler( + dev_dataset, batch_size=args.batch_size, shuffle=False) + dev_data_loader = DataLoader( + dataset=dev_dataset, + batch_sampler=dev_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + num_labels = 1 if train_dataset.get_labels() == None else len( + train_dataset.get_labels()) + model = model_class.from_pretrained( + args.model_name_or_path, num_labels=num_labels) + if paddle.distributed.get_world_size() > 1: + model = paddle.DataParallel(model) + + num_training_steps = args.max_steps if args.max_steps > 0 else ( + len(train_data_loader) * args.num_train_epochs) + warmup_steps = int(math.floor(num_training_steps * args.warmup_proportion)) + lr_scheduler = paddle.optimizer.lr.LambdaDecay( + args.learning_rate, + lambda current_step, num_warmup_steps=warmup_steps, + num_training_steps=num_training_steps : float( + current_step) / float(max(1, num_warmup_steps)) + if current_step < num_warmup_steps else max( + 0.0, + float(num_training_steps - current_step) / float( + max(1, num_training_steps - num_warmup_steps)))) + + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + beta1=0.9, + beta2=0.999, + epsilon=args.adam_epsilon, + parameters=model.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in [ + p.name for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm", "LayerNorm"]) + ]) + + loss_fct = paddle.nn.loss.CrossEntropyLoss() if train_dataset.get_labels( + ) else paddle.nn.loss.MSELoss() + + metric = metric_class() + + ### TODO: use hapi + # trainer = paddle.hapi.Model(model) + # trainer.prepare(optimizer, loss_fct, paddle.metric.Accuracy()) + # trainer.fit(train_data_loader, + # dev_data_loader, + # log_freq=args.logging_steps, + # epochs=args.num_train_epochs, + # save_dir=args.output_dir) + + global_step = 0 + tic_train = time.time() + for epoch in range(args.num_train_epochs): + for step, batch in enumerate(train_data_loader): + global_step += 1 + input_ids, segment_ids, labels = batch + logits = model(input_ids=input_ids, token_type_ids=segment_ids) + loss = loss_fct(logits, labels) + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_gradients() + if global_step % args.logging_steps == 0: + print( + "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s" + % (global_step, num_training_steps, epoch, step, + paddle.distributed.get_rank(), loss, optimizer.get_lr(), + args.logging_steps / (time.time() - tic_train))) + tic_train = time.time() + if global_step % args.save_steps == 0: + tic_eval = time.time() + if args.task_name == "mnli": + evaluate(model, loss_fct, metric, dev_data_loader_matched) + evaluate(model, loss_fct, metric, + dev_data_loader_mismatched) + print("eval done total : %s s" % (time.time() - tic_eval)) + else: + evaluate(model, loss_fct, metric, dev_data_loader) + print("eval done total : %s s" % (time.time() - tic_eval)) + if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: + output_dir = os.path.join(args.output_dir, + "%s_ft_model_%d.pdparams" % + (args.task_name, global_step)) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # Need better way to get inner model of DataParallel + model_to_save = model._layers if isinstance( + model, paddle.DataParallel) else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + +def get_md5sum(file_path): + md5sum = None + if os.path.isfile(file_path): + with open(file_path, 'rb') as f: + md5_obj = hashlib.md5() + md5_obj.update(f.read()) + hash_code = md5_obj.hexdigest() + md5sum = str(hash_code).lower() + return md5sum + + +def print_arguments(args): + """print arguments""" + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--task_name", + default=None, + type=str, + required=True, + help="The name of the task to train selected in the list: " + + ", ".join(TASK_CLASSES.keys()), ) + parser.add_argument( + "--model_type", + default="electra", + type=str, + required=False, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--model_name_or_path", + default="./", + type=str, + required=False, + help="Path to pre-trained model or shortcut name selected in the list: " + + ", ".join( + sum([ + list(classes[-1].pretrained_init_configuration.keys()) + for classes in MODEL_CLASSES.values() + ], [])), ) + parser.add_argument( + "--output_dir", + default="./ft_model/", + type=str, + required=False, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--learning_rate", + default=1e-4, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument( + "--num_train_epochs", + default=3, + type=int, + help="Total number of training epochs to perform.", ) + parser.add_argument( + "--logging_steps", + type=int, + default=100, + help="Log every X updates steps.") + parser.add_argument( + "--save_steps", + type=int, + default=100, + help="Save checkpoint every X updates steps.") + parser.add_argument( + "--batch_size", + default=32, + type=int, + help="Batch size per GPU/CPU for training.", ) + parser.add_argument( + "--weight_decay", + default=0.0, + type=float, + help="Weight decay if we apply some.") + parser.add_argument( + "--warmup_proportion", + default=0.1, + type=float, + help="Linear warmup proportion over total steps.") + parser.add_argument( + "--adam_epsilon", + default=1e-6, + type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument( + "--seed", default=42, type=int, help="random seed for initialization") + parser.add_argument( + "--eager_run", default=True, type=eval, help="Use dygraph mode.") + parser.add_argument( + "--n_gpu", + default=1, + type=int, + help="number of gpus to use, 0 for cpu.") + args, unparsed = parser.parse_known_args() + print_arguments(args) + if args.n_gpu > 1: + paddle.distributed.spawn(do_train, args=(args, ), nprocs=args.n_gpu) + else: + do_train(args) diff --git a/PaddleNLP/examples/electra/run_pretrain.py b/PaddleNLP/examples/language_model/electra/run_pretrain.py similarity index 100% rename from PaddleNLP/examples/electra/run_pretrain.py rename to PaddleNLP/examples/language_model/electra/run_pretrain.py