未验证 提交 1388a943 编写于 作者: X xuezhong 提交者: GitHub

Merge pull request #1288 from xuezhong/machine_reading_comprehesion

add files for DuReader
# Abstract
Dureader is an end-to-end neural network model for machine reading comprehesion style question answering, which aims to anser questions from given passages. We first match the question and passage with a bidireactional attention flow network to obtrain the question-aware passages represenation. Then we employ a pointer network to locate the positions of answers from passages. Our experimental evalutions show that DuReader model achieves the state-of-the-art results in DuReader Dadaset.
# Dataset
DuReader Dataset is a new large-scale real-world and human sourced MRC dataset in Chinese. DuReader focuses on real-world open-domain question answering. The advantages of DuReader over existing datasets are concluded as follows:
- Real question
- Real article
- Real answer
- Real application scenario
- Rich annotation
# Network
DuReader is inspired by 3 classic reading comprehension models([BiDAF](https://arxiv.org/abs/1611.01603), [Match-LSTM](https://arxiv.org/abs/1608.07905), [R-NET](https://www.microsoft.com/en-us/research/wp-content/uploads/2017/05/r-net.pdf)).
DuReader model is a hierarchical multi-stage process and consists of five layers
- **Word Embedding Layer** maps each word to a vector using a pre-trained word embedding model.
- **Encoding Layer** extracts context infomation for each position in question and passages with a bi-directional LSTM network.
- **Attention Flow Layer** couples the query and context vectors and produces a set of query-aware feature vectors for each word in the context. Please refer to [BiDAF](https://arxiv.org/abs/1611.01603) for more details.
- **Fusion Layer** employs a layer of bi-directional LSTM to capture the interaction among context words independent of the query.
- **Decode Layer** employs an answer point network with attention pooling of the quesiton to locate the positions of answers from passages. Please refer to [Match-LSTM](https://arxiv.org/abs/1608.07905) and [R-NET](https://www.microsoft.com/en-us/research/wp-content/uploads/2017/05/r-net.pdf) for more details.
## How to Run
### Download the Dataset
To Download DuReader dataset:
```
cd data && bash download.sh
```
For more details about DuReader dataset please refer to [DuReader Dataset Homepage](https://ai.baidu.com//broad/subordinate?dataset=dureader).
### Download Thirdparty Dependencies
We use Bleu and Rouge as evaluation metrics, the calculation of these metrics relies on the scoring scripts under [coco-caption](https://github.com/tylin/coco-caption), to download them, run:
```
cd utils && bash download_thirdparty.sh
```
### Environment Requirements
For now we've only tested on PaddlePaddle v1.0, to install PaddlePaddle and for more details about PaddlePaddle, see [PaddlePaddle Homepage](http://paddlepaddle.org).
### Preparation
Before training the model, we have to make sure that the data is ready. For preparation, we will check the data files, make directories and extract a vocabulary for later use. You can run the following command to do this with a specified task name:
```
sh run.sh --prepare
```
You can specify the files for train/dev/test by setting the `trainset`/`devset`/`testset`.
### Training
To train the model and you can also set the hyper-parameters such as the learning rate by using `--learning_rate NUM`. For example, to train the model for 10 passes, you can run:
```
sh run.sh --train --pass_num 10
```
The training process includes an evaluation on the dev set after each training epoch. By default, the model with the least Bleu-4 score on the dev set will be saved.
### Evaluation
To conduct a single evaluation on the dev set with the the model already trained, you can run the following command:
```
sh run.sh --evaluate --load_dir models/1
```
### Prediction
You can also predict answers for the samples in some files using the following command:
```
sh run.sh --predict --load_dir models/1 --testset ../data/demo/devset/search.dev.json
```
By default, the results are saved at `../data/results/` folder. You can change this by specifying `--result_dir DIR_PATH`.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import distutils.util
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--prepare',
action='store_true',
help='create the directories, prepare the vocabulary and embeddings')
parser.add_argument('--train', action='store_true', help='train the model')
parser.add_argument(
'--evaluate', action='store_true', help='evaluate the model on dev set')
parser.add_argument(
'--predict',
action='store_true',
help='predict the answers for test set with trained model')
parser.add_argument(
"--embed_size",
type=int,
default=300,
help="The dimension of embedding table. (default: %(default)d)")
parser.add_argument(
"--hidden_size",
type=int,
default=300,
help="The size of rnn hidden unit. (default: %(default)d)")
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="The sequence number of a mini-batch data. (default: %(default)d)")
parser.add_argument(
"--pass_num",
type=int,
default=5,
help="The pass number to train. (default: %(default)d)")
parser.add_argument(
"--learning_rate",
type=float,
default=0.001,
help="Learning rate used to train the model. (default: %(default)f)")
parser.add_argument(
"--use_gpu",
type=distutils.util.strtobool,
default=True,
help="Whether to use gpu. (default: %(default)d)")
parser.add_argument(
"--save_dir",
type=str,
default="model",
help="Specify the path to save trained models.")
parser.add_argument(
"--load_dir",
type=str,
default="",
help="Specify the path to load trained models.")
parser.add_argument(
"--save_interval",
type=int,
default=1,
help="Save the trained model every n passes."
"(default: %(default)d)")
parser.add_argument(
"--log_interval",
type=int,
default=50,
help="log the train loss every n batches."
"(default: %(default)d)")
parser.add_argument(
"--dev_interval",
type=int,
default=1000,
help="cal dev loss every n batches."
"(default: %(default)d)")
parser.add_argument('--optim', default='adam', help='optimizer type')
parser.add_argument('--trainset', nargs='+', help='train dataset')
parser.add_argument('--devset', nargs='+', help='dev dataset')
parser.add_argument('--testset', nargs='+', help='test dataset')
parser.add_argument('--vocab_dir', help='dict')
parser.add_argument('--max_p_num', type=int, default=5)
parser.add_argument('--max_a_len', type=int, default=200)
parser.add_argument('--max_p_len', type=int, default=500)
parser.add_argument('--max_q_len', type=int, default=9)
parser.add_argument('--doc_num', type=int, default=5)
parser.add_argument('--para_print', action='store_true')
parser.add_argument('--drop_rate', type=float, default=0.0)
parser.add_argument('--random_seed', type=int, default=123)
parser.add_argument(
'--log_path',
help='path of the log file. If not set, logs are printed to console')
parser.add_argument(
'--result_dir',
default='../data/results/',
help='the dir to output the results')
parser.add_argument(
'--result_name',
default='test_result',
help='the file name of the results')
args = parser.parse_args()
return args
#!/bin/bash
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
if [[ -d preprocessed ]] && [[ -d raw ]]; then
echo "data exist"
exit 0
else
wget -c --no-check-certificate http://dureader.gz.bcebos.com/dureader_preprocessed.zip
fi
if md5sum --status -c md5sum.txt; then
unzip dureader_preprocessed.zip
else
echo "download data error!" >> /dev/stderr
exit 1
fi
7a4c28026f7dc94e8135d17203c63664 dureader_preprocessed.zip
# -*- coding:utf8 -*-
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module implements data process strategies.
"""
import os
import json
import logging
import numpy as np
from collections import Counter
class BRCDataset(object):
"""
This module implements the APIs for loading and using baidu reading comprehension dataset
"""
def __init__(self,
max_p_num,
max_p_len,
max_q_len,
train_files=[],
dev_files=[],
test_files=[]):
self.logger = logging.getLogger("brc")
self.max_p_num = max_p_num
self.max_p_len = max_p_len
self.max_q_len = max_q_len
self.train_set, self.dev_set, self.test_set = [], [], []
if train_files:
for train_file in train_files:
self.train_set += self._load_dataset(train_file, train=True)
self.logger.info('Train set size: {} questions.'.format(
len(self.train_set)))
if dev_files:
for dev_file in dev_files:
self.dev_set += self._load_dataset(dev_file)
self.logger.info('Dev set size: {} questions.'.format(
len(self.dev_set)))
if test_files:
for test_file in test_files:
self.test_set += self._load_dataset(test_file)
self.logger.info('Test set size: {} questions.'.format(
len(self.test_set)))
def _load_dataset(self, data_path, train=False):
"""
Loads the dataset
Args:
data_path: the data file to load
"""
with open(data_path) as fin:
data_set = []
for lidx, line in enumerate(fin):
sample = json.loads(line.strip())
if train:
if len(sample['answer_spans']) == 0:
continue
if sample['answer_spans'][0][1] >= self.max_p_len:
continue
if 'answer_docs' in sample:
sample['answer_passages'] = sample['answer_docs']
sample['question_tokens'] = sample['segmented_question']
sample['passages'] = []
for d_idx, doc in enumerate(sample['documents']):
if train:
most_related_para = doc['most_related_para']
sample['passages'].append({
'passage_tokens':
doc['segmented_paragraphs'][most_related_para],
'is_selected': doc['is_selected']
})
else:
para_infos = []
for para_tokens in doc['segmented_paragraphs']:
question_tokens = sample['segmented_question']
common_with_question = Counter(
para_tokens) & Counter(question_tokens)
correct_preds = sum(common_with_question.values())
if correct_preds == 0:
recall_wrt_question = 0
else:
recall_wrt_question = float(
correct_preds) / len(question_tokens)
para_infos.append((para_tokens, recall_wrt_question,
len(para_tokens)))
para_infos.sort(key=lambda x: (-x[1], x[2]))
fake_passage_tokens = []
for para_info in para_infos[:1]:
fake_passage_tokens += para_info[0]
sample['passages'].append({
'passage_tokens': fake_passage_tokens
})
data_set.append(sample)
return data_set
def _one_mini_batch(self, data, indices, pad_id):
"""
Get one mini batch
Args:
data: all data
indices: the indices of the samples to be selected
pad_id:
Returns:
one batch of data
"""
batch_data = {
'raw_data': [data[i] for i in indices],
'question_token_ids': [],
'question_length': [],
'passage_token_ids': [],
'passage_length': [],
'start_id': [],
'end_id': []
}
max_passage_num = max(
[len(sample['passages']) for sample in batch_data['raw_data']])
#max_passage_num = min(self.max_p_num, max_passage_num)
max_passage_num = self.max_p_num
for sidx, sample in enumerate(batch_data['raw_data']):
for pidx in range(max_passage_num):
if pidx < len(sample['passages']):
batch_data['question_token_ids'].append(sample[
'question_token_ids'])
batch_data['question_length'].append(
len(sample['question_token_ids']))
passage_token_ids = sample['passages'][pidx][
'passage_token_ids']
batch_data['passage_token_ids'].append(passage_token_ids)
batch_data['passage_length'].append(
min(len(passage_token_ids), self.max_p_len))
else:
batch_data['question_token_ids'].append([])
batch_data['question_length'].append(0)
batch_data['passage_token_ids'].append([])
batch_data['passage_length'].append(0)
batch_data, padded_p_len, padded_q_len = self._dynamic_padding(
batch_data, pad_id)
for sample in batch_data['raw_data']:
if 'answer_passages' in sample and len(sample['answer_passages']):
gold_passage_offset = padded_p_len * sample['answer_passages'][
0]
batch_data['start_id'].append(gold_passage_offset + sample[
'answer_spans'][0][0])
batch_data['end_id'].append(gold_passage_offset + sample[
'answer_spans'][0][1])
else:
# fake span for some samples, only valid for testing
batch_data['start_id'].append(0)
batch_data['end_id'].append(0)
return batch_data
def _dynamic_padding(self, batch_data, pad_id):
"""
Dynamically pads the batch_data with pad_id
"""
pad_p_len = min(self.max_p_len, max(batch_data['passage_length']))
pad_q_len = min(self.max_q_len, max(batch_data['question_length']))
batch_data['passage_token_ids'] = [
(ids + [pad_id] * (pad_p_len - len(ids)))[:pad_p_len]
for ids in batch_data['passage_token_ids']
]
batch_data['question_token_ids'] = [
(ids + [pad_id] * (pad_q_len - len(ids)))[:pad_q_len]
for ids in batch_data['question_token_ids']
]
return batch_data, pad_p_len, pad_q_len
def word_iter(self, set_name=None):
"""
Iterates over all the words in the dataset
Args:
set_name: if it is set, then the specific set will be used
Returns:
a generator
"""
if set_name is None:
data_set = self.train_set + self.dev_set + self.test_set
elif set_name == 'train':
data_set = self.train_set
elif set_name == 'dev':
data_set = self.dev_set
elif set_name == 'test':
data_set = self.test_set
else:
raise NotImplementedError('No data set named as {}'.format(
set_name))
if data_set is not None:
for sample in data_set:
for token in sample['question_tokens']:
yield token
for passage in sample['passages']:
for token in passage['passage_tokens']:
yield token
def convert_to_ids(self, vocab):
"""
Convert the question and passage in the original dataset to ids
Args:
vocab: the vocabulary on this dataset
"""
for data_set in [self.train_set, self.dev_set, self.test_set]:
if data_set is None:
continue
for sample in data_set:
sample['question_token_ids'] = vocab.convert_to_ids(sample[
'question_tokens'])
for passage in sample['passages']:
passage['passage_token_ids'] = vocab.convert_to_ids(passage[
'passage_tokens'])
def gen_mini_batches(self, set_name, batch_size, pad_id, shuffle=True):
"""
Generate data batches for a specific dataset (train/dev/test)
Args:
set_name: train/dev/test to indicate the set
batch_size: number of samples in one batch
pad_id: pad id
shuffle: if set to be true, the data is shuffled.
Returns:
a generator for all batches
"""
if set_name == 'train':
data = self.train_set
elif set_name == 'dev':
data = self.dev_set
elif set_name == 'test':
data = self.test_set
else:
raise NotImplementedError('No data set named as {}'.format(
set_name))
data_size = len(data)
indices = np.arange(data_size)
if shuffle:
np.random.shuffle(indices)
for batch_start in np.arange(0, data_size, batch_size):
batch_indices = indices[batch_start:batch_start + batch_size]
yield self._one_mini_batch(data, batch_indices, pad_id)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid.layers as layers
import paddle.fluid as fluid
import numpy as np
def dropout(input, args):
if args.drop_rate:
return layers.dropout(
input,
dropout_prob=args.drop_rate,
seed=args.random_seed,
is_test=False)
else:
return input
def bi_lstm_encoder(input_seq, gate_size, para_name, args):
# A bi-directional lstm encoder implementation.
# Linear transformation part for input gate, output gate, forget gate
# and cell activation vectors need be done outside of dynamic_lstm.
# So the output size is 4 times of gate_size.
input_forward_proj = layers.fc(
input=input_seq,
param_attr=fluid.ParamAttr(name=para_name + '_fw_gate_w'),
size=gate_size * 4,
act=None,
bias_attr=False)
input_reversed_proj = layers.fc(
input=input_seq,
param_attr=fluid.ParamAttr(name=para_name + '_bw_gate_w'),
size=gate_size * 4,
act=None,
bias_attr=False)
forward, _ = layers.dynamic_lstm(
input=input_forward_proj,
size=gate_size * 4,
use_peepholes=False,
param_attr=fluid.ParamAttr(name=para_name + '_fw_lstm_w'),
bias_attr=fluid.ParamAttr(name=para_name + '_fw_lstm_b'))
reversed, _ = layers.dynamic_lstm(
input=input_reversed_proj,
param_attr=fluid.ParamAttr(name=para_name + '_bw_lstm_w'),
bias_attr=fluid.ParamAttr(name=para_name + '_bw_lstm_b'),
size=gate_size * 4,
is_reverse=True,
use_peepholes=False)
encoder_out = layers.concat(input=[forward, reversed], axis=1)
return encoder_out
def encoder(input_name, para_name, shape, hidden_size, args):
input_ids = layers.data(
name=input_name, shape=[1], dtype='int64', lod_level=1)
input_embedding = layers.embedding(
input=input_ids,
size=shape,
dtype='float32',
is_sparse=True,
param_attr=fluid.ParamAttr(name='embedding_para'))
encoder_out = bi_lstm_encoder(
input_seq=input_embedding,
gate_size=hidden_size,
para_name=para_name,
args=args)
return dropout(encoder_out, args)
def attn_flow(q_enc, p_enc, p_ids_name, args):
tag = p_ids_name + "::"
drnn = layers.DynamicRNN()
with drnn.block():
h_cur = drnn.step_input(p_enc)
u_all = drnn.static_input(q_enc)
h_expd = layers.sequence_expand(x=h_cur, y=u_all)
s_t_mul = layers.elementwise_mul(x=u_all, y=h_expd, axis=0)
s_t_sum = layers.reduce_sum(input=s_t_mul, dim=1, keep_dim=True)
s_t_re = layers.reshape(s_t_sum, shape=[-1, 0])
s_t = layers.sequence_softmax(input=s_t_re)
u_expr = layers.elementwise_mul(x=u_all, y=s_t, axis=0)
u_expr = layers.sequence_pool(input=u_expr, pool_type='sum')
b_t = layers.sequence_pool(input=s_t_sum, pool_type='max')
drnn.output(u_expr, b_t)
U_expr, b = drnn()
b_norm = layers.sequence_softmax(input=b)
h_expr = layers.elementwise_mul(x=p_enc, y=b_norm, axis=0)
h_expr = layers.sequence_pool(input=h_expr, pool_type='sum')
H_expr = layers.sequence_expand(x=h_expr, y=p_enc)
H_expr = layers.lod_reset(x=H_expr, y=p_enc)
h_u = layers.elementwise_mul(x=p_enc, y=U_expr, axis=0)
h_h = layers.elementwise_mul(x=p_enc, y=H_expr, axis=0)
g = layers.concat(input=[p_enc, U_expr, h_u, h_h], axis=1)
return dropout(g, args)
def lstm_step(x_t, hidden_t_prev, cell_t_prev, size, para_name, args):
def linear(inputs, para_name, args):
return layers.fc(input=inputs,
size=size,
param_attr=fluid.ParamAttr(name=para_name + '_w'),
bias_attr=fluid.ParamAttr(name=para_name + '_b'))
input_cat = layers.concat([hidden_t_prev, x_t], axis=1)
forget_gate = layers.sigmoid(x=linear(input_cat, para_name + '_lstm_f',
args))
input_gate = layers.sigmoid(x=linear(input_cat, para_name + '_lstm_i',
args))
output_gate = layers.sigmoid(x=linear(input_cat, para_name + '_lstm_o',
args))
cell_tilde = layers.tanh(x=linear(input_cat, para_name + '_lstm_c', args))
cell_t = layers.sums(input=[
layers.elementwise_mul(
x=forget_gate, y=cell_t_prev), layers.elementwise_mul(
x=input_gate, y=cell_tilde)
])
hidden_t = layers.elementwise_mul(x=output_gate, y=layers.tanh(x=cell_t))
return hidden_t, cell_t
#point network
def point_network_decoder(p_vec, q_vec, hidden_size, args):
tag = 'pn_decoder:'
init_random = fluid.initializer.Normal(loc=0.0, scale=1.0)
random_attn = layers.create_parameter(
shape=[1, hidden_size],
dtype='float32',
default_initializer=init_random)
random_attn = layers.fc(
input=random_attn,
size=hidden_size,
act=None,
param_attr=fluid.ParamAttr(name=tag + 'random_attn_fc_w'),
bias_attr=fluid.ParamAttr(name=tag + 'random_attn_fc_b'))
random_attn = layers.reshape(random_attn, shape=[-1])
U = layers.fc(input=q_vec,
param_attr=fluid.ParamAttr(name=tag + 'q_vec_fc_w'),
bias_attr=False,
size=hidden_size,
act=None) + random_attn
U = layers.tanh(U)
logits = layers.fc(input=U,
param_attr=fluid.ParamAttr(name=tag + 'logits_fc_w'),
bias_attr=fluid.ParamAttr(name=tag + 'logits_fc_b'),
size=1,
act=None)
scores = layers.sequence_softmax(input=logits)
pooled_vec = layers.elementwise_mul(x=q_vec, y=scores, axis=0)
pooled_vec = layers.sequence_pool(input=pooled_vec, pool_type='sum')
init_state = layers.fc(
input=pooled_vec,
param_attr=fluid.ParamAttr(name=tag + 'init_state_fc_w'),
bias_attr=fluid.ParamAttr(name=tag + 'init_state_fc_b'),
size=hidden_size,
act=None)
def custom_dynamic_rnn(p_vec, init_state, hidden_size, para_name, args):
tag = para_name + "custom_dynamic_rnn:"
def static_rnn(step,
p_vec=p_vec,
init_state=None,
para_name='',
args=args):
tag = para_name + "static_rnn:"
ctx = layers.fc(
input=p_vec,
param_attr=fluid.ParamAttr(name=tag + 'context_fc_w'),
bias_attr=fluid.ParamAttr(name=tag + 'context_fc_b'),
size=hidden_size,
act=None)
beta = []
c_prev = init_state
m_prev = init_state
for i in range(step):
m_prev0 = layers.fc(
input=m_prev,
size=hidden_size,
act=None,
param_attr=fluid.ParamAttr(name=tag + 'm_prev0_fc_w'),
bias_attr=fluid.ParamAttr(name=tag + 'm_prev0_fc_b'))
m_prev1 = layers.sequence_expand(x=m_prev0, y=ctx)
Fk = ctx + m_prev1
Fk = layers.tanh(Fk)
logits = layers.fc(
input=Fk,
size=1,
act=None,
param_attr=fluid.ParamAttr(name=tag + 'logits_fc_w'),
bias_attr=fluid.ParamAttr(name=tag + 'logits_fc_b'))
scores = layers.sequence_softmax(input=logits)
attn_ctx = layers.elementwise_mul(x=p_vec, y=scores, axis=0)
attn_ctx = layers.sequence_pool(input=attn_ctx, pool_type='sum')
hidden_t, cell_t = lstm_step(
attn_ctx,
hidden_t_prev=m_prev,
cell_t_prev=c_prev,
size=hidden_size,
para_name=tag,
args=args)
m_prev = hidden_t
c_prev = cell_t
beta.append(scores)
return beta
return static_rnn(
2, p_vec=p_vec, init_state=init_state, para_name=para_name)
fw_outputs = custom_dynamic_rnn(p_vec, init_state, hidden_size, tag + "fw:",
args)
bw_outputs = custom_dynamic_rnn(p_vec, init_state, hidden_size, tag + "bw:",
args)
start_prob = layers.elementwise_add(
x=fw_outputs[0], y=bw_outputs[1], axis=0) / 2
end_prob = layers.elementwise_add(
x=fw_outputs[1], y=bw_outputs[0], axis=0) / 2
return start_prob, end_prob
def fusion(g, args):
m = bi_lstm_encoder(
input_seq=g, gate_size=args.hidden_size, para_name='fusion', args=args)
return dropout(m, args)
def rc_model(hidden_size, vocab, args):
emb_shape = [vocab.size(), vocab.embed_dim]
# stage 1:encode
p_ids_names = []
q_ids_names = []
ms = []
gs = []
qs = []
for i in range(args.doc_num):
p_ids_name = "pids_%d" % i
p_ids_names.append(p_ids_name)
p_enc_i = encoder(p_ids_name, 'p_enc', emb_shape, hidden_size, args)
q_ids_name = "qids_%d" % i
q_ids_names.append(q_ids_name)
q_enc_i = encoder(q_ids_name, 'q_enc', emb_shape, hidden_size, args)
# stage 2:match
g_i = attn_flow(q_enc_i, p_enc_i, p_ids_name, args)
# stage 3:fusion
m_i = fusion(g_i, args)
ms.append(m_i)
gs.append(g_i)
qs.append(q_enc_i)
m = layers.sequence_concat(input=ms)
g = layers.sequence_concat(input=gs)
q_vec = layers.sequence_concat(input=qs)
# stage 4:decode
start_probs, end_probs = point_network_decoder(
p_vec=m, q_vec=q_vec, hidden_size=hidden_size, args=args)
start_labels = layers.data(
name="start_lables", shape=[1], dtype='float32', lod_level=1)
end_labels = layers.data(
name="end_lables", shape=[1], dtype='float32', lod_level=1)
cost0 = layers.sequence_pool(
layers.cross_entropy(
input=start_probs, label=start_labels, soft_label=True),
'sum')
cost1 = layers.sequence_pool(
layers.cross_entropy(
input=end_probs, label=end_labels, soft_label=True),
'sum')
cost0 = layers.mean(cost0)
cost1 = layers.mean(cost1)
cost = cost0 + cost1
cost.persistable = True
feeding_list = q_ids_names + ["start_lables", "end_lables"] + p_ids_names
return cost, start_probs, end_probs, feeding_list
此差异已折叠。
export CUDA_VISIBLE_DEVICES=1
python run.py \
--trainset 'data/preprocessed/trainset/search.train.json' \
'data/preprocessed/trainset/zhidao.train.json' \
--devset 'data/preprocessed/devset/search.dev.json' \
'data/preprocessed/devset/zhidao.dev.json' \
--testset 'data/preprocessed/testset/search.test.json' \
'data/preprocessed/testset/zhidao.test.json' \
--vocab_dir 'data/vocab' \
--use_gpu true \
--save_dir ./models \
--pass_num 10 \
--learning_rate 0.001 \
--batch_size 8 \
--embed_size 300 \
--hidden_size 150 \
--max_p_num 5 \
--max_p_len 500 \
--max_q_len 60 \
--max_a_len 200 \
--drop_rate 0.2 $@\
# coding:utf8
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This package implements some utility functions shared by PaddlePaddle
and Tensorflow model implementations.
Authors: liuyuan(liuyuan04@baidu.com)
Date: 2017/10/06 18:23:06
"""
from .dureader_eval import compute_bleu_rouge
from .dureader_eval import normalize
from .preprocess import find_fake_answer
from .preprocess import find_best_question_match
__all__ = [
'compute_bleu_rouge',
'normalize',
'find_fake_answer',
'find_best_question_match',
]
#!/bin/bash
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# We use Bleu and Rouge as evaluation metrics, the calculation of these metrics
# relies on the scoring scripts under "https://github.com/tylin/coco-caption"
bleu_base_url='https://raw.githubusercontent.com/tylin/coco-caption/master/pycocoevalcap/bleu'
bleu_files=("LICENSE" "__init__.py" "bleu.py" "bleu_scorer.py")
rouge_base_url="https://raw.githubusercontent.com/tylin/coco-caption/master/pycocoevalcap/rouge"
rouge_files=("__init__.py" "rouge.py")
download() {
local metric=$1; shift;
local base_url=$1; shift;
local fnames=($@);
mkdir -p ${metric}
for fname in ${fnames[@]};
do
printf "downloading: %s\n" ${base_url}/${fname}
wget --no-check-certificate ${base_url}/${fname} -O ${metric}/${fname}
done
}
# prepare rouge
download "rouge_metric" ${rouge_base_url} ${rouge_files[@]}
# prepare bleu
download "bleu_metric" ${bleu_base_url} ${bleu_files[@]}
# convert python 2.x source code to python 3.x
2to3 -w "../utils/bleu_metric/bleu_scorer.py"
2to3 -w "../utils/bleu_metric/bleu.py"
# -*- coding:utf8 -*-
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module computes evaluation metrics for DuReader dataset.
"""
import argparse
import json
import sys
import zipfile
from collections import Counter
from .bleu_metric.bleu import Bleu
from .rouge_metric.rouge import Rouge
EMPTY = ''
YESNO_LABELS = set(['Yes', 'No', 'Depends'])
def normalize(s):
"""
Normalize strings to space joined chars.
Args:
s: a list of strings.
Returns:
A list of normalized strings.
"""
if not s:
return s
normalized = []
for ss in s:
tokens = [c for c in list(ss) if len(c.strip()) != 0]
normalized.append(' '.join(tokens))
return normalized
def data_check(obj, task):
"""
Check data.
Raises:
Raises AssertionError when data is not legal.
"""
assert 'question_id' in obj, "Missing 'question_id' field."
assert 'question_type' in obj, \
"Missing 'question_type' field. question_id: {}".format(obj['question_type'])
assert 'yesno_answers' in obj, \
"Missing 'yesno_answers' field. question_id: {}".format(obj['question_id'])
assert isinstance(obj['yesno_answers'], list), \
r"""'yesno_answers' field must be a list, if the 'question_type' is not
'YES_NO', then this field should be an empty list.
question_id: {}""".format(obj['question_id'])
assert 'entity_answers' in obj, \
"Missing 'entity_answers' field. question_id: {}".format(obj['question_id'])
assert isinstance(obj['entity_answers'], list) \
and len(obj['entity_answers']) > 0, \
r"""'entity_answers' field must be a list, and has at least one element,
which can be a empty list. question_id: {}""".format(obj['question_id'])
def read_file(file_name, task, is_ref=False):
"""
Read predict answers or reference answers from file.
Args:
file_name: the name of the file containing predict result or reference
result.
Returns:
A dictionary mapping question_id to the result information. The result
information itself is also a dictionary with has four keys:
- question_type: type of the query.
- yesno_answers: A list of yesno answers corresponding to 'answers'.
- answers: A list of predicted answers.
- entity_answers: A list, each element is also a list containing the entities
tagged out from the corresponding answer string.
"""
def _open(file_name, mode, zip_obj=None):
if zip_obj is not None:
return zip_obj.open(file_name, mode)
return open(file_name, mode)
results = {}
keys = ['answers', 'yesno_answers', 'entity_answers', 'question_type']
if is_ref:
keys += ['source']
zf = zipfile.ZipFile(file_name, 'r') if file_name.endswith('.zip') else None
file_list = [file_name] if zf is None else zf.namelist()
for fn in file_list:
for line in _open(fn, 'r', zip_obj=zf):
try:
obj = json.loads(line.strip())
except ValueError:
raise ValueError("Every line of data should be legal json")
data_check(obj, task)
qid = obj['question_id']
assert qid not in results, "Duplicate question_id: {}".format(qid)
results[qid] = {}
for k in keys:
results[qid][k] = obj[k]
return results
def compute_bleu_rouge(pred_dict, ref_dict, bleu_order=4):
"""
Compute bleu and rouge scores.
"""
assert set(pred_dict.keys()) == set(ref_dict.keys()), \
"missing keys: {}".format(set(ref_dict.keys()) - set(pred_dict.keys()))
scores = {}
bleu_scores, _ = Bleu(bleu_order).compute_score(ref_dict, pred_dict)
for i, bleu_score in enumerate(bleu_scores):
scores['Bleu-%d' % (i + 1)] = bleu_score
rouge_score, _ = Rouge().compute_score(ref_dict, pred_dict)
scores['Rouge-L'] = rouge_score
return scores
def local_prf(pred_list, ref_list):
"""
Compute local precision recall and f1-score,
given only one prediction list and one reference list
"""
common = Counter(pred_list) & Counter(ref_list)
num_same = sum(common.values())
if num_same == 0:
return 0, 0, 0
p = 1.0 * num_same / len(pred_list)
r = 1.0 * num_same / len(ref_list)
f1 = (2 * p * r) / (p + r)
return p, r, f1
def compute_prf(pred_dict, ref_dict):
"""
Compute precision recall and f1-score.
"""
pred_question_ids = set(pred_dict.keys())
ref_question_ids = set(ref_dict.keys())
correct_preds, total_correct, total_preds = 0, 0, 0
for question_id in ref_question_ids:
pred_entity_list = pred_dict.get(question_id, [[]])
assert len(pred_entity_list) == 1, \
'the number of entity list for question_id {} is not 1.'.format(question_id)
pred_entity_list = pred_entity_list[0]
all_ref_entity_lists = ref_dict[question_id]
best_local_f1 = 0
best_ref_entity_list = None
for ref_entity_list in all_ref_entity_lists:
local_f1 = local_prf(pred_entity_list, ref_entity_list)[2]
if local_f1 > best_local_f1:
best_ref_entity_list = ref_entity_list
best_local_f1 = local_f1
if best_ref_entity_list is None:
if len(all_ref_entity_lists) > 0:
best_ref_entity_list = sorted(
all_ref_entity_lists, key=lambda x: len(x))[0]
else:
best_ref_entity_list = []
gold_entities = set(best_ref_entity_list)
pred_entities = set(pred_entity_list)
correct_preds += len(gold_entities & pred_entities)
total_preds += len(pred_entities)
total_correct += len(gold_entities)
p = float(correct_preds) / total_preds if correct_preds > 0 else 0
r = float(correct_preds) / total_correct if correct_preds > 0 else 0
f1 = 2 * p * r / (p + r) if correct_preds > 0 else 0
return {'Precision': p, 'Recall': r, 'F1': f1}
def prepare_prf(pred_dict, ref_dict):
"""
Prepares data for calculation of prf scores.
"""
preds = {k: v['entity_answers'] for k, v in pred_dict.items()}
refs = {k: v['entity_answers'] for k, v in ref_dict.items()}
return preds, refs
def filter_dict(result_dict, key_tag):
"""
Filter a subset of the result_dict, where keys ends with 'key_tag'.
"""
filtered = {}
for k, v in result_dict.items():
if k.endswith(key_tag):
filtered[k] = v
return filtered
def get_metrics(pred_result, ref_result, task, source):
"""
Computes metrics.
"""
metrics = {}
ref_result_filtered = {}
pred_result_filtered = {}
if source == 'both':
ref_result_filtered = ref_result
pred_result_filtered = pred_result
else:
for question_id, info in ref_result.items():
if info['source'] == source:
ref_result_filtered[question_id] = info
if question_id in pred_result:
pred_result_filtered[question_id] = pred_result[question_id]
if task == 'main' or task == 'all' \
or task == 'description':
pred_dict, ref_dict = prepare_bleu(pred_result_filtered,
ref_result_filtered, task)
metrics = compute_bleu_rouge(pred_dict, ref_dict)
elif task == 'yesno':
pred_dict, ref_dict = prepare_bleu(pred_result_filtered,
ref_result_filtered, task)
keys = ['Yes', 'No', 'Depends']
preds = [filter_dict(pred_dict, k) for k in keys]
refs = [filter_dict(ref_dict, k) for k in keys]
metrics = compute_bleu_rouge(pred_dict, ref_dict)
for k, pred, ref in zip(keys, preds, refs):
m = compute_bleu_rouge(pred, ref)
k_metric = [(k + '|' + key, v) for key, v in m.items()]
metrics.update(k_metric)
elif task == 'entity':
pred_dict, ref_dict = prepare_prf(pred_result_filtered,
ref_result_filtered)
pred_dict_bleu, ref_dict_bleu = prepare_bleu(pred_result_filtered,
ref_result_filtered, task)
metrics = compute_prf(pred_dict, ref_dict)
metrics.update(compute_bleu_rouge(pred_dict_bleu, ref_dict_bleu))
else:
raise ValueError("Illegal task name: {}".format(task))
return metrics
def prepare_bleu(pred_result, ref_result, task):
"""
Prepares data for calculation of bleu and rouge scores.
"""
pred_list, ref_list = [], []
qids = ref_result.keys()
for qid in qids:
if task == 'main':
pred, ref = get_main_result(qid, pred_result, ref_result)
elif task == 'yesno':
pred, ref = get_yesno_result(qid, pred_result, ref_result)
elif task == 'all':
pred, ref = get_all_result(qid, pred_result, ref_result)
elif task == 'entity':
pred, ref = get_entity_result(qid, pred_result, ref_result)
elif task == 'description':
pred, ref = get_desc_result(qid, pred_result, ref_result)
else:
raise ValueError("Illegal task name: {}".format(task))
if pred and ref:
pred_list += pred
ref_list += ref
pred_dict = dict(pred_list)
ref_dict = dict(ref_list)
for qid, ans in ref_dict.items():
ref_dict[qid] = normalize(ref_dict[qid])
pred_dict[qid] = normalize(pred_dict.get(qid, [EMPTY]))
if not ans or ans == [EMPTY]:
del ref_dict[qid]
del pred_dict[qid]
for k, v in pred_dict.items():
assert len(v) == 1, \
"There should be only one predict answer. question_id: {}".format(k)
return pred_dict, ref_dict
def get_main_result(qid, pred_result, ref_result):
"""
Prepare answers for task 'main'.
Args:
qid: question_id.
pred_result: A dict include all question_id's result information read
from args.pred_file.
ref_result: A dict incluce all question_id's result information read
from args.ref_file.
Returns:
Two lists, the first one contains predict result, the second
one contains reference result of the same question_id. Each list has
elements of tuple (question_id, answers), 'answers' is a list of strings.
"""
ref_ans = ref_result[qid]['answers']
if not ref_ans:
ref_ans = [EMPTY]
pred_ans = pred_result.get(qid, {}).get('answers', [])[:1]
if not pred_ans:
pred_ans = [EMPTY]
return [(qid, pred_ans)], [(qid, ref_ans)]
def get_entity_result(qid, pred_result, ref_result):
"""
Prepare answers for task 'entity'.
Args:
qid: question_id.
pred_result: A dict include all question_id's result information read
from args.pred_file.
ref_result: A dict incluce all question_id's result information read
from args.ref_file.
Returns:
Two lists, the first one contains predict result, the second
one contains reference result of the same question_id. Each list has
elements of tuple (question_id, answers), 'answers' is a list of strings.
"""
if ref_result[qid]['question_type'] != 'ENTITY':
return None, None
return get_main_result(qid, pred_result, ref_result)
def get_desc_result(qid, pred_result, ref_result):
"""
Prepare answers for task 'description'.
Args:
qid: question_id.
pred_result: A dict include all question_id's result information read
from args.pred_file.
ref_result: A dict incluce all question_id's result information read
from args.ref_file.
Returns:
Two lists, the first one contains predict result, the second
one contains reference result of the same question_id. Each list has
elements of tuple (question_id, answers), 'answers' is a list of strings.
"""
if ref_result[qid]['question_type'] != 'DESCRIPTION':
return None, None
return get_main_result(qid, pred_result, ref_result)
def get_yesno_result(qid, pred_result, ref_result):
"""
Prepare answers for task 'yesno'.
Args:
qid: question_id.
pred_result: A dict include all question_id's result information read
from args.pred_file.
ref_result: A dict incluce all question_id's result information read
from args.ref_file.
Returns:
Two lists, the first one contains predict result, the second
one contains reference result of the same question_id. Each list has
elements of tuple (question_id, answers), 'answers' is a list of strings.
"""
def _uniq(li, is_ref):
uniq_li = []
left = []
keys = set()
for k, v in li:
if k not in keys:
uniq_li.append((k, v))
keys.add(k)
else:
left.append((k, v))
if is_ref:
dict_li = dict(uniq_li)
for k, v in left:
dict_li[k] += v
uniq_li = [(k, v) for k, v in dict_li.items()]
return uniq_li
def _expand_result(uniq_li):
expanded = uniq_li[:]
keys = set([x[0] for x in uniq_li])
for k in YESNO_LABELS - keys:
expanded.append((k, [EMPTY]))
return expanded
def _get_yesno_ans(qid, result_dict, is_ref=False):
if qid not in result_dict:
return [(str(qid) + '_' + k, v) for k, v in _expand_result([])]
yesno_answers = result_dict[qid]['yesno_answers']
answers = result_dict[qid]['answers']
lbl_ans = _uniq([(k, [v]) for k, v in zip(yesno_answers, answers)],
is_ref)
ret = [(str(qid) + '_' + k, v) for k, v in _expand_result(lbl_ans)]
return ret
if ref_result[qid]['question_type'] != 'YES_NO':
return None, None
ref_ans = _get_yesno_ans(qid, ref_result, is_ref=True)
pred_ans = _get_yesno_ans(qid, pred_result)
return pred_ans, ref_ans
def get_all_result(qid, pred_result, ref_result):
"""
Prepare answers for task 'all'.
Args:
qid: question_id.
pred_result: A dict include all question_id's result information read
from args.pred_file.
ref_result: A dict incluce all question_id's result information read
from args.ref_file.
Returns:
Two lists, the first one contains predict result, the second
one contains reference result of the same question_id. Each list has
elements of tuple (question_id, answers), 'answers' is a list of strings.
"""
if ref_result[qid]['question_type'] == 'YES_NO':
return get_yesno_result(qid, pred_result, ref_result)
return get_main_result(qid, pred_result, ref_result)
def format_metrics(metrics, task, err_msg):
"""
Format metrics. 'err' field returns any error occured during evaluation.
Args:
metrics: A dict object contains metrics for different tasks.
task: Task name.
err_msg: Exception raised during evaluation.
Returns:
Formatted result.
"""
result = {}
sources = ["both", "search", "zhidao"]
if err_msg is not None:
return {'errorMsg': str(err_msg), 'errorCode': 1, 'data': []}
data = []
if task != 'all' and task != 'main':
sources = ["both"]
if task == 'entity':
metric_names = ["Bleu-4", "Rouge-L"]
metric_names_prf = ["F1", "Precision", "Recall"]
for name in metric_names + metric_names_prf:
for src in sources:
obj = {
"name": name,
"value": round(metrics[src].get(name, 0) * 100, 2),
"type": src,
}
data.append(obj)
elif task == 'yesno':
metric_names = ["Bleu-4", "Rouge-L"]
details = ["Yes", "No", "Depends"]
src = sources[0]
for name in metric_names:
obj = {
"name": name,
"value": round(metrics[src].get(name, 0) * 100, 2),
"type": 'All',
}
data.append(obj)
for d in details:
obj = {
"name": name,
"value": \
round(metrics[src].get(d + '|' + name, 0) * 100, 2),
"type": d,
}
data.append(obj)
else:
metric_names = ["Bleu-4", "Rouge-L"]
for name in metric_names:
for src in sources:
obj = {
"name": name,
"value": \
round(metrics[src].get(name, 0) * 100, 2),
"type": src,
}
data.append(obj)
result["data"] = data
result["errorCode"] = 0
result["errorMsg"] = "success"
return result
def main(args):
"""
Do evaluation.
"""
err = None
metrics = {}
try:
pred_result = read_file(args.pred_file, args.task)
ref_result = read_file(args.ref_file, args.task, is_ref=True)
sources = ['both', 'search', 'zhidao']
if args.task not in set(['main', 'all']):
sources = sources[:1]
for source in sources:
metrics[source] = get_metrics(pred_result, ref_result, args.task,
source)
except ValueError as ve:
err = ve
except AssertionError as ae:
err = ae
print(
json.dumps(
format_metrics(metrics, args.task, err), ensure_ascii=False).encode(
'utf8'))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('pred_file', help='predict file')
parser.add_argument('ref_file', help='reference file')
parser.add_argument(
'task', help='task name: Main|Yes_No|All|Entity|Description')
args = parser.parse_args()
args.task = args.task.lower().replace('_', '')
main(args)
# -*- coding:utf8 -*-
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Utility function to generate vocabulary file.
"""
import argparse
import sys
import json
from itertools import chain
def get_vocab(files, vocab_file):
"""
Builds vocabulary file from field 'segmented_paragraphs'
and 'segmented_question'.
Args:
files: A list of file names.
vocab_file: The file that stores the vocabulary.
"""
vocab = {}
for f in files:
with open(f, 'r') as fin:
for line in fin:
obj = json.loads(line.strip())
paras = [
chain(*d['segmented_paragraphs']) for d in obj['documents']
]
doc_tokens = chain(*paras)
question_tokens = obj['segmented_question']
for t in list(doc_tokens) + question_tokens:
vocab[t] = vocab.get(t, 0) + 1
# output
sorted_vocab = sorted(
[(v, c) for v, c in vocab.items()], key=lambda x: x[1], reverse=True)
with open(vocab_file, 'w') as outf:
for w, c in sorted_vocab:
print >> outf, '{}\t{}'.format(w.encode('utf8'), c)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--files',
nargs='+',
required=True,
help='file list to count vocab from.')
parser.add_argument(
'--vocab', required=True, help='file to store counted vocab.')
args = parser.parse_args()
get_vocab(args.files, args.vocab)
#coding=utf8
import os, sys, json
import nltk
def _nltk_tokenize(sequence):
tokens = nltk.word_tokenize(sequence)
cur_char_offset = 0
token_offsets = []
token_words = []
for token in tokens:
cur_char_offset = sequence.find(token, cur_char_offset)
token_offsets.append(
[cur_char_offset, cur_char_offset + len(token) - 1])
token_words.append(token)
return token_offsets, token_words
def segment(input_js):
_, input_js['segmented_question'] = _nltk_tokenize(input_js['question'])
for doc_id, doc in enumerate(input_js['documents']):
doc['segmented_title'] = []
doc['segmented_paragraphs'] = []
for para_id, para in enumerate(doc['paragraphs']):
_, seg_para = _nltk_tokenize(para)
doc['segmented_paragraphs'].append(seg_para)
if 'answers' in input_js:
input_js['segmented_answers'] = []
for answer_id, answer in enumerate(input_js['answers']):
_, seg_answer = _nltk_tokenize(answer)
input_js['segmented_answers'].append(seg_answer)
if __name__ == '__main__':
if len(sys.argv) != 2:
print('Usage: tokenize_data.py <input_path>')
exit()
nltk.download('punkt')
for line in open(sys.argv[1]):
dureader_js = json.loads(line.strip())
segment(dureader_js)
print(json.dumps(dureader_js))
#coding=utf8
import sys
import json
import pandas as pd
def trans(input_js):
output_js = {}
output_js['question'] = input_js['query']
output_js['question_type'] = input_js['query_type']
output_js['question_id'] = input_js['query_id']
output_js['fact_or_opinion'] = ""
output_js['documents'] = []
for para_id, para in enumerate(input_js['passages']):
doc = {}
doc['title'] = ""
if 'is_selected' in para:
doc['is_selected'] = True if para['is_selected'] != 0 else False
doc['paragraphs'] = [para['passage_text']]
output_js['documents'].append(doc)
if 'answers' in input_js:
output_js['answers'] = input_js['answers']
return output_js
if __name__ == '__main__':
if len(sys.argv) != 2:
print('Usage: marcov1_to_dureader.py <input_path>')
exit()
df = pd.read_json(sys.argv[1])
for row in df.iterrows():
marco_js = json.loads(row[1].to_json())
dureader_js = trans(marco_js)
print(json.dumps(dureader_js))
import sys
import json
import pandas as pd
if __name__ == '__main__':
if len(sys.argv) != 3:
print('Usage: tojson.py <input_path> <output_path>')
exit()
infile = sys.argv[1]
outfile = sys.argv[2]
df = pd.read_json(infile)
with open(outfile, 'w') as f:
for row in df.iterrows():
f.write(row[1].to_json() + '\n')
###############################################################################
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module finds the most related paragraph of each document according to recall.
"""
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
import json
from collections import Counter
def precision_recall_f1(prediction, ground_truth):
"""
This function calculates and returns the precision, recall and f1-score
Args:
prediction: prediction string or list to be matched
ground_truth: golden string or list reference
Returns:
floats of (p, r, f1)
Raises:
None
"""
if not isinstance(prediction, list):
prediction_tokens = prediction.split()
else:
prediction_tokens = prediction
if not isinstance(ground_truth, list):
ground_truth_tokens = ground_truth.split()
else:
ground_truth_tokens = ground_truth
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0, 0, 0
p = 1.0 * num_same / len(prediction_tokens)
r = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * p * r) / (p + r)
return p, r, f1
def recall(prediction, ground_truth):
"""
This function calculates and returns the recall
Args:
prediction: prediction string or list to be matched
ground_truth: golden string or list reference
Returns:
floats of recall
Raises:
None
"""
return precision_recall_f1(prediction, ground_truth)[1]
def f1_score(prediction, ground_truth):
"""
This function calculates and returns the f1-score
Args:
prediction: prediction string or list to be matched
ground_truth: golden string or list reference
Returns:
floats of f1
Raises:
None
"""
return precision_recall_f1(prediction, ground_truth)[2]
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""
This function calculates and returns the precision, recall and f1-score
Args:
metric_fn: metric function pointer which calculates scores according to corresponding logic.
prediction: prediction string or list to be matched
ground_truth: golden string or list reference
Returns:
floats of (p, r, f1)
Raises:
None
"""
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def find_best_question_match(doc, question, with_score=False):
"""
For each docment, find the paragraph that matches best to the question.
Args:
doc: The document object.
question: The question tokens.
with_score: If True then the match score will be returned,
otherwise False.
Returns:
The index of the best match paragraph, if with_score=False,
otherwise returns a tuple of the index of the best match paragraph
and the match score of that paragraph.
"""
most_related_para = -1
max_related_score = 0
most_related_para_len = 0
for p_idx, para_tokens in enumerate(doc['segmented_paragraphs']):
if len(question) > 0:
related_score = metric_max_over_ground_truths(recall, para_tokens,
question)
else:
related_score = 0
if related_score > max_related_score \
or (related_score == max_related_score \
and len(para_tokens) < most_related_para_len):
most_related_para = p_idx
max_related_score = related_score
most_related_para_len = len(para_tokens)
if most_related_para == -1:
most_related_para = 0
if with_score:
return most_related_para, max_related_score
return most_related_para
def find_fake_answer(sample):
"""
For each document, finds the most related paragraph based on recall,
then finds a span that maximize the f1_score compared with the gold answers
and uses this span as a fake answer span
Args:
sample: a sample in the dataset
Returns:
None
Raises:
None
"""
for doc in sample['documents']:
most_related_para = -1
most_related_para_len = 999999
max_related_score = 0
for p_idx, para_tokens in enumerate(doc['segmented_paragraphs']):
if len(sample['segmented_answers']) > 0:
related_score = metric_max_over_ground_truths(
recall, para_tokens, sample['segmented_answers'])
else:
continue
if related_score > max_related_score \
or (related_score == max_related_score
and len(para_tokens) < most_related_para_len):
most_related_para = p_idx
most_related_para_len = len(para_tokens)
max_related_score = related_score
doc['most_related_para'] = most_related_para
sample['answer_docs'] = []
sample['answer_spans'] = []
sample['fake_answers'] = []
sample['match_scores'] = []
best_match_score = 0
best_match_d_idx, best_match_span = -1, [-1, -1]
best_fake_answer = None
answer_tokens = set()
for segmented_answer in sample['segmented_answers']:
answer_tokens = answer_tokens | set(
[token for token in segmented_answer])
for d_idx, doc in enumerate(sample['documents']):
if not doc['is_selected']:
continue
if doc['most_related_para'] == -1:
doc['most_related_para'] = 0
most_related_para_tokens = doc['segmented_paragraphs'][doc[
'most_related_para']][:1000]
for start_tidx in range(len(most_related_para_tokens)):
if most_related_para_tokens[start_tidx] not in answer_tokens:
continue
for end_tidx in range(
len(most_related_para_tokens) - 1, start_tidx - 1, -1):
span_tokens = most_related_para_tokens[start_tidx:end_tidx + 1]
if len(sample['segmented_answers']) > 0:
match_score = metric_max_over_ground_truths(
f1_score, span_tokens, sample['segmented_answers'])
else:
match_score = 0
if match_score == 0:
break
if match_score > best_match_score:
best_match_d_idx = d_idx
best_match_span = [start_tidx, end_tidx]
best_match_score = match_score
best_fake_answer = ''.join(span_tokens)
if best_match_score > 0:
sample['answer_docs'].append(best_match_d_idx)
sample['answer_spans'].append(best_match_span)
sample['fake_answers'].append(best_fake_answer)
sample['match_scores'].append(best_match_score)
if __name__ == '__main__':
for line in sys.stdin:
sample = json.loads(line)
find_fake_answer(sample)
print(json.dumps(sample, encoding='utf8', ensure_ascii=False))
#!/bin/bash
input_file=$1
output_file=$2
# convert the data from MARCO V2 (json) format to MARCO V1 (jsonl) format.
# the script was forked from MARCO repo.
# the format of MARCO V1 is much more easier to explore.
python3 marcov2_to_v1_tojsonl.py $input_file $input_file.marcov1
# convert the data from MARCO V1 format to DuReader format.
python3 marcov1_to_dureader.py $input_file.marcov1 >$input_file.dureader_raw
# tokenize the data.
python3 marco_tokenize_data.py $input_file.dureader_raw >$input_file.segmented
# find fake answers (indicating the start and end positions of answers in the document) for train and dev sets.
# note that this should not be applied for test set, since there is no ground truth in test set.
python preprocess.py $input_file.segmented >$output_file
# remove the temporal data files.
rm -rf $input_file.dureader_raw $input_file.segmented
# -*- coding:utf8 -*-
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module implements the Vocab class for converting string to id and back
"""
import numpy as np
class Vocab(object):
"""
Implements a vocabulary to store the tokens in the data, with their corresponding embeddings.
"""
def __init__(self, filename=None, initial_tokens=None, lower=False):
self.id2token = {}
self.token2id = {}
self.token_cnt = {}
self.lower = lower
self.embed_dim = None
self.embeddings = None
self.pad_token = '<blank>'
self.unk_token = '<unk>'
self.initial_tokens = initial_tokens if initial_tokens is not None else []
self.initial_tokens.extend([self.pad_token, self.unk_token])
for token in self.initial_tokens:
self.add(token)
if filename is not None:
self.load_from_file(filename)
def size(self):
"""
get the size of vocabulary
Returns:
an integer indicating the size
"""
return len(self.id2token)
def load_from_file(self, file_path):
"""
loads the vocab from file_path
Args:
file_path: a file with a word in each line
"""
for line in open(file_path, 'r'):
token = line.rstrip('\n')
self.add(token)
def get_id(self, token):
"""
gets the id of a token, returns the id of unk token if token is not in vocab
Args:
key: a string indicating the word
Returns:
an integer
"""
token = token.lower() if self.lower else token
try:
return self.token2id[token]
except KeyError:
return self.token2id[self.unk_token]
def get_token(self, idx):
"""
gets the token corresponding to idx, returns unk token if idx is not in vocab
Args:
idx: an integer
returns:
a token string
"""
try:
return self.id2token[idx]
except KeyError:
return self.unk_token
def add(self, token, cnt=1):
"""
adds the token to vocab
Args:
token: a string
cnt: a num indicating the count of the token to add, default is 1
"""
token = token.lower() if self.lower else token
if token in self.token2id:
idx = self.token2id[token]
else:
idx = len(self.id2token)
self.id2token[idx] = token
self.token2id[token] = idx
if cnt > 0:
if token in self.token_cnt:
self.token_cnt[token] += cnt
else:
self.token_cnt[token] = cnt
return idx
def filter_tokens_by_cnt(self, min_cnt):
"""
filter the tokens in vocab by their count
Args:
min_cnt: tokens with frequency less than min_cnt is filtered
"""
filtered_tokens = [
token for token in self.token2id if self.token_cnt[token] >= min_cnt
]
# rebuild the token x id map
self.token2id = {}
self.id2token = {}
for token in self.initial_tokens:
self.add(token, cnt=0)
for token in filtered_tokens:
self.add(token, cnt=0)
def randomly_init_embeddings(self, embed_dim):
"""
randomly initializes the embeddings for each token
Args:
embed_dim: the size of the embedding for each token
"""
self.embed_dim = embed_dim
self.embeddings = np.random.rand(self.size(), embed_dim)
for token in [self.pad_token, self.unk_token]:
self.embeddings[self.get_id(token)] = np.zeros([self.embed_dim])
def load_pretrained_embeddings(self, embedding_path):
"""
loads the pretrained embeddings from embedding_path,
tokens not in pretrained embeddings will be filtered
Args:
embedding_path: the path of the pretrained embedding file
"""
trained_embeddings = {}
with open(embedding_path, 'r') as fin:
for line in fin:
contents = line.strip().split()
token = contents[0].decode('utf8')
if token not in self.token2id:
continue
trained_embeddings[token] = list(map(float, contents[1:]))
if self.embed_dim is None:
self.embed_dim = len(contents) - 1
filtered_tokens = trained_embeddings.keys()
# rebuild the token x id map
self.token2id = {}
self.id2token = {}
for token in self.initial_tokens:
self.add(token, cnt=0)
for token in filtered_tokens:
self.add(token, cnt=0)
# load embeddings
self.embeddings = np.zeros([self.size(), self.embed_dim])
for token in self.token2id.keys():
if token in trained_embeddings:
self.embeddings[self.get_id(token)] = trained_embeddings[token]
def convert_to_ids(self, tokens):
"""
Convert a list of tokens to ids, use unk_token if the token is not in vocab.
Args:
tokens: a list of token
Returns:
a list of ids
"""
vec = [self.get_id(label) for label in tokens]
return vec
def recover_from_ids(self, ids, stop_id=None):
"""
Convert a list of ids to tokens, stop converting if the stop_id is encountered
Args:
ids: a list of ids to convert
stop_id: the stop id, default is None
Returns:
a list of tokens
"""
tokens = []
for i in ids:
tokens += [self.get_token(i)]
if stop_id is not None and i == stop_id:
break
return tokens
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册