From 20dac612546b0c63cafcec62a3d686dd2e6ce90d Mon Sep 17 00:00:00 2001 From: xuezhong Date: Tue, 23 Oct 2018 08:17:23 +0000 Subject: [PATCH] use dynamic rnn for multi-doc --- .../machine_reading_comprehension/dataset.py | 28 +- .../machine_reading_comprehension/rc_model.py | 68 +++-- fluid/machine_reading_comprehension/run.py | 269 ++++++++++++------ fluid/machine_reading_comprehension/run.sh | 2 +- 4 files changed, 231 insertions(+), 136 deletions(-) diff --git a/fluid/machine_reading_comprehension/dataset.py b/fluid/machine_reading_comprehension/dataset.py index 7a5cea18..0aa91896 100644 --- a/fluid/machine_reading_comprehension/dataset.py +++ b/fluid/machine_reading_comprehension/dataset.py @@ -132,35 +132,33 @@ class BRCDataset(object): 'passage_token_ids': [], 'passage_length': [], 'start_id': [], - 'end_id': [] + 'end_id': [], + 'passage_num': [] } max_passage_num = max( [len(sample['passages']) for sample in batch_data['raw_data']]) - #max_passage_num = min(self.max_p_num, max_passage_num) - max_passage_num = self.max_p_num + max_passage_num = min(self.max_p_num, max_passage_num) for sidx, sample in enumerate(batch_data['raw_data']): + count = 0 for pidx in range(max_passage_num): if pidx < len(sample['passages']): + count += 1 batch_data['question_token_ids'].append(sample[ - 'question_token_ids']) + 'question_token_ids'][0:self.max_q_len]) batch_data['question_length'].append( - len(sample['question_token_ids'])) + min(len(sample['question_token_ids']), self.max_q_len)) passage_token_ids = sample['passages'][pidx][ - 'passage_token_ids'] + 'passage_token_ids'][0:self.max_p_len] batch_data['passage_token_ids'].append(passage_token_ids) batch_data['passage_length'].append( min(len(passage_token_ids), self.max_p_len)) - else: - batch_data['question_token_ids'].append([]) - batch_data['question_length'].append(0) - batch_data['passage_token_ids'].append([]) - batch_data['passage_length'].append(0) - batch_data, padded_p_len, padded_q_len = self._dynamic_padding( - batch_data, pad_id) + batch_data['passage_num'].append(count) for sample in batch_data['raw_data']: + gold_passage_offset = 0 if 'answer_passages' in sample and len(sample['answer_passages']): - gold_passage_offset = padded_p_len * sample['answer_passages'][ - 0] + for i in range(sample['answer_passages'][0]): + gold_passage_offset += len(batch_data['passage_token_ids'][ + i]) batch_data['start_id'].append(gold_passage_offset + sample[ 'answer_spans'][0][0]) batch_data['end_id'].append(gold_passage_offset + sample[ diff --git a/fluid/machine_reading_comprehension/rc_model.py b/fluid/machine_reading_comprehension/rc_model.py index 11d5b5d9..932ccd9c 100644 --- a/fluid/machine_reading_comprehension/rc_model.py +++ b/fluid/machine_reading_comprehension/rc_model.py @@ -68,16 +68,23 @@ def bi_lstm_encoder(input_seq, gate_size, para_name, args): return encoder_out -def encoder(input_name, para_name, shape, hidden_size, args): +def get_data(input_name, lod_level, args): input_ids = layers.data( - name=input_name, shape=[1], dtype='int64', lod_level=1) + name=input_name, shape=[1], dtype='int64', lod_level=lod_level) + return input_ids + + +def embedding(input_ids, shape, args): input_embedding = layers.embedding( input=input_ids, size=shape, dtype='float32', is_sparse=True, param_attr=fluid.ParamAttr(name='embedding_para')) + return input_embedding + +def encoder(input_embedding, para_name, hidden_size, args): encoder_out = bi_lstm_encoder( input_seq=input_embedding, gate_size=hidden_size, @@ -259,40 +266,41 @@ def fusion(g, args): def rc_model(hidden_size, vocab, args): emb_shape = [vocab.size(), vocab.embed_dim] + start_labels = layers.data( + name="start_lables", shape=[1], dtype='float32', lod_level=1) + end_labels = layers.data( + name="end_lables", shape=[1], dtype='float32', lod_level=1) + # stage 1:encode - p_ids_names = [] - q_ids_names = [] - ms = [] - gs = [] - qs = [] - for i in range(args.doc_num): - p_ids_name = "pids_%d" % i - p_ids_names.append(p_ids_name) - p_enc_i = encoder(p_ids_name, 'p_enc', emb_shape, hidden_size, args) - - q_ids_name = "qids_%d" % i - q_ids_names.append(q_ids_name) - q_enc_i = encoder(q_ids_name, 'q_enc', emb_shape, hidden_size, args) + q_id0 = get_data('q_id0', 1, args) + + q_ids = get_data('q_ids', 2, args) + p_ids_name = 'p_ids' + + p_ids = get_data('p_ids', 2, args) + p_embs = embedding(p_ids, emb_shape, args) + q_embs = embedding(q_ids, emb_shape, args) + drnn = layers.DynamicRNN() + with drnn.block(): + p_emb = drnn.step_input(p_embs) + q_emb = drnn.step_input(q_embs) + + p_enc = encoder(p_emb, 'p_enc', hidden_size, args) + q_enc = encoder(q_emb, 'q_enc', hidden_size, args) # stage 2:match - g_i = attn_flow(q_enc_i, p_enc_i, p_ids_name, args) + g_i = attn_flow(q_enc, p_enc, p_ids_name, args) # stage 3:fusion m_i = fusion(g_i, args) - ms.append(m_i) - gs.append(g_i) - qs.append(q_enc_i) - m = layers.sequence_concat(input=ms) - g = layers.sequence_concat(input=gs) - q_vec = layers.sequence_concat(input=qs) + drnn.output(m_i, q_enc) + + ms, q_encs = drnn() + p_vec = layers.lod_reset(x=ms, y=start_labels) + q_vec = layers.lod_reset(x=q_encs, y=q_id0) # stage 4:decode start_probs, end_probs = point_network_decoder( - p_vec=m, q_vec=q_vec, hidden_size=hidden_size, args=args) - - start_labels = layers.data( - name="start_lables", shape=[1], dtype='float32', lod_level=1) - end_labels = layers.data( - name="end_lables", shape=[1], dtype='float32', lod_level=1) + p_vec=p_vec, q_vec=q_vec, hidden_size=hidden_size, args=args) cost0 = layers.sequence_pool( layers.cross_entropy( @@ -308,5 +316,5 @@ def rc_model(hidden_size, vocab, args): cost = cost0 + cost1 cost.persistable = True - feeding_list = q_ids_names + ["start_lables", "end_lables"] + p_ids_names - return cost, start_probs, end_probs, feeding_list + feeding_list = ["q_ids", "start_lables", "end_lables", "p_ids", "q_id0"] + return cost, start_probs, end_probs, ms, feeding_list diff --git a/fluid/machine_reading_comprehension/run.py b/fluid/machine_reading_comprehension/run.py index 1b68d79f..0ab05b90 100644 --- a/fluid/machine_reading_comprehension/run.py +++ b/fluid/machine_reading_comprehension/run.py @@ -46,22 +46,32 @@ from vocab import Vocab def prepare_batch_input(insts, args): - doc_num = args.doc_num - batch_size = len(insts['raw_data']) + inst_num = len(insts['passage_num']) + if batch_size != inst_num: + print("data error %d, %d" % (batch_size, inst_num)) + return None new_insts = [] + passage_idx = 0 for i in range(batch_size): + p_len = 0 p_id = [] - q_id = [] p_ids = [] q_ids = [] - p_len = 0 - for j in range(i * doc_num, (i + 1) * doc_num): - p_ids.append(insts['passage_token_ids'][j]) - p_id = p_id + insts['passage_token_ids'][j] - q_ids.append(insts['question_token_ids'][j]) - q_id = q_id + insts['question_token_ids'][j] + q_id = [] + p_id_r = [] + p_ids_r = [] + q_ids_r = [] + q_id_r = [] + + for j in range(insts['passage_num'][i]): + p_ids.append(insts['passage_token_ids'][passage_idx + j]) + p_id = p_id + insts['passage_token_ids'][passage_idx + j] + q_ids.append(insts['question_token_ids'][passage_idx + j]) + q_id = q_id + insts['question_token_ids'][passage_idx + j] + + passage_idx += insts['passage_num'][i] p_len = len(p_id) def _get_label(idx, ref_len): @@ -72,11 +82,46 @@ def prepare_batch_input(insts, args): start_label = _get_label(insts['start_id'][i], p_len) end_label = _get_label(insts['end_id'][i], p_len) - new_inst = q_ids + [start_label, end_label] + p_ids + new_inst = [q_ids, start_label, end_label, p_ids, q_id] new_insts.append(new_inst) return new_insts +def batch_reader(batch_list, args): + res = [] + for batch in batch_list: + res.append(prepare_batch_input(batch, args)) + return res + + +def read_multiple(reader, count, clip_last=True): + """ + Stack data from reader for multi-devices. + """ + + def __impl__(): + res = [] + for item in reader(): + res.append(item) + if len(res) == count: + yield res + res = [] + if len(res) == count: + yield res + elif not clip_last: + data = [] + for item in res: + data += item + if len(data) > count: + inst_num_per_part = len(data) // count + yield [ + data[inst_num_per_part * i:inst_num_per_part * (i + 1)] + for i in range(count) + ] + + return __impl__ + + def LodTensor_Array(lod_tensor): lod = lod_tensor.lod() array = np.array(lod_tensor) @@ -103,7 +148,7 @@ def print_para(train_prog, train_exe, logger, args): logger.info("total param num: {0}".format(num_sum)) -def find_best_answer_for_passage(start_probs, end_probs, passage_len, args): +def find_best_answer_for_passage(start_probs, end_probs, passage_len): """ Finds the best answer with the maximum start_prob * end_prob from a single passage """ @@ -125,7 +170,7 @@ def find_best_answer_for_passage(start_probs, end_probs, passage_len, args): return (best_start, best_end), max_prob -def find_best_answer(sample, start_prob, end_prob, padded_p_len, args): +def find_best_answer_for_inst(sample, start_prob, end_prob, inst_lod): """ Finds the best answer for a sample given start_prob and end_prob for each position. This will call find_best_answer_for_passage because there are multiple passages in a sample @@ -134,11 +179,16 @@ def find_best_answer(sample, start_prob, end_prob, padded_p_len, args): for p_idx, passage in enumerate(sample['passages']): if p_idx >= args.max_p_num: continue + if len(start_prob) != len(end_prob): + logger.info('error: {}'.format(sample['question'])) + continue + passage_start = inst_lod[p_idx] - inst_lod[0] + passage_end = inst_lod[p_idx + 1] - inst_lod[0] + passage_len = passage_end - passage_start passage_len = min(args.max_p_len, len(passage['passage_tokens'])) answer_span, score = find_best_answer_for_passage( - start_prob[p_idx * padded_p_len:(p_idx + 1) * padded_p_len], - end_prob[p_idx * padded_p_len:(p_idx + 1) * padded_p_len], - passage_len, args) + start_prob[passage_start:passage_end], + end_prob[passage_start:passage_end], passage_len) if score > best_score: best_score = score best_p_idx = p_idx @@ -148,11 +198,11 @@ def find_best_answer(sample, start_prob, end_prob, padded_p_len, args): else: best_answer = ''.join(sample['passages'][best_p_idx]['passage_tokens'][ best_span[0]:best_span[1] + 1]) - return best_answer + return best_answer, best_span -def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place, - vocab, brc_data, logger, args): +def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order, + place, dev_count, vocab, brc_data, logger, args): """ """ @@ -165,6 +215,8 @@ def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place, # Use test set as validation each pass total_loss = 0.0 count = 0 + n_batch_cnt = 0 + n_batch_loss = 0.0 pred_answers, ref_answers = [], [] val_feed_list = [ inference_program.global_block().var(var_name) @@ -172,55 +224,80 @@ def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place, ] val_feeder = fluid.DataFeeder(val_feed_list, place) pad_id = vocab.get_id(vocab.pad_token) - dev_batches = brc_data.gen_mini_batches( - 'dev', args.batch_size, pad_id, shuffle=False) + dev_reader = lambda:brc_data.gen_mini_batches('dev', args.batch_size, pad_id, shuffle=False) + dev_reader = read_multiple(dev_reader, dev_count) - for batch_id, batch in enumerate(dev_batches, 1): - feed_data = prepare_batch_input(batch, args) + for batch_id, batch_list in enumerate(dev_reader(), 1): + feed_data = batch_reader(batch_list, args) val_fetch_outs = parallel_executor.run( - feed=val_feeder.feed(feed_data), - fetch_list=[avg_cost.name, s_probs.name, e_probs.name], + feed=list(val_feeder.feed_parallel(feed_data, dev_count)), + fetch_list=[avg_cost.name, s_probs.name, e_probs.name, match.name], return_numpy=False) - - total_loss += np.array(val_fetch_outs[0])[0] - - start_probs = LodTensor_Array(val_fetch_outs[1]) - end_probs = LodTensor_Array(val_fetch_outs[2]) - count += len(batch['raw_data']) - - padded_p_len = len(batch['passage_token_ids'][0]) - for sample, start_prob, end_prob in zip(batch['raw_data'], start_probs, - end_probs): - - best_answer = find_best_answer(sample, start_prob, end_prob, - padded_p_len, args) - pred_answers.append({ - 'question_id': sample['question_id'], - 'question_type': sample['question_type'], - 'answers': [best_answer], - 'entity_answers': [[]], - 'yesno_answers': [] - }) - if 'answers' in sample: - ref_answers.append({ + total_loss += np.array(val_fetch_outs[0]).sum() + start_probs_m = LodTensor_Array(val_fetch_outs[1]) + end_probs_m = LodTensor_Array(val_fetch_outs[2]) + match_lod = val_fetch_outs[3].lod() + count += len(np.array(val_fetch_outs[0])) + + n_batch_cnt += len(np.array(val_fetch_outs[0])) + n_batch_loss += np.array(val_fetch_outs[0]).sum() + log_every_n_batch = args.log_interval + if log_every_n_batch > 0 and batch_id % log_every_n_batch == 0: + logger.info('Average dev loss from batch {} to {} is {}'.format( + batch_id - log_every_n_batch + 1, batch_id, "%.10f" % ( + n_batch_loss / n_batch_cnt))) + n_batch_loss = 0.0 + n_batch_cnt = 0 + + for idx, batch in enumerate(batch_list): + #one batch + batch_size = len(batch['raw_data']) + batch_range = match_lod[0][idx * batch_size:(idx + 1) * batch_size + + 1] + batch_lod = [[batch_range[x], batch_range[x + 1]] + for x in range(len(batch_range[:-1]))] + start_prob_batch = start_probs_m[idx * batch_size:(idx + 1) * + batch_size] + end_prob_batch = end_probs_m[idx * batch_size:(idx + 1) * + batch_size] + for sample, start_prob_inst, end_prob_inst, inst_range in zip( + batch['raw_data'], start_prob_batch, end_prob_batch, + batch_lod): + #one instance + inst_lod = match_lod[1][inst_range[0]:inst_range[1] + 1] + best_answer, best_span = find_best_answer_for_inst( + sample, start_prob_inst, end_prob_inst, inst_lod) + pred = { 'question_id': sample['question_id'], 'question_type': sample['question_type'], - 'answers': sample['answers'], + 'answers': [best_answer], 'entity_answers': [[]], - 'yesno_answers': [] - }) - if args.result_dir is not None and args.result_name is not None: + 'yesno_answers': [best_span] + } + pred_answers.append(pred) + if 'answers' in sample: + ref = { + 'question_id': sample['question_id'], + 'question_type': sample['question_type'], + 'answers': sample['answers'], + 'entity_answers': [[]], + 'yesno_answers': [] + } + ref_answers.append(ref) + + result_dir = args.result_dir + result_prefix = args.result_name + if result_dir is not None and result_prefix is not None: if not os.path.exists(args.result_dir): os.makedirs(args.result_dir) - result_file = os.path.join(args.result_dir, args.result_name + '.json') + result_file = os.path.join(result_dir, result_prefix + 'json') with open(result_file, 'w') as fout: for pred_answer in pred_answers: fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n') - logger.info('Saving {} results to {}'.format(args.result_name, + logger.info('Saving {} results to {}'.format(result_prefix, result_file)) ave_loss = 1.0 * total_loss / count - # compute the bleu and rouge scores if reference answers is provided if len(ref_answers) > 0: pred_dict, ref_dict = {}, {} @@ -250,6 +327,13 @@ def train(logger, args): brc_data.convert_to_ids(vocab) logger.info('Initialize the model...') + if not args.use_gpu: + place = fluid.CPUPlace() + dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + else: + place = fluid.CUDAPlace(0) + dev_count = fluid.core.get_cuda_device_count() + # build model main_program = fluid.Program() startup_prog = fluid.Program() @@ -257,7 +341,7 @@ def train(logger, args): startup_prog.random_seed = args.random_seed with fluid.program_guard(main_program, startup_prog): with fluid.unique_name.guard(): - avg_cost, s_probs, e_probs, feed_order = rc_model.rc_model( + avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model( args.hidden_size, vocab, args) # clone from default main program and use it as the validation program inference_program = main_program.clone(for_test=True) @@ -314,20 +398,21 @@ def train(logger, args): for pass_id in range(1, args.pass_num + 1): pass_start_time = time.time() pad_id = vocab.get_id(vocab.pad_token) - train_batches = brc_data.gen_mini_batches( - 'train', args.batch_size, pad_id, shuffle=True) + train_reader = lambda:brc_data.gen_mini_batches('train', args.batch_size, pad_id, shuffle=False) + train_reader = read_multiple(train_reader, dev_count) log_every_n_batch, n_batch_loss = args.log_interval, 0 total_num, total_loss = 0, 0 - for batch_id, batch in enumerate(train_batches, 1): - input_data_dict = prepare_batch_input(batch, args) + for batch_id, batch_list in enumerate(train_reader(), 1): + feed_data = batch_reader(batch_list, args) fetch_outs = parallel_executor.run( - feed=feeder.feed(input_data_dict), + feed=list(feeder.feed_parallel(feed_data, dev_count)), fetch_list=[avg_cost.name], return_numpy=False) - cost_train = np.array(fetch_outs[0])[0] - total_num += len(batch['raw_data']) + cost_train = np.array(fetch_outs[0]).mean() + total_num += args.batch_size * dev_count n_batch_loss += cost_train - total_loss += cost_train * len(batch['raw_data']) + total_loss += cost_train * args.batch_size * dev_count + if log_every_n_batch > 0 and batch_id % log_every_n_batch == 0: print_para(main_program, parallel_executor, logger, args) @@ -337,19 +422,23 @@ def train(logger, args): "%.10f" % (n_batch_loss / log_every_n_batch))) n_batch_loss = 0 if args.dev_interval > 0 and batch_id % args.dev_interval == 0: - eval_loss, bleu_rouge = validation( - inference_program, avg_cost, s_probs, e_probs, - feed_order, place, vocab, brc_data, logger, args) - logger.info('Dev eval loss {}'.format(eval_loss)) - logger.info('Dev eval result: {}'.format(bleu_rouge)) + if brc_data.dev_set is not None: + eval_loss, bleu_rouge = validation( + inference_program, avg_cost, s_probs, e_probs, + match, feed_order, place, dev_count, vocab, + brc_data, logger, args) + logger.info('Dev eval loss {}'.format(eval_loss)) + logger.info('Dev eval result: {}'.format( + bleu_rouge)) pass_end_time = time.time() logger.info('Evaluating the model after epoch {}'.format( pass_id)) if brc_data.dev_set is not None: eval_loss, bleu_rouge = validation( - inference_program, avg_cost, s_probs, e_probs, - feed_order, place, vocab, brc_data, logger, args) + inference_program, avg_cost, s_probs, e_probs, match, + feed_order, place, dev_count, vocab, brc_data, logger, + args) logger.info('Dev eval loss {}'.format(eval_loss)) logger.info('Dev eval result: {}'.format(bleu_rouge)) else: @@ -389,10 +478,17 @@ def evaluate(logger, args): startup_prog.random_seed = args.random_seed with fluid.program_guard(main_program, startup_prog): with fluid.unique_name.guard(): - avg_cost, s_probs, e_probs, feed_order = rc_model.rc_model( + avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model( args.hidden_size, vocab, args) # initialize parameters - place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace() + if not args.use_gpu: + place = fluid.CPUPlace() + dev_count = int( + os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + else: + place = fluid.CUDAPlace(0) + dev_count = fluid.core.get_cuda_device_count() + exe = Executor(place) if args.load_dir: logger.info('load from {}'.format(args.load_dir)) @@ -402,17 +498,10 @@ def evaluate(logger, args): logger.error('No model file to load ...') return - # prepare data - feed_list = [ - main_program.global_block().var(var_name) - for var_name in feed_order - ] - feeder = fluid.DataFeeder(feed_list, place) - inference_program = main_program.clone(for_test=True) eval_loss, bleu_rouge = validation( inference_program, avg_cost, s_probs, e_probs, feed_order, - place, vocab, brc_data, logger, args) + place, dev_count, vocab, brc_data, logger, args) logger.info('Dev eval loss {}'.format(eval_loss)) logger.info('Dev eval result: {}'.format(bleu_rouge)) logger.info('Predicted answers are saved to {}'.format( @@ -438,10 +527,17 @@ def predict(logger, args): startup_prog.random_seed = args.random_seed with fluid.program_guard(main_program, startup_prog): with fluid.unique_name.guard(): - avg_cost, s_probs, e_probs, feed_order = rc_model.rc_model( + avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model( args.hidden_size, vocab, args) # initialize parameters - place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace() + if not args.use_gpu: + place = fluid.CPUPlace() + dev_count = int( + os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + else: + place = fluid.CUDAPlace(0) + dev_count = fluid.core.get_cuda_device_count() + exe = Executor(place) if args.load_dir: logger.info('load from {}'.format(args.load_dir)) @@ -451,17 +547,10 @@ def predict(logger, args): logger.error('No model file to load ...') return - # prepare data - feed_list = [ - main_program.global_block().var(var_name) - for var_name in feed_order - ] - feeder = fluid.DataFeeder(feed_list, place) - inference_program = main_program.clone(for_test=True) eval_loss, bleu_rouge = validation( - inference_program, avg_cost, s_probs, e_probs, feed_order, - place, vocab, brc_data, logger, args) + inference_program, avg_cost, s_probs, e_probs, match, + feed_order, place, dev_count, vocab, brc_data, logger, args) def prepare(logger, args): diff --git a/fluid/machine_reading_comprehension/run.sh b/fluid/machine_reading_comprehension/run.sh index 4bcab2be..58bd1a21 100644 --- a/fluid/machine_reading_comprehension/run.sh +++ b/fluid/machine_reading_comprehension/run.sh @@ -1,4 +1,4 @@ -export CUDA_VISIBLE_DEVICES=1 +export CUDA_VISIBLE_DEVICES=0 python run.py \ --trainset 'data/preprocessed/trainset/search.train.json' \ 'data/preprocessed/trainset/zhidao.train.json' \ -- GitLab