diff --git a/PaddleNLP/docs/models.md b/PaddleNLP/docs/models.md index b9f36b643ebaa47f6bc51643916b8821edadc703..4b185ca1472955d43b37e4b06012ae478dc35d2f 100644 --- a/PaddleNLP/docs/models.md +++ b/PaddleNLP/docs/models.md @@ -1,3 +1,133 @@ # paddlenlp.models -高阶组网API说明,Ernie, SimNet, Senta +该模块提供了百度自研的模型的高阶API,如文本分类模型Senta,文本匹配模型SimNet,通用预训练模型ERNIE等。 + +```python +class paddlenlp.models.Ernie(model_name, num_classes, task=None, **kwargs): + """ + 预训练模型ERNIE。 + 更多信息参考:ERNIE: Enhanced Representation through Knowledge Integration(https://arxiv.org/abs/1904.09223) + + 参数: + `model_name (obj:`str`)`: 模型名称,如`ernie-1.0`,`ernie-tiny`,`ernie-2.0-en`, `ernie-2.0-large-en`。 + `num_classes (obj:`int`)`: 分类类别数。 + `task (obj:`str`): 预训练模型ERNIE用于下游任务名称,可以为`seq-cls`,`token-cls`,`qa`. 默认为None + + - task='seq-cls': ERNIE用于文本分类任务。其将从ERNIE模型中提取句子特征,用于最后一层全连接网络进行文本分类。 + 详细信息参考:`paddlenlp.transformers.ErnieForSequenceClassification`。 + - task='token-cls': ERNIE用于序列标注任务。其将从ERNIE模型中提取每一个token特征,用于最后一层全连接网络进行token分类。 + 详细信息参考:`paddlenlp.transformers.ErnieForQuestionAnswering`。 + - task='qa': ERNIE用于阅读理解任务。其将从ERNIE模型中提取每一个token特征,用于最后一层全连接网络进行答案位置在原文中位置的预测。 + 详细信息参考:`paddlenlp.transformers.ErnieForTokenClassification`。 + - task='None':预训练模型ERNIE。可将其作为backbone,用于提取句子特征pooled_output、token特征sequence_output。 + 详细信息参考:`paddlenlp.transformers.ErnieModel` + """ + + def forward(input_ids, token_type_ids=None, position_ids=None, attention_mask=None): + """ + 参数: + `input_ids (obj:`paddle.Tensor`)`:文本token id,shape为(batch_size, sequence_length)。 + `token_type_ids (obj:`paddle.Tensor`)`: 各token所在文本的标识(token属于文本1或者文本2),shape为(batch_size, sequence_length)。 + 默认为None,表示所有token都属于文本1。 + `position_ids(obj:`paddle.Tensor`)`:各Token在输入序列中的位置,shape为(batch_size, sequence_length)。默认为None。 + `attention_mask`(obj:`paddle.Tensor`)`:为了避免在padding token上做attention操作,`attention_mask`表示token是否为padding token的标志矩阵, + shape为(batch_size, sequence_length)。mask的值或为0或为1, 为1表示该token是padding token,为0表示该token为真实输入token id。默认为None。 + + 返回: + - 当`task=None`时,返回相应下游任务的分类概率值`probs(obj:`paddle.Tensor`)`,shape为(batch_size,num_classes)。 + - 当`task=None`时,返回预训练模型ERNIE的句子特征pooled_output、token特征sequence_output。 + * pooled_output(obj:`paddle.Tensor`):shape (batch_size,hidden_size) + * sequence_output(obj:`paddle.Tensor`):shape (batch_size,sequence_length, hidden_size) + + """ + +``` + + +```python +class paddlenlp.models.Senta(network, vocab_size, num_classes, emb_dim=128, pad_token_id=0): + """ + 文本分类模型Senta + + 参数: + `network(obj:`str`)`: 网络名称,可选bow,bilstm,bilstm_attn,bigru,birnn,cnn,lstm,gru,rnn以及textcnn。 + + - network='bow',对输入word embedding相加作为文本特征表示。 + 详细信息参考:`paddlenlp.seq2vec.BoWEncoder`。 + - network=`bilstm`, 对输入word embedding进行双向lstm操作,取最后一个step的表示作为文本特征表示。 + 详细信息参考:`paddlenlp.seq2vec.LSTMEncoder`。 + - network=`bilstm_attn`, 对输入word embedding进行双向lstm和Attention操作,取最后一个step的表示作为文本特征表示。 + 详细信息参考:`paddlenlp.seq2vec.LSTMEncoder`。 + - network=`bigru`, 对输入word embedding进行双向gru操作,取最后一个step的表示作为文本特征表示。 + 详细信息参考:`paddlenlp.seq2vec.GRUEncoder`。 + - network=`birnn`, 对输入word embedding进行双向rnn操作,取最后一个step的表示作为文本特征表示。 + 详细信息参考:`paddlenlp.seq2vec.RNNEncoder`。 + - network='cnn',对输入word embedding进行一次积操作后进行max-pooling,作为文本特征表示。 + 详细信息参考:`paddlenlp.seq2vec.CNNEncoder`。 + - network='lstm', 对输入word embedding进行lstm操作后进行max-pooling,作为文本特征表示。 + 详细信息参考:`paddlenlp.seq2vec.LSTMEncoder`。 + - network='gru', 对输入word embedding进行lstm操作后进行max-pooling,作为文本特征表示。 + 详细信息参考:`paddlenlp.seq2vec.GRUEncoder`。 + - network='rnn', 对输入word embedding进行lstm操作后进行max-pooling,作为文本特征表示。 + 详细信息参考:`paddlenlp.seq2vec.RNNEncoder`。 + - network='textcnn',对输入word embedding进行多次卷积和max-pooling,作为文本特征表示。 + 详细信息参考:`paddlenlp.seq2vec.CNNEncoder`。 + + `vocab_size(obj:`int`)`:词汇表大小。 + `num_classes(obj:`int`)`:分类类别数。 + `emb_dim(obj:`int`)`:word embedding维度,默认128. + `pad_token_id(obj:`int`)`:padding token 在词汇表中index,默认0。 + + """ + + def forward(text, seq_len): + """ + 参数: + `text(obj:`paddle.Tensor`)`: 文本token id,shape为(batch_size, sequence_length)。 + `seq_len(obj:`paddle.Tensor`): 文本序列长度, shape为(batch_size)。 + + 返回: + `probs(obj:`paddle.Tensor`)`: 分类概率值,shape为(batch_size,num_classes)。 + + """ + +``` + +```python +class paddlenlp.models.SimNet(nn.Layer): + """ + 文本匹配模型SimNet + + 参数: + `network(obj:`str`)`: 网络名称,可选bow,cnn,lstm,以及gru,rnn以及textcnn。 + + - network='bow',对输入word embedding相加作为文本特征表示。 + 详细信息参考:`paddlenlp.seq2vec.BoWEncoder`。 + - network='cnn',对输入word embedding进行一次积操作后进行max-pooling,作为文本特征表示。 + 详细信息参考:`paddlenlp.seq2vec.CNNEncoder`。 + - network='lstm', 对输入word embedding进行lstm操作,取最后一个step的表示作为文本特征表示。 + 详细信息参考:`paddlenlp.seq2vec.LSTMEncoder`。 + - network='gru', 对输入word embedding进行lstm操作后进行max-pooling,取最后一个step的表示作为文本特征表示。 + 详细信息参考:`paddlenlp.seq2vec.GRUEncoder`。 + + `vocab_size(obj:`int`)`:词汇表大小。 + `num_classes(obj:`int`)`:分类类别数。 + `emb_dim(obj:`int`)`:word embedding维度,默认128。 + `pad_token_id(obj:`int`)`:padding token 在词汇表中index,默认0。 + + """ + + def forward(query, title, query_seq_len=None, title_seq_len=None): + """ + 参数: + `query(obj:`paddle.Tensor`)`: query文本token id,shape为(batch_size, query_sequence_length)。 + `title(obj:`paddle.Tensor`)`: title文本token id,shape为(batch_size, title_sequence_length)。 + + `query_seq_len(obj:`paddle.Tensor`): query文本序列长度,shape为(batch_size)。。 + + 返回: + `probs(obj:`paddle.Tensor`)`: 分类概率值,shape为(batch_size,num_classes)。 + + """ + +``` diff --git a/PaddleNLP/examples/text_classification/pretrained_models/README.md b/PaddleNLP/examples/text_classification/pretrained_models/README.md index efdb0f8fd69e88cf1ac59893e1b42ae2591dedc1..fe98917252615b6fb76ad597a2a7ef895c7813a8 100644 --- a/PaddleNLP/examples/text_classification/pretrained_models/README.md +++ b/PaddleNLP/examples/text_classification/pretrained_models/README.md @@ -11,8 +11,8 @@ 本项目针对中文文本分类问题,开源了一系列模型,供用户可配置地使用: + BERT([Bidirectional Encoder Representations from Transformers](https://arxiv.org/abs/1810.04805))中文模型,简写`bert-base-chinese`, 其由12层Transformer网络组成。 -+ ERNIE([Enhanced Representation through Knowledge Integration](https://arxiv.org/pdf/1904.09223)),支持ERNIE 1.0中文模型(简写`ernie-1.0`)和ERNIE Tiny中文模型(简写`ernie_tiny`)。 - 其中`ernie`由12层Transformer网络组成,`ernie_tiny`由3层Transformer网络组成。 ++ ERNIE([Enhanced Representation through Knowledge Integration](https://arxiv.org/pdf/1904.09223)),支持ERNIE 1.0中文模型(简写`ernie-1.0`)和ERNIE Tiny中文模型(简写`ernie-tiny`)。 + 其中`ernie`由12层Transformer网络组成,`ernie-tiny`由3层Transformer网络组成。 + RoBERTa([A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692)),支持24层Transformer网络的`roberta-wwm-ext-large`和12层Transformer网络的`roberta-wwm-ext`。 + Electra([ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://arxiv.org/abs/2003.10555)), 支持hidden_size=256的`chinese-electra-discriminator-small`和 hidden_size=768的`chinese-electra-discriminator-base` @@ -28,8 +28,8 @@ | roberta-wwm-ext-large | 0.95250 | 0.95333 | | rbt3 | 0.92583 | 0.93250 | | rbtl3 | 0.9341 | 0.93583 | -| chinese-electra-discriminator-base | 0.94500 | 0.94500 | -| chinese-electra-discriminator-small | 0.92417 | 0.93417 | +| chinese-electra-base | 0.94500 | 0.94500 | +| chinese-electra-small | 0.92417 | 0.93417 | ## 快速开始 @@ -66,13 +66,13 @@ pretrained_models/ ```shell # 设置使用的GPU卡号 CUDA_VISIBLE_DEVICES=0 -python train.py --model_type ernie --model_name ernie_tiny --n_gpu 1 --save_dir ./checkpoints +python train.py --model_type ernie --model_name ernie-tiny --n_gpu 1 --save_dir ./checkpoints ``` 可支持配置的参数: * `model_type`:必选,模型类型,可以选择bert,ernie,roberta。 -* `model_name`: 必选,具体的模型简称。如`model_type=ernie`,则model_name可以选择`ernie`和`ernie_tiny`。`model_type=bert`,则model_name可以选择`bert-base-chinese`。 +* `model_name`: 必选,具体的模型简称。如`model_type=ernie`,则model_name可以选择`ernie-1.0`和`ernie-tiny`。`model_type=bert`,则model_name可以选择`bert-base-chinese`。 `model_type=roberta`,则model_name可以选择`roberta-wwm-ext-large`和`roberta-wwm-ext`。 * `save_dir`:必选,保存训练模型的目录。 * `max_seq_length`:可选,ERNIE/BERT模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为128。 @@ -99,14 +99,14 @@ checkpoints/ **NOTE:** * 如需恢复模型训练,则可以设置`init_from_ckpt`, 如`init_from_ckpt=checkpoints/model_100/model_state.pdparams`。 -* 如需使用ernie_tiny模型,则需要提前先安装sentencepiece依赖,如`pip install sentencepiece` +* 如需使用ernie-tiny模型,则需要提前先安装sentencepiece依赖,如`pip install sentencepiece` ### 模型预测 启动预测: ```shell export CUDA_VISIBLE_DEVICES=0 -python predict.py --model_type ernie --model_name ernie_tiny --params_path checkpoints/model_400/model_state.pdparams +python predict.py --model_type ernie --model_name ernie-tiny --params_path checkpoints/model_400/model_state.pdparams ``` 将待预测数据如以下示例: diff --git a/PaddleNLP/examples/text_classification/pretrained_models/predict.py b/PaddleNLP/examples/text_classification/pretrained_models/predict.py index 59477e11110276d4e5a844fa21c2551295200e37..d34024e17c938e838763eb069f2961464a8ec6bc 100644 --- a/PaddleNLP/examples/text_classification/pretrained_models/predict.py +++ b/PaddleNLP/examples/text_classification/pretrained_models/predict.py @@ -42,7 +42,7 @@ def parse_args(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--model_type", default='ernie', required=True, type=str, help="Model type selected in the list: " +", ".join(MODEL_CLASSES.keys())) - parser.add_argument("--model_name_or_path", default='ernie_tiny', required=True, type=str, help="Path to pre-trained model or shortcut name selected in the list: " + + parser.add_argument("--model_name_or_path", default='ernie-tiny', required=True, type=str, help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(sum([list(classes[-1].pretrained_init_configuration.keys()) for classes in MODEL_CLASSES.values()], []))) parser.add_argument("--params_path", type=str, required=True, help="The path to model parameters to be loaded.") diff --git a/PaddleNLP/examples/text_classification/pretrained_models/train.py b/PaddleNLP/examples/text_classification/pretrained_models/train.py index 33cf221114320b7a7d3e8d334084f63d65ca42f2..45308b5d430a744e11b5aa3c2603720dc0a7a947 100644 --- a/PaddleNLP/examples/text_classification/pretrained_models/train.py +++ b/PaddleNLP/examples/text_classification/pretrained_models/train.py @@ -49,7 +49,7 @@ def parse_args(): ", ".join(MODEL_CLASSES.keys())) parser.add_argument( "--model_name", - default='ernie_tiny', + default='ernie-tiny', required=True, type=str, help="Path to pre-trained model or shortcut name selected in the list: " diff --git a/PaddleNLP/examples/text_matching/sentence_transformers/README.md b/PaddleNLP/examples/text_matching/sentence_transformers/README.md index 5670563c7636d1fd2c1b12c2b5235bee9988445b..83871a18a22e529f82fe14d5bdb05eacd0981aaa 100644 --- a/PaddleNLP/examples/text_matching/sentence_transformers/README.md +++ b/PaddleNLP/examples/text_matching/sentence_transformers/README.md @@ -39,8 +39,8 @@ PaddleNLP提供了丰富的预训练模型,并且可以便捷地获取PaddlePa 本项目针对中文文本匹配问题,开源了一系列模型,供用户可配置地使用: + BERT([Bidirectional Encoder Representations from Transformers](https://arxiv.org/abs/1810.04805))中文模型,简写`bert-base-chinese`, 其由12层Transformer网络组成。 -+ ERNIE([Enhanced Representation through Knowledge Integration](https://arxiv.org/pdf/1904.09223)),支持ERNIE 1.0中文模型(简写`ernie-1.0`)和ERNIE Tiny中文模型(简写`ernie_tiny`)。 - 其中`ernie`由12层Transformer网络组成,`ernie_tiny`由3层Transformer网络组成。 ++ ERNIE([Enhanced Representation through Knowledge Integration](https://arxiv.org/pdf/1904.09223)),支持ERNIE 1.0中文模型(简写`ernie-1.0`)和ERNIE Tiny中文模型(简写`ernie-tiny`)。 + 其中`ernie`由12层Transformer网络组成,`ernie-tiny`由3层Transformer网络组成。 + RoBERTa([A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692)),支持24层Transformer网络的`roberta-wwm-ext-large`和12层Transformer网络的`roberta-wwm-ext`。 + Electra([ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://arxiv.org/abs/2003.10555)), 支持hidden_size=256的`chinese-electra-discriminator-small`和 hidden_size=768的`chinese-electra-discriminator-base` @@ -48,11 +48,11 @@ PaddleNLP提供了丰富的预训练模型,并且可以便捷地获取PaddlePa ## TODO 增加模型效果 | 模型 | dev acc | test acc | | ---- | ------- | -------- | -| bert-base-chinese | | | -| bert-wwm-chinese | | | +| bert-base-chinese | 0.86537 | 0.84440 | +| bert-wwm-chinese | 0.86333 | 0.84128 | | bert-wwm-ext-chinese | | | -| ernie | | | -| ernie-tiny | | | +| ernie | 0.87480 | 0.84760 | +| ernie-tiny | 0.86071 | 0.83352 | | roberta-wwm-ext | | | | roberta-wwm-ext-large | | | | rbt3 | | | @@ -108,7 +108,7 @@ python train.py --model_type ernie --model_name ernie-1.0 --n_gpu 1 --save_dir . 可支持配置的参数: * `model_type`:必选,模型类型,可以选择bert,ernie,roberta。 -* `model_name`: 必选,具体的模型简称。如`model_type=ernie`,则model_name可以选择`ernie`和`ernie_tiny`。`model_type=bert`,则model_name可以选择`bert-base-chinese`。 +* `model_name`: 必选,具体的模型简称。如`model_type=ernie`,则model_name可以选择`ernie-1.0`和`ernie-tiny`。`model_type=bert`,则model_name可以选择`bert-base-chinese`。 `model_type=roberta`,则model_name可以选择`roberta-wwm-ext-large`和`roberta-wwm-ext`。 * `save_dir`:必选,保存训练模型的目录。 * `max_seq_length`:可选,ERNIE/BERT模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为128。 @@ -135,14 +135,14 @@ checkpoints/ **NOTE:** * 如需恢复模型训练,则可以设置`init_from_ckpt`, 如`init_from_ckpt=checkpoints/model_100/model_state.pdparams`。 -* 如需使用ernie_tiny模型,则需要提前先安装sentencepiece依赖,如`pip install sentencepiece` +* 如需使用ernie-tiny模型,则需要提前先安装sentencepiece依赖,如`pip install sentencepiece` ### 模型预测 启动预测: ```shell CUDA_VISIBLE_DEVICES=0 -python predict.py --model_type ernie --model_name ernie_tiny --params_path checkpoints/model_400/model_state.pdparams +python predict.py --model_type ernie --model_name ernie-tiny --params_path checkpoints/model_400/model_state.pdparams ``` 将待预测数据如以下示例: diff --git a/PaddleNLP/examples/text_matching/sentence_transformers/predict.py b/PaddleNLP/examples/text_matching/sentence_transformers/predict.py index 6d90dc9a50918f6c3f40ff9474d2f4e3a6fdbd3c..c67ac7a502570ceac4448cbb1c62832a77f2ddb5 100644 --- a/PaddleNLP/examples/text_matching/sentence_transformers/predict.py +++ b/PaddleNLP/examples/text_matching/sentence_transformers/predict.py @@ -31,8 +31,7 @@ MODEL_CLASSES = { 'ernie': (ppnlp.transformers.ErnieModel, ppnlp.transformers.ErnieTokenizer), 'roberta': (ppnlp.transformers.RobertaModel, ppnlp.transformers.RobertaTokenizer), - # 'electra': (ppnlp.transformers.Electra, - # ppnlp.transformers.ElectraTokenizer) + 'electra': (ppnlp.transformers.Electra, ppnlp.transformers.ElectraTokenizer) } @@ -176,10 +175,6 @@ def predict(model, data, tokenizer, label_map, batch_size=1): title_input_ids = paddle.to_tensor(title_input_ids) title_segment_ids = paddle.to_tensor(title_segment_ids) - print(query_input_ids) - print(query_segment_ids) - print(title_segment_ids) - probs = model( query_input_ids, title_input_ids, diff --git a/PaddleNLP/examples/text_matching/sentence_transformers/train.py b/PaddleNLP/examples/text_matching/sentence_transformers/train.py index bb9673abf35b41b096a46375110c1129013b6288..c9a1d46b968ebef981fae5b8cffc45f1d240fac3 100644 --- a/PaddleNLP/examples/text_matching/sentence_transformers/train.py +++ b/PaddleNLP/examples/text_matching/sentence_transformers/train.py @@ -32,8 +32,7 @@ MODEL_CLASSES = { 'ernie': (ppnlp.transformers.ErnieModel, ppnlp.transformers.ErnieTokenizer), 'roberta': (ppnlp.transformers.RobertaModel, ppnlp.transformers.RobertaTokenizer), - # 'electra': (ppnlp.transformers.Electra, - # ppnlp.transformers.ElectraTokenizer) + 'electra': (ppnlp.transformers.Electra, ppnlp.transformers.ElectraTokenizer) } diff --git a/PaddleNLP/paddlenlp/models/ernie.py b/PaddleNLP/paddlenlp/models/ernie.py index 1d03c69d7d4bee66592444569997fcb29d86c4d8..9d0508d56101c1d8a0de76b5378dd4b2ecc8d484 100644 --- a/PaddleNLP/paddlenlp/models/ernie.py +++ b/PaddleNLP/paddlenlp/models/ernie.py @@ -20,7 +20,7 @@ from paddlenlp.transformers import * class Ernie(nn.Layer): - def __init__(self, model_name, num_classes, task=None): + def __init__(self, model_name, num_classes, task=None, **kwargs): super().__init__() model_name = model_name.lower() self.task = task.lower() @@ -30,20 +30,21 @@ class Ernie(nn.Layer): assert model_name in required_names, "model_name must be in %s, unknown %s ." ( required_names, model_name) self.model = ErnieForSequenceClassification.from_pretrained( - model_name, num_classes=num_classes) + model_name, num_classes=num_classes, **kwargs) elif self.task == 'token-cls': required_names = list(ErnieForTokenClassification. pretrained_init_configuration.keys()) assert model_name in required_names, "model_name must be in %s, unknown %s ." ( required_names, model_name) self.model = ErnieForTokenClassification.from_pretrained( - model_name, num_classes=num_classes) + model_name, num_classes=num_classes, **kwargs) elif self.task == 'qa': required_names = list( ErnieForQuestionAnswering.pretrained_init_configuration.keys()) assert model_name in required_names, "model_name must be in %s, unknown %s ." ( required_names, model_name) - self.model = ErnieForQuestionAnswering.from_pretrained(model_name) + self.model = ErnieForQuestionAnswering.from_pretrained(model_name, + **kwargs) elif self.task is None: required_names = list(ErnieModel.pretrained_init_configuration.keys( )) diff --git a/PaddleNLP/paddlenlp/models/senta.py b/PaddleNLP/paddlenlp/models/senta.py index b6543a861be33dbe00d95e94e0cdfe3d670ba63c..f3f89c664d81554b91f701d262dc49ae4a0b8ee9 100644 --- a/PaddleNLP/paddlenlp/models/senta.py +++ b/PaddleNLP/paddlenlp/models/senta.py @@ -23,32 +23,32 @@ INF = 1. * 1e12 class Senta(nn.Layer): def __init__(self, - network_name, + network, vocab_size, num_classes, emb_dim=128, pad_token_id=0): super().__init__() - network_name = network_name.lower() - if network_name == 'bow': + network = network.lower() + if network == 'bow': self.model = BoWModel( vocab_size, num_classes, emb_dim, padding_idx=pad_token_id) - elif network_name == 'bigru': + elif network == 'bigru': self.model = GRUModel( vocab_size, num_classes, emb_dim, direction='bidirectional', padding_idx=pad_token_id) - elif network_name == 'bilstm': + elif network == 'bilstm': self.model = LSTMModel( vocab_size, num_classes, emb_dim, direction='bidirectional', padding_idx=pad_token_id) - elif network_name == 'bilstm_attn': + elif network == 'bilstm_attn': lstm_hidden_size = 196 attention = SelfInteractiveAttention(hidden_size=2 * lstm_hidden_size) @@ -58,17 +58,17 @@ class Senta(nn.Layer): lstm_hidden_size=lstm_hidden_size, num_classes=num_classes, padding_idx=pad_token_id) - elif network_name == 'birnn': + elif network == 'birnn': self.model = RNNModel( vocab_size, num_classes, emb_dim, direction='bidrectional', padding_idx=pad_token_id) - elif network_name == 'cnn': + elif network == 'cnn': self.model = CNNModel( vocab_size, num_classes, emb_dim, padding_idx=pad_token_id) - elif network_name == 'gru': + elif network == 'gru': self.model = GRUModel( vocab_size, num_classes, @@ -76,7 +76,7 @@ class Senta(nn.Layer): direction='forward', padding_idx=pad_token_id, pooling_type='max') - elif network_name == 'lstm': + elif network == 'lstm': self.model = LSTMModel( vocab_size, num_classes, @@ -84,7 +84,7 @@ class Senta(nn.Layer): direction='forward', padding_idx=pad_token_id, pooling_type='max') - elif network_name == 'rnn': + elif network == 'rnn': self.model = RNNModel( vocab_size, num_classes, @@ -92,15 +92,15 @@ class Senta(nn.Layer): direction='forward', padding_idx=pad_token_id, pooling_type='max') - elif network_name == 'textcnn': + elif network == 'textcnn': self.model = TextCNNModel( vocab_size, num_classes, emb_dim, padding_idx=pad_token_id) else: raise ValueError( "Unknown network: %s, it must be one of bow, lstm, bilstm, cnn, gru, bigru, rnn, birnn, bilstm_attn and textcnn." - % network_name) + % network) - def forward(self, text, seq_len): + def forward(self, text, seq_len=None): logits = self.model(text, seq_len) probs = F.softmax(logits, axis=-1) return probs @@ -137,7 +137,7 @@ class BoWModel(nn.Layer): self.fc2 = nn.Linear(hidden_size, fc_hidden_size) self.output_layer = nn.Linear(fc_hidden_size, num_classes) - def forward(self, text, seq_len): + def forward(self, text, seq_len=None): # Shape: (batch_size, num_tokens, embedding_dim) embedded_text = self.embedder(text) @@ -462,7 +462,7 @@ class CNNModel(nn.Layer): self.fc = nn.Linear(self.encoder.get_output_dim(), fc_hidden_size) self.output_layer = nn.Linear(fc_hidden_size, num_classes) - def forward(self, text, seq_len): + def forward(self, text, seq_len=None): # Shape: (batch_size, num_tokens, embedding_dim) embedded_text = self.embedder(text) # Shape: (batch_size, len(ngram_filter_sizes)*num_filter) @@ -511,7 +511,7 @@ class TextCNNModel(nn.Layer): self.fc = nn.Linear(self.encoder.get_output_dim(), fc_hidden_size) self.output_layer = nn.Linear(fc_hidden_size, num_classes) - def forward(self, text, seq_len): + def forward(self, text, seq_len=None): # Shape: (batch_size, num_tokens, embedding_dim) embedded_text = self.embedder(text) # Shape: (batch_size, len(ngram_filter_sizes)*num_filter) diff --git a/PaddleNLP/paddlenlp/transformers/ernie/modeling.py b/PaddleNLP/paddlenlp/transformers/ernie/modeling.py index e285b082e2b8cb0db5e22ccc69a43f7856fdaad2..44050560a2a89a91d9458c989c0e859e3d402f2d 100644 --- a/PaddleNLP/paddlenlp/transformers/ernie/modeling.py +++ b/PaddleNLP/paddlenlp/transformers/ernie/modeling.py @@ -106,7 +106,7 @@ class ErniePretrainedModel(PretrainedModel): "vocab_size": 18000, "pad_token_id": 0, }, - "ernie_tiny": { + "ernie-tiny": { "attention_probs_dropout_prob": 0.1, "hidden_act": "relu", "hidden_dropout_prob": 0.1, @@ -153,7 +153,7 @@ class ErniePretrainedModel(PretrainedModel): "model_state": { "ernie-1.0": "https://paddlenlp.bj.bcebos.com/models/transformers/ernie/ernie_v1_chn_base.pdparams", - "ernie_tiny": + "ernie-tiny": "https://paddlenlp.bj.bcebos.com/models/transformers/ernie_tiny/ernie_tiny.pdparams", "ernie-2.0-en": "https://paddlenlp.bj.bcebos.com/models/transformers/ernie_v2_base/ernie-2.0-en.pdparams", diff --git a/PaddleNLP/paddlenlp/transformers/ernie/tokenizer.py b/PaddleNLP/paddlenlp/transformers/ernie/tokenizer.py index 239e6d5caeeb5789db83ad830e24da64d6f24cfb..e3c83b5db4f914080787669df9de0cf2c7d4bece 100644 --- a/PaddleNLP/paddlenlp/transformers/ernie/tokenizer.py +++ b/PaddleNLP/paddlenlp/transformers/ernie/tokenizer.py @@ -403,7 +403,7 @@ class ErnieTinyTokenizer(PretrainedTokenizer): Examples: .. code-block:: python from paddlenlp.transformers import ErnieTinyTokenizer - tokenizer = ErnieTinyTokenizer.from_pretrained('ernie_tiny) + tokenizer = ErnieTinyTokenizer.from_pretrained('ernie-tiny) # the following line get: ['he', 'was', 'a', 'puppet', '##eer'] tokens = tokenizer('He was a puppeteer') # the following line get: 'he was a puppeteer' @@ -416,19 +416,19 @@ class ErnieTinyTokenizer(PretrainedTokenizer): } # for save_pretrained pretrained_resource_files_map = { "vocab_file": { - "ernie_tiny": + "ernie-tiny": "https://paddlenlp.bj.bcebos.com/models/transformers/ernie_tiny/vocab.txt" }, "sentencepiece_model_file": { - "ernie_tiny": + "ernie-tiny": "https://paddlenlp.bj.bcebos.com/models/transformers/ernie_tiny/spm_cased_simp_sampled.model" }, "word_dict": { - "ernie_tiny": + "ernie-tiny": "https://paddlenlp.bj.bcebos.com/models/transformers/ernie_tiny/dict.wordseg.pickle" }, } - pretrained_init_configuration = {"ernie_tiny": {"do_lower_case": True}} + pretrained_init_configuration = {"ernie-tiny": {"do_lower_case": True}} def __init__(self, vocab_file, @@ -553,8 +553,8 @@ class ErnieTinyTokenizer(PretrainedTokenizer): save_directory (str): Directory to save files into. """ for name, file_name in self.resource_files_names.items(): - ### TODO: make the name 'ernie_tiny' as a variable - source_path = os.path.join(MODEL_HOME, 'ernie_tiny', file_name) + ### TODO: make the name 'ernie-tiny' as a variable + source_path = os.path.join(MODEL_HOME, 'ernie-tiny', file_name) save_path = os.path.join(save_directory, self.resource_files_names[name]) shutil.copyfile(source_path, save_path)