未验证 提交 4748684e 编写于 作者: 0 0YuanZhang0 提交者: GitHub

update d-net (#3591)

* Remove KD scripts

* Add ERNIE2.0 service

* Update server

* Update MTL

* Update README.md

* Update README.md for MTL

* Update README.md
上级 2bbe37ad
# D-NET
## Introduction
D-NET is the system Baidu submitted for MRQA (Machine Reading for Question Answering) 2019 Shared Task that focused on generalization of machine reading comprehension (MRC) models. Our system is built on a framework of pre-training and fine-tuning. The techniques of pre-trained language models, multi-task learning and knowledge distillation are employed to improve the generalization of MRC models and the experimental results show the effectiveness of these strategies. Our system is ranked at top 1 of all the participants in terms of averaged F1 score. Additionally, we won the first place for 10 of the 12 test sets and the second place for the other two in terms of F1 scores.
D-NET is a simple pre-training and fine-tuning framework that Baidu used for the MRQA (Machine Reading for Question Answering) 2019 Shared Task, which focused on the generalization of machine reading comprehension (MRC) models. Our system is ranked at top 1 of all the participants in terms of the averaged F1 score. Additionally, we won the first place for 10 of the 12 test sets and the second place for the other two in terms of F1 scores.
In this repository, we release the related code, data and model parametrs which have been used in the D-NET framework.
## Framework
An overview of the D-NET framework is shown in the figure below. To improve the generalization capability of a MRC system, we use mainly two techniques, i.e. **multi-task learning (MTL)** and **ensemble of multiple pre-trained models**.
<p align="center">
<img src="./images/D-NET_framework.png" width="500">
</p>
### D-NET includes 3 parts:
#### multi_task_learning
We use PaddlePaddle PALM multi-task learning library [Link](https://github.com/PaddlePaddle/PALM) to train single model for MRQA 2019 Shared Task.
#### knowledge_distillation
Model ensemble can improve the generalization of MRC models, we leverage the technique of distillation to ensemble multiple models into a single model, and no loss of accuracy, distillation solves the problem of slow inference process and reduce the use of a huge amount of resource.
#### server
MRQA2019 submission environment with baidu bert inference model and xlnet inference model.
#### Multi-task learning
In addition to the MRC task, we further introduce several auxiliary tasks in the fine-tuning stage to learn more general language representations. Specifically, we have the following auxiliary tasks:
- Unsupervised Task: masked Language Model
- Supervised Tasks:
- natural language inference
- paragraph ranking
We use the [PALM](https://github.com/PaddlePaddle/PALM) multi-task learning library based on [PaddlePaddle](https://www.paddlepaddle.org.cn/) in our experiments, which makes the implementation of new tasks and pre-trained models much easier than from scratch. To train the MRQA data sets with MTL, please refer to the instructions [here](multi_task_learning) (under `multi_task_learning/`).
#### Ensemble of multiple pre-trained models
In our experiments, we found that the ensemble system based on different pre-trained models shows better generalization capability than the system that based on the single ones. In this repository, we provide the parameters of 3 models that are fine-tuned on the MRQA in-domain data, based on ERNIE2.0, XL-NET and BERT, respectively. The ensemble of these models are implemented as servers. Please refer the instructions [here](server) (under `server/`) for more detials.
## Directory structure
```
├── multi_task_learning/ # scripts for multi-task learning
│ ├── configs/ # PALM config files
│ ├── scripts/ # auxiliary scripts
│ ├── wget_pretrained_model.sh # download pretrained model
│ ├── wget_data.sh # download data for MTL
│ ├── run_build_palm.sh # MLT preparation
│ ├── run_evaluation.sh # evaluation
│ ├── run_multi_task.sh # start MTL training
├── server/ # scripts for the ensemble of multiple pretrained models
│ ├── ernie_server/ # ERNIE mdoel server
│ ├── xlnet_server/ # XL-NET mdoel server
│ ├── bert_server/ # BERT mdoel server
│ ├── main_server.py # main server scripts for ensemble
│ ├── client/ # client scripts which read examples and make requests
│ ├── wget_server_inference_model.sh # script for downlowding model parameters
│ ├── start.sh # script for launching all the servers
```
## Copyright and License
Copyright 2019 Baidu.com, Inc. All Rights Reserved Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
# knowledge_distillation
## 1、Introduction
Model ensemble can improve the generalization of MRC models. However, such approach is not efficient. Because the inference of an ensemble model is slow and a huge amount of resources are required. We leverage the technique of distillation to ensemble multiple models into a single model solves the problem of slow inference process.
## 2、Quick Start
### Environment
- Python >= 2.7
- cuda >= 9.0
- cudnn >= 7.0
- PaddlePaddle >= 1.6 Please refer to Installation Guide [Installation Guide](http://www.paddlepaddle.org/#quick-start)
### Data and Models Preparation
User can get the data and trained knowledge_distillation models directly we provided:
```
bash wget_models_and_data.sh
```
user can get data and models directorys:
data:
./data/input/mlm_data: mask language model dataset.
./data/input/mrqa_distill_data: mrqa dataset, it includes two parts: mrqa_distill.json(json data we calculate from teacher models), mrqa-combined.all_dev.raw.json(merge all mrqa dev dataset).
./data/input/mrqa_evaluation_dataset: mrqa evaluation data(in_domain data and out_of_domain json data).
models:
./data/pretrain_model/squad2_model: pretrain model(google squad2.0 model as pretrain model [Model Link](https://worksheets.codalab.org/worksheets/0x3852e60a51d2444680606556d404c657)).
./data/saved_models/knowledge_distillation_model: baidu trained knowledge distillation model.
## 3、Train and Predict
Train and predict knowledge distillation model
```
bash run_distill.sh
```
## 4、Evaluation
To evaluate the result, run
```
sh run_evaluation.sh
```
Note that we use the evaluation script for SQuAD 1.1 here, which is equivalent to the official one.
## 5、Performance
| | dev in_domain(Macro-F1)| dev out_of_domain(Macro-F1) |
| ------------- | ------------ | ------------ |
| Official baseline | 77.87 | 58.67 |
| KD(4 teacher model-> student)| 83.67 | 67.34 |
KD: knowledge distillation model(ensemble 4 teacher models to student model)
## Copyright and License
Copyright 2019 Baidu.com, Inc. All Rights Reserved Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and
limitations under the License.
input data dir: mrqa distillation dataset and mask language model dataset
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import json
import numpy as np
import paddle.fluid as fluid
from model.transformer_encoder import encoder as encoder
from model.transformer_encoder import pre_process_layer as pre_process_layer
class BertModel(object):
def __init__(self,
src_ids,
position_ids,
sentence_ids,
input_mask,
config,
weight_sharing=True,
use_fp16=False,
model_name = ''):
self._emb_size = config["hidden_size"]
self._n_layer = config["num_hidden_layers"]
self._n_head = config["num_attention_heads"]
self._voc_size = config["vocab_size"]
self._max_position_seq_len = config["max_position_embeddings"]
self._sent_types = config["type_vocab_size"]
self._hidden_act = config["hidden_act"]
self._prepostprocess_dropout = config["hidden_dropout_prob"]
self._attention_dropout = config["attention_probs_dropout_prob"]
self._weight_sharing = weight_sharing
self.model_name = model_name
self._word_emb_name = self.model_name + "word_embedding"
self._pos_emb_name = self.model_name + "pos_embedding"
self._sent_emb_name = self.model_name + "sent_embedding"
self._dtype = "float16" if use_fp16 else "float32"
# Initialize all weigths by truncated normal initializer, and all biases
# will be initialized by constant zero by default.
self._param_initializer = fluid.initializer.TruncatedNormal(
scale=config["initializer_range"])
self._build_model(src_ids, position_ids, sentence_ids, input_mask, config)
def _build_model(self, src_ids, position_ids, sentence_ids, input_mask, config):
# padding id in vocabulary must be set to 0
emb_out = fluid.layers.embedding(
input=src_ids,
size=[self._voc_size, self._emb_size],
dtype=self._dtype,
param_attr=fluid.ParamAttr(
name=self._word_emb_name, initializer=self._param_initializer),
is_sparse=False)
self.emb_out =emb_out
position_emb_out = fluid.layers.embedding(
input=position_ids,
size=[self._max_position_seq_len, self._emb_size],
dtype=self._dtype,
param_attr=fluid.ParamAttr(
name=self._pos_emb_name, initializer=self._param_initializer))
self.position_emb_out = position_emb_out
sent_emb_out = fluid.layers.embedding(
sentence_ids,
size=[self._sent_types, self._emb_size],
dtype=self._dtype,
param_attr=fluid.ParamAttr(
name=self._sent_emb_name, initializer=self._param_initializer))
self.sent_emb_out = sent_emb_out
emb_out = emb_out + position_emb_out
emb_out = emb_out + sent_emb_out
emb_out = pre_process_layer(
emb_out, 'nd', self._prepostprocess_dropout, name='pre_encoder')
if self._dtype == "float16":
input_mask = fluid.layers.cast(x=input_mask, dtype=self._dtype)
self_attn_mask = fluid.layers.matmul(
x = input_mask, y = input_mask, transpose_y = True)
self_attn_mask = fluid.layers.scale(
x = self_attn_mask, scale = 10000.0, bias = -1.0, bias_after_scale = False)
n_head_self_attn_mask = fluid.layers.stack(
x=[self_attn_mask] * self._n_head, axis=1)
n_head_self_attn_mask.stop_gradient = True
self._enc_out = encoder(
enc_input = emb_out,
attn_bias = n_head_self_attn_mask,
n_layer = self._n_layer,
n_head = self._n_head,
d_key = self._emb_size // self._n_head,
d_value = self._emb_size // self._n_head,
d_model = self._emb_size,
d_inner_hid = self._emb_size * 4,
prepostprocess_dropout = self._prepostprocess_dropout,
attention_dropout = self._attention_dropout,
relu_dropout = 0,
hidden_act = self._hidden_act,
preprocess_cmd = "",
postprocess_cmd = "dan",
param_initializer = self._param_initializer,
name = self.model_name + 'encoder')
def get_sequence_output(self):
return self._enc_out
def get_pooled_output(self):
"""Get the first feature of each sequence for classification"""
next_sent_feat = fluid.layers.slice(
input = self._enc_out, axes = [1], starts = [0], ends = [1])
next_sent_feat = fluid.layers.fc(
input = next_sent_feat,
size = self._emb_size,
act = "tanh",
param_attr = fluid.ParamAttr(
name = self.model_name + "pooled_fc.w_0",
initializer = self._param_initializer),
bias_attr = "pooled_fc.b_0")
return next_sent_feat
def get_pretraining_output(self, mask_label, mask_pos, labels):
"""Get the loss & accuracy for pretraining"""
mask_pos = fluid.layers.cast(x=mask_pos, dtype='int32')
# extract the first token feature in each sentence
next_sent_feat = self.get_pooled_output()
reshaped_emb_out = fluid.layers.reshape(
x=self._enc_out, shape = [-1, self._emb_size])
# extract masked tokens' feature
mask_feat = fluid.layers.gather(input = reshaped_emb_out, index = mask_pos)
# transform: fc
mask_trans_feat = fluid.layers.fc(
input = mask_feat,
size = self._emb_size,
act = self._hidden_act,
param_attr = fluid.ParamAttr(
name = self.model_name + 'mask_lm_trans_fc.w_0',
initializer = self._param_initializer),
bias_attr = fluid.ParamAttr(name = self.model_name + 'mask_lm_trans_fc.b_0'))
# transform: layer norm
mask_trans_feat = pre_process_layer(
mask_trans_feat, 'n', name = self.model_name + 'mask_lm_trans')
mask_lm_out_bias_attr = fluid.ParamAttr(
name = self.model_name + "mask_lm_out_fc.b_0",
initializer = fluid.initializer.Constant(value = 0.0))
if self._weight_sharing:
fc_out = fluid.layers.matmul(
x = mask_trans_feat,
y = fluid.default_main_program().global_block().var(
self._word_emb_name),
transpose_y = True)
fc_out += fluid.layers.create_parameter(
shape = [self._voc_size],
dtype = self._dtype,
attr = mask_lm_out_bias_attr,
is_bias = True)
else:
fc_out = fluid.layers.fc(input = mask_trans_feat,
size = self._voc_size,
param_attr = fluid.ParamAttr(
name = self.model_name + "mask_lm_out_fc.w_0",
initializer = self._param_initializer),
bias_attr = mask_lm_out_bias_attr)
mask_lm_loss = fluid.layers.softmax_with_cross_entropy(
logits = fc_out, label = mask_label)
mean_mask_lm_loss = fluid.layers.mean(mask_lm_loss)
next_sent_fc_out = fluid.layers.fc(
input = next_sent_feat,
size = 2,
param_attr = fluid.ParamAttr(
name = self.model_name + "next_sent_fc.w_0",
initializer = self._param_initializer),
bias_attr = self.model_name + "next_sent_fc.b_0")
next_sent_loss, next_sent_softmax = fluid.layers.softmax_with_cross_entropy(
logits = next_sent_fc_out, label = labels, return_softmax = True)
next_sent_acc = fluid.layers.accuracy(
input = next_sent_softmax, label = labels)
mean_next_sent_loss = fluid.layers.mean(next_sent_loss)
loss = mean_next_sent_loss + mean_mask_lm_loss
return next_sent_acc, mean_mask_lm_loss, loss
if __name__ == "__main__":
print("hello wolrd!")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import argparse
import collections
import numpy as np
import multiprocessing
from copy import deepcopy as copy
import paddle
import paddle.fluid as fluid
from model.bert import BertModel
from utils.configure import JsonConfig
class ModelBERT(object):
def __init__(
self,
conf,
name = "",
is_training = False,
base_model = None):
# the name of this task
# name is used for identifying parameters
self.name = name
# deep copy the configure of model
self.conf = copy(conf)
self.is_training = is_training
## the overall loss of this task
self.loss = None
## outputs may be useful for the other models
self.outputs = {}
## the prediction of this task
self.predict = []
def create_model(self,
args,
reader_input,
base_model = None):
"""
given the base model, reader_input
return the create fn for create this model
"""
def _create_model():
src_ids, pos_ids, sent_ids, input_mask = reader_input
bert_conf = JsonConfig(self.conf["bert_conf_file"])
self.bert = BertModel(
src_ids = src_ids,
position_ids = pos_ids,
sentence_ids = sent_ids,
input_mask = input_mask,
config = bert_conf,
use_fp16 = args.use_fp16,
model_name = self.name)
self.loss = None
self.outputs = {
"sequence_output":self.bert.get_sequence_output(),
}
return _create_model
def get_output(self, name):
return self.outputs[name]
def get_outputs(self):
return self.outputs
def get_predict(self):
return self.predict
if __name__ == "__main__":
bert_model = ModelBERT(conf = {"json_conf_path" : "./data/pretrained_models/squad2_model/bert_config.json"})
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from model.transformer_encoder import pre_process_layer
from utils.configure import JsonConfig
def compute_loss(output_tensors, args=None):
"""Compute loss for mlm model"""
fc_out = output_tensors['mlm_out']
mask_label = output_tensors['mask_label']
mask_lm_loss = fluid.layers.softmax_with_cross_entropy(
logits=fc_out, label=mask_label)
mean_mask_lm_loss = fluid.layers.mean(mask_lm_loss)
return mean_mask_lm_loss
def create_model(reader_input, base_model=None, is_training=True, args=None):
"""
given the base model, reader_input
return the output tensors
"""
mask_label, mask_pos = reader_input
config = JsonConfig(args.bert_config_path)
_emb_size = config['hidden_size']
_voc_size = config['vocab_size']
_hidden_act = config['hidden_act']
_word_emb_name = "word_embedding"
_dtype = "float16" if args.use_fp16 else "float32"
_param_initializer = fluid.initializer.TruncatedNormal(
scale=config['initializer_range'])
mask_pos = fluid.layers.cast(x=mask_pos, dtype='int32')
enc_out = base_model.get_output("sequence_output")
# extract the first token feature in each sentence
reshaped_emb_out = fluid.layers.reshape(
x=enc_out, shape=[-1, _emb_size])
# extract masked tokens' feature
mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos)
num_seqs = fluid.layers.fill_constant(shape=[1], value=512, dtype='int64')
# transform: fc
mask_trans_feat = fluid.layers.fc(
input=mask_feat,
size=_emb_size,
act=_hidden_act,
param_attr=fluid.ParamAttr(
name='mask_lm_trans_fc.w_0',
initializer=_param_initializer),
bias_attr=fluid.ParamAttr(name='mask_lm_trans_fc.b_0'))
# transform: layer norm
mask_trans_feat = pre_process_layer(
mask_trans_feat, 'n', name='mask_lm_trans')
mask_lm_out_bias_attr = fluid.ParamAttr(
name="mask_lm_out_fc.b_0",
initializer=fluid.initializer.Constant(value=0.0))
fc_out = fluid.layers.matmul(
x=mask_trans_feat,
y=fluid.default_main_program().global_block().var(
_word_emb_name),
transpose_y=True)
fc_out += fluid.layers.create_parameter(
shape=[_voc_size],
dtype=_dtype,
attr=mask_lm_out_bias_attr,
is_bias=True)
output_tensors = {}
output_tensors['num_seqs'] = num_seqs
output_tensors['mlm_out'] = fc_out
output_tensors['mask_label'] = mask_label
return output_tensors
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
def compute_loss(output_tensors, args=None):
"""Compute loss for mrc model"""
def _compute_single_loss(logits, positions):
"""Compute start/end loss for mrc model"""
loss = fluid.layers.softmax_with_cross_entropy(
logits=logits, label=positions)
loss = fluid.layers.mean(x=loss)
return loss
start_logits = output_tensors['start_logits']
end_logits = output_tensors['end_logits']
start_positions = output_tensors['start_positions']
end_positions = output_tensors['end_positions']
start_loss = _compute_single_loss(start_logits, start_positions)
end_loss = _compute_single_loss(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2.0
if args.use_fp16 and args.loss_scaling > 1.0:
total_loss = total_loss * args.loss_scaling
return total_loss
def compute_distill_loss(output_tensors, args=None):
"""Compute loss for mrc model"""
start_logits = output_tensors['start_logits']
end_logits = output_tensors['end_logits']
start_logits_truth = output_tensors['start_logits_truth']
end_logits_truth = output_tensors['end_logits_truth']
input_mask = output_tensors['input_mask']
def _mask(logits, input_mask, nan=1e5):
input_mask = fluid.layers.reshape(input_mask, [-1, 512])
logits = logits - (1.0 - input_mask) * nan
return logits
start_logits = _mask(start_logits, input_mask)
end_logits = _mask(end_logits, input_mask)
start_logits_truth = _mask(start_logits_truth, input_mask)
end_logits_truth = _mask(end_logits_truth, input_mask)
start_logits_truth = fluid.layers.reshape(start_logits_truth, [-1, 512])
end_logits_truth = fluid.layers.reshape(end_logits_truth, [-1, 512])
T = 1.0
start_logits_softmax = fluid.layers.softmax(input=start_logits/T)
end_logits_softmax = fluid.layers.softmax(input=end_logits/T)
start_logits_truth_softmax = fluid.layers.softmax(input=start_logits_truth/T)
end_logits_truth_softmax = fluid.layers.softmax(input=end_logits_truth/T)
start_logits_truth_softmax.stop_gradient = True
end_logits_truth_softmax.stop_gradient = True
start_loss = fluid.layers.cross_entropy(start_logits_softmax, start_logits_truth_softmax, soft_label=True)
end_loss = fluid.layers.cross_entropy(end_logits_softmax, end_logits_truth_softmax, soft_label=True)
start_loss = fluid.layers.mean(x=start_loss)
end_loss = fluid.layers.mean(x=end_loss)
total_loss = (start_loss + end_loss) / 2.0
return total_loss
def create_model(reader_input, base_model=None, is_training=True, args=None):
"""
given the base model, reader_input
return the output tensors
"""
if is_training:
if args.do_distill:
src_ids, pos_ids, sent_ids, input_mask, \
start_logits_truth, end_logits_truth, start_positions, end_positions = reader_input
else:
src_ids, pos_ids, sent_ids, input_mask, \
start_positions, end_positions = reader_input
else:
src_ids, pos_ids, sent_ids, input_mask, unique_id = reader_input
enc_out = base_model.get_output("sequence_output")
logits = fluid.layers.fc(
input=enc_out,
size=2,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
name="cls_squad_out_w",
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr=fluid.ParamAttr(
name="cls_squad_out_b", initializer=fluid.initializer.Constant(0.)))
logits = fluid.layers.transpose(x=logits, perm=[2, 0, 1])
start_logits, end_logits = fluid.layers.unstack(x=logits, axis=0)
batch_ones = fluid.layers.fill_constant_batch_size_like(
input=start_logits, dtype='int64', shape=[1], value=1)
num_seqs = fluid.layers.reduce_sum(input=batch_ones)
output_tensors = {}
output_tensors['start_logits'] = start_logits
output_tensors['end_logits'] = end_logits
output_tensors['num_seqs'] = num_seqs
output_tensors['input_mask'] = input_mask
if is_training:
output_tensors['start_positions'] = start_positions
output_tensors['end_positions'] = end_positions
if args.do_distill:
output_tensors['start_logits_truth'] = start_logits_truth
output_tensors['end_logits_truth'] = end_logits_truth
else:
output_tensors['unique_id'] = unique_id
output_tensors['start_logits'] = start_logits
output_tensors['end_logits'] = end_logits
return output_tensors
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import random
import numpy as np
import paddle
import paddle.fluid as fluid
from utils.placeholder import Placeholder
def repeat(reader):
"""Repeat a generator forever"""
generator = reader()
while True:
try:
yield next(generator)
except StopIteration:
generator = reader()
yield next(generator)
def create_joint_generator(input_shape, generators, do_distill, is_multi_task=True):
def empty_output(input_shape, batch_size=1):
results = []
for i in range(len(input_shape)):
if input_shape[i][1] == 'int32':
dtype = np.int32
if input_shape[i][1] == 'int64':
dtype = np.int64
if input_shape[i][1] == 'float32':
dtype = np.float32
if input_shape[i][1] == 'float64':
dtype = np.float64
shape = input_shape[i][0]
shape[0] = batch_size
pad_tensor = np.zeros(shape=shape, dtype=dtype)
results.append(pad_tensor)
return results
def wrapper():
"""wrapper data"""
generators_inst = [repeat(gen[0]) for gen in generators]
generators_ratio = [gen[1] for gen in generators]
weights = [ratio/sum(generators_ratio) for ratio in generators_ratio]
run_task_id = range(len(generators))
while True:
idx = np.random.choice(run_task_id, p=weights)
gen_results = next(generators_inst[idx])
if not gen_results:
break
batch_size = gen_results[0].shape[0]
results = empty_output(input_shape, batch_size)
task_id_tensor = np.array([[idx]]).astype("int64")
results[0] = task_id_tensor
for i in range(4):
results[i+1] = gen_results[i]
if do_distill:
if idx == 0:
results[5] = gen_results[4]
results[6] = gen_results[5]
results[7] = gen_results[6]
results[8] = gen_results[7]
else:
results[9] = gen_results[4]
results[10] = gen_results[5]
else:
if idx == 0:
# mrc batch
results[5] = gen_results[4]
results[6] = gen_results[5]
elif idx == 1:
# mlm batch
results[7] = gen_results[4]
results[8] = gen_results[5]
# idx stands for the task index
yield results
return wrapper
def create_reader(reader_name, input_shape, is_multi_task, do_distill, *gens):
"""
build reader for multi_task_learning
"""
placeholder = Placeholder(input_shape)
pyreader, model_inputs = placeholder.build(capacity=100, reader_name=reader_name)
joint_generator = create_joint_generator(input_shape, gens[0], do_distill, is_multi_task=is_multi_task)
return joint_generator, pyreader, model_inputs
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import division
import os
import re
import six
import gzip
import types
import logging
import numpy as np
import collections
import paddle
import paddle.fluid as fluid
from utils import tokenization
from utils.batching import prepare_batch_data
class DataReader(object):
def __init__(self,
data_dir,
vocab_path,
batch_size=4096,
in_tokens=True,
max_seq_len=512,
shuffle_files=True,
epoch=100,
voc_size=0,
is_test=False,
generate_neg_sample=False):
self.vocab = self.load_vocab(vocab_path)
self.data_dir = data_dir
self.batch_size = batch_size
self.in_tokens = in_tokens
self.shuffle_files = shuffle_files
self.epoch = epoch
self.current_epoch = 0
self.current_file_index = 0
self.total_file = 0
self.current_file = None
self.voc_size = voc_size
self.max_seq_len = max_seq_len
self.pad_id = self.vocab["[PAD]"]
self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"]
self.mask_id = self.vocab["[MASK]"]
self.is_test = is_test
self.generate_neg_sample = generate_neg_sample
if self.in_tokens:
assert self.batch_size >= self.max_seq_len, "The number of " \
"tokens in batch should not be smaller than max seq length."
if self.is_test:
self.epoch = 1
self.shuffle_files = False
def get_progress(self):
"""return current progress of traning data
"""
return self.current_epoch, self.current_file_index, self.total_file, self.current_file
def parse_line(self, line, max_seq_len=512):
""" parse one line to token_ids, sentence_ids, pos_ids, label
"""
line = line.strip().decode().split(";")
assert len(line) == 4, "One sample must have 4 fields!"
(token_ids, sent_ids, pos_ids, label) = line
token_ids = [int(token) for token in token_ids.split(" ")]
sent_ids = [int(token) for token in sent_ids.split(" ")]
pos_ids = [int(token) for token in pos_ids.split(" ")]
assert len(token_ids) == len(sent_ids) == len(
pos_ids
), "[Must be true]len(token_ids) == len(sent_ids) == len(pos_ids)"
label = int(label)
if len(token_ids) > max_seq_len:
return None
return [token_ids, sent_ids, pos_ids, label]
def read_file(self, file):
assert file.endswith('.gz'), "[ERROR] %s is not a gzip file" % file
file_path = self.data_dir + "/" + file
with gzip.open(file_path, "rb") as f:
for line in f:
parsed_line = self.parse_line(
line, max_seq_len=self.max_seq_len)
if parsed_line is None:
continue
yield parsed_line
def convert_to_unicode(self, text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(self, vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
fin = open(vocab_file)
for num, line in enumerate(fin):
items = self.convert_to_unicode(line.strip()).split("\t")
if len(items) > 2:
break
token = items[0]
index = items[1] if len(items) == 2 else num
token = token.strip()
vocab[token] = int(index)
return vocab
def random_pair_neg_samples(self, pos_samples):
""" randomly generate negtive samples using pos_samples
Args:
pos_samples: list of positive samples
Returns:
neg_samples: list of negtive samples
"""
np.random.shuffle(pos_samples)
num_sample = len(pos_samples)
neg_samples = []
miss_num = 0
for i in range(num_sample):
pair_index = (i + 1) % num_sample
origin_src_ids = pos_samples[i][0]
origin_sep_index = origin_src_ids.index(2)
pair_src_ids = pos_samples[pair_index][0]
pair_sep_index = pair_src_ids.index(2)
src_ids = origin_src_ids[:origin_sep_index + 1] + pair_src_ids[
pair_sep_index + 1:]
if len(src_ids) >= self.max_seq_len:
miss_num += 1
continue
sent_ids = [0] * len(origin_src_ids[:origin_sep_index + 1]) + [
1
] * len(pair_src_ids[pair_sep_index + 1:])
pos_ids = list(range(len(src_ids)))
neg_sample = [src_ids, sent_ids, pos_ids, 0]
assert len(src_ids) == len(sent_ids) == len(
pos_ids
), "[ERROR]len(src_id) == lne(sent_id) == len(pos_id) must be True"
neg_samples.append(neg_sample)
return neg_samples, miss_num
def mixin_negtive_samples(self, pos_sample_generator, buffer=1000):
""" 1. generate negtive samples by randomly group sentence_1 and sentence_2 of positive samples
2. combine negtive samples and positive samples
Args:
pos_sample_generator: a generator producing a parsed positive sample, which is a list: [token_ids, sent_ids, pos_ids, 1]
Returns:
sample: one sample from shuffled positive samples and negtive samples
"""
pos_samples = []
num_total_miss = 0
pos_sample_num = 0
try:
while True:
while len(pos_samples) < buffer:
pos_sample = next(pos_sample_generator)
label = pos_sample[3]
assert label == 1, "positive sample's label must be 1"
pos_samples.append(pos_sample)
pos_sample_num += 1
neg_samples, miss_num = self.random_pair_neg_samples(
pos_samples)
num_total_miss += miss_num
samples = pos_samples + neg_samples
pos_samples = []
np.random.shuffle(samples)
for sample in samples:
yield sample
except StopIteration:
print("stopiteration: reach end of file")
if len(pos_samples) == 1:
yield pos_samples[0]
elif len(pos_samples) == 0:
yield None
else:
neg_samples, miss_num = self.random_pair_neg_samples(
pos_samples)
num_total_miss += miss_num
samples = pos_samples + neg_samples
pos_samples = []
np.random.shuffle(samples)
for sample in samples:
yield sample
print("miss_num:%d\tideal_total_sample_num:%d\tmiss_rate:%f" %
(num_total_miss, pos_sample_num * 2,
num_total_miss / (pos_sample_num * 2)))
def data_generator(self):
"""
data_generator
"""
files = os.listdir(self.data_dir)
self.total_file = len(files)
assert self.total_file > 0, "[Error] data_dir is empty"
def wrapper():
def reader():
for epoch in range(self.epoch):
self.current_epoch = epoch + 1
if self.shuffle_files:
np.random.shuffle(files)
for index, file in enumerate(files):
self.current_file_index = index + 1
self.current_file = file
sample_generator = self.read_file(file)
if not self.is_test and self.generate_neg_sample:
sample_generator = self.mixin_negtive_samples(
sample_generator)
for sample in sample_generator:
if sample is None:
continue
yield sample
def batch_reader(reader, batch_size, in_tokens):
batch, total_token_num, max_len = [], 0, 0
for parsed_line in reader():
token_ids, sent_ids, pos_ids, label = parsed_line
max_len = max(max_len, len(token_ids))
if in_tokens:
to_append = (len(batch) + 1) * max_len <= batch_size
else:
to_append = len(batch) < batch_size
if to_append:
batch.append(parsed_line)
total_token_num += len(token_ids)
else:
yield batch, total_token_num
batch, total_token_num, max_len = [parsed_line], len(
token_ids), len(token_ids)
if len(batch) > 0:
yield batch, total_token_num
for batch_data, total_token_num in batch_reader(
reader, self.batch_size, self.in_tokens):
yield prepare_batch_data(
batch_data,
total_token_num,
voc_size=self.voc_size,
pad_id=self.pad_id,
cls_id=self.cls_id,
sep_id=self.sep_id,
mask_id=self.mask_id,
max_len=self.max_seq_len,
return_input_mask=True,
return_max_len=False,
return_num_token=False)
return wrapper
if __name__ == "__main__":
pass
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run MRQA"""
import six
import math
import json
import random
import collections
import numpy as np
from utils import tokenization
from utils.batching import prepare_batch_data
class DataProcessorDistill(object):
def __init__(self):
self.num_examples = -1
self.current_train_example = -1
self.current_train_epoch = -1
def get_features(self, data_path):
with open(data_path, 'r') as fr:
for line in fr:
yield line.strip()
def data_generator(self,
data_file,
batch_size,
max_len,
in_tokens,
dev_count,
epochs,
shuffle):
self.num_examples = len([ "" for line in open(data_file,"r")])
def batch_reader(data_file, in_tokens, batch_size):
batch = []
index = 0
for feature in self.get_features(data_file):
to_append = len(batch) < batch_size
if to_append:
batch.append(feature)
else:
yield batch
batch = []
if len(batch) > 0:
yield batch
def wrapper():
for epoch in range(epochs):
all_batches = []
for batch_data in batch_reader(data_file, in_tokens, batch_size):
batch_data_segment = []
for feature in batch_data:
data = json.loads(feature.strip())
example_index = data['example_index']
unique_id = data['unique_id']
input_ids = data['input_ids']
position_ids = data['position_ids']
input_mask = data['input_mask']
segment_ids = data['segment_ids']
start_position = data['start_position']
end_position = data['end_position']
start_logits = data['start_logits']
end_logits = data['end_logits']
instance = [input_ids, position_ids, segment_ids, input_mask, start_logits, end_logits, start_position, end_position]
batch_data_segment.append(instance)
batch_data = batch_data_segment
src_ids = [inst[0] for inst in batch_data]
pos_ids = [inst[1] for inst in batch_data]
sent_ids = [inst[2] for inst in batch_data]
input_mask = [inst[3] for inst in batch_data]
start_logits = [inst[4] for inst in batch_data]
end_logits = [inst[5] for inst in batch_data]
src_ids = np.array(src_ids).astype("int64").reshape([-1, max_len, 1])
pos_ids = np.array(pos_ids).astype("int64").reshape([-1, max_len, 1])
sent_ids = np.array(sent_ids).astype("int64").reshape([-1, max_len, 1])
input_mask = np.array(input_mask).astype("float32").reshape([-1, max_len, 1])
start_logits = np.array(start_logits).astype("float32").reshape([-1, max_len])
end_logits = np.array(end_logits).astype("float32").reshape([-1, max_len])
start_positions = [inst[6] for inst in batch_data]
end_positions = [inst[7] for inst in batch_data]
start_positions = np.array(start_positions).astype("int64").reshape([-1, 1])
end_positions = np.array(end_positions).astype("int64").reshape([-1, 1])
batch_data = [src_ids, pos_ids, sent_ids, input_mask, start_logits, end_logits, start_positions, end_positions]
if len(all_batches) < dev_count:
all_batches.append(batch_data)
if len(all_batches) == dev_count:
for batch in all_batches:
yield batch
all_batches = []
return wrapper
#!/bin/bash
export FLAGS_sync_nccl_allreduce=0
export FLAGS_eager_delete_tensor_gb=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
if [ ! "$CUDA_VISIBLE_DEVICES" ]
then
export CPU_NUM=1
use_cuda=false
else
use_cuda=true
fi
# path of pre_train model
INPUT_PATH="data/input"
PRETRAIN_MODEL_PATH="data/pretrain_model/squad2_model"
# path to save checkpoint
CHECKPOINT_PATH="data/output/output_mrqa"
mkdir -p $CHECKPOINT_PATH
python -u train.py --use_cuda ${use_cuda}\
--batch_size 8 \
--in_tokens false \
--init_pretraining_params ${PRETRAIN_MODEL_PATH}/params \
--checkpoints $CHECKPOINT_PATH \
--vocab_path ${PRETRAIN_MODEL_PATH}/vocab.txt \
--do_distill true \
--do_train true \
--do_predict true \
--save_steps 10000 \
--warmup_proportion 0.1 \
--weight_decay 0.01 \
--sample_rate 0.02 \
--epoch 2 \
--max_seq_len 512 \
--bert_config_path ${PRETRAIN_MODEL_PATH}/bert_config.json \
--predict_file ${INPUT_PATH}/mrqa_distill_data/mrqa-combined.all_dev.raw.json \
--do_lower_case false \
--doc_stride 128 \
--train_file ${INPUT_PATH}/mrqa_distill_data/mrqa_distill.json \
--mlm_path ${INPUT_PATH}/mlm_data \
--mix_ratio 2.0 \
--learning_rate 3e-5 \
--lr_scheduler linear_warmup_decay \
--skip_steps 100
#!/usr/bin/env bash
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# path of dev data
PATH_dev=./data/input/mrqa_evaluation_dataset
# path of dev predict
KD_prediction=./prediction_results/KD_ema_predictions.json
files=$(ls ./prediction_results/*.log 2> /dev/null | wc -l)
if [ "$files" != "0" ];
then
rm prediction_results/*.log
fi
# evaluation KD model
echo "evaluate knowledge distillation model........................................."
for dataset in `ls $PATH_dev/in_domain_dev/*.raw.json`;do
echo $dataset >> prediction_results/KD.log
python ../multi_task_learning/scripts/evaluate-v1.1.py $dataset $KD_prediction >> prediction_results/KD.log
done
for dataset in `ls $PATH_dev/out_of_domain_dev/*.raw.json`;do
echo $dataset >> prediction_results/KD.log
python ../multi_task_learning/scripts/evaluate-v1.1.py $dataset $KD_prediction >> prediction_results/KD.log
done
python ../multi_task_learning/scripts/macro_avg.py prediction_results/KD.log
# wget pretrain model
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/D-Net/squad2_model.tar.gz
tar -xvf squad2_model.tar.gz
rm squad2_model.tar.gz
mv squad2_model ./data/pretrain_model/
# wget knowledge_distillation dataset
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/D-Net/d_net_knowledge_distillation_dataset.tar.gz
tar -xvf d_net_knowledge_distillation_dataset.tar.gz
rm d_net_knowledge_distillation_dataset.tar.gz
mv mlm_data ./data/input
mv mrqa_distill_data ./data/input
# wget evaluation dev dataset
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/D-Net/mrqa_evaluation_dataset.tar.gz
tar -xvf mrqa_evaluation_dataset.tar.gz
rm mrqa_evaluation_dataset.tar.gz
mv mrqa_evaluation_dataset ./data/input
# wget predictions results
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/D-Net/kd_prediction_results.tar.gz
tar -xvf kd_prediction_results.tar.gz
rm kd_prediction_results.tar.gz
# wget MRQA baidu trained knowledge distillation model
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/D-Net/knowledge_distillation_model.tar.gz
tar -xvf knowledge_distillation_model.tar.gz
rm knowledge_distillation_model.tar.gz
mv knowledge_distillation_model ./data/saved_models
# Multi_task_learning
# Multi task learning
## 1Introduction
The pretraining is usually performed on corpus with restricted domains, it is expected that increasing the domain diversity by further pre-training on other corpus may improve the generalization capability. Hence, we incorporate masked language model and domain classify model by using corpus from various domains as an auxiliary tasks in the fine-tuning phase, along with MRC. Additionally, we explore multi-task learning by incorporating the supervised dataset from other NLP tasks to learn better language representation.
## 1. Introduction
Multi task learning (MTL) has been used in many NLP tasks to obtain better language representations. Hence, we experiment with several auxiliary tasks to improve the generalization capability of a MRC model. The auxiliary tasks that we use include
## 2、Quick Start
We use PaddlePaddle PALM(multi-task Learning Library) to train MRQA2019 MRC multi-task baseline model, download PALM:
```
git clone https://github.com/PaddlePaddle/PALM.git
```
- Unsupervised Task: masked Language Model
- Supervised Tasks:
- natural language inference
- paragraph ranking
In the MRQA 2019 shared task, We use [PALM](https://github.com/PaddlePaddle/PALM) v1.0 (a multi-task learning Library based on PaddlePaddle) to perform multi-task training, which makes the implementation of new tasks and pre-trained models much easier than from scratch.
PALM user guide: [README.md](https://github.com/PaddlePaddle/PALM/blob/master/README.md)
## 2.Preparation
### Environment
- Python >= 2.7
- cuda >= 9.0
- cudnn >= 7.0
- PaddlePaddle >= 1.6 Please refer to Installation Guide [Installation Guide](http://www.paddlepaddle.org/#quick-start)
- PaddlePaddle 1.5.2 (Please refer to the Installation Guide [Installation Guide](http://www.paddlepaddle.org/#quick-start))
- PALM v1.0
### Data Preparation
#### Get data directly:
User can get the data directly we provided:
```
bash wget_data.sh
```
### Install PALM
To install PALM v1.0, run the follwing command under `multi_task_learning/`,
#### Convert MRC dataset to squad format data:
To download the MRQA datasets, run
```
cd scripts && bash download_data.sh && cd ..
git clone --branch v1.0 --depth 1 https://github.com/PaddlePaddle/PALM.git
```
The training and prediction datasets will be saved in `./scripts/train/` and `./scripts/dev/`, respectively.
The Multi_task_learning model only supports dataset files in SQuAD format. Before running the model on MRQA datasets, one need to convert the official MRQA data to SQuAD format. To do the conversion, run
```
cd scripts && bash convert_mrqa2squad.sh && cd ..
```
The output files will be named as `xxx.raw.json`.
For more instructions, see the PALM user guide: [README.md](https://github.com/PaddlePaddle/PALM/blob/v1.0/README.md)
### Dowload data
To download the MRQA training and development data, as well as other auxiliary data for MTL, run
For convenience, we provide a script to combine all the training and development data into a single file respectively.
```
cd scripts && bash combine.sh && cd ..
bash wget_data.sh
```
The combined files will be saved in `./scripts/train/mrqa-combined.raw.json` and `./scripts/dev/mrqa-combined.raw.json`.
The downloaded data will be saved into `data/mrqa` (combined MRQA training and development data), `data/mrqa_dev` (seperated MRQA in-domain and out-of-domain data, for model evaluation), `mlm4mrqa` (training data for masked language model task) and `data/am4mrqa` (training data for paragraph matching task).
### Download pre-trained parameters
In our MTL experiments, we use BERT as our shared encoder. The parameters are initialized from the Whole Word Masking BERT (BERTwwm), further fine-tuned on the SQuAD 2.0 task with synthetic generated question answering corpora. The model parameters in Tensorflow format can be downloaded [here](https://worksheets.codalab.org/worksheets/0x3852e60a51d2444680606556d404c657). The following command can be used to convert the parameters to the format that is readable for PaddlePaddle.
### Models Preparation
In this competition, We use google squad2.0 model as pretrain model [Model Link](https://worksheets.codalab.org/worksheets/0x3852e60a51d2444680606556d404c657)
we provide script to convert tensorflow model to paddle model
```
cd scripts && python convert_model_params.py --init_tf_checkpoint tf_model --fluid_params_dir paddle_model && cd ..
```
or user can get the pretrain model and multi-task learning trained models we provided:
Alternatively, user can directly **download the parameters that we have converted**:
```
bash wget_models.sh
bash wget_pretrained_model.sh
```
## 3、Train and Predict
Preparing data, models, and task profiles for PALM
## 3. Training
In the following example, we use PALM library to preform a MLT with 3 tasks (i.e. machine reading comprehension as main task, masked lagnuage model and paragraph ranking as auxiliary tasks). For a detialed instruction on PALM, please refer to the [user guide](https://github.com/PaddlePaddle/PALM/blob/v1.0/README.md).
The PALM library requires a config file for every single task and a main config file `mtl_config.yaml`, which control the training behavior and hyper-parameters. For simplicity, we have prepared those files in the `multi_task_learning/configs` folder. To move the configuration files, data set and model parameters to the correct directory, run
```
bash run_build_palm.sh
```
Start training:
Once everything is in the right place, one can start training
```
cd PALM
bash run_multi_task.sh
```
The fine-tuned parameters and model predictions will be saved in `PALM/output/`, as specified by `mtl_config.yaml`.
## 4. Evaluation
The scripts for evaluation are in the folder `scripts/`. Here we provide an example for the usage of those scripts.
Before evaluation, one need a json file which contains the prediction results on the MRQA dev set. For convenience, we prepare two model prediction files with different MTL configurations, which have been saved in the `prediction_results/` folder, as downloaded in section **Download data**.
## 4、Evaluation
To evaluate the result, run
```
bash run_evaluation.sh
```
Note that we use the evaluation script for SQuAD 1.1 here, which is equivalent to the official one.
The F1 and EM score of the two model predictions will be saved into `prediction_results/BERT_MLM.log` and `prediction_results/BERT_MLM_ParaRank.log`. The macro average of F1 score will be printed on the console. The table below shows the results of our experiments with different MTL configurations.
## 5、Performance
| | dev in_domain(Macro-F1)| dev out_of_domain(Macro-F1) |
|models |in-domain dev (Macro-F1)|out-of-domain dev (Macro-F1) |
| ------------- | ------------ | ------------ |
| Official baseline | 77.87 | 58.67 |
| BERT | 82.40 | 66.35 |
| BERT (no MTL) | 82.40 | 66.35 |
| BERT + MLM | 83.19 | 67.45 |
| BERT + MLM + ParaRank | 83.51 | 66.83 |
BERT: reading comprehension single model.
BERT + MLM: reading comprehension single model as main task, mask language model as auxiliary task.
BERT + MLM + ParaRank: reading comprehension single model as main task, mask language model and paragraph classify rank as auxiliary tasks.
BERT config: configs/reading_comprehension.yaml
MLM config: configs/mask_language_model.yaml
ParaRank config: configs/answer_matching.yaml
## Copyright and License
Copyright 2019 Baidu.com, Inc. All Rights Reserved Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and
......
......@@ -5,5 +5,4 @@ cp configs/mtl_config.yaml PALM/
rm -rf PALM/data
mv data PALM/
mv squad2_model PALM/pretrain_model
mv mrqa_multi_task_models PALM/
cp run_multi_task.sh PALM/
......@@ -2,6 +2,3 @@ wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/D-Net/squad2_model.t
tar -xvf squad2_model.tar.gz
rm squad2_model.tar.gz
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/D-Net/mrqa_multi_task_models.tar.gz
tar -xvf mrqa_multi_task_models.tar.gz
rm mrqa_multi_task_models.tar.gz
# server
# ensemble server system
This directory contains the ensemble system for the three models that are fine-tuned on the MRQA in-domain data (i.e. models based on ERNIE2.0, XL-NET and BERT). The architecture of the ensemble system is shown in the figure below. We first start 3 independent model server for ERNIE, XL-NET and BERT. We then start a main server to receive client requests, invoke model servers and ensemble model results.
For convinience, users are able to explore **any ensemble combinations** (e.g. ERNIE+XL-NET, BERT+XL-NET), by simply modifying the configurations.
## Introduction
MRQA 2019 Shared Task submission will be handled through the [Codalab](https://worksheets.codalab.org/) platform: see [these instructions](https://worksheets.codalab.org/worksheets/0x926e37ac8b4941f793bf9b9758cc01be/).
<p align="center">
<img src="../images/D-NET_server.png" width="500">
</p>
We provided D-NET models submission environment for MRQA competition. it includes two server: bert server and xlnet server, we merged the results of two serves.
## Inference Model Preparation
Download bert inference model and xlnet inferece model
## Environment
In our test environment, we use
- Python 2.7.13
- PaddlePaddle 1.5.2
- sentencepiece 0.1.83
- flask 1.1.1
- Cuda 9.0
- CuDNN 7.0
## Download model parameters
To downlowd the model parameters that are fine-tuned on the MRQA in-domain data, run
```
bash wget_server_inference_model.sh
```
A folder named `infere_model` will appear in `ernie_server/`, `xlnet_server/` and `bert_server/`.
## Start server
## Start servers
Before starting the server, please make sure the ports `5118` to `5121` are available, and specify the `gpu_id` in `start.sh` (by default `GPU 0` on the machine will be used).
To start the servers, run
We can set GPU card for bert server or xlnet server, By setting variable CUDA_VISIBLE_DEVICES:
```
export CUDA_VISIBLE_DEVICES=1
```
In main_server.py file we set the server port for bert and xlnet model, as shown below, If the port 5118 or 5120 is occupied, please set up an idle port.
```
url_1 = 'http://127.0.0.1:5118' # url for model1
url_2 = 'http://127.0.0.1:5120' # url for model2
bash start.sh
```
start server
The log for the main server will be saved in `main_server.log`, and the logs for the 3 model servers witll be saved in `ernie_server/ernie.log`, `xlnet_server/xlnet.log` and `bert_server/bert.log`.
By default, the main server will ensemble the results from ERNIE and XL-NET. To explore other ensemble combinations, one can change the configuration in `start.sh` (e.g. `python main_server.py --ernie --xlnet --bert` for 3 models, `python main_server.py --bert --xlnet` for BERT and XL-NET only).
Note that in our test environment, we use Tesla K40 (12G) and the three modles are able to fit in a single card. For GPUs with smaller RAM, one can choose to put three models on different card by modifying the configurations in `start.sh`.
## Send requests
Once the servers are successfully launched, one can use the client script to send requests.
```
bash start.sh
cd client
python client.py demo.txt results.txt 5121
```
This will the read the examples in `demo.txt`, send requests to the main server, and save results into `results.txt`. The format of the input file (i.e. `demo.txt`) need to be in [MRQA official format](https://github.com/mrqa/MRQA-Shared-Task-2019).
\ No newline at end of file
#encoding=utf8
import os
import sys
import argparse
from copy import deepcopy as copy
import numpy as np
import paddle
import paddle.fluid as fluid
import collections
import multiprocessing
from pdnlp.nets.bert import BertModel
from pdnlp.toolkit.configure import JsonConfig
class ModelBERT(object):
def __init__(
self,
conf,
name = "",
is_training = False,
base_model = None):
# the name of this task
# name is used for identifying parameters
self.name = name
# deep copy the configure of model
self.conf = copy(conf)
self.is_training = is_training
## the overall loss of this task
self.loss = None
## outputs may be useful for the other models
self.outputs = {}
## the prediction of this task
self.predict = []
def create_model(self,
args,
reader_input,
base_model = None):
"""
given the base model, reader_input
return the create fn for create this model
"""
def _create_model():
src_ids, pos_ids, sent_ids, input_mask = reader_input
bert_conf = JsonConfig(self.conf["bert_conf_file"])
self.bert = BertModel(
src_ids = src_ids,
position_ids = pos_ids,
sentence_ids = sent_ids,
input_mask = input_mask,
config = bert_conf,
use_fp16 = args.use_fp16,
model_name = self.name)
self.loss = None
self.outputs = {
"sequence_output": self.bert.get_sequence_output(),
# "pooled_output": self.bert.get_pooled_output()
}
return _create_model
def get_output(self, name):
return self.outputs[name]
def get_outputs(self):
return self.outputs
def get_predict(self):
return self.predict
if __name__ == "__main__":
bert_model = ModelBERT(conf = {"json_conf_path" : "./data/pretrained_models/squad2_model/bert_config.json"})
......@@ -12,8 +12,6 @@ import argparse
import numpy as np
import paddle.fluid as fluid
from task_reader.mrqa import DataProcessor, get_answers
from bert_model import ModelBERT
import mrc_model
ema_decay = 0.9999
verbose = False
......
# encoding=utf8
import paddle.fluid as fluid
def compute_loss(output_tensors, args=None):
"""Compute loss for mrc model"""
def _compute_single_loss(logits, positions):
"""Compute start/end loss for mrc model"""
loss = fluid.layers.softmax_with_cross_entropy(
logits=logits, label=positions)
loss = fluid.layers.mean(x=loss)
return loss
start_logits = output_tensors['start_logits']
end_logits = output_tensors['end_logits']
start_positions = output_tensors['start_positions']
end_positions = output_tensors['end_positions']
start_loss = _compute_single_loss(start_logits, start_positions)
end_loss = _compute_single_loss(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2.0
if args.use_fp16 and args.loss_scaling > 1.0:
total_loss = total_loss * args.loss_scaling
return total_loss
def create_model(reader_input, base_model=None, is_training=True, args=None):
"""
given the base model, reader_input
return the output tensors
"""
if is_training:
src_ids, pos_ids, sent_ids, input_mask, \
start_positions, end_positions = reader_input
else:
src_ids, pos_ids, sent_ids, input_mask, unique_id = reader_input
enc_out = base_model.get_output("sequence_output")
logits = fluid.layers.fc(
input=enc_out,
size=2,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
name="cls_squad_out_w",
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr=fluid.ParamAttr(
name="cls_squad_out_b", initializer=fluid.initializer.Constant(0.)))
logits = fluid.layers.transpose(x=logits, perm=[2, 0, 1])
start_logits, end_logits = fluid.layers.unstack(x=logits, axis=0)
batch_ones = fluid.layers.fill_constant_batch_size_like(
input=start_logits, dtype='int64', shape=[1], value=1)
num_seqs = fluid.layers.reduce_sum(input=batch_ones)
output_tensors = {}
output_tensors['start_logits'] = start_logits
output_tensors['end_logits'] = end_logits
output_tensors['num_seqs'] = num_seqs
if is_training:
output_tensors['start_positions'] = start_positions
output_tensors['end_positions'] = end_positions
else:
output_tensors['unique_id'] = unique_id
output_tensors['start_logits'] = start_logits
output_tensors['end_logits'] = end_logits
return output_tensors
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer encoder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
def multi_head_attention(queries,
keys,
values,
attn_bias,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.,
cache=None,
param_initializer=None,
name='multi_head_att'):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
keys = queries if keys is None else keys
values = keys if values is None else values
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError(
"Inputs: quries, keys and values should all be 3-D tensors.")
def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Add linear projection to queries, keys, and values.
"""
q = layers.fc(input = queries,
size = d_key * n_head,
num_flatten_dims = 2,
param_attr = fluid.ParamAttr(
name = name + '_query_fc.w_0',
initializer = param_initializer),
bias_attr = name + '_query_fc.b_0')
k = layers.fc(input = keys,
size = d_key * n_head,
num_flatten_dims = 2,
param_attr = fluid.ParamAttr(
name = name + '_key_fc.w_0',
initializer = param_initializer),
bias_attr = name + '_key_fc.b_0')
v = layers.fc(input = values,
size = d_value * n_head,
num_flatten_dims = 2,
param_attr = fluid.ParamAttr(
name = name + '_value_fc.w_0',
initializer = param_initializer),
bias_attr = name + '_value_fc.b_0')
return q, k, v
def __split_heads(x, n_head):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions and then transpose. Specifically, input a tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] then output a tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
hidden_size = x.shape[-1]
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x = x, shape = [0, 0, n_head, hidden_size // n_head], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
def __combine_heads(x):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if len(x.shape) == 3: return x
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape(
x = trans_x,
shape = [0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace = True)
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
"""
Scaled Dot-Product Attention
"""
scaled_q = layers.scale(x = q, scale = d_key**-0.5)
product = layers.matmul(x = scaled_q, y = k, transpose_y = True)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if dropout_rate:
weights = layers.dropout(
weights,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
out = layers.matmul(weights, v)
return out
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
# Since the inplace reshape in __split_heads changes the shape of k and
# v, which is the cache input for next time step, reshape the cache
# input from the previous time step first.
k = cache["k"] = layers.concat(
[layers.reshape(
cache["k"], shape=[0, 0, d_model]), k], axis=1)
v = cache["v"] = layers.concat(
[layers.reshape(
cache["v"], shape=[0, 0, d_model]), v], axis=1)
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
v = __split_heads(v, n_head)
ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key,
dropout_rate)
out = __combine_heads(ctx_multiheads)
# Project back to the model size.
proj_out = layers.fc(input = out,
size = d_model,
num_flatten_dims = 2,
param_attr=fluid.ParamAttr(
name = name + '_output_fc.w_0',
initializer = param_initializer),
bias_attr = name + '_output_fc.b_0')
return proj_out
def positionwise_feed_forward(x,
d_inner_hid,
d_hid,
dropout_rate,
hidden_act,
param_initializer=None,
name='ffn'):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden = layers.fc(input=x,
size=d_inner_hid,
num_flatten_dims=2,
act=hidden_act,
param_attr=fluid.ParamAttr(
name=name + '_fc_0.w_0',
initializer=param_initializer),
bias_attr=name + '_fc_0.b_0')
if dropout_rate:
hidden = layers.dropout(
hidden,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test = False)
out = layers.fc(input = hidden,
size = d_hid,
num_flatten_dims = 2,
param_attr=fluid.ParamAttr(
name = name + '_fc_1.w_0',
initializer = param_initializer),
bias_attr = name + '_fc_1.b_0')
return out
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.,
name=''):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out_dtype = out.dtype
if out_dtype == fluid.core.VarDesc.VarType.FP16:
out = layers.cast(x = out, dtype = "float32")
out = layers.layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
param_attr=fluid.ParamAttr(
name = name + '_layer_norm_scale',
initializer = fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
name = name + '_layer_norm_bias',
initializer = fluid.initializer.Constant(0.)))
if out_dtype == fluid.core.VarDesc.VarType.FP16:
out = layers.cast(x = out, dtype = "float16")
elif cmd == "d": # add dropout
if dropout_rate:
out = layers.dropout(
out,
dropout_prob = dropout_rate,
dropout_implementation = "upscale_in_train",
is_test = False)
return out
pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer
def encoder_layer(enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd="n",
postprocess_cmd="da",
param_initializer=None,
name=''):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output = multi_head_attention(
pre_process_layer(
enc_input,
preprocess_cmd,
prepostprocess_dropout,
name=name + '_pre_att'),
None,
None,
attn_bias,
d_key,
d_value,
d_model,
n_head,
attention_dropout,
param_initializer = param_initializer,
name = name + '_multi_head_att')
attn_output = post_process_layer(
enc_input,
attn_output,
postprocess_cmd,
prepostprocess_dropout,
name = name + '_post_att')
ffd_output = positionwise_feed_forward(
pre_process_layer(
attn_output,
preprocess_cmd,
prepostprocess_dropout,
name = name + '_pre_ffn'),
d_inner_hid,
d_model,
relu_dropout,
hidden_act,
param_initializer = param_initializer,
name = name + '_ffn')
return post_process_layer(
attn_output,
ffd_output,
postprocess_cmd,
prepostprocess_dropout,
name = name + '_post_ffn')
def encoder(enc_input,
attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd="n",
postprocess_cmd="da",
param_initializer=None,
name='',
return_all = False):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
enc_outputs = []
for i in range(n_layer):
enc_output = encoder_layer(
enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd,
postprocess_cmd,
param_initializer = param_initializer,
name = name + '_layer_' + str(i))
enc_input = enc_output
if i < n_layer - 1:
enc_outputs.append(enc_output)
enc_output = pre_process_layer(
enc_output, preprocess_cmd, prepostprocess_dropout, name="post_encoder")
enc_outputs.append(enc_output)
if not return_all:
return enc_output
else:
return enc_output, enc_outputs
#encoding=utf8
import os
import sys
import random
import numpy as np
import paddle
import paddle.fluid as fluid
from pdnlp.toolkit.placeholder import Placeholder
def repeat(reader):
"""Repeat a generator forever"""
generator = reader()
while True:
try:
yield next(generator)
except StopIteration:
generator = reader()
yield next(generator)
def create_joint_generator(input_shape, generators, is_multi_task=True):
def empty_output(input_shape, batch_size=1):
results = []
for i in range(len(input_shape)):
if input_shape[i][1] == 'int32':
dtype = np.int32
if input_shape[i][1] == 'int64':
dtype = np.int64
if input_shape[i][1] == 'float32':
dtype = np.float32
if input_shape[i][1] == 'float64':
dtype = np.float64
shape = input_shape[i][0]
shape[0] = batch_size
pad_tensor = np.zeros(shape=shape, dtype=dtype)
results.append(pad_tensor)
return results
def wrapper():
"""wrapper data"""
generators_inst = [repeat(gen[0]) for gen in generators]
generators_ratio = [gen[1] for gen in generators]
weights = [ratio/sum(generators_ratio) for ratio in generators_ratio]
run_task_id = range(len(generators))
while True:
idx = np.random.choice(run_task_id, p=weights)
gen_results = next(generators_inst[idx])
if not gen_results:
break
batch_size = gen_results[0].shape[0]
results = empty_output(input_shape, batch_size)
task_id_tensor = np.array([[idx]]).astype("int64")
results[0] = task_id_tensor
for i in range(4):
results[i+1] = gen_results[i]
if idx == 0:
# mrc batch
results[5] = gen_results[4]
results[6] = gen_results[5]
elif idx == 1:
# mlm batch
results[7] = gen_results[4]
results[8] = gen_results[5]
elif idx == 2:
# MNLI batch
results[9] = gen_results[4]
else:
raise RuntimeError('Invalid task ID - {}'.format(idx))
# idx stands for the task index
yield results
return wrapper
def create_reader(reader_name, input_shape, is_multi_task, *gens):
"""
build reader for multi_task_learning
"""
placeholder = Placeholder(input_shape)
pyreader, model_inputs = placeholder.build(capacity=100, reader_name=reader_name)
joint_generator = create_joint_generator(input_shape, gens[0], is_multi_task=is_multi_task)
return joint_generator, pyreader, model_inputs
export FLAGS_fraction_of_gpu_memory_to_use=0.1
python start_service.py ./infer_model 5118 &
port=$1
gpu=$2
export CUDA_VISIBLE_DEVICES=$gpu
python start_service.py ./infer_model $port
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Provide MRC service for TOP1 short answer extraction system
Note the services here share some global pre/post process objects, which
are **NOT THREAD SAFE**. Try to use multi-process instead of multi-thread
for deployment.
"""
BERT model service
"""
import json
import sys
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Query the MRQA model server to generate predictions.
"""
import argparse
import json
import requests
import time
if __name__ == '__main__':
parse = argparse.ArgumentParser("")
parse.add_argument("dataset")
parse.add_argument("output_file")
parse.add_argument("port", type=int)
args = parse.parse_args()
all_predictions = {}
contexts = []
f = open(args.dataset)
for example in f:
context = json.loads(example)
if 'header' in context:
continue
contexts.append(context)
f.close()
results = {}
cnt = 0
for context in contexts:
cnt += 1
start = time.time()
pred = requests.post('http://127.0.0.1:%d' % args.port, json=context)
result = pred.json()
results.update(result)
end=time.time()
print('----- request cnt: {}, time elapsed: {:.2f} ms -----'.format(cnt, (end - start)*1000))
for qid, answer in result.items():
print('{}: {}'.format(qid, answer.encode('utf-8')))
with open(args.output_file,'w') as f:
json.dump(results, f, indent=1)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""ERNIE (PaddlePaddle) model wrapper"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import json
import collections
import multiprocessing
import argparse
import numpy as np
import paddle.fluid as fluid
from pdnlp.toolkit.configure import ArgumentGroup
from task_reader.mrqa_infer import DataProcessor, get_answers
from pdnlp.toolkit.init import init_pretraining_params, init_checkpoint
ema_decay = 0.9999
verbose = False
max_seq_len = 512
max_query_length = 64
max_answer_length = 30
in_tokens = False
do_lower_case = True
doc_stride = 128
n_best_size = 20
use_cuda = True
class ERNIEModelWrapper():
"""
Wrap a tnet model
the basic processes include input checking, preprocessing, calling tf-serving
and postprocessing
"""
def __init__(self, model_dir):
""" """
if use_cuda:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
self.exe = fluid.Executor(place)
self.bert_preprocessor = DataProcessor(
vocab_path=os.path.join(model_dir, 'vocab.txt'),
do_lower_case=do_lower_case,
max_seq_length=max_seq_len,
in_tokens=in_tokens,
doc_stride=doc_stride,
max_query_length=max_query_length)
self.inference_program, self.feed_target_names, self.fetch_targets = \
fluid.io.load_inference_model(dirname=model_dir, executor=self.exe)
def preprocessor(self, samples, batch_size, examples_start_id, features_start_id):
"""Preprocess the input samples, including word seg, padding, token to ids"""
# Tokenization and paragraph padding
examples, features, batch = self.bert_preprocessor.data_generator(
samples, batch_size, max_len=max_seq_len, examples_start_id=examples_start_id, features_start_id=features_start_id)
self.samples = samples
return examples, features, batch
def call_mrc(self, batch, squeeze_dim0=False, return_list=False):
"""MRC"""
if squeeze_dim0 and return_list:
raise ValueError("squeeze_dim0 only work for dict-type return value.")
src_ids = batch[0]
pos_ids = batch[1]
sent_ids = batch[2]
input_mask = batch[3]
unique_id = batch[4]
feed_dict = {
self.feed_target_names[0]: src_ids,
self.feed_target_names[1]: pos_ids,
self.feed_target_names[2]: sent_ids,
self.feed_target_names[3]: input_mask,
self.feed_target_names[4]: unique_id
}
np_unique_ids, np_start_logits, np_end_logits, np_num_seqs = \
self.exe.run(self.inference_program, feed=feed_dict, fetch_list=self.fetch_targets)
if len(np_unique_ids) == 1 and squeeze_dim0:
np_unique_ids = np_unique_ids[0]
np_start_logits = np_start_logits[0]
np_end_logits = np_end_logits[0]
if return_list:
mrc_results = [{'unique_ids': id, 'start_logits': st, 'end_logits': end}
for id, st, end in zip(np_unique_ids, np_start_logits, np_end_logits)]
else:
mrc_results = {
'unique_ids': np_unique_ids,
'start_logits': np_start_logits,
'end_logits': np_end_logits,
}
return mrc_results
def postprocessor(self, examples, features, mrc_results):
"""Extract answer
batch: [examples, features] from preprocessor
mrc_results: model results from call_mrc. if mrc_results is list, each element of which is a size=1 batch.
"""
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
results = []
if isinstance(mrc_results, list):
for res in mrc_results:
unique_id = res['unique_ids'][0]
start_logits = [float(x) for x in res['start_logits'].flat]
end_logits = [float(x) for x in res['end_logits'].flat]
results.append(
RawResult(
unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
else:
assert isinstance(mrc_results, dict)
for idx in range(mrc_results['unique_ids'].shape[0]):
unique_id = int(mrc_results['unique_ids'][idx])
start_logits = [float(x) for x in mrc_results['start_logits'][idx].flat]
end_logits = [float(x) for x in mrc_results['end_logits'][idx].flat]
results.append(
RawResult(
unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
answers = get_answers(
examples, features, results, n_best_size,
max_answer_length, do_lower_case, verbose)
return answers
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Some utilities for MRC online service"""
import json
import sys
import logging
import time
import numpy as np
from flask import Response
from flask import request
from copy import deepcopy
verbose = False
def _request_check(input_json):
"""Check if the request json is valid"""
if input_json is None or not isinstance(input_json, dict):
return 'Can not parse the input json data - {}'.format(input_json)
try:
c = input_json['context']
qa = input_json['qas'][0]
qid = qa['qid']
q = qa['question']
except KeyError as e:
return 'Invalid request, key "{}" not found'.format(e)
return 'OK'
def _abort(status_code, message):
"""Create custom error message and status code"""
return Response(json.dumps(message), status=status_code, mimetype='application/json')
def _timmer(init_start, start, current, process_name):
cumulated_elapsed_time = (current - init_start) * 1000
current_elapsed_time = (current - start) * 1000
print('{}\t-\t{:.2f}\t{:.2f}'.format(process_name, cumulated_elapsed_time,
current_elapsed_time))
def _split_input_json(input_json):
if len(input_json['context_tokens']) > 810:
input_json['context'] = input_json['context'][:5000]
if len(input_json['qas']) == 1:
return [input_json]
else:
rets = []
for i in range(len(input_json['qas'])):
temp = deepcopy(input_json)
temp['qas'] = [input_json['qas'][i]]
rets.append(temp)
return rets
class BasicMRCService(object):
"""Provide basic MRC service for flask"""
def __init__(self, name, logger=None, log_data=False):
""" """
self.name = name
if logger is None:
self.logger = logging.getLogger('flask')
else:
self.logger = logger
self.log_data = log_data
def __call__(self, model, process_mode='serial', max_batch_size=5, timmer=False):
"""
Args:
mode: serial, parallel
"""
if timmer:
start = time.time()
"""Call mrc model wrapper and handle expectations"""
self.input_json = request.get_json(silent=True)
try:
if timmer:
start_request_check = time.time()
request_status = _request_check(self.input_json)
if timmer:
current_time = time.time()
_timmer(start, start_request_check, current_time, 'request check')
if self.log_data:
if self.logger is None:
logging.info(
'Client input - {}'.format(json.dumps(self.input_json, ensure_ascii=False))
)
else:
self.logger.info(
'Client input - {}'.format(json.dumps(self.input_json, ensure_ascii=False))
)
except Exception as e:
self.logger.error('server request checker error')
self.logger.exception(e)
return _abort(500, 'server request checker error - {}'.format(e))
if request_status != 'OK':
return _abort(400, request_status)
# call preprocessor
try:
if timmer:
start_preprocess = time.time()
jsons = _split_input_json(self.input_json)
processed = []
ex_start_idx = 0
feat_start_idx = 1000000000
for i in jsons:
e,f,b = model.preprocessor(i, batch_size=max_batch_size if process_mode == 'parallel' else 1, examples_start_id=ex_start_idx, features_start_id=feat_start_idx)
ex_start_idx += len(e)
feat_start_idx += len(f)
processed.append([e,f,b])
if timmer:
current_time = time.time()
_timmer(start, start_preprocess, current_time, 'preprocess')
except Exception as e:
self.logger.error('preprocessor error')
self.logger.exception(e)
return _abort(500, 'preprocessor error - {}'.format(e))
def transpose(mat):
return zip(*mat)
# call mrc
try:
if timmer:
start_call_mrc = time.time()
self.mrc_results = []
self.examples = []
self.features = []
for e, f, batches in processed:
if verbose:
if len(f) > max_batch_size:
print("get a too long example....")
if process_mode == 'serial':
self.mrc_results.extend([model.call_mrc(b, squeeze_dim0=True) for b in batches[:max_batch_size]])
elif process_mode == 'parallel':
# only keep first max_batch_size features
# batches = batches[0]
for b in batches:
self.mrc_results.extend(model.call_mrc(b, return_list=True))
else:
raise NotImplementedError()
self.examples.extend(e)
# self.features.extend(f[:max_batch_size])
self.features.extend(f)
if timmer:
current_time = time.time()
_timmer(start, start_call_mrc, current_time, 'call mrc')
except Exception as e:
self.logger.error('call_mrc error')
self.logger.exception(e)
return _abort(500, 'call_mrc error - {}'.format(e))
# call post processor
try:
if timmer:
start_post_precess = time.time()
self.results = model.postprocessor(self.examples, self.features, self.mrc_results)
# only nbest results is POSTed back
self.results = self.results[1]
# self.results = self.results[0]
if timmer:
current_time = time.time()
_timmer(start, start_post_precess, current_time, 'post process')
except Exception as e:
self.logger.error('postprocessor error')
self.logger.exception(e)
return _abort(500, 'postprocessor error - {}'.format(e))
return self._response_constructor()
def _response_constructor(self):
"""construct http response object"""
try:
response = {
# 'requestID': self.input_json['requestID'],
'results': self.results
}
if self.log_data:
self.logger.info(
'Response - {}'.format(json.dumps(response, ensure_ascii=False))
)
return Response(json.dumps(response), mimetype='application/json')
except Exception as e:
self.logger.error('response constructor error')
self.logger.exception(e)
return _abort(500, 'response constructor error - {}'.format(e))
from algorithm import optimization
from algorithm import multitask
from extension import fp16
from module import transformer_encoder
from toolkit import configure
from toolkit import init
from toolkit import placeholder
from nets import bert
#encoding=utf8
import os
import sys
import random
from copy import deepcopy as copy
import numpy as np
import paddle
import paddle.fluid as fluid
import multiprocessing
class Task:
def __init__(
self,
conf,
name = "",
is_training = False,
_DataProcesser = None,
shared_name = ""):
self.conf = copy(conf)
self.name = name
self.shared_name = shared_name
self.is_training = is_training
self.DataProcesser = _DataProcesser
def _create_reader(self):
raise NotImplementedError("Task:_create_reader not implemented")
def _create_model(self):
raise NotImplementedError("Task:_create_model not implemented")
def prepare(self, args):
raise NotImplementedError("Task:prepare not implemented")
def train_step(self, args):
raise NotImplementedError("Task:train_step not implemented")
def predict(self, args):
raise NotImplementedError("Task:_predict not implemented")
class JointTask:
def __init__(self):
self.tasks = []
#self.startup_exe = None
#self.train_exe = None
self.exe = None
self.share_vars_from = None
self.startup_prog = fluid.Program()
def __add__(self, task):
assert isinstance(task, Task)
self.tasks.append(task)
return self
def prepare(self, args):
if args.use_cuda:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
#self.startup_exe = fluid.Executor(place)
self.exe = fluid.Executor(place)
for idx, task in enumerate(self.tasks):
if idx == 0:
print("for idx : %d" % idx)
task.prepare(args, exe = self.exe)
self.share_vars_from = task.compiled_train_prog
else:
print("for idx : %d" % idx)
task.prepare(args, exe = self.exe, share_vars_from = self.share_vars_from)
def train(self, args):
joint_steps = []
for i in xrange(0, len(self.tasks)):
for _ in xrange(0, self.tasks[i].max_train_steps):
joint_steps.append(i)
self.tasks[0].train_step(args, exe = self.exe)
random.shuffle(joint_steps)
for next_task_id in joint_steps:
self.tasks[next_task_id].train_step(args, exe = self.exe)
if __name__ == "__main__":
basetask_a = Task(None)
basetask_b = Task(None)
joint_tasks = JointTask()
joint_tasks += basetask_a
print(joint_tasks.tasks)
joint_tasks += basetask_b
print(joint_tasks.tasks)
......@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from utils.fp16 import create_master_params_grads, master_param_to_train_param
from pdnlp.extension.fp16 import create_master_params_grads, master_param_to_train_param
def linear_warmup_decay(learning_rate, warmup_steps, num_train_steps):
......
......@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
from functools import partial
from functools import reduce
import numpy as np
import paddle.fluid as fluid
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#encoding=utf8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......@@ -25,7 +12,6 @@ import json
logging_only_message = "%(message)s"
logging_details = "%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s"
class JsonConfig(object):
def __init__(self, config_path):
self._config_dict = self._parse(config_path)
......@@ -62,7 +48,6 @@ class ArgumentGroup(object):
help=help + ' Default: %(default)s.',
**kwargs)
class ArgConfig(object):
def __init__(self):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
#encoding=utf8
from __future__ import print_function
import os
......@@ -54,6 +40,7 @@ class Placeholder(object):
self.lod_levels.append(lod_level)
self.names.append(name)
def build(self, capacity, reader_name, use_double_buffer = False):
pyreader = fluid.layers.py_reader(
capacity = capacity,
......@@ -65,6 +52,7 @@ class Placeholder(object):
return [pyreader, fluid.layers.read_file(pyreader)]
def __add__(self, new_holder):
assert isinstance(new_holder, tuple) or isinstance(new_holder, list)
assert len(new_holder) >= 2
......
export FLAGS_fraction_of_gpu_memory_to_use=0.1
port=$1
gpu=$2
export CUDA_VISIBLE_DEVICES=$gpu
python start_service.py ./infer_model $port
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
ERNIE model service
"""
import json
import sys
import logging
logging.basicConfig(
level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
import requests
from flask import Flask
from flask import Response
from flask import request
import mrc_service
import model_wrapper as ernie_wrapper
assert len(sys.argv) == 3 or len(sys.argv) == 4, "Usage: python serve.py <model_dir> <port> [process_mode]"
if len(sys.argv) == 3:
_, model_dir, port = sys.argv
mode = 'parallel'
else:
_, model_dir, port, mode = sys.argv
app = Flask(__name__)
app.logger.setLevel(logging.INFO)
ernie_model = ernie_wrapper.ERNIEModelWrapper(model_dir=model_dir)
server = mrc_service.BasicMRCService('Short answer MRC service', app.logger)
@app.route('/', methods=['POST'])
def mrqa_service():
"""Description"""
model = ernie_model
return server(model, process_mode=mode, max_batch_size=5)
if __name__ == '__main__':
app.run(port=port, debug=False, threaded=False, processes=1)
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mask, padding and batching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,15 +14,15 @@
# limitations under the License.
"""Run MRQA"""
import re
import six
import math
import json
import random
import collections
import numpy as np
from utils import tokenization
from utils.batching import prepare_batch_data
import tokenization
from batching import prepare_batch_data
class MRQAExample(object):
......@@ -94,10 +95,8 @@ class InputFeatures(object):
self.is_impossible = is_impossible
def read_mrqa_examples(input_file, is_training, with_negative=False):
def read_mrqa_examples(sample, is_training=False, with_negative=False):
"""Read a MRQA json file into a list of MRQAExample."""
with open(input_file, "r") as reader:
input_data = json.load(reader)["data"]
def is_whitespace(c):
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
......@@ -105,74 +104,36 @@ def read_mrqa_examples(input_file, is_training, with_negative=False):
return False
examples = []
for entry in input_data:
for paragraph in entry["paragraphs"]:
paragraph_text = paragraph["context"]
doc_tokens = []
char_to_word_offset = []
# sample = json.loads(raw_sample)
paragraph_text = sample["context"]
paragraph_text = re.sub(r'\[TLE\]|\[DOC\]|\[PAR\]', '[SEP]', paragraph_text)
doc_tokens = []
char_to_word_offset = []
prev_is_whitespace = True
for c in paragraph_text:
if is_whitespace(c):
prev_is_whitespace = True
for c in paragraph_text:
if is_whitespace(c):
prev_is_whitespace = True
else:
if prev_is_whitespace:
doc_tokens.append(c)
else:
doc_tokens[-1] += c
prev_is_whitespace = False
char_to_word_offset.append(len(doc_tokens) - 1)
for qa in paragraph["qas"]:
qas_id = qa["id"]
question_text = qa["question"]
start_position = None
end_position = None
orig_answer_text = None
is_impossible = False
if is_training:
if with_negative:
is_impossible = qa["is_impossible"]
if (len(qa["answers"]) != 1) and (not is_impossible):
raise ValueError(
"For training, each question should have exactly 1 answer."
)
if not is_impossible:
answer = qa["answers"][0]
orig_answer_text = answer["text"]
answer_offset = answer["answer_start"]
answer_length = len(orig_answer_text)
start_position = char_to_word_offset[answer_offset]
end_position = char_to_word_offset[answer_offset +
answer_length - 1]
# Only add answers where the text can be exactly recovered from the
# document. If this CAN'T happen it's likely due to weird Unicode
# stuff so we will just skip the example.
#
# Note that this means for training mode, every example is NOT
# guaranteed to be preserved.
actual_text = " ".join(doc_tokens[start_position:(
end_position + 1)])
cleaned_answer_text = " ".join(
tokenization.whitespace_tokenize(orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1:
print("Could not find answer: '%s' vs. '%s'",
actual_text, cleaned_answer_text)
continue
else:
start_position = -1
end_position = -1
orig_answer_text = ""
example = MRQAExample(
qas_id=qas_id,
question_text=question_text,
doc_tokens=doc_tokens,
orig_answer_text=orig_answer_text,
start_position=start_position,
end_position=end_position,
is_impossible=is_impossible)
examples.append(example)
else:
if prev_is_whitespace:
doc_tokens.append(c)
else:
doc_tokens[-1] += c
prev_is_whitespace = False
char_to_word_offset.append(len(doc_tokens) - 1)
for qa in sample["qas"]:
qas_id = qa["qid"]
question_text = qa["question"]
start_position = None
end_position = None
orig_answer_text = None
is_impossible = False
example = MRQAExample(
qas_id=qas_id,
question_text=question_text,
doc_tokens=doc_tokens)
examples.append(example)
return examples
......@@ -184,13 +145,17 @@ def convert_examples_to_features(
doc_stride,
max_query_length,
is_training,
examples_start_id=0,
features_start_id=1000000000
#output_fn
):
"""Loads a data file into a list of `InputBatch`s."""
unique_id = 1000000000
unique_id = features_start_id
example_index = examples_start_id
for (example_index, example) in enumerate(examples):
features = []
for example in examples:
query_tokens = tokenizer.tokenize(example.question_text)
if len(query_tokens) > max_query_length:
......@@ -308,34 +273,6 @@ def convert_examples_to_features(
start_position = 0
end_position = 0
"""
if example_index < 3:
print("*** Example ***")
print("unique_id: %s" % (unique_id))
print("example_index: %s" % (example_index))
print("doc_span_index: %s" % (doc_span_index))
print("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in tokens]))
print("token_to_orig_map: %s" % " ".join([
"%d:%d" % (x, y)
for (x, y) in six.iteritems(token_to_orig_map)
]))
print("token_is_max_context: %s" % " ".join([
"%d:%s" % (x, y)
for (x, y) in six.iteritems(token_is_max_context)
]))
print("input_ids: %s" % " ".join([str(x) for x in input_ids]))
print("input_mask: %s" % " ".join([str(x) for x in input_mask]))
print("segment_ids: %s" %
" ".join([str(x) for x in segment_ids]))
if is_training and example.is_impossible:
print("impossible example")
if is_training and not example.is_impossible:
answer_text = " ".join(tokens[start_position:(end_position +
1)])
print("start_position: %d" % (start_position))
print("end_position: %d" % (end_position))
print("answer: %s" %
(tokenization.printable_text(answer_text)))
feature = InputFeatures(
unique_id=unique_id,
......@@ -352,8 +289,9 @@ def convert_examples_to_features(
is_impossible=example.is_impossible)
unique_id += 1
yield feature
features.append(feature)
example_index += 1
return features
def estimate_runtime_examples(data_path, sample_rate, tokenizer, \
......@@ -606,7 +544,6 @@ class DataProcessor(object):
self.current_train_epoch = -1
self.train_examples = None
self.predict_examples = None
self.num_examples = {'train': -1, 'predict': -1}
def get_train_progress(self):
......@@ -636,42 +573,30 @@ class DataProcessor(object):
self._max_seq_length, self._doc_stride, self._max_query_length, \
remove_impossible_questions=True, filter_invalid_spans=True)
def get_features(self, examples, is_training):
def get_features(self, examples, is_training, examples_start_id, features_start_id):
features = convert_examples_to_features(
examples=examples,
tokenizer=self._tokenizer,
max_seq_length=self._max_seq_length,
doc_stride=self._doc_stride,
max_query_length=self._max_query_length,
examples_start_id=examples_start_id,
features_start_id=features_start_id,
is_training=is_training)
return features
def data_generator(self,
data_path,
raw_samples,
batch_size,
max_len=None,
phase='train',
phase='predict',
shuffle=False,
dev_count=1,
with_negative=False,
epoch=1):
if phase == 'train':
self.train_examples = self.get_examples(
data_path,
is_training=True,
with_negative=with_negative)
examples = self.train_examples
self.num_examples['train'] = len(self.train_examples)
elif phase == 'predict':
self.predict_examples = self.get_examples(
data_path,
is_training=False,
with_negative=with_negative)
examples = self.predict_examples
self.num_examples['predict'] = len(self.predict_examples)
else:
raise ValueError(
"Unknown phase, which should be in ['train', 'predict'].")
epoch=1,
examples_start_id=0,
features_start_id=1000000000):
examples = read_mrqa_examples(raw_samples)
def batch_reader(features, batch_size, in_tokens):
batch, total_token_num, max_len = [], 0, 0
......@@ -704,57 +629,31 @@ class DataProcessor(object):
if len(batch) > 0:
yield batch, total_token_num
def wrapper():
for epoch_index in range(epoch):
if shuffle:
random.shuffle(examples)
if phase == 'train':
self.current_train_epoch = epoch_index
features = self.get_features(examples, is_training=True)
else:
features = self.get_features(examples, is_training=False)
all_dev_batches = []
for batch_data, total_token_num in batch_reader(
features, batch_size, self._in_tokens):
batch_data = prepare_batch_data(
batch_data,
total_token_num,
max_len=max_len,
voc_size=-1,
pad_id=self.pad_id,
cls_id=self.cls_id,
sep_id=self.sep_id,
mask_id=-1,
return_input_mask=True,
return_max_len=False,
return_num_token=False)
if len(all_dev_batches) < dev_count:
all_dev_batches.append(batch_data)
if len(all_dev_batches) == dev_count:
for batch in all_dev_batches:
yield batch
all_dev_batches = []
if phase == 'predict' and len(all_dev_batches) > 0:
fake_batch = all_dev_batches[-1]
fake_batch = fake_batch[:-1] + [np.array([-1]*len(fake_batch[0]))]
all_dev_batches = all_dev_batches + [fake_batch] * (dev_count - len(all_dev_batches))
for batch in all_dev_batches:
yield batch
return wrapper
def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file,
with_negative, null_score_diff_threshold,
verbose):
features = self.get_features(examples, is_training=False, examples_start_id=examples_start_id, features_start_id=features_start_id)
all_dev_batches = []
for batch_data, total_token_num in batch_reader(
features, batch_size, self._in_tokens):
batch_data = prepare_batch_data(
batch_data,
total_token_num,
max_len=max_len,
voc_size=-1,
pad_id=self.pad_id,
cls_id=self.cls_id,
sep_id=self.sep_id,
mask_id=-1,
return_input_mask=True,
return_max_len=False,
return_num_token=False)
all_dev_batches.append(batch_data)
return examples, features, all_dev_batches
def get_answers(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case,
verbose=False):
"""Write final predictions to the json file and log-odds of null if needed."""
print("Writing predictions to: %s" % (output_prediction_file))
print("Writing nbest to: %s" % (output_nbest_file))
example_index_to_features = collections.defaultdict(list)
for feature in all_features:
......@@ -788,14 +687,6 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
# if we could have irrelevant answers, get the min score of irrelevant
if with_negative:
feature_null_score = result.start_logits[0] + result.end_logits[
0]
if feature_null_score < score_null:
score_null = feature_null_score
min_null_feature_index = feature_index
null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0]
for start_index in start_indexes:
for end_index in end_indexes:
# We could hypothetically create invalid predictions, e.g., predict
......@@ -824,14 +715,6 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index]))
if with_negative:
prelim_predictions.append(
_PrelimPrediction(
feature_index=min_null_feature_index,
start_index=0,
end_index=0,
start_logit=null_start_logit,
end_logit=null_end_logit))
prelim_predictions = sorted(
prelim_predictions,
key=lambda x: (x.start_logit + x.end_logit),
......@@ -880,14 +763,6 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
start_logit=pred.start_logit,
end_logit=pred.end_logit))
# if we didn't inlude the empty option in the n-best, inlcude it
if with_negative:
if "" not in seen_predictions:
nbest.append(
_NbestPrediction(
text="",
start_logit=null_start_logit,
end_logit=null_end_logit))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if not nbest:
......@@ -921,29 +796,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
assert len(nbest_json) >= 1
if not with_negative:
all_predictions[example.qas_id] = nbest_json[0]["text"]
else:
# predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text
all_predictions[example.qas_id] = nbest_json[0]["text"]
all_nbest_json[example.qas_id] = nbest_json
with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
with open(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if with_negative:
with open(output_null_log_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
return all_predictions, all_nbest_json
def get_final_text(pred_text, orig_text, do_lower_case, verbose):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
......@@ -11,13 +11,13 @@ from flask import Flask
from flask import Response
from flask import request
import numpy as np
import argparse
from multiprocessing.dummy import Pool as ThreadPool
app = Flask(__name__)
logger = logging.getLogger('flask')
url_1 = 'http://127.0.0.1:5118' # url for model1
url_2 = 'http://127.0.0.1:5120' # url for model2
def ensemble_example(answers, n_models=None):
if n_models is None:
......@@ -50,32 +50,45 @@ def mrqa_main():
return nbest
try:
input_json = request.get_json(silent=True)
pool = ThreadPool(2)
res1 = pool.apply_async(_call_model, (url_1, input_json))
res2 = pool.apply_async(_call_model, (url_2, input_json))
nbest1 = res1.get()
nbest2 = res2.get()
# print(res1)
# print(nbest1)
n_models = len(urls)
pool = ThreadPool(n_models)
results = []
for url in urls:
result = pool.apply_async(_call_model, (url, input_json))
results.append(result.get())
pool.close()
pool.join()
nbest1 = nbest1.json()['results']
nbest2 = nbest2.json()['results']
qids = list(nbest1.keys())
nbests = [nbest.json()['results'] for nbest in results]
qids = list(nbests[0].keys())
for qid in qids:
ensemble_nbest = ensemble_example([nbest1[qid], nbest2[qid]], n_models=2)
ensemble_nbest = ensemble_example([nbest[qid] for nbest in nbests], n_models=n_models)
pred[qid] = ensemble_nbest[0]['text']
except Exception as e:
pred['error'] = 'empty'
# logger.error('Error in mrc server - {}'.format(e))
logger.exception(e)
# import pdb; pdb.set_trace() # XXX BREAKPOINT
return Response(json.dumps(pred), mimetype='application/json')
if __name__ == '__main__':
url_1 = 'http://127.0.0.1:5118' # url for ernie
url_2 = 'http://127.0.0.1:5119' # url for xl-net
url_3 = 'http://127.0.0.1:5120' # url for bert
parser = argparse.ArgumentParser('main server')
parser.add_argument('--ernie', action='store_true', default=False, help="Include ERNIE")
parser.add_argument('--xlnet', action='store_true', default=False, help="Include XL-NET")
parser.add_argument('--bert', action='store_true', default=False, help="Include BERT")
args = parser.parse_args()
urls = []
if args.ernie:
print('Include ERNIE model')
urls.append(url_1)
if args.xlnet:
print('Include XL-NET model')
urls.append(url_2)
if args.bert:
print('Include BERT model')
urls.append(url_3)
assert len(urls) > 0, "At lease one model is required"
app.run(host='127.0.0.1', port=5121, debug=False, threaded=False, processes=1)
#!/bin/bash
gpu_id=0
# start ernie service
# usage: sh start.sh port gpu_id
cd ernie_server
nohup sh start.sh 5118 $gpu_id > ernie.log 2>&1 &
cd ..
# start xlnet service
cd xlnet_server
nohup sh start.sh 5119 $gpu_id > xlnet.log 2>&1 &
cd ..
# start bert service
cd bert_server
export CUDA_VISIBLE_DEVICES=1
sh start.sh
cd ../xlnet_server
export CUDA_VISIBLE_DEVICES=2
sh serve.sh
nohup sh start.sh 5120 $gpu_id > bert.log 2>&1 &
cd ..
sleep 60
python main_server.py
sleep 3
# start main server
# usage: python main_server.py --model_name
# the model_name specifies the model to be used in the ensemble.
nohup python main_server.py --ernie --xlnet > main_server.log 2>&1 &
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/D-Net/mrqa2019_inference_model.tar.gz
tar -xvf mrqa2019_inference_model.tar.gz
rm mrqa2019_inference_model.tar.gz
mv infer_model bert_server
mv infer_model_800_bs128 xlnet_server
mv bert_infer_model bert_server/infer_model
mv xlnet_infer_model xlnet_server/infer_model
mv ernie_infer_model ernie_server/infer_model
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Provide MRC service for TOP1 short answer extraction system
Note the services here share some global pre/post process objects, which
are **NOT THREAD SAFE**. Try to use multi-process instead of multi-thread
for deployment.
"""
XL-NET model service
"""
import json
import sys
......
export FLAGS_sync_nccl_allreduce=0
export FLAGS_eager_delete_tensor_gb=1
export FLAGS_fraction_of_gpu_memory_to_use=0.1
port=$1
gpu=$2
export CUDA_VISIBLE_DEVICES=$gpu
python serve.py ./infer_model_800_bs128 5001 &
python serve.py ./infer_model $port
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册