README.md 5.5 KB
Newer Older
W
wuzewu 已提交
1
```shell
L
linjieccc 已提交
2
$ hub install chinese-electra-base==2.0.1
W
wuzewu 已提交
3
```
K
KP 已提交
4

W
wuzewu 已提交
5
<p align="center">
K
KP 已提交
6
<img src="http://bj.bcebos.com/ibox-thumbnail98/1a5578bfbe1ad629035f7ad1eb3d0bce?authorization=bce-auth-v1%2Ffbe74140929444858491fbf2b6bc0935%2F2020-03-31T06%3A45%3A51Z%2F1800%2F%2F02b8749292f8ba1c606410d0e4e5dbabdf1d367d80da395887775d36424ac13e"  hspace='10'/> <br />
W
wuzewu 已提交
7 8 9 10 11 12
</p>

更多详情请参考[ELECTRA论文](https://openreview.net/pdf?id=r1xMH1BtvB)

## API
```python
K
KP 已提交
13 14 15 16 17
def __init__(
        task=None,
        load_checkpoint=None,
        label_map=None,
        num_classes=2,
L
linjieccc 已提交
18 19
        suffix=False,
        **kwargs,
W
wuzewu 已提交
20 21 22
)
```

K
KP 已提交
23
创建Module对象(动态图组网版本)。
W
wuzewu 已提交
24

K
KP 已提交
25
**参数**
W
wuzewu 已提交
26

K
KP 已提交
27 28 29 30
* `task`: 任务名称,可为`seq-cls`(文本分类任务,原来的`sequence_classification`在未来会被弃用)或`token-cls`(序列标注任务)。
* `load_checkpoint`:使用PaddleHub Fine-tune api训练保存的模型参数文件路径。
* `label_map`:预测时的类别映射表。
* `num_classes`:分类任务的类别数,如果指定了`label_map`,此参数可不传,默认2分类。
L
linjieccc 已提交
31
* `suffix`: 序列标注任务的标签格式,如果设定为`True`,标签以'-B', '-I', '-E' 或者 '-S'为结尾,此参数默认为`False`
K
KP 已提交
32
* `**kwargs`:用户额外指定的关键字字典类型的参数。
W
wuzewu 已提交
33 34

```python
K
KP 已提交
35 36 37 38 39
def predict(
    data,
    max_seq_len=128,
    batch_size=1,
    use_gpu=False
W
wuzewu 已提交
40 41 42 43 44
)
```

**参数**

K
KP 已提交
45 46 47 48
* `data`: 待预测数据,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。每个样例文本数量(1个或者2个)需和训练时保持一致。
* `max_seq_len`:模型处理文本的最大长度
* `batch_size`:模型批处理大小
* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
W
wuzewu 已提交
49 50 51

**返回**

K
KP 已提交
52 53 54
* `results`:list类型,不同任务类型的返回结果如下
  * 文本分类:列表里包含每个句子的预测标签,格式为\[label\_1, label\_2, …,\]
  * 序列标注:列表里包含每个句子每个token的预测标签,格式为\[\[token\_1, token\_2, …,\], \[token\_1, token\_2, …,\], …,\]
W
wuzewu 已提交
55 56

```python
K
KP 已提交
57 58 59 60
def get_embedding(
    data,
    use_gpu=False
)
W
wuzewu 已提交
61 62
```

K
KP 已提交
63
用于获取输入文本的句子粒度特征与字粒度特征
W
wuzewu 已提交
64 65 66

**参数**

K
KP 已提交
67 68
* `data`:输入文本列表,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。
* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
W
wuzewu 已提交
69 70 71

**返回**

K
KP 已提交
72
* `results`:list类型,格式为\[\[sample\_a\_pooled\_feature, sample\_a\_seq\_feature\], \[sample\_b\_pooled\_feature, sample\_b\_seq\_feature\],…,\],其中每个元素都是对应样例的特征输出,每个样例都有句子粒度特征pooled\_feature与字粒度特征seq\_feature。
W
wuzewu 已提交
73 74 75 76 77 78 79


**代码示例**

```python
import paddlehub as hub

K
KP 已提交
80 81 82 83 84 85 86 87 88
data = [
    ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般'],
    ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片'],
    ['作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。'],
]
label_map = {0: 'negative', 1: 'positive'}

model = hub.Module(
    name='chinese-electra-base',
L
linjieccc 已提交
89
    version='2.0.1',
K
KP 已提交
90 91 92 93 94 95 96 97 98 99 100
    task='seq-cls',
    load_checkpoint='/path/to/parameters',
    label_map=label_map)
results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False)
for idx, text in enumerate(data):
    print('Data: {} \t Lable: {}'.format(text, results[idx]))
```

详情可参考PaddleHub示例:
- [文本分类](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0.0-beta/demo/text_classification)
- [序列标注](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0.0-beta/demo/sequence_labeling)
W
wuzewu 已提交
101

K
KP 已提交
102
## 服务部署
W
wuzewu 已提交
103

K
KP 已提交
104
PaddleHub Serving可以部署一个在线获取预训练词向量。
W
wuzewu 已提交
105

K
KP 已提交
106
### Step1: 启动PaddleHub Serving
W
wuzewu 已提交
107

K
KP 已提交
108
运行启动命令:
W
wuzewu 已提交
109

K
KP 已提交
110 111
```shell
$ hub serving start -m chinese-electra-base
W
wuzewu 已提交
112 113
```

K
KP 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
这样就完成了一个获取预训练词向量服务化API的部署,默认端口号为8866。

**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。

### Step2: 发送预测请求

配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果

```python
import requests
import json

# 指定用于获取embedding的文本[[text_1], [text_2], ... ]}
text = [["今天是个好日子"], ["天气预报说今天要下雨"]]
# 以key的方式指定text传入预测方法的时的参数,此例中为"data"
# 对应本地部署,则为module.get_embedding(data=text)
data = {"data": text}
# 发送post请求,content-type类型应指定json方式,url中的ip地址需改为对应机器的ip
url = "http://10.12.121.132:8866/predict/chinese-electra-base"
# 指定post请求的headers为application/json方式
headers = {"Content-Type": "application/json"}

r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r.json())
```
W
wuzewu 已提交
139 140 141 142 143 144 145 146

##   查看代码

https://github.com/ymcui/Chinese-ELECTRA


## 依赖

K
KP 已提交
147
paddlepaddle >= 2.0.0
W
wuzewu 已提交
148

K
KP 已提交
149
paddlehub >= 2.0.0
W
wuzewu 已提交
150 151 152 153 154 155

## 更新历史

* 1.0.0

  初始发布
W
wuzewu 已提交
156

K
KP 已提交
157 158 159
* 2.0.0

  全面升级动态图,接口有所变化。任务名称调整,增加序列标注任务`token-cls`
160 161 162

* 2.0.1

W
wuzewu 已提交
163
  增加文本匹配任务`text-matching`