未验证 提交 3bec90d5 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #493 from ranqiu92/convseq2seq

Update conv_seq2seq.
......@@ -4,8 +4,13 @@ This model implements the work in the following paper:
Jonas Gehring, Micheal Auli, David Grangier, et al. Convolutional Sequence to Sequence Learning. Association for Computational Linguistics (ACL), 2017
# Data Preparation
- The data used in this tutorial can be downloaded by runing:
- In this tutorial, each line in a data file contains one sample and each sample consists of a source sentence and a target sentence. And the two sentences are seperated by '\t'. So, to use your own data, it should be organized as follows:
```bash
sh download.sh
```
- Each line in the data file contains one sample and each sample consists of a source sentence and a target sentence. And the two sentences are seperated by '\t'. So, to use your own data, it should be organized as follows:
```
<source sentence>\t<target sentence>
......@@ -16,15 +21,16 @@ Jonas Gehring, Micheal Auli, David Grangier, et al. Convolutional Sequence to Se
```bash
python train.py \
--train_data_path ./data/train_data \
--test_data_path ./data/test_data \
--train_data_path ./data/train \
--test_data_path ./data/test \
--src_dict_path ./data/src_dict \
--trg_dict_path ./data/trg_dict \
--enc_blocks "[(256, 3)] * 5" \
--dec_blocks "[(256, 3)] * 3" \
--emb_size 256 \
--pos_size 200 \
--drop_rate 0.1 \
--drop_rate 0.2 \
--use_bn False \
--use_gpu False \
--trainer_count 1 \
--batch_size 32 \
......@@ -37,22 +43,24 @@ Jonas Gehring, Micheal Auli, David Grangier, et al. Convolutional Sequence to Se
```bash
python infer.py \
--infer_data_path ./data/infer_data \
--infer_data_path ./data/dev \
--src_dict_path ./data/src_dict \
--trg_dict_path ./data/trg_dict \
--enc_blocks "[(256, 3)] * 5" \
--dec_blocks "[(256, 3)] * 3" \
--emb_size 256 \
--pos_size 200 \
--drop_rate 0.1 \
--drop_rate 0.2 \
--use_bn False \
--use_gpu False \
--trainer_count 1 \
--max_len 100 \
--batch_size 256 \
--beam_size 1 \
--is_show_attention False \
--model_path ./params.pass-0.tar.gz \
1>infer_result 2>infer.log
```
# Notes
Currently, beam search will forward the encoder multiple times when predicting each target word, which requires extra computations. And we will fix it later.
Since PaddlePaddle of current version doesn't support weight normalization, we use batch normalization instead to confirm convergence when the network is deep.
......@@ -2,8 +2,11 @@
import sys
import time
import math
import numpy as np
import reader
class BeamSearch(object):
"""
......@@ -16,44 +19,42 @@ class BeamSearch(object):
trg_dict,
pos_size,
padding_num,
batch_size=1,
beam_size=1,
max_len=100):
self.inferer = inferer
self.trg_dict = trg_dict
self.reverse_trg_dict = reader.get_reverse_dict(trg_dict)
self.word_padding = trg_dict.__len__()
self.pos_size = pos_size
self.pos_padding = pos_size
self.padding_num = padding_num
self.win_len = padding_num + 1
self.max_len = max_len
self.batch_size = batch_size
self.beam_size = beam_size
def get_beam_input(self, pre_beam_list, infer_data):
def get_beam_input(self, batch, sample_list):
"""
Get input for generation at the current iteration.
"""
beam_input = []
if len(pre_beam_list) == 0:
cur_trg = [self.word_padding
] * self.padding_num + [self.trg_dict['<s>']]
cur_trg_pos = [self.pos_padding] * self.padding_num + [0]
beam_input.append(infer_data + [cur_trg] + [cur_trg_pos])
else:
for seq in pre_beam_list:
if len(seq) < self.win_len:
cur_trg = [self.word_padding] * (
self.win_len - len(seq) - 1
) + [self.trg_dict['<s>']] + seq
cur_trg_pos = [self.pos_padding] * (
self.win_len - len(seq) - 1) + [0] + range(1,
len(seq) + 1)
for sample_id in sample_list:
for path in self.candidate_path[sample_id]:
if len(path['seq']) < self.win_len:
cur_trg = [self.word_padding] * (self.win_len - len(
path['seq']) - 1) + [self.trg_dict['<s>']] + path['seq']
cur_trg_pos = [self.pos_padding] * (self.win_len - len(
path['seq']) - 1) + [0] + range(1, len(path['seq']) + 1)
else:
cur_trg = seq[-self.win_len:]
cur_trg = path['seq'][-self.win_len:]
cur_trg_pos = range(
len(seq) + 1 - self.win_len, len(seq) + 1)
len(path['seq']) + 1 - self.win_len,
len(path['seq']) + 1)
beam_input.append(batch[sample_id] + [cur_trg] + [cur_trg_pos])
beam_input.append(infer_data + [cur_trg] + [cur_trg_pos])
return beam_input
def get_prob(self, beam_input):
......@@ -64,100 +65,136 @@ class BeamSearch(object):
prob = self.inferer.infer(beam_input, field='value')[row_list, :]
return prob
def get_candidate(self, pre_beam_list, pre_beam_score, prob):
def _top_k(self, prob, k):
"""
Get top beam_size tokens and their scores for each beam.
Get indices of the words with k highest probablities.
"""
if prob.ndim == 1:
candidate_id = prob.argsort()[-self.beam_size:][::-1]
candidate_log_prob = np.log(prob[candidate_id])
else:
candidate_id = prob.argsort()[:, -self.beam_size:][:, ::-1]
candidate_log_prob = np.zeros_like(candidate_id).astype('float32')
for j in range(len(pre_beam_list)):
candidate_log_prob[j, :] = np.log(prob[j, candidate_id[j, :]])
if pre_beam_score.size > 0:
candidate_score = candidate_log_prob + pre_beam_score.reshape(
(pre_beam_score.size, 1))
else:
candidate_score = candidate_log_prob
return candidate_id, candidate_score
def prune(self, candidate_id, candidate_score, pre_beam_list,
completed_seq_list, completed_seq_score, completed_seq_min_score):
"""
Pruning process of the beam search. During the process, beam_size most possible sequences
are selected for the beam in the next iteration. Besides, their scores and the minimum score
of the completed sequences are updated.
"""
candidate_id = candidate_id.flatten()
candidate_score = candidate_score.flatten()
topk_idx = candidate_score.argsort()[-self.beam_size:][::-1].tolist()
topk_seq_idx = [idx / self.beam_size for idx in topk_idx]
next_beam = []
beam_score = []
for j in range(len(topk_idx)):
if candidate_id[topk_idx[j]] == self.trg_dict['<e>']:
if len(
completed_seq_list
) < self.beam_size or completed_seq_min_score <= candidate_score[
topk_idx[j]]:
completed_seq_list.append(pre_beam_list[topk_seq_idx[j]])
completed_seq_score.append(candidate_score[topk_idx[j]])
if completed_seq_min_score is None or (
completed_seq_min_score >=
candidate_score[topk_idx[j]] and
len(completed_seq_list) < self.beam_size):
completed_seq_min_score = candidate_score[topk_idx[j]]
else:
seq = pre_beam_list[topk_seq_idx[
j]] + [candidate_id[topk_idx[j]]]
score = candidate_score[topk_idx[j]]
next_beam.append(seq)
beam_score.append(score)
beam_score = np.array(beam_score)
return next_beam, beam_score, completed_seq_min_score
return prob.argsort()[-k:][::-1]
def search_one_sample(self, infer_data):
def beam_expand(self, prob, sample_list):
"""
In every iteration step, the model predicts the possible next words.
For each input sentence, the top beam_size words are selected as candidates.
"""
Beam search process for one sample.
top_words = np.apply_along_axis(self._top_k, 1, prob, self.beam_size)
candidate_words = [[]] * len(self.candidate_path)
idx = 0
for sample_id in sample_list:
for seq_id, path in enumerate(self.candidate_path[sample_id]):
for w in top_words[idx, :]:
score = path['score'] + math.log(prob[idx, w])
candidate_words[sample_id] = candidate_words[sample_id] + [
{
'word': w,
'score': score,
'seq_id': seq_id
}
]
idx = idx + 1
return candidate_words
def beam_shrink(self, candidate_words, sample_list):
"""
completed_seq_list = []
completed_seq_score = []
completed_seq_min_score = None
uncompleted_seq_list = [[]]
uncompleted_seq_score = np.zeros(0)
Pruning process of the beam search. During the process, beam_size most post possible
sequences are selected for the beam in the next generation.
"""
new_path = [[]] * len(self.candidate_path)
for sample_id in sample_list:
beam_words = sorted(
candidate_words[sample_id],
key=lambda x: x['score'],
reverse=True)[:self.beam_size]
complete_seq_min_score = None
complete_path_num = len(self.complete_path[sample_id])
if complete_path_num > 0:
complete_seq_min_score = min(self.complete_path[sample_id],
key=lambda x: x['score'])['score']
if complete_path_num >= self.beam_size:
beam_words_max_score = beam_words[0]['score']
if beam_words_max_score < complete_seq_min_score:
continue
for w in beam_words:
if w['word'] == self.trg_dict['<e>']:
if complete_path_num < self.beam_size or complete_seq_min_score <= w[
'score']:
seq = self.candidate_path[sample_id][w['seq_id']]['seq']
self.complete_path[sample_id] = self.complete_path[
sample_id] + [{
'seq': seq,
'score': w['score']
}]
if complete_seq_min_score is None or complete_seq_min_score > w[
'score']:
complete_seq_min_score = w['score']
else:
seq = self.candidate_path[sample_id][w['seq_id']]['seq'] + [
w['word']
]
new_path[sample_id] = new_path[sample_id] + [{
'seq':
seq,
'score':
w['score']
}]
return new_path
def search_one_batch(self, batch):
"""
Perform beam search on one mini-batch.
"""
real_size = len(batch)
self.candidate_path = [[{'seq': [], 'score': 0.}]] * real_size
self.complete_path = [[]] * real_size
sample_list = range(real_size)
for i in xrange(self.max_len):
beam_input = self.get_beam_input(uncompleted_seq_list, infer_data)
beam_input = self.get_beam_input(batch, sample_list)
prob = self.get_prob(beam_input)
candidate_id, candidate_score = self.get_candidate(
uncompleted_seq_list, uncompleted_seq_score, prob)
candidate_words = self.beam_expand(prob, sample_list)
new_path = self.beam_shrink(candidate_words, sample_list)
self.candidate_path = new_path
sample_list = [
sample_id for sample_id in sample_list
if len(new_path[sample_id]) > 0
]
uncompleted_seq_list, uncompleted_seq_score, completed_seq_min_score = self.prune(
candidate_id, candidate_score, uncompleted_seq_list,
completed_seq_list, completed_seq_score,
completed_seq_min_score)
if len(uncompleted_seq_list) == 0:
break
if len(completed_seq_list) >= self.beam_size:
seq_max_score = uncompleted_seq_score.max()
if seq_max_score < completed_seq_min_score:
uncompleted_seq_list = []
if len(sample_list) == 0:
break
final_seq_list = completed_seq_list + uncompleted_seq_list
final_score = np.concatenate(
(np.array(completed_seq_score), uncompleted_seq_score))
max_id = final_score.argmax()
top_seq = final_seq_list[max_id]
return top_seq
final_path = []
for i in xrange(real_size):
top_path = sorted(
self.complete_path[i] + self.candidate_path[i],
key=lambda x: x['score'],
reverse=True)[:self.beam_size]
final_path.append(top_path)
return final_path
def search(self, infer_data):
"""
Perform beam search on all data.
"""
def _to_sentence(seq):
raw_sentence = [self.reverse_trg_dict[id] for id in seq]
sentence = " ".join(raw_sentence)
return sentence
for pos in xrange(0, len(infer_data), self.batch_size):
batch = infer_data[pos:min(pos + self.batch_size, len(infer_data))]
self.final_path = self.search_one_batch(batch)
for top_path in self.final_path:
print _to_sentence(top_path[0]['seq'])
sys.stdout.flush()
#!/usr/bin/env bash
CUR_PATH=`pwd`
git clone https://github.com/moses-smt/mosesdecoder.git
git clone https://github.com/rizar/actor-critic-public
export MOSES=`pwd`/mosesdecoder
export LVSR=`pwd`/actor-critic-public
cd actor-critic-public/exp/ted
sh create_dataset.sh
cd $CUR_PATH
mkdir data
cp actor-critic-public/exp/ted/prep/*-* data/
cp actor-critic-public/exp/ted/vocab.* data/
cd data
python ../preprocess.py
cd ..
rm -rf actor-critic-public mosesdecoder
......@@ -36,7 +36,7 @@ def parse_args():
parser.add_argument(
'--emb_size',
type=int,
default=512,
default=256,
help='Dimension of word embedding. (default: %(default)s)')
parser.add_argument(
'--pos_size',
......@@ -48,6 +48,11 @@ def parse_args():
type=float,
default=0.,
help='Dropout rate. (default: %(default)s)')
parser.add_argument(
"--use_bn",
default=False,
type=distutils.util.strtobool,
help="Use batch normalization or not. (default: %(default)s)")
parser.add_argument(
"--use_gpu",
default=False,
......@@ -64,36 +69,43 @@ def parse_args():
default=100,
help="The maximum length of the sentence to be generated. (default: %(default)s)"
)
parser.add_argument(
"--batch_size",
default=1,
type=int,
help="Size of a mini-batch. (default: %(default)s)")
parser.add_argument(
"--beam_size",
default=1,
type=int,
help="The width of beam expasion. (default: %(default)s)")
help="The width of beam expansion. (default: %(default)s)")
parser.add_argument(
"--model_path",
type=str,
required=True,
help="The path of trained model. (default: %(default)s)")
parser.add_argument(
"--is_show_attention",
default=False,
type=distutils.util.strtobool,
help="Whether to show attention weight or not. (default: %(default)s)")
return parser.parse_args()
def to_sentence(seq, dictionary):
raw_sentence = [dictionary[id] for id in seq]
sentence = " ".join(raw_sentence)
return sentence
def infer(infer_data_path,
src_dict_path,
trg_dict_path,
model_path,
enc_conv_blocks,
dec_conv_blocks,
emb_dim=512,
emb_dim=256,
pos_size=200,
drop_rate=0.,
use_bn=False,
max_len=100,
beam_size=1):
batch_size=1,
beam_size=1,
is_show_attention=False):
"""
Inference.
......@@ -120,10 +132,14 @@ def infer(infer_data_path,
:type pos_size: int
:param drop_rate: Dropout rate.
:type drop_rate: float
:param use_bn: Whether to use batch normalization or not. False is the default value.
:type use_bn: bool
:param max_len: The maximum length of the sentence to be generated.
:type max_len: int
:param beam_size: The width of beam expansion.
:type beam_size: int
:param is_show_attention: Whether to show attention weight or not. False is the default value.
:type is_show_attention: bool
"""
# load dict
src_dict = reader.load_dict(src_dict_path)
......@@ -131,7 +147,7 @@ def infer(infer_data_path,
src_dict_size = src_dict.__len__()
trg_dict_size = trg_dict.__len__()
prob = conv_seq2seq(
prob, weight = conv_seq2seq(
src_dict_size=src_dict_size,
trg_dict_size=trg_dict_size,
pos_size=pos_size,
......@@ -139,6 +155,7 @@ def infer(infer_data_path,
enc_conv_blocks=enc_conv_blocks,
dec_conv_blocks=dec_conv_blocks,
drop_rate=drop_rate,
with_bn=use_bn,
is_infer=True)
# load parameters
......@@ -153,6 +170,26 @@ def infer(infer_data_path,
pos_size=pos_size,
padding_num=padding_num)
if is_show_attention:
attention_inferer = paddle.inference.Inference(
output_layer=weight, parameters=parameters)
for i, data in enumerate(infer_reader()):
src_len = len(data[0])
trg_len = len(data[2])
attention_weight = attention_inferer.infer(
[data], field='value', flatten_result=False)
attention_weight = [
weight.reshape((trg_len, src_len))
for weight in attention_weight
]
print attention_weight
break
return
infer_data = []
for i, raw_data in enumerate(infer_reader()):
infer_data.append([raw_data[0], raw_data[1]])
inferer = paddle.inference.Inference(
output_layer=prob, parameters=parameters)
......@@ -162,15 +199,10 @@ def infer(infer_data_path,
pos_size=pos_size,
padding_num=padding_num,
max_len=max_len,
batch_size=batch_size,
beam_size=beam_size)
reverse_trg_dict = reader.get_reverse_dict(trg_dict)
for i, raw_data in enumerate(infer_reader()):
infer_data = [raw_data[0], raw_data[1]]
result = searcher.search_one_sample(infer_data)
sentence = to_sentence(result, reverse_trg_dict)
print sentence
sys.stdout.flush()
searcher.search(infer_data)
return
......@@ -179,6 +211,8 @@ def main():
enc_conv_blocks = eval(args.enc_blocks)
dec_conv_blocks = eval(args.dec_blocks)
sys.setrecursionlimit(10000)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
infer(
......@@ -191,8 +225,11 @@ def main():
emb_dim=args.emb_size,
pos_size=args.pos_size,
drop_rate=args.drop_rate,
use_bn=args.use_bn,
max_len=args.max_len,
beam_size=args.beam_size)
batch_size=args.batch_size,
beam_size=args.beam_size,
is_show_attention=args.is_show_attention)
if __name__ == '__main__':
......
......@@ -12,7 +12,8 @@ def gated_conv_with_batchnorm(input,
context_len,
context_start=None,
learning_rate=1.0,
drop_rate=0.):
drop_rate=0.,
with_bn=False):
"""
Definition of the convolution block.
......@@ -30,6 +31,9 @@ def gated_conv_with_batchnorm(input,
:type learning_rate: float
:param drop_rate: Dropout rate.
:type drop_rate: float
:param with_bn: Whether to use batch normalization or not. False is the default
value.
:type with_bn: bool
:return: The output of the convolution block.
:rtype: LayerOutput
"""
......@@ -50,18 +54,18 @@ def gated_conv_with_batchnorm(input,
learning_rate=learning_rate),
bias_attr=False)
batch_norm_conv = paddle.layer.batch_norm(
if with_bn:
raw_conv = paddle.layer.batch_norm(
input=raw_conv,
act=paddle.activation.Linear(),
param_attr=paddle.attr.Param(learning_rate=learning_rate))
with paddle.layer.mixed(size=size) as conv:
conv += paddle.layer.identity_projection(
batch_norm_conv, size=size, offset=0)
conv += paddle.layer.identity_projection(raw_conv, size=size, offset=0)
with paddle.layer.mixed(size=size, act=paddle.activation.Sigmoid()) as gate:
gate += paddle.layer.identity_projection(
batch_norm_conv, size=size, offset=size)
raw_conv, size=size, offset=size)
with paddle.layer.mixed(size=size) as gated_conv:
gated_conv += paddle.layer.dotmul_operator(conv, gate)
......@@ -73,7 +77,8 @@ def encoder(token_emb,
pos_emb,
conv_blocks=[(256, 3)] * 5,
num_attention=3,
drop_rate=0.1):
drop_rate=0.,
with_bn=False):
"""
Definition of the encoder.
......@@ -89,6 +94,9 @@ def encoder(token_emb,
:type num_attention: int
:param drop_rate: Dropout rate.
:type drop_rate: float
:param with_bn: Whether to use batch normalization or not. False is the default
value.
:type with_bn: bool
:return: The input token encoding.
:rtype: LayerOutput
"""
......@@ -124,7 +132,8 @@ def encoder(token_emb,
size=size,
context_len=context_len,
learning_rate=1.0 / (2.0 * num_attention),
drop_rate=drop_rate)
drop_rate=drop_rate,
with_bn=with_bn)
with paddle.layer.mixed(size=size) as block_output:
block_output += paddle.layer.identity_projection(residual)
......@@ -165,7 +174,7 @@ def attention(decoder_state, cur_embedding, encoded_vec, encoded_sum):
:type encoded_vec: LayerOutput
:param encoded_sum: The sum of the source token's encoding and embedding.
:type encoded_sum: LayerOutput
:return: A context vector.
:return: A context vector and the attention weight.
:rtype: LayerOutput
"""
residual = decoder_state
......@@ -182,7 +191,7 @@ def attention(decoder_state, cur_embedding, encoded_vec, encoded_sum):
expanded = paddle.layer.expand(input=state_summary, expand_as=encoded_vec)
m = paddle.layer.linear_comb(weights=expanded, vectors=encoded_vec)
m = paddle.layer.dot_prod(input1=expanded, input2=encoded_vec)
attention_weight = paddle.layer.fc(
input=m,
......@@ -206,7 +215,7 @@ def attention(decoder_state, cur_embedding, encoded_vec, encoded_sum):
# halve the variance of the sum
attention_result = paddle.layer.slope_intercept(
input=attention_result, slope=math.sqrt(0.5))
return attention_result
return attention_result, attention_weight
def decoder(token_emb,
......@@ -215,7 +224,8 @@ def decoder(token_emb,
encoded_sum,
dict_size,
conv_blocks=[(256, 3)] * 3,
drop_rate=0.1):
drop_rate=0.,
with_bn=False):
"""
Definition of the decoder.
......@@ -235,7 +245,10 @@ def decoder(token_emb,
:type conv_blocks: list of tuple
:param drop_rate: Dropout rate.
:type drop_rate: float
:return: The probability of the predicted token.
:param with_bn: Whether to use batch normalization or not. False is the default
value.
:type with_bn: bool
:return: The probability of the predicted token and the attention weights.
:rtype: LayerOutput
"""
......@@ -261,6 +274,7 @@ def decoder(token_emb,
initial_std=math.sqrt((1.0 - drop_rate) / embedding.size)),
bias_attr=True, )
weight = []
for (size, context_len) in conv_blocks:
if block_input.size == size:
residual = block_input
......@@ -276,7 +290,8 @@ def decoder(token_emb,
size=size,
context_len=context_len,
context_start=0,
drop_rate=drop_rate)
drop_rate=drop_rate,
with_bn=with_bn)
group_inputs = [
decoder_state,
......@@ -285,8 +300,9 @@ def decoder(token_emb,
paddle.layer.StaticInput(input=encoded_sum),
]
conditional = paddle.layer.recurrent_group(
conditional, attention_weight = paddle.layer.recurrent_group(
step=attention_step, input=group_inputs)
weight.append(attention_weight)
block_output = paddle.layer.addto(input=[conditional, residual])
......@@ -312,7 +328,7 @@ def decoder(token_emb,
initial_std=math.sqrt((1.0 - drop_rate) / block_output.size)),
bias_attr=True)
return decoder_out
return decoder_out, weight
def conv_seq2seq(src_dict_size,
......@@ -321,7 +337,8 @@ def conv_seq2seq(src_dict_size,
emb_dim,
enc_conv_blocks=[(256, 3)] * 5,
dec_conv_blocks=[(256, 3)] * 3,
drop_rate=0.1,
drop_rate=0.,
with_bn=False,
is_infer=False):
"""
Definition of convolutional sequence-to-sequence network.
......@@ -345,6 +362,8 @@ def conv_seq2seq(src_dict_size,
:type dec_conv_blocks: list of tuple
:param drop_rate: Dropout rate.
:type drop_rate: float
:param with_bn: Whether to use batch normalization or not. False is the default value.
:type with_bn: bool
:param is_infer: Whether infer or not.
:type is_infer: bool
:return: Cost or output layer.
......@@ -375,7 +394,8 @@ def conv_seq2seq(src_dict_size,
pos_emb=src_pos_emb,
conv_blocks=enc_conv_blocks,
num_attention=num_attention,
drop_rate=drop_rate)
drop_rate=drop_rate,
with_bn=with_bn)
trg = paddle.layer.data(
name='trg_word',
......@@ -397,17 +417,18 @@ def conv_seq2seq(src_dict_size,
name='trg_pos_emb',
param_attr=paddle.attr.Param(initial_mean=0., initial_std=0.1))
decoder_out = decoder(
decoder_out, weight = decoder(
token_emb=trg_emb,
pos_emb=trg_pos_emb,
encoded_vec=encoded_vec,
encoded_sum=encoded_sum,
dict_size=trg_dict_size,
conv_blocks=dec_conv_blocks,
drop_rate=drop_rate)
drop_rate=drop_rate,
with_bn=with_bn)
if is_infer:
return decoder_out
return decoder_out, weight
trg_next_word = paddle.layer.data(
name='trg_next_word',
......
#coding=utf-8
import cPickle
def concat_file(file1, file2, dst_file):
with open(dst_file, 'w') as dst:
with open(file1) as f1:
with open(file2) as f2:
for i, (line1, line2) in enumerate(zip(f1, f2)):
line1 = line1.strip()
line = line1 + '\t' + line2
dst.write(line)
if __name__ == '__main__':
concat_file('dev.de-en.de', 'dev.de-en.en', 'dev')
concat_file('test.de-en.de', 'test.de-en.en', 'test')
concat_file('train.de-en.de', 'train.de-en.en', 'train')
src_dict = cPickle.load(open('vocab.de'))
trg_dict = cPickle.load(open('vocab.en'))
with open('src_dict', 'w') as f:
f.write('<s>\n<e>\nUNK\n')
f.writelines('\n'.join(src_dict.keys()))
with open('trg_dict', 'w') as f:
f.write('<s>\n<e>\nUNK\n')
f.writelines('\n'.join(trg_dict.keys()))
......@@ -18,7 +18,7 @@ def get_reverse_dict(dictionary):
def load_data(data_file, src_dict, trg_dict):
UNK_IDX = src_dict['<unk>']
UNK_IDX = src_dict['UNK']
with open(data_file, 'r') as f:
for line in f:
line_split = line.strip().split('\t')
......@@ -34,7 +34,7 @@ def load_data(data_file, src_dict, trg_dict):
def data_reader(data_file, src_dict, trg_dict, pos_size, padding_num):
def reader():
UNK_IDX = src_dict['<unk>']
UNK_IDX = src_dict['UNK']
word_padding = trg_dict.__len__()
pos_padding = pos_size
......
......@@ -40,7 +40,7 @@ def parse_args():
parser.add_argument(
'--emb_size',
type=int,
default=512,
default=256,
help='Dimension of word embedding. (default: %(default)s)')
parser.add_argument(
'--pos_size',
......@@ -52,6 +52,11 @@ def parse_args():
type=float,
default=0.,
help='Dropout rate. (default: %(default)s)')
parser.add_argument(
"--use_bn",
default=False,
type=distutils.util.strtobool,
help="Use batch normalization or not. (default: %(default)s)")
parser.add_argument(
"--use_gpu",
default=False,
......@@ -116,9 +121,10 @@ def train(train_data_path,
trg_dict_path,
enc_conv_blocks,
dec_conv_blocks,
emb_dim=512,
emb_dim=256,
pos_size=200,
drop_rate=0.,
use_bn=False,
batch_size=32,
num_passes=15):
"""
......@@ -147,6 +153,8 @@ def train(train_data_path,
:type pos_size: int
:param drop_rate: Dropout rate.
:type drop_rate: float
:param use_bn: Whether to use batch normalization or not. False is the default value.
:type use_bn: bool
:param batch_size: The size of a mini-batch.
:type batch_size: int
:param num_passes: The total number of the passes to train.
......@@ -169,6 +177,7 @@ def train(train_data_path,
enc_conv_blocks=enc_conv_blocks,
dec_conv_blocks=dec_conv_blocks,
drop_rate=drop_rate,
with_bn=use_bn,
is_infer=False)
# create parameters and trainer
......@@ -203,7 +212,6 @@ def train(train_data_path,
print "[%s]: Pass: %d, Batch: %d, TrainCost: %f, %s" % (
cur_time, event.pass_id, event.batch_id, event.cost,
event.metrics)
else:
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
......@@ -232,6 +240,8 @@ def main():
enc_conv_blocks = eval(args.enc_blocks)
dec_conv_blocks = eval(args.dec_blocks)
sys.setrecursionlimit(10000)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
train(
......@@ -244,6 +254,7 @@ def main():
emb_dim=args.emb_size,
pos_size=args.pos_size,
drop_rate=args.drop_rate,
use_bn=args.use_bn,
batch_size=args.batch_size,
num_passes=args.num_passes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册