提交 b3175bd2 编写于 作者: Q qiuxuezhong

add files for DuReader

上级 d1e78e57
# DuReader Dataset
DuReader is a new large-scale real-world and human sourced MRC dataset in Chinese. DuReader focuses on real-world open-domain question answering. The advantages of DuReader over existing datasets are concluded as follows:
- Real question
- Real article
- Real answer
- Real application scenario
- Rich annotation
# DuReader Network
DuReader is inspired by 3 classic reading comprehension models([BiDAF](https://arxiv.org/abs/1611.01603), [Match-LSTM](https://arxiv.org/abs/1608.07905), [R_NET](https://www.microsoft.com/en-us/research/wp-content/uploads/2017/05/r-net.pdf)).
Attention Flow from [BiDAF](https://arxiv.org/abs/1611.01603)
Anwser Point Network from [Match-LSTM](https://arxiv.org/abs/1608.07905)
Attention Pooling from [R_NET](https://www.microsoft.com/en-us/research/wp-content/uploads/2017/05/r-net.pdf)
## How to Run
### Download the Dataset
To Download DuReader dataset:
```
cd data && bash download.sh
```
For more details about DuReader dataset please refer to [DuReader Dataset Homepage](https://ai.baidu.com//broad/subordinate?dataset=dureader).
### Download Thirdparty Dependencies
We use Bleu and Rouge as evaluation metrics, the calculation of these metrics relies on the scoring scripts under "https://github.com/tylin/coco-caption", to download them, run:
```
cd utils && bash download_thirdparty.sh
```
### Preprocess the Data
After the dataset is downloaded, there is still some work to do to run the baseline systems. DuReader dataset offers rich amount of documents for every user question, the documents are too long for popular RC models to cope with. In our baseline models, we preprocess the train set and development set data by selecting the paragraph that is most related to the answer string, while for inferring(no available golden answer), we select the paragraph that is most related to the question string. The preprocessing strategy is implemented in `utils/preprocess.py`. To preprocess the raw data, you should first segment 'question', 'title', 'paragraphs' and then store the segemented result into 'segmented_question', 'segmented_title', 'segmented_paragraphs' like the downloaded preprocessed data, then run:
```
cat data/raw/trainset/search.train.json | python utils/preprocess.py > data/preprocessed/trainset/search.train.json
```
The preprocessed data can be automatically downloaded by `data/download.sh`, and is stored in `data/preprocessed`, the raw data before preprocessing is under `data/raw`.
### Run with PaddlePaddle
#### Get the Vocab File
Once the preprocessed data is ready, you can run `utils/get_vocab.py` to generate the vocabulary file, for example, if you want to train model with Baidu Search data:
```
python utils/get_vocab.py --files data/preprocessed/trainset/search.train.json data/preprocessed/devset/search.dev.json --vocab data/vocab.search
```
If you want to use the demo data, run:
```
python utils/get_vocab.py --files data/demo/trainset/search.train.json data/demo/devset/search.dev.json --vocab data/demo/vocab.search
```
#### Environment Requirements
For now we've only tested on PaddlePaddle v1.0, to install PaddlePaddle and for more details about PaddlePaddle, see [PaddlePaddle Homepage](http://paddlepaddle.org).
#### Training
The DuReader model can be trained by run `train.py`, for complete usage run `python train.py -h`.
The basic training and infering process has been wrapped in `run.sh`, the basic usage is:
```
bash run.sh TASK_NAME
```
For example, to train the model, run:
```
bash run.sh train
```
## Run DuReader on multilingual datasets
To help evaluate the system performance on multilingual datasets, we provide scripts to convert MS MARCO V2 data from its format to DuReader format.
[MS MARCO](http://www.msmarco.org/dataset.aspx) (Microsoft Machine Reading Comprehension) is an English dataset focused on machine reading comprehension and question answering. The design of MS MARCO and DuReader is similiar. It is worthwhile examining the MRC systems on both Chinese (DuReader) and English (MS MARCO) datasets.
You can download MS MARCO V2 data, and run the following scripts to convert the data from MS MARCO V2 format to DuReader format. Then, you can run and evaluate our DuReader baselines or your DuReader systems on MS MARCO data.
```
./run_marco2dureader_preprocess.sh ../data/marco/train_v2.1.json ../data/marco/train_v2.1_dureaderformat.json
./run_marco2dureader_preprocess.sh ../data/marco/dev_v2.1.json ../data/marco/dev_v2.1_dureaderformat.json
```
## Copyright and License
Copyright 2017 Baidu.com, Inc. All Rights Reserved
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import distutils.util
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--embed_size",
type=int,
default=300,
help="The dimension of embedding table. (default: %(default)d)")
parser.add_argument(
"--hidden_size",
type=int,
default=300,
help="The size of rnn hidden unit. (default: %(default)d)")
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="The sequence number of a mini-batch data. (default: %(default)d)")
parser.add_argument(
"--pass_num",
type=int,
default=5,
help="The pass number to train. (default: %(default)d)")
parser.add_argument(
"--learning_rate",
type=float,
default=0.001,
help="Learning rate used to train the model. (default: %(default)f)")
parser.add_argument(
"--use_gpu",
type=distutils.util.strtobool,
default=True,
help="Whether to use gpu. (default: %(default)d)")
parser.add_argument(
"--save_dir",
type=str,
default="model",
help="Specify the path to save trained models.")
parser.add_argument(
"--load_dir",
type=str,
default="",
help="Specify the path to load trained models.")
parser.add_argument(
"--save_interval",
type=int,
default=1,
help="Save the trained model every n passes."
"(default: %(default)d)")
parser.add_argument(
"--log_interval",
type=int,
default=50,
help="log the train loss every n batches."
"(default: %(default)d)")
parser.add_argument(
"--dev_interval",
type=int,
default=1000,
help="cal dev loss every n batches."
"(default: %(default)d)")
parser.add_argument('--optim', default='rprop', help='optimizer type')
parser.add_argument('--trainset', nargs='+', help='train dataset')
parser.add_argument('--devset', nargs='+', help='dev dataset')
parser.add_argument('--testset', nargs='+', help='test dataset')
parser.add_argument('--vocab_dir', help='dict')
parser.add_argument('--max_p_num', type=int, default=5)
parser.add_argument('--max_a_len', type=int, default=200)
parser.add_argument('--max_p_len', type=int, default=500)
parser.add_argument('--max_q_len', type=int, default=9)
parser.add_argument('--doc_num', type=int, default=5)
parser.add_argument('--para_print', action='store_true')
parser.add_argument('--drop_rate', type=float, default=0.0)
parser.add_argument('--random_seed', type=int, default=123)
parser.add_argument(
'--log_path',
help='path of the log file. If not set, logs are printed to console')
args = parser.parse_args()
return args
# -*- coding:utf8 -*-
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module implements data process strategies.
"""
import os
import json
import logging
import numpy as np
from collections import Counter
class BRCDataset(object):
"""
This module implements the APIs for loading and using baidu reading comprehension dataset
"""
def __init__(self,
max_p_num,
max_p_len,
max_q_len,
train_files=[],
dev_files=[],
test_files=[]):
self.logger = logging.getLogger("brc")
self.max_p_num = max_p_num
self.max_p_len = max_p_len
self.max_q_len = max_q_len
self.train_set, self.dev_set, self.test_set = [], [], []
if train_files:
for train_file in train_files:
self.train_set += self._load_dataset(train_file, train=True)
self.logger.info('Train set size: {} questions.'.format(
len(self.train_set)))
if dev_files:
for dev_file in dev_files:
self.dev_set += self._load_dataset(dev_file)
self.logger.info('Dev set size: {} questions.'.format(
len(self.dev_set)))
if test_files:
for test_file in test_files:
self.test_set += self._load_dataset(test_file)
self.logger.info('Test set size: {} questions.'.format(
len(self.test_set)))
def _load_dataset(self, data_path, train=False):
"""
Loads the dataset
Args:
data_path: the data file to load
"""
with open(data_path) as fin:
data_set = []
for lidx, line in enumerate(fin):
sample = json.loads(line.strip())
if train:
if len(sample['answer_spans']) == 0:
continue
if sample['answer_spans'][0][1] >= self.max_p_len:
continue
if 'answer_docs' in sample:
sample['answer_passages'] = sample['answer_docs']
sample['question_tokens'] = sample['segmented_question']
sample['passages'] = []
for d_idx, doc in enumerate(sample['documents']):
if train:
most_related_para = doc['most_related_para']
sample['passages'].append({
'passage_tokens':
doc['segmented_paragraphs'][most_related_para],
'is_selected': doc['is_selected']
})
else:
para_infos = []
for para_tokens in doc['segmented_paragraphs']:
question_tokens = sample['segmented_question']
common_with_question = Counter(
para_tokens) & Counter(question_tokens)
correct_preds = sum(common_with_question.values())
if correct_preds == 0:
recall_wrt_question = 0
else:
recall_wrt_question = float(
correct_preds) / len(question_tokens)
para_infos.append((para_tokens, recall_wrt_question,
len(para_tokens)))
para_infos.sort(key=lambda x: (-x[1], x[2]))
fake_passage_tokens = []
for para_info in para_infos[:1]:
fake_passage_tokens += para_info[0]
sample['passages'].append({
'passage_tokens': fake_passage_tokens
})
data_set.append(sample)
return data_set
def _one_mini_batch(self, data, indices, pad_id):
"""
Get one mini batch
Args:
data: all data
indices: the indices of the samples to be selected
pad_id:
Returns:
one batch of data
"""
batch_data = {
'raw_data': [data[i] for i in indices],
'question_token_ids': [],
'question_length': [],
'passage_token_ids': [],
'passage_length': [],
'start_id': [],
'end_id': []
}
max_passage_num = max(
[len(sample['passages']) for sample in batch_data['raw_data']])
#max_passage_num = min(self.max_p_num, max_passage_num)
max_passage_num = self.max_p_num
for sidx, sample in enumerate(batch_data['raw_data']):
for pidx in range(max_passage_num):
if pidx < len(sample['passages']):
batch_data['question_token_ids'].append(sample[
'question_token_ids'])
batch_data['question_length'].append(
len(sample['question_token_ids']))
passage_token_ids = sample['passages'][pidx][
'passage_token_ids']
batch_data['passage_token_ids'].append(passage_token_ids)
batch_data['passage_length'].append(
min(len(passage_token_ids), self.max_p_len))
else:
batch_data['question_token_ids'].append([])
batch_data['question_length'].append(0)
batch_data['passage_token_ids'].append([])
batch_data['passage_length'].append(0)
batch_data, padded_p_len, padded_q_len = self._dynamic_padding(
batch_data, pad_id)
for sample in batch_data['raw_data']:
if 'answer_passages' in sample and len(sample['answer_passages']):
gold_passage_offset = padded_p_len * sample['answer_passages'][
0]
batch_data['start_id'].append(gold_passage_offset + sample[
'answer_spans'][0][0])
batch_data['end_id'].append(gold_passage_offset + sample[
'answer_spans'][0][1])
else:
# fake span for some samples, only valid for testing
batch_data['start_id'].append(0)
batch_data['end_id'].append(0)
return batch_data
def _dynamic_padding(self, batch_data, pad_id):
"""
Dynamically pads the batch_data with pad_id
"""
pad_p_len = min(self.max_p_len, max(batch_data['passage_length']))
pad_q_len = min(self.max_q_len, max(batch_data['question_length']))
batch_data['passage_token_ids'] = [
(ids + [pad_id] * (pad_p_len - len(ids)))[:pad_p_len]
for ids in batch_data['passage_token_ids']
]
batch_data['question_token_ids'] = [
(ids + [pad_id] * (pad_q_len - len(ids)))[:pad_q_len]
for ids in batch_data['question_token_ids']
]
return batch_data, pad_p_len, pad_q_len
def word_iter(self, set_name=None):
"""
Iterates over all the words in the dataset
Args:
set_name: if it is set, then the specific set will be used
Returns:
a generator
"""
if set_name is None:
data_set = self.train_set + self.dev_set + self.test_set
elif set_name == 'train':
data_set = self.train_set
elif set_name == 'dev':
data_set = self.dev_set
elif set_name == 'test':
data_set = self.test_set
else:
raise NotImplementedError('No data set named as {}'.format(
set_name))
if data_set is not None:
for sample in data_set:
for token in sample['question_tokens']:
yield token
for passage in sample['passages']:
for token in passage['passage_tokens']:
yield token
def convert_to_ids(self, vocab):
"""
Convert the question and passage in the original dataset to ids
Args:
vocab: the vocabulary on this dataset
"""
for data_set in [self.train_set, self.dev_set, self.test_set]:
if data_set is None:
continue
for sample in data_set:
sample['question_token_ids'] = vocab.convert_to_ids(sample[
'question_tokens'])
for passage in sample['passages']:
passage['passage_token_ids'] = vocab.convert_to_ids(passage[
'passage_tokens'])
def gen_mini_batches(self, set_name, batch_size, pad_id, shuffle=True):
"""
Generate data batches for a specific dataset (train/dev/test)
Args:
set_name: train/dev/test to indicate the set
batch_size: number of samples in one batch
pad_id: pad id
shuffle: if set to be true, the data is shuffled.
Returns:
a generator for all batches
"""
if set_name == 'train':
data = self.train_set
elif set_name == 'dev':
data = self.dev_set
elif set_name == 'test':
data = self.test_set
else:
raise NotImplementedError('No data set named as {}'.format(
set_name))
data_size = len(data)
indices = np.arange(data_size)
if shuffle:
np.random.shuffle(indices)
for batch_start in np.arange(0, data_size, batch_size):
batch_indices = indices[batch_start:batch_start + batch_size]
yield self._one_mini_batch(data, batch_indices, pad_id)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid.layers as layers
import paddle.fluid as fluid
import numpy as np
def dropout(input, args):
if args.drop_rate:
return layers.dropout(
input,
dropout_prob=args.drop_rate,
seed=args.random_seed,
is_test=False)
else:
return input
def bi_lstm_encoder(input_seq, gate_size, para_name, args):
# A bi-directional lstm encoder implementation.
# Linear transformation part for input gate, output gate, forget gate
# and cell activation vectors need be done outside of dynamic_lstm.
# So the output size is 4 times of gate_size.
input_forward_proj = layers.fc(
input=input_seq,
param_attr=fluid.ParamAttr(name=para_name + '_fw_gate_w'),
size=gate_size * 4,
act=None,
bias_attr=False)
input_reversed_proj = layers.fc(
input=input_seq,
param_attr=fluid.ParamAttr(name=para_name + '_bw_gate_w'),
size=gate_size * 4,
act=None,
bias_attr=False)
forward, _ = layers.dynamic_lstm(
input=input_forward_proj,
size=gate_size * 4,
use_peepholes=False,
param_attr=fluid.ParamAttr(name=para_name + '_fw_lstm_w'),
bias_attr=fluid.ParamAttr(name=para_name + '_fw_lstm_b'))
reversed, _ = layers.dynamic_lstm(
input=input_reversed_proj,
param_attr=fluid.ParamAttr(name=para_name + '_bw_lstm_w'),
bias_attr=fluid.ParamAttr(name=para_name + '_bw_lstm_b'),
size=gate_size * 4,
is_reverse=True,
use_peepholes=False)
encoder_out = layers.concat(input=[forward, reversed], axis=1)
return encoder_out
def encoder(input_name, para_name, shape, hidden_size, args):
input_ids = layers.data(
name=input_name, shape=[1], dtype='int64', lod_level=1)
input_embedding = layers.embedding(
input=input_ids,
size=shape,
dtype='float32',
param_attr=fluid.ParamAttr(name='embedding_para'))
encoder_out = bi_lstm_encoder(
input_seq=input_embedding,
gate_size=hidden_size,
para_name=para_name,
args=args)
return dropout(encoder_out, args)
def attn_flow(q_enc, p_enc, p_ids_name, args):
tag = p_ids_name + "::"
drnn = layers.DynamicRNN()
with drnn.block():
h_cur = drnn.step_input(p_enc)
u_all = drnn.static_input(q_enc)
h_expd = layers.sequence_expand(x=h_cur, y=u_all)
s_t_mul = layers.elementwise_mul(x=u_all, y=h_expd, axis=0)
s_t_sum = layers.reduce_sum(input=s_t_mul, dim=1, keep_dim=True)
s_t_re = layers.reshape(s_t_sum, shape=[-1, 0])
s_t = layers.sequence_softmax(input=s_t_re)
u_expr = layers.elementwise_mul(x=u_all, y=s_t, axis=0)
u_expr = layers.sequence_pool(input=u_expr, pool_type='sum')
b_t = layers.sequence_pool(input=s_t_sum, pool_type='max')
drnn.output(u_expr, b_t)
U_expr, b = drnn()
b_norm = layers.sequence_softmax(input=b)
h_expr = layers.elementwise_mul(x=p_enc, y=b_norm, axis=0)
h_expr = layers.sequence_pool(input=h_expr, pool_type='sum')
H_expr = layers.sequence_expand(x=h_expr, y=p_enc)
H_expr = layers.lod_reset(x=H_expr, y=p_enc)
h_u = layers.elementwise_mul(x=p_enc, y=U_expr, axis=0)
h_h = layers.elementwise_mul(x=p_enc, y=H_expr, axis=0)
g = layers.concat(input=[p_enc, U_expr, h_u, h_h], axis=1)
return dropout(g, args)
def lstm_step(x_t, hidden_t_prev, cell_t_prev, size, para_name, args):
def linear(inputs, para_name, args):
return layers.fc(input=inputs,
size=size,
param_attr=fluid.ParamAttr(name=para_name + '_w'),
bias_attr=fluid.ParamAttr(name=para_name + '_b'))
input_cat = layers.concat([hidden_t_prev, x_t], axis=1)
forget_gate = layers.sigmoid(x=linear(input_cat, para_name + '_lstm_f',
args))
input_gate = layers.sigmoid(x=linear(input_cat, para_name + '_lstm_i',
args))
output_gate = layers.sigmoid(x=linear(input_cat, para_name + '_lstm_o',
args))
cell_tilde = layers.tanh(x=linear(input_cat, para_name + '_lstm_c', args))
cell_t = layers.sums(input=[
layers.elementwise_mul(
x=forget_gate, y=cell_t_prev), layers.elementwise_mul(
x=input_gate, y=cell_tilde)
])
hidden_t = layers.elementwise_mul(x=output_gate, y=layers.tanh(x=cell_t))
return hidden_t, cell_t
#point network
def point_network_decoder(p_vec, q_vec, hidden_size, args):
tag = 'pn_decoder:'
init_random = fluid.initializer.Normal(loc=0.0, scale=1.0)
random_attn = layers.create_parameter(
shape=[1, hidden_size],
dtype='float32',
default_initializer=init_random)
random_attn = layers.fc(
input=random_attn,
size=hidden_size,
act=None,
param_attr=fluid.ParamAttr(name=tag + 'random_attn_fc_w'),
bias_attr=fluid.ParamAttr(name=tag + 'random_attn_fc_b'))
random_attn = layers.reshape(random_attn, shape=[-1])
U = layers.fc(input=q_vec,
param_attr=fluid.ParamAttr(name=tag + 'q_vec_fc_w'),
bias_attr=False,
size=hidden_size,
act=None) + random_attn
U = layers.tanh(U)
logits = layers.fc(input=U,
param_attr=fluid.ParamAttr(name=tag + 'logits_fc_w'),
bias_attr=fluid.ParamAttr(name=tag + 'logits_fc_b'),
size=1,
act=None)
scores = layers.sequence_softmax(input=logits)
pooled_vec = layers.elementwise_mul(x=q_vec, y=scores, axis=0)
pooled_vec = layers.sequence_pool(input=pooled_vec, pool_type='sum')
init_state = layers.fc(
input=pooled_vec,
param_attr=fluid.ParamAttr(name=tag + 'init_state_fc_w'),
bias_attr=fluid.ParamAttr(name=tag + 'init_state_fc_b'),
size=hidden_size,
act=None)
def custom_dynamic_rnn(p_vec, init_state, hidden_size, para_name, args):
tag = para_name + "custom_dynamic_rnn:"
def static_rnn(step,
p_vec=p_vec,
init_state=None,
para_name='',
args=args):
tag = para_name + "static_rnn:"
ctx = layers.fc(
input=p_vec,
param_attr=fluid.ParamAttr(name=tag + 'context_fc_w'),
bias_attr=fluid.ParamAttr(name=tag + 'context_fc_b'),
size=hidden_size,
act=None)
beta = []
c_prev = init_state
m_prev = init_state
for i in range(step):
m_prev0 = layers.fc(
input=m_prev,
size=hidden_size,
act=None,
param_attr=fluid.ParamAttr(name=tag + 'm_prev0_fc_w'),
bias_attr=fluid.ParamAttr(name=tag + 'm_prev0_fc_b'))
m_prev1 = layers.sequence_expand(x=m_prev0, y=ctx)
Fk = ctx + m_prev1
Fk = layers.tanh(Fk)
logits = layers.fc(
input=Fk,
size=1,
act=None,
param_attr=fluid.ParamAttr(name=tag + 'logits_fc_w'),
bias_attr=fluid.ParamAttr(name=tag + 'logits_fc_b'))
scores = layers.sequence_softmax(input=logits)
attn_ctx = layers.elementwise_mul(x=p_vec, y=scores, axis=0)
attn_ctx = layers.sequence_pool(input=attn_ctx, pool_type='sum')
hidden_t, cell_t = lstm_step(
attn_ctx,
hidden_t_prev=m_prev,
cell_t_prev=c_prev,
size=hidden_size,
para_name=tag,
args=args)
m_prev = hidden_t
c_prev = cell_t
beta.append(scores)
return beta
return static_rnn(
2, p_vec=p_vec, init_state=init_state, para_name=para_name)
fw_outputs = custom_dynamic_rnn(p_vec, init_state, hidden_size, tag + "fw:",
args)
bw_outputs = custom_dynamic_rnn(p_vec, init_state, hidden_size, tag + "bw:",
args)
start_prob = layers.elementwise_add(
x=fw_outputs[0], y=bw_outputs[1], axis=0) / 2
end_prob = layers.elementwise_add(
x=fw_outputs[1], y=bw_outputs[0], axis=0) / 2
return start_prob, end_prob
def fusion(g, args):
m = bi_lstm_encoder(
input_seq=g, gate_size=args.hidden_size, para_name='fusion', args=args)
return dropout(m, args)
def rc_model(hidden_size, vocab, args):
emb_shape = [vocab.size(), vocab.embed_dim]
# stage 1:encode
p_ids_names = []
q_ids_names = []
ms = []
gs = []
qs = []
for i in range(args.doc_num):
p_ids_name = "pids_%d" % i
p_ids_names.append(p_ids_name)
p_enc_i = encoder(p_ids_name, 'p_enc', emb_shape, hidden_size, args)
q_ids_name = "qids_%d" % i
q_ids_names.append(q_ids_name)
q_enc_i = encoder(q_ids_name, 'q_enc', emb_shape, hidden_size, args)
# stage 2:match
g_i = attn_flow(q_enc_i, p_enc_i, p_ids_name, args)
# stage 3:fusion
m_i = fusion(g_i, args)
ms.append(m_i)
gs.append(g_i)
qs.append(q_enc_i)
m = layers.sequence_concat(input=ms)
g = layers.sequence_concat(input=gs)
q_vec = layers.sequence_concat(input=qs)
# stage 4:decode
start_probs, end_probs = point_network_decoder(
p_vec=m, q_vec=q_vec, hidden_size=hidden_size, args=args)
start_labels = layers.data(
name="start_lables", shape=[1], dtype='float32', lod_level=1)
end_labels = layers.data(
name="end_lables", shape=[1], dtype='float32', lod_level=1)
cost0 = layers.sequence_pool(
layers.cross_entropy(
input=start_probs, label=start_labels, soft_label=True),
'sum')
cost1 = layers.sequence_pool(
layers.cross_entropy(
input=end_probs, label=end_labels, soft_label=True),
'sum')
cost0 = layers.mean(cost0)
cost1 = layers.mean(cost1)
cost = cost0 + cost1
cost.persistable = True
feeding_list = q_ids_names + ["start_lables", "end_lables"] + p_ids_names
return cost, start_probs, end_probs, feeding_list
python train.py \
--trainset 'data/preprocessed/trainset/search.train.json' \
'data/preprocessed/trainset/zhidao.train.json' \
--devset 'data/preprocessed/devset/search.dev.json' \
'data/preprocessed/devset/zhidao.dev.json' \
--testset 'data/preprocessed/testset/search.test.json' \
'data/preprocessed/testset/zhidao.test.json' \
--vocab_dir 'data/vocab' \
--use_gpu true \
--save_dir ./models \
--pass_num 10 \
--learning_rate 0.001 \
--batch_size 8 \
--embed_size 300 \
--hidden_size 150 \
--max_p_num 5 \
--max_p_len 500 \
--max_q_len 60 \
--max_a_len 200 \
--drop_rate 0.2
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import time
import os
import random
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.framework as framework
from paddle.fluid.executor import Executor
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
sys.path.append('..')
from args import *
import rc_model
from dataset import BRCDataset
import logging
import pickle
from utils import normalize
from utils import compute_bleu_rouge
def prepare_batch_input(insts, args):
doc_num = args.doc_num
batch_size = len(insts['raw_data'])
new_insts = []
for i in range(batch_size):
p_id = []
q_id = []
p_ids = []
q_ids = []
p_len = 0
for j in range(i * doc_num, (i + 1) * doc_num):
p_ids.append(insts['passage_token_ids'][j])
p_id = p_id + insts['passage_token_ids'][j]
q_ids.append(insts['question_token_ids'][j])
q_id = q_id + insts['question_token_ids'][j]
p_len = len(p_id)
def _get_label(idx, ref_len):
ret = [0.0] * ref_len
if idx >= 0 and idx < ref_len:
ret[idx] = 1.0
return [[x] for x in ret]
start_label = _get_label(insts['start_id'][i], p_len)
end_label = _get_label(insts['end_id'][i], p_len)
new_inst = q_ids + [start_label, end_label] + p_ids
new_insts.append(new_inst)
return new_insts
def LodTensor_Array(lod_tensor):
lod = lod_tensor.lod()
array = np.array(lod_tensor)
new_array = []
for i in range(len(lod[0]) - 1):
new_array.append(array[lod[0][i]:lod[0][i + 1]])
return new_array
def print_para(train_prog, train_exe, logger, args):
if args.para_print:
param_list = train_prog.block(0).all_parameters()
param_name_list = [p.name for p in param_list]
num_sum = 0
for p_name in param_name_list:
p_array = np.array(train_exe.scope.find_var(p_name).get_tensor())
param_num = np.prod(p_array.shape)
num_sum = num_sum + param_num
print("param: {0}, mean={1} max={2} min={3} num={4} {5}".format(
p_name,
p_array.mean(),
p_array.max(), p_array.min(), p_array.shape, param_num))
print("total param num: {0}".format(num_sum))
def find_best_answer_for_passage(start_probs, end_probs, passage_len, args):
"""
Finds the best answer with the maximum start_prob * end_prob from a single passage
"""
if passage_len is None:
passage_len = len(start_probs)
else:
passage_len = min(len(start_probs), passage_len)
best_start, best_end, max_prob = -1, -1, 0
for start_idx in range(passage_len):
for ans_len in range(args.max_a_len):
end_idx = start_idx + ans_len
if end_idx >= passage_len:
continue
prob = start_probs[start_idx] * end_probs[end_idx]
if prob > max_prob:
best_start = start_idx
best_end = end_idx
max_prob = prob
return (best_start, best_end), max_prob
def find_best_answer(sample, start_prob, end_prob, padded_p_len, args):
"""
Finds the best answer for a sample given start_prob and end_prob for each position.
This will call find_best_answer_for_passage because there are multiple passages in a sample
"""
best_p_idx, best_span, best_score = None, None, 0
for p_idx, passage in enumerate(sample['passages']):
if p_idx >= args.max_p_num:
continue
passage_len = min(args.max_p_len, len(passage['passage_tokens']))
answer_span, score = find_best_answer_for_passage(
start_prob[p_idx * padded_p_len:(p_idx + 1) * padded_p_len],
end_prob[p_idx * padded_p_len:(p_idx + 1) * padded_p_len],
passage_len, args)
if score > best_score:
best_score = score
best_p_idx = p_idx
best_span = answer_span
if best_p_idx is None or best_span is None:
best_answer = ''
else:
best_answer = ''.join(sample['passages'][best_p_idx]['passage_tokens'][
best_span[0]:best_span[1] + 1])
return best_answer
def validation(exe, inference_program, avg_cost, s_probs, e_probs, feed_order,
place, vocab, brc_data, args):
"""
"""
# Use test set as validation each pass
total_loss = 0.0
count = 0
pred_answers, ref_answers = [], []
val_feed_list = [
inference_program.global_block().var(var_name)
for var_name in feed_order
]
val_feeder = fluid.DataFeeder(val_feed_list, place)
pad_id = vocab.get_id(vocab.pad_token)
dev_batches = brc_data.gen_mini_batches(
'dev', args.batch_size, pad_id, shuffle=False)
for batch_id, batch in enumerate(dev_batches, 1):
feed_data = prepare_batch_input(batch, args)
val_fetch_outs = exe.run(inference_program,
feed=val_feeder.feed(feed_data),
fetch_list=[avg_cost, s_probs, e_probs],
return_numpy=False)
total_loss += np.array(val_fetch_outs[0])[0]
start_probs = LodTensor_Array(val_fetch_outs[1])
end_probs = LodTensor_Array(val_fetch_outs[2])
count += len(batch['raw_data'])
padded_p_len = len(batch['passage_token_ids'][0])
for sample, start_prob, end_prob in zip(batch['raw_data'], start_probs,
end_probs):
best_answer = find_best_answer(sample, start_prob, end_prob,
padded_p_len, args)
pred_answers.append({
'question_id': sample['question_id'],
'question_type': sample['question_type'],
'answers': [best_answer],
'entity_answers': [[]],
'yesno_answers': []
})
if 'answers' in sample:
ref_answers.append({
'question_id': sample['question_id'],
'question_type': sample['question_type'],
'answers': sample['answers'],
'entity_answers': [[]],
'yesno_answers': []
})
ave_loss = 1.0 * total_loss / count
# compute the bleu and rouge scores if reference answers is provided
if len(ref_answers) > 0:
pred_dict, ref_dict = {}, {}
for pred, ref in zip(pred_answers, ref_answers):
question_id = ref['question_id']
if len(ref['answers']) > 0:
pred_dict[question_id] = normalize(pred['answers'])
ref_dict[question_id] = normalize(ref['answers'])
bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict)
else:
bleu_rouge = None
return ave_loss, bleu_rouge
def train():
args = parse_args()
random.seed(args.random_seed)
np.random.seed(args.random_seed)
fluid.framework.default_startup_program().random_seed = args.random_seed
fluid.default_main_program().random_seed = args.random_seed
logger = logging.getLogger("brc")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
if args.log_path:
file_handler = logging.FileHandler(args.log_path)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
else:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
logger.info('Running with args : {}'.format(args))
logger.info('Load data_set and vocab...')
with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
vocab = pickle.load(fin)
logger.info('vocab size is {} and embed dim is {}'.format(vocab.size(
), vocab.embed_dim))
brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
args.trainset, args.devset)
logger.info('Converting text into ids...')
brc_data.convert_to_ids(vocab)
logger.info('Initialize the model...')
# build model
avg_cost, s_probs, e_probs, feed_order = rc_model.rc_model(args.hidden_size,
vocab, args)
# clone from default main program and use it as the validation program
main_program = fluid.default_main_program()
inference_program = fluid.default_main_program().clone(for_test=True)
# build optimizer
if args.optim == 'sgd':
optimizer = fluid.optimizer.SGD(learning_rate=args.learning_rate)
elif args.optim == 'adam':
optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
elif args.optim == 'rprop':
optimizer = fluid.optimizer.RMSPropOptimizer(
learning_rate=args.learning_rate)
else:
logger.error('Unsupported optimizer: {}'.format(args.optim))
exit(-1)
optimizer.minimize(avg_cost)
# initialize parameters
place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace()
exe = Executor(place)
if args.load_dir:
logger.info('load from {}'.format(args.load_dir))
fluid.io.load_persistables(exe, args.load_dir)
else:
exe.run(framework.default_startup_program())
embedding_para = fluid.global_scope().find_var(
'embedding_para').get_tensor()
embedding_para.set(vocab.embeddings.astype(np.float32), place)
# prepare data
feed_list = [
main_program.global_block().var(var_name) for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
logger.info('Training the model...')
parallel_executor = fluid.ParallelExecutor(
use_cuda=bool(args.use_gpu), loss_name=avg_cost.name)
print_para(fluid.framework.default_main_program(), parallel_executor,
logger, args)
for pass_id in range(1, args.pass_num + 1):
pass_start_time = time.time()
pad_id = vocab.get_id(vocab.pad_token)
train_batches = brc_data.gen_mini_batches(
'train', args.batch_size, pad_id, shuffle=True)
log_every_n_batch, n_batch_loss = args.log_interval, 0
total_num, total_loss = 0, 0
for batch_id, batch in enumerate(train_batches, 1):
input_data_dict = prepare_batch_input(batch, args)
fetch_outs = parallel_executor.run(
feed=feeder.feed(input_data_dict),
fetch_list=[avg_cost.name],
return_numpy=False)
cost_train = np.array(fetch_outs[0])[0]
total_num += len(batch['raw_data'])
n_batch_loss += cost_train
total_loss += cost_train * len(batch['raw_data'])
if log_every_n_batch > 0 and batch_id % log_every_n_batch == 0:
print_para(fluid.framework.default_main_program(),
parallel_executor, logger, args)
logger.info('Average loss from batch {} to {} is {}'.format(
batch_id - log_every_n_batch + 1, batch_id, "%.10f" % (
n_batch_loss / log_every_n_batch)))
n_batch_loss = 0
if args.dev_interval > 0 and batch_id % args.dev_interval == 0:
eval_loss, bleu_rouge = validation(
exe, inference_program, avg_cost, s_probs, e_probs,
feed_order, place, vocab, brc_data, args)
logger.info('Dev eval loss {}'.format(eval_loss))
logger.info('Dev eval result: {}'.format(bleu_rouge))
pass_end_time = time.time()
logger.info('Evaluating the model after epoch {}'.format(pass_id))
if brc_data.dev_set is not None:
eval_loss, bleu_rouge = validation(exe, inference_program, avg_cost,
s_probs, e_probs, feed_order,
place, vocab, brc_data, args)
logger.info('Dev eval loss {}'.format(eval_loss))
logger.info('Dev eval result: {}'.format(bleu_rouge))
else:
logger.warning(
'No dev set is loaded for evaluation in the dataset!')
time_consumed = pass_end_time - pass_start_time
logger.info('Average train loss for epoch {} is {}'.format(
pass_id, "%.10f" % (1.0 * total_loss / total_num)))
if pass_id % args.save_interval == 0:
model_path = os.path.join(args.save_dir, str(pass_id))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(
executor=exe, dirname=model_path, main_program=main_program)
if __name__ == '__main__':
train()
# coding:utf8
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This package implements some utility functions shared by PaddlePaddle
and Tensorflow model implementations.
Authors: liuyuan(liuyuan04@baidu.com)
Date: 2017/10/06 18:23:06
"""
from .dureader_eval import compute_bleu_rouge
from .dureader_eval import normalize
from .preprocess import find_fake_answer
from .preprocess import find_best_question_match
__all__ = [
'compute_bleu_rouge',
'normalize',
'find_fake_answer',
'find_best_question_match',
]
#!/bin/bash
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# We use Bleu and Rouge as evaluation metrics, the calculation of these metrics
# relies on the scoring scripts under "https://github.com/tylin/coco-caption"
bleu_base_url='https://raw.githubusercontent.com/tylin/coco-caption/master/pycocoevalcap/bleu'
bleu_files=("LICENSE" "__init__.py" "bleu.py" "bleu_scorer.py")
rouge_base_url="https://raw.githubusercontent.com/tylin/coco-caption/master/pycocoevalcap/rouge"
rouge_files=("__init__.py" "rouge.py")
download() {
local metric=$1; shift;
local base_url=$1; shift;
local fnames=($@);
mkdir -p ${metric}
for fname in ${fnames[@]};
do
printf "downloading: %s\n" ${base_url}/${fname}
wget --no-check-certificate ${base_url}/${fname} -O ${metric}/${fname}
done
}
# prepare rouge
download "rouge_metric" ${rouge_base_url} ${rouge_files[@]}
# prepare bleu
download "bleu_metric" ${bleu_base_url} ${bleu_files[@]}
# convert python 2.x source code to python 3.x
2to3 -w "../utils/bleu_metric/bleu_scorer.py"
2to3 -w "../utils/bleu_metric/bleu.py"
# -*- coding:utf8 -*-
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module computes evaluation metrics for DuReader dataset.
"""
import argparse
import json
import sys
import zipfile
from collections import Counter
from .bleu_metric.bleu import Bleu
from .rouge_metric.rouge import Rouge
EMPTY = ''
YESNO_LABELS = set(['Yes', 'No', 'Depends'])
def normalize(s):
"""
Normalize strings to space joined chars.
Args:
s: a list of strings.
Returns:
A list of normalized strings.
"""
if not s:
return s
normalized = []
for ss in s:
tokens = [c for c in list(ss) if len(c.strip()) != 0]
normalized.append(' '.join(tokens))
return normalized
def data_check(obj, task):
"""
Check data.
Raises:
Raises AssertionError when data is not legal.
"""
assert 'question_id' in obj, "Missing 'question_id' field."
assert 'question_type' in obj, \
"Missing 'question_type' field. question_id: {}".format(obj['question_type'])
assert 'yesno_answers' in obj, \
"Missing 'yesno_answers' field. question_id: {}".format(obj['question_id'])
assert isinstance(obj['yesno_answers'], list), \
r"""'yesno_answers' field must be a list, if the 'question_type' is not
'YES_NO', then this field should be an empty list.
question_id: {}""".format(obj['question_id'])
assert 'entity_answers' in obj, \
"Missing 'entity_answers' field. question_id: {}".format(obj['question_id'])
assert isinstance(obj['entity_answers'], list) \
and len(obj['entity_answers']) > 0, \
r"""'entity_answers' field must be a list, and has at least one element,
which can be a empty list. question_id: {}""".format(obj['question_id'])
def read_file(file_name, task, is_ref=False):
"""
Read predict answers or reference answers from file.
Args:
file_name: the name of the file containing predict result or reference
result.
Returns:
A dictionary mapping question_id to the result information. The result
information itself is also a dictionary with has four keys:
- question_type: type of the query.
- yesno_answers: A list of yesno answers corresponding to 'answers'.
- answers: A list of predicted answers.
- entity_answers: A list, each element is also a list containing the entities
tagged out from the corresponding answer string.
"""
def _open(file_name, mode, zip_obj=None):
if zip_obj is not None:
return zip_obj.open(file_name, mode)
return open(file_name, mode)
results = {}
keys = ['answers', 'yesno_answers', 'entity_answers', 'question_type']
if is_ref:
keys += ['source']
zf = zipfile.ZipFile(file_name, 'r') if file_name.endswith('.zip') else None
file_list = [file_name] if zf is None else zf.namelist()
for fn in file_list:
for line in _open(fn, 'r', zip_obj=zf):
try:
obj = json.loads(line.strip())
except ValueError:
raise ValueError("Every line of data should be legal json")
data_check(obj, task)
qid = obj['question_id']
assert qid not in results, "Duplicate question_id: {}".format(qid)
results[qid] = {}
for k in keys:
results[qid][k] = obj[k]
return results
def compute_bleu_rouge(pred_dict, ref_dict, bleu_order=4):
"""
Compute bleu and rouge scores.
"""
assert set(pred_dict.keys()) == set(ref_dict.keys()), \
"missing keys: {}".format(set(ref_dict.keys()) - set(pred_dict.keys()))
scores = {}
bleu_scores, _ = Bleu(bleu_order).compute_score(ref_dict, pred_dict)
for i, bleu_score in enumerate(bleu_scores):
scores['Bleu-%d' % (i + 1)] = bleu_score
rouge_score, _ = Rouge().compute_score(ref_dict, pred_dict)
scores['Rouge-L'] = rouge_score
return scores
def local_prf(pred_list, ref_list):
"""
Compute local precision recall and f1-score,
given only one prediction list and one reference list
"""
common = Counter(pred_list) & Counter(ref_list)
num_same = sum(common.values())
if num_same == 0:
return 0, 0, 0
p = 1.0 * num_same / len(pred_list)
r = 1.0 * num_same / len(ref_list)
f1 = (2 * p * r) / (p + r)
return p, r, f1
def compute_prf(pred_dict, ref_dict):
"""
Compute precision recall and f1-score.
"""
pred_question_ids = set(pred_dict.keys())
ref_question_ids = set(ref_dict.keys())
correct_preds, total_correct, total_preds = 0, 0, 0
for question_id in ref_question_ids:
pred_entity_list = pred_dict.get(question_id, [[]])
assert len(pred_entity_list) == 1, \
'the number of entity list for question_id {} is not 1.'.format(question_id)
pred_entity_list = pred_entity_list[0]
all_ref_entity_lists = ref_dict[question_id]
best_local_f1 = 0
best_ref_entity_list = None
for ref_entity_list in all_ref_entity_lists:
local_f1 = local_prf(pred_entity_list, ref_entity_list)[2]
if local_f1 > best_local_f1:
best_ref_entity_list = ref_entity_list
best_local_f1 = local_f1
if best_ref_entity_list is None:
if len(all_ref_entity_lists) > 0:
best_ref_entity_list = sorted(
all_ref_entity_lists, key=lambda x: len(x))[0]
else:
best_ref_entity_list = []
gold_entities = set(best_ref_entity_list)
pred_entities = set(pred_entity_list)
correct_preds += len(gold_entities & pred_entities)
total_preds += len(pred_entities)
total_correct += len(gold_entities)
p = float(correct_preds) / total_preds if correct_preds > 0 else 0
r = float(correct_preds) / total_correct if correct_preds > 0 else 0
f1 = 2 * p * r / (p + r) if correct_preds > 0 else 0
return {'Precision': p, 'Recall': r, 'F1': f1}
def prepare_prf(pred_dict, ref_dict):
"""
Prepares data for calculation of prf scores.
"""
preds = {k: v['entity_answers'] for k, v in pred_dict.items()}
refs = {k: v['entity_answers'] for k, v in ref_dict.items()}
return preds, refs
def filter_dict(result_dict, key_tag):
"""
Filter a subset of the result_dict, where keys ends with 'key_tag'.
"""
filtered = {}
for k, v in result_dict.items():
if k.endswith(key_tag):
filtered[k] = v
return filtered
def get_metrics(pred_result, ref_result, task, source):
"""
Computes metrics.
"""
metrics = {}
ref_result_filtered = {}
pred_result_filtered = {}
if source == 'both':
ref_result_filtered = ref_result
pred_result_filtered = pred_result
else:
for question_id, info in ref_result.items():
if info['source'] == source:
ref_result_filtered[question_id] = info
if question_id in pred_result:
pred_result_filtered[question_id] = pred_result[question_id]
if task == 'main' or task == 'all' \
or task == 'description':
pred_dict, ref_dict = prepare_bleu(pred_result_filtered,
ref_result_filtered, task)
metrics = compute_bleu_rouge(pred_dict, ref_dict)
elif task == 'yesno':
pred_dict, ref_dict = prepare_bleu(pred_result_filtered,
ref_result_filtered, task)
keys = ['Yes', 'No', 'Depends']
preds = [filter_dict(pred_dict, k) for k in keys]
refs = [filter_dict(ref_dict, k) for k in keys]
metrics = compute_bleu_rouge(pred_dict, ref_dict)
for k, pred, ref in zip(keys, preds, refs):
m = compute_bleu_rouge(pred, ref)
k_metric = [(k + '|' + key, v) for key, v in m.items()]
metrics.update(k_metric)
elif task == 'entity':
pred_dict, ref_dict = prepare_prf(pred_result_filtered,
ref_result_filtered)
pred_dict_bleu, ref_dict_bleu = prepare_bleu(pred_result_filtered,
ref_result_filtered, task)
metrics = compute_prf(pred_dict, ref_dict)
metrics.update(compute_bleu_rouge(pred_dict_bleu, ref_dict_bleu))
else:
raise ValueError("Illegal task name: {}".format(task))
return metrics
def prepare_bleu(pred_result, ref_result, task):
"""
Prepares data for calculation of bleu and rouge scores.
"""
pred_list, ref_list = [], []
qids = ref_result.keys()
for qid in qids:
if task == 'main':
pred, ref = get_main_result(qid, pred_result, ref_result)
elif task == 'yesno':
pred, ref = get_yesno_result(qid, pred_result, ref_result)
elif task == 'all':
pred, ref = get_all_result(qid, pred_result, ref_result)
elif task == 'entity':
pred, ref = get_entity_result(qid, pred_result, ref_result)
elif task == 'description':
pred, ref = get_desc_result(qid, pred_result, ref_result)
else:
raise ValueError("Illegal task name: {}".format(task))
if pred and ref:
pred_list += pred
ref_list += ref
pred_dict = dict(pred_list)
ref_dict = dict(ref_list)
for qid, ans in ref_dict.items():
ref_dict[qid] = normalize(ref_dict[qid])
pred_dict[qid] = normalize(pred_dict.get(qid, [EMPTY]))
if not ans or ans == [EMPTY]:
del ref_dict[qid]
del pred_dict[qid]
for k, v in pred_dict.items():
assert len(v) == 1, \
"There should be only one predict answer. question_id: {}".format(k)
return pred_dict, ref_dict
def get_main_result(qid, pred_result, ref_result):
"""
Prepare answers for task 'main'.
Args:
qid: question_id.
pred_result: A dict include all question_id's result information read
from args.pred_file.
ref_result: A dict incluce all question_id's result information read
from args.ref_file.
Returns:
Two lists, the first one contains predict result, the second
one contains reference result of the same question_id. Each list has
elements of tuple (question_id, answers), 'answers' is a list of strings.
"""
ref_ans = ref_result[qid]['answers']
if not ref_ans:
ref_ans = [EMPTY]
pred_ans = pred_result.get(qid, {}).get('answers', [])[:1]
if not pred_ans:
pred_ans = [EMPTY]
return [(qid, pred_ans)], [(qid, ref_ans)]
def get_entity_result(qid, pred_result, ref_result):
"""
Prepare answers for task 'entity'.
Args:
qid: question_id.
pred_result: A dict include all question_id's result information read
from args.pred_file.
ref_result: A dict incluce all question_id's result information read
from args.ref_file.
Returns:
Two lists, the first one contains predict result, the second
one contains reference result of the same question_id. Each list has
elements of tuple (question_id, answers), 'answers' is a list of strings.
"""
if ref_result[qid]['question_type'] != 'ENTITY':
return None, None
return get_main_result(qid, pred_result, ref_result)
def get_desc_result(qid, pred_result, ref_result):
"""
Prepare answers for task 'description'.
Args:
qid: question_id.
pred_result: A dict include all question_id's result information read
from args.pred_file.
ref_result: A dict incluce all question_id's result information read
from args.ref_file.
Returns:
Two lists, the first one contains predict result, the second
one contains reference result of the same question_id. Each list has
elements of tuple (question_id, answers), 'answers' is a list of strings.
"""
if ref_result[qid]['question_type'] != 'DESCRIPTION':
return None, None
return get_main_result(qid, pred_result, ref_result)
def get_yesno_result(qid, pred_result, ref_result):
"""
Prepare answers for task 'yesno'.
Args:
qid: question_id.
pred_result: A dict include all question_id's result information read
from args.pred_file.
ref_result: A dict incluce all question_id's result information read
from args.ref_file.
Returns:
Two lists, the first one contains predict result, the second
one contains reference result of the same question_id. Each list has
elements of tuple (question_id, answers), 'answers' is a list of strings.
"""
def _uniq(li, is_ref):
uniq_li = []
left = []
keys = set()
for k, v in li:
if k not in keys:
uniq_li.append((k, v))
keys.add(k)
else:
left.append((k, v))
if is_ref:
dict_li = dict(uniq_li)
for k, v in left:
dict_li[k] += v
uniq_li = [(k, v) for k, v in dict_li.items()]
return uniq_li
def _expand_result(uniq_li):
expanded = uniq_li[:]
keys = set([x[0] for x in uniq_li])
for k in YESNO_LABELS - keys:
expanded.append((k, [EMPTY]))
return expanded
def _get_yesno_ans(qid, result_dict, is_ref=False):
if qid not in result_dict:
return [(str(qid) + '_' + k, v) for k, v in _expand_result([])]
yesno_answers = result_dict[qid]['yesno_answers']
answers = result_dict[qid]['answers']
lbl_ans = _uniq([(k, [v]) for k, v in zip(yesno_answers, answers)],
is_ref)
ret = [(str(qid) + '_' + k, v) for k, v in _expand_result(lbl_ans)]
return ret
if ref_result[qid]['question_type'] != 'YES_NO':
return None, None
ref_ans = _get_yesno_ans(qid, ref_result, is_ref=True)
pred_ans = _get_yesno_ans(qid, pred_result)
return pred_ans, ref_ans
def get_all_result(qid, pred_result, ref_result):
"""
Prepare answers for task 'all'.
Args:
qid: question_id.
pred_result: A dict include all question_id's result information read
from args.pred_file.
ref_result: A dict incluce all question_id's result information read
from args.ref_file.
Returns:
Two lists, the first one contains predict result, the second
one contains reference result of the same question_id. Each list has
elements of tuple (question_id, answers), 'answers' is a list of strings.
"""
if ref_result[qid]['question_type'] == 'YES_NO':
return get_yesno_result(qid, pred_result, ref_result)
return get_main_result(qid, pred_result, ref_result)
def format_metrics(metrics, task, err_msg):
"""
Format metrics. 'err' field returns any error occured during evaluation.
Args:
metrics: A dict object contains metrics for different tasks.
task: Task name.
err_msg: Exception raised during evaluation.
Returns:
Formatted result.
"""
result = {}
sources = ["both", "search", "zhidao"]
if err_msg is not None:
return {'errorMsg': str(err_msg), 'errorCode': 1, 'data': []}
data = []
if task != 'all' and task != 'main':
sources = ["both"]
if task == 'entity':
metric_names = ["Bleu-4", "Rouge-L"]
metric_names_prf = ["F1", "Precision", "Recall"]
for name in metric_names + metric_names_prf:
for src in sources:
obj = {
"name": name,
"value": round(metrics[src].get(name, 0) * 100, 2),
"type": src,
}
data.append(obj)
elif task == 'yesno':
metric_names = ["Bleu-4", "Rouge-L"]
details = ["Yes", "No", "Depends"]
src = sources[0]
for name in metric_names:
obj = {
"name": name,
"value": round(metrics[src].get(name, 0) * 100, 2),
"type": 'All',
}
data.append(obj)
for d in details:
obj = {
"name": name,
"value": \
round(metrics[src].get(d + '|' + name, 0) * 100, 2),
"type": d,
}
data.append(obj)
else:
metric_names = ["Bleu-4", "Rouge-L"]
for name in metric_names:
for src in sources:
obj = {
"name": name,
"value": \
round(metrics[src].get(name, 0) * 100, 2),
"type": src,
}
data.append(obj)
result["data"] = data
result["errorCode"] = 0
result["errorMsg"] = "success"
return result
def main(args):
"""
Do evaluation.
"""
err = None
metrics = {}
try:
pred_result = read_file(args.pred_file, args.task)
ref_result = read_file(args.ref_file, args.task, is_ref=True)
sources = ['both', 'search', 'zhidao']
if args.task not in set(['main', 'all']):
sources = sources[:1]
for source in sources:
metrics[source] = get_metrics(pred_result, ref_result, args.task,
source)
except ValueError as ve:
err = ve
except AssertionError as ae:
err = ae
print(
json.dumps(
format_metrics(metrics, args.task, err), ensure_ascii=False).encode(
'utf8'))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('pred_file', help='predict file')
parser.add_argument('ref_file', help='reference file')
parser.add_argument(
'task', help='task name: Main|Yes_No|All|Entity|Description')
args = parser.parse_args()
args.task = args.task.lower().replace('_', '')
main(args)
# -*- coding:utf8 -*-
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Utility function to generate vocabulary file.
"""
import argparse
import sys
import json
from itertools import chain
def get_vocab(files, vocab_file):
"""
Builds vocabulary file from field 'segmented_paragraphs'
and 'segmented_question'.
Args:
files: A list of file names.
vocab_file: The file that stores the vocabulary.
"""
vocab = {}
for f in files:
with open(f, 'r') as fin:
for line in fin:
obj = json.loads(line.strip())
paras = [
chain(*d['segmented_paragraphs']) for d in obj['documents']
]
doc_tokens = chain(*paras)
question_tokens = obj['segmented_question']
for t in list(doc_tokens) + question_tokens:
vocab[t] = vocab.get(t, 0) + 1
# output
sorted_vocab = sorted(
[(v, c) for v, c in vocab.items()], key=lambda x: x[1], reverse=True)
with open(vocab_file, 'w') as outf:
for w, c in sorted_vocab:
print >> outf, '{}\t{}'.format(w.encode('utf8'), c)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--files',
nargs='+',
required=True,
help='file list to count vocab from.')
parser.add_argument(
'--vocab', required=True, help='file to store counted vocab.')
args = parser.parse_args()
get_vocab(args.files, args.vocab)
#coding=utf8
import os, sys, json
import nltk
def _nltk_tokenize(sequence):
tokens = nltk.word_tokenize(sequence)
cur_char_offset = 0
token_offsets = []
token_words = []
for token in tokens:
cur_char_offset = sequence.find(token, cur_char_offset)
token_offsets.append(
[cur_char_offset, cur_char_offset + len(token) - 1])
token_words.append(token)
return token_offsets, token_words
def segment(input_js):
_, input_js['segmented_question'] = _nltk_tokenize(input_js['question'])
for doc_id, doc in enumerate(input_js['documents']):
doc['segmented_title'] = []
doc['segmented_paragraphs'] = []
for para_id, para in enumerate(doc['paragraphs']):
_, seg_para = _nltk_tokenize(para)
doc['segmented_paragraphs'].append(seg_para)
if 'answers' in input_js:
input_js['segmented_answers'] = []
for answer_id, answer in enumerate(input_js['answers']):
_, seg_answer = _nltk_tokenize(answer)
input_js['segmented_answers'].append(seg_answer)
if __name__ == '__main__':
if len(sys.argv) != 2:
print('Usage: tokenize_data.py <input_path>')
exit()
nltk.download('punkt')
for line in open(sys.argv[1]):
dureader_js = json.loads(line.strip())
segment(dureader_js)
print(json.dumps(dureader_js))
#coding=utf8
import sys
import json
import pandas as pd
def trans(input_js):
output_js = {}
output_js['question'] = input_js['query']
output_js['question_type'] = input_js['query_type']
output_js['question_id'] = input_js['query_id']
output_js['fact_or_opinion'] = ""
output_js['documents'] = []
for para_id, para in enumerate(input_js['passages']):
doc = {}
doc['title'] = ""
if 'is_selected' in para:
doc['is_selected'] = True if para['is_selected'] != 0 else False
doc['paragraphs'] = [para['passage_text']]
output_js['documents'].append(doc)
if 'answers' in input_js:
output_js['answers'] = input_js['answers']
return output_js
if __name__ == '__main__':
if len(sys.argv) != 2:
print('Usage: marcov1_to_dureader.py <input_path>')
exit()
df = pd.read_json(sys.argv[1])
for row in df.iterrows():
marco_js = json.loads(row[1].to_json())
dureader_js = trans(marco_js)
print(json.dumps(dureader_js))
import sys
import json
import pandas as pd
if __name__ == '__main__':
if len(sys.argv) != 3:
print('Usage: tojson.py <input_path> <output_path>')
exit()
infile = sys.argv[1]
outfile = sys.argv[2]
df = pd.read_json(infile)
with open(outfile, 'w') as f:
for row in df.iterrows():
f.write(row[1].to_json() + '\n')
###############################################################################
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module finds the most related paragraph of each document according to recall.
"""
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
import json
from collections import Counter
def precision_recall_f1(prediction, ground_truth):
"""
This function calculates and returns the precision, recall and f1-score
Args:
prediction: prediction string or list to be matched
ground_truth: golden string or list reference
Returns:
floats of (p, r, f1)
Raises:
None
"""
if not isinstance(prediction, list):
prediction_tokens = prediction.split()
else:
prediction_tokens = prediction
if not isinstance(ground_truth, list):
ground_truth_tokens = ground_truth.split()
else:
ground_truth_tokens = ground_truth
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0, 0, 0
p = 1.0 * num_same / len(prediction_tokens)
r = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * p * r) / (p + r)
return p, r, f1
def recall(prediction, ground_truth):
"""
This function calculates and returns the recall
Args:
prediction: prediction string or list to be matched
ground_truth: golden string or list reference
Returns:
floats of recall
Raises:
None
"""
return precision_recall_f1(prediction, ground_truth)[1]
def f1_score(prediction, ground_truth):
"""
This function calculates and returns the f1-score
Args:
prediction: prediction string or list to be matched
ground_truth: golden string or list reference
Returns:
floats of f1
Raises:
None
"""
return precision_recall_f1(prediction, ground_truth)[2]
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""
This function calculates and returns the precision, recall and f1-score
Args:
metric_fn: metric function pointer which calculates scores according to corresponding logic.
prediction: prediction string or list to be matched
ground_truth: golden string or list reference
Returns:
floats of (p, r, f1)
Raises:
None
"""
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def find_best_question_match(doc, question, with_score=False):
"""
For each docment, find the paragraph that matches best to the question.
Args:
doc: The document object.
question: The question tokens.
with_score: If True then the match score will be returned,
otherwise False.
Returns:
The index of the best match paragraph, if with_score=False,
otherwise returns a tuple of the index of the best match paragraph
and the match score of that paragraph.
"""
most_related_para = -1
max_related_score = 0
most_related_para_len = 0
for p_idx, para_tokens in enumerate(doc['segmented_paragraphs']):
if len(question) > 0:
related_score = metric_max_over_ground_truths(recall, para_tokens,
question)
else:
related_score = 0
if related_score > max_related_score \
or (related_score == max_related_score \
and len(para_tokens) < most_related_para_len):
most_related_para = p_idx
max_related_score = related_score
most_related_para_len = len(para_tokens)
if most_related_para == -1:
most_related_para = 0
if with_score:
return most_related_para, max_related_score
return most_related_para
def find_fake_answer(sample):
"""
For each document, finds the most related paragraph based on recall,
then finds a span that maximize the f1_score compared with the gold answers
and uses this span as a fake answer span
Args:
sample: a sample in the dataset
Returns:
None
Raises:
None
"""
for doc in sample['documents']:
most_related_para = -1
most_related_para_len = 999999
max_related_score = 0
for p_idx, para_tokens in enumerate(doc['segmented_paragraphs']):
if len(sample['segmented_answers']) > 0:
related_score = metric_max_over_ground_truths(
recall, para_tokens, sample['segmented_answers'])
else:
continue
if related_score > max_related_score \
or (related_score == max_related_score
and len(para_tokens) < most_related_para_len):
most_related_para = p_idx
most_related_para_len = len(para_tokens)
max_related_score = related_score
doc['most_related_para'] = most_related_para
sample['answer_docs'] = []
sample['answer_spans'] = []
sample['fake_answers'] = []
sample['match_scores'] = []
best_match_score = 0
best_match_d_idx, best_match_span = -1, [-1, -1]
best_fake_answer = None
answer_tokens = set()
for segmented_answer in sample['segmented_answers']:
answer_tokens = answer_tokens | set(
[token for token in segmented_answer])
for d_idx, doc in enumerate(sample['documents']):
if not doc['is_selected']:
continue
if doc['most_related_para'] == -1:
doc['most_related_para'] = 0
most_related_para_tokens = doc['segmented_paragraphs'][doc[
'most_related_para']][:1000]
for start_tidx in range(len(most_related_para_tokens)):
if most_related_para_tokens[start_tidx] not in answer_tokens:
continue
for end_tidx in range(
len(most_related_para_tokens) - 1, start_tidx - 1, -1):
span_tokens = most_related_para_tokens[start_tidx:end_tidx + 1]
if len(sample['segmented_answers']) > 0:
match_score = metric_max_over_ground_truths(
f1_score, span_tokens, sample['segmented_answers'])
else:
match_score = 0
if match_score == 0:
break
if match_score > best_match_score:
best_match_d_idx = d_idx
best_match_span = [start_tidx, end_tidx]
best_match_score = match_score
best_fake_answer = ''.join(span_tokens)
if best_match_score > 0:
sample['answer_docs'].append(best_match_d_idx)
sample['answer_spans'].append(best_match_span)
sample['fake_answers'].append(best_fake_answer)
sample['match_scores'].append(best_match_score)
if __name__ == '__main__':
for line in sys.stdin:
sample = json.loads(line)
find_fake_answer(sample)
print(json.dumps(sample, encoding='utf8', ensure_ascii=False))
#!/bin/bash
input_file=$1
output_file=$2
# convert the data from MARCO V2 (json) format to MARCO V1 (jsonl) format.
# the script was forked from MARCO repo.
# the format of MARCO V1 is much more easier to explore.
python3 marcov2_to_v1_tojsonl.py $input_file $input_file.marcov1
# convert the data from MARCO V1 format to DuReader format.
python3 marcov1_to_dureader.py $input_file.marcov1 >$input_file.dureader_raw
# tokenize the data.
python3 marco_tokenize_data.py $input_file.dureader_raw >$input_file.segmented
# find fake answers (indicating the start and end positions of answers in the document) for train and dev sets.
# note that this should not be applied for test set, since there is no ground truth in test set.
python preprocess.py $input_file.segmented >$output_file
# remove the temporal data files.
rm -rf $input_file.dureader_raw $input_file.segmented
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册