README.md 5.4 KB
Newer Older
W
wuzewu 已提交
1
```shell
K
KP 已提交
2
$ hub install rbtl3==2.0.0
W
wuzewu 已提交
3 4 5 6 7 8 9 10 11
```
<p align="center">
<img src="https://bj.bcebos.com/paddlehub/paddlehub-img/bert_network.png"  hspace='10'/> <br />
</p>

更多详情请参考[RoBERTa论文](https://arxiv.org/abs/1907.11692)[Chinese-BERT-wwm技术报告](https://arxiv.org/abs/1906.08101)

## API
```python
K
KP 已提交
12 13 14 15 16
def __init__(
    task=None,
    load_checkpoint=None,
    label_map=None,
    num_classes=2,
L
linjieccc 已提交
17
    suffix=False,
K
KP 已提交
18
    **kwargs,
W
wuzewu 已提交
19 20 21
)
```

K
KP 已提交
22
创建Module对象(动态图组网版本)。
W
wuzewu 已提交
23

K
KP 已提交
24
**参数**
W
wuzewu 已提交
25

K
KP 已提交
26 27 28 29
* `task`: 任务名称,可为`seq-cls`(文本分类任务,原来的`sequence_classification`在未来会被弃用)或`token-cls`(序列标注任务)。
* `load_checkpoint`:使用PaddleHub Fine-tune api训练保存的模型参数文件路径。
* `label_map`:预测时的类别映射表。
* `num_classes`:分类任务的类别数,如果指定了`label_map`,此参数可不传,默认2分类。
L
linjieccc 已提交
30
* `suffix`: 序列标注任务的标签格式,如果设定为`True`,标签以'-B', '-I', '-E' 或者 '-S'为结尾,此参数默认为`False`
K
KP 已提交
31
* `**kwargs`:用户额外指定的关键字字典类型的参数。
L
linjieccc 已提交
32

W
wuzewu 已提交
33
```python
K
KP 已提交
34 35 36 37 38
def predict(
    data,
    max_seq_len=128,
    batch_size=1,
    use_gpu=False
W
wuzewu 已提交
39 40 41 42 43
)
```

**参数**

44
* `data`: 待预测数据,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。每个样例文本数量(1个或者2个)需和训练时保持一致。
K
KP 已提交
45 46 47
* `max_seq_len`:模型处理文本的最大长度
* `batch_size`:模型批处理大小
* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
W
wuzewu 已提交
48 49 50

**返回**

51 52 53 54
* `results`:list类型,不同任务类型的返回结果如下
  * 文本分类:列表里包含每个句子的预测标签,格式为\[label\_1, label\_2, …,\]
  * 序列标注:列表里包含每个句子每个token的预测标签,格式为\[\[token\_1, token\_2, …,\], \[token\_1, token\_2, …,\], …,\]

W
wuzewu 已提交
55
```python
K
KP 已提交
56 57 58 59
def get_embedding(
    data,
    use_gpu=False
)
W
wuzewu 已提交
60 61
```

K
KP 已提交
62
用于获取输入文本的句子粒度特征与字粒度特征
W
wuzewu 已提交
63 64 65

**参数**

K
KP 已提交
66 67
* `data`:输入文本列表,格式为\[\[sample\_a\_text\_a, sample\_a\_text\_b\], \[sample\_b\_text\_a, sample\_b\_text\_b\],…,\],其中每个元素都是一个样例,每个样例可以包含text\_a与text\_b。
* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
W
wuzewu 已提交
68 69 70

**返回**

K
KP 已提交
71 72
* `results`:list类型,格式为\[\[sample\_a\_pooled\_feature, sample\_a\_seq\_feature\], \[sample\_b\_pooled\_feature, sample\_b\_seq\_feature\],…,\],其中每个元素都是对应样例的特征输出,每个样例都有句子粒度特征pooled\_feature与字粒度特征seq\_feature。

W
wuzewu 已提交
73 74 75 76 77 78

**代码示例**

```python
import paddlehub as hub

K
KP 已提交
79 80 81 82 83 84 85 86 87
data = [
    ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般'],
    ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片'],
    ['作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。'],
]
label_map = {0: 'negative', 1: 'positive'}

model = hub.Module(
    name='rbtl3',
K
KP 已提交
88
    version='2.0.0',
K
KP 已提交
89 90 91 92 93 94 95
    task='seq-cls',
    load_checkpoint='/path/to/parameters',
    label_map=label_map)
results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False)
for idx, text in enumerate(data):
    print('Data: {} \t Lable: {}'.format(text, results[idx]))
```
W
wuzewu 已提交
96

K
KP 已提交
97 98 99
详情可参考PaddleHub示例:
- [文本分类](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0.0-beta/demo/text_classification)
- [序列标注](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0.0-beta/demo/sequence_labeling)
W
wuzewu 已提交
100

K
KP 已提交
101
## 服务部署
W
wuzewu 已提交
102

K
KP 已提交
103
PaddleHub Serving可以部署一个在线获取预训练词向量。
W
wuzewu 已提交
104

K
KP 已提交
105
### Step1: 启动PaddleHub Serving
W
wuzewu 已提交
106

K
KP 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
运行启动命令:

```shell
$ hub serving start -m rbtl3
```

这样就完成了一个获取预训练词向量服务化API的部署,默认端口号为8866。

**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。

### Step2: 发送预测请求

配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果

```python
import requests
import json

# 指定用于获取embedding的文本[[text_1], [text_2], ... ]}
text = [["今天是个好日子"], ["天气预报说今天要下雨"]]
# 以key的方式指定text传入预测方法的时的参数,此例中为"data"
# 对应本地部署,则为module.get_embedding(data=text)
data = {"data": text}
# 发送post请求,content-type类型应指定json方式,url中的ip地址需改为对应机器的ip
url = "http://10.12.121.132:8866/predict/rbtl3"
# 指定post请求的headers为application/json方式
headers = {"Content-Type": "application/json"}

r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r.json())
W
wuzewu 已提交
137 138 139 140 141 142 143 144 145 146 147 148
```

## 查看代码
https://github.com/ymcui/Chinese-BERT-wwm


## 贡献者

[ymcui](https://github.com/ymcui)

## 依赖

K
KP 已提交
149
paddlepaddle >= 2.0.0
W
wuzewu 已提交
150

K
KP 已提交
151
paddlehub >= 2.0.0
W
wuzewu 已提交
152 153 154 155 156 157

## 更新历史

* 1.0.0

  初始发布
K
KP 已提交
158

K
KP 已提交
159
* 2.0.0
K
KP 已提交
160 161

  全面升级动态图,接口有所变化。任务名称调整,增加序列标注任务`token-cls`
162 163 164

* 2.0.1

W
wuzewu 已提交
165
  增加文本匹配任务`text-matching`