提交 05541295 编写于 作者: Y Yi Wang 提交者: GitHub

Merge pull request #1560 from jacquesqiao/seq2seq-dataset

optimize Seq2seq dataset
import os
import paddle.v2 as paddle import paddle.v2 as paddle
from seqToseq_net_v2 import seqToseq_net_v2
# Data Definiation.
# TODO:This code should be merged to dataset package.
data_dir = "./data/pre-wmt14"
src_lang_dict = os.path.join(data_dir, 'src.dict')
trg_lang_dict = os.path.join(data_dir, 'trg.dict')
source_dict_dim = len(open(src_lang_dict, "r").readlines())
target_dict_dim = len(open(trg_lang_dict, "r").readlines())
def read_to_dict(dict_path):
with open(dict_path, "r") as fin:
out_dict = {
line.strip(): line_count
for line_count, line in enumerate(fin)
}
return out_dict
src_dict = read_to_dict(src_lang_dict)
trg_dict = read_to_dict(trg_lang_dict)
train_list = os.path.join(data_dir, 'train.list')
test_list = os.path.join(data_dir, 'test.list')
UNK_IDX = 2 def seqToseq_net(source_dict_dim, target_dict_dim):
START = "<s>" ### Network Architecture
END = "<e>" word_vector_dim = 512 # dimension of word vector
decoder_size = 512 # dimension of hidden unit in GRU Decoder network
encoder_size = 512 # dimension of hidden unit in GRU Encoder network
def _get_ids(s, dictionary):
words = s.strip().split() #### Encoder
return [dictionary[START]] + \ src_word_id = paddle.layer.data(
[dictionary.get(w, UNK_IDX) for w in words] + \ name='source_language_word',
[dictionary[END]] type=paddle.data_type.integer_value_sequence(source_dict_dim))
src_embedding = paddle.layer.embedding(
input=src_word_id,
def train_reader(file_name): size=word_vector_dim,
def reader(): param_attr=paddle.attr.ParamAttr(name='_source_language_embedding'))
with open(file_name, 'r') as f: src_forward = paddle.networks.simple_gru(
for line_count, line in enumerate(f): input=src_embedding, size=encoder_size)
line_split = line.strip().split('\t') src_backward = paddle.networks.simple_gru(
if len(line_split) != 2: input=src_embedding, size=encoder_size, reverse=True)
continue encoded_vector = paddle.layer.concat(input=[src_forward, src_backward])
src_seq = line_split[0] # one source sequence
src_ids = _get_ids(src_seq, src_dict) #### Decoder
with paddle.layer.mixed(size=decoder_size) as encoded_proj:
trg_seq = line_split[1] # one target sequence encoded_proj += paddle.layer.full_matrix_projection(
trg_words = trg_seq.split() input=encoded_vector)
trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words]
backward_first = paddle.layer.first_seq(input=src_backward)
# remove sequence whose length > 80 in training mode
if len(src_ids) > 80 or len(trg_ids) > 80: with paddle.layer.mixed(
continue size=decoder_size, act=paddle.activation.Tanh()) as decoder_boot:
trg_ids_next = trg_ids + [trg_dict[END]] decoder_boot += paddle.layer.full_matrix_projection(
trg_ids = [trg_dict[START]] + trg_ids input=backward_first)
yield src_ids, trg_ids, trg_ids_next def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
return reader decoder_mem = paddle.layer.memory(
name='gru_decoder', size=decoder_size, boot_layer=decoder_boot)
context = paddle.networks.simple_attention(
encoded_sequence=enc_vec,
encoded_proj=enc_proj,
decoder_state=decoder_mem)
with paddle.layer.mixed(size=decoder_size * 3) as decoder_inputs:
decoder_inputs += paddle.layer.full_matrix_projection(input=context)
decoder_inputs += paddle.layer.full_matrix_projection(
input=current_word)
gru_step = paddle.layer.gru_step(
name='gru_decoder',
input=decoder_inputs,
output_mem=decoder_mem,
size=decoder_size)
with paddle.layer.mixed(
size=target_dict_dim,
bias_attr=True,
act=paddle.activation.Softmax()) as out:
out += paddle.layer.full_matrix_projection(input=gru_step)
return out
decoder_group_name = "decoder_group"
group_input1 = paddle.layer.StaticInputV2(input=encoded_vector, is_seq=True)
group_input2 = paddle.layer.StaticInputV2(input=encoded_proj, is_seq=True)
group_inputs = [group_input1, group_input2]
trg_embedding = paddle.layer.embedding(
input=paddle.layer.data(
name='target_language_word',
type=paddle.data_type.integer_value_sequence(target_dict_dim)),
size=word_vector_dim,
param_attr=paddle.attr.ParamAttr(name='_target_language_embedding'))
group_inputs.append(trg_embedding)
# For decoder equipped with attention mechanism, in training,
# target embeding (the groudtruth) is the data input,
# while encoded source sequence is accessed to as an unbounded memory.
# Here, the StaticInput defines a read-only memory
# for the recurrent_group.
decoder = paddle.layer.recurrent_group(
name=decoder_group_name,
step=gru_decoder_with_attention,
input=group_inputs)
lbl = paddle.layer.data(
name='target_language_next_word',
type=paddle.data_type.integer_value_sequence(target_dict_dim))
cost = paddle.layer.classification_cost(input=decoder, label=lbl)
return cost
def main(): def main():
paddle.init(use_gpu=False, trainer_count=1) paddle.init(use_gpu=False, trainer_count=1)
# source and target dict dim.
dict_size = 30000
source_dict_dim = target_dict_dim = dict_size
# define network topology # define network topology
cost = seqToseq_net_v2(source_dict_dim, target_dict_dim) cost = seqToseq_net(source_dict_dim, target_dict_dim)
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
# define optimize method and trainer # define optimize method and trainer
...@@ -88,7 +118,7 @@ def main(): ...@@ -88,7 +118,7 @@ def main():
wmt14_reader = paddle.batch( wmt14_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
train_reader("data/pre-wmt14/train/train"), buf_size=8192), paddle.dataset.wmt14.train(dict_size=dict_size), buf_size=8192),
batch_size=5) batch_size=5)
# define event_handler callback # define event_handler callback
......
import paddle.v2 as paddle
def seqToseq_net_v2(source_dict_dim, target_dict_dim):
### Network Architecture
word_vector_dim = 512 # dimension of word vector
decoder_size = 512 # dimension of hidden unit in GRU Decoder network
encoder_size = 512 # dimension of hidden unit in GRU Encoder network
#### Encoder
src_word_id = paddle.layer.data(
name='source_language_word',
type=paddle.data_type.integer_value_sequence(source_dict_dim))
src_embedding = paddle.layer.embedding(
input=src_word_id,
size=word_vector_dim,
param_attr=paddle.attr.ParamAttr(name='_source_language_embedding'))
src_forward = paddle.networks.simple_gru(
input=src_embedding, size=encoder_size)
src_backward = paddle.networks.simple_gru(
input=src_embedding, size=encoder_size, reverse=True)
encoded_vector = paddle.layer.concat(input=[src_forward, src_backward])
#### Decoder
with paddle.layer.mixed(size=decoder_size) as encoded_proj:
encoded_proj += paddle.layer.full_matrix_projection(
input=encoded_vector)
backward_first = paddle.layer.first_seq(input=src_backward)
with paddle.layer.mixed(
size=decoder_size, act=paddle.activation.Tanh()) as decoder_boot:
decoder_boot += paddle.layer.full_matrix_projection(
input=backward_first)
def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
decoder_mem = paddle.layer.memory(
name='gru_decoder', size=decoder_size, boot_layer=decoder_boot)
context = paddle.networks.simple_attention(
encoded_sequence=enc_vec,
encoded_proj=enc_proj,
decoder_state=decoder_mem)
with paddle.layer.mixed(size=decoder_size * 3) as decoder_inputs:
decoder_inputs += paddle.layer.full_matrix_projection(input=context)
decoder_inputs += paddle.layer.full_matrix_projection(
input=current_word)
gru_step = paddle.layer.gru_step(
name='gru_decoder',
input=decoder_inputs,
output_mem=decoder_mem,
size=decoder_size)
with paddle.layer.mixed(
size=target_dict_dim,
bias_attr=True,
act=paddle.activation.Softmax()) as out:
out += paddle.layer.full_matrix_projection(input=gru_step)
return out
decoder_group_name = "decoder_group"
group_input1 = paddle.layer.StaticInputV2(input=encoded_vector, is_seq=True)
group_input2 = paddle.layer.StaticInputV2(input=encoded_proj, is_seq=True)
group_inputs = [group_input1, group_input2]
trg_embedding = paddle.layer.embedding(
input=paddle.layer.data(
name='target_language_word',
type=paddle.data_type.integer_value_sequence(target_dict_dim)),
size=word_vector_dim,
param_attr=paddle.attr.ParamAttr(name='_target_language_embedding'))
group_inputs.append(trg_embedding)
# For decoder equipped with attention mechanism, in training,
# target embeding (the groudtruth) is the data input,
# while encoded source sequence is accessed to as an unbounded memory.
# Here, the StaticInput defines a read-only memory
# for the recurrent_group.
decoder = paddle.layer.recurrent_group(
name=decoder_group_name,
step=gru_decoder_with_attention,
input=group_inputs)
lbl = paddle.layer.data(
name='target_language_next_word',
type=paddle.data_type.integer_value_sequence(target_dict_dim))
cost = paddle.layer.classification_cost(input=decoder, label=lbl)
return cost
...@@ -14,129 +14,92 @@ ...@@ -14,129 +14,92 @@
""" """
wmt14 dataset wmt14 dataset
""" """
import paddle.v2.dataset.common
import tarfile import tarfile
import os.path
import itertools import paddle.v2.dataset.common
__all__ = ['train', 'test', 'build_dict'] __all__ = ['train', 'test', 'build_dict']
URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz' URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
URL_TRAIN = 'http://localhost:8000/train.tgz' # this is a small set of data for test. The original data is too large and will be add later.
MD5_TRAIN = '72de99da2830ea5a3a2c4eb36092bbc7' URL_TRAIN = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
def word_count(f, word_freq=None): START = "<s>"
add = paddle.v2.dataset.common.dict_add END = "<e>"
if word_freq == None: UNK = "<unk>"
word_freq = {} UNK_IDX = 2
for l in f:
for w in l.strip().split(): def __read_to_dict__(tar_file, dict_size):
add(word_freq, w) def __to_dict__(fd, size):
add(word_freq, '<s>') out_dict = dict()
add(word_freq, '<e>') for line_count, line in enumerate(fd):
if line_count < size:
return word_freq out_dict[line.strip()] = line_count
else:
break
def get_word_dix(word_freq): return out_dict
TYPO_FREQ = 50
word_freq = filter(lambda x: x[1] > TYPO_FREQ, word_freq.items()) with tarfile.open(tar_file, mode='r') as f:
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) names = [
words, _ = list(zip(*word_freq_sorted)) each_item.name for each_item in f
word_idx = dict(zip(words, xrange(len(words)))) if each_item.name.endswith("src.dict")
word_idx['<unk>'] = len(words) ]
return word_idx assert len(names) == 1
src_dict = __to_dict__(f.extractfile(names[0]), dict_size)
names = [
def get_word_freq(train, dev): each_item.name for each_item in f
word_freq = word_count(train, word_count(dev)) if each_item.name.endswith("trg.dict")
if '<unk>' in word_freq: ]
# remove <unk> for now, since we will set it as last index assert len(names) == 1
del word_freq['<unk>'] trg_dict = __to_dict__(f.extractfile(names[0]), dict_size)
return word_freq return src_dict, trg_dict
def build_dict():
base_dir = './wmt14-data'
train_en_filename = base_dir + '/train/train.en'
train_fr_filename = base_dir + '/train/train.fr'
dev_en_filename = base_dir + '/dev/ntst1213.en'
dev_fr_filename = base_dir + '/dev/ntst1213.fr'
if not os.path.exists(train_en_filename) or not os.path.exists(
train_fr_filename):
with tarfile.open(
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14',
MD5_TRAIN)) as tf:
tf.extractall(base_dir)
if not os.path.exists(dev_en_filename) or not os.path.exists(
dev_fr_filename):
with tarfile.open(
paddle.v2.dataset.common.download(URL_DEV_TEST, 'wmt14',
MD5_DEV_TEST)) as tf:
tf.extractall(base_dir)
f_en = open(train_en_filename)
f_fr = open(train_fr_filename)
f_en_dev = open(dev_en_filename)
f_fr_dev = open(dev_fr_filename)
word_freq_en = get_word_freq(f_en, f_en_dev)
word_freq_fr = get_word_freq(f_fr, f_fr_dev)
f_en.close()
f_fr.close()
f_en_dev.close()
f_fr_dev.close()
return get_word_dix(word_freq_en), get_word_dix(word_freq_fr)
def reader_creator(directory, path_en, path_fr, URL, MD5, dict_en, dict_fr): def reader_creator(tar_file, file_name, dict_size):
def reader(): def reader():
if not os.path.exists(path_en) or not os.path.exists(path_fr): src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
with tarfile.open( with tarfile.open(tar_file, mode='r') as f:
paddle.v2.dataset.common.download(URL, 'wmt14', MD5)) as tf: names = [
tf.extractall(directory) each_item.name for each_item in f
if each_item.name.endswith(file_name)
f_en = open(path_en) ]
f_fr = open(path_fr) for name in names:
UNK_en = dict_en['<unk>'] for line in f.extractfile(name):
UNK_fr = dict_fr['<unk>'] line_split = line.strip().split('\t')
if len(line_split) != 2:
for en, fr in itertools.izip(f_en, f_fr): continue
src_ids = [dict_en.get(w, UNK_en) for w in en.strip().split()] src_seq = line_split[0] # one source sequence
tar_ids = [ src_words = src_seq.split()
dict_fr.get(w, UNK_fr) src_ids = [
for w in ['<s>'] + fr.strip().split() + ['<e>'] src_dict.get(w, UNK_IDX)
for w in [START] + src_words + [END]
] ]
trg_seq = line_split[1] # one target sequence
trg_words = trg_seq.split()
trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words]
# remove sequence whose length > 80 in training mode # remove sequence whose length > 80 in training mode
if len(src_ids) == 0 or len(tar_ids) <= 1 or len( if len(src_ids) > 80 or len(trg_ids) > 80:
src_ids) > 80 or len(tar_ids) > 80:
continue continue
trg_ids_next = trg_ids + [trg_dict[END]]
trg_ids = [trg_dict[START]] + trg_ids
yield src_ids, tar_ids[:-1], tar_ids[1:] yield src_ids, trg_ids, trg_ids_next
f_en.close()
f_fr.close()
return reader return reader
def train(dict_en, dict_fr): def train(dict_size):
directory = './wmt14-data' return reader_creator(
return reader_creator(directory, directory + '/train/train.en', paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
directory + '/train/train.fr', URL_TRAIN, MD5_TRAIN, 'train/train', dict_size)
dict_en, dict_fr)
def test(dict_en, dict_fr): def test(dict_size):
directory = './wmt14-data' return reader_creator(
return reader_creator(directory, directory + '/dev/ntst1213.en', paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
directory + '/dev/ntst1213.fr', URL_DEV_TEST, 'test/test', dict_size)
MD5_DEV_TEST, dict_en, dict_fr)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册