From b68da8e62855cabbf09c852e2e5fb0f6dd1d1425 Mon Sep 17 00:00:00 2001
From: KP <109694228@qq.com>
Date: Wed, 6 Jan 2021 20:01:12 +0800
Subject: [PATCH] Add electra modules (#1173)
* Add electra modules
* Amend module versions of chinese-bert-wwm, chinese-bert-wwm-ext, rbt3 and rbtl3
* Updata demo README.md
---
demo/sequence_labeling/README.md | 13 +-
demo/text_classification/README.md | 13 +-
.../language_model/chinese_bert_wwm/README.md | 6 +-
.../language_model/chinese_bert_wwm/module.py | 2 +-
.../chinese_bert_wwm_ext/README.md | 6 +-
.../chinese_bert_wwm_ext/module.py | 2 +-
.../chinese_electra_base/README.md | 150 +++++----
.../chinese_electra_base/model/electra.py | 190 -----------
.../model/transformer_encoder.py | 295 ------------------
.../chinese_electra_base/module.py | 147 ++++++---
.../chinese_electra_small/README.md | 150 +++++----
.../chinese_electra_small/model/electra.py | 200 ------------
.../model/transformer_encoder.py | 295 ------------------
.../chinese_electra_small/module.py | 147 ++++++---
.../language_model/electra_base/README.md | 153 +++++++++
.../model => electra_base}/__init__.py | 0
.../language_model/electra_base/module.py | 130 ++++++++
.../language_model/electra_large/README.md | 153 +++++++++
.../model => electra_large}/__init__.py | 0
.../language_model/electra_large/module.py | 130 ++++++++
.../language_model/electra_small/README.md | 153 +++++++++
.../language_model/electra_small/__init__.py | 0
.../language_model/electra_small/module.py | 130 ++++++++
modules/text/language_model/rbt3/README.md | 6 +-
modules/text/language_model/rbt3/module.py | 2 +-
modules/text/language_model/rbtl3/README.md | 6 +-
modules/text/language_model/rbtl3/module.py | 2 +-
27 files changed, 1275 insertions(+), 1206 deletions(-)
delete mode 100644 modules/text/language_model/chinese_electra_base/model/electra.py
delete mode 100644 modules/text/language_model/chinese_electra_base/model/transformer_encoder.py
delete mode 100644 modules/text/language_model/chinese_electra_small/model/electra.py
delete mode 100644 modules/text/language_model/chinese_electra_small/model/transformer_encoder.py
create mode 100644 modules/text/language_model/electra_base/README.md
rename modules/text/language_model/{chinese_electra_base/model => electra_base}/__init__.py (100%)
create mode 100644 modules/text/language_model/electra_base/module.py
create mode 100644 modules/text/language_model/electra_large/README.md
rename modules/text/language_model/{chinese_electra_small/model => electra_large}/__init__.py (100%)
create mode 100644 modules/text/language_model/electra_large/module.py
create mode 100644 modules/text/language_model/electra_small/README.md
create mode 100644 modules/text/language_model/electra_small/__init__.py
create mode 100644 modules/text/language_model/electra_small/module.py
diff --git a/demo/sequence_labeling/README.md b/demo/sequence_labeling/README.md
index 5c8be253..2b39458e 100644
--- a/demo/sequence_labeling/README.md
+++ b/demo/sequence_labeling/README.md
@@ -60,10 +60,10 @@ ERNIE, Chinese | `hub.Module(name='ernie')`
ERNIE tiny, Chinese | `hub.Module(name='ernie_tiny')`
ERNIE 2.0 Base, English | `hub.Module(name='ernie_v2_eng_base')`
ERNIE 2.0 Large, English | `hub.Module(name='ernie_v2_eng_large')`
-BERT-Base, Cased | `hub.Module(name='bert-base-cased')`
-BERT-Base, Uncased | `hub.Module(name='bert-base-uncased')`
-BERT-Large, Cased | `hub.Module(name='bert-large-cased')`
-BERT-Large, Uncased | `hub.Module(name='bert-large-uncased')`
+BERT-Base, English Cased | `hub.Module(name='bert-base-cased')`
+BERT-Base, English Uncased | `hub.Module(name='bert-base-uncased')`
+BERT-Large, English Cased | `hub.Module(name='bert-large-cased')`
+BERT-Large, English Uncased | `hub.Module(name='bert-large-uncased')`
BERT-Base, Multilingual Cased | `hub.Module(nane='bert-base-multilingual-cased')`
BERT-Base, Multilingual Uncased | `hub.Module(nane='bert-base-multilingual-uncased')`
BERT-Base, Chinese | `hub.Module(name='bert-base-chinese')`
@@ -73,6 +73,11 @@ RoBERTa-wwm-ext, Chinese | `hub.Module(name='roberta-wwm-ext')`
RoBERTa-wwm-ext-large, Chinese | `hub.Module(name='roberta-wwm-ext-large')`
RBT3, Chinese | `hub.Module(name='rbt3')`
RBTL3, Chinese | `hub.Module(name='rbtl3')`
+ELECTRA-Small, English | `hub.Module(name='electra-small')`
+ELECTRA-Base, English | `hub.Module(name='electra-base')`
+ELECTRA-Large, English | `hub.Module(name='electra-large')`
+ELECTRA-Base, Chinese | `hub.Module(name='chinese-electra-base')`
+ELECTRA-Small, Chinese | `hub.Module(name='chinese-electra-small')`
通过以上的一行代码,`model`初始化为一个适用于序列标注任务的模型,为ERNIE Tiny的预训练模型后拼接上一个输出token共享的全连接网络(Full Connected)。
![](https://ss1.bdstatic.com/70cFuXSh_Q1YnxGkpoWK1HF6hhy/it/u=224484727,3049769188&fm=15&gp=0.jpg)
diff --git a/demo/text_classification/README.md b/demo/text_classification/README.md
index 70113ccd..4d0de6e4 100644
--- a/demo/text_classification/README.md
+++ b/demo/text_classification/README.md
@@ -49,10 +49,10 @@ ERNIE, Chinese | `hub.Module(name='ernie')`
ERNIE tiny, Chinese | `hub.Module(name='ernie_tiny')`
ERNIE 2.0 Base, English | `hub.Module(name='ernie_v2_eng_base')`
ERNIE 2.0 Large, English | `hub.Module(name='ernie_v2_eng_large')`
-BERT-Base, Cased | `hub.Module(name='bert-base-cased')`
-BERT-Base, Uncased | `hub.Module(name='bert-base-uncased')`
-BERT-Large, Cased | `hub.Module(name='bert-large-cased')`
-BERT-Large, Uncased | `hub.Module(name='bert-large-uncased')`
+BERT-Base, English Cased | `hub.Module(name='bert-base-cased')`
+BERT-Base, English Uncased | `hub.Module(name='bert-base-uncased')`
+BERT-Large, English Cased | `hub.Module(name='bert-large-cased')`
+BERT-Large, English Uncased | `hub.Module(name='bert-large-uncased')`
BERT-Base, Multilingual Cased | `hub.Module(nane='bert-base-multilingual-cased')`
BERT-Base, Multilingual Uncased | `hub.Module(nane='bert-base-multilingual-uncased')`
BERT-Base, Chinese | `hub.Module(name='bert-base-chinese')`
@@ -62,6 +62,11 @@ RoBERTa-wwm-ext, Chinese | `hub.Module(name='roberta-wwm-ext')`
RoBERTa-wwm-ext-large, Chinese | `hub.Module(name='roberta-wwm-ext-large')`
RBT3, Chinese | `hub.Module(name='rbt3')`
RBTL3, Chinese | `hub.Module(name='rbtl3')`
+ELECTRA-Small, English | `hub.Module(name='electra-small')`
+ELECTRA-Base, English | `hub.Module(name='electra-base')`
+ELECTRA-Large, English | `hub.Module(name='electra-large')`
+ELECTRA-Base, Chinese | `hub.Module(name='chinese-electra-base')`
+ELECTRA-Small, Chinese | `hub.Module(name='chinese-electra-small')`
通过以上的一行代码,`model`初始化为一个适用于文本分类任务的模型,为ERNIE Tiny的预训练模型后拼接上一个全连接网络(Full Connected)。
![](https://ai-studio-static-online.cdn.bcebos.com/f9e1bf9d56c6412d939960f2e3767c2f13b93eab30554d738b137ab2b98e328c)
diff --git a/modules/text/language_model/chinese_bert_wwm/README.md b/modules/text/language_model/chinese_bert_wwm/README.md
index ca1e3aae..61eabad7 100644
--- a/modules/text/language_model/chinese_bert_wwm/README.md
+++ b/modules/text/language_model/chinese_bert_wwm/README.md
@@ -1,5 +1,5 @@
```shell
-$ hub install chinese-bert-wwm==2.0.1
+$ hub install chinese-bert-wwm==2.0.0
```
@@ -82,7 +82,7 @@ label_map = {0: 'negative', 1: 'positive'}
model = hub.Module(
name='chinese-bert-wwm',
- version='2.0.1',
+ version='2.0.0',
task='seq-cls',
load_checkpoint='/path/to/parameters',
label_map=label_map)
@@ -153,6 +153,6 @@ paddlehub >= 2.0.0
初始发布
-* 2.0.1
+* 2.0.0
全面升级动态图,接口有所变化。任务名称调整,增加序列标注任务`token-cls`
diff --git a/modules/text/language_model/chinese_bert_wwm/module.py b/modules/text/language_model/chinese_bert_wwm/module.py
index 3ee03088..b225bb4e 100644
--- a/modules/text/language_model/chinese_bert_wwm/module.py
+++ b/modules/text/language_model/chinese_bert_wwm/module.py
@@ -29,7 +29,7 @@ from paddlehub.utils.log import logger
@moduleinfo(
name="chinese-bert-wwm",
- version="2.0.1",
+ version="2.0.0",
summary=
"chinese-bert-wwm, 12-layer, 768-hidden, 12-heads, 110M parameters. The module is executed as paddle.dygraph.",
author="ymcui",
diff --git a/modules/text/language_model/chinese_bert_wwm_ext/README.md b/modules/text/language_model/chinese_bert_wwm_ext/README.md
index 45709dbe..7a287a30 100644
--- a/modules/text/language_model/chinese_bert_wwm_ext/README.md
+++ b/modules/text/language_model/chinese_bert_wwm_ext/README.md
@@ -1,5 +1,5 @@
```shell
-$ hub install chinese-bert-wwm-ext==2.0.1
+$ hub install chinese-bert-wwm-ext==2.0.0
```
@@ -82,7 +82,7 @@ label_map = {0: 'negative', 1: 'positive'}
model = hub.Module(
name='chinese-bert-wwm-ext',
- version='2.0.1',
+ version='2.0.0',
task='seq-cls',
load_checkpoint='/path/to/parameters',
label_map=label_map)
@@ -153,6 +153,6 @@ paddlehub >= 2.0.0
初始发布
-* 2.0.1
+* 2.0.0
全面升级动态图,接口有所变化。任务名称调整,增加序列标注任务`token-cls`
diff --git a/modules/text/language_model/chinese_bert_wwm_ext/module.py b/modules/text/language_model/chinese_bert_wwm_ext/module.py
index 6ff6803f..2a4e8256 100644
--- a/modules/text/language_model/chinese_bert_wwm_ext/module.py
+++ b/modules/text/language_model/chinese_bert_wwm_ext/module.py
@@ -29,7 +29,7 @@ from paddlehub.utils.log import logger
@moduleinfo(
name="chinese-bert-wwm-ext",
- version="2.0.1",
+ version="2.0.0",
summary=
"chinese-bert-wwm-ext, 12-layer, 768-hidden, 12-heads, 110M parameters. The module is executed as paddle.dygraph.",
author="ymcui",
diff --git a/modules/text/language_model/chinese_electra_base/README.md b/modules/text/language_model/chinese_electra_base/README.md
index 9fbf6b60..b3dfff0c 100644
--- a/modules/text/language_model/chinese_electra_base/README.md
+++ b/modules/text/language_model/chinese_electra_base/README.md
@@ -1,73 +1,73 @@
```shell
-$ hub install chinese-electra-base==1.0.0
+$ hub install chinese-electra-base==2.0.0
```
+
-
+
更多详情请参考[ELECTRA论文](https://openreview.net/pdf?id=r1xMH1BtvB)
## API
```python
-def context(
- trainable=True,
- max_seq_len=128
+def __init__(
+ task=None,
+ load_checkpoint=None,
+ label_map=None,
+ num_classes=2,
+ **kwargs,
)
```
-用于获取Module的上下文信息,得到输入、输出以及预训练的Paddle Program副本
-
-**参数**
-> trainable:设置为True时,Module中的参数在Fine-tune时也会随之训练,否则保持不变。
-> max_seq_len:BERT模型的最大序列长度,若序列长度不足,会通过padding方式补到**max_seq_len**, 若序列长度大于该值,则会以截断方式让序列长度为**max_seq_len**,max_seq_len可取值范围为0~512;
+创建Module对象(动态图组网版本)。
-**返回**
-> inputs:dict类型,有以下字段:
-> >**input_ids**存放输入文本tokenize后各token对应BERT词汇表的word ids, shape为\[batch_size, max_seq_len\],int64类型;
-> >**position_ids**存放输入文本tokenize后各token所在该文本的位置,shape为\[batch_size, max_seq_len\],int64类型;
-> >**segment_ids**存放各token所在文本的标识(token属于文本1或者文本2),shape为\[batch_size, max_seq_len\],int64类型;
-> >**input_mask**存放token是否为padding的标识,shape为\[batch_size, max_seq_len\],int64类型;
->
-> outputs:dict类型,Module的输出特征,有以下字段:
-> >**pooled_output**字段存放句子粒度的特征,可用于文本分类等任务,shape为 \[batch_size, 768\],int64类型;
-> >**sequence_output**字段存放字粒度的特征,可用于序列标注等任务,shape为 \[batch_size, seq_len, 768\],int64类型;
->
-> program:包含该Module计算图的Program。
+**参数**
+* `task`: 任务名称,可为`seq-cls`(文本分类任务,原来的`sequence_classification`在未来会被弃用)或`token-cls`(序列标注任务)。
+* `load_checkpoint`:使用PaddleHub Fine-tune api训练保存的模型参数文件路径。
+* `label_map`:预测时的类别映射表。
+* `num_classes`:分类任务的类别数,如果指定了`label_map`,此参数可不传,默认2分类。
+* `**kwargs`:用户额外指定的关键字字典类型的参数。
```python
-def get_embedding(
- texts,
- use_gpu=False,
- batch_size=1
+def predict(
+ data,
+ max_seq_len=128,
+ batch_size=1,
+ use_gpu=False
)
```
-用于获取输入文本的句子粒度特征与字粒度特征
-
**参数**
-> texts:输入文本列表,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。
-> use_gpu:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
+* `data`: 待预测数据,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。每个样例文本数量(1个或者2个)需和训练时保持一致。
+* `max_seq_len`:模型处理文本的最大长度
+* `batch_size`:模型批处理大小
+* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
**返回**
-> results:list类型,格式为\[\[sample\_a\_pooled\_feature, sample\_a\_seq\_feature\], \[sample\_b\_pooled\_feature, sample\_b\_seq\_feature\],…,\],其中每个元素都是对应样例的特征输出,每个样例都有句子粒度特征pooled\_feature与字粒度特征seq\_feature。
->
+* `results`:list类型,不同任务类型的返回结果如下
+ * 文本分类:列表里包含每个句子的预测标签,格式为\[label\_1, label\_2, …,\]
+ * 序列标注:列表里包含每个句子每个token的预测标签,格式为\[\[token\_1, token\_2, …,\], \[token\_1, token\_2, …,\], …,\]
```python
-def get_params_layer()
+def get_embedding(
+ data,
+ use_gpu=False
+)
```
-用于获取参数层信息,该方法与ULMFiTStrategy联用可以严格按照层数设置分层学习率与逐层解冻。
+用于获取输入文本的句子粒度特征与字粒度特征
**参数**
-> 无
+* `data`:输入文本列表,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。
+* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
**返回**
-> params_layer:dict类型,key为参数名,值为参数所在层数
+* `results`:list类型,格式为\[\[sample\_a\_pooled\_feature, sample\_a\_seq\_feature\], \[sample\_b\_pooled\_feature, sample\_b\_seq\_feature\],…,\],其中每个元素都是对应样例的特征输出,每个样例都有句子粒度特征pooled\_feature与字粒度特征seq\_feature。
**代码示例**
@@ -75,45 +75,83 @@ def get_params_layer()
```python
import paddlehub as hub
-# Load $ hub install chinese-electra-base pretrained model
-module = hub.Module(name="chinese-electra-base")
-inputs, outputs, program = module.context(trainable=True, max_seq_len=128)
+data = [
+ ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般'],
+ ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片'],
+ ['作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。'],
+]
+label_map = {0: 'negative', 1: 'positive'}
+
+model = hub.Module(
+ name='chinese-electra-base',
+ version='2.0.0',
+ task='seq-cls',
+ load_checkpoint='/path/to/parameters',
+ label_map=label_map)
+results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False)
+for idx, text in enumerate(data):
+ print('Data: {} \t Lable: {}'.format(text, results[idx]))
+```
+
+详情可参考PaddleHub示例:
+- [文本分类](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0.0-beta/demo/text_classification)
+- [序列标注](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0.0-beta/demo/sequence_labeling)
-# Must feed all the tensor of chinese-electra-base's module need
-input_ids = inputs["input_ids"]
-position_ids = inputs["position_ids"]
-segment_ids = inputs["segment_ids"]
-input_mask = inputs["input_mask"]
+## 服务部署
-# Use "pooled_output" for sentence-level output.
-pooled_output = outputs["pooled_output"]
+PaddleHub Serving可以部署一个在线获取预训练词向量。
-# Use "sequence_output" for token-level output.
-sequence_output = outputs["sequence_output"]
+### Step1: 启动PaddleHub Serving
-# Use "get_embedding" to get embedding result.
-embedding_result = module.get_embedding(texts=[["Sample1_text_a"],["Sample2_text_a","Sample2_text_b"]], use_gpu=True)
+运行启动命令:
-# Use "get_params_layer" to get params layer and used to ULMFiTStrategy.
-params_layer = module.get_params_layer()
-strategy = hub.finetune.strategy.ULMFiTStrategy(frz_params_layer=params_layer, dis_params_layer=params_layer)
+```shell
+$ hub serving start -m chinese-electra-base
```
+这样就完成了一个获取预训练词向量服务化API的部署,默认端口号为8866。
+
+**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。
+
+### Step2: 发送预测请求
+
+配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
+
+```python
+import requests
+import json
+
+# 指定用于获取embedding的文本[[text_1], [text_2], ... ]}
+text = [["今天是个好日子"], ["天气预报说今天要下雨"]]
+# 以key的方式指定text传入预测方法的时的参数,此例中为"data"
+# 对应本地部署,则为module.get_embedding(data=text)
+data = {"data": text}
+# 发送post请求,content-type类型应指定json方式,url中的ip地址需改为对应机器的ip
+url = "http://10.12.121.132:8866/predict/chinese-electra-base"
+# 指定post请求的headers为application/json方式
+headers = {"Content-Type": "application/json"}
+
+r = requests.post(url=url, headers=headers, data=json.dumps(data))
+print(r.json())
+```
## 查看代码
https://github.com/ymcui/Chinese-ELECTRA
-
## 依赖
-paddlepaddle >= 1.6.2
+paddlepaddle >= 2.0.0
-paddlehub >= 1.6.0
+paddlehub >= 2.0.0
## 更新历史
* 1.0.0
初始发布
+
+* 2.0.0
+
+ 全面升级动态图,接口有所变化。任务名称调整,增加序列标注任务`token-cls`
diff --git a/modules/text/language_model/chinese_electra_base/model/electra.py b/modules/text/language_model/chinese_electra_base/model/electra.py
deleted file mode 100644
index a2b647d1..00000000
--- a/modules/text/language_model/chinese_electra_base/model/electra.py
+++ /dev/null
@@ -1,190 +0,0 @@
-# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""ELECTRA model."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import six
-import json
-
-import paddle.fluid as fluid
-
-from chinese_electra_base.model.transformer_encoder import encoder, pre_process_layer
-
-
-class ElectraConfig(object):
- def __init__(self, config_path):
- self._config_dict = self._parse(config_path)
-
- def _parse(self, config_path):
- try:
- with open(config_path) as json_file:
- config_dict = json.load(json_file)
- except Exception:
- raise IOError("Error in parsing electra model config file '%s'" % config_path)
- else:
- return config_dict
-
- def __getitem__(self, key):
- return self._config_dict[key]
-
- def print_config(self):
- for arg, value in sorted(six.iteritems(self._config_dict)):
- print('%s: %s' % (arg, value))
- print('------------------------------------------------')
-
-
-class ElectraModel(object):
- def __init__(self, src_ids, position_ids, sentence_ids, input_mask, config, weight_sharing=True, use_fp16=False):
-
- self._emb_size = config['hidden_size']
- self._n_layer = config['num_hidden_layers']
- self._n_head = config['num_attention_heads']
- self._voc_size = config['vocab_size']
- self._max_position_seq_len = config['max_position_embeddings']
- self._sent_types = config['type_vocab_size']
- self._hidden_act = config['hidden_act']
- self._prepostprocess_dropout = config['hidden_dropout_prob']
- self._attention_dropout = config['attention_probs_dropout_prob']
- self._weight_sharing = weight_sharing
-
- self._word_emb_name = "word_embedding"
- self._pos_emb_name = "pos_embedding"
- self._sent_emb_name = "sent_embedding"
- self._dtype = "float16" if use_fp16 else "float32"
-
- # Initialize all weigths by truncated normal initializer, and all biases
- # will be initialized by constant zero by default.
- self._param_initializer = fluid.initializer.TruncatedNormal(scale=config['initializer_range'])
-
- self._build_model(src_ids, position_ids, sentence_ids, input_mask)
-
- def _build_model(self, src_ids, position_ids, sentence_ids, input_mask):
- # padding id in vocabulary must be set to 0
- emb_out = fluid.layers.embedding(input=src_ids,
- size=[self._voc_size, self._emb_size],
- dtype=self._dtype,
- param_attr=fluid.ParamAttr(name=self._word_emb_name,
- initializer=self._param_initializer),
- is_sparse=False)
- position_emb_out = fluid.layers.embedding(input=position_ids,
- size=[self._max_position_seq_len, self._emb_size],
- dtype=self._dtype,
- param_attr=fluid.ParamAttr(name=self._pos_emb_name,
- initializer=self._param_initializer))
-
- sent_emb_out = fluid.layers.embedding(sentence_ids,
- size=[self._sent_types, self._emb_size],
- dtype=self._dtype,
- param_attr=fluid.ParamAttr(name=self._sent_emb_name,
- initializer=self._param_initializer))
-
- emb_out = emb_out + position_emb_out
- emb_out = emb_out + sent_emb_out
-
- emb_out = pre_process_layer(emb_out, 'nd', self._prepostprocess_dropout, name='pre_encoder')
-
- if self._dtype == "float16":
- input_mask = fluid.layers.cast(x=input_mask, dtype=self._dtype)
-
- self_attn_mask = fluid.layers.matmul(x=input_mask, y=input_mask, transpose_y=True)
- self_attn_mask = fluid.layers.scale(x=self_attn_mask, scale=10000.0, bias=-1.0, bias_after_scale=False)
- n_head_self_attn_mask = fluid.layers.stack(x=[self_attn_mask] * self._n_head, axis=1)
- n_head_self_attn_mask.stop_gradient = True
-
- self._enc_out = encoder(enc_input=emb_out,
- attn_bias=n_head_self_attn_mask,
- n_layer=self._n_layer,
- n_head=self._n_head,
- d_key=self._emb_size // self._n_head,
- d_value=self._emb_size // self._n_head,
- d_model=self._emb_size,
- d_inner_hid=self._emb_size * 4,
- prepostprocess_dropout=self._prepostprocess_dropout,
- attention_dropout=self._attention_dropout,
- relu_dropout=0,
- hidden_act=self._hidden_act,
- preprocess_cmd="",
- postprocess_cmd="dan",
- param_initializer=self._param_initializer,
- name='encoder')
-
- def get_sequence_output(self):
- return self._enc_out
-
- def get_pooled_output(self):
- """Get the first feature of each sequence for classification"""
- next_sent_feat = fluid.layers.slice(input=self._enc_out, axes=[1], starts=[0], ends=[1])
- return next_sent_feat
-
- def get_pretraining_output(self, mask_label, mask_pos, labels):
- """Get the loss & accuracy for pretraining"""
-
- mask_pos = fluid.layers.cast(x=mask_pos, dtype='int32')
-
- # extract the first token feature in each sentence
- next_sent_feat = self.get_pooled_output()
- reshaped_emb_out = fluid.layers.reshape(x=self._enc_out, shape=[-1, self._emb_size])
- # extract masked tokens' feature
- mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos)
-
- # transform: fc
- mask_trans_feat = fluid.layers.fc(input=mask_feat,
- size=self._emb_size,
- act=self._hidden_act,
- param_attr=fluid.ParamAttr(name='mask_lm_trans_fc.w_0',
- initializer=self._param_initializer),
- bias_attr=fluid.ParamAttr(name='mask_lm_trans_fc.b_0'))
- # transform: layer norm
- mask_trans_feat = pre_process_layer(mask_trans_feat, 'n', name='mask_lm_trans')
-
- mask_lm_out_bias_attr = fluid.ParamAttr(name="mask_lm_out_fc.b_0",
- initializer=fluid.initializer.Constant(value=0.0))
- if self._weight_sharing:
- fc_out = fluid.layers.matmul(x=mask_trans_feat,
- y=fluid.default_main_program().global_block().var(self._word_emb_name),
- transpose_y=True)
- fc_out += fluid.layers.create_parameter(shape=[self._voc_size],
- dtype=self._dtype,
- attr=mask_lm_out_bias_attr,
- is_bias=True)
-
- else:
- fc_out = fluid.layers.fc(input=mask_trans_feat,
- size=self._voc_size,
- param_attr=fluid.ParamAttr(name="mask_lm_out_fc.w_0",
- initializer=self._param_initializer),
- bias_attr=mask_lm_out_bias_attr)
-
- mask_lm_loss = fluid.layers.softmax_with_cross_entropy(logits=fc_out, label=mask_label)
- mean_mask_lm_loss = fluid.layers.mean(mask_lm_loss)
-
- next_sent_fc_out = fluid.layers.fc(input=next_sent_feat,
- size=2,
- param_attr=fluid.ParamAttr(name="next_sent_fc.w_0",
- initializer=self._param_initializer),
- bias_attr="next_sent_fc.b_0")
-
- next_sent_loss, next_sent_softmax = fluid.layers.softmax_with_cross_entropy(logits=next_sent_fc_out,
- label=labels,
- return_softmax=True)
-
- next_sent_acc = fluid.layers.accuracy(input=next_sent_softmax, label=labels)
-
- mean_next_sent_loss = fluid.layers.mean(next_sent_loss)
-
- loss = mean_next_sent_loss + mean_mask_lm_loss
- return next_sent_acc, mean_mask_lm_loss, loss
diff --git a/modules/text/language_model/chinese_electra_base/model/transformer_encoder.py b/modules/text/language_model/chinese_electra_base/model/transformer_encoder.py
deleted file mode 100644
index b15d8388..00000000
--- a/modules/text/language_model/chinese_electra_base/model/transformer_encoder.py
+++ /dev/null
@@ -1,295 +0,0 @@
-# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Transformer encoder."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from functools import partial
-
-import paddle.fluid as fluid
-import paddle.fluid.layers as layers
-
-
-def multi_head_attention(queries,
- keys,
- values,
- attn_bias,
- d_key,
- d_value,
- d_model,
- n_head=1,
- dropout_rate=0.,
- cache=None,
- param_initializer=None,
- name='multi_head_att'):
- """
- Multi-Head Attention. Note that attn_bias is added to the logit before
- computing softmax activiation to mask certain selected positions so that
- they will not considered in attention weights.
- """
- keys = queries if keys is None else keys
- values = keys if values is None else values
-
- if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
- raise ValueError("Inputs: quries, keys and values should all be 3-D tensors.")
-
- def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
- """
- Add linear projection to queries, keys, and values.
- """
- q = layers.fc(input=queries,
- size=d_key * n_head,
- num_flatten_dims=2,
- param_attr=fluid.ParamAttr(name=name + '_query_fc.w_0', initializer=param_initializer),
- bias_attr=name + '_query_fc.b_0')
- k = layers.fc(input=keys,
- size=d_key * n_head,
- num_flatten_dims=2,
- param_attr=fluid.ParamAttr(name=name + '_key_fc.w_0', initializer=param_initializer),
- bias_attr=name + '_key_fc.b_0')
- v = layers.fc(input=values,
- size=d_value * n_head,
- num_flatten_dims=2,
- param_attr=fluid.ParamAttr(name=name + '_value_fc.w_0', initializer=param_initializer),
- bias_attr=name + '_value_fc.b_0')
- return q, k, v
-
- def __split_heads(x, n_head):
- """
- Reshape the last dimension of inpunt tensor x so that it becomes two
- dimensions and then transpose. Specifically, input a tensor with shape
- [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
- with shape [bs, n_head, max_sequence_length, hidden_dim].
- """
- hidden_size = x.shape[-1]
- # The value 0 in shape attr means copying the corresponding dimension
- # size of the input as the output dimension size.
- reshaped = layers.reshape(x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
-
- # permuate the dimensions into:
- # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
- return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
-
- def __combine_heads(x):
- """
- Transpose and then reshape the last two dimensions of inpunt tensor x
- so that it becomes one dimension, which is reverse to __split_heads.
- """
- if len(x.shape) == 3: return x
- if len(x.shape) != 4:
- raise ValueError("Input(x) should be a 4-D Tensor.")
-
- trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
- # The value 0 in shape attr means copying the corresponding dimension
- # size of the input as the output dimension size.
- return layers.reshape(x=trans_x, shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]], inplace=True)
-
- def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
- """
- Scaled Dot-Product Attention
- """
- scaled_q = layers.scale(x=q, scale=d_key**-0.5)
- product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
- if attn_bias:
- product += attn_bias
- weights = layers.softmax(product)
- if dropout_rate:
- weights = layers.dropout(weights,
- dropout_prob=dropout_rate,
- dropout_implementation="upscale_in_train",
- is_test=False)
- out = layers.matmul(weights, v)
- return out
-
- q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
-
- if cache is not None: # use cache and concat time steps
- # Since the inplace reshape in __split_heads changes the shape of k and
- # v, which is the cache input for next time step, reshape the cache
- # input from the previous time step first.
- k = cache["k"] = layers.concat([layers.reshape(cache["k"], shape=[0, 0, d_model]), k], axis=1)
- v = cache["v"] = layers.concat([layers.reshape(cache["v"], shape=[0, 0, d_model]), v], axis=1)
-
- q = __split_heads(q, n_head)
- k = __split_heads(k, n_head)
- v = __split_heads(v, n_head)
-
- ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate)
-
- out = __combine_heads(ctx_multiheads)
-
- # Project back to the model size.
- proj_out = layers.fc(input=out,
- size=d_model,
- num_flatten_dims=2,
- param_attr=fluid.ParamAttr(name=name + '_output_fc.w_0', initializer=param_initializer),
- bias_attr=name + '_output_fc.b_0')
- return proj_out
-
-
-def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate, hidden_act, param_initializer=None, name='ffn'):
- """
- Position-wise Feed-Forward Networks.
- This module consists of two linear transformations with a ReLU activation
- in between, which is applied to each position separately and identically.
- """
- hidden = layers.fc(input=x,
- size=d_inner_hid,
- num_flatten_dims=2,
- act=hidden_act,
- param_attr=fluid.ParamAttr(name=name + '_fc_0.w_0', initializer=param_initializer),
- bias_attr=name + '_fc_0.b_0')
- if dropout_rate:
- hidden = layers.dropout(hidden,
- dropout_prob=dropout_rate,
- dropout_implementation="upscale_in_train",
- is_test=False)
- out = layers.fc(input=hidden,
- size=d_hid,
- num_flatten_dims=2,
- param_attr=fluid.ParamAttr(name=name + '_fc_1.w_0', initializer=param_initializer),
- bias_attr=name + '_fc_1.b_0')
- return out
-
-
-def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0., name=''):
- """
- Add residual connection, layer normalization and droput to the out tensor
- optionally according to the value of process_cmd.
- This will be used before or after multi-head attention and position-wise
- feed-forward networks.
- """
- for cmd in process_cmd:
- if cmd == "a": # add residual connection
- out = out + prev_out if prev_out else out
- elif cmd == "n": # add layer normalization
- out_dtype = out.dtype
- if out_dtype == fluid.core.VarDesc.VarType.FP16:
- out = layers.cast(x=out, dtype="float32")
- out = layers.layer_norm(out,
- begin_norm_axis=len(out.shape) - 1,
- param_attr=fluid.ParamAttr(name=name + '_layer_norm_scale',
- initializer=fluid.initializer.Constant(1.)),
- bias_attr=fluid.ParamAttr(name=name + '_layer_norm_bias',
- initializer=fluid.initializer.Constant(0.)))
- if out_dtype == fluid.core.VarDesc.VarType.FP16:
- out = layers.cast(x=out, dtype="float16")
- elif cmd == "d": # add dropout
- if dropout_rate:
- out = layers.dropout(out,
- dropout_prob=dropout_rate,
- dropout_implementation="upscale_in_train",
- is_test=False)
- return out
-
-
-pre_process_layer = partial(pre_post_process_layer, None)
-post_process_layer = pre_post_process_layer
-
-
-def encoder_layer(enc_input,
- attn_bias,
- n_head,
- d_key,
- d_value,
- d_model,
- d_inner_hid,
- prepostprocess_dropout,
- attention_dropout,
- relu_dropout,
- hidden_act,
- preprocess_cmd="n",
- postprocess_cmd="da",
- param_initializer=None,
- name=''):
- """The encoder layers that can be stacked to form a deep encoder.
- This module consits of a multi-head (self) attention followed by
- position-wise feed-forward networks and both the two components companied
- with the post_process_layer to add residual connection, layer normalization
- and droput.
- """
- attn_output = multi_head_attention(pre_process_layer(enc_input,
- preprocess_cmd,
- prepostprocess_dropout,
- name=name + '_pre_att'),
- None,
- None,
- attn_bias,
- d_key,
- d_value,
- d_model,
- n_head,
- attention_dropout,
- param_initializer=param_initializer,
- name=name + '_multi_head_att')
- attn_output = post_process_layer(enc_input,
- attn_output,
- postprocess_cmd,
- prepostprocess_dropout,
- name=name + '_post_att')
- ffd_output = positionwise_feed_forward(pre_process_layer(attn_output,
- preprocess_cmd,
- prepostprocess_dropout,
- name=name + '_pre_ffn'),
- d_inner_hid,
- d_model,
- relu_dropout,
- hidden_act,
- param_initializer=param_initializer,
- name=name + '_ffn')
- return post_process_layer(attn_output, ffd_output, postprocess_cmd, prepostprocess_dropout, name=name + '_post_ffn')
-
-
-def encoder(enc_input,
- attn_bias,
- n_layer,
- n_head,
- d_key,
- d_value,
- d_model,
- d_inner_hid,
- prepostprocess_dropout,
- attention_dropout,
- relu_dropout,
- hidden_act,
- preprocess_cmd="n",
- postprocess_cmd="da",
- param_initializer=None,
- name=''):
- """
- The encoder is composed of a stack of identical layers returned by calling
- encoder_layer.
- """
- for i in range(n_layer):
- enc_output = encoder_layer(enc_input,
- attn_bias,
- n_head,
- d_key,
- d_value,
- d_model,
- d_inner_hid,
- prepostprocess_dropout,
- attention_dropout,
- relu_dropout,
- hidden_act,
- preprocess_cmd,
- postprocess_cmd,
- param_initializer=param_initializer,
- name=name + '_layer_' + str(i))
- enc_input = enc_output
- enc_output = pre_process_layer(enc_output, preprocess_cmd, prepostprocess_dropout, name="post_encoder")
-
- return enc_output
diff --git a/modules/text/language_model/chinese_electra_base/module.py b/modules/text/language_model/chinese_electra_base/module.py
index 8a24ffd3..338c6605 100644
--- a/modules/text/language_model/chinese_electra_base/module.py
+++ b/modules/text/language_model/chinese_electra_base/module.py
@@ -1,7 +1,6 @@
-# coding:utf-8
-# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
-# Licensed under the Apache License, Version 2.0 (the "License"
+# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
@@ -12,62 +11,120 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
+from typing import Dict
import os
-from paddlehub import TransformerModule
-from paddlehub.module.module import moduleinfo
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
-from chinese_electra_base.model.electra import ElectraConfig, ElectraModel
+from paddlenlp.transformers.electra.modeling import ElectraForSequenceClassification, ElectraForTokenClassification, ElectraModel
+from paddlenlp.transformers.electra.tokenizer import ElectraTokenizer
+from paddlenlp.metrics import ChunkEvaluator
+from paddlehub.module.module import moduleinfo
+from paddlehub.module.nlp_module import TransformerModule
+from paddlehub.utils.log import logger
@moduleinfo(
name="chinese-electra-base",
- version="1.0.0",
- summary="chinese-electra-base, 12-layer, 768-hidden, 12-heads, 102M parameters",
+ version="2.0.0",
+ summary=
+ "chinese-electra-base, 12-layer, 768-hidden, 12-heads, 102M parameters. The module is executed as paddle.dygraph.",
author="ymcui",
author_email="ymcui@ir.hit.edu.cn",
type="nlp/semantic_model",
+ meta=TransformerModule,
)
-class Electra(TransformerModule):
- def _initialize(self):
- self.MAX_SEQ_LEN = 512
- self.params_path = os.path.join(self.directory, "assets", "params")
- self.vocab_path = os.path.join(self.directory, "assets", "vocab.txt")
+class Electra(nn.Layer):
+ """
+ Electra model
+ """
- electra_config_path = os.path.join(self.directory, "assets", "config.json")
- self.electra_config = ElectraConfig(electra_config_path)
+ def __init__(
+ self,
+ task: str = None,
+ load_checkpoint: str = None,
+ label_map: Dict = None,
+ num_classes: int = 2,
+ **kwargs,
+ ):
+ super(Electra, self).__init__()
+ if label_map:
+ self.label_map = label_map
+ self.num_classes = len(label_map)
+ else:
+ self.num_classes = num_classes
- def net(self, input_ids, position_ids, segment_ids, input_mask):
- """
- create neural network.
+ if task == 'sequence_classification':
+ task = 'seq-cls'
+ logger.warning(
+ "current task name 'sequence_classification' was renamed to 'seq-cls', "
+ "'sequence_classification' has been deprecated and will be removed in the future.",
+ )
+ if task == 'seq-cls':
+ self.model = ElectraForSequenceClassification.from_pretrained(
+ pretrained_model_name_or_path='chinese-electra-base',
+ num_classes=self.num_classes,
+ **kwargs
+ )
+ self.criterion = paddle.nn.loss.CrossEntropyLoss()
+ self.metric = paddle.metric.Accuracy()
+ elif task == 'token-cls':
+ self.model = ElectraForTokenClassification.from_pretrained(
+ pretrained_model_name_or_path='chinese-electra-base',
+ num_classes=self.num_classes,
+ **kwargs
+ )
+ self.criterion = paddle.nn.loss.CrossEntropyLoss()
+ self.metric = ChunkEvaluator(
+ label_list=[self.label_map[i] for i in sorted(self.label_map.keys())]
+ )
+ elif task is None:
+ self.model = ElectraModel.from_pretrained(pretrained_model_name_or_path='chinese-electra-base', **kwargs)
+ else:
+ raise RuntimeError("Unknown task {}, task should be one in {}".format(
+ task, self._tasks_supported))
- Args:
- input_ids (tensor): the word ids.
- position_ids (tensor): the position ids.
- segment_ids (tensor): the segment ids.
- input_mask (tensor): the padding mask.
+ self.task = task
- Returns:
- pooled_output (tensor): sentence-level output for classification task.
- sequence_output (tensor): token-level output for sequence task.
- """
- electra = ElectraModel(src_ids=input_ids,
- position_ids=position_ids,
- sentence_ids=segment_ids,
- input_mask=input_mask,
- config=self.electra_config,
- use_fp16=False)
- pooled_output = electra.get_pooled_output()
- sequence_output = electra.get_sequence_output()
- return pooled_output, sequence_output
+ if load_checkpoint is not None and os.path.isfile(load_checkpoint):
+ state_dict = paddle.load(load_checkpoint)
+ self.set_state_dict(state_dict)
+ logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint))
+ def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None, seq_lengths=None, labels=None):
+ result = self.model(input_ids, token_type_ids, position_ids, attention_mask)
+ if self.task == 'seq-cls':
+ logits = result
+ probs = F.softmax(logits, axis=1)
+ if labels is not None:
+ loss = self.criterion(logits, labels)
+ correct = self.metric.compute(probs, labels)
+ acc = self.metric.update(correct)
+ return probs, loss, {'acc': acc}
+ return probs
+ elif self.task == 'token-cls':
+ logits = result
+ token_level_probs = F.softmax(logits, axis=-1)
+ preds = token_level_probs.argmax(axis=-1)
+ if labels is not None:
+ loss = self.criterion(logits, labels.unsqueeze(-1))
+ num_infer_chunks, num_label_chunks, num_correct_chunks = \
+ self.metric.compute(None, seq_lengths, preds, labels)
+ self.metric.update(
+ num_infer_chunks.numpy(), num_label_chunks.numpy(), num_correct_chunks.numpy())
+ _, _, f1_score = map(float, self.metric.accumulate())
+ return token_level_probs, loss, {'f1_score': f1_score}
+ return token_level_probs
+ else:
+ sequence_output, pooled_output = result
+ return sequence_output, pooled_output
-if __name__ == '__main__':
- test_module = Electra()
+ @staticmethod
+ def get_tokenizer(*args, **kwargs):
+ """
+ Gets the tokenizer that is customized for this module.
+ """
+ return ElectraTokenizer.from_pretrained(
+ pretrained_model_name_or_path='chinese-electra-base', *args, **kwargs)
diff --git a/modules/text/language_model/chinese_electra_small/README.md b/modules/text/language_model/chinese_electra_small/README.md
index 34580f1d..4a5f0e99 100644
--- a/modules/text/language_model/chinese_electra_small/README.md
+++ b/modules/text/language_model/chinese_electra_small/README.md
@@ -1,73 +1,73 @@
```shell
-$ hub install chinese-electra-small==1.0.0
+$ hub install chinese-electra-small==2.0.0
```
+
-
+
更多详情请参考[ELECTRA论文](https://openreview.net/pdf?id=r1xMH1BtvB)
## API
```python
-def context(
- trainable=True,
- max_seq_len=128
+def __init__(
+ task=None,
+ load_checkpoint=None,
+ label_map=None,
+ num_classes=2,
+ **kwargs,
)
```
-用于获取Module的上下文信息,得到输入、输出以及预训练的Paddle Program副本
-
-**参数**
-> trainable:设置为True时,Module中的参数在Fine-tune时也会随之训练,否则保持不变。
-> max_seq_len:BERT模型的最大序列长度,若序列长度不足,会通过padding方式补到**max_seq_len**, 若序列长度大于该值,则会以截断方式让序列长度为**max_seq_len**,max_seq_len可取值范围为0~512;
+创建Module对象(动态图组网版本)。
-**返回**
-> inputs:dict类型,有以下字段:
-> >**input_ids**存放输入文本tokenize后各token对应BERT词汇表的word ids, shape为\[batch_size, max_seq_len\],int64类型;
-> >**position_ids**存放输入文本tokenize后各token所在该文本的位置,shape为\[batch_size, max_seq_len\],int64类型;
-> >**segment_ids**存放各token所在文本的标识(token属于文本1或者文本2),shape为\[batch_size, max_seq_len\],int64类型;
-> >**input_mask**存放token是否为padding的标识,shape为\[batch_size, max_seq_len\],int64类型;
->
-> outputs:dict类型,Module的输出特征,有以下字段:
-> >**pooled_output**字段存放句子粒度的特征,可用于文本分类等任务,shape为 \[batch_size, 768\],int64类型;
-> >**sequence_output**字段存放字粒度的特征,可用于序列标注等任务,shape为 \[batch_size, seq_len, 768\],int64类型;
->
-> program:包含该Module计算图的Program。
+**参数**
+* `task`: 任务名称,可为`seq-cls`(文本分类任务,原来的`sequence_classification`在未来会被弃用)或`token-cls`(序列标注任务)。
+* `load_checkpoint`:使用PaddleHub Fine-tune api训练保存的模型参数文件路径。
+* `label_map`:预测时的类别映射表。
+* `num_classes`:分类任务的类别数,如果指定了`label_map`,此参数可不传,默认2分类。
+* `**kwargs`:用户额外指定的关键字字典类型的参数。
```python
-def get_embedding(
- texts,
- use_gpu=False,
- batch_size=1
+def predict(
+ data,
+ max_seq_len=128,
+ batch_size=1,
+ use_gpu=False
)
```
-用于获取输入文本的句子粒度特征与字粒度特征
-
**参数**
-> texts:输入文本列表,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。
-> use_gpu:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
+* `data`: 待预测数据,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。每个样例文本数量(1个或者2个)需和训练时保持一致。
+* `max_seq_len`:模型处理文本的最大长度
+* `batch_size`:模型批处理大小
+* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
**返回**
-> results:list类型,格式为\[\[sample\_a\_pooled\_feature, sample\_a\_seq\_feature\], \[sample\_b\_pooled\_feature, sample\_b\_seq\_feature\],…,\],其中每个元素都是对应样例的特征输出,每个样例都有句子粒度特征pooled\_feature与字粒度特征seq\_feature。
->
+* `results`:list类型,不同任务类型的返回结果如下
+ * 文本分类:列表里包含每个句子的预测标签,格式为\[label\_1, label\_2, …,\]
+ * 序列标注:列表里包含每个句子每个token的预测标签,格式为\[\[token\_1, token\_2, …,\], \[token\_1, token\_2, …,\], …,\]
```python
-def get_params_layer()
+def get_embedding(
+ data,
+ use_gpu=False
+)
```
-用于获取参数层信息,该方法与ULMFiTStrategy联用可以严格按照层数设置分层学习率与逐层解冻。
+用于获取输入文本的句子粒度特征与字粒度特征
**参数**
-> 无
+* `data`:输入文本列表,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。
+* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
**返回**
-> params_layer:dict类型,key为参数名,值为参数所在层数
+* `results`:list类型,格式为\[\[sample\_a\_pooled\_feature, sample\_a\_seq\_feature\], \[sample\_b\_pooled\_feature, sample\_b\_seq\_feature\],…,\],其中每个元素都是对应样例的特征输出,每个样例都有句子粒度特征pooled\_feature与字粒度特征seq\_feature。
**代码示例**
@@ -75,45 +75,83 @@ def get_params_layer()
```python
import paddlehub as hub
-# Load $ hub install chinese-electra-small pretrained model
-module = hub.Module(name="chinese-electra-small")
-inputs, outputs, program = module.context(trainable=True, max_seq_len=128)
+data = [
+ ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般'],
+ ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片'],
+ ['作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。'],
+]
+label_map = {0: 'negative', 1: 'positive'}
+
+model = hub.Module(
+ name='chinese-electra-small',
+ version='2.0.0',
+ task='seq-cls',
+ load_checkpoint='/path/to/parameters',
+ label_map=label_map)
+results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False)
+for idx, text in enumerate(data):
+ print('Data: {} \t Lable: {}'.format(text, results[idx]))
+```
+
+详情可参考PaddleHub示例:
+- [文本分类](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0.0-beta/demo/text_classification)
+- [序列标注](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0.0-beta/demo/sequence_labeling)
-# Must feed all the tensor of chinese-electra-small's module need
-input_ids = inputs["input_ids"]
-position_ids = inputs["position_ids"]
-segment_ids = inputs["segment_ids"]
-input_mask = inputs["input_mask"]
+## 服务部署
-# Use "pooled_output" for sentence-level output.
-pooled_output = outputs["pooled_output"]
+PaddleHub Serving可以部署一个在线获取预训练词向量。
-# Use "sequence_output" for token-level output.
-sequence_output = outputs["sequence_output"]
+### Step1: 启动PaddleHub Serving
-# Use "get_embedding" to get embedding result.
-embedding_result = module.get_embedding(texts=[["Sample1_text_a"],["Sample2_text_a","Sample2_text_b"]], use_gpu=True)
+运行启动命令:
-# Use "get_params_layer" to get params layer and used to ULMFiTStrategy.
-params_layer = module.get_params_layer()
-strategy = hub.finetune.strategy.ULMFiTStrategy(frz_params_layer=params_layer, dis_params_layer=params_layer)
+```shell
+$ hub serving start -m chinese-electra-small
```
+这样就完成了一个获取预训练词向量服务化API的部署,默认端口号为8866。
+
+**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。
+
+### Step2: 发送预测请求
+
+配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
+
+```python
+import requests
+import json
+
+# 指定用于获取embedding的文本[[text_1], [text_2], ... ]}
+text = [["今天是个好日子"], ["天气预报说今天要下雨"]]
+# 以key的方式指定text传入预测方法的时的参数,此例中为"data"
+# 对应本地部署,则为module.get_embedding(data=text)
+data = {"data": text}
+# 发送post请求,content-type类型应指定json方式,url中的ip地址需改为对应机器的ip
+url = "http://10.12.121.132:8866/predict/chinese-electra-small"
+# 指定post请求的headers为application/json方式
+headers = {"Content-Type": "application/json"}
+
+r = requests.post(url=url, headers=headers, data=json.dumps(data))
+print(r.json())
+```
## 查看代码
https://github.com/ymcui/Chinese-ELECTRA
-
## 依赖
-paddlepaddle >= 1.6.2
+paddlepaddle >= 2.0.0
-paddlehub >= 1.6.0
+paddlehub >= 2.0.0
## 更新历史
* 1.0.0
初始发布
+
+* 2.0.0
+
+ 全面升级动态图,接口有所变化。任务名称调整,增加序列标注任务`token-cls`
diff --git a/modules/text/language_model/chinese_electra_small/model/electra.py b/modules/text/language_model/chinese_electra_small/model/electra.py
deleted file mode 100644
index e1ec68f8..00000000
--- a/modules/text/language_model/chinese_electra_small/model/electra.py
+++ /dev/null
@@ -1,200 +0,0 @@
-# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""ELECTRA model."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import six
-import json
-
-import paddle.fluid as fluid
-
-from chinese_electra_small.model.transformer_encoder import encoder, pre_process_layer
-
-
-class ElectraConfig(object):
- def __init__(self, config_path):
- self._config_dict = self._parse(config_path)
-
- def _parse(self, config_path):
- try:
- with open(config_path) as json_file:
- config_dict = json.load(json_file)
- except Exception:
- raise IOError("Error in parsing electra model config file '%s'" % config_path)
- else:
- return config_dict
-
- def __getitem__(self, key):
- return self._config_dict[key]
-
- def print_config(self):
- for arg, value in sorted(six.iteritems(self._config_dict)):
- print('%s: %s' % (arg, value))
- print('------------------------------------------------')
-
-
-class ElectraModel(object):
- def __init__(self, src_ids, position_ids, sentence_ids, input_mask, config, weight_sharing=True, use_fp16=False):
-
- self._emb_size = 128
- self._hidden_size = config['hidden_size']
- self._n_layer = config['num_hidden_layers']
- self._n_head = config['num_attention_heads']
- self._voc_size = config['vocab_size']
- self._max_position_seq_len = config['max_position_embeddings']
- self._sent_types = config['type_vocab_size']
- self._hidden_act = config['hidden_act']
- self._prepostprocess_dropout = config['hidden_dropout_prob']
- self._attention_dropout = config['attention_probs_dropout_prob']
- self._weight_sharing = weight_sharing
-
- self._word_emb_name = "word_embedding"
- self._pos_emb_name = "pos_embedding"
- self._sent_emb_name = "sent_embedding"
- self._dtype = "float16" if use_fp16 else "float32"
-
- # Initialize all weigths by truncated normal initializer, and all biases
- # will be initialized by constant zero by default.
- self._param_initializer = fluid.initializer.TruncatedNormal(scale=config['initializer_range'])
-
- self._build_model(src_ids, position_ids, sentence_ids, input_mask)
-
- def _build_model(self, src_ids, position_ids, sentence_ids, input_mask):
- # padding id in vocabulary must be set to 0
- emb_out = fluid.layers.embedding(input=src_ids,
- size=[self._voc_size, self._emb_size],
- dtype=self._dtype,
- param_attr=fluid.ParamAttr(name=self._word_emb_name,
- initializer=self._param_initializer),
- is_sparse=False)
- position_emb_out = fluid.layers.embedding(input=position_ids,
- size=[self._max_position_seq_len, self._emb_size],
- dtype=self._dtype,
- param_attr=fluid.ParamAttr(name=self._pos_emb_name,
- initializer=self._param_initializer))
-
- sent_emb_out = fluid.layers.embedding(sentence_ids,
- size=[self._sent_types, self._emb_size],
- dtype=self._dtype,
- param_attr=fluid.ParamAttr(name=self._sent_emb_name,
- initializer=self._param_initializer))
-
- emb_out = emb_out + position_emb_out
- emb_out = emb_out + sent_emb_out
-
- emb_out = pre_process_layer(emb_out, 'nd', self._prepostprocess_dropout, name='pre_encoder')
-
- if self._emb_size != self._hidden_size:
- emb_out = fluid.layers.fc(input=emb_out,
- size=self._hidden_size,
- act=None,
- param_attr=fluid.ParamAttr(name="embeddings_project.w_0",
- initializer=self._param_initializer),
- num_flatten_dims=2,
- bias_attr="embeddings_project.b_0")
-
- if self._dtype == "float16":
- input_mask = fluid.layers.cast(x=input_mask, dtype=self._dtype)
-
- self_attn_mask = fluid.layers.matmul(x=input_mask, y=input_mask, transpose_y=True)
- self_attn_mask = fluid.layers.scale(x=self_attn_mask, scale=10000.0, bias=-1.0, bias_after_scale=False)
- n_head_self_attn_mask = fluid.layers.stack(x=[self_attn_mask] * self._n_head, axis=1)
- n_head_self_attn_mask.stop_gradient = True
-
- self._enc_out = encoder(enc_input=emb_out,
- attn_bias=n_head_self_attn_mask,
- n_layer=self._n_layer,
- n_head=self._n_head,
- d_key=self._hidden_size // self._n_head,
- d_value=self._hidden_size // self._n_head,
- d_model=self._hidden_size,
- d_inner_hid=self._hidden_size * 4,
- prepostprocess_dropout=self._prepostprocess_dropout,
- attention_dropout=self._attention_dropout,
- relu_dropout=0,
- hidden_act=self._hidden_act,
- preprocess_cmd="",
- postprocess_cmd="dan",
- param_initializer=self._param_initializer,
- name='encoder')
-
- def get_sequence_output(self):
- return self._enc_out
-
- def get_pooled_output(self):
- """Get the first feature of each sequence for classification"""
- next_sent_feat = fluid.layers.slice(input=self._enc_out, axes=[1], starts=[0], ends=[1])
- return next_sent_feat
-
- def get_pretraining_output(self, mask_label, mask_pos, labels):
- """Get the loss & accuracy for pretraining"""
-
- mask_pos = fluid.layers.cast(x=mask_pos, dtype='int32')
-
- # extract the first token feature in each sentence
- next_sent_feat = self.get_pooled_output()
- reshaped_emb_out = fluid.layers.reshape(x=self._enc_out, shape=[-1, self._hidden_size])
- # extract masked tokens' feature
- mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos)
-
- # transform: fc
- mask_trans_feat = fluid.layers.fc(input=mask_feat,
- size=self._hidden_size,
- act=self._hidden_act,
- param_attr=fluid.ParamAttr(name='mask_lm_trans_fc.w_0',
- initializer=self._param_initializer),
- bias_attr=fluid.ParamAttr(name='mask_lm_trans_fc.b_0'))
- # transform: layer norm
- mask_trans_feat = pre_process_layer(mask_trans_feat, 'n', name='mask_lm_trans')
-
- mask_lm_out_bias_attr = fluid.ParamAttr(name="mask_lm_out_fc.b_0",
- initializer=fluid.initializer.Constant(value=0.0))
- if self._weight_sharing:
- fc_out = fluid.layers.matmul(x=mask_trans_feat,
- y=fluid.default_main_program().global_block().var(self._word_emb_name),
- transpose_y=True)
- fc_out += fluid.layers.create_parameter(shape=[self._voc_size],
- dtype=self._dtype,
- attr=mask_lm_out_bias_attr,
- is_bias=True)
-
- else:
- fc_out = fluid.layers.fc(input=mask_trans_feat,
- size=self._voc_size,
- param_attr=fluid.ParamAttr(name="mask_lm_out_fc.w_0",
- initializer=self._param_initializer),
- bias_attr=mask_lm_out_bias_attr)
-
- mask_lm_loss = fluid.layers.softmax_with_cross_entropy(logits=fc_out, label=mask_label)
- mean_mask_lm_loss = fluid.layers.mean(mask_lm_loss)
-
- next_sent_fc_out = fluid.layers.fc(input=next_sent_feat,
- size=2,
- param_attr=fluid.ParamAttr(name="next_sent_fc.w_0",
- initializer=self._param_initializer),
- bias_attr="next_sent_fc.b_0")
-
- next_sent_loss, next_sent_softmax = fluid.layers.softmax_with_cross_entropy(logits=next_sent_fc_out,
- label=labels,
- return_softmax=True)
-
- next_sent_acc = fluid.layers.accuracy(input=next_sent_softmax, label=labels)
-
- mean_next_sent_loss = fluid.layers.mean(next_sent_loss)
-
- loss = mean_next_sent_loss + mean_mask_lm_loss
- return next_sent_acc, mean_mask_lm_loss, loss
diff --git a/modules/text/language_model/chinese_electra_small/model/transformer_encoder.py b/modules/text/language_model/chinese_electra_small/model/transformer_encoder.py
deleted file mode 100644
index b15d8388..00000000
--- a/modules/text/language_model/chinese_electra_small/model/transformer_encoder.py
+++ /dev/null
@@ -1,295 +0,0 @@
-# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Transformer encoder."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from functools import partial
-
-import paddle.fluid as fluid
-import paddle.fluid.layers as layers
-
-
-def multi_head_attention(queries,
- keys,
- values,
- attn_bias,
- d_key,
- d_value,
- d_model,
- n_head=1,
- dropout_rate=0.,
- cache=None,
- param_initializer=None,
- name='multi_head_att'):
- """
- Multi-Head Attention. Note that attn_bias is added to the logit before
- computing softmax activiation to mask certain selected positions so that
- they will not considered in attention weights.
- """
- keys = queries if keys is None else keys
- values = keys if values is None else values
-
- if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
- raise ValueError("Inputs: quries, keys and values should all be 3-D tensors.")
-
- def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
- """
- Add linear projection to queries, keys, and values.
- """
- q = layers.fc(input=queries,
- size=d_key * n_head,
- num_flatten_dims=2,
- param_attr=fluid.ParamAttr(name=name + '_query_fc.w_0', initializer=param_initializer),
- bias_attr=name + '_query_fc.b_0')
- k = layers.fc(input=keys,
- size=d_key * n_head,
- num_flatten_dims=2,
- param_attr=fluid.ParamAttr(name=name + '_key_fc.w_0', initializer=param_initializer),
- bias_attr=name + '_key_fc.b_0')
- v = layers.fc(input=values,
- size=d_value * n_head,
- num_flatten_dims=2,
- param_attr=fluid.ParamAttr(name=name + '_value_fc.w_0', initializer=param_initializer),
- bias_attr=name + '_value_fc.b_0')
- return q, k, v
-
- def __split_heads(x, n_head):
- """
- Reshape the last dimension of inpunt tensor x so that it becomes two
- dimensions and then transpose. Specifically, input a tensor with shape
- [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
- with shape [bs, n_head, max_sequence_length, hidden_dim].
- """
- hidden_size = x.shape[-1]
- # The value 0 in shape attr means copying the corresponding dimension
- # size of the input as the output dimension size.
- reshaped = layers.reshape(x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
-
- # permuate the dimensions into:
- # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
- return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
-
- def __combine_heads(x):
- """
- Transpose and then reshape the last two dimensions of inpunt tensor x
- so that it becomes one dimension, which is reverse to __split_heads.
- """
- if len(x.shape) == 3: return x
- if len(x.shape) != 4:
- raise ValueError("Input(x) should be a 4-D Tensor.")
-
- trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
- # The value 0 in shape attr means copying the corresponding dimension
- # size of the input as the output dimension size.
- return layers.reshape(x=trans_x, shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]], inplace=True)
-
- def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
- """
- Scaled Dot-Product Attention
- """
- scaled_q = layers.scale(x=q, scale=d_key**-0.5)
- product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
- if attn_bias:
- product += attn_bias
- weights = layers.softmax(product)
- if dropout_rate:
- weights = layers.dropout(weights,
- dropout_prob=dropout_rate,
- dropout_implementation="upscale_in_train",
- is_test=False)
- out = layers.matmul(weights, v)
- return out
-
- q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
-
- if cache is not None: # use cache and concat time steps
- # Since the inplace reshape in __split_heads changes the shape of k and
- # v, which is the cache input for next time step, reshape the cache
- # input from the previous time step first.
- k = cache["k"] = layers.concat([layers.reshape(cache["k"], shape=[0, 0, d_model]), k], axis=1)
- v = cache["v"] = layers.concat([layers.reshape(cache["v"], shape=[0, 0, d_model]), v], axis=1)
-
- q = __split_heads(q, n_head)
- k = __split_heads(k, n_head)
- v = __split_heads(v, n_head)
-
- ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate)
-
- out = __combine_heads(ctx_multiheads)
-
- # Project back to the model size.
- proj_out = layers.fc(input=out,
- size=d_model,
- num_flatten_dims=2,
- param_attr=fluid.ParamAttr(name=name + '_output_fc.w_0', initializer=param_initializer),
- bias_attr=name + '_output_fc.b_0')
- return proj_out
-
-
-def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate, hidden_act, param_initializer=None, name='ffn'):
- """
- Position-wise Feed-Forward Networks.
- This module consists of two linear transformations with a ReLU activation
- in between, which is applied to each position separately and identically.
- """
- hidden = layers.fc(input=x,
- size=d_inner_hid,
- num_flatten_dims=2,
- act=hidden_act,
- param_attr=fluid.ParamAttr(name=name + '_fc_0.w_0', initializer=param_initializer),
- bias_attr=name + '_fc_0.b_0')
- if dropout_rate:
- hidden = layers.dropout(hidden,
- dropout_prob=dropout_rate,
- dropout_implementation="upscale_in_train",
- is_test=False)
- out = layers.fc(input=hidden,
- size=d_hid,
- num_flatten_dims=2,
- param_attr=fluid.ParamAttr(name=name + '_fc_1.w_0', initializer=param_initializer),
- bias_attr=name + '_fc_1.b_0')
- return out
-
-
-def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0., name=''):
- """
- Add residual connection, layer normalization and droput to the out tensor
- optionally according to the value of process_cmd.
- This will be used before or after multi-head attention and position-wise
- feed-forward networks.
- """
- for cmd in process_cmd:
- if cmd == "a": # add residual connection
- out = out + prev_out if prev_out else out
- elif cmd == "n": # add layer normalization
- out_dtype = out.dtype
- if out_dtype == fluid.core.VarDesc.VarType.FP16:
- out = layers.cast(x=out, dtype="float32")
- out = layers.layer_norm(out,
- begin_norm_axis=len(out.shape) - 1,
- param_attr=fluid.ParamAttr(name=name + '_layer_norm_scale',
- initializer=fluid.initializer.Constant(1.)),
- bias_attr=fluid.ParamAttr(name=name + '_layer_norm_bias',
- initializer=fluid.initializer.Constant(0.)))
- if out_dtype == fluid.core.VarDesc.VarType.FP16:
- out = layers.cast(x=out, dtype="float16")
- elif cmd == "d": # add dropout
- if dropout_rate:
- out = layers.dropout(out,
- dropout_prob=dropout_rate,
- dropout_implementation="upscale_in_train",
- is_test=False)
- return out
-
-
-pre_process_layer = partial(pre_post_process_layer, None)
-post_process_layer = pre_post_process_layer
-
-
-def encoder_layer(enc_input,
- attn_bias,
- n_head,
- d_key,
- d_value,
- d_model,
- d_inner_hid,
- prepostprocess_dropout,
- attention_dropout,
- relu_dropout,
- hidden_act,
- preprocess_cmd="n",
- postprocess_cmd="da",
- param_initializer=None,
- name=''):
- """The encoder layers that can be stacked to form a deep encoder.
- This module consits of a multi-head (self) attention followed by
- position-wise feed-forward networks and both the two components companied
- with the post_process_layer to add residual connection, layer normalization
- and droput.
- """
- attn_output = multi_head_attention(pre_process_layer(enc_input,
- preprocess_cmd,
- prepostprocess_dropout,
- name=name + '_pre_att'),
- None,
- None,
- attn_bias,
- d_key,
- d_value,
- d_model,
- n_head,
- attention_dropout,
- param_initializer=param_initializer,
- name=name + '_multi_head_att')
- attn_output = post_process_layer(enc_input,
- attn_output,
- postprocess_cmd,
- prepostprocess_dropout,
- name=name + '_post_att')
- ffd_output = positionwise_feed_forward(pre_process_layer(attn_output,
- preprocess_cmd,
- prepostprocess_dropout,
- name=name + '_pre_ffn'),
- d_inner_hid,
- d_model,
- relu_dropout,
- hidden_act,
- param_initializer=param_initializer,
- name=name + '_ffn')
- return post_process_layer(attn_output, ffd_output, postprocess_cmd, prepostprocess_dropout, name=name + '_post_ffn')
-
-
-def encoder(enc_input,
- attn_bias,
- n_layer,
- n_head,
- d_key,
- d_value,
- d_model,
- d_inner_hid,
- prepostprocess_dropout,
- attention_dropout,
- relu_dropout,
- hidden_act,
- preprocess_cmd="n",
- postprocess_cmd="da",
- param_initializer=None,
- name=''):
- """
- The encoder is composed of a stack of identical layers returned by calling
- encoder_layer.
- """
- for i in range(n_layer):
- enc_output = encoder_layer(enc_input,
- attn_bias,
- n_head,
- d_key,
- d_value,
- d_model,
- d_inner_hid,
- prepostprocess_dropout,
- attention_dropout,
- relu_dropout,
- hidden_act,
- preprocess_cmd,
- postprocess_cmd,
- param_initializer=param_initializer,
- name=name + '_layer_' + str(i))
- enc_input = enc_output
- enc_output = pre_process_layer(enc_output, preprocess_cmd, prepostprocess_dropout, name="post_encoder")
-
- return enc_output
diff --git a/modules/text/language_model/chinese_electra_small/module.py b/modules/text/language_model/chinese_electra_small/module.py
index ac55aed3..763f7d4f 100644
--- a/modules/text/language_model/chinese_electra_small/module.py
+++ b/modules/text/language_model/chinese_electra_small/module.py
@@ -1,7 +1,6 @@
-# coding:utf-8
-# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
-# Licensed under the Apache License, Version 2.0 (the "License"
+# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
@@ -12,62 +11,120 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
+from typing import Dict
import os
-from paddlehub import TransformerModule
-from paddlehub.module.module import moduleinfo
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
-from chinese_electra_small.model.electra import ElectraConfig, ElectraModel
+from paddlenlp.transformers.electra.modeling import ElectraForSequenceClassification, ElectraForTokenClassification, ElectraModel
+from paddlenlp.transformers.electra.tokenizer import ElectraTokenizer
+from paddlenlp.metrics import ChunkEvaluator
+from paddlehub.module.module import moduleinfo
+from paddlehub.module.nlp_module import TransformerModule
+from paddlehub.utils.log import logger
@moduleinfo(
name="chinese-electra-small",
- version="1.0.0",
- summary="chinese-electra-small, 12-layer, 256-hidden, 4-heads, 12M parameters",
+ version="2.0.0",
+ summary=
+ "chinese-electra-small, 12-layer, 256-hidden, 4-heads, 12M parameters. The module is executed as paddle.dygraph.",
author="ymcui",
author_email="ymcui@ir.hit.edu.cn",
type="nlp/semantic_model",
+ meta=TransformerModule,
)
-class Electra(TransformerModule):
- def _initialize(self):
- self.MAX_SEQ_LEN = 512
- self.params_path = os.path.join(self.directory, "assets", "params")
- self.vocab_path = os.path.join(self.directory, "assets", "vocab.txt")
+class Electra(nn.Layer):
+ """
+ Electra model
+ """
- electra_config_path = os.path.join(self.directory, "assets", "config.json")
- self.electra_config = ElectraConfig(electra_config_path)
+ def __init__(
+ self,
+ task: str = None,
+ load_checkpoint: str = None,
+ label_map: Dict = None,
+ num_classes: int = 2,
+ **kwargs,
+ ):
+ super(Electra, self).__init__()
+ if label_map:
+ self.label_map = label_map
+ self.num_classes = len(label_map)
+ else:
+ self.num_classes = num_classes
- def net(self, input_ids, position_ids, segment_ids, input_mask):
- """
- create neural network.
+ if task == 'sequence_classification':
+ task = 'seq-cls'
+ logger.warning(
+ "current task name 'sequence_classification' was renamed to 'seq-cls', "
+ "'sequence_classification' has been deprecated and will be removed in the future.",
+ )
+ if task == 'seq-cls':
+ self.model = ElectraForSequenceClassification.from_pretrained(
+ pretrained_model_name_or_path='chinese-electra-small',
+ num_classes=self.num_classes,
+ **kwargs
+ )
+ self.criterion = paddle.nn.loss.CrossEntropyLoss()
+ self.metric = paddle.metric.Accuracy()
+ elif task == 'token-cls':
+ self.model = ElectraForTokenClassification.from_pretrained(
+ pretrained_model_name_or_path='chinese-electra-small',
+ num_classes=self.num_classes,
+ **kwargs
+ )
+ self.criterion = paddle.nn.loss.CrossEntropyLoss()
+ self.metric = ChunkEvaluator(
+ label_list=[self.label_map[i] for i in sorted(self.label_map.keys())]
+ )
+ elif task is None:
+ self.model = ElectraModel.from_pretrained(pretrained_model_name_or_path='chinese-electra-small', **kwargs)
+ else:
+ raise RuntimeError("Unknown task {}, task should be one in {}".format(
+ task, self._tasks_supported))
- Args:
- input_ids (tensor): the word ids.
- position_ids (tensor): the position ids.
- segment_ids (tensor): the segment ids.
- input_mask (tensor): the padding mask.
+ self.task = task
- Returns:
- pooled_output (tensor): sentence-level output for classification task.
- sequence_output (tensor): token-level output for sequence task.
- """
- electra = ElectraModel(src_ids=input_ids,
- position_ids=position_ids,
- sentence_ids=segment_ids,
- input_mask=input_mask,
- config=self.electra_config,
- use_fp16=False)
- pooled_output = electra.get_pooled_output()
- sequence_output = electra.get_sequence_output()
- return pooled_output, sequence_output
+ if load_checkpoint is not None and os.path.isfile(load_checkpoint):
+ state_dict = paddle.load(load_checkpoint)
+ self.set_state_dict(state_dict)
+ logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint))
+ def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None, seq_lengths=None, labels=None):
+ result = self.model(input_ids, token_type_ids, position_ids, attention_mask)
+ if self.task == 'seq-cls':
+ logits = result
+ probs = F.softmax(logits, axis=1)
+ if labels is not None:
+ loss = self.criterion(logits, labels)
+ correct = self.metric.compute(probs, labels)
+ acc = self.metric.update(correct)
+ return probs, loss, {'acc': acc}
+ return probs
+ elif self.task == 'token-cls':
+ logits = result
+ token_level_probs = F.softmax(logits, axis=-1)
+ preds = token_level_probs.argmax(axis=-1)
+ if labels is not None:
+ loss = self.criterion(logits, labels.unsqueeze(-1))
+ num_infer_chunks, num_label_chunks, num_correct_chunks = \
+ self.metric.compute(None, seq_lengths, preds, labels)
+ self.metric.update(
+ num_infer_chunks.numpy(), num_label_chunks.numpy(), num_correct_chunks.numpy())
+ _, _, f1_score = map(float, self.metric.accumulate())
+ return token_level_probs, loss, {'f1_score': f1_score}
+ return token_level_probs
+ else:
+ sequence_output, pooled_output = result
+ return sequence_output, pooled_output
-if __name__ == '__main__':
- test_module = Electra()
+ @staticmethod
+ def get_tokenizer(*args, **kwargs):
+ """
+ Gets the tokenizer that is customized for this module.
+ """
+ return ElectraTokenizer.from_pretrained(
+ pretrained_model_name_or_path='chinese-electra-small', *args, **kwargs)
diff --git a/modules/text/language_model/electra_base/README.md b/modules/text/language_model/electra_base/README.md
new file mode 100644
index 00000000..df076cc0
--- /dev/null
+++ b/modules/text/language_model/electra_base/README.md
@@ -0,0 +1,153 @@
+```shell
+$ hub install electra-base==1.0.0
+```
+
+
+
+
+
+更多详情请参考[ELECTRA论文](https://openreview.net/pdf?id=r1xMH1BtvB)
+
+## API
+```python
+def __init__(
+ task=None,
+ load_checkpoint=None,
+ label_map=None,
+ num_classes=2,
+ **kwargs,
+)
+```
+
+创建Module对象(动态图组网版本)。
+
+**参数**
+
+* `task`: 任务名称,可为`seq-cls`(文本分类任务,原来的`sequence_classification`在未来会被弃用)或`token-cls`(序列标注任务)。
+* `load_checkpoint`:使用PaddleHub Fine-tune api训练保存的模型参数文件路径。
+* `label_map`:预测时的类别映射表。
+* `num_classes`:分类任务的类别数,如果指定了`label_map`,此参数可不传,默认2分类。
+* `**kwargs`:用户额外指定的关键字字典类型的参数。
+
+```python
+def predict(
+ data,
+ max_seq_len=128,
+ batch_size=1,
+ use_gpu=False
+)
+```
+
+**参数**
+
+* `data`: 待预测数据,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。每个样例文本数量(1个或者2个)需和训练时保持一致。
+* `max_seq_len`:模型处理文本的最大长度
+* `batch_size`:模型批处理大小
+* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
+
+**返回**
+
+* `results`:list类型,不同任务类型的返回结果如下
+ * 文本分类:列表里包含每个句子的预测标签,格式为\[label\_1, label\_2, …,\]
+ * 序列标注:列表里包含每个句子每个token的预测标签,格式为\[\[token\_1, token\_2, …,\], \[token\_1, token\_2, …,\], …,\]
+
+```python
+def get_embedding(
+ data,
+ use_gpu=False
+)
+```
+
+用于获取输入文本的句子粒度特征与字粒度特征
+
+**参数**
+
+* `data`:输入文本列表,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。
+* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
+
+**返回**
+
+* `results`:list类型,格式为\[\[sample\_a\_pooled\_feature, sample\_a\_seq\_feature\], \[sample\_b\_pooled\_feature, sample\_b\_seq\_feature\],…,\],其中每个元素都是对应样例的特征输出,每个样例都有句子粒度特征pooled\_feature与字粒度特征seq\_feature。
+
+
+**代码示例**
+
+```python
+import paddlehub as hub
+
+data = [
+ ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般'],
+ ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片'],
+ ['作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。'],
+]
+label_map = {0: 'negative', 1: 'positive'}
+
+model = hub.Module(
+ name='electra-base',
+ version='1.0.0',
+ task='seq-cls',
+ load_checkpoint='/path/to/parameters',
+ label_map=label_map)
+results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False)
+for idx, text in enumerate(data):
+ print('Data: {} \t Lable: {}'.format(text, results[idx]))
+```
+
+详情可参考PaddleHub示例:
+- [文本分类](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0.0-beta/demo/text_classification)
+- [序列标注](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0.0-beta/demo/sequence_labeling)
+
+## 服务部署
+
+PaddleHub Serving可以部署一个在线获取预训练词向量。
+
+### Step1: 启动PaddleHub Serving
+
+运行启动命令:
+
+```shell
+$ hub serving start -m electra-base
+```
+
+这样就完成了一个获取预训练词向量服务化API的部署,默认端口号为8866。
+
+**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。
+
+### Step2: 发送预测请求
+
+配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
+
+```python
+import requests
+import json
+
+# 指定用于获取embedding的文本[[text_1], [text_2], ... ]}
+text = [["今天是个好日子"], ["天气预报说今天要下雨"]]
+# 以key的方式指定text传入预测方法的时的参数,此例中为"data"
+# 对应本地部署,则为module.get_embedding(data=text)
+data = {"data": text}
+# 发送post请求,content-type类型应指定json方式,url中的ip地址需改为对应机器的ip
+url = "http://10.12.121.132:8866/predict/electra-base"
+# 指定post请求的headers为application/json方式
+headers = {"Content-Type": "application/json"}
+
+r = requests.post(url=url, headers=headers, data=json.dumps(data))
+print(r.json())
+```
+
+## 查看代码
+
+https://github.com/google-research/electra
+
+
+## 依赖
+
+paddlepaddle >= 2.0.0
+
+paddlehub >= 2.0.0
+
+## 更新历史
+
+* 1.0.0
+
+ 初始发布,动态图版本模型,支持文本分类`seq-cls`和序列标注`token-cls`任务的fine-tune
diff --git a/modules/text/language_model/chinese_electra_base/model/__init__.py b/modules/text/language_model/electra_base/__init__.py
similarity index 100%
rename from modules/text/language_model/chinese_electra_base/model/__init__.py
rename to modules/text/language_model/electra_base/__init__.py
diff --git a/modules/text/language_model/electra_base/module.py b/modules/text/language_model/electra_base/module.py
new file mode 100644
index 00000000..1cfd62ff
--- /dev/null
+++ b/modules/text/language_model/electra_base/module.py
@@ -0,0 +1,130 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict
+import os
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from paddlenlp.transformers.electra.modeling import ElectraForSequenceClassification, ElectraForTokenClassification, ElectraModel
+from paddlenlp.transformers.electra.tokenizer import ElectraTokenizer
+from paddlenlp.metrics import ChunkEvaluator
+from paddlehub.module.module import moduleinfo
+from paddlehub.module.nlp_module import TransformerModule
+from paddlehub.utils.log import logger
+
+
+@moduleinfo(
+ name="electra-base",
+ version="1.0.0",
+ summary=
+ "electra-base, 12-layer, 768-hidden, 12-heads, 110M parameters. The module is executed as paddle.dygraph.",
+ author="paddlepaddle",
+ author_email="",
+ type="nlp/semantic_model",
+ meta=TransformerModule,
+)
+class Electra(nn.Layer):
+ """
+ Electra model
+ """
+
+ def __init__(
+ self,
+ task: str = None,
+ load_checkpoint: str = None,
+ label_map: Dict = None,
+ num_classes: int = 2,
+ **kwargs,
+ ):
+ super(Electra, self).__init__()
+ if label_map:
+ self.label_map = label_map
+ self.num_classes = len(label_map)
+ else:
+ self.num_classes = num_classes
+
+ if task == 'sequence_classification':
+ task = 'seq-cls'
+ logger.warning(
+ "current task name 'sequence_classification' was renamed to 'seq-cls', "
+ "'sequence_classification' has been deprecated and will be removed in the future.",
+ )
+ if task == 'seq-cls':
+ self.model = ElectraForSequenceClassification.from_pretrained(
+ pretrained_model_name_or_path='electra-base',
+ num_classes=self.num_classes,
+ **kwargs
+ )
+ self.criterion = paddle.nn.loss.CrossEntropyLoss()
+ self.metric = paddle.metric.Accuracy()
+ elif task == 'token-cls':
+ self.model = ElectraForTokenClassification.from_pretrained(
+ pretrained_model_name_or_path='electra-base',
+ num_classes=self.num_classes,
+ **kwargs
+ )
+ self.criterion = paddle.nn.loss.CrossEntropyLoss()
+ self.metric = ChunkEvaluator(
+ label_list=[self.label_map[i] for i in sorted(self.label_map.keys())]
+ )
+ elif task is None:
+ self.model = ElectraModel.from_pretrained(pretrained_model_name_or_path='electra-base', **kwargs)
+ else:
+ raise RuntimeError("Unknown task {}, task should be one in {}".format(
+ task, self._tasks_supported))
+
+ self.task = task
+
+ if load_checkpoint is not None and os.path.isfile(load_checkpoint):
+ state_dict = paddle.load(load_checkpoint)
+ self.set_state_dict(state_dict)
+ logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint))
+
+ def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None, seq_lengths=None, labels=None):
+ result = self.model(input_ids, token_type_ids, position_ids, attention_mask)
+ if self.task == 'seq-cls':
+ logits = result
+ probs = F.softmax(logits, axis=1)
+ if labels is not None:
+ loss = self.criterion(logits, labels)
+ correct = self.metric.compute(probs, labels)
+ acc = self.metric.update(correct)
+ return probs, loss, {'acc': acc}
+ return probs
+ elif self.task == 'token-cls':
+ logits = result
+ token_level_probs = F.softmax(logits, axis=-1)
+ preds = token_level_probs.argmax(axis=-1)
+ if labels is not None:
+ loss = self.criterion(logits, labels.unsqueeze(-1))
+ num_infer_chunks, num_label_chunks, num_correct_chunks = \
+ self.metric.compute(None, seq_lengths, preds, labels)
+ self.metric.update(
+ num_infer_chunks.numpy(), num_label_chunks.numpy(), num_correct_chunks.numpy())
+ _, _, f1_score = map(float, self.metric.accumulate())
+ return token_level_probs, loss, {'f1_score': f1_score}
+ return token_level_probs
+ else:
+ sequence_output, pooled_output = result
+ return sequence_output, pooled_output
+
+ @staticmethod
+ def get_tokenizer(*args, **kwargs):
+ """
+ Gets the tokenizer that is customized for this module.
+ """
+ return ElectraTokenizer.from_pretrained(
+ pretrained_model_name_or_path='electra-base', *args, **kwargs)
diff --git a/modules/text/language_model/electra_large/README.md b/modules/text/language_model/electra_large/README.md
new file mode 100644
index 00000000..81f931d8
--- /dev/null
+++ b/modules/text/language_model/electra_large/README.md
@@ -0,0 +1,153 @@
+```shell
+$ hub install electra-large==1.0.0
+```
+
+
+
+
+
+更多详情请参考[ELECTRA论文](https://openreview.net/pdf?id=r1xMH1BtvB)
+
+## API
+```python
+def __init__(
+ task=None,
+ load_checkpoint=None,
+ label_map=None,
+ num_classes=2,
+ **kwargs,
+)
+```
+
+创建Module对象(动态图组网版本)。
+
+**参数**
+
+* `task`: 任务名称,可为`seq-cls`(文本分类任务,原来的`sequence_classification`在未来会被弃用)或`token-cls`(序列标注任务)。
+* `load_checkpoint`:使用PaddleHub Fine-tune api训练保存的模型参数文件路径。
+* `label_map`:预测时的类别映射表。
+* `num_classes`:分类任务的类别数,如果指定了`label_map`,此参数可不传,默认2分类。
+* `**kwargs`:用户额外指定的关键字字典类型的参数。
+
+```python
+def predict(
+ data,
+ max_seq_len=128,
+ batch_size=1,
+ use_gpu=False
+)
+```
+
+**参数**
+
+* `data`: 待预测数据,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。每个样例文本数量(1个或者2个)需和训练时保持一致。
+* `max_seq_len`:模型处理文本的最大长度
+* `batch_size`:模型批处理大小
+* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
+
+**返回**
+
+* `results`:list类型,不同任务类型的返回结果如下
+ * 文本分类:列表里包含每个句子的预测标签,格式为\[label\_1, label\_2, …,\]
+ * 序列标注:列表里包含每个句子每个token的预测标签,格式为\[\[token\_1, token\_2, …,\], \[token\_1, token\_2, …,\], …,\]
+
+```python
+def get_embedding(
+ data,
+ use_gpu=False
+)
+```
+
+用于获取输入文本的句子粒度特征与字粒度特征
+
+**参数**
+
+* `data`:输入文本列表,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。
+* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
+
+**返回**
+
+* `results`:list类型,格式为\[\[sample\_a\_pooled\_feature, sample\_a\_seq\_feature\], \[sample\_b\_pooled\_feature, sample\_b\_seq\_feature\],…,\],其中每个元素都是对应样例的特征输出,每个样例都有句子粒度特征pooled\_feature与字粒度特征seq\_feature。
+
+
+**代码示例**
+
+```python
+import paddlehub as hub
+
+data = [
+ ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般'],
+ ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片'],
+ ['作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。'],
+]
+label_map = {0: 'negative', 1: 'positive'}
+
+model = hub.Module(
+ name='electra-large',
+ version='1.0.0',
+ task='seq-cls',
+ load_checkpoint='/path/to/parameters',
+ label_map=label_map)
+results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False)
+for idx, text in enumerate(data):
+ print('Data: {} \t Lable: {}'.format(text, results[idx]))
+```
+
+详情可参考PaddleHub示例:
+- [文本分类](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0.0-beta/demo/text_classification)
+- [序列标注](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0.0-beta/demo/sequence_labeling)
+
+## 服务部署
+
+PaddleHub Serving可以部署一个在线获取预训练词向量。
+
+### Step1: 启动PaddleHub Serving
+
+运行启动命令:
+
+```shell
+$ hub serving start -m electra-large
+```
+
+这样就完成了一个获取预训练词向量服务化API的部署,默认端口号为8866。
+
+**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。
+
+### Step2: 发送预测请求
+
+配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
+
+```python
+import requests
+import json
+
+# 指定用于获取embedding的文本[[text_1], [text_2], ... ]}
+text = [["今天是个好日子"], ["天气预报说今天要下雨"]]
+# 以key的方式指定text传入预测方法的时的参数,此例中为"data"
+# 对应本地部署,则为module.get_embedding(data=text)
+data = {"data": text}
+# 发送post请求,content-type类型应指定json方式,url中的ip地址需改为对应机器的ip
+url = "http://10.12.121.132:8866/predict/electra-large"
+# 指定post请求的headers为application/json方式
+headers = {"Content-Type": "application/json"}
+
+r = requests.post(url=url, headers=headers, data=json.dumps(data))
+print(r.json())
+```
+
+## 查看代码
+
+https://github.com/google-research/electra
+
+
+## 依赖
+
+paddlepaddle >= 2.0.0
+
+paddlehub >= 2.0.0
+
+## 更新历史
+
+* 1.0.0
+
+ 初始发布,动态图版本模型,支持文本分类`seq-cls`和序列标注`token-cls`任务的fine-tune
diff --git a/modules/text/language_model/chinese_electra_small/model/__init__.py b/modules/text/language_model/electra_large/__init__.py
similarity index 100%
rename from modules/text/language_model/chinese_electra_small/model/__init__.py
rename to modules/text/language_model/electra_large/__init__.py
diff --git a/modules/text/language_model/electra_large/module.py b/modules/text/language_model/electra_large/module.py
new file mode 100644
index 00000000..ae11788d
--- /dev/null
+++ b/modules/text/language_model/electra_large/module.py
@@ -0,0 +1,130 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict
+import os
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from paddlenlp.transformers.electra.modeling import ElectraForSequenceClassification, ElectraForTokenClassification, ElectraModel
+from paddlenlp.transformers.electra.tokenizer import ElectraTokenizer
+from paddlenlp.metrics import ChunkEvaluator
+from paddlehub.module.module import moduleinfo
+from paddlehub.module.nlp_module import TransformerModule
+from paddlehub.utils.log import logger
+
+
+@moduleinfo(
+ name="electra-large",
+ version="1.0.0",
+ summary=
+ "electra-large, 24-layer, 1024-hidden, 16-heads, 335M parameters. The module is executed as paddle.dygraph.",
+ author="paddlepaddle",
+ author_email="",
+ type="nlp/semantic_model",
+ meta=TransformerModule,
+)
+class Electra(nn.Layer):
+ """
+ Electra model
+ """
+
+ def __init__(
+ self,
+ task: str = None,
+ load_checkpoint: str = None,
+ label_map: Dict = None,
+ num_classes: int = 2,
+ **kwargs,
+ ):
+ super(Electra, self).__init__()
+ if label_map:
+ self.label_map = label_map
+ self.num_classes = len(label_map)
+ else:
+ self.num_classes = num_classes
+
+ if task == 'sequence_classification':
+ task = 'seq-cls'
+ logger.warning(
+ "current task name 'sequence_classification' was renamed to 'seq-cls', "
+ "'sequence_classification' has been deprecated and will be removed in the future.",
+ )
+ if task == 'seq-cls':
+ self.model = ElectraForSequenceClassification.from_pretrained(
+ pretrained_model_name_or_path='electra-large',
+ num_classes=self.num_classes,
+ **kwargs
+ )
+ self.criterion = paddle.nn.loss.CrossEntropyLoss()
+ self.metric = paddle.metric.Accuracy()
+ elif task == 'token-cls':
+ self.model = ElectraForTokenClassification.from_pretrained(
+ pretrained_model_name_or_path='electra-large',
+ num_classes=self.num_classes,
+ **kwargs
+ )
+ self.criterion = paddle.nn.loss.CrossEntropyLoss()
+ self.metric = ChunkEvaluator(
+ label_list=[self.label_map[i] for i in sorted(self.label_map.keys())]
+ )
+ elif task is None:
+ self.model = ElectraModel.from_pretrained(pretrained_model_name_or_path='electra-large', **kwargs)
+ else:
+ raise RuntimeError("Unknown task {}, task should be one in {}".format(
+ task, self._tasks_supported))
+
+ self.task = task
+
+ if load_checkpoint is not None and os.path.isfile(load_checkpoint):
+ state_dict = paddle.load(load_checkpoint)
+ self.set_state_dict(state_dict)
+ logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint))
+
+ def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None, seq_lengths=None, labels=None):
+ result = self.model(input_ids, token_type_ids, position_ids, attention_mask)
+ if self.task == 'seq-cls':
+ logits = result
+ probs = F.softmax(logits, axis=1)
+ if labels is not None:
+ loss = self.criterion(logits, labels)
+ correct = self.metric.compute(probs, labels)
+ acc = self.metric.update(correct)
+ return probs, loss, {'acc': acc}
+ return probs
+ elif self.task == 'token-cls':
+ logits = result
+ token_level_probs = F.softmax(logits, axis=-1)
+ preds = token_level_probs.argmax(axis=-1)
+ if labels is not None:
+ loss = self.criterion(logits, labels.unsqueeze(-1))
+ num_infer_chunks, num_label_chunks, num_correct_chunks = \
+ self.metric.compute(None, seq_lengths, preds, labels)
+ self.metric.update(
+ num_infer_chunks.numpy(), num_label_chunks.numpy(), num_correct_chunks.numpy())
+ _, _, f1_score = map(float, self.metric.accumulate())
+ return token_level_probs, loss, {'f1_score': f1_score}
+ return token_level_probs
+ else:
+ sequence_output, pooled_output = result
+ return sequence_output, pooled_output
+
+ @staticmethod
+ def get_tokenizer(*args, **kwargs):
+ """
+ Gets the tokenizer that is customized for this module.
+ """
+ return ElectraTokenizer.from_pretrained(
+ pretrained_model_name_or_path='electra-large', *args, **kwargs)
diff --git a/modules/text/language_model/electra_small/README.md b/modules/text/language_model/electra_small/README.md
new file mode 100644
index 00000000..65ec7548
--- /dev/null
+++ b/modules/text/language_model/electra_small/README.md
@@ -0,0 +1,153 @@
+```shell
+$ hub install electra-small==1.0.0
+```
+
+
+
+
+
+更多详情请参考[ELECTRA论文](https://openreview.net/pdf?id=r1xMH1BtvB)
+
+## API
+```python
+def __init__(
+ task=None,
+ load_checkpoint=None,
+ label_map=None,
+ num_classes=2,
+ **kwargs,
+)
+```
+
+创建Module对象(动态图组网版本)。
+
+**参数**
+
+* `task`: 任务名称,可为`seq-cls`(文本分类任务,原来的`sequence_classification`在未来会被弃用)或`token-cls`(序列标注任务)。
+* `load_checkpoint`:使用PaddleHub Fine-tune api训练保存的模型参数文件路径。
+* `label_map`:预测时的类别映射表。
+* `num_classes`:分类任务的类别数,如果指定了`label_map`,此参数可不传,默认2分类。
+* `**kwargs`:用户额外指定的关键字字典类型的参数。
+
+```python
+def predict(
+ data,
+ max_seq_len=128,
+ batch_size=1,
+ use_gpu=False
+)
+```
+
+**参数**
+
+* `data`: 待预测数据,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。每个样例文本数量(1个或者2个)需和训练时保持一致。
+* `max_seq_len`:模型处理文本的最大长度
+* `batch_size`:模型批处理大小
+* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
+
+**返回**
+
+* `results`:list类型,不同任务类型的返回结果如下
+ * 文本分类:列表里包含每个句子的预测标签,格式为\[label\_1, label\_2, …,\]
+ * 序列标注:列表里包含每个句子每个token的预测标签,格式为\[\[token\_1, token\_2, …,\], \[token\_1, token\_2, …,\], …,\]
+
+```python
+def get_embedding(
+ data,
+ use_gpu=False
+)
+```
+
+用于获取输入文本的句子粒度特征与字粒度特征
+
+**参数**
+
+* `data`:输入文本列表,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。
+* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
+
+**返回**
+
+* `results`:list类型,格式为\[\[sample\_a\_pooled\_feature, sample\_a\_seq\_feature\], \[sample\_b\_pooled\_feature, sample\_b\_seq\_feature\],…,\],其中每个元素都是对应样例的特征输出,每个样例都有句子粒度特征pooled\_feature与字粒度特征seq\_feature。
+
+
+**代码示例**
+
+```python
+import paddlehub as hub
+
+data = [
+ ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般'],
+ ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片'],
+ ['作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。'],
+]
+label_map = {0: 'negative', 1: 'positive'}
+
+model = hub.Module(
+ name='electra-small',
+ version='1.0.0',
+ task='seq-cls',
+ load_checkpoint='/path/to/parameters',
+ label_map=label_map)
+results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False)
+for idx, text in enumerate(data):
+ print('Data: {} \t Lable: {}'.format(text, results[idx]))
+```
+
+详情可参考PaddleHub示例:
+- [文本分类](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0.0-beta/demo/text_classification)
+- [序列标注](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0.0-beta/demo/sequence_labeling)
+
+## 服务部署
+
+PaddleHub Serving可以部署一个在线获取预训练词向量。
+
+### Step1: 启动PaddleHub Serving
+
+运行启动命令:
+
+```shell
+$ hub serving start -m electra-small
+```
+
+这样就完成了一个获取预训练词向量服务化API的部署,默认端口号为8866。
+
+**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。
+
+### Step2: 发送预测请求
+
+配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
+
+```python
+import requests
+import json
+
+# 指定用于获取embedding的文本[[text_1], [text_2], ... ]}
+text = [["今天是个好日子"], ["天气预报说今天要下雨"]]
+# 以key的方式指定text传入预测方法的时的参数,此例中为"data"
+# 对应本地部署,则为module.get_embedding(data=text)
+data = {"data": text}
+# 发送post请求,content-type类型应指定json方式,url中的ip地址需改为对应机器的ip
+url = "http://10.12.121.132:8866/predict/electra-small"
+# 指定post请求的headers为application/json方式
+headers = {"Content-Type": "application/json"}
+
+r = requests.post(url=url, headers=headers, data=json.dumps(data))
+print(r.json())
+```
+
+## 查看代码
+
+https://github.com/google-research/electra
+
+
+## 依赖
+
+paddlepaddle >= 2.0.0
+
+paddlehub >= 2.0.0
+
+## 更新历史
+
+* 1.0.0
+
+ 初始发布,动态图版本模型,支持文本分类`seq-cls`和序列标注`token-cls`任务的fine-tune
diff --git a/modules/text/language_model/electra_small/__init__.py b/modules/text/language_model/electra_small/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/modules/text/language_model/electra_small/module.py b/modules/text/language_model/electra_small/module.py
new file mode 100644
index 00000000..ad60dd88
--- /dev/null
+++ b/modules/text/language_model/electra_small/module.py
@@ -0,0 +1,130 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict
+import os
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from paddlenlp.transformers.electra.modeling import ElectraForSequenceClassification, ElectraForTokenClassification, ElectraModel
+from paddlenlp.transformers.electra.tokenizer import ElectraTokenizer
+from paddlenlp.metrics import ChunkEvaluator
+from paddlehub.module.module import moduleinfo
+from paddlehub.module.nlp_module import TransformerModule
+from paddlehub.utils.log import logger
+
+
+@moduleinfo(
+ name="electra-small",
+ version="1.0.0",
+ summary=
+ "electra-small, 12-layer, 256-hidden, 4-heads, 14M parameters. The module is executed as paddle.dygraph.",
+ author="paddlepaddle",
+ author_email="",
+ type="nlp/semantic_model",
+ meta=TransformerModule,
+)
+class Electra(nn.Layer):
+ """
+ Electra model
+ """
+
+ def __init__(
+ self,
+ task: str = None,
+ load_checkpoint: str = None,
+ label_map: Dict = None,
+ num_classes: int = 2,
+ **kwargs,
+ ):
+ super(Electra, self).__init__()
+ if label_map:
+ self.label_map = label_map
+ self.num_classes = len(label_map)
+ else:
+ self.num_classes = num_classes
+
+ if task == 'sequence_classification':
+ task = 'seq-cls'
+ logger.warning(
+ "current task name 'sequence_classification' was renamed to 'seq-cls', "
+ "'sequence_classification' has been deprecated and will be removed in the future.",
+ )
+ if task == 'seq-cls':
+ self.model = ElectraForSequenceClassification.from_pretrained(
+ pretrained_model_name_or_path='electra-small',
+ num_classes=self.num_classes,
+ **kwargs
+ )
+ self.criterion = paddle.nn.loss.CrossEntropyLoss()
+ self.metric = paddle.metric.Accuracy()
+ elif task == 'token-cls':
+ self.model = ElectraForTokenClassification.from_pretrained(
+ pretrained_model_name_or_path='electra-small',
+ num_classes=self.num_classes,
+ **kwargs
+ )
+ self.criterion = paddle.nn.loss.CrossEntropyLoss()
+ self.metric = ChunkEvaluator(
+ label_list=[self.label_map[i] for i in sorted(self.label_map.keys())]
+ )
+ elif task is None:
+ self.model = ElectraModel.from_pretrained(pretrained_model_name_or_path='electra-small', **kwargs)
+ else:
+ raise RuntimeError("Unknown task {}, task should be one in {}".format(
+ task, self._tasks_supported))
+
+ self.task = task
+
+ if load_checkpoint is not None and os.path.isfile(load_checkpoint):
+ state_dict = paddle.load(load_checkpoint)
+ self.set_state_dict(state_dict)
+ logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint))
+
+ def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None, seq_lengths=None, labels=None):
+ result = self.model(input_ids, token_type_ids, position_ids, attention_mask)
+ if self.task == 'seq-cls':
+ logits = result
+ probs = F.softmax(logits, axis=1)
+ if labels is not None:
+ loss = self.criterion(logits, labels)
+ correct = self.metric.compute(probs, labels)
+ acc = self.metric.update(correct)
+ return probs, loss, {'acc': acc}
+ return probs
+ elif self.task == 'token-cls':
+ logits = result
+ token_level_probs = F.softmax(logits, axis=-1)
+ preds = token_level_probs.argmax(axis=-1)
+ if labels is not None:
+ loss = self.criterion(logits, labels.unsqueeze(-1))
+ num_infer_chunks, num_label_chunks, num_correct_chunks = \
+ self.metric.compute(None, seq_lengths, preds, labels)
+ self.metric.update(
+ num_infer_chunks.numpy(), num_label_chunks.numpy(), num_correct_chunks.numpy())
+ _, _, f1_score = map(float, self.metric.accumulate())
+ return token_level_probs, loss, {'f1_score': f1_score}
+ return token_level_probs
+ else:
+ sequence_output, pooled_output = result
+ return sequence_output, pooled_output
+
+ @staticmethod
+ def get_tokenizer(*args, **kwargs):
+ """
+ Gets the tokenizer that is customized for this module.
+ """
+ return ElectraTokenizer.from_pretrained(
+ pretrained_model_name_or_path='electra-small', *args, **kwargs)
diff --git a/modules/text/language_model/rbt3/README.md b/modules/text/language_model/rbt3/README.md
index 0a41ed6d..89d69289 100644
--- a/modules/text/language_model/rbt3/README.md
+++ b/modules/text/language_model/rbt3/README.md
@@ -1,5 +1,5 @@
```shell
-$ hub install rtb3==2.0.1
+$ hub install rtb3==2.0.0
```
@@ -82,7 +82,7 @@ label_map = {0: 'negative', 1: 'positive'}
model = hub.Module(
name='rtb3',
- version='2.0.1',
+ version='2.0.0',
task='seq-cls',
load_checkpoint='/path/to/parameters',
label_map=label_map)
@@ -153,6 +153,6 @@ paddlehub >= 2.0.0
初始发布
-* 2.0.1
+* 2.0.0
全面升级动态图,接口有所变化。任务名称调整,增加序列标注任务`token-cls`
diff --git a/modules/text/language_model/rbt3/module.py b/modules/text/language_model/rbt3/module.py
index 3833c987..63d2b5db 100644
--- a/modules/text/language_model/rbt3/module.py
+++ b/modules/text/language_model/rbt3/module.py
@@ -29,7 +29,7 @@ from paddlehub.utils.log import logger
@moduleinfo(
name="rbt3",
- version="2.0.1",
+ version="2.0.0",
summary="rbt3, 3-layer, 768-hidden, 12-heads, 38M parameters ",
author="ymcui",
author_email="ymcui@ir.hit.edu.cn",
diff --git a/modules/text/language_model/rbtl3/README.md b/modules/text/language_model/rbtl3/README.md
index 5cdcdefe..80b1c67e 100644
--- a/modules/text/language_model/rbtl3/README.md
+++ b/modules/text/language_model/rbtl3/README.md
@@ -1,5 +1,5 @@
```shell
-$ hub install rbtl3==2.0.1
+$ hub install rbtl3==2.0.0
```
@@ -82,7 +82,7 @@ label_map = {0: 'negative', 1: 'positive'}
model = hub.Module(
name='rbtl3',
- version='2.0.1',
+ version='2.0.0',
task='seq-cls',
load_checkpoint='/path/to/parameters',
label_map=label_map)
@@ -153,6 +153,6 @@ paddlehub >= 2.0.0
初始发布
-* 2.0.1
+* 2.0.0
全面升级动态图,接口有所变化。任务名称调整,增加序列标注任务`token-cls`
diff --git a/modules/text/language_model/rbtl3/module.py b/modules/text/language_model/rbtl3/module.py
index 500fc42c..ac00a9a5 100644
--- a/modules/text/language_model/rbtl3/module.py
+++ b/modules/text/language_model/rbtl3/module.py
@@ -29,7 +29,7 @@ from paddlehub.utils.log import logger
@moduleinfo(
name="rbtl3",
- version="2.0.1",
+ version="2.0.0",
summary="rbtl3, 3-layer, 1024-hidden, 16-heads, 61M parameters ",
author="ymcui",
author_email="ymcui@ir.hit.edu.cn",
--
GitLab