提交 20dac612 编写于 作者: X xuezhong

use dynamic rnn for multi-doc

上级 b8613b9a
......@@ -132,35 +132,33 @@ class BRCDataset(object):
'passage_token_ids': [],
'passage_length': [],
'start_id': [],
'end_id': []
'end_id': [],
'passage_num': []
}
max_passage_num = max(
[len(sample['passages']) for sample in batch_data['raw_data']])
#max_passage_num = min(self.max_p_num, max_passage_num)
max_passage_num = self.max_p_num
max_passage_num = min(self.max_p_num, max_passage_num)
for sidx, sample in enumerate(batch_data['raw_data']):
count = 0
for pidx in range(max_passage_num):
if pidx < len(sample['passages']):
count += 1
batch_data['question_token_ids'].append(sample[
'question_token_ids'])
'question_token_ids'][0:self.max_q_len])
batch_data['question_length'].append(
len(sample['question_token_ids']))
min(len(sample['question_token_ids']), self.max_q_len))
passage_token_ids = sample['passages'][pidx][
'passage_token_ids']
'passage_token_ids'][0:self.max_p_len]
batch_data['passage_token_ids'].append(passage_token_ids)
batch_data['passage_length'].append(
min(len(passage_token_ids), self.max_p_len))
else:
batch_data['question_token_ids'].append([])
batch_data['question_length'].append(0)
batch_data['passage_token_ids'].append([])
batch_data['passage_length'].append(0)
batch_data, padded_p_len, padded_q_len = self._dynamic_padding(
batch_data, pad_id)
batch_data['passage_num'].append(count)
for sample in batch_data['raw_data']:
gold_passage_offset = 0
if 'answer_passages' in sample and len(sample['answer_passages']):
gold_passage_offset = padded_p_len * sample['answer_passages'][
0]
for i in range(sample['answer_passages'][0]):
gold_passage_offset += len(batch_data['passage_token_ids'][
i])
batch_data['start_id'].append(gold_passage_offset + sample[
'answer_spans'][0][0])
batch_data['end_id'].append(gold_passage_offset + sample[
......
......@@ -68,16 +68,23 @@ def bi_lstm_encoder(input_seq, gate_size, para_name, args):
return encoder_out
def encoder(input_name, para_name, shape, hidden_size, args):
def get_data(input_name, lod_level, args):
input_ids = layers.data(
name=input_name, shape=[1], dtype='int64', lod_level=1)
name=input_name, shape=[1], dtype='int64', lod_level=lod_level)
return input_ids
def embedding(input_ids, shape, args):
input_embedding = layers.embedding(
input=input_ids,
size=shape,
dtype='float32',
is_sparse=True,
param_attr=fluid.ParamAttr(name='embedding_para'))
return input_embedding
def encoder(input_embedding, para_name, hidden_size, args):
encoder_out = bi_lstm_encoder(
input_seq=input_embedding,
gate_size=hidden_size,
......@@ -259,40 +266,41 @@ def fusion(g, args):
def rc_model(hidden_size, vocab, args):
emb_shape = [vocab.size(), vocab.embed_dim]
start_labels = layers.data(
name="start_lables", shape=[1], dtype='float32', lod_level=1)
end_labels = layers.data(
name="end_lables", shape=[1], dtype='float32', lod_level=1)
# stage 1:encode
p_ids_names = []
q_ids_names = []
ms = []
gs = []
qs = []
for i in range(args.doc_num):
p_ids_name = "pids_%d" % i
p_ids_names.append(p_ids_name)
p_enc_i = encoder(p_ids_name, 'p_enc', emb_shape, hidden_size, args)
q_ids_name = "qids_%d" % i
q_ids_names.append(q_ids_name)
q_enc_i = encoder(q_ids_name, 'q_enc', emb_shape, hidden_size, args)
q_id0 = get_data('q_id0', 1, args)
q_ids = get_data('q_ids', 2, args)
p_ids_name = 'p_ids'
p_ids = get_data('p_ids', 2, args)
p_embs = embedding(p_ids, emb_shape, args)
q_embs = embedding(q_ids, emb_shape, args)
drnn = layers.DynamicRNN()
with drnn.block():
p_emb = drnn.step_input(p_embs)
q_emb = drnn.step_input(q_embs)
p_enc = encoder(p_emb, 'p_enc', hidden_size, args)
q_enc = encoder(q_emb, 'q_enc', hidden_size, args)
# stage 2:match
g_i = attn_flow(q_enc_i, p_enc_i, p_ids_name, args)
g_i = attn_flow(q_enc, p_enc, p_ids_name, args)
# stage 3:fusion
m_i = fusion(g_i, args)
ms.append(m_i)
gs.append(g_i)
qs.append(q_enc_i)
m = layers.sequence_concat(input=ms)
g = layers.sequence_concat(input=gs)
q_vec = layers.sequence_concat(input=qs)
drnn.output(m_i, q_enc)
ms, q_encs = drnn()
p_vec = layers.lod_reset(x=ms, y=start_labels)
q_vec = layers.lod_reset(x=q_encs, y=q_id0)
# stage 4:decode
start_probs, end_probs = point_network_decoder(
p_vec=m, q_vec=q_vec, hidden_size=hidden_size, args=args)
start_labels = layers.data(
name="start_lables", shape=[1], dtype='float32', lod_level=1)
end_labels = layers.data(
name="end_lables", shape=[1], dtype='float32', lod_level=1)
p_vec=p_vec, q_vec=q_vec, hidden_size=hidden_size, args=args)
cost0 = layers.sequence_pool(
layers.cross_entropy(
......@@ -308,5 +316,5 @@ def rc_model(hidden_size, vocab, args):
cost = cost0 + cost1
cost.persistable = True
feeding_list = q_ids_names + ["start_lables", "end_lables"] + p_ids_names
return cost, start_probs, end_probs, feeding_list
feeding_list = ["q_ids", "start_lables", "end_lables", "p_ids", "q_id0"]
return cost, start_probs, end_probs, ms, feeding_list
......@@ -46,22 +46,32 @@ from vocab import Vocab
def prepare_batch_input(insts, args):
doc_num = args.doc_num
batch_size = len(insts['raw_data'])
inst_num = len(insts['passage_num'])
if batch_size != inst_num:
print("data error %d, %d" % (batch_size, inst_num))
return None
new_insts = []
passage_idx = 0
for i in range(batch_size):
p_len = 0
p_id = []
q_id = []
p_ids = []
q_ids = []
p_len = 0
for j in range(i * doc_num, (i + 1) * doc_num):
p_ids.append(insts['passage_token_ids'][j])
p_id = p_id + insts['passage_token_ids'][j]
q_ids.append(insts['question_token_ids'][j])
q_id = q_id + insts['question_token_ids'][j]
q_id = []
p_id_r = []
p_ids_r = []
q_ids_r = []
q_id_r = []
for j in range(insts['passage_num'][i]):
p_ids.append(insts['passage_token_ids'][passage_idx + j])
p_id = p_id + insts['passage_token_ids'][passage_idx + j]
q_ids.append(insts['question_token_ids'][passage_idx + j])
q_id = q_id + insts['question_token_ids'][passage_idx + j]
passage_idx += insts['passage_num'][i]
p_len = len(p_id)
def _get_label(idx, ref_len):
......@@ -72,11 +82,46 @@ def prepare_batch_input(insts, args):
start_label = _get_label(insts['start_id'][i], p_len)
end_label = _get_label(insts['end_id'][i], p_len)
new_inst = q_ids + [start_label, end_label] + p_ids
new_inst = [q_ids, start_label, end_label, p_ids, q_id]
new_insts.append(new_inst)
return new_insts
def batch_reader(batch_list, args):
res = []
for batch in batch_list:
res.append(prepare_batch_input(batch, args))
return res
def read_multiple(reader, count, clip_last=True):
"""
Stack data from reader for multi-devices.
"""
def __impl__():
res = []
for item in reader():
res.append(item)
if len(res) == count:
yield res
res = []
if len(res) == count:
yield res
elif not clip_last:
data = []
for item in res:
data += item
if len(data) > count:
inst_num_per_part = len(data) // count
yield [
data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
for i in range(count)
]
return __impl__
def LodTensor_Array(lod_tensor):
lod = lod_tensor.lod()
array = np.array(lod_tensor)
......@@ -103,7 +148,7 @@ def print_para(train_prog, train_exe, logger, args):
logger.info("total param num: {0}".format(num_sum))
def find_best_answer_for_passage(start_probs, end_probs, passage_len, args):
def find_best_answer_for_passage(start_probs, end_probs, passage_len):
"""
Finds the best answer with the maximum start_prob * end_prob from a single passage
"""
......@@ -125,7 +170,7 @@ def find_best_answer_for_passage(start_probs, end_probs, passage_len, args):
return (best_start, best_end), max_prob
def find_best_answer(sample, start_prob, end_prob, padded_p_len, args):
def find_best_answer_for_inst(sample, start_prob, end_prob, inst_lod):
"""
Finds the best answer for a sample given start_prob and end_prob for each position.
This will call find_best_answer_for_passage because there are multiple passages in a sample
......@@ -134,11 +179,16 @@ def find_best_answer(sample, start_prob, end_prob, padded_p_len, args):
for p_idx, passage in enumerate(sample['passages']):
if p_idx >= args.max_p_num:
continue
if len(start_prob) != len(end_prob):
logger.info('error: {}'.format(sample['question']))
continue
passage_start = inst_lod[p_idx] - inst_lod[0]
passage_end = inst_lod[p_idx + 1] - inst_lod[0]
passage_len = passage_end - passage_start
passage_len = min(args.max_p_len, len(passage['passage_tokens']))
answer_span, score = find_best_answer_for_passage(
start_prob[p_idx * padded_p_len:(p_idx + 1) * padded_p_len],
end_prob[p_idx * padded_p_len:(p_idx + 1) * padded_p_len],
passage_len, args)
start_prob[passage_start:passage_end],
end_prob[passage_start:passage_end], passage_len)
if score > best_score:
best_score = score
best_p_idx = p_idx
......@@ -148,11 +198,11 @@ def find_best_answer(sample, start_prob, end_prob, padded_p_len, args):
else:
best_answer = ''.join(sample['passages'][best_p_idx]['passage_tokens'][
best_span[0]:best_span[1] + 1])
return best_answer
return best_answer, best_span
def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place,
vocab, brc_data, logger, args):
def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order,
place, dev_count, vocab, brc_data, logger, args):
"""
"""
......@@ -165,6 +215,8 @@ def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place,
# Use test set as validation each pass
total_loss = 0.0
count = 0
n_batch_cnt = 0
n_batch_loss = 0.0
pred_answers, ref_answers = [], []
val_feed_list = [
inference_program.global_block().var(var_name)
......@@ -172,55 +224,80 @@ def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place,
]
val_feeder = fluid.DataFeeder(val_feed_list, place)
pad_id = vocab.get_id(vocab.pad_token)
dev_batches = brc_data.gen_mini_batches(
'dev', args.batch_size, pad_id, shuffle=False)
dev_reader = lambda:brc_data.gen_mini_batches('dev', args.batch_size, pad_id, shuffle=False)
dev_reader = read_multiple(dev_reader, dev_count)
for batch_id, batch in enumerate(dev_batches, 1):
feed_data = prepare_batch_input(batch, args)
for batch_id, batch_list in enumerate(dev_reader(), 1):
feed_data = batch_reader(batch_list, args)
val_fetch_outs = parallel_executor.run(
feed=val_feeder.feed(feed_data),
fetch_list=[avg_cost.name, s_probs.name, e_probs.name],
feed=list(val_feeder.feed_parallel(feed_data, dev_count)),
fetch_list=[avg_cost.name, s_probs.name, e_probs.name, match.name],
return_numpy=False)
total_loss += np.array(val_fetch_outs[0])[0]
start_probs = LodTensor_Array(val_fetch_outs[1])
end_probs = LodTensor_Array(val_fetch_outs[2])
count += len(batch['raw_data'])
padded_p_len = len(batch['passage_token_ids'][0])
for sample, start_prob, end_prob in zip(batch['raw_data'], start_probs,
end_probs):
best_answer = find_best_answer(sample, start_prob, end_prob,
padded_p_len, args)
pred_answers.append({
'question_id': sample['question_id'],
'question_type': sample['question_type'],
'answers': [best_answer],
'entity_answers': [[]],
'yesno_answers': []
})
if 'answers' in sample:
ref_answers.append({
total_loss += np.array(val_fetch_outs[0]).sum()
start_probs_m = LodTensor_Array(val_fetch_outs[1])
end_probs_m = LodTensor_Array(val_fetch_outs[2])
match_lod = val_fetch_outs[3].lod()
count += len(np.array(val_fetch_outs[0]))
n_batch_cnt += len(np.array(val_fetch_outs[0]))
n_batch_loss += np.array(val_fetch_outs[0]).sum()
log_every_n_batch = args.log_interval
if log_every_n_batch > 0 and batch_id % log_every_n_batch == 0:
logger.info('Average dev loss from batch {} to {} is {}'.format(
batch_id - log_every_n_batch + 1, batch_id, "%.10f" % (
n_batch_loss / n_batch_cnt)))
n_batch_loss = 0.0
n_batch_cnt = 0
for idx, batch in enumerate(batch_list):
#one batch
batch_size = len(batch['raw_data'])
batch_range = match_lod[0][idx * batch_size:(idx + 1) * batch_size +
1]
batch_lod = [[batch_range[x], batch_range[x + 1]]
for x in range(len(batch_range[:-1]))]
start_prob_batch = start_probs_m[idx * batch_size:(idx + 1) *
batch_size]
end_prob_batch = end_probs_m[idx * batch_size:(idx + 1) *
batch_size]
for sample, start_prob_inst, end_prob_inst, inst_range in zip(
batch['raw_data'], start_prob_batch, end_prob_batch,
batch_lod):
#one instance
inst_lod = match_lod[1][inst_range[0]:inst_range[1] + 1]
best_answer, best_span = find_best_answer_for_inst(
sample, start_prob_inst, end_prob_inst, inst_lod)
pred = {
'question_id': sample['question_id'],
'question_type': sample['question_type'],
'answers': sample['answers'],
'answers': [best_answer],
'entity_answers': [[]],
'yesno_answers': []
})
if args.result_dir is not None and args.result_name is not None:
'yesno_answers': [best_span]
}
pred_answers.append(pred)
if 'answers' in sample:
ref = {
'question_id': sample['question_id'],
'question_type': sample['question_type'],
'answers': sample['answers'],
'entity_answers': [[]],
'yesno_answers': []
}
ref_answers.append(ref)
result_dir = args.result_dir
result_prefix = args.result_name
if result_dir is not None and result_prefix is not None:
if not os.path.exists(args.result_dir):
os.makedirs(args.result_dir)
result_file = os.path.join(args.result_dir, args.result_name + '.json')
result_file = os.path.join(result_dir, result_prefix + 'json')
with open(result_file, 'w') as fout:
for pred_answer in pred_answers:
fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n')
logger.info('Saving {} results to {}'.format(args.result_name,
logger.info('Saving {} results to {}'.format(result_prefix,
result_file))
ave_loss = 1.0 * total_loss / count
# compute the bleu and rouge scores if reference answers is provided
if len(ref_answers) > 0:
pred_dict, ref_dict = {}, {}
......@@ -250,6 +327,13 @@ def train(logger, args):
brc_data.convert_to_ids(vocab)
logger.info('Initialize the model...')
if not args.use_gpu:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
# build model
main_program = fluid.Program()
startup_prog = fluid.Program()
......@@ -257,7 +341,7 @@ def train(logger, args):
startup_prog.random_seed = args.random_seed
with fluid.program_guard(main_program, startup_prog):
with fluid.unique_name.guard():
avg_cost, s_probs, e_probs, feed_order = rc_model.rc_model(
avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model(
args.hidden_size, vocab, args)
# clone from default main program and use it as the validation program
inference_program = main_program.clone(for_test=True)
......@@ -314,20 +398,21 @@ def train(logger, args):
for pass_id in range(1, args.pass_num + 1):
pass_start_time = time.time()
pad_id = vocab.get_id(vocab.pad_token)
train_batches = brc_data.gen_mini_batches(
'train', args.batch_size, pad_id, shuffle=True)
train_reader = lambda:brc_data.gen_mini_batches('train', args.batch_size, pad_id, shuffle=False)
train_reader = read_multiple(train_reader, dev_count)
log_every_n_batch, n_batch_loss = args.log_interval, 0
total_num, total_loss = 0, 0
for batch_id, batch in enumerate(train_batches, 1):
input_data_dict = prepare_batch_input(batch, args)
for batch_id, batch_list in enumerate(train_reader(), 1):
feed_data = batch_reader(batch_list, args)
fetch_outs = parallel_executor.run(
feed=feeder.feed(input_data_dict),
feed=list(feeder.feed_parallel(feed_data, dev_count)),
fetch_list=[avg_cost.name],
return_numpy=False)
cost_train = np.array(fetch_outs[0])[0]
total_num += len(batch['raw_data'])
cost_train = np.array(fetch_outs[0]).mean()
total_num += args.batch_size * dev_count
n_batch_loss += cost_train
total_loss += cost_train * len(batch['raw_data'])
total_loss += cost_train * args.batch_size * dev_count
if log_every_n_batch > 0 and batch_id % log_every_n_batch == 0:
print_para(main_program, parallel_executor, logger,
args)
......@@ -337,19 +422,23 @@ def train(logger, args):
"%.10f" % (n_batch_loss / log_every_n_batch)))
n_batch_loss = 0
if args.dev_interval > 0 and batch_id % args.dev_interval == 0:
eval_loss, bleu_rouge = validation(
inference_program, avg_cost, s_probs, e_probs,
feed_order, place, vocab, brc_data, logger, args)
logger.info('Dev eval loss {}'.format(eval_loss))
logger.info('Dev eval result: {}'.format(bleu_rouge))
if brc_data.dev_set is not None:
eval_loss, bleu_rouge = validation(
inference_program, avg_cost, s_probs, e_probs,
match, feed_order, place, dev_count, vocab,
brc_data, logger, args)
logger.info('Dev eval loss {}'.format(eval_loss))
logger.info('Dev eval result: {}'.format(
bleu_rouge))
pass_end_time = time.time()
logger.info('Evaluating the model after epoch {}'.format(
pass_id))
if brc_data.dev_set is not None:
eval_loss, bleu_rouge = validation(
inference_program, avg_cost, s_probs, e_probs,
feed_order, place, vocab, brc_data, logger, args)
inference_program, avg_cost, s_probs, e_probs, match,
feed_order, place, dev_count, vocab, brc_data, logger,
args)
logger.info('Dev eval loss {}'.format(eval_loss))
logger.info('Dev eval result: {}'.format(bleu_rouge))
else:
......@@ -389,10 +478,17 @@ def evaluate(logger, args):
startup_prog.random_seed = args.random_seed
with fluid.program_guard(main_program, startup_prog):
with fluid.unique_name.guard():
avg_cost, s_probs, e_probs, feed_order = rc_model.rc_model(
avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model(
args.hidden_size, vocab, args)
# initialize parameters
place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace()
if not args.use_gpu:
place = fluid.CPUPlace()
dev_count = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
exe = Executor(place)
if args.load_dir:
logger.info('load from {}'.format(args.load_dir))
......@@ -402,17 +498,10 @@ def evaluate(logger, args):
logger.error('No model file to load ...')
return
# prepare data
feed_list = [
main_program.global_block().var(var_name)
for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
inference_program = main_program.clone(for_test=True)
eval_loss, bleu_rouge = validation(
inference_program, avg_cost, s_probs, e_probs, feed_order,
place, vocab, brc_data, logger, args)
place, dev_count, vocab, brc_data, logger, args)
logger.info('Dev eval loss {}'.format(eval_loss))
logger.info('Dev eval result: {}'.format(bleu_rouge))
logger.info('Predicted answers are saved to {}'.format(
......@@ -438,10 +527,17 @@ def predict(logger, args):
startup_prog.random_seed = args.random_seed
with fluid.program_guard(main_program, startup_prog):
with fluid.unique_name.guard():
avg_cost, s_probs, e_probs, feed_order = rc_model.rc_model(
avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model(
args.hidden_size, vocab, args)
# initialize parameters
place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace()
if not args.use_gpu:
place = fluid.CPUPlace()
dev_count = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
exe = Executor(place)
if args.load_dir:
logger.info('load from {}'.format(args.load_dir))
......@@ -451,17 +547,10 @@ def predict(logger, args):
logger.error('No model file to load ...')
return
# prepare data
feed_list = [
main_program.global_block().var(var_name)
for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
inference_program = main_program.clone(for_test=True)
eval_loss, bleu_rouge = validation(
inference_program, avg_cost, s_probs, e_probs, feed_order,
place, vocab, brc_data, logger, args)
inference_program, avg_cost, s_probs, e_probs, match,
feed_order, place, dev_count, vocab, brc_data, logger, args)
def prepare(logger, args):
......
export CUDA_VISIBLE_DEVICES=1
export CUDA_VISIBLE_DEVICES=0
python run.py \
--trainset 'data/preprocessed/trainset/search.train.json' \
'data/preprocessed/trainset/zhidao.train.json' \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册