提交 ff1178bb 编写于 作者: W wangxiao1021

add sequence labeling

上级 5d517171
......@@ -11,10 +11,11 @@ PALM (PArallel Learning from Multi-tasks) 是一个强大通用、预置丰富
- [理论准备](#理论准备)
- [框架原理](#框架原理)
- [预训练模型](#预训练模型)
- [个DEMO入门PALM](#三个demo入门paddlepalm)
- [个DEMO入门PALM](#三个demo入门paddlepalm)
- [DEMO1:单任务训练](#demo1单任务训练)
- [DEMO2:多任务辅助训练与目标任务预测](#demo2多任务辅助训练与目标任务预测)
- [DEMO3:多目标任务联合训练与任务层参数复用](#demo3多目标任务联合训练与任务层参数复用)
- [DEMO4:序列标注](#demo4序列标注)
- [进阶篇](#进阶篇)
- [配置广播机制](#配置广播机制)
- [reader、backbone与paradigm的选择](#readerbackbone与paradigm的选择)
......@@ -78,12 +79,14 @@ cd PALM && python setup.py install
│ ├── cls.py # 文本分类数据集工具
│ ├── match.py # 文本匹配数据集工具
│ ├── mrc.py # 机器阅读理解数据集工具
│ └── mlm.py # 掩码语言模型(mask language model)数据集生成与处理工具
│ ├── mlm.py # 掩码语言模型(mask language model)数据集生成与处理工具
│ └── ner.py # 序列标注数据集工具
└── paradigm # 任务范式
├── cls.py # 文本分类
├── match.py # 文本匹配
├── mrc.py # 机器阅读理解
└── mlm.py # 掩码语言模型(mask language model)
├── mlm.py # 掩码语言模型(mask language model)
└── ner.py # 序列标注
```
......@@ -146,7 +149,7 @@ python download_models.py -d bert-en-uncased-large
## 个DEMO入门PaddlePALM
## 个DEMO入门PaddlePALM
### DEMO1:单任务训练
......@@ -470,6 +473,117 @@ cls3: inference model saved at output_model/thirdrun/infer_model
对本DEMO更深入的理解可以参考[多目标任务下的训练终止条件与预期训练步数](#多目标任务下的训练终止条件与预期训练步数)。
### DEMO4:序列标注
> 本demo路径位于`demo/demo4`
除以上三个demo涉及到的任务,框架新增支持序列标注任务。本demo实例为基于微软提供的公开数据集(Airline Travel Information System),实现槽位识别任务的训练及预测。
用户进入本demo目录后,可通过运行如下脚本一键开始本节任务的训练:
```shell
bash run.sh
```
下面以该任务为例,讲解如何基于paddlepalm框架轻松实现序列标注任务。
**1. 配置任务实例**
首先,我们编写`tasks`文件夹下的该任务实例的配置文件`atis_alot.yaml`,若该任务实例参与训练或预测,则框架将自动解析该配置文件并创建相应的任务实例。配置文件需符合yaml格式的要求。一个任务实例的配置在配置文件中,设置用于训练的文件路径`train_file`,保存label->index的map文件地址`label_map_config`,数据集载入与处理工具`reader`和任务范式`paradigm`,类别总数`n_classes`:
```yaml
train_file: "data/atis_slot/train.tsv"
label_map_config: "data/atis_slot/label_map.json"
reader: ner
paradigm: ner
n_classes: 130
use_crf: true
```
这里,如需在序列标注任务中使用线性链条件随机场,需设置`use_crf`参数(默认为`false`)。在本demo中,设置为`true`。
配置reader的预处理规则:
```yaml
max_seq_len: 128
do_lower_case: False
vocab_path: "../../pretrain/ernie-en-uncased-large/vocab.txt"
```
**2.配置backbone和训练规则**
编写全局配置文件`config.yaml`,配置需要学习的任务`task_instance`、模型的保存路径`save_path`、基于的主干网络`backbone`、优化器`optimizer`等:
```yaml
task_instance: "atis_slot"
save_path: "output_model/run"
backbone: "ernie"
backbone_config_path: "../../pretrain/ernie-en-uncased-large/ernie_config.json"
batch_size: 32
pred_batch_size: 32
num_epochs: 2
optimizer: "adam"
learning_rate: 2e-5
warmup_proportion: 0.1
weight_decay: 0.01
print_every_n_steps: 10
lr_scheduler: "linear_warmup_decay"
```
**3.开始训练**
如同前三个demo,创建Controller,实例化任务、载入预训练模型并启动atis_slot任务训练:
```python
# Demo 4: single task training of ATIS_SLOT
import paddlepalm as palm
if __name__ == '__main__':
controller = palm.Controller('config.yaml')
controller.load_pretrain('../../pretrain/bert-en-uncased-large/params')
controller.train()
```
训练日志如下,可以看到loss值随着训练收敛。在训练结束后,`Controller`自动为atis_slot任务保存预测模型。
```
Global step: 10. Task: atis_slot, step 10/154 (epoch 0), loss: 59.974, speed: 0.64 steps/s
Global step: 20. Task: atis_slot, step 20/154 (epoch 0), loss: 33.286, speed: 0.77 steps/s
Global step: 30. Task: atis_slot, step 30/154 (epoch 0), loss: 19.285, speed: 0.68 steps/s
...
Global step: 280. Task: atis_slot, step 126/154 (epoch 1), loss: 2.350, speed: 0.56 steps/s
Global step: 290. Task: atis_slot, step 136/154 (epoch 1), loss: 1.436, speed: 0.58 steps/s
Global step: 300. Task: atis_slot, step 146/154 (epoch 1), loss: 2.353, speed: 0.58 steps/s
atis_slot: train finished!
atis_slot: inference model saved at output_model/run/atis_slot/infer_model
```
**4.预测**
在得到目标任务的预测模型(inference_model)后,完成对该目标任务的预测,`run.py`的预测部分代码如下:
```python
controller = palm.Controller(config='config.yaml', task_dir='tasks', for_train=False)
controller.pred('atis_slot', inference_model_dir='output_model/fourthrun/atis_slot/infer_model')
```
我们可以在`output_models/fourthrun/atis_slot/`文件夹下的`predictions.json`文件中看到类似如下的预测结果:
```
[129, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 5, 19, 1, 1, 1, 1, 1, 21, 21, 68, 129]
[129, 1, 39, 37, 1, 1, 1, 1, 1, 2, 1, 5, 19, 1, 23, 3, 4, 129, 129, 129, 129, 129]
[129, 1, 39, 37, 1, 1, 1, 1, 1, 1, 2, 1, 5, 19, 129, 129, 129, 129, 129, 129, 129, 129]
...
```
如上所示,每一行是测试集中的每一条text对应的标注结果,其中`129`为padding。
## 进阶篇
本章节更深入的对paddlepalm的使用方法展开介绍,并提供一些提高使用效率的小技巧。
......@@ -777,6 +891,40 @@ mask_pos: 一个shape为[None]的向量,长度与`mask_pos`一致且元素一
task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE模型的输入。
```
#### 序列标注数据集reader工具:ner
该reader完成文本分类数据集的载入与处理,reader接受[tsv格式](https://en.wikipedia.org/wiki/Tab-separated_values)的数据集输入,数据集应该包含两列,一列为原始文本`text_a`,一列为样本标签`label`,文本中词与词之间及标签中的tag之间均用`^B`分隔。数据集范例可参考`data/atis_slot`中的数据集文件,格式形如
```
text_a label
i[^B]want[^B]to[^B]fly[^B]from[^B]boston[^B]at[^B]838[^B]am[^B]and[^B]arrive[^B]in[^B]denver[^B]at[^B]1110[^B]in[^B]the[^B]morning O[^B]O[^B]O[^B]O[^B]O[^B]B-fromloc.city_name[^B]O[^B]B-depart_time.time[^B]I-depart_time.time[^B]O[^B]O[^B]O[^B]B-toloc.city_name[^B]O[^B]B-arrive_time.time[^B]O[^B]O[^B]B-arrive_time.period_of_day
what[^B]flights[^B]are[^B]available[^B]from[^B]pittsburgh[^B]to[^B]baltimore[^B]on[^B]thursday[^B]morning O[^B]O[^B]O[^B]O[^B]O[^B]B-fromloc.city_name[^B]O[^B]B-toloc.city_name[^B]O[^B]B-depart_date.day_name[^B]B-depart_time.period_of_day
what[^B]is[^B]the[^B]arrival[^B]time[^B]in[^B]san[^B]francisco[^B]for[^B]the[^B]755[^B]am[^B]flight[^B]leaving[^B]washington O[^B]O[^B]O[^B]B-flight_time[^B]I-flight_time[^B]O[^B]B-fromloc.city_name[^B]I-fromloc.city_name[^B]O[^B]O[^B]B-depart_time.time[^B]I-depart_time.time[^B]O[^B]O[^B]B-fromloc.city_name
```
***注意:数据集的第一列必须为header,即标注每一列的列名***
该reader额外包含以下配置字段
```yaml
label_map_config : str类型。保存label->index的map文件地址。
n_classes(REQUIRED): int类型。分类任务的类别数。
```
reader的输出(生成器每次yield出的数据)包含以下字段
```yaml
token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的单词id。
position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。
segment_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持BERT、ERNIE等模型的输入。
input_mask: 一个shape为[batch_size, seq_len]的矩阵,其中的每个元素为0或1,表示该位置是否是padding词(为1时代表是真实词,为0时代表是填充词)。
label_ids: 一个shape为[batch_size]的矩阵,其中的每个元素为该样本的类别标签。
task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE模型的输入。
seq_lens: 一个shape为[batch_size]的矩阵,对应每一行样本的序列长度。
```
当处于预测阶段时,reader所yield出的数据不会包含`label_ids`字段。
## 附录B:内置主干网络(backbone)
框架中内置了BERT和ERNIE作为主干网络,未来框架会引入更多的骨干网络如XLNet等。
......@@ -797,9 +945,9 @@ input_mask: 一个shape为[batch_size, seq_len]的矩阵,其中的每个元素
```yaml
word_embedding: 一个shape为[batch_size, seq_len, emb_size]的张量(Tensor),float32类型。表示当前batch中各个样本的(上下文无关)词向量序列。
embedding_table: 一个shape为[vocab_size, emb_size]的矩阵,float32类型。表示BERT当前维护的词向量查找表矩阵。
encoder_outputs: 一个shape为[batch_size, seq_len, hidden_size]的Tensor, float32类型。表示BERT encoder对当前batch中各个样本的encoding结果。
sentence_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表BERT encoder对当前batch中相应样本的句子向量(sentence embedding)
sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表BERT encoder对当前batch中相应样本的句子向量(sentence embedding)
encoder_outputs: 一个shape为[batch_size, seq_len, hidden_size]的Tensor, float32类型。表示backbone模型的encoder对当前batch中各个样本的encoding结果。
sentence_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表backbone模型的encoder对当前batch中相应样本的句子向量(sentence embedding)
sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表backbone模型的encoder对当前batch中相应样本的句子向量(sentence embedding)
```
#### ERNIE
......@@ -820,9 +968,9 @@ task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE f
```yaml
word_embedding: 一个shape为[batch_size, seq_len, emb_size]的张量(Tensor),float32类型。表示当前batch中各个样本的(上下文无关)词向量序列。
embedding_table: 一个shape为[vocab_size, emb_size]的矩阵,float32类型。表示BERT当前维护的词向量查找表矩阵。
encoder_outputs: 一个shape为[batch_size, seq_len, hidden_size]的Tensor, float32类型。表示BERT encoder对当前batch中各个样本的encoding结果。
sentence_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表BERT encoder对当前batch中相应样本的句子向量(sentence embedding)
sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表BERT encoder对当前batch中相应样本的句子向量(sentence embedding)
encoder_outputs: 一个shape为[batch_size, seq_len, hidden_size]的Tensor, float32类型。表示backbone模型的encoder对当前batch中各个样本的encoding结果。
sentence_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表backbone模型的encoder对当前batch中相应样本的句子向量(sentence embedding)
sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表backbone模型的encoder对当前batch中相应样本的句子向量(sentence embedding)
```
......@@ -843,13 +991,13 @@ save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模
训练阶段:
```yaml
sentence_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表BERT encoder对当前batch中相应样本的句子向量(sentence embedding)
sentence_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表backbone模型的encoder对当前batch中相应样本的句子向量(sentence embedding)
label_ids: 一个shape为[batch_size]的矩阵,其中的每个元素为该样本的类别标签。
```
预测阶段:
```yaml
sentence_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表BERT encoder对当前batch中相应样本的句子向量(sentence embedding)
sentence_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表backbone模型的encoder对当前batch中相应样本的句子向量(sentence embedding)
```
在训练阶段,输出loss;预测阶段输出各个类别的预测概率。
......@@ -868,13 +1016,13 @@ save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模
训练阶段:
```yaml
sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表BERT encoder对当前batch中相应样本的句子向量(sentence embedding)
sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表backbone模型的encoder对当前batch中相应样本的句子向量(sentence embedding)
label_ids: 一个shape为[batch_size]的矩阵,其中的每个元素为该样本的类别标签,为0时表示两段文本不匹配,为1时代表构成匹配
```
预测阶段:
```yaml
sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表BERT encoder对当前batch中相应样本的句子向量(sentence embedding)
sentence_pair_embedding: 一个shape为[batch_size, hidden_size]的matrix, float32类型。每一行代表backbone模型的encoder对当前batch中相应样本的句子向量(sentence embedding)
```
在训练阶段,输出loss;预测阶段输出匹配与否的概率分布。
......@@ -895,14 +1043,14 @@ save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模
训练阶段:
```yaml
encoder_outputs: 一个shape为[batch_size, seq_len, hidden_size]的Tensor, float32类型。表示BERT encoder对当前batch中各个样本的encoding结果。
encoder_outputs: 一个shape为[batch_size, seq_len, hidden_size]的Tensor, float32类型。表示backbone模型的encoder对当前batch中各个样本的encoding结果。
start_positions: 一个shape为[batch_size]的向量,每个元素代表当前样本的答案片段的起始位置。
end_positions: 一个shape为[batch_size]的向量,每个元素代表当前样本的答案片段的结束位置。
```
预测阶段:
```yaml
encoder_outputs: 一个shape为[batch_size, seq_len, hidden_size]的Tensor, float32类型。表示BERT encoder对当前batch中各个样本的encoding结果。
encoder_outputs: 一个shape为[batch_size, seq_len, hidden_size]的Tensor, float32类型。表示backbone模型的encoder对当前batch中各个样本的encoding结果。
unique_ids: 一个shape为[batch_size, seq_len]的矩阵,代表每个样本的全局唯一的id,用于预测后对滑动窗口的结果进行合并。
```
......@@ -915,9 +1063,36 @@ unique_ids: 一个shape为[batch_size, seq_len]的矩阵,代表每个样本的
mask_label: 一个shape为[None]的向量,其中的每个元素为被mask掉的单词的真实单词id。
mask_pos": 一个shape为[None]的向量,长度与`mask_pos`一致且元素一一对应。每个元素表示被mask掉的单词的位置。
embedding_table: 一个shape为[vocab_size, emb_size]的矩阵,float32类型。表示BERT当前维护的词向量查找表矩阵。
encoder_outputs: 一个shape为[batch_size, seq_len, hidden_size]的Tensor, float32类型。表示BERT encoder对当前batch中各个样本的encoding结果。
encoder_outputs: 一个shape为[batch_size, seq_len, hidden_size]的Tensor, float32类型。表示backbone模型的encoder对当前batch中各个样本的encoding结果。
```
#### 序列标注范式:ner
序列标注范式额外包含以下配置字段:
```yaml
n_classes(REQUIRED): int类型。序列标注任务的类别数。
pred_output_path (OPTIONAL) : str类型。预测输出结果的保存路径,当该参数未空时,保存至全局配置文件中的`save_path`字段指定路径下的任务目录。
save_infermodel_every_n_steps (OPTIONAL) : int类型。周期性保存预测模型的间隔,未设置或设为-1时仅在该任务训练结束时保存预测模型。默认为-1。
```
序列标注范式包含如下的输入对象:
训练阶段:
```yaml
encoder_outputs: 一个shape为[batch_size, seq_len, hidden_size]的Tensor, float32类型。表示backbone模型的encoder对当前batch中各个样本的encoding结果。
label_ids: 一个shape为[batch_size, seq_lens]的矩阵,其中的每个元素为该样本的类别标签。
```
预测阶段:
```yaml
encoder_outputs: 一个shape为[batch_size, seq_len, hidden_size]的Tensor, float32类型。表示backbone模型的encoder对当前batch中各个样本的encoding结果。
```
在训练阶段,输出loss;预测阶段输出每一行文本对应的标注。
## 附录D:可配置的全局参数列表
```yaml
......
task_instance: "atis_slot"
save_path: "output_model/fourthrun"
backbone: "ernie"
backbone_config_path: "../../pretrain/ernie-en-uncased-large/ernie_config.json"
vocab_path: "../../pretrain/ernie-en-uncased-large/vocab.txt"
do_lower_case: False
max_seq_len: 128
batch_size: 32
pred_batch_size: 32
num_epochs: 2
optimizer: "adam"
learning_rate: 2e-5
warmup_proportion: 0.1
weight_decay: 0.01
print_every_n_steps: 10
lr_scheduler: "linear_warmup_decay"
{"B-time_relative": 109, "B-stoploc.state_code": 101, "B-depart_date.today_relative": 18, "B-arrive_date.date_relative": 78, "PAD": 0, "B-depart_date.date_relative": 26, "I-restriction_code": 87, "B-return_date.month_name": 50, "I-time": 110, "B-depart_date.day_name": 8, "I-arrive_time.end_time": 75, "B-fromloc.airport_code": 57, "B-cost_relative": 13, "B-connect": 84, "B-return_time.period_mod": 114, "B-arrive_time.period_mod": 65, "B-flight_number": 64, "B-depart_time.time_relative": 23, "I-toloc.city_name": 19, "B-arrive_time.period_of_day": 7, "B-depart_time.period_of_day": 9, "I-return_date.date_relative": 119, "I-depart_time.start_time": 31, "B-fare_amount": 16, "I-depart_time.time_relative": 96, "B-city_name": 20, "B-depart_date.day_number": 37, "I-arrive_time.period_of_day": 95, "I-depart_date.today_relative": 115, "I-airport_name": 90, "I-arrive_date.day_number": 61, "B-toloc.state_code": 48, "B-arrive_date.month_name": 45, "B-stoploc.airport_code": 126, "I-depart_time.time": 4, "B-airport_code": 81, "B-arrive_time.start_time": 73, "B-period_of_day": 98, "B-arrive_time.time": 6, "I-flight_stop": 72, "B-toloc.state_name": 36, "B-booking_class": 128, "I-meal_code": 103, "B-arrive_time.end_time": 74, "B-meal": 47, "B-arrive_time.time_relative": 29, "B-return_date.day_number": 51, "I-city_name": 56, "B-day_name": 97, "B-or": 60, "I-depart_date.day_name": 99, "I-arrive_time.time": 54, "B-economy": 62, "B-return_date.day_name": 123, "B-fromloc.airport_name": 34, "O": 1, "B-class_type": 24, "B-meal_code": 102, "B-depart_time.time": 3, "B-return_date.today_relative": 121, "I-depart_date.day_number": 38, "B-restriction_code": 86, "B-fare_basis_code": 41, "I-stoploc.city_name": 68, "I-fare_basis_code": 93, "B-flight": 129, "B-airline_name": 27, "B-compartment": 125, "B-airline_code": 52, "B-fromloc.state_name": 76, "B-flight_stop": 55, "B-day_number": 118, "B-flight_mod": 43, "I-meal_description": 120, "B-depart_time.start_time": 30, "B-today_relative": 100, "I-arrive_time.time_relative": 94, "B-arrive_date.day_number": 46, "I-flight_time": 11, "B-arrive_date.day_name": 58, "I-fromloc.state_name": 77, "B-mod": 40, "B-depart_date.month_name": 39, "B-flight_days": 67, "I-mod": 116, "I-cost_relative": 44, "B-stoploc.airport_name": 107, "B-flight_time": 10, "I-today_relative": 104, "B-fromloc.city_name": 2, "B-transport_type": 42, "B-return_time.period_of_day": 111, "B-time": 59, "B-toloc.country_name": 91, "B-return_date.date_relative": 80, "B-round_trip": 14, "I-transport_type": 66, "I-fromloc.city_name": 12, "B-depart_date.year": 79, "I-return_date.day_number": 112, "I-flight_mod": 105, "B-toloc.city_name": 5, "B-depart_time.period_mod": 53, "I-arrive_time.start_time": 85, "B-state_code": 71, "B-airport_name": 89, "B-stoploc.city_name": 21, "I-toloc.airport_name": 70, "B-meal_description": 49, "I-class_type": 25, "B-toloc.airport_code": 22, "I-depart_time.period_of_day": 113, "I-toloc.state_name": 88, "B-days_code": 92, "B-toloc.airport_name": 69, "B-arrive_date.today_relative": 108, "I-round_trip": 15, "I-state_name": 127, "I-fare_amount": 17, "I-fromloc.airport_name": 35, "I-flight_number": 124, "I-airline_name": 28, "B-state_name": 106, "I-economy": 63, "B-depart_time.end_time": 32, "B-aircraft_code": 82, "I-return_date.today_relative": 122, "B-month_name": 117, "B-fromloc.state_code": 83, "I-depart_time.end_time": 33}
text_a label
iwouldliketofindaflightfromcharlottetolasvegasthatmakesastopinst.louis OOOOOOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameOOOOOB-stoploc.city_nameI-stoploc.city_name
onaprilfirstineedaticketfromtacomatosanjosedepartingbefore7am OB-depart_date.month_nameB-depart_date.day_numberOOOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameOB-depart_time.time_relativeB-depart_time.timeI-depart_time.time
onaprilfirstineedaflightgoingfromphoenixtosandiego OB-depart_date.month_nameB-depart_date.day_numberOOOOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_name
iwouldlikeaflighttravelingonewayfromphoenixtosandiegoonaprilfirst OOOOOOB-round_tripI-round_tripOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameOB-depart_date.month_nameB-depart_date.day_number
iwouldlikeaflightfromorlandotosaltlakecityforaprilfirstondeltaairlines OOOOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameI-toloc.city_nameOB-depart_date.month_nameB-depart_date.day_numberOB-airline_nameI-airline_name
ineedaflightfromtorontotonewarkonewayleavingwednesdayeveningorthursdaymorning OOOOOB-fromloc.city_nameOB-toloc.city_nameB-round_tripI-round_tripOB-depart_date.day_nameB-depart_time.period_of_dayOB-depart_date.day_nameB-depart_time.period_of_day
mondaymorningiwouldliketoflyfromcolumbustoindianapolis B-depart_date.day_nameB-depart_time.period_of_dayOOOOOOB-fromloc.city_nameOB-toloc.city_name
onwednesdayaprilsixthiwouldliketoflyfromlongbeachtocolumbusafter3pm OB-depart_date.day_nameB-depart_date.month_nameB-depart_date.day_numberOOOOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameB-depart_time.time_relativeB-depart_time.timeI-depart_time.time
after12pmonwednesdayaprilsixthiwouldliketoflyfromlongbeachtocolumbus B-depart_time.time_relativeB-depart_time.timeI-depart_time.timeOB-depart_date.day_nameB-depart_date.month_nameB-depart_date.day_numberOOOOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_name
arethereanyflightsfromlongbeachtocolumbusonwednesdayaprilsixth OOOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameOB-depart_date.day_nameB-depart_date.month_nameB-depart_date.day_number
findaflightfrommemphistotacomadinner OOOOB-fromloc.city_nameOB-toloc.city_nameB-meal_description
onnextwednesdayflightfromkansascitytochicagoshouldarriveinchicagoaround7pmreturnflightonthursday OB-depart_date.date_relativeB-depart_date.day_nameOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameOOOB-toloc.city_nameB-arrive_time.time_relativeB-arrive_time.timeI-arrive_time.timeOOOB-return_date.day_name
showflightandpriceskansascitytochicagoonnextwednesdayarrivinginchicagoby7pm OOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameOB-depart_date.date_relativeB-depart_date.day_nameOOB-toloc.city_nameB-arrive_time.time_relativeB-arrive_time.timeI-arrive_time.time
flightonamericanfrommiamitochicagoarriveinchicagoabout5pm OOB-airline_nameOB-fromloc.city_nameOB-toloc.city_nameOOB-toloc.city_nameB-arrive_time.time_relativeB-arrive_time.timeI-arrive_time.time
findflightsarrivingnewyorkcitynextsaturday OOOB-toloc.city_nameI-toloc.city_nameI-toloc.city_nameB-arrive_date.date_relativeB-arrive_date.day_name
findnonstopflightsfromsaltlakecitytonewyorkonsaturdayaprilninth OB-flight_stopOOB-fromloc.city_nameI-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameOB-depart_date.day_nameB-depart_date.month_nameB-depart_date.day_number
showflightsfromburbanktomilwaukeefortoday OOOB-fromloc.city_nameOB-toloc.city_nameOB-depart_date.today_relative
showflightstomorroweveningfrommilwaukeetost.louis OOB-depart_date.today_relativeB-depart_time.period_of_dayOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_name
showflightssaturdayeveningfromst.louistoburbank OOB-depart_date.day_nameB-depart_time.period_of_dayOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_name
showflightsfromburbanktost.louisonmonday OOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameOB-depart_date.day_name
showflightsfromburbanktomilwaukeeonmonday OOOB-fromloc.city_nameOB-toloc.city_nameOB-depart_date.day_name
showflightstuesdayeveningfrommilwaukeetost.louis OOB-depart_date.day_nameB-depart_time.period_of_dayOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_name
showflightswednesdayeveningfromst.louistoburbank OOB-depart_date.day_nameB-depart_time.period_of_dayOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_name
whichflightstravelfromkansascitytolosangeles OOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameI-toloc.city_name
whatflightstravelfromlasvegastolosangeles OOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameI-toloc.city_name
whichflightstravelfromkansascitytolosangelesonaprilninth OOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameOB-depart_date.month_nameB-depart_date.day_number
whichflightstravelfromlasvegastolosangelescaliforniaandarriveonaprilninthbetween4and5pm OOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameB-toloc.state_nameOOOB-arrive_date.month_nameB-arrive_date.day_numberOB-arrive_time.start_timeOB-arrive_time.end_timeI-arrive_time.end_time
whichflightsonusairgofromorlandotocleveland OOOB-airline_nameI-airline_nameOOB-fromloc.city_nameOB-toloc.city_name
whichflightstravelfromclevelandtoindianapolisonaprilfifth OOOOB-fromloc.city_nameOB-toloc.city_nameOB-depart_date.month_nameB-depart_date.day_number
whichflightstravelfromindianapolistosandiegoonaprilfifth OOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameOB-depart_date.month_nameB-depart_date.day_number
whichflightsgofromclevelandtoindianapolisonaprilfifth OOOOB-fromloc.city_nameOB-toloc.city_nameOB-depart_date.month_nameB-depart_date.day_number
whichflightstravelfromnashvilletotacoma OOOOB-fromloc.city_nameOB-toloc.city_name
doestacomaairportoffertransportationfromtheairporttothedowntownarea OB-airport_nameI-airport_nameOOOOOOOOO
whichflightstravelfromtacomatosanjose OOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_name
whatdayoftheweekdoflightsfromnashvilletotacomaflyon OOOOOOOOB-fromloc.city_nameOB-toloc.city_nameOO
whataretheflightsfromtacomatosanjose OOOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_name
whatdaysoftheweekdoflightsfromsanjosetonashvilleflyon OOOOOOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameOO
whataretheflightsfromtacomatosanjose OOOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_name
ineedaflightthatgoesfrombostontoorlando OOOOOOOB-fromloc.city_nameOB-toloc.city_name
arethereanyflightsfrombostontoorlandoconnectinginnewyork OOOOOB-fromloc.city_nameOB-toloc.city_nameB-connectOB-stoploc.city_nameI-stoploc.city_name
arethereanyflightsfrombostontoorlandoconnectingincolumbus OOOOOB-fromloc.city_nameOB-toloc.city_nameB-connectOB-stoploc.city_name
doesthisflightservedinner OOOOB-meal_description
ineedaflightfromcharlottetomiami OOOOOB-fromloc.city_nameOB-toloc.city_name
ineedanonstopflightfrommiamitotoronto OOOB-flight_stopOOB-fromloc.city_nameOB-toloc.city_name
ineedanonstopflightfromtorontotost.louis OOOB-flight_stopOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_name
ineedaflightfromtorontotost.louis OOOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_name
ineedaflightfromst.louistocharlotte OOOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_name
ineedaflightonunitedairlinesfromlaguardiatosanjose OOOOOB-airline_nameI-airline_nameOB-fromloc.airport_nameI-fromloc.airport_nameOB-toloc.city_nameI-toloc.city_name
ineedaflightfromtampatomilwaukee OOOOOB-fromloc.city_nameOB-toloc.city_name
ineedaflightfrommilwaukeetoseattle OOOOOB-fromloc.city_nameOB-toloc.city_name
whatmealsareservedonamericanflight811fromtampatomilwaukee OB-mealOOOB-airline_nameOB-flight_numberOB-fromloc.city_nameOB-toloc.city_name
whatmealsareservedonamericanflight665673frommilwaukeetoseattle OB-mealOOOB-airline_nameOB-flight_numberI-flight_numberOB-fromloc.city_nameOB-toloc.city_name
ineedaflightfrommemphistolasvegas OOOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_name
pleasefindflightsavailablefromkansascitytonewark OOOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_name
pleasefindaflightthatgoesfromkansascitytonewarktoorlandobacktokansascity OOOOOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameOB-toloc.city_nameOOB-toloc.city_nameI-toloc.city_name
pleasefindaflightfromkansascitytonewark OOOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_name
pleasefindaflightfromnewarktoorlando OOOOOB-fromloc.city_nameOB-toloc.city_name
pleasefindaflightfromorlandotokansascity OOOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_name
iwouldliketoflyfromcolumbustophoenixthroughcincinnatiintheafternoon OOOOOOB-fromloc.city_nameOB-toloc.city_nameOB-stoploc.city_nameOOB-depart_time.period_of_day
iwouldliketoknowwhatairportsareinlosangeles OOOOOOOOOB-city_nameI-city_name
doestheairportatburbankhaveaflightthatcomesinfromkansascity OOOOB-toloc.city_nameOOOOOOOB-fromloc.city_nameI-fromloc.city_name
whichflightsarriveinburbankfromkansascityonsaturdaysintheafternoon OOOOB-toloc.city_nameOB-fromloc.city_nameI-fromloc.city_nameOB-arrive_date.day_nameOOB-arrive_time.period_of_day
whichflightsarriveinburbankfromlasvegasonsaturdayapriltwentythirdintheafternoon OOOOB-toloc.city_nameOB-fromloc.city_nameI-fromloc.city_nameOB-depart_date.day_nameB-depart_date.month_nameB-depart_date.day_numberI-depart_date.day_numberOOB-depart_time.period_of_day
whichflightsareavailablefromorlandotoclevelandthatarrivearound10pm OOOOOB-fromloc.city_nameOB-toloc.city_nameOOB-arrive_time.time_relativeB-arrive_time.timeI-arrive_time.time
pleaselisttheflightsfromcharlottetonewark OOOOOB-fromloc.city_nameOB-toloc.city_name
pleaselisttheflightsfromnewarktolosangeles OOOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_name
pleaselisttheflightsfromcincinnatitoburbankonamericanairlines OOOOOB-fromloc.city_nameOB-toloc.city_nameOB-airline_nameI-airline_name
pleasegivemetheflightsfromkansascitytochicagoonjunesixteenth OOOOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameOB-depart_date.month_nameB-depart_date.day_number
pleasegivemetheflightsfromchicagotokansascityonjuneseventeenth OOOOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameOB-depart_date.month_nameB-depart_date.day_number
pleaselistalltheflightsfromkansascitytochicagoonjunesixteenth OOOOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameOB-depart_date.month_nameB-depart_date.day_number
pleaselistalltheflightsfromchicagotokansascityonjuneseventeenth OOOOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameOB-depart_date.month_nameB-depart_date.day_number
i'dliketotravelfromburbanktomilwaukee OOOOOOB-fromloc.city_nameOB-toloc.city_name
canyoufindmeaflightfromsaltlakecitytonewyorkcitynextsaturdaybeforearrivingbefore6pm OOOOOOOB-fromloc.city_nameI-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameI-toloc.city_nameB-depart_date.date_relativeB-depart_date.day_nameB-arrive_time.time_relativeOB-arrive_time.time_relativeB-arrive_time.timeI-arrive_time.time
canyoufindmeanotherflightfromcincinnatitonewyorkonsaturdaybefore6pm OOOOOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameOB-depart_date.day_nameB-depart_time.time_relativeB-depart_time.timeI-depart_time.time
canyoulistallofthedeltaflightsfromsaltlakecitytonewyorknextsaturdayarrivingbefore6pm OOOOOOB-airline_nameOOB-fromloc.city_nameI-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameB-depart_date.date_relativeB-depart_date.day_nameOB-arrive_time.time_relativeB-arrive_time.timeI-arrive_time.time
i'dliketoflyfrommiamitochicagoononamericanairlinesarrivingaround5pm OOOOOOB-fromloc.city_nameOB-toloc.city_nameOOB-airline_nameI-airline_nameOB-arrive_time.time_relativeB-arrive_time.timeI-arrive_time.time
i'dliketotravelfromkansascitytochicagonextwednesday OOOOOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameB-depart_date.date_relativeB-depart_date.day_name
i'dlikearoundtripflightfromkansascitytochicagoonwednesdaymaytwentysixtharrivingat7pm OOOOB-round_tripI-round_tripOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameOB-depart_date.day_nameB-depart_date.month_nameB-depart_date.day_numberI-depart_date.day_numberOOB-arrive_time.timeI-arrive_time.time
yesi'dliketofindaflightfrommemphistotacomastoppinginlosangeles OOOOOOOOOB-fromloc.city_nameOB-toloc.city_nameOOB-stoploc.city_nameI-stoploc.city_name
findflightfromsandiegotophoenixonmondayam OOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameOB-depart_date.day_nameB-depart_time.period_of_day
findflightfromphoenixtodetroitonmonday OOOB-fromloc.city_nameOB-toloc.city_nameOB-depart_date.day_name
findflightfromdetroittosandiegoontuesday OOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameOB-depart_date.day_name
findflightfromcincinnatitosanjoseonmonday OOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameOB-depart_date.day_name
findflightfromsanjosetohoustononwednesday OOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameOB-depart_date.day_name
findflightfromhoustontomemphisonfriday OOOB-fromloc.city_nameOB-toloc.city_nameOB-depart_date.day_name
findflightfrommemphistocincinnationsunday OOOB-fromloc.city_nameOB-toloc.city_nameOB-depart_date.day_name
findamericanflightfromnewarktonashvillearound630pm OB-airline_nameOOB-fromloc.city_nameOB-toloc.city_nameB-depart_time.time_relativeB-depart_time.timeI-depart_time.time
pleasefindaflightroundtripfromlosangelestotacomawashingtonwithastopoverinsanfrancisconotexceedingthepriceof300dollarsforjunetenth1993 OOOOB-round_tripI-round_tripOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameB-toloc.state_nameOOOOB-stoploc.city_nameI-stoploc.city_nameB-cost_relativeI-cost_relativeOOOB-fare_amountI-fare_amountOB-depart_date.month_nameB-depart_date.day_numberB-depart_date.year
arethereanyflightsonjunetenthfromburbanktotacoma OOOOOB-depart_date.month_nameB-depart_date.day_numberOB-fromloc.city_nameOB-toloc.city_name
pleasefindaflightfromontariotowestchesterthatmakesastopinchicagoonmayseventeenthonewaywithdinner OOOOOB-fromloc.city_nameOB-toloc.city_nameOOOOOB-stoploc.city_nameOB-depart_date.month_nameB-depart_date.day_numberB-round_tripI-round_tripOB-meal_description
liketobookaflightfromburbanktomilwaukee OOOOOOB-fromloc.city_nameOB-toloc.city_name
showmealltheflightsfromburbanktomilwaukee OOOOOOB-fromloc.city_nameOB-toloc.city_name
findmealltheflightsfrommilwaukeetost.louis OOOOOOB-fromloc.city_nameOB-city_nameI-city_name
nowshowmealltheflightsfromst.louistoburbank OOOOOOOB-city_nameI-city_nameOB-toloc.city_name
isthereoneairlinethatfliesfromburbanktomilwaukeemilwaukeetost.louisandfromst.louistoburbank OOOOOOOB-fromloc.city_nameOB-toloc.city_nameB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_name
findmealltheroundtripflightsfromburbanktomilwaukeestoppinginst.louis OOOOB-round_tripI-round_tripOOB-fromloc.city_nameOB-toloc.city_nameOOB-stoploc.city_nameI-stoploc.city_name
i'dliketobooktwoflightstowestchestercounty OOOOOOOOB-toloc.city_nameI-toloc.city_name
iwanttobookaflightfromsaltlakecitytowestchestercounty OOOOOOOB-fromloc.city_nameI-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_nameI-toloc.city_name
tellmealltheairportsnearwestchestercounty OOOOOOB-city_nameI-city_name
i'dliketobookaflightfromcincinnatitonewyorkcityonunitedairlinesfornextsaturday OOOOOOOOB-fromloc.city_nameOB-toloc.city_nameI-toloc.city_nameI-toloc.city_nameOB-airline_nameI-airline_nameOB-depart_date.date_relativeB-depart_date.day_name
tellmealltheairportsinthenewyorkcityarea OOOOOOOB-city_nameI-city_nameI-city_nameO
pleasefindalltheflightsfromcincinnatitoanyairportinthenewyorkcityareathatarrivenextsaturdaybefore6pm OOOOOOB-fromloc.city_nameOOOOOB-toloc.city_nameI-toloc.city_nameI-toloc.city_nameOOOB-arrive_date.date_relativeB-arrive_date.day_nameB-arrive_time.time_relativeB-arrive_time.timeI-arrive_time.time
findmeaflightfromcincinnatitoanyairportinthenewyorkcityarea OOOOOB-fromloc.city_nameOOOOOB-toloc.city_nameI-toloc.city_nameI-toloc.city_nameO
i'dliketoflyfrommiamitochicagoonamericanairlines OOOOOOB-fromloc.city_nameOB-toloc.city_nameOB-airline_nameI-airline_name
iwouldliketobookaroundtripflightfromkansascitytochicago OOOOOOB-round_tripI-round_tripOOB-fromloc.city_nameI-fromloc.city_nameOB-toloc.city_name
findmeaflightthatfliesfrommemphistotacoma OOOOOOOB-fromloc.city_nameOB-toloc.city_name
此差异已折叠。
import paddlepalm as palm
if __name__ == '__main__':
controller = palm.Controller('config.yaml', task_dir='tasks')
controller.load_pretrain('../../pretrain/ernie-en-uncased-large/params')
controller.train()
controller = palm.Controller(config='config.yaml', task_dir='tasks', for_train=False)
controller.pred('atis_slot', inference_model_dir='output_model/fourthrun/atis_slot/infer_model')
export CUDA_VISIBLE_DEVICES=0
python run.py
train_file: "data/atis_slot/train.tsv"
pred_file: "data/atis_slot/test.tsv"
label_map_config: "data/atis_slot/label_map.json"
reader: ner
paradigm: ner
n_classes: 130
use_crf: true
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import SequenceLabelReader, ClassifyReader
class Reader(reader):
def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
"""
Args:
phase: train, eval, pred
"""
self._is_training = phase == 'train'
reader = SequenceLabelReader(config['vocab_path'],
max_seq_len=config['max_seq_len'],
do_lower_case=config.get('do_lower_case', False),
for_cn=config.get('for_cn', False),
random_seed=config.get('seed', None),
label_map_config=config['label_map_config']
)
self._reader = reader
self._dev_count = dev_count
self._batch_size = config['batch_size']
self._max_seq_len = config['max_seq_len']
self._num_classes = config['n_classes']
if phase == 'train':
self._input_file = config['train_file']
self._num_epochs = None # 防止iteartor终止
self._shuffle = config.get('shuffle', True)
# self._shuffle_buffer = config.get('shuffle_buffer', 5000)
elif phase == 'eval':
self._input_file = config['dev_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
elif phase == 'pred':
self._input_file = config['pred_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
self._phase = phase
self._print_first_n = config.get('print_first_n', 0)
@property
def outputs_attr(self):
rets = {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"task_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"seq_lens": [[-1], 'int64']
}
if self._is_training:
rets.update({"label_ids": [[-1, -1], 'int64']})
return rets
def load_data(self):
self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase)
def iterator(self):
def list_to_dict(x):
names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask',
'label_ids', 'seq_lens']
outputs = {n: i for n,i in zip(names, x)}
if not self._is_training:
del outputs['label_ids']
return outputs
for batch in self._data_generator():
yield list_to_dict(batch)
def get_epoch_outputs(self):
return {'examples': self._reader.get_examples(self._phase),
'features': self._reader.get_features(self._phase)}
@property
def num_examples(self):
return self._reader.get_num_examples(phase=self._phase)
......@@ -615,19 +615,7 @@ class SequenceLabelReader(BaseReader):
ret_labels.append(label)
continue
if label == "O" or label.startswith("I-"):
ret_labels.extend([label] * len(sub_token))
elif label.startswith("B-"):
i_label = "I-" + label[2:]
ret_labels.extend([label] + [i_label] * (len(sub_token) - 1))
elif label.startswith("S-"):
b_laebl = "B-" + label[2:]
e_label = "E-" + label[2:]
i_label = "I-" + label[2:]
ret_labels.extend([b_laebl] + [i_label] * (len(sub_token) - 2) + [e_label])
elif label.startswith("E-"):
i_label = "I-" + label[2:]
ret_labels.extend([i_label] * (len(sub_token) - 1) + [label])
assert len(ret_tokens) == len(ret_labels)
return ret_tokens, ret_labels
......@@ -646,6 +634,9 @@ class SequenceLabelReader(BaseReader):
position_ids = list(range(len(token_ids)))
text_type_ids = [0] * len(token_ids)
no_entity_id = len(self.label_map) - 1
labels = [
label if label in self.label_map else u"O" for label in labels
]
label_ids = [no_entity_id] + [
self.label_map[label] for label in labels
] + [no_entity_id]
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import layers
from paddlepalm.interface import task_paradigm
import numpy as np
import os
import math
class TaskParadigm(task_paradigm):
'''
Sequence labeling
'''
def __init__(self, config, phase, backbone_config=None):
self._is_training = phase == 'train'
self._hidden_size = backbone_config['hidden_size']
self.num_classes = config['n_classes']
self.learning_rate = config['learning_rate']
if 'initializer_range' in config:
self._param_initializer = config['initializer_range']
else:
self._param_initializer = fluid.initializer.TruncatedNormal(
scale=backbone_config.get('initializer_range', 0.02))
if 'dropout_prob' in config:
self._dropout_prob = config['dropout_prob']
else:
self._dropout_prob = backbone_config.get('hidden_dropout_prob', 0.0)
self._pred_output_path = config.get('pred_output_path', None)
self._use_crf = config.get('use_crf', False)
self._preds = []
@property
def inputs_attrs(self):
reader = {}
bb = {"encoder_outputs": [[-1, -1, -1], 'float32']}
if self._use_crf:
reader["seq_lens"] = [[-1], 'int64']
if self._is_training:
reader["label_ids"] = [[-1, -1], 'int64']
return {'reader': reader, 'backbone': bb}
@property
def outputs_attrs(self):
if self._is_training:
return {'loss': [[1], 'float32']}
else:
if self._use_crf:
return {'crf_decode': [[-1, -1], 'float32']}
else:
return {'logits': [[-1, -1, self.num_classes], 'float32']}
def build(self, inputs, scope_name=''):
token_emb = inputs['backbone']['encoder_outputs']
seq_lens = inputs['reader']['seq_lens']
if self._is_training:
label_ids = inputs['reader']['label_ids']
logits = fluid.layers.fc(
size=self.num_classes,
input=token_emb,
param_attr=fluid.ParamAttr(
initializer=self._param_initializer,
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)),
bias_attr=fluid.ParamAttr(
name=scope_name+"cls_out_b", initializer=fluid.initializer.Constant(0.)),
num_flatten_dims=2)
# use_crf
if self._use_crf:
if self._is_training:
crf_cost = fluid.layers.linear_chain_crf(
input=logits,
label=label_ids,
param_attr=fluid.ParamAttr(
name=scope_name+'crfw', learning_rate=self.learning_rate),
length=seq_lens)
avg_cost = fluid.layers.mean(x=crf_cost)
crf_decode = fluid.layers.crf_decoding(
input=logits,
param_attr=fluid.ParamAttr(name=scope_name+'crfw'),
length=seq_lens)
return {"loss": avg_cost}
else:
size = self.num_classes
fluid.layers.create_parameter(
shape=[size+2, size], dtype=logits.dtype, name=scope_name+'crfw')
crf_decode = fluid.layers.crf_decoding(
input=logits, param_attr=fluid.ParamAttr(name=scope_name+'crfw'),
length=seq_lens)
return {"crf_decode": crf_decode}
else:
if self._is_training:
probs = fluid.layers.softmax(logits)
ce_loss = fluid.layers.cross_entropy(
input=probs, label=label_ids)
avg_cost = fluid.layers.mean(x=ce_loss)
return {"loss": avg_cost}
else:
return {"logits": logits}
def postprocess(self, rt_outputs):
if not self._is_training:
if self._use_crf:
preds = rt_outputs['crf_decode']
else:
logits = rt_outputs['logits']
preds = np.argmax(logits, -1)
self._preds.extend(preds.tolist())
def epoch_postprocess(self, post_inputs):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training:
if self._pred_output_path is None:
raise ValueError('argument pred_output_path not found in config. Please add it into config dict/file.')
with open(os.path.join(self._pred_output_path, 'predictions.json'), 'w') as writer:
for p in self._preds:
writer.write(str(p)+'\n')
print('Predictions saved at '+os.path.join(self._pred_output_path, 'predictions.json'))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册