提交 bca3c03d 编写于 作者: G guosheng

Add reader, ParallelExecutor and refine for Transformer

上级 e7684f07
class TrainTaskConfig(object):
use_gpu = False
use_gpu = True
# the epoch number to train.
pass_num = 2
pass_num = 30
# the number of sequences contained in a mini-batch.
batch_size = 64
batch_size = 32
# the hyper parameters for Adam optimizer.
learning_rate = 0.001
# This static learning_rate will multiply LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate = 1
beta1 = 0.9
beta2 = 0.98
eps = 1e-9
# the parameters for learning rate scheduling.
warmup_steps = 4000
# the flag indicating to use average loss or sum loss when training.
use_avg_cost = False
use_avg_cost = True
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps = 0.1
# the directory for saving trained models.
model_dir = "trained_models"
# the directory for saving checkpoints.
ckpt_dir = "trained_ckpts"
# the directory for loading checkpoint.
# If provided, continue training from the checkpoint.
ckpt_path = None
# the parameter to initialize the learning rate scheduler.
# It should be provided if use checkpoints, since the checkpoint doesn't
# include the training step counter currently.
start_step = 0
class InferTaskConfig(object):
use_gpu = False
use_gpu = True
# the number of examples in one run for sequence generation.
batch_size = 10
# the parameters for beam search.
beam_size = 5
max_length = 30
# the number of decoded sentences to output.
n_best = 1
# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = False
# the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model"
......@@ -47,30 +54,24 @@ class ModelHyperParams(object):
# <unk> token has alreay been added. As for the <pad> token, any token
# included in dict can be used to pad, since the paddings' loss will be
# masked out and make no effect on parameter gradients.
# size of source word dictionary.
src_vocab_size = 10000
# size of target word dictionay
trg_vocab_size = 10000
# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2
# max length of sequences.
# The size of position encoding table should at least plus 1, since the
# sinusoid position encoding starts from 1 and 0 can be used as the padding
# token for position encoding.
max_length = 50
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 1024
......@@ -86,34 +87,116 @@ class ModelHyperParams(object):
dropout = 0.1
def merge_cfg_from_list(cfg_list, g_cfgs):
"""
Set the above global configurations using the cfg_list.
"""
assert len(cfg_list) % 2 == 0
for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
for g_cfg in g_cfgs:
if hasattr(g_cfg, key):
try:
value = eval(value)
except SyntaxError: # for file path
pass
setattr(g_cfg, key, value)
break
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1]
"src_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
# The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1]
"src_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias":
[(1, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"],
# This shape input is used to reshape the output of embedding layer.
"src_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
"src_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in self attention.
"src_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
# The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(1, ModelHyperParams.n_head,
(ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(1, ModelHyperParams.n_head,
(ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"],
# This shape input is used to reshape the output of embedding layer.
"trg_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
"trg_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in self attention.
"trg_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# This shape input is used to reshape before softmax in encoder-decoder
# attention.
"trg_src_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in encoder-decoder
# attention.
"trg_src_attn_post_softmax_shape": [(4L, ), "int32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(1, (ModelHyperParams.max_length + 1),
ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"],
}
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table", )
# Names of all data layers in encoder listed in order.
encoder_input_data_names = (
# separated inputs for different usages.
encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias",
"src_slf_attn_bias", )
encoder_util_input_fields = (
"src_data_shape",
"src_slf_attn_pre_softmax_shape",
"src_slf_attn_post_softmax_shape", )
# Names of all data layers in decoder listed in order.
decoder_input_data_names = (
decoder_data_input_fields = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output", )
decoder_util_input_fields = (
"trg_data_shape",
"trg_slf_attn_pre_softmax_shape",
"trg_slf_attn_post_softmax_shape",
"trg_src_attn_pre_softmax_shape",
"trg_src_attn_post_softmax_shape",
"enc_output", )
# Names of label related data layers listed in order.
label_data_names = (
"trg_src_attn_post_softmax_shape", )
label_data_input_fields = (
"lbl_word",
"lbl_weight", )
import argparse
import numpy as np
import paddle
......@@ -6,9 +7,52 @@ import paddle.fluid as fluid
import model
from model import wrap_encoder as encoder
from model import wrap_decoder as decoder
from config import InferTaskConfig, ModelHyperParams, \
encoder_input_data_names, decoder_input_data_names
from config import *
from train import pad_batch_data
import reader
def parse_args():
parser = argparse.ArgumentParser("Training for Transformer.")
parser.add_argument(
"--src_vocab_fpath",
type=str,
required=True,
help="The path of vocabulary file of source language.")
parser.add_argument(
"--trg_vocab_fpath",
type=str,
required=True,
help="The path of vocabulary file of target language.")
parser.add_argument(
"--test_file_pattern",
type=str,
required=True,
help="The pattern to match test data files.")
parser.add_argument(
"--batch_size",
type=int,
default=50,
help="The number of examples in one run for sequence generation.")
parser.add_argument(
"--pool_size",
type=int,
default=10000,
help="The buffer size to pool data.")
parser.add_argument(
"--special_token",
type=str,
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
'opts',
help='See config.py for all options',
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
merge_cfg_from_list(args.opts, [InferTaskConfig, ModelHyperParams])
return args
def translate_batch(exe,
......@@ -243,7 +287,7 @@ def translate_batch(exe,
return seqs, scores[:, :n_best].tolist()
def main():
def infer(args):
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -292,13 +336,23 @@ def main():
decoder_program = fluid.io.get_inference_program(
target_vars=[predict], main_program=decoder_program)
test_data = paddle.batch(
paddle.dataset.wmt16.test(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=InferTaskConfig.batch_size)
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
batch_size=args.batch_size,
use_token_batch=False,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
clip_last_batch=False)
trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
trg_idx2word = test_data._load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
......@@ -320,15 +374,16 @@ def main():
(output_eos or idx != eos_idx),
seq)
for batch_id, data in enumerate(test_data()):
for batch_id, data in enumerate(test_data.batch_generator()):
batch_seqs, batch_scores = translate_batch(
exe,
[item[0] for item in data],
encoder_program,
encoder_input_data_names,
encoder_data_input_fields + encoder_util_input_fields,
[enc_output.name],
decoder_program,
decoder_input_data_names,
decoder_data_input_fields[:-1] + decoder_util_input_fields +
(decoder_data_input_fields[-1], ),
[predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_length,
......@@ -351,4 +406,5 @@ def main():
if __name__ == "__main__":
main()
args = parse_args()
infer(args)
......@@ -4,8 +4,7 @@ import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from config import TrainTaskConfig, pos_enc_param_names, \
encoder_input_data_names, decoder_input_data_names, label_data_names
from config import *
def position_encoding_init(n_position, d_pos_vec):
......@@ -171,7 +170,6 @@ def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
......@@ -206,7 +204,6 @@ def prepare_encoder(src_word,
"""Add word embeddings and position encodings.
The output tensor has a shape of:
[batch_size, max_src_length_in_batch, d_model].
This module is used at the bottom of the encoder stacks.
"""
src_word_emb = layers.embedding(
......@@ -245,7 +242,6 @@ def encoder_layer(enc_input,
pre_softmax_shape=None,
post_softmax_shape=None):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
......@@ -306,7 +302,6 @@ def decoder_layer(dec_input,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
"""
......@@ -394,116 +389,19 @@ def decoder(dec_input,
return dec_output
def make_inputs(input_data_names,
n_head,
d_model,
max_length,
is_pos,
slf_attn_bias_flag,
src_attn_bias_flag,
enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=True):
def make_all_inputs(input_fields):
"""
Define the input data layers for the transformer model.
"""
input_layers = []
batch_size = 1 # Only for the infer-shape in compile time.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
# The actual data shape of word is:
# [batch_size * max_len_in_batch, 1]
word = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1],
dtype="int64",
append_batch_size=False)
input_layers += [word]
# This is used for position data or label weight.
# The actual data shape of pos is:
# [batch_size * max_len_in_batch, 1]
pos = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1],
dtype="int64" if is_pos else "float32",
append_batch_size=False)
input_layers += [pos]
if slf_attn_bias_flag:
# This input is used to remove attention weights on paddings for the
# encoder and to remove attention weights on subsequent words for the
# decoder.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, n_head, max_len_in_batch, max_len_in_batch]
slf_attn_bias = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length],
dtype="float32",
append_batch_size=False)
input_layers += [slf_attn_bias]
if src_attn_bias_flag:
# This input is used to remove attention weights on paddings. It's used
# in encoder-decoder attention.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, n_head, trg_max_len_in_batch, src_max_len_in_batch]
src_attn_bias = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length],
dtype="float32",
append_batch_size=False)
input_layers += [src_attn_bias]
if data_shape_flag:
# This input is used to reshape the output of embedding layer.
data_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[3],
dtype="int32",
inputs = []
for input_field in input_fields:
input_var = layers.data(
name=input_field,
shape=input_descs[input_field][0],
dtype=input_descs[input_field][1],
append_batch_size=False)
input_layers += [data_shape]
if slf_attn_shape_flag:
# This shape input is used to reshape before softmax in self attention.
slf_attn_pre_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[2],
dtype="int32",
append_batch_size=False)
input_layers += [slf_attn_pre_softmax_shape]
# This shape input is used to reshape after softmax in self attention.
slf_attn_post_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[4],
dtype="int32",
append_batch_size=False)
input_layers += [slf_attn_post_softmax_shape]
if src_attn_shape_flag:
# This shape input is used to reshape before softmax in encoder-decoder
# attention.
src_attn_pre_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[2],
dtype="int32",
append_batch_size=False)
input_layers += [src_attn_pre_softmax_shape]
# This shape input is used to reshape after softmax in encoder-decoder
# attention.
src_attn_post_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[4],
dtype="int32",
append_batch_size=False)
input_layers += [src_attn_post_softmax_shape]
if enc_output_flag:
# This input is used in independent decoder program for inference.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, max_len_in_batch, d_model]
enc_output = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, max_length, d_model],
dtype="float32",
append_batch_size=False)
input_layers += [enc_output]
return input_layers
inputs.append(input_var)
return inputs
def transformer(
......@@ -516,19 +414,10 @@ def transformer(
d_value,
d_model,
d_inner_hid,
dropout_rate, ):
enc_inputs = make_inputs(
encoder_input_data_names,
n_head,
d_model,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=False,
enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=False)
dropout_rate,
label_smooth_eps, ):
enc_inputs = make_all_inputs(encoder_data_input_fields +
encoder_util_input_fields)
enc_output = wrap_encoder(
src_vocab_size,
......@@ -542,18 +431,8 @@ def transformer(
dropout_rate,
enc_inputs, )
dec_inputs = make_inputs(
decoder_input_data_names,
n_head,
d_model,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=True,
enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=True)
dec_inputs = make_all_inputs(decoder_data_input_fields[:-1] +
decoder_util_input_fields)
predict = wrap_decoder(
trg_vocab_size,
......@@ -570,19 +449,17 @@ def transformer(
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
gold, weights = make_inputs(
label_data_names,
n_head,
d_model,
max_length,
is_pos=False,
slf_attn_bias_flag=False,
src_attn_bias_flag=False,
enc_output_flag=False,
data_shape_flag=False,
slf_attn_shape_flag=False,
src_attn_shape_flag=False)
cost = layers.softmax_with_cross_entropy(logits=predict, label=gold)
label, weights = make_all_inputs(label_data_input_fields)
if label_smooth_eps:
label = layers.label_smooth(
label=layers.one_hot(
input=label, depth=trg_vocab_size),
epsilon=label_smooth_eps)
cost = layers.softmax_with_cross_entropy(
logits=predict,
label=label,
soft_label=True if label_smooth_eps else False)
# cost = layers.softmax_with_cross_entropy(logits=predict, label=gold)
weighted_cost = cost * weights
sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights)
......@@ -607,18 +484,8 @@ def wrap_encoder(src_vocab_size,
# This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias, src_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
make_inputs(
encoder_input_data_names,
n_head,
d_model,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=False,
enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=False)
make_all_inputs(encoder_data_input_fields +
encoder_util_input_fields)
else:
src_word, src_pos, src_slf_attn_bias, src_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
......@@ -663,20 +530,10 @@ def wrap_decoder(trg_vocab_size,
if dec_inputs is None:
# This is used to implement independent decoder program in inference.
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, slf_attn_pre_softmax_shape, \
enc_output, trg_data_shape, slf_attn_pre_softmax_shape, \
slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
src_attn_post_softmax_shape, enc_output = make_inputs(
decoder_input_data_names,
n_head,
d_model,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=True,
enc_output_flag=True,
data_shape_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=True)
src_attn_post_softmax_shape = make_all_inputs(
decoder_data_input_fields + decoder_util_input_fields)
else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, slf_attn_pre_softmax_shape, \
......
......@@ -14,27 +14,24 @@ class LearningRateScheduler(object):
def __init__(self,
d_model,
warmup_steps,
place,
learning_rate=0.001,
current_steps=0,
name="learning_rate"):
self.current_steps = current_steps
self.warmup_steps = warmup_steps
self.d_model = d_model
self.static_lr = learning_rate
self.learning_rate = layers.create_global_var(
name=name,
shape=[1],
value=float(learning_rate),
dtype="float32",
persistable=True)
self.place = place
def update_learning_rate(self, data_input):
def update_learning_rate(self):
self.current_steps += 1
lr_value = np.power(self.d_model, -0.5) * np.min([
np.power(self.current_steps, -0.5),
np.power(self.warmup_steps, -1.5) * self.current_steps
])
lr_tensor = fluid.LoDTensor()
lr_tensor.set(np.array([lr_value], dtype="float32"), self.place)
data_input[self.learning_rate.name] = lr_tensor
]) * self.static_lr
return np.array([lr_value], dtype="float32")
import os
import tarfile
import glob
import random
class SortType(object):
GLOBAL = 'global'
POOL = 'pool'
NONE = "none"
class EndEpoch():
pass
class Pool(object):
def __init__(self, sample_generator, pool_size, sort):
self._pool_size = pool_size
self._pool = []
self._sample_generator = sample_generator()
self._end = False
self._sort = sort
def _fill(self):
while len(self._pool) < self._pool_size and not self._end:
try:
sample = self._sample_generator.next()
self._pool.append(sample)
except StopIteration as e:
self._end = True
break
if self._sort:
self._pool.sort(
key=lambda sample: max(len(sample[0]), len(sample[1])) if len(sample) > 1 else len(sample[0])
)
if self._end and len(self._pool) < self._pool_size:
self._pool.append(EndEpoch())
def push_back(self, samples):
if len(self._pool) != 0:
raise Exception("Pool should be empty.")
if len(samples) >= self._pool_size:
raise Exception("Capacity of pool should be greater than a batch. "
"Please enlarge `pool_size`.")
for sample in samples:
self._pool.append(sample)
self._fill()
def next(self, look=False):
if len(self._pool) == 0:
return None
else:
return self._pool[0] if look else self._pool.pop(0)
class DataReader(object):
"""
The data reader loads all data from files and produces batches of data
in the way corresponding to settings.
number of tokens or number of sequences.
"""
def __init__(self,
src_vocab_fpath,
trg_vocab_fpath,
fpattern,
batch_size,
pool_size,
sort_type=SortType.NONE,
clip_last_batch=True,
tar_fname=None,
min_length=0,
max_length=100,
shuffle=True,
shuffle_batch=False,
use_token_batch=False,
delimiter="\t",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
seed=0):
"""
Load all data from files and set the settings to make mini-batches.
:param src_vocab_fpath: The path of vocabulary file of source language.
:type src_vocab_fpath: basestring
:param trg_vocab_fpath: The path of vocabulary file of target language.
:type trg_vocab_fpath: basestring
:param fpattern: The pattern to match data files.
:type fpattern: basestring
:param batch_size: The number of sequences contained in a mini-batch.
or the maximum number of tokens (include paddings) contained in a
mini-batch.
:type batch_size: int
:param pool_size: The buffer size to pool data.
:type pool_size: int
:param sort_type: The grain to sort by length: 'global' for all
instances; 'pool' for instances in pool; 'none' for no sort.
:type sort_type: basestring
:param sort_type: The grain to sort by length: 'global' for all
instances; 'pool' for instances in pool; 'none' for no sort.
:type sort_type: basestring
:param clip_last_batch: Whether to clip the last uncompleted batch.
:type clip_last_batch: bool
:param tar_fname: The data file in tar if fpattern matches a tar file.
:type tar_fname: basestring
:param min_length: The minimum length used to filt sequences.
:type min_length: int
:param max_length: The maximum length used to filt sequences.
:type max_length: int
:param shuffle: Whether to shuffle all instances.
:type shuffle: bool
:param shuffle_batch: Whether to shuffle the generated batches.
:type shuffle_batch: bool
:param use_token_batch: Whether to produce batch data according to
token number.
:type use_token_batch: bool
:param delimiter: The delimiter used to split source and target in each
line of data file.
:type delimiter: basestring
:param start_mark: The token representing for the beginning of
sentences in dictionary.
:type start_mark: basestring
:param end_mark: The token representing for the end of sentences
in dictionary.
:type end_mark: basestring
:param unk_mark: The token representing for unknown word in dictionary.
:type unk_mark: basestring
:param seed: The seed for random.
:type seed: int
"""
self._src_vocab = self._load_dict(src_vocab_fpath)
self._only_src = True
if trg_vocab_fpath is not None:
self._trg_vocab = self._load_dict(trg_vocab_fpath)
self._only_src = False
self._pool_size = pool_size
self._batch_size = batch_size
self._use_token_batch = use_token_batch
self._sort_type = sort_type
self._clip_last_batch = clip_last_batch
self._shuffle = shuffle
self._shuffle_batch = shuffle_batch
self._min_length = min_length
self._max_length = max_length
self._delimiter = delimiter
self._epoch_batches = []
src_seq_words, trg_seq_words = self._load_data(fpattern, tar_fname)
self._src_seq_ids = [[
self._src_vocab.get(word, self._src_vocab.get(unk_mark))
for word in ([start_mark] + src_seq + [end_mark])
] for src_seq in src_seq_words]
self._sample_count = len(self._src_seq_ids)
if not self._only_src:
self._trg_seq_ids = [[
self._trg_vocab.get(word, self._trg_vocab.get(unk_mark))
for word in ([start_mark] + trg_seq + [end_mark])
] for trg_seq in trg_seq_words]
if len(self._trg_seq_ids) != self._sample_count:
raise Exception("Inconsistent sample count between "
"source sequences and target sequences.")
else:
self._trg_seq_ids = None
self._sample_idxs = [i for i in xrange(self._sample_count)]
self._sorted = False
random.seed(seed)
def _parse_file(self, f_obj):
src_seq_words = []
trg_seq_words = []
for line in f_obj:
fields = line.strip().split(self._delimiter)
if len(fields) != 2 or (self._only_src and len(fields) != 1):
continue
sample_words = []
is_valid_sample = True
max_len = -1
for i, seq in enumerate(fields):
seq_words = seq.split()
max_len = max(max_len, len(seq_words))
if len(seq_words) == 0 or \
len(seq_words) < self._min_length or \
len(seq_words) > self._max_length or \
(self._use_token_batch and max_len > self._batch_size):
is_valid_sample = False
break
sample_words.append(seq_words)
if not is_valid_sample: continue
src_seq_words.append(sample_words[0])
if not self._only_src:
trg_seq_words.append(sample_words[1])
return (src_seq_words, trg_seq_words)
def _load_data(self, fpattern, tar_fname):
fpaths = glob.glob(fpattern)
src_seq_words = []
trg_seq_words = []
if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], 'r')
part_file_data = self._parse_file(f.extractfile(tar_fname))
src_seq_words = part_file_data[0]
trg_seq_words = part_file_data[1]
else:
for fpath in fpaths:
if not os.path.isfile(fpath):
raise IOError("Invalid file: %s" % fpath)
part_file_data = self._parse_file(open(fpath, 'r'))
src_seq_words.extend(part_file_data[0])
trg_seq_words.extend(part_file_data[1])
return src_seq_words, trg_seq_words
def _load_dict(self, dict_path, reverse=False):
word_dict = {}
with open(dict_path, "r") as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip()
else:
word_dict[line.strip()] = idx
return word_dict
def _sample_generator(self):
if self._sort_type == SortType.GLOBAL:
if not self._sorted:
self._sample_idxs.sort(
key=lambda idx: max(len(self._src_seq_ids[idx]),
len(self._trg_seq_ids[idx] if not self._only_src else 0))
)
self._sorted = True
elif self._shuffle:
random.shuffle(self._sample_idxs)
for sample_idx in self._sample_idxs:
if self._only_src:
yield (self._src_seq_ids[sample_idx])
else:
yield (self._src_seq_ids[sample_idx],
self._trg_seq_ids[sample_idx][:-1],
self._trg_seq_ids[sample_idx][1:])
def batch_generator(self):
pool = Pool(self._sample_generator, self._pool_size, True
if self._sort_type == SortType.POOL else False)
def next_batch():
batch_data = []
max_len = -1
batch_max_seq_len = -1
while True:
sample = pool.next(look=True)
if sample is None:
pool.push_back(batch_data)
batch_data = []
continue
if isinstance(sample, EndEpoch):
return batch_data, batch_max_seq_len, True
max_len = max(max_len, len(sample[0]))
if not self._only_src:
max_len = max(max_len, len(sample[1]))
if self._use_token_batch:
if max_len * (len(batch_data) + 1) < self._batch_size:
batch_max_seq_len = max_len
batch_data.append(pool.next())
else:
return batch_data, batch_max_seq_len, False
else:
if len(batch_data) < self._batch_size:
batch_max_seq_len = max_len
batch_data.append(pool.next())
else:
return batch_data, batch_max_seq_len, False
if not self._shuffle_batch:
batch_data, batch_max_seq_len, last_batch = next_batch()
while not last_batch:
yield batch_data
batch_data, batch_max_seq_len, last_batch = next_batch()
batch_size = len(batch_data)
if self._use_token_batch:
batch_size *= batch_max_seq_len
if (not self._clip_last_batch and len(batch_data) > 0) \
or (batch_size == self._batch_size):
yield batch_data
else:
# should re-generate batches
if self._sort_type == SortType.POOL \
or len(self._epoch_batches) == 0:
self._epoch_batches = []
batch_data, batch_max_seq_len, last_batch = next_batch()
while not last_batch:
self._epoch_batches.append(batch_data)
batch_data, batch_max_seq_len, last_batch = next_batch()
batch_size = len(batch_data)
if self._use_token_batch:
batch_size *= batch_max_seq_len
if (not self._clip_last_batch and len(batch_data) > 0) \
or (batch_size == self._batch_size):
self._epoch_batches.append(batch_data)
random.shuffle(self._epoch_batches)
for batch_data in self._epoch_batches:
yield batch_data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册