From 44a72d3de70b09d6b4cd58a264791b20ae45aed8 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Fri, 1 Sep 2017 12:25:59 +0800 Subject: [PATCH] clean codes and add comments. --- globally_normalized_reader/.gitignore | 1 - globally_normalized_reader/README.md | 56 ++++- globally_normalized_reader/basic_modules.py | 90 +++++-- globally_normalized_reader/beam_decoding.py | 111 +++++++-- globally_normalized_reader/config.py | 13 +- globally_normalized_reader/data/download.sh | 4 + globally_normalized_reader/index.html | 119 ++++++++++ globally_normalized_reader/infer.py | 131 ++++++++++- globally_normalized_reader/model.py | 245 ++++++++++++++------ globally_normalized_reader/reader.py | 20 +- globally_normalized_reader/train.py | 138 ++++++----- 11 files changed, 736 insertions(+), 192 deletions(-) mode change 100755 => 100644 globally_normalized_reader/basic_modules.py create mode 100755 globally_normalized_reader/data/download.sh create mode 100644 globally_normalized_reader/index.html mode change 100755 => 100644 globally_normalized_reader/infer.py mode change 100755 => 100644 globally_normalized_reader/model.py mode change 100755 => 100644 globally_normalized_reader/reader.py mode change 100755 => 100644 globally_normalized_reader/train.py diff --git a/globally_normalized_reader/.gitignore b/globally_normalized_reader/.gitignore index c345f460..57079595 100644 --- a/globally_normalized_reader/.gitignore +++ b/globally_normalized_reader/.gitignore @@ -1,3 +1,2 @@ -data *.txt *.pyc diff --git a/globally_normalized_reader/README.md b/globally_normalized_reader/README.md index a0990367..363f4cb9 100644 --- a/globally_normalized_reader/README.md +++ b/globally_normalized_reader/README.md @@ -1 +1,55 @@ -TBD +# Globally Normalized Reader + +This model implements the work in the following paper: + +Jonathan Raiman and John Miller. Globally Normalized Reader. Empirical Methods in Natural Language Processing (EMNLP), 2017. + +If you use the dataset/code in your research, please cite the above paper: + +```text +@inproceedings{raiman2015gnr, + author={Raiman, Jonathan and Miller, John}, + booktitle={Empirical Methods in Natural Language Processing (EMNLP)}, + title={Globally Normalized Reader}, + year={2017}, +} +``` + +You can also visit https://github.com/baidu-research/GloballyNormalizedReader to get more information. + + +# Installation + +1. Please use [docker image](http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/docker_install_en.html) to install the latest PaddlePaddle, by running: + ```bash + docker pull paddledev/paddle + ``` +2. Download all necessary data by running: + ```bash + cd data && ./download.sh + ``` +3. Featurize the data by running: + ``` + python featurize.py --datadir data --outdir featurized + ``` + +# Training a Model + +- Configurate the model by modifying `config.py` if needed, and then run: + + ```bash + python train.py 2>&1 | tee train.log + ``` + +# Inferring by a Trained Model + +- Infer by a trained model by running: + ```bash + python infer.py \ + --model_path models/pass_00000.tar.gz \ + --data_dir data/featurized/ \ + --batch_size 2 \ + --use_gpu 0 \ + --trainer_count 1 \ + 2>&1 | tee infer.log + ``` diff --git a/globally_normalized_reader/basic_modules.py b/globally_normalized_reader/basic_modules.py old mode 100755 new mode 100644 index 66cb28c6..91aefc2f --- a/globally_normalized_reader/basic_modules.py +++ b/globally_normalized_reader/basic_modules.py @@ -1,5 +1,6 @@ #!/usr/bin/env python #coding=utf-8 + import collections import paddle.v2 as paddle @@ -7,11 +8,43 @@ from paddle.v2.layer import parse_network __all__ = [ "stacked_bidirectional_lstm", + "stacked_bidirectional_lstm_by_nested_seq", "lstm_by_nested_sequence", ] -def stacked_bidirectional_lstm(inputs, size, depth, drop_rate=0., prefix=""): +def stacked_bidirectional_lstm(inputs, + hidden_dim, + depth, + drop_rate=0., + prefix=""): + """ The stacked bi-directional LSTM. + + In PaddlePaddle recurrent layers have two different implementations: + 1. recurrent layer implemented by recurrent_group: any intermedia states a + recurent unit computes during one time step, such as hidden states, + input-to-hidden mapping, memory cells and so on, is accessable. + 2. recurrent layer as a whole: only outputs of the recurrent layer are + accessable. + + The second type (recurrent layer as a whole) is more computation efficient, + because recurrent_group is made up of many basic layers (including add, + element-wise multiplications, matrix multiplication and so on). + + This function uses the second type to implement the stacked bi-directional + LSTM. + + Arguments: + - inputs: The input layer to the bi-directional LSTM. + - hidden_dim: The dimension of the hidden state of the LSTM. + - depth: Depth of the stacked bi-directional LSTM. + - drop_rate: The drop rate to drop the LSTM output states. + - prefix: A string which will be appended to name of each layer + created in this function. Each layer in a network should + has a unique name. The prefix makes this fucntion can be + called multiple times. + """ + if not isinstance(inputs, collections.Sequence): inputs = [inputs] @@ -20,7 +53,7 @@ def stacked_bidirectional_lstm(inputs, size, depth, drop_rate=0., prefix=""): for i in range(depth): input_proj = paddle.layer.mixed( name="%s_in_proj_%0d_%s__" % (prefix, i, dirt), - size=size * 4, + size=hidden_dim * 4, bias_attr=paddle.attr.Param(initial_std=0.), input=[paddle.layer.full_matrix_projection(lstm)] if i else [ paddle.layer.full_matrix_projection(in_layer) @@ -45,8 +78,8 @@ def stacked_bidirectional_lstm(inputs, size, depth, drop_rate=0., prefix=""): def lstm_by_nested_sequence(input_layer, hidden_dim, name="", reverse=False): - ''' - This is a LSTM implemended by nested recurrent_group. + """This is a LSTM implemended by nested recurrent_group. + Paragraph is a nature nested sequence: 1. each paragraph is a sequence of sentence. 2. each sentence is a sequence of words. @@ -60,7 +93,14 @@ def lstm_by_nested_sequence(input_layer, hidden_dim, name="", reverse=False): 5. Consequently, this function is just equivalent to concatenate all sentences in a paragraph into one (long) sentence, and use one LSTM to encode this new long sentence. - ''' + + Arguments: + - input_layer: The input layer to the bi-directional LSTM. + - hidden_dim: The dimension of the hidden state of the LSTM. + - name: The name of the bi-directional LSTM. + - reverse: The boolean parameter indicating whether to prcess + the input sequence by the reverse order. + """ def lstm_outer_step(lstm_group_input, hidden_dim, reverse, name=''): outer_memory = paddle.layer.memory( @@ -71,9 +111,8 @@ def lstm_by_nested_sequence(input_layer, hidden_dim, name="", reverse=False): name="__inner_state_%s__" % name, size=hidden_dim, boot_layer=outer_memory) - input_proj = paddle.layer.fc(size=hidden_dim * 4, - bias_attr=False, - input=input_layer) + input_proj = paddle.layer.fc( + size=hidden_dim * 4, bias_attr=False, input=input_layer) return paddle.networks.lstmemory_unit( input=input_proj, name="__inner_state_%s__" % name, @@ -111,7 +150,27 @@ def lstm_by_nested_sequence(input_layer, hidden_dim, name="", reverse=False): reverse=reverse) -def stacked_bi_lstm_by_nested_seq(input_layer, depth, hidden_dim, prefix=""): +def stacked_bidirectional_lstm_by_nested_seq(input_layer, + depth, + hidden_dim, + prefix=""): + """ The stacked bi-directional LSTM to process a nested sequence. + + The modules defined in this function is exactly equivalent to + that defined in stacked_bidirectional_lstm, the only difference is the + bi-directional LSTM defined in this function implemented by recurrent_group + in PaddlePaddle, and receive a nested sequence as its input. + + Arguments: + - inputs: The input layer to the bi-directional LSTM. + - hidden_dim: The dimension of the hidden state of the LSTM. + - depth: Depth of the stacked bi-directional LSTM. + - prefix: A string which will be appended to name of each layer + created in this function. Each layer in a network should + has a unique name. The prefix makes this fucntion can be + called multiple times. + """ + lstm_final_outs = [] for dirt in ["fwd", "bwd"]: for i in range(depth): @@ -122,16 +181,3 @@ def stacked_bi_lstm_by_nested_seq(input_layer, depth, hidden_dim, prefix=""): reverse=(dirt == "bwd")) lstm_final_outs.append(lstm_out) return paddle.layer.concat(input=lstm_final_outs) - - -if __name__ == "__main__": - vocab_size = 1024 - emb_dim = 128 - embedding = paddle.layer.embedding( - input=paddle.layer.data( - name="word", - type=paddle.data_type.integer_value_sub_sequence(vocab_size)), - size=emb_dim) - print(parse_network( - stacked_bi_lstm_by_nested_seq( - input_layer=embedding, depth=3, hidden_dim=128, prefix="__lstm"))) diff --git a/globally_normalized_reader/beam_decoding.py b/globally_normalized_reader/beam_decoding.py index 5d651fcd..d072ca17 100644 --- a/globally_normalized_reader/beam_decoding.py +++ b/globally_normalized_reader/beam_decoding.py @@ -1,14 +1,41 @@ #!/usr/bin/env python #coding=utf-8 -import pdb + import numpy as np __all__ = ["BeamDecoding"] class BeamDecoding(object): + """ + Decode outputs of the PaddlePaddle layers into readable answers. + """ + def __init__(self, documents, sentence_scores, selected_sentences, start_scores, selected_starts, end_scores, selected_ends): + """ The constructor. + + Arguments: + - documents: The one-hot input of the document words. + - sentence_scores: The score for each sentece in a document. + - selected_sentences: The top k seleceted sentence. This is the + output of the paddle.layer.kmax_seq_score + layer in the model. + - start_scores: The score for each words in the selected + sentence indicating whether the it is start + of the answer. + - selected_starts: The top k selected start spans. This is the + output of the paddle.layer.kmax_seq_score + layer in the model. + - end_scores: The score for each words in the sub-sequence + which is from the selecetd starts till end of + the selected sentence. + - selected_ends: The top k selected end spans. This is the + output of the paddle.layer.kmax_seq_score + layer in the model. + + """ + self.documents = documents self.sentence_scores = sentence_scores @@ -19,13 +46,14 @@ class BeamDecoding(object): self.end_scores = end_scores self.selected_ends = selected_ends - - # sequence start position information for the three step search - # beam1 is to search the sequence index + """ + sequence start position information for the three step search + beam1 is to search the sequence index + """ self.beam1_seq_start_positions = [] - # beam2 is to search the start answer span + """beam2 is to search the start answer span""" self.beam2_seq_start_positions = [] - # beam3 is to search the end answer span + """beam3 is to search the end answer span """ self.beam3_seq_start_positions = [] self.ans_per_sample_in_a_batch = [0] @@ -34,6 +62,11 @@ class BeamDecoding(object): self.final_ans = [[] for i in range(len(documents))] def _build_beam1_seq_info(self): + """ + The internal function to calculate the offset of each test sequence + in a batch for the first beam in searching the answer sentence. + """ + self.beam1_seq_start_positions.append([0]) for idx, one_doc in enumerate(self.documents): for sentence in one_doc: @@ -45,15 +78,20 @@ class BeamDecoding(object): [self.beam1_seq_start_positions[-1][-1]]) def _build_beam2_seq_info(self): + """ + The internal function to calculate the offset of each test sequence + in a batch for the second beam in searching the start spans. + """ + seq_num, beam_size = self.selected_sentences.shape self.beam2_seq_start_positions.append([0]) for i in range(seq_num): for j in range(beam_size): selected_id = int(self.selected_sentences[i][j]) if selected_id == -1: break - seq_len = self.beam1_seq_start_positions[i][ - selected_id + 1] - self.beam1_seq_start_positions[i][ - selected_id] + seq_len = self.beam1_seq_start_positions[ + i][selected_id + + 1] - self.beam1_seq_start_positions[i][selected_id] self.beam2_seq_start_positions[-1].append( self.beam2_seq_start_positions[-1][-1] + seq_len) @@ -62,6 +100,11 @@ class BeamDecoding(object): [self.beam2_seq_start_positions[-1][-1]]) def _build_beam3_seq_info(self): + """ + The internal function to calculate the offset of each test sequence + in a batch for the third beam in searching the end spans. + """ + seq_num_in_a_batch = len(self.documents) seq_id = 0 @@ -71,9 +114,9 @@ class BeamDecoding(object): self.beam3_seq_start_positions.append([0]) sub_seq_num, beam_size = self.selected_starts.shape for i in range(sub_seq_num): - seq_len = self.beam2_seq_start_positions[seq_id][ - sub_seq_id + 1] - self.beam2_seq_start_positions[seq_id][ - sub_seq_id] + seq_len = self.beam2_seq_start_positions[ + seq_id][sub_seq_id + + 1] - self.beam2_seq_start_positions[seq_id][sub_seq_id] for j in range(beam_size): start_id = int(self.selected_starts[i][j]) if start_id == -1: break @@ -88,17 +131,27 @@ class BeamDecoding(object): [self.beam3_seq_start_positions[-1][-1]]) sub_seq_id = 0 seq_id += 1 - sub_seq_count = len(self.beam2_seq_start_positions[ - seq_id]) - 1 + sub_seq_count = len( + self.beam2_seq_start_positions[seq_id]) - 1 assert ( self.beam3_seq_start_positions[-1][-1] == self.end_scores.shape[0]) def _build_seq_info_for_each_beam(self): + """ + The internal function to calculate the offset of each test sequence + in a batch for beams expanded at all the three search steps. + """ + self._build_beam1_seq_info() self._build_beam2_seq_info() self._build_beam3_seq_info() def _cal_ans_per_sample_in_a_batch(self): + """ + The internal function to calculate there are how many candidate answers + for each of the test sequemce in a batch. + """ + start_row = 0 for seq in self.beam3_seq_start_positions: end_row = start_row + len(seq) - 1 @@ -109,6 +162,13 @@ class BeamDecoding(object): start_row = end_row def _get_valid_seleceted_ids(slef, mat): + """ + The internal function to post-process the output matrix of + paddle.layer.kmax_seq_score layer. This function takes off the special + dilimeter -1 away and flattens the original two-dimensional output + matrix into a python list. + """ + flattened = [] height, width = mat.shape for i in range(height): @@ -118,6 +178,11 @@ class BeamDecoding(object): return flattened def decoding(self): + """ + The internal function to decode forward results of the GNR network into + readable answers. + """ + self._build_seq_info_for_each_beam() self._cal_ans_per_sample_in_a_batch() @@ -134,11 +199,16 @@ class BeamDecoding(object): if end_pos == -1: break self.all_searched_ans.append({ - "score": self.end_scores[seq_offset_in_batch + end_pos], - "sentence_pos": -1, - "start_span_pos": -1, - "end_span_pos": end_pos, - "parent_ids_in_prev_beam": i + "score": + self.end_scores[seq_offset_in_batch + end_pos], + "sentence_pos": + -1, + "start_span_pos": + -1, + "end_span_pos": + end_pos, + "parent_ids_in_prev_beam": + i }) sub_seq_id += 1 @@ -196,7 +266,8 @@ class BeamDecoding(object): key=lambda x: x["score"], reverse=True): self.final_ans[i].append({ - "score": ans["score"], + "score": + ans["score"], "label": [ ans["sentence_pos"], ans["start_span_pos"], ans["end_span_pos"] diff --git a/globally_normalized_reader/config.py b/globally_normalized_reader/config.py index 59f6853f..d89fd0e4 100644 --- a/globally_normalized_reader/config.py +++ b/globally_normalized_reader/config.py @@ -1,7 +1,7 @@ #!/usr/bin/env python #coding=utf-8 -__all__ = ["ModelConfig"] +__all__ = ["ModelConfig", "TrainerConfig"] class ModelConfig(object): @@ -24,14 +24,15 @@ class ModelConfig(object): class TrainerConfig(object): learning_rate = 1e-3 + l2_decay_rate = 5e-4 + gradient_clipping_threshold = 20 + data_dir = "data/featurized" save_dir = "models" - use_gpu = True - trainer_count = 4 - train_batch_size = trainer_count * 10 - - test_batch_size = 4 + use_gpu = False + trainer_count = 1 + train_batch_size = trainer_count * 8 epochs = 20 diff --git a/globally_normalized_reader/data/download.sh b/globally_normalized_reader/data/download.sh new file mode 100755 index 00000000..4782dd55 --- /dev/null +++ b/globally_normalized_reader/data/download.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +wget --no-check-certificate https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json +wget --no-check-certificate https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json diff --git a/globally_normalized_reader/index.html b/globally_normalized_reader/index.html new file mode 100644 index 00000000..78cb24ee --- /dev/null +++ b/globally_normalized_reader/index.html @@ -0,0 +1,119 @@ + + + + + + + + + + + + + + + + + +
+
+ + + + + + + diff --git a/globally_normalized_reader/infer.py b/globally_normalized_reader/infer.py old mode 100755 new mode 100644 index c919016a..351b2659 --- a/globally_normalized_reader/infer.py +++ b/globally_normalized_reader/infer.py @@ -1,7 +1,9 @@ #!/usr/bin/env python #coding=utf-8 + import os import sys +import argparse import gzip import logging import numpy as np @@ -12,14 +14,62 @@ import reader from model import GNR from train import choose_samples -from config import ModelConfig, TrainerConfig +from config import ModelConfig from beam_decoding import BeamDecoding logger = logging.getLogger("paddle") logger.setLevel(logging.INFO) +def parse_cmd(): + """ + Build the command line arguments parser for inferring task. + """ + parser = argparse.ArgumentParser( + description="Globally Normalized Reader in PaddlePaddle.") + parser.add_argument( + "--model_path", + required=True, + type=str, + help="Path of the trained model to evaluate.", + default="") + parser.add_argument( + "--data_dir", + type=str, + required=True, + help="Path of the training and testing data.", + default="") + parser.add_argument( + "--batch_size", + type=int, + required=False, + help="The batch size for inferring.", + default=1) + parser.add_argument( + "--use_gpu", + type=int, + required=False, + help="Whether to run the inferring on GPU.", + default=0) + parser.add_argument( + "--trainer_count", + type=int, + required=False, + help=("The thread number used in inferring. When set " + "use_gpu=True, the trainer_count cannot excess " + "the gpu device number in your computer."), + default=1) + return parser.parse_args() + + def load_reverse_dict(dict_file): + """ Build the dict which is used to map the word index to word string. + + The keys are word index and the values are word strings. + + Arguments: + - dict_file: The path of a word dictionary. + """ word_dict = {} with open(dict_file, "r") as fin: for idx, line in enumerate(fin): @@ -28,6 +78,34 @@ def load_reverse_dict(dict_file): def print_result(test_batch, predicted_ans, ids_2_word, print_top_k=1): + """ Print the readable predicted answers. + + Format of the output: + query:\tthe input query. + documents:\n + 0\tthe first sentence in the document. + 1\tthe second sentence in the document. + ... + gold:\t[i j k] the answer words. + (i: the sentence index; + j: the start span index; + k: the end span index) + top answers: + score0\t[i j k] the answer with the highest score. + score1\t[i j k] the answer with the second highest score. + (i, j, k has a same meaning as in gold.) + ... + + By default, top 10 answers will be printed. + + Arguments: + - test_batch: A test batch returned by reader. + - predicted_ans: The beam decoding results. + - ids_2_word: The dict whose key is word index and the values are + word strings. + - print_top_k: Indicating how many answers will be printed. + """ + for i, sample in enumerate(test_batch): query_words = [ids_2_word[ids] for ids in sample[0]] print("query:\t%s" % (" ".join(query_words))) @@ -42,7 +120,7 @@ def print_result(test_batch, predicted_ans, ids_2_word, print_top_k=1): sample[3], sample[4], sample[5], " ".join( [ids_2_word[ids] for ids in sample[1][sample[3]][start:end]]))) - print("predicted:") + print("top answers:") for k in range(print_top_k): label = predicted_ans[i][k]["label"] start = label[1] @@ -57,14 +135,42 @@ def print_result(test_batch, predicted_ans, ids_2_word, print_top_k=1): def infer_a_batch(inferer, test_batch, ids_2_word, out_layer_count): + """ Call the PaddlePaddle's infer interface to infer by batch. + + Arguments: + - inferer: The PaddlePaddle Inference object. + - test_batch: A test batch returned by reader. + - ids_2_word: The dict whose key is word index and the values are + word strings. + - out_layer_count: The number of output layers in the inferring process. + """ + outs = inferer.infer(input=test_batch, flatten_result=False, field="value") decoder = BeamDecoding([sample[1] for sample in test_batch], *outs) print_result(test_batch, decoder.decoding(), ids_2_word, print_top_k=10) -def infer(model_path, data_dir, test_batch_size, config): +def infer(model_path, + data_dir, + batch_size, + config, + use_gpu=False, + trainer_count=1): + """ The inferring process. + + Arguments: + - model_path: The path of trained model. + - data_dir: The directory path of test data. + - batch_size: The batch_size. + - config: The model configuration. + - use_gpu: Whether to run the inferring on GPU. + - trainer_count: The thread number used in inferring. When set + use_gpu=True, the trainer_count cannot excess + the gpu device number in your computer. + """ + assert os.path.exists(model_path), "The model does not exist." - paddle.init(use_gpu=True, trainer_count=1) + paddle.init(use_gpu=use_gpu, trainer_count=trainer_count) ids_2_word = load_reverse_dict(config.dict_path) @@ -84,7 +190,7 @@ def infer(model_path, data_dir, test_batch_size, config): test_batch = [] for i, item in enumerate(test_reader()): test_batch.append(item) - if len(test_batch) == test_batch_size: + if len(test_batch) == batch_size: infer_a_batch(inferer, test_batch, ids_2_word, len(outputs)) test_batch = [] @@ -93,7 +199,16 @@ def infer(model_path, data_dir, test_batch_size, config): test_batch = [] +def main(args): + infer( + model_path=args.model_path, + data_dir=args.data_dir, + batch_size=args.batch_size, + config=ModelConfig, + use_gpu=args.use_gpu, + trainer_count=args.trainer_count) + + if __name__ == "__main__": - # infer("models/round1/pass_00000.tar.gz", TrainerConfig.data_dir, - infer("models/round2_on_cpu/pass_00000.tar.gz", TrainerConfig.data_dir, - TrainerConfig.test_batch_size, ModelConfig) + args = parse_cmd() + main(args) diff --git a/globally_normalized_reader/model.py b/globally_normalized_reader/model.py old mode 100755 new mode 100644 index b3e1aec9..db5cccae --- a/globally_normalized_reader/model.py +++ b/globally_normalized_reader/model.py @@ -1,5 +1,6 @@ #!/usr/bin/env python #coding=utf-8 + import paddle.v2 as paddle from paddle.v2.layer import parse_network import basic_modules @@ -9,52 +10,115 @@ __all__ = ["GNR"] def build_pretrained_embedding(name, data_type, emb_dim, emb_drop=0.): + """create word a embedding layer which loads pre-trained embeddings. + + Arguments: + - name: The name of the data layer which accepts one-hot input. + - data_type: PaddlePaddle's data type for data layer. + - emb_dim: The path to the data files. + """ + return paddle.layer.embedding( - input=paddle.layer.data( - name=name, type=data_type), + input=paddle.layer.data(name=name, type=data_type), size=emb_dim, - param_attr=paddle.attr.Param( - name="GloveVectors", is_static=True), + param_attr=paddle.attr.Param(name="GloveVectors", is_static=True), layer_attr=paddle.attr.ExtraLayerAttribute(drop_rate=emb_drop), ) -def encode_question(input_embedding, config, prefix): +def encode_question(input_embedding, + lstm_hidden_dim, + depth, + passage_indep_embedding_dim, + prefix=""): + """build question encoding by using bidirectional LSTM. + + Each question word is encoded by runing a stack of bidirectional LSTM over + word embedding in question, producing hidden states. The hidden states are + used to compute a passage-independent question embedding. + + The final question encoding is constructed by concatenating the final + hidden states of the forward and backward LSTMs and the passage-independent + embedding. + + Arguments: + - input_embedding: The question word embeddings. + - lstm_hidden_dim: The dimension of bi-directional LSTM. + - depth: The depth of stacked bi-directional LSTM. + - passage_indep_embedding_dim: The dimension of passage-independent + embedding. + - prefix: A string which will be appended to name of each layer + created in this function. Each layer in a network should + has a unique name. The prefix makes this fucntion can be + called multiple times. + """ + # stacked bi-directional LSTM to process question embeddings. lstm_final, lstm_outs = basic_modules.stacked_bidirectional_lstm( - inputs=input_embedding, - size=config.lstm_hidden_dim, - depth=config.lstm_depth, - drop_rate=config.lstm_hidden_droprate, - prefix=prefix) - - # passage-independent embeddings - candidates = paddle.layer.fc(input=lstm_outs, - bias_attr=False, - size=config.passage_indep_embedding_dim, - act=paddle.activation.Linear()) - weights = paddle.layer.fc(input=lstm_outs, - size=1, - bias_attr=False, - act=paddle.activation.SequenceSoftmax()) + input_embedding, lstm_hidden_dim, depth, 0., prefix) + + # compute passage-independent embeddings. + candidates = paddle.layer.fc( + input=lstm_outs, + bias_attr=False, + size=passage_indep_embedding_dim, + act=paddle.activation.Linear()) + weights = paddle.layer.fc( + input=lstm_outs, + size=1, + bias_attr=False, + act=paddle.activation.SequenceSoftmax()) weighted_candidates = paddle.layer.scaling(input=candidates, weight=weights) passage_indep_embedding = paddle.layer.pooling( input=weighted_candidates, pooling_type=paddle.pooling.Sum()) + return paddle.layer.concat( input=[lstm_final, passage_indep_embedding]), lstm_outs def question_aligned_passage_embedding(question_lstm_outs, document_embeddings, - config): - def outer_sentence_step(document_embeddings, question_lstm_outs, config): - ''' - in this recurrent_group, document_embeddings has scattered into sequence, - ''' + passage_aligned_embedding_dim): + """create question aligned passage embedding. + + Arguments: + - question_lstm_outs: The dimension of output of LSTM that process + question word embedding. + - document_embeddings: The document embeddings. + - passage_aligned_embedding_dim: The dimension of passage aligned + embedding. + """ + + def outer_sentence_step(document_embeddings, question_lstm_outs, + passage_aligned_embedding_dim): + """step function for PaddlePaddle's recurrent_group. + + In this function, the original input document_embeddings are scattered + from nested sequence into sequence by recurrent_group in PaddlePaddle. + The step function iterates over each sentence in the document. + + Arguments: + - document_embeddings: The word embeddings of the document. + - question_lstm_outs: The dimension of output of LSTM that + process question word embedding. + - passage_aligned_embedding_dim: The dimension of passage aligned + embedding. + """ def inner_word_step(word_embedding, question_lstm_outs, - question_outs_proj, config): - ''' - in this recurrent_group, sentence embedding has scattered into word - embeddings. - ''' + question_outs_proj, passage_aligned_embedding_dim): + """ + In this recurrent_group, sentence embedding has been scattered into + word embeddings. The step function iterates over each word in one + sentence in the document. + + Arguments: + - word_embedding: The word embeddings of documents. + - question_lstm_outs: The dimension of output of LSTM that + process question word embedding. + - question_outs_proj: The projection of question_lstm_outs + into a new hidden space. + - passage_aligned_embedding_dim: The dimension of passage + aligned embedding. + """ + doc_word_expand = paddle.layer.expand( input=word_embedding, expand_as=question_lstm_outs, @@ -62,10 +126,6 @@ def question_aligned_passage_embedding(question_lstm_outs, document_embeddings, weights = paddle.layer.fc( input=[question_lstm_outs, doc_word_expand], - param_attr=[ - paddle.attr.Param(initial_std=1e-3), - paddle.attr.Param(initial_std=1e-3) - ], size=1, bias_attr=False, act=paddle.activation.SequenceSoftmax()) @@ -77,13 +137,13 @@ def question_aligned_passage_embedding(question_lstm_outs, document_embeddings, question_outs_proj = paddle.layer.fc( input=question_lstm_outs, bias_attr=False, - size=config.passage_aligned_embedding_dim) + size=passage_aligned_embedding_dim) return paddle.layer.recurrent_group( input=[ paddle.layer.SubsequenceInput(document_embeddings), paddle.layer.StaticInput(question_lstm_outs), paddle.layer.StaticInput(question_outs_proj), - config, + passage_aligned_embedding_dim, ], step=inner_word_step, name="iter_over_word") @@ -91,20 +151,39 @@ def question_aligned_passage_embedding(question_lstm_outs, document_embeddings, return paddle.layer.recurrent_group( input=[ paddle.layer.SubsequenceInput(document_embeddings), - paddle.layer.StaticInput(question_lstm_outs), config + paddle.layer.StaticInput(question_lstm_outs), + passage_aligned_embedding_dim ], step=outer_sentence_step, name="iter_over_sen") def encode_documents(input_embedding, same_as_question, question_vector, - question_lstm_outs, config, prefix): + question_lstm_outs, passage_indep_embedding_dim, prefix): + """Build the final question-aware document embeddings. + + Each word in the document is represented as concatenation of its word + vector, the question vector, boolean features indicating if a word appers + in the question or is repeated, and a question aligned embedding. + + + Arguments: + - input_embedding: The word embeddings of the document. + - same_as_question: The boolean features indicating if a word appears + in the question or is repeated. + - question_lstm_outs: The final question encoding. + - passage_indep_embedding_dim: The dimension of passage independent + embedding. + - prefix: The prefix which will be appended to name of each layer in + This function. + """ + question_expanded = paddle.layer.expand( input=question_vector, expand_as=input_embedding, expand_level=paddle.layer.ExpandLevel.FROM_NO_SEQUENCE) question_aligned_embedding = question_aligned_passage_embedding( - question_lstm_outs, input_embedding, config) + question_lstm_outs, input_embedding, passage_indep_embedding_dim) return paddle.layer.concat(input=[ input_embedding, question_expanded, same_as_question, question_aligned_embedding @@ -113,15 +192,48 @@ def encode_documents(input_embedding, same_as_question, question_vector, def search_answer(doc_lstm_outs, sentence_idx, start_idx, end_idx, config, is_infer): + """Search the answer from the document. + + The search process for this layer begins with searching a target sequence + from a nested sequence by using paddle.lauer.kmax_seq_score and + paddle.layer.sub_nested_seq_layer. In the first search step, top beam size + sequences with highest scores, indices of these top k sequences in the + original nested sequence, and the ground truth (also called gold) + altogether (a triple) make up of the first beam. + + Then, start and end positions are searched. In these searches, top k + positions with highest scores are selected, and then sequence, starting + from the selected starts till ends of the sequences are taken to search + next by using paddle.layer.seq_slice. + + Finally, the layer paddle.layer.cross_entropy_over_beam takes all the beam + expansions which contain several candidate targets found along the + three-step search. cross_entropy_over_beam calculates cross entropy over + the expanded beams which all the candidates in the beam as the normalized + factor. + + Note that, if gold falls off the beam at search step t, then the cost is + calculated over the beam at step t. + + Arguments: + - doc_lstm_outs: The output of LSTM that process each document words. + - sentence_idx: Ground-truth indicating sentence index of the answer + in the document. + - start_idx: Ground-truth indicating start span index of the answer + in the sentence. + - end_idx: Ground-truth indicating end span index of the answer + in the sentence. + - is_infer: The boolean parameter indicating inferring or training. + """ + last_state_of_sentence = paddle.layer.last_seq( input=doc_lstm_outs, agg_level=paddle.layer.AggregateLevel.TO_SEQUENCE) sentence_scores = paddle.layer.fc( input=last_state_of_sentence, size=1, bias_attr=False, - param_attr=paddle.attr.Param(initial_std=1e-3), act=paddle.activation.Linear()) - topk_sentence_ids = paddle.layer.kmax_sequence_score( + topk_sentence_ids = paddle.layer.kmax_seq_score( input=sentence_scores, beam_size=config.beam_size) topk_sen = paddle.layer.sub_nested_seq( input=doc_lstm_outs, selected_indices=topk_sentence_ids) @@ -131,29 +243,24 @@ def search_answer(doc_lstm_outs, sentence_idx, start_idx, end_idx, config, input=topk_sen, size=1, layer_attr=paddle.attr.ExtraLayerAttribute( - error_clipping_threshold=10.0), + error_clipping_threshold=5.0), bias_attr=False, - param_attr=paddle.attr.Param(initial_std=1e-3), act=paddle.activation.Linear()) - topk_start_pos_ids = paddle.layer.kmax_sequence_score( + topk_start_pos_ids = paddle.layer.kmax_seq_score( input=start_pos_scores, beam_size=config.beam_size) topk_start_spans = paddle.layer.seq_slice( input=topk_sen, starts=topk_start_pos_ids, ends=None) # expand beam to search end positions on selected start spans _, end_span_embedding = basic_modules.stacked_bidirectional_lstm( - inputs=topk_start_spans, - size=config.lstm_hidden_dim, - depth=config.lstm_depth, - drop_rate=config.lstm_hidden_droprate, - prefix="__end_span_embeddings__") + topk_start_spans, config.lstm_hidden_dim, config.lstm_depth, + config.lstm_hidden_droprate, "__end_span_embeddings__") end_pos_scores = paddle.layer.fc( input=end_span_embedding, size=1, bias_attr=False, - param_attr=paddle.attr.Param(initial_std=1e-3), act=paddle.activation.Linear()) - topk_end_pos_ids = paddle.layer.kmax_sequence_score( + topk_end_pos_ids = paddle.layer.kmax_seq_score( input=end_pos_scores, beam_size=config.beam_size) if is_infer: @@ -172,15 +279,23 @@ def search_answer(doc_lstm_outs, sentence_idx, start_idx, end_idx, config, def GNR(config, is_infer=False): - # encoding question words + """Build the globally normalized reader model. + + Arguments: + - config: The model configuration. + - is_infer: The boolean parameter indicating inferring or training. + """ + + # encode question words question_embeddings = build_pretrained_embedding( "question", paddle.data_type.integer_value_sequence(config.vocab_size), config.embedding_dim, config.embedding_droprate) question_vector, question_lstm_outs = encode_question( - input_embedding=question_embeddings, config=config, prefix="__ques") + question_embeddings, config.lstm_hidden_dim, config.lstm_depth, + config.passage_indep_embedding_dim, "__ques") - # encoding document words + # encode document words document_embeddings = build_pretrained_embedding( "documents", paddle.data_type.integer_value_sub_sequence(config.vocab_size), @@ -190,20 +305,14 @@ def GNR(config, is_infer=False): type=paddle.data_type.dense_vector_sub_sequence(1)) document_words_ecoding = encode_documents( - input_embedding=document_embeddings, - question_vector=question_vector, - question_lstm_outs=question_lstm_outs, - same_as_question=same_as_question, - config=config, - prefix="__doc") - - doc_lstm_outs = basic_modules.stacked_bi_lstm_by_nested_seq( - input_layer=document_words_ecoding, - hidden_dim=config.lstm_hidden_dim, - depth=config.lstm_depth, - prefix="__doc_lstm") - - # define labels + document_embeddings, same_as_question, question_vector, + question_lstm_outs, config.passage_indep_embedding_dim, "__doc") + + doc_lstm_outs = basic_modules.stacked_bidirectional_lstm_by_nested_seq( + document_words_ecoding, config.lstm_depth, config.lstm_hidden_dim, + "__doc_lstm") + + # search the answer. sentence_idx = paddle.layer.data( name="sen_idx", type=paddle.data_type.integer_value(1)) start_idx = paddle.layer.data( diff --git a/globally_normalized_reader/reader.py b/globally_normalized_reader/reader.py old mode 100755 new mode 100644 index e17f2a7c..c6642aa9 --- a/globally_normalized_reader/reader.py +++ b/globally_normalized_reader/reader.py @@ -1,5 +1,6 @@ #!/usr/bin/env python #coding=utf-8 + import os import random import json @@ -10,8 +11,16 @@ logger.setLevel(logging.INFO) def data_reader(data_list, is_train=True): + """ Data reader. + + Arguments: + - data_list: A python list which contains path of training samples. + - is_train: A boolean parameter indicating this function is called + in training or in inferring. + """ + def reader(): - # every pass shuffle the data list again + """shuffle the data list again at the begining of every pass""" if is_train: random.shuffle(data_list) @@ -33,12 +42,3 @@ def data_reader(data_list, is_train=True): data['ans_end'] - data['ans_start']) return reader - - -if __name__ == "__main__": - from train import choose_samples - - train_list, dev_list = choose_samples("data/featurized") - for i, item in enumerate(data_reader(train_list)()): - print(item) - if i > 5: break diff --git a/globally_normalized_reader/train.py b/globally_normalized_reader/train.py old mode 100755 new mode 100644 index 6fdd523e..e377fa1c --- a/globally_normalized_reader/train.py +++ b/globally_normalized_reader/train.py @@ -1,5 +1,6 @@ #!/usr/bin/env python #coding=utf-8 + from __future__ import print_function import os @@ -21,44 +22,59 @@ logger.setLevel(logging.INFO) def load_initial_model(model_path, parameters): - """ - initalize parameters in the network from a trained model. + """ Initalize parameters in the network from a trained model. + + This is useful in resuming the training from previously saved models. + + Arguments: + - model_path: The path of a trained model. + - parameters: The parameters in a network which will be initialized + from the specified model. """ with gzip.open(model_path, "rb") as f: parameters.init_from_tar(f) -def load_pretrained_parameters(path, height, width): +def load_pretrained_parameters(path): + """ Load one pre-trained parameter. + + Arguments: + - path: The path of the pre-trained parameter. + """ return np.load(path) def save_model(save_path, parameters): + """ Save the trained parameters. + + Arguments: + - save_path: The path to save the trained parameters. + - parameters: The trained model parameters. + """ with gzip.open(save_path, "w") as f: parameters.to_tar(f) -def load_initial_model(model_path, parameters): - with gzip.open(model_path, "rb") as f: - parameters.init_from_tar(f) - - def show_parameter_init_info(parameters): + """ Print the information of initialization mean and std of parameters. + + Arguments: + - parameters: The parameters created in a model. + """ for p in parameters: logger.info("%s : initial_mean %.4f initial_std %.4f" % (p, parameters.__param_conf__[p].initial_mean, parameters.__param_conf__[p].initial_std)) -def dump_value_matrix(param_name, dims, value): - np.savetxt( - param_name + ".txt", - value.reshape(dims[0], dims[1]), - fmt="%.4f", - delimiter=",") +def show_parameter_status(parameters): + """ Print some statistical information of parameters in a network. + This is used for debugging the model. -def show_parameter_status(parameters): - # for debug print + Arguments: + - parameters: The parameters created in a model. + """ for p in parameters: value = parameters.get(p) @@ -75,8 +91,10 @@ def show_parameter_status(parameters): def choose_samples(path): - """ - Load filenames for train, dev, and augmented samples. + """Load filenames for train, dev, and augmented samples. + + Arguments: + - path: The path of training data. """ if not os.path.exists(os.path.join(path, "train")): print( @@ -97,8 +115,11 @@ def choose_samples(path): def build_reader(data_dir, batch_size): - """ - Build the data reader for this model. + """Build the data reader for this model. + + Arguments: + - data_dir: The path of training data. + - batch_size: batch size for the training task. """ train_samples, valid_samples = choose_samples(data_dir) @@ -109,15 +130,18 @@ def build_reader(data_dir, batch_size): # testing data is not shuffled test_reader = paddle.batch( - reader.data_reader( - valid_samples, is_train=False), + reader.data_reader(valid_samples, is_train=False), batch_size=batch_size) - return train_reader, test_reader + return train_reader, test_reader, len(train_samples) -def build_event_handler(config, parameters, trainer, test_reader): - """ - Build the event handler for this model. +def build_event_handler(config, parameters, trainer): + """Build the event handler for this model. + + Arguments: + - config: The training task configuration for this model. + - parameters: The parameters in the network. + - trainer: The trainer object. """ # End batch and end pass event handler @@ -127,12 +151,8 @@ def build_event_handler(config, parameters, trainer, test_reader): if isinstance(event, paddle.event.EndIteration): if event.batch_id and \ (not event.batch_id % config.checkpoint_period): - # save_path = os.path.join(config.save_dir, - # "checkpoint_param.latest.tar.gz") - save_path = os.path.join(config.save_dir, - "pass_%05d_%03d.tar.gz" % - (event.pass_id, event.batch_id)) + "checkpoint_param.latest.tar.gz") save_model(save_path, parameters) if event.batch_id and not event.batch_id % config.log_period: @@ -148,14 +168,17 @@ def build_event_handler(config, parameters, trainer, test_reader): "pass_%05d.tar.gz" % event.pass_id) save_model(save_path, parameters) - # result = trainer.test(reader=test_reader) - # logger.info("Test with Pass %d, %s" % - # (event.pass_id, result.metrics)) - return event_handler def train(model_config, trainer_config): + """Training the GNR model. + + Arguments: + - modle_config: The model configuration for this model. + - trainer_config: The training task configuration for this model. + """ + if not os.path.exists(trainer_config.save_dir): os.mkdir(trainer_config.save_dir) @@ -163,13 +186,24 @@ def train(model_config, trainer_config): use_gpu=trainer_config.use_gpu, trainer_count=trainer_config.trainer_count) - # define the optimizer + train_reader, test_reader, train_sample_count = build_reader( + trainer_config.data_dir, trainer_config.train_batch_size) + """ + Define the optimizer. The learning rate will decrease according to + the following formula: + + lr = learning_rate * pow(learning_rate_decay_a, + floor(num_samples_processed / + learning_rate_decay_b)) + """ optimizer = paddle.optimizer.Adam( learning_rate=trainer_config.learning_rate, - gradient_clipping_threshold=50, - regularization=paddle.optimizer.L2Regularization(rate=5e-4), - model_average=paddle.optimizer.ModelAverage( - average_window=0.5, max_average_window=1000)) + gradient_clipping_threshold=trainer_config.gradient_clipping_threshold, + regularization=paddle.optimizer.L2Regularization( + trainer_config.l2_decay_rate), + learning_rate_decay_a=0.5, + learning_rate_decay_b=train_sample_count, + learning_rate_schedule="discexp") # define network topology loss = GNR(model_config) @@ -180,22 +214,14 @@ def train(model_config, trainer_config): load_initial_model(trainer_config.init_model_path, parameters) else: show_parameter_init_info(parameters) - # load the pre-trained embeddings - parameters.set("GloveVectors", - load_pretrained_parameters( - ModelConfig.pretrained_emb_path, - height=ModelConfig.vocab_size, - width=ModelConfig.embedding_dim)) - - trainer = paddle.trainer.SGD(cost=loss, - parameters=parameters, - update_equation=optimizer) - - train_reader, test_reader = build_reader(trainer_config.data_dir, - trainer_config.train_batch_size) - - event_handler = build_event_handler(trainer_config, parameters, trainer, - test_reader) + parameters.set( + "GloveVectors", + load_pretrained_parameters(ModelConfig.pretrained_emb_path)) + + trainer = paddle.trainer.SGD( + cost=loss, parameters=parameters, update_equation=optimizer) + + event_handler = build_event_handler(trainer_config, parameters, trainer) trainer.train( reader=train_reader, num_passes=trainer_config.epochs, -- GitLab