提交 bc039c59 编写于 作者: G guosheng

Add greedy search.

Add PPL metric.
上级 ae47e2a8
......@@ -28,7 +28,7 @@ from paddle.fluid.io import DataLoader
from model import Input, set_device
from args import parse_args
from seq2seq_base import BaseInferModel
from seq2seq_attn import AttentionInferModel
from seq2seq_attn import AttentionInferModel, AttentionGreedyInferModel
from reader import Seq2SeqDataset, Seq2SeqBatchSampler, SortType, prepare_infer_input
......@@ -87,7 +87,8 @@ def do_predict(args):
num_workers=0,
return_list=True)
model_maker = AttentionInferModel if args.attention else BaseInferModel
# model_maker = AttentionInferModel if args.attention else BaseInferModel
model_maker = AttentionGreedyInferModel if args.attention else BaseInferModel
model = model_maker(
args.src_vocab_size,
args.tar_vocab_size,
......@@ -111,6 +112,8 @@ def do_predict(args):
with io.open(args.infer_output_file, 'w', encoding='utf-8') as f:
for data in data_loader():
finished_seq = model.test(inputs=flatten(data))[0]
finished_seq = finished_seq[:, :, np.newaxis] if len(
finished_seq.shape == 2) else finished_seq
finished_seq = np.transpose(finished_seq, [0, 2, 1])
for ins in finished_seq:
for beam_idx, beam in enumerate(ins):
......
......@@ -44,7 +44,8 @@ def create_data_loader(args, device, for_train=True):
end_mark="</s>",
unk_mark="<unk>",
max_length=args.max_len if i == 0 else None,
truncate=True)
truncate=True,
trg_add_bos_eos=True)
(args.src_vocab_size, args.tar_vocab_size, bos_id, eos_id,
unk_id) = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(
......@@ -53,7 +54,8 @@ def create_data_loader(args, device, for_train=True):
batch_size=args.batch_size,
pool_size=args.batch_size * 20,
sort_type=SortType.POOL,
shuffle=False if args.enable_ce else True)
shuffle=False if args.enable_ce else True,
distribute_mode=True if i == 0 else False)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
......@@ -73,7 +75,7 @@ def prepare_train_input(insts, bos_id, eos_id, pad_id):
src, src_length = pad_batch_data(
[inst[0] for inst in insts], pad_id=pad_id)
trg, trg_length = pad_batch_data(
[[bos_id] + inst[1] + [eos_id] for inst in insts], pad_id=pad_id)
[inst[1] for inst in insts], pad_id=pad_id)
trg_length = trg_length - 1
return src, src_length, trg[:, :-1], trg_length, trg[:, 1:, np.newaxis]
......@@ -165,9 +167,24 @@ class TokenBatchCreator(object):
class SampleInfo(object):
def __init__(self, i, lens):
self.i = i
# to be consistent with origianl reader implementation
self.min_len = lens[0]
self.max_len = lens[0]
self.lens = lens
def get_ranges(self, min_length=None, max_length=None, truncate=False):
ranges = []
# source
if (min_length is None or self.lens[0] >= min_length) and (
max_length is None or self.lens[0] <= max_length or truncate):
end = max_length if truncate and max_length else self.lens[0]
ranges.append([0, end])
# target
if len(self.lens) == 2:
if (min_length is None or self.lens[1] >= min_length) and (
max_length is None or self.lens[1] <= max_length + 2 or
truncate):
end = max_length + 2 if truncate and max_length else self.lens[
1]
ranges.append([0, end])
return ranges if len(ranges) == len(self.lens) else None
class MinMaxFilter(object):
......@@ -197,6 +214,7 @@ class Seq2SeqDataset(Dataset):
end_mark="<e>",
unk_mark="<unk>",
trg_fpattern=None,
trg_add_bos_eos=False,
byte_data=False,
min_length=None,
max_length=None,
......@@ -220,6 +238,7 @@ class Seq2SeqDataset(Dataset):
self._min_length = min_length
self._max_length = max_length
self._truncate = truncate
self._trg_add_bos_eos = trg_add_bos_eos
self.load_src_trg_ids(fpattern, trg_fpattern)
def load_src_trg_ids(self, fpattern, trg_fpattern=None):
......@@ -238,8 +257,8 @@ class Seq2SeqDataset(Dataset):
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False,
add_end=False)
add_beg=True if self._trg_add_bos_eos else False,
add_end=True if self._trg_add_bos_eos else False)
converters = ComposedConverter([src_converter, trg_converter])
......@@ -252,13 +271,12 @@ class Seq2SeqDataset(Dataset):
fields = converters(line)
lens = [len(field) for field in fields]
sample = SampleInfo(i, lens)
if (self._min_length is None or
sample.min_len >= self._min_length) and (
self._max_length is None or
sample.max_len <= self._max_length or self._truncate):
for field, slot in zip(fields, slots):
slot.append(field[:self._max_length] if self._truncate and
self._max_length is not None else field)
field_ranges = sample.get_ranges(self._min_length,
self._max_length, self._truncate)
if field_ranges:
for field, field_range, slot in zip(fields, field_ranges,
slots):
slot.append(field[field_range[0]:field_range[1]])
self._sample_infos.append(sample)
def _load_lines(self, fpattern, trg_fpattern=None):
......
......@@ -152,7 +152,7 @@ class AttentionModel(Model):
self.decoder = Decoder(trg_vocab_size, embed_dim, hidden_size,
num_layers, dropout_prob, init_scale)
def forward(self, src, src_length, trg, trg_length):
def forward(self, src, src_length, trg):
# encoder
encoder_output, encoder_final_state = self.encoder(src, src_length)
......@@ -174,11 +174,7 @@ class AttentionModel(Model):
# decoder with attentioon
predict = self.decoder(trg, decoder_initial_states, encoder_output,
encoder_padding_mask)
# for target padding mask
mask = layers.sequence_mask(
trg_length, maxlen=layers.shape(trg)[1], dtype=predict.dtype)
return predict, mask
return predict
class AttentionInferModel(AttentionModel):
......@@ -242,3 +238,90 @@ class AttentionInferModel(AttentionModel):
encoder_output=encoder_output,
encoder_padding_mask=encoder_padding_mask)
return rs
class GreedyEmbeddingHelper(fluid.layers.GreedyEmbeddingHelper):
def __init__(self, embedding_fn, start_tokens, end_token):
if isinstance(start_tokens, int):
self.need_convert_start_tokens = True
self.start_token_value = start_tokens
super(GreedyEmbeddingHelper, self).__init__(embedding_fn, start_tokens,
end_token)
def initialize(self, batch_ref=None):
if getattr(self, "need_convert_start_tokens", False):
assert batch_ref is not None, (
"Need to give batch_ref to get batch size "
"to initialize the tensor for start tokens.")
self.start_tokens = fluid.layers.fill_constant_batch_size_like(
input=fluid.layers.utils.flatten(batch_ref)[0],
shape=[-1],
dtype="int64",
value=self.start_token_value,
input_dim_idx=0)
return super(GreedyEmbeddingHelper, self).initialize()
class BasicDecoder(fluid.layers.BasicDecoder):
def initialize(self, initial_cell_states):
(initial_inputs,
initial_finished) = self.helper.initialize(initial_cell_states)
return initial_inputs, initial_cell_states, initial_finished
class AttentionGreedyInferModel(AttentionModel):
def __init__(self,
src_vocab_size,
trg_vocab_size,
embed_dim,
hidden_size,
num_layers,
dropout_prob=0.,
bos_id=0,
eos_id=1,
beam_size=1,
max_out_len=256):
args = dict(locals())
args.pop("self")
args.pop("__class__", None) # py3
args.pop("beam_size", None)
self.bos_id = args.pop("bos_id")
self.eos_id = args.pop("eos_id")
self.max_out_len = args.pop("max_out_len")
super(AttentionGreedyInferModel, self).__init__(**args)
# dynamic decoder for inference
decoder_helper = GreedyEmbeddingHelper(
start_tokens=bos_id,
end_token=eos_id,
embedding_fn=self.decoder.embedder)
decoder = BasicDecoder(
cell=self.decoder.lstm_attention.cell,
helper=decoder_helper,
output_fn=self.decoder.output_layer)
self.greedy_search_decoder = DynamicDecode(
decoder, max_step_num=max_out_len, is_test=True)
def forward(self, src, src_length):
# encoding
encoder_output, encoder_final_state = self.encoder(src, src_length)
# decoder initial states
decoder_initial_states = [
encoder_final_state,
self.decoder.lstm_attention.cell.get_initial_states(
batch_ref=encoder_output, shape=[self.hidden_size])
]
# attention mask to avoid paying attention on padddings
src_mask = layers.sequence_mask(
src_length,
maxlen=layers.shape(src)[1],
dtype=encoder_output.dtype)
encoder_padding_mask = (src_mask - 1.0) * 1e9
encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1])
# dynamic decoding with beam search
rs, _ = self.greedy_search_decoder(
inits=decoder_initial_states,
encoder_output=encoder_output,
encoder_padding_mask=encoder_padding_mask)
return rs.sample_ids
......@@ -27,7 +27,10 @@ class CrossEntropyCriterion(Loss):
super(CrossEntropyCriterion, self).__init__()
def forward(self, outputs, labels):
(predict, mask), label = outputs, labels[0]
predict, (trg_length, label) = outputs[0], labels
# for target padding mask
mask = layers.sequence_mask(
trg_length, maxlen=layers.shape(predict)[1], dtype=predict.dtype)
cost = layers.softmax_with_cross_entropy(
logits=predict, label=label, soft_label=False)
......@@ -151,17 +154,13 @@ class BaseModel(Model):
self.decoder = Decoder(trg_vocab_size, embed_dim, hidden_size,
num_layers, dropout_prob, init_scale)
def forward(self, src, src_length, trg, trg_length):
def forward(self, src, src_length, trg):
# encoder
encoder_output, encoder_final_states = self.encoder(src, src_length)
# decoder
predict = self.decoder(trg, encoder_final_states)
# for target padding mask
mask = layers.sequence_mask(
trg_length, maxlen=layers.shape(trg)[1], dtype=predict.dtype)
return predict, mask
return predict
class BaseInferModel(BaseModel):
......
......@@ -24,6 +24,7 @@ import paddle.fluid as fluid
from paddle.fluid.io import DataLoader
from model import Input, set_device
from metrics import Metric
from callbacks import ProgBarLogger
from args import parse_args
from seq2seq_base import BaseModel, CrossEntropyCriterion
......@@ -31,6 +32,65 @@ from seq2seq_attn import AttentionModel
from reader import create_data_loader
class TrainCallback(ProgBarLogger):
def __init__(self, args, ppl, verbose=2):
super(TrainCallback, self).__init__(1, verbose)
# control metric
self.ppl = ppl
self.batch_size = args.batch_size
def on_train_begin(self, logs=None):
super(TrainCallback, self).on_train_begin(logs)
self.train_metrics += ["ppl"] # remove loss to not print it
self.ppl.reset()
def on_train_batch_end(self, step, logs=None):
batch_loss = logs["loss"][0]
self.ppl.total_loss += batch_loss * self.batch_size
logs["ppl"] = np.exp(self.ppl.total_loss / self.ppl.word_count)
if step > 0 and step % self.ppl.reset_freq == 0:
self.ppl.reset()
super(TrainCallback, self).on_train_batch_end(step, logs)
def on_eval_begin(self, logs=None):
super(TrainCallback, self).on_eval_begin(logs)
self.eval_metrics = ["ppl"]
self.ppl.reset()
def on_eval_batch_end(self, step, logs=None):
batch_loss = logs["loss"][0]
self.ppl.total_loss += batch_loss * self.batch_size
logs["ppl"] = np.exp(self.ppl.total_loss / self.ppl.word_count)
super(TrainCallback, self).on_eval_batch_end(step, logs)
class PPL(Metric):
def __init__(self, reset_freq=100, name=None):
super(PPL, self).__init__()
self._name = name or "ppl"
self.reset_freq = reset_freq
self.reset()
def add_metric_op(self, pred, label):
seq_length = label[0]
word_num = fluid.layers.reduce_sum(seq_length)
return word_num
def update(self, word_num):
self.word_count += word_num
return word_num
def reset(self):
self.total_loss = 0
self.word_count = 0
def accumulate(self):
return self.word_count
def name(self):
return self._name
def do_train(args):
device = set_device("gpu" if args.use_gpu else "cpu")
fluid.enable_dygraph(device) if args.eager_run else None
......@@ -47,10 +107,13 @@ def do_train(args):
[None], "int64", name="src_length"),
Input(
[None, None], "int64", name="trg_word"),
]
labels = [
Input(
[None], "int64", name="trg_length"),
Input(
[None, None, 1], "int64", name="label"),
]
labels = [Input([None, None, 1], "int64", name="label"), ]
# def dataloader
train_loader, eval_loader = create_data_loader(args, device)
......@@ -63,15 +126,20 @@ def do_train(args):
learning_rate=args.learning_rate, parameter_list=model.parameters())
optimizer._grad_clip = fluid.clip.GradientClipByGlobalNorm(
clip_norm=args.max_grad_norm)
ppl_metric = PPL()
model.prepare(
optimizer, CrossEntropyCriterion(), inputs=inputs, labels=labels)
optimizer,
CrossEntropyCriterion(),
ppl_metric,
inputs=inputs,
labels=labels)
model.fit(train_data=train_loader,
eval_data=eval_loader,
epochs=args.max_epoch,
eval_freq=1,
save_freq=1,
save_dir=args.model_path,
log_freq=1)
callbacks=[TrainCallback(args, ppl_metric)])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册