From 71745b889d6796c32b073d765b15b9c857e8359e Mon Sep 17 00:00:00 2001 From: ranqiu Date: Mon, 27 Nov 2017 17:25:19 +0800 Subject: [PATCH] Update conv_seq2seq --- conv_seq2seq/README.md | 40 +++--- conv_seq2seq/beamsearch.py | 243 +++++++++++++++++++++---------------- conv_seq2seq/download.sh | 22 ++++ conv_seq2seq/infer.py | 75 +++++++++--- conv_seq2seq/model.py | 67 ++++++---- conv_seq2seq/preprocess.py | 30 +++++ conv_seq2seq/reader.py | 4 +- conv_seq2seq/train.py | 17 ++- 8 files changed, 332 insertions(+), 166 deletions(-) create mode 100644 conv_seq2seq/download.sh create mode 100644 conv_seq2seq/preprocess.py diff --git a/conv_seq2seq/README.md b/conv_seq2seq/README.md index 75ea8770..920c6645 100644 --- a/conv_seq2seq/README.md +++ b/conv_seq2seq/README.md @@ -4,55 +4,63 @@ This model implements the work in the following paper: Jonas Gehring, Micheal Auli, David Grangier, et al. Convolutional Sequence to Sequence Learning. Association for Computational Linguistics (ACL), 2017 # Data Preparation +- The data used in this tutorial can be downloaded by runing: -- In this tutorial, each line in a data file contains one sample and each sample consists of a source sentence and a target sentence. And the two sentences are seperated by '\t'. So, to use your own data, it should be organized as follows: + ```bash + sh download.sh + ``` - ``` - \t - ``` +- Each line in the data file contains one sample and each sample consists of a source sentence and a target sentence. And the two sentences are seperated by '\t'. So, to use your own data, it should be organized as follows: + + ``` + \t + ``` # Training a Model - Modify the following script if needed and then run: - ```bash - python train.py \ - --train_data_path ./data/train_data \ - --test_data_path ./data/test_data \ + ```bash + python train.py \ + --train_data_path ./data/train \ + --test_data_path ./data/test \ --src_dict_path ./data/src_dict \ --trg_dict_path ./data/trg_dict \ --enc_blocks "[(256, 3)] * 5" \ --dec_blocks "[(256, 3)] * 3" \ --emb_size 256 \ --pos_size 200 \ - --drop_rate 0.1 \ + --drop_rate 0.2 \ + --use_bn False \ --use_gpu False \ --trainer_count 1 \ --batch_size 32 \ --num_passes 20 \ >train.log 2>&1 - ``` + ``` # Inferring by a Trained Model - Infer by a trained model by running: - ```bash - python infer.py \ - --infer_data_path ./data/infer_data \ + ```bash + python infer.py \ + --infer_data_path ./data/dev \ --src_dict_path ./data/src_dict \ --trg_dict_path ./data/trg_dict \ --enc_blocks "[(256, 3)] * 5" \ --dec_blocks "[(256, 3)] * 3" \ --emb_size 256 \ --pos_size 200 \ - --drop_rate 0.1 \ + --drop_rate 0.2 \ + --use_bn False \ --use_gpu False \ --trainer_count 1 \ --max_len 100 \ + --batch_size 256 \ --beam_size 1 \ + --is_show_attention False \ --model_path ./params.pass-0.tar.gz \ 1>infer_result 2>infer.log ``` # Notes - -Currently, beam search will forward the encoder multiple times when predicting each target word, which requires extra computations. And we will fix it later. +Since PaddlePaddle of current version doesn't support weight normalization, we use batch normalization instead to confirm convergence when the network is deep. diff --git a/conv_seq2seq/beamsearch.py b/conv_seq2seq/beamsearch.py index 45656e80..22318c29 100644 --- a/conv_seq2seq/beamsearch.py +++ b/conv_seq2seq/beamsearch.py @@ -2,8 +2,11 @@ import sys import time +import math import numpy as np +import reader + class BeamSearch(object): """ @@ -16,44 +19,42 @@ class BeamSearch(object): trg_dict, pos_size, padding_num, + batch_size=1, beam_size=1, max_len=100): self.inferer = inferer self.trg_dict = trg_dict + self.reverse_trg_dict = reader.get_reverse_dict(trg_dict) self.word_padding = trg_dict.__len__() self.pos_size = pos_size self.pos_padding = pos_size self.padding_num = padding_num self.win_len = padding_num + 1 self.max_len = max_len + self.batch_size = batch_size self.beam_size = beam_size - def get_beam_input(self, pre_beam_list, infer_data): + def get_beam_input(self, batch, sample_list): """ Get input for generation at the current iteration. """ beam_input = [] - if len(pre_beam_list) == 0: - cur_trg = [self.word_padding - ] * self.padding_num + [self.trg_dict['']] - cur_trg_pos = [self.pos_padding] * self.padding_num + [0] - beam_input.append(infer_data + [cur_trg] + [cur_trg_pos]) - else: - for seq in pre_beam_list: - if len(seq) < self.win_len: - cur_trg = [self.word_padding] * ( - self.win_len - len(seq) - 1 - ) + [self.trg_dict['']] + seq - cur_trg_pos = [self.pos_padding] * ( - self.win_len - len(seq) - 1) + [0] + range(1, - len(seq) + 1) + for sample_id in sample_list: + for path in self.candidate_path[sample_id]: + if len(path['seq']) < self.win_len: + cur_trg = [self.word_padding] * (self.win_len - len( + path['seq']) - 1) + [self.trg_dict['']] + path['seq'] + cur_trg_pos = [self.pos_padding] * (self.win_len - len( + path['seq']) - 1) + [0] + range(1, len(path['seq']) + 1) else: - cur_trg = seq[-self.win_len:] + cur_trg = path['seq'][-self.win_len:] cur_trg_pos = range( - len(seq) + 1 - self.win_len, len(seq) + 1) + len(path['seq']) + 1 - self.win_len, + len(path['seq']) + 1) + + beam_input.append(batch[sample_id] + [cur_trg] + [cur_trg_pos]) - beam_input.append(infer_data + [cur_trg] + [cur_trg_pos]) return beam_input def get_prob(self, beam_input): @@ -64,100 +65,136 @@ class BeamSearch(object): prob = self.inferer.infer(beam_input, field='value')[row_list, :] return prob - def get_candidate(self, pre_beam_list, pre_beam_score, prob): + def _top_k(self, prob, k): """ - Get top beam_size tokens and their scores for each beam. + Get indices of the words with k highest probablities. """ - if prob.ndim == 1: - candidate_id = prob.argsort()[-self.beam_size:][::-1] - candidate_log_prob = np.log(prob[candidate_id]) - else: - candidate_id = prob.argsort()[:, -self.beam_size:][:, ::-1] - candidate_log_prob = np.zeros_like(candidate_id).astype('float32') - for j in range(len(pre_beam_list)): - candidate_log_prob[j, :] = np.log(prob[j, candidate_id[j, :]]) - - if pre_beam_score.size > 0: - candidate_score = candidate_log_prob + pre_beam_score.reshape( - (pre_beam_score.size, 1)) - else: - candidate_score = candidate_log_prob - - return candidate_id, candidate_score - - def prune(self, candidate_id, candidate_score, pre_beam_list, - completed_seq_list, completed_seq_score, completed_seq_min_score): + return prob.argsort()[-k:][::-1] + + def beam_expand(self, prob, sample_list): """ - Pruning process of the beam search. During the process, beam_size most possible sequences - are selected for the beam in the next iteration. Besides, their scores and the minimum score - of the completed sequences are updated. + In every iteration step, the model predicts the possible next words. + For each input sentence, the top beam_size words are selected as candidates. """ - candidate_id = candidate_id.flatten() - candidate_score = candidate_score.flatten() - - topk_idx = candidate_score.argsort()[-self.beam_size:][::-1].tolist() - topk_seq_idx = [idx / self.beam_size for idx in topk_idx] - - next_beam = [] - beam_score = [] - for j in range(len(topk_idx)): - if candidate_id[topk_idx[j]] == self.trg_dict['']: - if len( - completed_seq_list - ) < self.beam_size or completed_seq_min_score <= candidate_score[ - topk_idx[j]]: - completed_seq_list.append(pre_beam_list[topk_seq_idx[j]]) - completed_seq_score.append(candidate_score[topk_idx[j]]) - - if completed_seq_min_score is None or ( - completed_seq_min_score >= - candidate_score[topk_idx[j]] and - len(completed_seq_list) < self.beam_size): - completed_seq_min_score = candidate_score[topk_idx[j]] - else: - seq = pre_beam_list[topk_seq_idx[ - j]] + [candidate_id[topk_idx[j]]] - score = candidate_score[topk_idx[j]] - next_beam.append(seq) - beam_score.append(score) - - beam_score = np.array(beam_score) - return next_beam, beam_score, completed_seq_min_score - - def search_one_sample(self, infer_data): + top_words = np.apply_along_axis(self._top_k, 1, prob, self.beam_size) + + candidate_words = [[]] * len(self.candidate_path) + idx = 0 + + for sample_id in sample_list: + for seq_id, path in enumerate(self.candidate_path[sample_id]): + for w in top_words[idx, :]: + score = path['score'] + math.log(prob[idx, w]) + candidate_words[sample_id] = candidate_words[sample_id] + [ + { + 'word': w, + 'score': score, + 'seq_id': seq_id + } + ] + idx = idx + 1 + + return candidate_words + + def beam_shrink(self, candidate_words, sample_list): """ - Beam search process for one sample. + Pruning process of the beam search. During the process, beam_size most post possible + sequences are selected for the beam in the next generation. """ - completed_seq_list = [] - completed_seq_score = [] - completed_seq_min_score = None - uncompleted_seq_list = [[]] - uncompleted_seq_score = np.zeros(0) + new_path = [[]] * len(self.candidate_path) + + for sample_id in sample_list: + beam_words = sorted( + candidate_words[sample_id], + key=lambda x: x['score'], + reverse=True)[:self.beam_size] + + complete_seq_min_score = None + complete_path_num = len(self.complete_path[sample_id]) + + if complete_path_num > 0: + complete_seq_min_score = min(self.complete_path[sample_id], + key=lambda x: x['score'])['score'] + if complete_path_num >= self.beam_size: + beam_words_max_score = beam_words[0]['score'] + if beam_words_max_score < complete_seq_min_score: + continue + + for w in beam_words: + + if w['word'] == self.trg_dict['']: + if complete_path_num < self.beam_size or complete_seq_min_score <= w[ + 'score']: + + seq = self.candidate_path[sample_id][w['seq_id']]['seq'] + self.complete_path[sample_id] = self.complete_path[ + sample_id] + [{ + 'seq': seq, + 'score': w['score'] + }] + + if complete_seq_min_score is None or complete_seq_min_score > w[ + 'score']: + complete_seq_min_score = w['score'] + else: + seq = self.candidate_path[sample_id][w['seq_id']]['seq'] + [ + w['word'] + ] + new_path[sample_id] = new_path[sample_id] + [{ + 'seq': + seq, + 'score': + w['score'] + }] + + return new_path + + def search_one_batch(self, batch): + """ + Perform beam search on one mini-batch. + """ + real_size = len(batch) + self.candidate_path = [[{'seq': [], 'score': 0.}]] * real_size + self.complete_path = [[]] * real_size + sample_list = range(real_size) for i in xrange(self.max_len): - beam_input = self.get_beam_input(uncompleted_seq_list, infer_data) - + beam_input = self.get_beam_input(batch, sample_list) prob = self.get_prob(beam_input) - candidate_id, candidate_score = self.get_candidate( - uncompleted_seq_list, uncompleted_seq_score, prob) + candidate_words = self.beam_expand(prob, sample_list) + new_path = self.beam_shrink(candidate_words, sample_list) + self.candidate_path = new_path + sample_list = [ + sample_id for sample_id in sample_list + if len(new_path[sample_id]) > 0 + ] - uncompleted_seq_list, uncompleted_seq_score, completed_seq_min_score = self.prune( - candidate_id, candidate_score, uncompleted_seq_list, - completed_seq_list, completed_seq_score, - completed_seq_min_score) - - if len(uncompleted_seq_list) == 0: + if len(sample_list) == 0: break - if len(completed_seq_list) >= self.beam_size: - seq_max_score = uncompleted_seq_score.max() - if seq_max_score < completed_seq_min_score: - uncompleted_seq_list = [] - break - - final_seq_list = completed_seq_list + uncompleted_seq_list - final_score = np.concatenate( - (np.array(completed_seq_score), uncompleted_seq_score)) - max_id = final_score.argmax() - top_seq = final_seq_list[max_id] - return top_seq + + final_path = [] + for i in xrange(real_size): + top_path = sorted( + self.complete_path[i] + self.candidate_path[i], + key=lambda x: x['score'], + reverse=True)[:self.beam_size] + final_path.append(top_path) + return final_path + + def search(self, infer_data): + """ + Perform beam search on all data. + """ + + def _to_sentence(seq): + raw_sentence = [self.reverse_trg_dict[id] for id in seq] + sentence = " ".join(raw_sentence) + return sentence + + for pos in xrange(0, len(infer_data), self.batch_size): + batch = infer_data[pos:min(pos + self.batch_size, len(infer_data))] + self.final_path = self.search_one_batch(batch) + for top_path in self.final_path: + print _to_sentence(top_path[0]['seq']) + sys.stdout.flush() diff --git a/conv_seq2seq/download.sh b/conv_seq2seq/download.sh new file mode 100644 index 00000000..b1a924d2 --- /dev/null +++ b/conv_seq2seq/download.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +CUR_PATH=`pwd` +git clone https://github.com/moses-smt/mosesdecoder.git +git clone https://github.com/rizar/actor-critic-public + +export MOSES=`pwd`/mosesdecoder +export LVSR=`pwd`/actor-critic-public + +cd actor-critic-public/exp/ted +sh create_dataset.sh + +cd $CUR_PATH +mkdir data +cp actor-critic-public/exp/ted/prep/*-* data/ +cp actor-critic-public/exp/ted/vocab.* data/ + +cd data +python ../preprocess.py + +cd .. +rm -rf actor-critic-public mosesdecoder diff --git a/conv_seq2seq/infer.py b/conv_seq2seq/infer.py index eb46df55..c804a84e 100644 --- a/conv_seq2seq/infer.py +++ b/conv_seq2seq/infer.py @@ -36,7 +36,7 @@ def parse_args(): parser.add_argument( '--emb_size', type=int, - default=512, + default=256, help='Dimension of word embedding. (default: %(default)s)') parser.add_argument( '--pos_size', @@ -48,6 +48,11 @@ def parse_args(): type=float, default=0., help='Dropout rate. (default: %(default)s)') + parser.add_argument( + "--use_bn", + default=False, + type=distutils.util.strtobool, + help="Use batch normalization or not. (default: %(default)s)") parser.add_argument( "--use_gpu", default=False, @@ -64,36 +69,43 @@ def parse_args(): default=100, help="The maximum length of the sentence to be generated. (default: %(default)s)" ) + parser.add_argument( + "--batch_size", + default=1, + type=int, + help="Size of a mini-batch. (default: %(default)s)") parser.add_argument( "--beam_size", default=1, type=int, - help="The width of beam expasion. (default: %(default)s)") + help="The width of beam expansion. (default: %(default)s)") parser.add_argument( "--model_path", type=str, required=True, help="The path of trained model. (default: %(default)s)") + parser.add_argument( + "--is_show_attention", + default=False, + type=distutils.util.strtobool, + help="Whether to show attention weight or not. (default: %(default)s)") return parser.parse_args() -def to_sentence(seq, dictionary): - raw_sentence = [dictionary[id] for id in seq] - sentence = " ".join(raw_sentence) - return sentence - - def infer(infer_data_path, src_dict_path, trg_dict_path, model_path, enc_conv_blocks, dec_conv_blocks, - emb_dim=512, + emb_dim=256, pos_size=200, drop_rate=0., + use_bn=False, max_len=100, - beam_size=1): + batch_size=1, + beam_size=1, + is_show_attention=False): """ Inference. @@ -120,10 +132,14 @@ def infer(infer_data_path, :type pos_size: int :param drop_rate: Dropout rate. :type drop_rate: float + :param use_bn: Whether to use batch normalization or not. False is the default value. + :type use_bn: bool :param max_len: The maximum length of the sentence to be generated. :type max_len: int :param beam_size: The width of beam expansion. :type beam_size: int + :param is_show_attention: Whether to show attention weight or not. False is the default value. + :type is_show_attention: bool """ # load dict src_dict = reader.load_dict(src_dict_path) @@ -131,7 +147,7 @@ def infer(infer_data_path, src_dict_size = src_dict.__len__() trg_dict_size = trg_dict.__len__() - prob = conv_seq2seq( + prob, weight = conv_seq2seq( src_dict_size=src_dict_size, trg_dict_size=trg_dict_size, pos_size=pos_size, @@ -139,6 +155,7 @@ def infer(infer_data_path, enc_conv_blocks=enc_conv_blocks, dec_conv_blocks=dec_conv_blocks, drop_rate=drop_rate, + with_bn=use_bn, is_infer=True) # load parameters @@ -153,6 +170,26 @@ def infer(infer_data_path, pos_size=pos_size, padding_num=padding_num) + if is_show_attention: + attention_inferer = paddle.inference.Inference( + output_layer=weight, parameters=parameters) + for i, data in enumerate(infer_reader()): + src_len = len(data[0]) + trg_len = len(data[2]) + attention_weight = attention_inferer.infer( + [data], field='value', flatten_result=False) + attention_weight = [ + weight.reshape((trg_len, src_len)) + for weight in attention_weight + ] + print attention_weight + break + return + + infer_data = [] + for i, raw_data in enumerate(infer_reader()): + infer_data.append([raw_data[0], raw_data[1]]) + inferer = paddle.inference.Inference( output_layer=prob, parameters=parameters) @@ -162,15 +199,10 @@ def infer(infer_data_path, pos_size=pos_size, padding_num=padding_num, max_len=max_len, + batch_size=batch_size, beam_size=beam_size) - reverse_trg_dict = reader.get_reverse_dict(trg_dict) - for i, raw_data in enumerate(infer_reader()): - infer_data = [raw_data[0], raw_data[1]] - result = searcher.search_one_sample(infer_data) - sentence = to_sentence(result, reverse_trg_dict) - print sentence - sys.stdout.flush() + searcher.search(infer_data) return @@ -179,6 +211,8 @@ def main(): enc_conv_blocks = eval(args.enc_blocks) dec_conv_blocks = eval(args.dec_blocks) + sys.setrecursionlimit(10000) + paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) infer( @@ -191,8 +225,11 @@ def main(): emb_dim=args.emb_size, pos_size=args.pos_size, drop_rate=args.drop_rate, + use_bn=args.use_bn, max_len=args.max_len, - beam_size=args.beam_size) + batch_size=args.batch_size, + beam_size=args.beam_size, + is_show_attention=args.is_show_attention) if __name__ == '__main__': diff --git a/conv_seq2seq/model.py b/conv_seq2seq/model.py index 85f23862..21813af4 100644 --- a/conv_seq2seq/model.py +++ b/conv_seq2seq/model.py @@ -12,7 +12,8 @@ def gated_conv_with_batchnorm(input, context_len, context_start=None, learning_rate=1.0, - drop_rate=0.): + drop_rate=0., + with_bn=False): """ Definition of the convolution block. @@ -30,6 +31,9 @@ def gated_conv_with_batchnorm(input, :type learning_rate: float :param drop_rate: Dropout rate. :type drop_rate: float + :param with_bn: Whether to use batch normalization or not. False is the default + value. + :type with_bn: bool :return: The output of the convolution block. :rtype: LayerOutput """ @@ -50,18 +54,18 @@ def gated_conv_with_batchnorm(input, learning_rate=learning_rate), bias_attr=False) - batch_norm_conv = paddle.layer.batch_norm( - input=raw_conv, - act=paddle.activation.Linear(), - param_attr=paddle.attr.Param(learning_rate=learning_rate)) + if with_bn: + raw_conv = paddle.layer.batch_norm( + input=raw_conv, + act=paddle.activation.Linear(), + param_attr=paddle.attr.Param(learning_rate=learning_rate)) with paddle.layer.mixed(size=size) as conv: - conv += paddle.layer.identity_projection( - batch_norm_conv, size=size, offset=0) + conv += paddle.layer.identity_projection(raw_conv, size=size, offset=0) with paddle.layer.mixed(size=size, act=paddle.activation.Sigmoid()) as gate: gate += paddle.layer.identity_projection( - batch_norm_conv, size=size, offset=size) + raw_conv, size=size, offset=size) with paddle.layer.mixed(size=size) as gated_conv: gated_conv += paddle.layer.dotmul_operator(conv, gate) @@ -73,7 +77,8 @@ def encoder(token_emb, pos_emb, conv_blocks=[(256, 3)] * 5, num_attention=3, - drop_rate=0.1): + drop_rate=0., + with_bn=False): """ Definition of the encoder. @@ -89,6 +94,9 @@ def encoder(token_emb, :type num_attention: int :param drop_rate: Dropout rate. :type drop_rate: float + :param with_bn: Whether to use batch normalization or not. False is the default + value. + :type with_bn: bool :return: The input token encoding. :rtype: LayerOutput """ @@ -124,7 +132,8 @@ def encoder(token_emb, size=size, context_len=context_len, learning_rate=1.0 / (2.0 * num_attention), - drop_rate=drop_rate) + drop_rate=drop_rate, + with_bn=with_bn) with paddle.layer.mixed(size=size) as block_output: block_output += paddle.layer.identity_projection(residual) @@ -165,7 +174,7 @@ def attention(decoder_state, cur_embedding, encoded_vec, encoded_sum): :type encoded_vec: LayerOutput :param encoded_sum: The sum of the source token's encoding and embedding. :type encoded_sum: LayerOutput - :return: A context vector. + :return: A context vector and the attention weight. :rtype: LayerOutput """ residual = decoder_state @@ -182,7 +191,7 @@ def attention(decoder_state, cur_embedding, encoded_vec, encoded_sum): expanded = paddle.layer.expand(input=state_summary, expand_as=encoded_vec) - m = paddle.layer.linear_comb(weights=expanded, vectors=encoded_vec) + m = paddle.layer.dot_prod(input1=expanded, input2=encoded_vec) attention_weight = paddle.layer.fc( input=m, @@ -206,7 +215,7 @@ def attention(decoder_state, cur_embedding, encoded_vec, encoded_sum): # halve the variance of the sum attention_result = paddle.layer.slope_intercept( input=attention_result, slope=math.sqrt(0.5)) - return attention_result + return attention_result, attention_weight def decoder(token_emb, @@ -215,7 +224,8 @@ def decoder(token_emb, encoded_sum, dict_size, conv_blocks=[(256, 3)] * 3, - drop_rate=0.1): + drop_rate=0., + with_bn=False): """ Definition of the decoder. @@ -235,7 +245,10 @@ def decoder(token_emb, :type conv_blocks: list of tuple :param drop_rate: Dropout rate. :type drop_rate: float - :return: The probability of the predicted token. + :param with_bn: Whether to use batch normalization or not. False is the default + value. + :type with_bn: bool + :return: The probability of the predicted token and the attention weights. :rtype: LayerOutput """ @@ -261,6 +274,7 @@ def decoder(token_emb, initial_std=math.sqrt((1.0 - drop_rate) / embedding.size)), bias_attr=True, ) + weight = [] for (size, context_len) in conv_blocks: if block_input.size == size: residual = block_input @@ -276,7 +290,8 @@ def decoder(token_emb, size=size, context_len=context_len, context_start=0, - drop_rate=drop_rate) + drop_rate=drop_rate, + with_bn=with_bn) group_inputs = [ decoder_state, @@ -285,8 +300,9 @@ def decoder(token_emb, paddle.layer.StaticInput(input=encoded_sum), ] - conditional = paddle.layer.recurrent_group( + conditional, attention_weight = paddle.layer.recurrent_group( step=attention_step, input=group_inputs) + weight.append(attention_weight) block_output = paddle.layer.addto(input=[conditional, residual]) @@ -312,7 +328,7 @@ def decoder(token_emb, initial_std=math.sqrt((1.0 - drop_rate) / block_output.size)), bias_attr=True) - return decoder_out + return decoder_out, weight def conv_seq2seq(src_dict_size, @@ -321,7 +337,8 @@ def conv_seq2seq(src_dict_size, emb_dim, enc_conv_blocks=[(256, 3)] * 5, dec_conv_blocks=[(256, 3)] * 3, - drop_rate=0.1, + drop_rate=0., + with_bn=False, is_infer=False): """ Definition of convolutional sequence-to-sequence network. @@ -345,6 +362,8 @@ def conv_seq2seq(src_dict_size, :type dec_conv_blocks: list of tuple :param drop_rate: Dropout rate. :type drop_rate: float + :param with_bn: Whether to use batch normalization or not. False is the default value. + :type with_bn: bool :param is_infer: Whether infer or not. :type is_infer: bool :return: Cost or output layer. @@ -375,7 +394,8 @@ def conv_seq2seq(src_dict_size, pos_emb=src_pos_emb, conv_blocks=enc_conv_blocks, num_attention=num_attention, - drop_rate=drop_rate) + drop_rate=drop_rate, + with_bn=with_bn) trg = paddle.layer.data( name='trg_word', @@ -397,17 +417,18 @@ def conv_seq2seq(src_dict_size, name='trg_pos_emb', param_attr=paddle.attr.Param(initial_mean=0., initial_std=0.1)) - decoder_out = decoder( + decoder_out, weight = decoder( token_emb=trg_emb, pos_emb=trg_pos_emb, encoded_vec=encoded_vec, encoded_sum=encoded_sum, dict_size=trg_dict_size, conv_blocks=dec_conv_blocks, - drop_rate=drop_rate) + drop_rate=drop_rate, + with_bn=with_bn) if is_infer: - return decoder_out + return decoder_out, weight trg_next_word = paddle.layer.data( name='trg_next_word', diff --git a/conv_seq2seq/preprocess.py b/conv_seq2seq/preprocess.py new file mode 100644 index 00000000..1d5c7cdd --- /dev/null +++ b/conv_seq2seq/preprocess.py @@ -0,0 +1,30 @@ +#coding=utf-8 + +import cPickle + + +def concat_file(file1, file2, dst_file): + with open(dst_file, 'w') as dst: + with open(file1) as f1: + with open(file2) as f2: + for i, (line1, line2) in enumerate(zip(f1, f2)): + line1 = line1.strip() + line = line1 + '\t' + line2 + dst.write(line) + + +if __name__ == '__main__': + concat_file('dev.de-en.de', 'dev.de-en.en', 'dev') + concat_file('test.de-en.de', 'test.de-en.en', 'test') + concat_file('train.de-en.de', 'train.de-en.en', 'train') + + src_dict = cPickle.load(open('vocab.de')) + trg_dict = cPickle.load(open('vocab.en')) + + with open('src_dict', 'w') as f: + f.write('\n\nUNK\n') + f.writelines('\n'.join(src_dict.keys())) + + with open('trg_dict', 'w') as f: + f.write('\n\nUNK\n') + f.writelines('\n'.join(trg_dict.keys())) diff --git a/conv_seq2seq/reader.py b/conv_seq2seq/reader.py index 6d4db49f..ad420af5 100644 --- a/conv_seq2seq/reader.py +++ b/conv_seq2seq/reader.py @@ -18,7 +18,7 @@ def get_reverse_dict(dictionary): def load_data(data_file, src_dict, trg_dict): - UNK_IDX = src_dict[''] + UNK_IDX = src_dict['UNK'] with open(data_file, 'r') as f: for line in f: line_split = line.strip().split('\t') @@ -34,7 +34,7 @@ def load_data(data_file, src_dict, trg_dict): def data_reader(data_file, src_dict, trg_dict, pos_size, padding_num): def reader(): - UNK_IDX = src_dict[''] + UNK_IDX = src_dict['UNK'] word_padding = trg_dict.__len__() pos_padding = pos_size diff --git a/conv_seq2seq/train.py b/conv_seq2seq/train.py index c6ce0dff..e23d9625 100644 --- a/conv_seq2seq/train.py +++ b/conv_seq2seq/train.py @@ -40,7 +40,7 @@ def parse_args(): parser.add_argument( '--emb_size', type=int, - default=512, + default=256, help='Dimension of word embedding. (default: %(default)s)') parser.add_argument( '--pos_size', @@ -52,6 +52,11 @@ def parse_args(): type=float, default=0., help='Dropout rate. (default: %(default)s)') + parser.add_argument( + "--use_bn", + default=False, + type=distutils.util.strtobool, + help="Use batch normalization or not. (default: %(default)s)") parser.add_argument( "--use_gpu", default=False, @@ -116,9 +121,10 @@ def train(train_data_path, trg_dict_path, enc_conv_blocks, dec_conv_blocks, - emb_dim=512, + emb_dim=256, pos_size=200, drop_rate=0., + use_bn=False, batch_size=32, num_passes=15): """ @@ -147,6 +153,8 @@ def train(train_data_path, :type pos_size: int :param drop_rate: Dropout rate. :type drop_rate: float + :param use_bn: Whether to use batch normalization or not. False is the default value. + :type use_bn: bool :param batch_size: The size of a mini-batch. :type batch_size: int :param num_passes: The total number of the passes to train. @@ -169,6 +177,7 @@ def train(train_data_path, enc_conv_blocks=enc_conv_blocks, dec_conv_blocks=dec_conv_blocks, drop_rate=drop_rate, + with_bn=use_bn, is_infer=False) # create parameters and trainer @@ -203,7 +212,6 @@ def train(train_data_path, print "[%s]: Pass: %d, Batch: %d, TrainCost: %f, %s" % ( cur_time, event.pass_id, event.batch_id, event.cost, event.metrics) - else: sys.stdout.flush() if isinstance(event, paddle.event.EndPass): @@ -232,6 +240,8 @@ def main(): enc_conv_blocks = eval(args.enc_blocks) dec_conv_blocks = eval(args.dec_blocks) + sys.setrecursionlimit(10000) + paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) train( @@ -244,6 +254,7 @@ def main(): emb_dim=args.emb_size, pos_size=args.pos_size, drop_rate=args.drop_rate, + use_bn=args.use_bn, batch_size=args.batch_size, num_passes=args.num_passes) -- GitLab