提交 a1b17d6b 编写于 作者: L linjieccc

fix ernie_v2_eng_base

上级 0e61fa66
......@@ -6,7 +6,7 @@
|数据集|百度自建数据集|
|是否支持Fine-tuning|是|
|模型大小|1.3G|
|最新更新日期|2021-03-16|
|最新更新日期|2021-06-28|
|数据指标|-|
## 一、模型基本信息
......@@ -36,7 +36,7 @@ Ernie是百度提出的基于知识增强的持续学习语义理解模型,该
- ### 2、安装
- ```shell
$ hub install ernie_tiny
$ hub install ernie_v2_eng_base
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
......@@ -57,7 +57,7 @@ label_map = {0: 'negative', 1: 'positive'}
model = hub.Module(
name='ernie_v2_eng_base',
version='2.0.2',
version='2.0.3',
task='seq-cls',
load_checkpoint='/path/to/parameters',
label_map=label_map)
......@@ -195,6 +195,10 @@ for idx, text in enumerate(data):
* 2.0.2
增加文本匹配任务`text-matching`
* 2.0.3
模型底座名称调整
```shell
$ hub install ernie_v2_eng_base==2.0.2
$ hub install ernie_v2_eng_base==2.0.3
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,17 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import os
import math
import os
from typing import Dict
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlenlp.transformers.ernie.modeling import ErnieModel, ErnieForSequenceClassification, ErnieForTokenClassification
from paddlenlp.transformers.ernie.tokenizer import ErnieTokenizer
from paddlenlp.metrics import ChunkEvaluator
from paddlenlp.transformers.ernie.modeling import ErnieForSequenceClassification
from paddlenlp.transformers.ernie.modeling import ErnieForTokenClassification
from paddlenlp.transformers.ernie.modeling import ErnieModel
from paddlenlp.transformers.ernie.tokenizer import ErnieTokenizer
from paddlehub.module.module import moduleinfo
from paddlehub.module.nlp_module import TransformerModule
from paddlehub.utils.log import logger
......@@ -29,7 +31,7 @@ from paddlehub.utils.log import logger
@moduleinfo(
name="ernie_v2_eng_base",
version="2.0.2",
version="2.0.3",
summary=
"Baidu's ERNIE 2.0, Enhanced Representation through kNowledge IntEgration, max_seq_len=512 when predtrained. The module is executed as paddle.dygraph.",
author="paddlepaddle",
......@@ -64,22 +66,24 @@ class ErnieV2(nn.Layer):
"'sequence_classification' has been deprecated and will be removed in the future.", )
if task == 'seq-cls':
self.model = ErnieForSequenceClassification.from_pretrained(
pretrained_model_name_or_path='ernie-2.0-en', num_classes=self.num_classes, **kwargs)
pretrained_model_name_or_path='ernie-2.0-base-en', num_classes=self.num_classes, **kwargs)
self.criterion = paddle.nn.loss.CrossEntropyLoss()
self.metric = paddle.metric.Accuracy()
elif task == 'token-cls':
self.model = ErnieForTokenClassification.from_pretrained(
pretrained_model_name_or_path='ernie-2.0-en', num_classes=self.num_classes, **kwargs)
self.model = ErnieForTokenClassification.from_pretrained(pretrained_model_name_or_path='ernie-2.0-base-en',
num_classes=self.num_classes,
**kwargs)
self.criterion = paddle.nn.loss.CrossEntropyLoss()
self.metric = ChunkEvaluator(label_list=[self.label_map[i] for i in sorted(self.label_map.keys())], suffix=suffix)
self.metric = ChunkEvaluator(label_list=[self.label_map[i] for i in sorted(self.label_map.keys())],
suffix=suffix)
elif task == 'text-matching':
self.model = ErnieModel.from_pretrained(pretrained_model_name_or_path='ernie-2.0-en', **kwargs)
self.model = ErnieModel.from_pretrained(pretrained_model_name_or_path='ernie-2.0-base-en', **kwargs)
self.dropout = paddle.nn.Dropout(0.1)
self.classifier = paddle.nn.Linear(self.model.config['hidden_size'] * 3, 2)
self.criterion = paddle.nn.loss.CrossEntropyLoss()
self.metric = paddle.metric.Accuracy()
elif task is None:
self.model = ErnieModel.from_pretrained(pretrained_model_name_or_path='ernie-2.0-en', **kwargs)
self.model = ErnieModel.from_pretrained(pretrained_model_name_or_path='ernie-2.0-base-en', **kwargs)
else:
raise RuntimeError("Unknown task {}, task should be one in {}".format(task, self._tasks_supported))
......@@ -171,4 +175,4 @@ class ErnieV2(nn.Layer):
"""
Gets the tokenizer that is customized for this module.
"""
return ErnieTokenizer.from_pretrained(pretrained_model_name_or_path='ernie-2.0-en', *args, **kwargs)
return ErnieTokenizer.from_pretrained(pretrained_model_name_or_path='ernie-2.0-base-en', *args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册