提交 1d260307 编写于 作者: W Wgm-Inspur

Optimizing the output of classification probablity in...

Optimizing the output of classification probablity in demo\text_classification\predict.py and corresponding README.md
上级 c2abb605
......@@ -80,10 +80,12 @@ train_dataset = hub.datasets.ChnSentiCorp(
tokenizer=model.get_tokenizer(), max_seq_len=128, mode='train')
dev_dataset = hub.datasets.ChnSentiCorp(
tokenizer=model.get_tokenizer(), max_seq_len=128, mode='dev')
test_dataset = hub.datasets.ChnSentiCorp(
tokenizer=model.get_tokenizer(), max_seq_len=128, mode='test')
```
* `tokenizer`:表示该module所需用到的tokenizer,其将对输入文本完成切词,并转化成module运行所需模型输入格式。
* `mode`:选择数据模式,可选项有 `train`, `test`, `val`, 默认为`train`
* `mode`:选择数据模式,可选项有 `train`, `test`, `dev`, 默认为`train`
* `max_seq_len`:ERNIE/BERT模型使用的最大序列长度,若出现显存不足,请适当调低这一参数。
预训练模型ERNIE对中文数据的处理是以字为单位,tokenizer作用为将原始输入文本转化成模型model可以接受的输入数据形式。 PaddleHub 2.0中的各种预训练模型已经内置了相应的tokenizer,可以通过`model.get_tokenizer`方法获取。
......@@ -95,7 +97,7 @@ dev_dataset = hub.datasets.ChnSentiCorp(
```python
optimizer = paddle.optimizer.Adam(learning_rate=5e-5, parameters=model.parameters())
trainer = hub.Trainer(model, optimizer, checkpoint_dir='test_ernie_text_cls')
trainer = hub.Trainer(model, optimizer, checkpoint_dir='test_ernie_text_cls', use_gpu=True)
trainer.train(train_dataset, epochs=3, batch_size=32, eval_dataset=dev_dataset)
......@@ -153,6 +155,7 @@ data = [
['作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。'],
]
label_map = {0: 'negative', 1: 'positive'}
label_map_rev = {'negative':0, 'positive':1}
model = hub.Module(
name='ernie_tiny',
......@@ -160,9 +163,9 @@ model = hub.Module(
task='seq-cls',
load_checkpoint='./test_ernie_text_cls/best_model/model.pdparams',
label_map=label_map)
results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False)
results, probs = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False, return_prob=True)
for idx, text in enumerate(data):
print('Data: {} \t Lable: {}'.format(text[0], results[idx]))
print('Data: {} \t Lable: {} \t Prob: {}'.format(text[0], results[idx], probs[idx][label_map_rev[results[idx]]]))
```
参数配置正确后,请执行脚本`python predict.py`, 加载模型具体可参见[加载](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc/api/paddle/framework/io/load_cn.html#load)
......@@ -21,6 +21,7 @@ if __name__ == '__main__':
['作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。'],
]
label_map = {0: 'negative', 1: 'positive'}
label_map_rev = {'negative':0, 'positive':1}
model = hub.Module(
name='ernie_tiny',
......@@ -30,4 +31,4 @@ if __name__ == '__main__':
label_map=label_map)
results, probs = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False, return_prob=True)
for idx, text in enumerate(data):
print('Data: {} \t Lable: {}'.format(text[0], results[idx]))
print('Data: {} \t Lable: {} \t Prob: {}'.format(text[0], results[idx], probs[idx][label_map_rev[results[idx]]]))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册