提交 bc039c59 编写于 作者: G guosheng

Add greedy search.

Add PPL metric.
上级 ae47e2a8
...@@ -28,7 +28,7 @@ from paddle.fluid.io import DataLoader ...@@ -28,7 +28,7 @@ from paddle.fluid.io import DataLoader
from model import Input, set_device from model import Input, set_device
from args import parse_args from args import parse_args
from seq2seq_base import BaseInferModel from seq2seq_base import BaseInferModel
from seq2seq_attn import AttentionInferModel from seq2seq_attn import AttentionInferModel, AttentionGreedyInferModel
from reader import Seq2SeqDataset, Seq2SeqBatchSampler, SortType, prepare_infer_input from reader import Seq2SeqDataset, Seq2SeqBatchSampler, SortType, prepare_infer_input
...@@ -87,7 +87,8 @@ def do_predict(args): ...@@ -87,7 +87,8 @@ def do_predict(args):
num_workers=0, num_workers=0,
return_list=True) return_list=True)
model_maker = AttentionInferModel if args.attention else BaseInferModel # model_maker = AttentionInferModel if args.attention else BaseInferModel
model_maker = AttentionGreedyInferModel if args.attention else BaseInferModel
model = model_maker( model = model_maker(
args.src_vocab_size, args.src_vocab_size,
args.tar_vocab_size, args.tar_vocab_size,
...@@ -111,6 +112,8 @@ def do_predict(args): ...@@ -111,6 +112,8 @@ def do_predict(args):
with io.open(args.infer_output_file, 'w', encoding='utf-8') as f: with io.open(args.infer_output_file, 'w', encoding='utf-8') as f:
for data in data_loader(): for data in data_loader():
finished_seq = model.test(inputs=flatten(data))[0] finished_seq = model.test(inputs=flatten(data))[0]
finished_seq = finished_seq[:, :, np.newaxis] if len(
finished_seq.shape == 2) else finished_seq
finished_seq = np.transpose(finished_seq, [0, 2, 1]) finished_seq = np.transpose(finished_seq, [0, 2, 1])
for ins in finished_seq: for ins in finished_seq:
for beam_idx, beam in enumerate(ins): for beam_idx, beam in enumerate(ins):
......
...@@ -44,7 +44,8 @@ def create_data_loader(args, device, for_train=True): ...@@ -44,7 +44,8 @@ def create_data_loader(args, device, for_train=True):
end_mark="</s>", end_mark="</s>",
unk_mark="<unk>", unk_mark="<unk>",
max_length=args.max_len if i == 0 else None, max_length=args.max_len if i == 0 else None,
truncate=True) truncate=True,
trg_add_bos_eos=True)
(args.src_vocab_size, args.tar_vocab_size, bos_id, eos_id, (args.src_vocab_size, args.tar_vocab_size, bos_id, eos_id,
unk_id) = dataset.get_vocab_summary() unk_id) = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler( batch_sampler = Seq2SeqBatchSampler(
...@@ -53,7 +54,8 @@ def create_data_loader(args, device, for_train=True): ...@@ -53,7 +54,8 @@ def create_data_loader(args, device, for_train=True):
batch_size=args.batch_size, batch_size=args.batch_size,
pool_size=args.batch_size * 20, pool_size=args.batch_size * 20,
sort_type=SortType.POOL, sort_type=SortType.POOL,
shuffle=False if args.enable_ce else True) shuffle=False if args.enable_ce else True,
distribute_mode=True if i == 0 else False)
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
...@@ -73,7 +75,7 @@ def prepare_train_input(insts, bos_id, eos_id, pad_id): ...@@ -73,7 +75,7 @@ def prepare_train_input(insts, bos_id, eos_id, pad_id):
src, src_length = pad_batch_data( src, src_length = pad_batch_data(
[inst[0] for inst in insts], pad_id=pad_id) [inst[0] for inst in insts], pad_id=pad_id)
trg, trg_length = pad_batch_data( trg, trg_length = pad_batch_data(
[[bos_id] + inst[1] + [eos_id] for inst in insts], pad_id=pad_id) [inst[1] for inst in insts], pad_id=pad_id)
trg_length = trg_length - 1 trg_length = trg_length - 1
return src, src_length, trg[:, :-1], trg_length, trg[:, 1:, np.newaxis] return src, src_length, trg[:, :-1], trg_length, trg[:, 1:, np.newaxis]
...@@ -165,9 +167,24 @@ class TokenBatchCreator(object): ...@@ -165,9 +167,24 @@ class TokenBatchCreator(object):
class SampleInfo(object): class SampleInfo(object):
def __init__(self, i, lens): def __init__(self, i, lens):
self.i = i self.i = i
# to be consistent with origianl reader implementation self.lens = lens
self.min_len = lens[0]
self.max_len = lens[0] def get_ranges(self, min_length=None, max_length=None, truncate=False):
ranges = []
# source
if (min_length is None or self.lens[0] >= min_length) and (
max_length is None or self.lens[0] <= max_length or truncate):
end = max_length if truncate and max_length else self.lens[0]
ranges.append([0, end])
# target
if len(self.lens) == 2:
if (min_length is None or self.lens[1] >= min_length) and (
max_length is None or self.lens[1] <= max_length + 2 or
truncate):
end = max_length + 2 if truncate and max_length else self.lens[
1]
ranges.append([0, end])
return ranges if len(ranges) == len(self.lens) else None
class MinMaxFilter(object): class MinMaxFilter(object):
...@@ -197,6 +214,7 @@ class Seq2SeqDataset(Dataset): ...@@ -197,6 +214,7 @@ class Seq2SeqDataset(Dataset):
end_mark="<e>", end_mark="<e>",
unk_mark="<unk>", unk_mark="<unk>",
trg_fpattern=None, trg_fpattern=None,
trg_add_bos_eos=False,
byte_data=False, byte_data=False,
min_length=None, min_length=None,
max_length=None, max_length=None,
...@@ -220,6 +238,7 @@ class Seq2SeqDataset(Dataset): ...@@ -220,6 +238,7 @@ class Seq2SeqDataset(Dataset):
self._min_length = min_length self._min_length = min_length
self._max_length = max_length self._max_length = max_length
self._truncate = truncate self._truncate = truncate
self._trg_add_bos_eos = trg_add_bos_eos
self.load_src_trg_ids(fpattern, trg_fpattern) self.load_src_trg_ids(fpattern, trg_fpattern)
def load_src_trg_ids(self, fpattern, trg_fpattern=None): def load_src_trg_ids(self, fpattern, trg_fpattern=None):
...@@ -238,8 +257,8 @@ class Seq2SeqDataset(Dataset): ...@@ -238,8 +257,8 @@ class Seq2SeqDataset(Dataset):
end=self._eos_idx, end=self._eos_idx,
unk=self._unk_idx, unk=self._unk_idx,
delimiter=self._token_delimiter, delimiter=self._token_delimiter,
add_beg=False, add_beg=True if self._trg_add_bos_eos else False,
add_end=False) add_end=True if self._trg_add_bos_eos else False)
converters = ComposedConverter([src_converter, trg_converter]) converters = ComposedConverter([src_converter, trg_converter])
...@@ -252,13 +271,12 @@ class Seq2SeqDataset(Dataset): ...@@ -252,13 +271,12 @@ class Seq2SeqDataset(Dataset):
fields = converters(line) fields = converters(line)
lens = [len(field) for field in fields] lens = [len(field) for field in fields]
sample = SampleInfo(i, lens) sample = SampleInfo(i, lens)
if (self._min_length is None or field_ranges = sample.get_ranges(self._min_length,
sample.min_len >= self._min_length) and ( self._max_length, self._truncate)
self._max_length is None or if field_ranges:
sample.max_len <= self._max_length or self._truncate): for field, field_range, slot in zip(fields, field_ranges,
for field, slot in zip(fields, slots): slots):
slot.append(field[:self._max_length] if self._truncate and slot.append(field[field_range[0]:field_range[1]])
self._max_length is not None else field)
self._sample_infos.append(sample) self._sample_infos.append(sample)
def _load_lines(self, fpattern, trg_fpattern=None): def _load_lines(self, fpattern, trg_fpattern=None):
......
...@@ -152,7 +152,7 @@ class AttentionModel(Model): ...@@ -152,7 +152,7 @@ class AttentionModel(Model):
self.decoder = Decoder(trg_vocab_size, embed_dim, hidden_size, self.decoder = Decoder(trg_vocab_size, embed_dim, hidden_size,
num_layers, dropout_prob, init_scale) num_layers, dropout_prob, init_scale)
def forward(self, src, src_length, trg, trg_length): def forward(self, src, src_length, trg):
# encoder # encoder
encoder_output, encoder_final_state = self.encoder(src, src_length) encoder_output, encoder_final_state = self.encoder(src, src_length)
...@@ -174,11 +174,7 @@ class AttentionModel(Model): ...@@ -174,11 +174,7 @@ class AttentionModel(Model):
# decoder with attentioon # decoder with attentioon
predict = self.decoder(trg, decoder_initial_states, encoder_output, predict = self.decoder(trg, decoder_initial_states, encoder_output,
encoder_padding_mask) encoder_padding_mask)
return predict
# for target padding mask
mask = layers.sequence_mask(
trg_length, maxlen=layers.shape(trg)[1], dtype=predict.dtype)
return predict, mask
class AttentionInferModel(AttentionModel): class AttentionInferModel(AttentionModel):
...@@ -242,3 +238,90 @@ class AttentionInferModel(AttentionModel): ...@@ -242,3 +238,90 @@ class AttentionInferModel(AttentionModel):
encoder_output=encoder_output, encoder_output=encoder_output,
encoder_padding_mask=encoder_padding_mask) encoder_padding_mask=encoder_padding_mask)
return rs return rs
class GreedyEmbeddingHelper(fluid.layers.GreedyEmbeddingHelper):
def __init__(self, embedding_fn, start_tokens, end_token):
if isinstance(start_tokens, int):
self.need_convert_start_tokens = True
self.start_token_value = start_tokens
super(GreedyEmbeddingHelper, self).__init__(embedding_fn, start_tokens,
end_token)
def initialize(self, batch_ref=None):
if getattr(self, "need_convert_start_tokens", False):
assert batch_ref is not None, (
"Need to give batch_ref to get batch size "
"to initialize the tensor for start tokens.")
self.start_tokens = fluid.layers.fill_constant_batch_size_like(
input=fluid.layers.utils.flatten(batch_ref)[0],
shape=[-1],
dtype="int64",
value=self.start_token_value,
input_dim_idx=0)
return super(GreedyEmbeddingHelper, self).initialize()
class BasicDecoder(fluid.layers.BasicDecoder):
def initialize(self, initial_cell_states):
(initial_inputs,
initial_finished) = self.helper.initialize(initial_cell_states)
return initial_inputs, initial_cell_states, initial_finished
class AttentionGreedyInferModel(AttentionModel):
def __init__(self,
src_vocab_size,
trg_vocab_size,
embed_dim,
hidden_size,
num_layers,
dropout_prob=0.,
bos_id=0,
eos_id=1,
beam_size=1,
max_out_len=256):
args = dict(locals())
args.pop("self")
args.pop("__class__", None) # py3
args.pop("beam_size", None)
self.bos_id = args.pop("bos_id")
self.eos_id = args.pop("eos_id")
self.max_out_len = args.pop("max_out_len")
super(AttentionGreedyInferModel, self).__init__(**args)
# dynamic decoder for inference
decoder_helper = GreedyEmbeddingHelper(
start_tokens=bos_id,
end_token=eos_id,
embedding_fn=self.decoder.embedder)
decoder = BasicDecoder(
cell=self.decoder.lstm_attention.cell,
helper=decoder_helper,
output_fn=self.decoder.output_layer)
self.greedy_search_decoder = DynamicDecode(
decoder, max_step_num=max_out_len, is_test=True)
def forward(self, src, src_length):
# encoding
encoder_output, encoder_final_state = self.encoder(src, src_length)
# decoder initial states
decoder_initial_states = [
encoder_final_state,
self.decoder.lstm_attention.cell.get_initial_states(
batch_ref=encoder_output, shape=[self.hidden_size])
]
# attention mask to avoid paying attention on padddings
src_mask = layers.sequence_mask(
src_length,
maxlen=layers.shape(src)[1],
dtype=encoder_output.dtype)
encoder_padding_mask = (src_mask - 1.0) * 1e9
encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1])
# dynamic decoding with beam search
rs, _ = self.greedy_search_decoder(
inits=decoder_initial_states,
encoder_output=encoder_output,
encoder_padding_mask=encoder_padding_mask)
return rs.sample_ids
...@@ -27,7 +27,10 @@ class CrossEntropyCriterion(Loss): ...@@ -27,7 +27,10 @@ class CrossEntropyCriterion(Loss):
super(CrossEntropyCriterion, self).__init__() super(CrossEntropyCriterion, self).__init__()
def forward(self, outputs, labels): def forward(self, outputs, labels):
(predict, mask), label = outputs, labels[0] predict, (trg_length, label) = outputs[0], labels
# for target padding mask
mask = layers.sequence_mask(
trg_length, maxlen=layers.shape(predict)[1], dtype=predict.dtype)
cost = layers.softmax_with_cross_entropy( cost = layers.softmax_with_cross_entropy(
logits=predict, label=label, soft_label=False) logits=predict, label=label, soft_label=False)
...@@ -151,17 +154,13 @@ class BaseModel(Model): ...@@ -151,17 +154,13 @@ class BaseModel(Model):
self.decoder = Decoder(trg_vocab_size, embed_dim, hidden_size, self.decoder = Decoder(trg_vocab_size, embed_dim, hidden_size,
num_layers, dropout_prob, init_scale) num_layers, dropout_prob, init_scale)
def forward(self, src, src_length, trg, trg_length): def forward(self, src, src_length, trg):
# encoder # encoder
encoder_output, encoder_final_states = self.encoder(src, src_length) encoder_output, encoder_final_states = self.encoder(src, src_length)
# decoder # decoder
predict = self.decoder(trg, encoder_final_states) predict = self.decoder(trg, encoder_final_states)
return predict
# for target padding mask
mask = layers.sequence_mask(
trg_length, maxlen=layers.shape(trg)[1], dtype=predict.dtype)
return predict, mask
class BaseInferModel(BaseModel): class BaseInferModel(BaseModel):
......
...@@ -24,6 +24,7 @@ import paddle.fluid as fluid ...@@ -24,6 +24,7 @@ import paddle.fluid as fluid
from paddle.fluid.io import DataLoader from paddle.fluid.io import DataLoader
from model import Input, set_device from model import Input, set_device
from metrics import Metric
from callbacks import ProgBarLogger from callbacks import ProgBarLogger
from args import parse_args from args import parse_args
from seq2seq_base import BaseModel, CrossEntropyCriterion from seq2seq_base import BaseModel, CrossEntropyCriterion
...@@ -31,6 +32,65 @@ from seq2seq_attn import AttentionModel ...@@ -31,6 +32,65 @@ from seq2seq_attn import AttentionModel
from reader import create_data_loader from reader import create_data_loader
class TrainCallback(ProgBarLogger):
def __init__(self, args, ppl, verbose=2):
super(TrainCallback, self).__init__(1, verbose)
# control metric
self.ppl = ppl
self.batch_size = args.batch_size
def on_train_begin(self, logs=None):
super(TrainCallback, self).on_train_begin(logs)
self.train_metrics += ["ppl"] # remove loss to not print it
self.ppl.reset()
def on_train_batch_end(self, step, logs=None):
batch_loss = logs["loss"][0]
self.ppl.total_loss += batch_loss * self.batch_size
logs["ppl"] = np.exp(self.ppl.total_loss / self.ppl.word_count)
if step > 0 and step % self.ppl.reset_freq == 0:
self.ppl.reset()
super(TrainCallback, self).on_train_batch_end(step, logs)
def on_eval_begin(self, logs=None):
super(TrainCallback, self).on_eval_begin(logs)
self.eval_metrics = ["ppl"]
self.ppl.reset()
def on_eval_batch_end(self, step, logs=None):
batch_loss = logs["loss"][0]
self.ppl.total_loss += batch_loss * self.batch_size
logs["ppl"] = np.exp(self.ppl.total_loss / self.ppl.word_count)
super(TrainCallback, self).on_eval_batch_end(step, logs)
class PPL(Metric):
def __init__(self, reset_freq=100, name=None):
super(PPL, self).__init__()
self._name = name or "ppl"
self.reset_freq = reset_freq
self.reset()
def add_metric_op(self, pred, label):
seq_length = label[0]
word_num = fluid.layers.reduce_sum(seq_length)
return word_num
def update(self, word_num):
self.word_count += word_num
return word_num
def reset(self):
self.total_loss = 0
self.word_count = 0
def accumulate(self):
return self.word_count
def name(self):
return self._name
def do_train(args): def do_train(args):
device = set_device("gpu" if args.use_gpu else "cpu") device = set_device("gpu" if args.use_gpu else "cpu")
fluid.enable_dygraph(device) if args.eager_run else None fluid.enable_dygraph(device) if args.eager_run else None
...@@ -47,10 +107,13 @@ def do_train(args): ...@@ -47,10 +107,13 @@ def do_train(args):
[None], "int64", name="src_length"), [None], "int64", name="src_length"),
Input( Input(
[None, None], "int64", name="trg_word"), [None, None], "int64", name="trg_word"),
]
labels = [
Input( Input(
[None], "int64", name="trg_length"), [None], "int64", name="trg_length"),
Input(
[None, None, 1], "int64", name="label"),
] ]
labels = [Input([None, None, 1], "int64", name="label"), ]
# def dataloader # def dataloader
train_loader, eval_loader = create_data_loader(args, device) train_loader, eval_loader = create_data_loader(args, device)
...@@ -63,15 +126,20 @@ def do_train(args): ...@@ -63,15 +126,20 @@ def do_train(args):
learning_rate=args.learning_rate, parameter_list=model.parameters()) learning_rate=args.learning_rate, parameter_list=model.parameters())
optimizer._grad_clip = fluid.clip.GradientClipByGlobalNorm( optimizer._grad_clip = fluid.clip.GradientClipByGlobalNorm(
clip_norm=args.max_grad_norm) clip_norm=args.max_grad_norm)
ppl_metric = PPL()
model.prepare( model.prepare(
optimizer, CrossEntropyCriterion(), inputs=inputs, labels=labels) optimizer,
CrossEntropyCriterion(),
ppl_metric,
inputs=inputs,
labels=labels)
model.fit(train_data=train_loader, model.fit(train_data=train_loader,
eval_data=eval_loader, eval_data=eval_loader,
epochs=args.max_epoch, epochs=args.max_epoch,
eval_freq=1, eval_freq=1,
save_freq=1, save_freq=1,
save_dir=args.model_path, save_dir=args.model_path,
log_freq=1) callbacks=[TrainCallback(args, ppl_metric)])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册