diff --git a/globally_normalized_reader/.gitignore b/globally_normalized_reader/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..5707959556bb6b21e88a3a77d11f9222bba01485 --- /dev/null +++ b/globally_normalized_reader/.gitignore @@ -0,0 +1,2 @@ +*.txt +*.pyc diff --git a/globally_normalized_reader/README.md b/globally_normalized_reader/README.md new file mode 100644 index 0000000000000000000000000000000000000000..583f53f7519f3a781ea481f87077106a2f329a4a --- /dev/null +++ b/globally_normalized_reader/README.md @@ -0,0 +1,52 @@ +# Globally Normalized Reader + +This model implements the work in the following paper: + +Jonathan Raiman and John Miller. Globally Normalized Reader. Empirical Methods in Natural Language Processing (EMNLP), 2017. + +If you use the dataset/code in your research, please cite the above paper: + +```text +@inproceedings{raiman2015gnr, + author={Raiman, Jonathan and Miller, John}, + booktitle={Empirical Methods in Natural Language Processing (EMNLP)}, + title={Globally Normalized Reader}, + year={2017}, +} +``` + +You can also visit https://github.com/baidu-research/GloballyNormalizedReader to get more information. + + +# Installation + +1. Please use [docker image](http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/docker_install_en.html) to install the latest PaddlePaddle, by running: + ```bash + docker pull paddledev/paddle + ``` +2. Download all necessary data by running: + ```bash + cd data && ./download.sh + ``` +3. **(TODO) add the preprocess and featurizer scripts.** + +# Training a Model + +- Configurate the model by modifying `config.py` if needed, and then run: + + ```bash + python train.py 2>&1 | tee train.log + ``` + +# Inferring by a Trained Model + +- Infer by a trained model by running: + ```bash + python infer.py \ + --model_path models/pass_00000.tar.gz \ + --data_dir data/featurized/ \ + --batch_size 2 \ + --use_gpu 0 \ + --trainer_count 1 \ + 2>&1 | tee infer.log + ``` diff --git a/globally_normalized_reader/basic_modules.py b/globally_normalized_reader/basic_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..91aefc2f625e12ab8da9232332ad4810abd4e578 --- /dev/null +++ b/globally_normalized_reader/basic_modules.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python +#coding=utf-8 + +import collections + +import paddle.v2 as paddle +from paddle.v2.layer import parse_network + +__all__ = [ + "stacked_bidirectional_lstm", + "stacked_bidirectional_lstm_by_nested_seq", + "lstm_by_nested_sequence", +] + + +def stacked_bidirectional_lstm(inputs, + hidden_dim, + depth, + drop_rate=0., + prefix=""): + """ The stacked bi-directional LSTM. + + In PaddlePaddle recurrent layers have two different implementations: + 1. recurrent layer implemented by recurrent_group: any intermedia states a + recurent unit computes during one time step, such as hidden states, + input-to-hidden mapping, memory cells and so on, is accessable. + 2. recurrent layer as a whole: only outputs of the recurrent layer are + accessable. + + The second type (recurrent layer as a whole) is more computation efficient, + because recurrent_group is made up of many basic layers (including add, + element-wise multiplications, matrix multiplication and so on). + + This function uses the second type to implement the stacked bi-directional + LSTM. + + Arguments: + - inputs: The input layer to the bi-directional LSTM. + - hidden_dim: The dimension of the hidden state of the LSTM. + - depth: Depth of the stacked bi-directional LSTM. + - drop_rate: The drop rate to drop the LSTM output states. + - prefix: A string which will be appended to name of each layer + created in this function. Each layer in a network should + has a unique name. The prefix makes this fucntion can be + called multiple times. + """ + + if not isinstance(inputs, collections.Sequence): + inputs = [inputs] + + lstm_last = [] + for dirt in ["fwd", "bwd"]: + for i in range(depth): + input_proj = paddle.layer.mixed( + name="%s_in_proj_%0d_%s__" % (prefix, i, dirt), + size=hidden_dim * 4, + bias_attr=paddle.attr.Param(initial_std=0.), + input=[paddle.layer.full_matrix_projection(lstm)] if i else [ + paddle.layer.full_matrix_projection(in_layer) + for in_layer in inputs + ]) + lstm = paddle.layer.lstmemory( + input=input_proj, + bias_attr=paddle.attr.Param(initial_std=0.), + param_attr=paddle.attr.Param(initial_std=5e-4), + reverse=(dirt == "bwd")) + lstm_last.append(lstm) + + final_states = paddle.layer.concat(input=[ + paddle.layer.last_seq(input=lstm_last[0]), + paddle.layer.first_seq(input=lstm_last[1]), + ]) + + lstm_outs = paddle.layer.concat( + input=lstm_last, + layer_attr=paddle.attr.ExtraLayerAttribute(drop_rate=drop_rate)) + return final_states, lstm_outs + + +def lstm_by_nested_sequence(input_layer, hidden_dim, name="", reverse=False): + """This is a LSTM implemended by nested recurrent_group. + + Paragraph is a nature nested sequence: + 1. each paragraph is a sequence of sentence. + 2. each sentence is a sequence of words. + + This function ueses the nested recurrent_group to implement LSTM. + 1. The outer group iterates over sentence in a paragraph. + 2. The inner group iterates over words in a sentence. + 3. A LSTM is used to encode sentence, its final outputs is used to + initialize memory of the LSTM that is used to encode the next sentence. + 4. Parameters are shared among these sentence-encoding LSTMs. + 5. Consequently, this function is just equivalent to concatenate all + sentences in a paragraph into one (long) sentence, and use one LSTM to + encode this new long sentence. + + Arguments: + - input_layer: The input layer to the bi-directional LSTM. + - hidden_dim: The dimension of the hidden state of the LSTM. + - name: The name of the bi-directional LSTM. + - reverse: The boolean parameter indicating whether to prcess + the input sequence by the reverse order. + """ + + def lstm_outer_step(lstm_group_input, hidden_dim, reverse, name=''): + outer_memory = paddle.layer.memory( + name="__inner_%s_last__" % name, size=hidden_dim) + + def lstm_inner_step(input_layer, hidden_dim, reverse, name): + inner_memory = paddle.layer.memory( + name="__inner_state_%s__" % name, + size=hidden_dim, + boot_layer=outer_memory) + input_proj = paddle.layer.fc( + size=hidden_dim * 4, bias_attr=False, input=input_layer) + return paddle.networks.lstmemory_unit( + input=input_proj, + name="__inner_state_%s__" % name, + out_memory=inner_memory, + size=hidden_dim, + act=paddle.activation.Tanh(), + gate_act=paddle.activation.Sigmoid(), + state_act=paddle.activation.Tanh()) + + inner_out = paddle.layer.recurrent_group( + name="__inner_%s__" % name, + step=lstm_inner_step, + reverse=reverse, + input=[lstm_group_input, hidden_dim, reverse, name]) + + if reverse: + inner_last_output = paddle.layer.first_seq( + input=inner_out, + name="__inner_%s_last__" % name, + agg_level=paddle.layer.AggregateLevel.TO_NO_SEQUENCE) + else: + inner_last_output = paddle.layer.last_seq( + input=inner_out, + name="__inner_%s_last__" % name, + agg_level=paddle.layer.AggregateLevel.TO_NO_SEQUENCE) + return inner_out + + return paddle.layer.recurrent_group( + input=[ + paddle.layer.SubsequenceInput(input_layer), hidden_dim, reverse, + name + ], + step=lstm_outer_step, + name="__outter_%s__" % name, + reverse=reverse) + + +def stacked_bidirectional_lstm_by_nested_seq(input_layer, + depth, + hidden_dim, + prefix=""): + """ The stacked bi-directional LSTM to process a nested sequence. + + The modules defined in this function is exactly equivalent to + that defined in stacked_bidirectional_lstm, the only difference is the + bi-directional LSTM defined in this function implemented by recurrent_group + in PaddlePaddle, and receive a nested sequence as its input. + + Arguments: + - inputs: The input layer to the bi-directional LSTM. + - hidden_dim: The dimension of the hidden state of the LSTM. + - depth: Depth of the stacked bi-directional LSTM. + - prefix: A string which will be appended to name of each layer + created in this function. Each layer in a network should + has a unique name. The prefix makes this fucntion can be + called multiple times. + """ + + lstm_final_outs = [] + for dirt in ["fwd", "bwd"]: + for i in range(depth): + lstm_out = lstm_by_nested_sequence( + input_layer=(lstm_out if i else input_layer), + hidden_dim=hidden_dim, + name="__%s_%s_%02d__" % (prefix, dirt, i), + reverse=(dirt == "bwd")) + lstm_final_outs.append(lstm_out) + return paddle.layer.concat(input=lstm_final_outs) diff --git a/globally_normalized_reader/beam_decoding.py b/globally_normalized_reader/beam_decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..d072ca17d1a3830c6b56d5fa9c5f2a256952a819 --- /dev/null +++ b/globally_normalized_reader/beam_decoding.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python +#coding=utf-8 + +import numpy as np + +__all__ = ["BeamDecoding"] + + +class BeamDecoding(object): + """ + Decode outputs of the PaddlePaddle layers into readable answers. + """ + + def __init__(self, documents, sentence_scores, selected_sentences, + start_scores, selected_starts, end_scores, selected_ends): + """ The constructor. + + Arguments: + - documents: The one-hot input of the document words. + - sentence_scores: The score for each sentece in a document. + - selected_sentences: The top k seleceted sentence. This is the + output of the paddle.layer.kmax_seq_score + layer in the model. + - start_scores: The score for each words in the selected + sentence indicating whether the it is start + of the answer. + - selected_starts: The top k selected start spans. This is the + output of the paddle.layer.kmax_seq_score + layer in the model. + - end_scores: The score for each words in the sub-sequence + which is from the selecetd starts till end of + the selected sentence. + - selected_ends: The top k selected end spans. This is the + output of the paddle.layer.kmax_seq_score + layer in the model. + + """ + + self.documents = documents + + self.sentence_scores = sentence_scores + self.selected_sentences = selected_sentences + + self.start_scores = start_scores + self.selected_starts = selected_starts + + self.end_scores = end_scores + self.selected_ends = selected_ends + """ + sequence start position information for the three step search + beam1 is to search the sequence index + """ + self.beam1_seq_start_positions = [] + """beam2 is to search the start answer span""" + self.beam2_seq_start_positions = [] + """beam3 is to search the end answer span """ + self.beam3_seq_start_positions = [] + + self.ans_per_sample_in_a_batch = [0] + self.all_searched_ans = [] + + self.final_ans = [[] for i in range(len(documents))] + + def _build_beam1_seq_info(self): + """ + The internal function to calculate the offset of each test sequence + in a batch for the first beam in searching the answer sentence. + """ + + self.beam1_seq_start_positions.append([0]) + for idx, one_doc in enumerate(self.documents): + for sentence in one_doc: + self.beam1_seq_start_positions[-1].append( + self.beam1_seq_start_positions[-1][-1] + len(sentence)) + + if len(self.beam1_seq_start_positions) != len(self.documents): + self.beam1_seq_start_positions.append( + [self.beam1_seq_start_positions[-1][-1]]) + + def _build_beam2_seq_info(self): + """ + The internal function to calculate the offset of each test sequence + in a batch for the second beam in searching the start spans. + """ + + seq_num, beam_size = self.selected_sentences.shape + self.beam2_seq_start_positions.append([0]) + for i in range(seq_num): + for j in range(beam_size): + selected_id = int(self.selected_sentences[i][j]) + if selected_id == -1: break + seq_len = self.beam1_seq_start_positions[ + i][selected_id + + 1] - self.beam1_seq_start_positions[i][selected_id] + self.beam2_seq_start_positions[-1].append( + self.beam2_seq_start_positions[-1][-1] + seq_len) + + if len(self.beam2_seq_start_positions) != seq_num: + self.beam2_seq_start_positions.append( + [self.beam2_seq_start_positions[-1][-1]]) + + def _build_beam3_seq_info(self): + """ + The internal function to calculate the offset of each test sequence + in a batch for the third beam in searching the end spans. + """ + + seq_num_in_a_batch = len(self.documents) + + seq_id = 0 + sub_seq_id = 0 + sub_seq_count = len(self.beam2_seq_start_positions[seq_id]) - 1 + + self.beam3_seq_start_positions.append([0]) + sub_seq_num, beam_size = self.selected_starts.shape + for i in range(sub_seq_num): + seq_len = self.beam2_seq_start_positions[ + seq_id][sub_seq_id + + 1] - self.beam2_seq_start_positions[seq_id][sub_seq_id] + for j in range(beam_size): + start_id = int(self.selected_starts[i][j]) + if start_id == -1: break + + self.beam3_seq_start_positions[-1].append( + self.beam3_seq_start_positions[-1][-1] + seq_len - start_id) + + sub_seq_id += 1 + if sub_seq_id == sub_seq_count: + if len(self.beam3_seq_start_positions) != seq_num_in_a_batch: + self.beam3_seq_start_positions.append( + [self.beam3_seq_start_positions[-1][-1]]) + sub_seq_id = 0 + seq_id += 1 + sub_seq_count = len( + self.beam2_seq_start_positions[seq_id]) - 1 + assert ( + self.beam3_seq_start_positions[-1][-1] == self.end_scores.shape[0]) + + def _build_seq_info_for_each_beam(self): + """ + The internal function to calculate the offset of each test sequence + in a batch for beams expanded at all the three search steps. + """ + + self._build_beam1_seq_info() + self._build_beam2_seq_info() + self._build_beam3_seq_info() + + def _cal_ans_per_sample_in_a_batch(self): + """ + The internal function to calculate there are how many candidate answers + for each of the test sequemce in a batch. + """ + + start_row = 0 + for seq in self.beam3_seq_start_positions: + end_row = start_row + len(seq) - 1 + ans_count = np.sum(self.selected_ends[start_row:end_row, :] != -1.) + + self.ans_per_sample_in_a_batch.append( + self.ans_per_sample_in_a_batch[-1] + ans_count) + start_row = end_row + + def _get_valid_seleceted_ids(slef, mat): + """ + The internal function to post-process the output matrix of + paddle.layer.kmax_seq_score layer. This function takes off the special + dilimeter -1 away and flattens the original two-dimensional output + matrix into a python list. + """ + + flattened = [] + height, width = mat.shape + for i in range(height): + for j in range(width): + if mat[i][j] == -1.: break + flattened.append([int(mat[i][j]), [i, j]]) + return flattened + + def decoding(self): + """ + The internal function to decode forward results of the GNR network into + readable answers. + """ + + self._build_seq_info_for_each_beam() + self._cal_ans_per_sample_in_a_batch() + + seq_id = 0 + sub_seq_id = 0 + sub_seq_count = len(self.beam3_seq_start_positions[seq_id]) - 1 + + sub_seq_num, beam_size = self.selected_ends.shape + for i in xrange(sub_seq_num): + seq_offset_in_batch = self.beam3_seq_start_positions[seq_id][ + sub_seq_id] + for j in xrange(beam_size): + end_pos = int(self.selected_ends[i][j]) + if end_pos == -1: break + + self.all_searched_ans.append({ + "score": + self.end_scores[seq_offset_in_batch + end_pos], + "sentence_pos": + -1, + "start_span_pos": + -1, + "end_span_pos": + end_pos, + "parent_ids_in_prev_beam": + i + }) + + sub_seq_id += 1 + if sub_seq_id == sub_seq_count: + seq_id += 1 + if seq_id == len(self.beam3_seq_start_positions): break + + sub_seq_id = 0 + sub_seq_count = len(self.beam3_seq_start_positions[seq_id]) - 1 + + assert len(self.all_searched_ans) == self.ans_per_sample_in_a_batch[-1] + + seq_id = 0 + sub_seq_id = 0 + sub_seq_count = len(self.beam2_seq_start_positions[seq_id]) - 1 + last_row_id = None + + starts = self._get_valid_seleceted_ids(self.selected_starts) + for i, ans in enumerate(self.all_searched_ans): + ans["start_span_pos"] = starts[ans["parent_ids_in_prev_beam"]][0] + + seq_offset_in_batch = ( + self.beam2_seq_start_positions[seq_id][sub_seq_id]) + ans["score"] += self.start_scores[( + seq_offset_in_batch + ans["start_span_pos"])] + ans["parent_ids_in_prev_beam"] = starts[ans[ + "parent_ids_in_prev_beam"]][1][0] + + if last_row_id and last_row_id != ans["parent_ids_in_prev_beam"]: + sub_seq_id += 1 + + if sub_seq_id == sub_seq_count: + seq_id += 1 + if seq_id == len(self.beam2_seq_start_positions): break + sub_seq_count = len(self.beam2_seq_start_positions[seq_id]) - 1 + sub_seq_id = 0 + last_row_id = ans["parent_ids_in_prev_beam"] + + offset_info = [0] + for sen in self.beam1_seq_start_positions[:-1]: + offset_info.append(offset_info[-1] + len(sen) - 1) + sen_ids = self._get_valid_seleceted_ids(self.selected_sentences) + for ans in self.all_searched_ans: + ans["sentence_pos"] = sen_ids[ans["parent_ids_in_prev_beam"]][0] + row_id = ans["parent_ids_in_prev_beam"] / beam_size + offset = offset_info[row_id - 1] if row_id else 0 + ans["score"] += self.sentence_scores[offset + ans["sentence_pos"]] + + for i in range(len(self.ans_per_sample_in_a_batch) - 1): + start_pos = self.ans_per_sample_in_a_batch[i] + end_pos = self.ans_per_sample_in_a_batch[i + 1] + + for ans in sorted( + self.all_searched_ans[start_pos:end_pos], + key=lambda x: x["score"], + reverse=True): + self.final_ans[i].append({ + "score": + ans["score"], + "label": [ + ans["sentence_pos"], ans["start_span_pos"], + ans["end_span_pos"] + ] + }) + + return self.final_ans diff --git a/globally_normalized_reader/config.py b/globally_normalized_reader/config.py new file mode 100644 index 0000000000000000000000000000000000000000..d89fd0e48535cb6c99eb9b679b52024dac00c5dd --- /dev/null +++ b/globally_normalized_reader/config.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +#coding=utf-8 + +__all__ = ["ModelConfig", "TrainerConfig"] + + +class ModelConfig(object): + vocab_size = 104808 + embedding_dim = 300 + embedding_droprate = 0.3 + + lstm_depth = 3 + lstm_hidden_dim = 300 + lstm_hidden_droprate = 0.3 + + passage_indep_embedding_dim = 300 + passage_aligned_embedding_dim = 300 + + beam_size = 32 + + dict_path = "data/featurized/vocab.txt" + pretrained_emb_path = "data/featurized/embeddings.npy" + + +class TrainerConfig(object): + learning_rate = 1e-3 + l2_decay_rate = 5e-4 + gradient_clipping_threshold = 20 + + data_dir = "data/featurized" + save_dir = "models" + + use_gpu = False + trainer_count = 1 + train_batch_size = trainer_count * 8 + + epochs = 20 + + # for debug print, if set to 0, no information will be printed. + show_parameter_status_period = 0 + checkpoint_period = 100 + log_period = 1 + + # this is used to resume training, this path can set to previously + # trained model. + init_model_path = None diff --git a/globally_normalized_reader/data/download.sh b/globally_normalized_reader/data/download.sh new file mode 100755 index 0000000000000000000000000000000000000000..4782dd55590272e29d5723076b5141887a438278 --- /dev/null +++ b/globally_normalized_reader/data/download.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +wget --no-check-certificate https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json +wget --no-check-certificate https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json diff --git a/globally_normalized_reader/index.html b/globally_normalized_reader/index.html new file mode 100644 index 0000000000000000000000000000000000000000..2b0768ff173cc581ac558edcc64a78a44874b1ed --- /dev/null +++ b/globally_normalized_reader/index.html @@ -0,0 +1,116 @@ + + + + + + + + + + + + + + + + + +
+
+ + + + + + + diff --git a/globally_normalized_reader/infer.py b/globally_normalized_reader/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..351b2659fb974123feed1e16636171399ae89ea7 --- /dev/null +++ b/globally_normalized_reader/infer.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python +#coding=utf-8 + +import os +import sys +import argparse +import gzip +import logging +import numpy as np + +import paddle.v2 as paddle +from paddle.v2.layer import parse_network +import reader + +from model import GNR +from train import choose_samples +from config import ModelConfig +from beam_decoding import BeamDecoding + +logger = logging.getLogger("paddle") +logger.setLevel(logging.INFO) + + +def parse_cmd(): + """ + Build the command line arguments parser for inferring task. + """ + parser = argparse.ArgumentParser( + description="Globally Normalized Reader in PaddlePaddle.") + parser.add_argument( + "--model_path", + required=True, + type=str, + help="Path of the trained model to evaluate.", + default="") + parser.add_argument( + "--data_dir", + type=str, + required=True, + help="Path of the training and testing data.", + default="") + parser.add_argument( + "--batch_size", + type=int, + required=False, + help="The batch size for inferring.", + default=1) + parser.add_argument( + "--use_gpu", + type=int, + required=False, + help="Whether to run the inferring on GPU.", + default=0) + parser.add_argument( + "--trainer_count", + type=int, + required=False, + help=("The thread number used in inferring. When set " + "use_gpu=True, the trainer_count cannot excess " + "the gpu device number in your computer."), + default=1) + return parser.parse_args() + + +def load_reverse_dict(dict_file): + """ Build the dict which is used to map the word index to word string. + + The keys are word index and the values are word strings. + + Arguments: + - dict_file: The path of a word dictionary. + """ + word_dict = {} + with open(dict_file, "r") as fin: + for idx, line in enumerate(fin): + word_dict[idx] = line.strip() + return word_dict + + +def print_result(test_batch, predicted_ans, ids_2_word, print_top_k=1): + """ Print the readable predicted answers. + + Format of the output: + query:\tthe input query. + documents:\n + 0\tthe first sentence in the document. + 1\tthe second sentence in the document. + ... + gold:\t[i j k] the answer words. + (i: the sentence index; + j: the start span index; + k: the end span index) + top answers: + score0\t[i j k] the answer with the highest score. + score1\t[i j k] the answer with the second highest score. + (i, j, k has a same meaning as in gold.) + ... + + By default, top 10 answers will be printed. + + Arguments: + - test_batch: A test batch returned by reader. + - predicted_ans: The beam decoding results. + - ids_2_word: The dict whose key is word index and the values are + word strings. + - print_top_k: Indicating how many answers will be printed. + """ + + for i, sample in enumerate(test_batch): + query_words = [ids_2_word[ids] for ids in sample[0]] + print("query:\t%s" % (" ".join(query_words))) + + print("documents:") + for j, sen in enumerate(sample[1]): + sen_words = [ids_2_word[ids] for ids in sen] + start = sample[4] + end = sample[4] + sample[5] + 1 + print("%d\t%s" % (j, " ".join(sen_words))) + print("gold:\t[%d %d %d] %s" % ( + sample[3], sample[4], sample[5], " ".join( + [ids_2_word[ids] for ids in sample[1][sample[3]][start:end]]))) + + print("top answers:") + for k in range(print_top_k): + label = predicted_ans[i][k]["label"] + start = label[1] + end = label[1] + label[2] + 1 + ans_words = [ + ids_2_word[ids] for ids in sample[1][label[0]][start:end] + ] + print("%.4f\t[%d %d %d] %s" % + (predicted_ans[i][k]["score"], label[0], label[1], label[2], + " ".join(ans_words))) + print("\n") + + +def infer_a_batch(inferer, test_batch, ids_2_word, out_layer_count): + """ Call the PaddlePaddle's infer interface to infer by batch. + + Arguments: + - inferer: The PaddlePaddle Inference object. + - test_batch: A test batch returned by reader. + - ids_2_word: The dict whose key is word index and the values are + word strings. + - out_layer_count: The number of output layers in the inferring process. + """ + + outs = inferer.infer(input=test_batch, flatten_result=False, field="value") + decoder = BeamDecoding([sample[1] for sample in test_batch], *outs) + print_result(test_batch, decoder.decoding(), ids_2_word, print_top_k=10) + + +def infer(model_path, + data_dir, + batch_size, + config, + use_gpu=False, + trainer_count=1): + """ The inferring process. + + Arguments: + - model_path: The path of trained model. + - data_dir: The directory path of test data. + - batch_size: The batch_size. + - config: The model configuration. + - use_gpu: Whether to run the inferring on GPU. + - trainer_count: The thread number used in inferring. When set + use_gpu=True, the trainer_count cannot excess + the gpu device number in your computer. + """ + + assert os.path.exists(model_path), "The model does not exist." + paddle.init(use_gpu=use_gpu, trainer_count=trainer_count) + + ids_2_word = load_reverse_dict(config.dict_path) + + outputs = GNR(config, is_infer=True) + + # load the trained models + parameters = paddle.parameters.Parameters.from_tar( + gzip.open(model_path, "r")) + logger.info("loading parameter is done.") + + inferer = paddle.inference.Inference( + output_layer=outputs, parameters=parameters) + + _, valid_samples = choose_samples(data_dir) + test_reader = reader.data_reader(valid_samples, is_train=False) + + test_batch = [] + for i, item in enumerate(test_reader()): + test_batch.append(item) + if len(test_batch) == batch_size: + infer_a_batch(inferer, test_batch, ids_2_word, len(outputs)) + test_batch = [] + + if len(test_batch): + infer_a_batch(inferer, test_batch, ids_2_word, len(outputs)) + test_batch = [] + + +def main(args): + infer( + model_path=args.model_path, + data_dir=args.data_dir, + batch_size=args.batch_size, + config=ModelConfig, + use_gpu=args.use_gpu, + trainer_count=args.trainer_count) + + +if __name__ == "__main__": + args = parse_cmd() + main(args) diff --git a/globally_normalized_reader/model.py b/globally_normalized_reader/model.py new file mode 100644 index 0000000000000000000000000000000000000000..db5cccae720e947fcf75164edbc8ef395a6b8105 --- /dev/null +++ b/globally_normalized_reader/model.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python +#coding=utf-8 + +import paddle.v2 as paddle +from paddle.v2.layer import parse_network +import basic_modules +from config import ModelConfig + +__all__ = ["GNR"] + + +def build_pretrained_embedding(name, data_type, emb_dim, emb_drop=0.): + """create word a embedding layer which loads pre-trained embeddings. + + Arguments: + - name: The name of the data layer which accepts one-hot input. + - data_type: PaddlePaddle's data type for data layer. + - emb_dim: The path to the data files. + """ + + return paddle.layer.embedding( + input=paddle.layer.data(name=name, type=data_type), + size=emb_dim, + param_attr=paddle.attr.Param(name="GloveVectors", is_static=True), + layer_attr=paddle.attr.ExtraLayerAttribute(drop_rate=emb_drop), ) + + +def encode_question(input_embedding, + lstm_hidden_dim, + depth, + passage_indep_embedding_dim, + prefix=""): + """build question encoding by using bidirectional LSTM. + + Each question word is encoded by runing a stack of bidirectional LSTM over + word embedding in question, producing hidden states. The hidden states are + used to compute a passage-independent question embedding. + + The final question encoding is constructed by concatenating the final + hidden states of the forward and backward LSTMs and the passage-independent + embedding. + + Arguments: + - input_embedding: The question word embeddings. + - lstm_hidden_dim: The dimension of bi-directional LSTM. + - depth: The depth of stacked bi-directional LSTM. + - passage_indep_embedding_dim: The dimension of passage-independent + embedding. + - prefix: A string which will be appended to name of each layer + created in this function. Each layer in a network should + has a unique name. The prefix makes this fucntion can be + called multiple times. + """ + # stacked bi-directional LSTM to process question embeddings. + lstm_final, lstm_outs = basic_modules.stacked_bidirectional_lstm( + input_embedding, lstm_hidden_dim, depth, 0., prefix) + + # compute passage-independent embeddings. + candidates = paddle.layer.fc( + input=lstm_outs, + bias_attr=False, + size=passage_indep_embedding_dim, + act=paddle.activation.Linear()) + weights = paddle.layer.fc( + input=lstm_outs, + size=1, + bias_attr=False, + act=paddle.activation.SequenceSoftmax()) + weighted_candidates = paddle.layer.scaling(input=candidates, weight=weights) + passage_indep_embedding = paddle.layer.pooling( + input=weighted_candidates, pooling_type=paddle.pooling.Sum()) + + return paddle.layer.concat( + input=[lstm_final, passage_indep_embedding]), lstm_outs + + +def question_aligned_passage_embedding(question_lstm_outs, document_embeddings, + passage_aligned_embedding_dim): + """create question aligned passage embedding. + + Arguments: + - question_lstm_outs: The dimension of output of LSTM that process + question word embedding. + - document_embeddings: The document embeddings. + - passage_aligned_embedding_dim: The dimension of passage aligned + embedding. + """ + + def outer_sentence_step(document_embeddings, question_lstm_outs, + passage_aligned_embedding_dim): + """step function for PaddlePaddle's recurrent_group. + + In this function, the original input document_embeddings are scattered + from nested sequence into sequence by recurrent_group in PaddlePaddle. + The step function iterates over each sentence in the document. + + Arguments: + - document_embeddings: The word embeddings of the document. + - question_lstm_outs: The dimension of output of LSTM that + process question word embedding. + - passage_aligned_embedding_dim: The dimension of passage aligned + embedding. + """ + + def inner_word_step(word_embedding, question_lstm_outs, + question_outs_proj, passage_aligned_embedding_dim): + """ + In this recurrent_group, sentence embedding has been scattered into + word embeddings. The step function iterates over each word in one + sentence in the document. + + Arguments: + - word_embedding: The word embeddings of documents. + - question_lstm_outs: The dimension of output of LSTM that + process question word embedding. + - question_outs_proj: The projection of question_lstm_outs + into a new hidden space. + - passage_aligned_embedding_dim: The dimension of passage + aligned embedding. + """ + + doc_word_expand = paddle.layer.expand( + input=word_embedding, + expand_as=question_lstm_outs, + expand_level=paddle.layer.ExpandLevel.FROM_NO_SEQUENCE) + + weights = paddle.layer.fc( + input=[question_lstm_outs, doc_word_expand], + size=1, + bias_attr=False, + act=paddle.activation.SequenceSoftmax()) + weighted_candidates = paddle.layer.scaling( + input=question_outs_proj, weight=weights) + return paddle.layer.pooling( + input=weighted_candidates, pooling_type=paddle.pooling.Sum()) + + question_outs_proj = paddle.layer.fc( + input=question_lstm_outs, + bias_attr=False, + size=passage_aligned_embedding_dim) + return paddle.layer.recurrent_group( + input=[ + paddle.layer.SubsequenceInput(document_embeddings), + paddle.layer.StaticInput(question_lstm_outs), + paddle.layer.StaticInput(question_outs_proj), + passage_aligned_embedding_dim, + ], + step=inner_word_step, + name="iter_over_word") + + return paddle.layer.recurrent_group( + input=[ + paddle.layer.SubsequenceInput(document_embeddings), + paddle.layer.StaticInput(question_lstm_outs), + passage_aligned_embedding_dim + ], + step=outer_sentence_step, + name="iter_over_sen") + + +def encode_documents(input_embedding, same_as_question, question_vector, + question_lstm_outs, passage_indep_embedding_dim, prefix): + """Build the final question-aware document embeddings. + + Each word in the document is represented as concatenation of its word + vector, the question vector, boolean features indicating if a word appers + in the question or is repeated, and a question aligned embedding. + + + Arguments: + - input_embedding: The word embeddings of the document. + - same_as_question: The boolean features indicating if a word appears + in the question or is repeated. + - question_lstm_outs: The final question encoding. + - passage_indep_embedding_dim: The dimension of passage independent + embedding. + - prefix: The prefix which will be appended to name of each layer in + This function. + """ + + question_expanded = paddle.layer.expand( + input=question_vector, + expand_as=input_embedding, + expand_level=paddle.layer.ExpandLevel.FROM_NO_SEQUENCE) + question_aligned_embedding = question_aligned_passage_embedding( + question_lstm_outs, input_embedding, passage_indep_embedding_dim) + return paddle.layer.concat(input=[ + input_embedding, question_expanded, same_as_question, + question_aligned_embedding + ]) + + +def search_answer(doc_lstm_outs, sentence_idx, start_idx, end_idx, config, + is_infer): + """Search the answer from the document. + + The search process for this layer begins with searching a target sequence + from a nested sequence by using paddle.lauer.kmax_seq_score and + paddle.layer.sub_nested_seq_layer. In the first search step, top beam size + sequences with highest scores, indices of these top k sequences in the + original nested sequence, and the ground truth (also called gold) + altogether (a triple) make up of the first beam. + + Then, start and end positions are searched. In these searches, top k + positions with highest scores are selected, and then sequence, starting + from the selected starts till ends of the sequences are taken to search + next by using paddle.layer.seq_slice. + + Finally, the layer paddle.layer.cross_entropy_over_beam takes all the beam + expansions which contain several candidate targets found along the + three-step search. cross_entropy_over_beam calculates cross entropy over + the expanded beams which all the candidates in the beam as the normalized + factor. + + Note that, if gold falls off the beam at search step t, then the cost is + calculated over the beam at step t. + + Arguments: + - doc_lstm_outs: The output of LSTM that process each document words. + - sentence_idx: Ground-truth indicating sentence index of the answer + in the document. + - start_idx: Ground-truth indicating start span index of the answer + in the sentence. + - end_idx: Ground-truth indicating end span index of the answer + in the sentence. + - is_infer: The boolean parameter indicating inferring or training. + """ + + last_state_of_sentence = paddle.layer.last_seq( + input=doc_lstm_outs, agg_level=paddle.layer.AggregateLevel.TO_SEQUENCE) + sentence_scores = paddle.layer.fc( + input=last_state_of_sentence, + size=1, + bias_attr=False, + act=paddle.activation.Linear()) + topk_sentence_ids = paddle.layer.kmax_seq_score( + input=sentence_scores, beam_size=config.beam_size) + topk_sen = paddle.layer.sub_nested_seq( + input=doc_lstm_outs, selected_indices=topk_sentence_ids) + + # expand beam to search start positions on selected sentences + start_pos_scores = paddle.layer.fc( + input=topk_sen, + size=1, + layer_attr=paddle.attr.ExtraLayerAttribute( + error_clipping_threshold=5.0), + bias_attr=False, + act=paddle.activation.Linear()) + topk_start_pos_ids = paddle.layer.kmax_seq_score( + input=start_pos_scores, beam_size=config.beam_size) + topk_start_spans = paddle.layer.seq_slice( + input=topk_sen, starts=topk_start_pos_ids, ends=None) + + # expand beam to search end positions on selected start spans + _, end_span_embedding = basic_modules.stacked_bidirectional_lstm( + topk_start_spans, config.lstm_hidden_dim, config.lstm_depth, + config.lstm_hidden_droprate, "__end_span_embeddings__") + end_pos_scores = paddle.layer.fc( + input=end_span_embedding, + size=1, + bias_attr=False, + act=paddle.activation.Linear()) + topk_end_pos_ids = paddle.layer.kmax_seq_score( + input=end_pos_scores, beam_size=config.beam_size) + + if is_infer: + return [ + sentence_scores, topk_sentence_ids, start_pos_scores, + topk_start_pos_ids, end_pos_scores, topk_end_pos_ids + ] + else: + return paddle.layer.cross_entropy_over_beam(input=[ + paddle.layer.BeamInput(sentence_scores, topk_sentence_ids, + sentence_idx), + paddle.layer.BeamInput(start_pos_scores, topk_start_pos_ids, + start_idx), + paddle.layer.BeamInput(end_pos_scores, topk_end_pos_ids, end_idx) + ]) + + +def GNR(config, is_infer=False): + """Build the globally normalized reader model. + + Arguments: + - config: The model configuration. + - is_infer: The boolean parameter indicating inferring or training. + """ + + # encode question words + question_embeddings = build_pretrained_embedding( + "question", + paddle.data_type.integer_value_sequence(config.vocab_size), + config.embedding_dim, config.embedding_droprate) + question_vector, question_lstm_outs = encode_question( + question_embeddings, config.lstm_hidden_dim, config.lstm_depth, + config.passage_indep_embedding_dim, "__ques") + + # encode document words + document_embeddings = build_pretrained_embedding( + "documents", + paddle.data_type.integer_value_sub_sequence(config.vocab_size), + config.embedding_dim, config.embedding_droprate) + same_as_question = paddle.layer.data( + name="same_as_question", + type=paddle.data_type.dense_vector_sub_sequence(1)) + + document_words_ecoding = encode_documents( + document_embeddings, same_as_question, question_vector, + question_lstm_outs, config.passage_indep_embedding_dim, "__doc") + + doc_lstm_outs = basic_modules.stacked_bidirectional_lstm_by_nested_seq( + document_words_ecoding, config.lstm_depth, config.lstm_hidden_dim, + "__doc_lstm") + + # search the answer. + sentence_idx = paddle.layer.data( + name="sen_idx", type=paddle.data_type.integer_value(1)) + start_idx = paddle.layer.data( + name="start_idx", type=paddle.data_type.integer_value(1)) + end_idx = paddle.layer.data( + name="end_idx", type=paddle.data_type.integer_value(1)) + return search_answer(doc_lstm_outs, sentence_idx, start_idx, end_idx, + config, is_infer) + + +if __name__ == "__main__": + print(parse_network(GNR(ModelConfig))) diff --git a/globally_normalized_reader/reader.py b/globally_normalized_reader/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..c6642aa9242ebebdc758a44d6c1d09b5291f73e7 --- /dev/null +++ b/globally_normalized_reader/reader.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +#coding=utf-8 + +import os +import random +import json +import logging + +logger = logging.getLogger("paddle") +logger.setLevel(logging.INFO) + + +def data_reader(data_list, is_train=True): + """ Data reader. + + Arguments: + - data_list: A python list which contains path of training samples. + - is_train: A boolean parameter indicating this function is called + in training or in inferring. + """ + + def reader(): + """shuffle the data list again at the begining of every pass""" + if is_train: + random.shuffle(data_list) + + for train_sample in data_list: + data = json.load(open(train_sample, "r")) + + start_pos = 0 + doc = [] + same_as_question_word = [] + for l in data['sent_lengths']: + doc.append(data['context'][start_pos:start_pos + l]) + same_as_question_word.append([ + [[x]] for x in data['same_as_question_word'] + ][start_pos:start_pos + l]) + start_pos += l + + yield (data['question'], doc, same_as_question_word, + data['ans_sentence'], data['ans_start'], + data['ans_end'] - data['ans_start']) + + return reader diff --git a/globally_normalized_reader/train.py b/globally_normalized_reader/train.py new file mode 100644 index 0000000000000000000000000000000000000000..e377fa1c98b8bfcbd5014d769270146febc96ace --- /dev/null +++ b/globally_normalized_reader/train.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python +#coding=utf-8 + +from __future__ import print_function + +import os +import sys +import logging +import random +import glob +import gzip +import numpy as np + +import reader +import paddle.v2 as paddle +from paddle.v2.layer import parse_network +from model import GNR +from config import ModelConfig, TrainerConfig + +logger = logging.getLogger("paddle") +logger.setLevel(logging.INFO) + + +def load_initial_model(model_path, parameters): + """ Initalize parameters in the network from a trained model. + + This is useful in resuming the training from previously saved models. + + Arguments: + - model_path: The path of a trained model. + - parameters: The parameters in a network which will be initialized + from the specified model. + """ + with gzip.open(model_path, "rb") as f: + parameters.init_from_tar(f) + + +def load_pretrained_parameters(path): + """ Load one pre-trained parameter. + + Arguments: + - path: The path of the pre-trained parameter. + """ + return np.load(path) + + +def save_model(save_path, parameters): + """ Save the trained parameters. + + Arguments: + - save_path: The path to save the trained parameters. + - parameters: The trained model parameters. + """ + with gzip.open(save_path, "w") as f: + parameters.to_tar(f) + + +def show_parameter_init_info(parameters): + """ Print the information of initialization mean and std of parameters. + + Arguments: + - parameters: The parameters created in a model. + """ + for p in parameters: + logger.info("%s : initial_mean %.4f initial_std %.4f" % + (p, parameters.__param_conf__[p].initial_mean, + parameters.__param_conf__[p].initial_std)) + + +def show_parameter_status(parameters): + """ Print some statistical information of parameters in a network. + + This is used for debugging the model. + + Arguments: + - parameters: The parameters created in a model. + """ + for p in parameters: + + value = parameters.get(p) + grad = parameters.get_grad(p) + + avg_abs_value = np.average(np.abs(value)) + avg_abs_grad = np.average(np.abs(grad)) + + logger.info( + ("%s avg_abs_value=%.6f avg_abs_grad=%.6f " + "min_value=%.6f max_value=%.6f min_grad=%.6f max_grad=%.6f") % + (p, avg_abs_value, avg_abs_grad, value.min(), value.max(), + grad.min(), grad.max())) + + +def choose_samples(path): + """Load filenames for train, dev, and augmented samples. + + Arguments: + - path: The path of training data. + """ + if not os.path.exists(os.path.join(path, "train")): + print( + "Non-existent directory as input path: {}".format(path), + file=sys.stderr) + sys.exit(1) + + # Get paths to all samples that we want to load. + train_samples = glob.glob(os.path.join(path, "train", "*")) + valid_samples = glob.glob(os.path.join(path, "dev", "*")) + + train_samples.sort() + valid_samples.sort() + + random.shuffle(train_samples) + + return train_samples, valid_samples + + +def build_reader(data_dir, batch_size): + """Build the data reader for this model. + + Arguments: + - data_dir: The path of training data. + - batch_size: batch size for the training task. + """ + train_samples, valid_samples = choose_samples(data_dir) + + train_reader = paddle.batch( + paddle.reader.shuffle( + reader.data_reader(train_samples), buf_size=102400), + batch_size=batch_size) + + # testing data is not shuffled + test_reader = paddle.batch( + reader.data_reader(valid_samples, is_train=False), + batch_size=batch_size) + return train_reader, test_reader, len(train_samples) + + +def build_event_handler(config, parameters, trainer): + """Build the event handler for this model. + + Arguments: + - config: The training task configuration for this model. + - parameters: The parameters in the network. + - trainer: The trainer object. + """ + + # End batch and end pass event handler + def event_handler(event): + """The event handler.""" + + if isinstance(event, paddle.event.EndIteration): + if event.batch_id and \ + (not event.batch_id % config.checkpoint_period): + save_path = os.path.join(config.save_dir, + "checkpoint_param.latest.tar.gz") + save_model(save_path, parameters) + + if event.batch_id and not event.batch_id % config.log_period: + logger.info("Pass %d, Batch %d, Cost %f" % + (event.pass_id, event.batch_id, event.cost)) + + if config.show_parameter_status_period and event.batch_id and \ + not (event.batch_id % config.show_parameter_status_period): + show_parameter_status(parameters) + + if isinstance(event, paddle.event.EndPass): + save_path = os.path.join(config.save_dir, + "pass_%05d.tar.gz" % event.pass_id) + save_model(save_path, parameters) + + return event_handler + + +def train(model_config, trainer_config): + """Training the GNR model. + + Arguments: + - modle_config: The model configuration for this model. + - trainer_config: The training task configuration for this model. + """ + + if not os.path.exists(trainer_config.save_dir): + os.mkdir(trainer_config.save_dir) + + paddle.init( + use_gpu=trainer_config.use_gpu, + trainer_count=trainer_config.trainer_count) + + train_reader, test_reader, train_sample_count = build_reader( + trainer_config.data_dir, trainer_config.train_batch_size) + """ + Define the optimizer. The learning rate will decrease according to + the following formula: + + lr = learning_rate * pow(learning_rate_decay_a, + floor(num_samples_processed / + learning_rate_decay_b)) + """ + optimizer = paddle.optimizer.Adam( + learning_rate=trainer_config.learning_rate, + gradient_clipping_threshold=trainer_config.gradient_clipping_threshold, + regularization=paddle.optimizer.L2Regularization( + trainer_config.l2_decay_rate), + learning_rate_decay_a=0.5, + learning_rate_decay_b=train_sample_count, + learning_rate_schedule="discexp") + + # define network topology + loss = GNR(model_config) + + parameters = paddle.parameters.create(loss) + + if trainer_config.init_model_path: + load_initial_model(trainer_config.init_model_path, parameters) + else: + show_parameter_init_info(parameters) + parameters.set( + "GloveVectors", + load_pretrained_parameters(ModelConfig.pretrained_emb_path)) + + trainer = paddle.trainer.SGD( + cost=loss, parameters=parameters, update_equation=optimizer) + + event_handler = build_event_handler(trainer_config, parameters, trainer) + trainer.train( + reader=train_reader, + num_passes=trainer_config.epochs, + event_handler=event_handler) + + +if __name__ == "__main__": + train(ModelConfig, TrainerConfig)