提交 4eda2803 编写于 作者: G guosheng

Code clean for fast_decoder of Transformer

上级 3cbf0f73
......@@ -33,9 +33,9 @@ class TrainTaskConfig(object):
class InferTaskConfig(object):
use_gpu = False
use_gpu = True
# the number of examples in one run for sequence generation.
batch_size = 2
batch_size = 10
# the parameters for beam search.
beam_size = 5
max_out_len = 256
......@@ -108,7 +108,7 @@ def merge_cfg_from_list(cfg_list, g_cfgs):
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
# The placeholder for squence length in compile time.
# The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
......
......@@ -88,7 +88,8 @@ def translate_batch(exe,
output_unk=True):
"""
Run the encoder program once and run the decoder program multiple times to
implement beam search externally.
implement beam search externally. This is deprecated since a faster beam
search decoder based solely on Fluid operators has been added.
"""
# Prepare data for encoder and run the encoder.
enc_in_data = pad_batch_data(
......@@ -255,8 +256,6 @@ def translate_batch(exe,
predict_all = exe.run(decoder,
feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0]
print predict_all.reshape(
[len(beam_inst_map) * beam_size, i + 1, -1])[:, -1, :]
predict_all = np.log(
predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1])
[:, -1, :])
......@@ -275,19 +274,11 @@ def translate_batch(exe,
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[::
-1]]
# top_scores_ids = np.asarray(
# sorted(
# top_scores_ids,
# lambda x, y: x / predict_all.shape[-1] - y / predict_all.shape[-1]
# )) # sort by pre_branch and score to compare with fast_infer
top_scores = predict[top_scores_ids]
scores[beam_idx] = top_scores
prev_branchs[beam_idx].append(top_scores_ids /
predict_all.shape[-1])
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1])
print prev_branchs[beam_idx][-1]
print next_ids[beam_idx][-1]
print top_scores
if next_ids[beam_idx][-1][0] != eos_idx:
active_beams.append(beam_idx)
if len(active_beams) == 0:
......@@ -308,7 +299,32 @@ def translate_batch(exe,
return seqs, scores[:, :n_best].tolist()
def infer(args):
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)
def py_infer(test_data, trg_idx2word):
"""
Inference by beam search implented by python, while the calculations from
symbols to probilities execute by Fluid operators.
"""
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -355,48 +371,7 @@ def infer(args):
encoder_program = encoder_program.inference_optimize()
decoder_program = decoder_program.inference_optimize()
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
batch_size=args.batch_size,
use_token_batch=False,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)
for batch_id, data in enumerate(test_data.batch_generator()):
if batch_id != 0:
continue
batch_seqs, batch_scores = translate_batch(
exe,
[item[0] for item in data],
......@@ -425,14 +400,12 @@ def infer(args):
scores = batch_scores[i]
for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq]))
print scores
exit(0)
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
bos_idx, n_head, d_model, place):
"""
Put all padded data needed by inference into a dict.
Put all padded data needed by beam search decoder into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
......@@ -492,18 +465,21 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
return input_dict
def fast_infer(args):
def fast_infer(test_data, trg_idx2word):
"""
Inference by beam search decoder based solely on Fluid operators.
"""
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
ids, scores = fast_decoder(
out_ids, out_scores = fast_decoder(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
InferTaskConfig.beam_size, InferTaskConfig.max_out_len,
ModelHyperParams.eos_idx)
ModelHyperParams.weight_sharing, InferTaskConfig.beam_size,
InferTaskConfig.max_out_len, ModelHyperParams.eos_idx)
fluid.io.load_vars(
exe,
......@@ -514,28 +490,7 @@ def fast_infer(args):
# This is used here to set dropout to the test mode.
infer_program = fluid.default_main_program().inference_optimize()
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
batch_size=args.batch_size,
use_token_batch=False,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
for batch_id, data in enumerate(test_data.batch_generator()):
if batch_id != 0:
continue
data_input = prepare_batch_input(
data, encoder_data_input_fields + fast_decoder_data_input_fields,
encoder_util_input_fields + fast_decoder_util_input_fields,
......@@ -543,10 +498,16 @@ def fast_infer(args):
ModelHyperParams.n_head, ModelHyperParams.d_model, place)
seq_ids, seq_scores = exe.run(infer_program,
feed=data_input,
fetch_list=[ids, scores],
fetch_list=[out_ids, out_scores],
return_numpy=False)
# print np.array(seq_ids)#, np.array(seq_scores)
# print seq_ids.lod()#, seq_scores.lod()
# How to parse the results:
# Suppose the lod of seq_ids is:
# [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
# then from lod[0]:
# there are 2 source sentences, beam width is 3.
# from lod[1]:
# the first source sentence has 3 hyps; the lengths are 12, 12, 16
# the second source sentence has 3 hyps; the lengths are 14, 13, 15
hyps = [[] for i in range(len(data))]
scores = [[] for i in range(len(data))]
for i in range(len(seq_ids.lod()[0]) - 1): # for each source sentence
......@@ -557,16 +518,38 @@ def fast_infer(args):
sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(" ".join([
trg_idx2word[idx]
for idx in np.array(seq_ids)[sub_start:sub_end]
for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end])
]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print hyps[i]
print scores[i]
print len(hyps[i]), [len(hyp.split()) for hyp in hyps[i]]
exit(0)
print hyps[i][-1]
if len(hyps[i]) >= InferTaskConfig.n_best:
break
def infer(args, inferencer=fast_infer):
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
batch_size=args.batch_size,
use_token_batch=False,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
inferencer(test_data, trg_idx2word)
if __name__ == "__main__":
args = parse_args()
fast_infer(args)
# infer(args)
infer(args)
......@@ -6,8 +6,6 @@ import paddle.fluid.layers as layers
from config import *
FLAG = False
def position_encoding_init(n_position, d_pos_vec):
"""
......@@ -103,12 +101,6 @@ def multi_head_attention(queries,
"""
scaled_q = layers.scale(x=q, scale=d_model**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
# global FLAG
# if FLAG and attn_bias:
# print "hehehehehe"
# layers.Print(product, message="product")
# layers.Print(attn_bias, message="bias")
# FLAG = False
weights = layers.reshape(
x=layers.elementwise_add(
x=product, y=attn_bias) if attn_bias else product,
......@@ -117,19 +109,9 @@ def multi_head_attention(queries,
act="softmax")
weights = layers.reshape(
x=weights, shape=product.shape, actual_shape=post_softmax_shape)
# global FLAG
# if FLAG:
# print "hehehehehe"
# layers.Print(scaled_q)
# layers.Print(k)
# layers.Print(v)
# layers.Print(product)
# layers.Print(weights)
# FLAG = False
if dropout_rate:
weights = layers.dropout(
weights, dropout_prob=dropout_rate, is_test=False)
out = layers.matmul(weights, v)
return out
......@@ -138,13 +120,7 @@ def multi_head_attention(queries,
if cache is not None: # use cache and concat time steps
k = cache["k"] = layers.concat([cache["k"], k], axis=1)
v = cache["v"] = layers.concat([cache["v"], v], axis=1)
# global FLAG
# if FLAG:
# print "hehehehehe"
# layers.Print(q)
# layers.Print(k)
# layers.Print(v)
# FLAG = False
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
v = __split_heads(v, n_head)
......@@ -153,16 +129,12 @@ def multi_head_attention(queries,
dropout_rate)
out = __combine_heads(ctx_multiheads)
# Project back to the model size.
proj_out = layers.fc(input=out,
size=d_model,
bias_attr=False,
num_flatten_dims=2)
# global FLAG
# if FLAG:
# print "hehehehehe"
# layers.Print(proj_out)
# FLAG = False
return proj_out
......@@ -391,15 +363,22 @@ def decoder(dec_input,
The decoder is composed of a stack of identical decoder_layer layers.
"""
for i in range(n_layer):
if i == 0: #n_layer-1:
global FLAG
FLAG = True
dec_output = decoder_layer(
dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias, n_head,
d_key, d_value, d_model, d_inner_hid, dropout_rate,
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, None
if caches is None else caches[i])
dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape,
None if caches is None else caches[i], )
dec_input = dec_output
return dec_output
......@@ -625,12 +604,17 @@ def fast_decode(
d_model,
d_inner_hid,
dropout_rate,
weight_sharing,
beam_size,
max_out_len,
eos_idx, ):
"""
Use beam search to decode. Caches will be used to store states of history
steps which can make the decoding faster.
"""
enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
dropout_rate)
dropout_rate, weight_sharing)
start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
......@@ -643,16 +627,14 @@ def fast_decode(
shape=[1], dtype=start_tokens.dtype, value=max_out_len)
step_idx = layers.fill_constant(
shape=[1], dtype=start_tokens.dtype, value=0)
# cond = layers.fill_constant(
# shape=[1], dtype='bool', value=1, force_cpu=True)
cond = layers.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond)
# init_scores = layers.fill_constant_batch_size_like(
# input=start_tokens, shape=[-1, 1], dtype="float32", value=0)
# array states
# array states will be stored for each step.
ids = layers.array_write(start_tokens, step_idx)
scores = layers.array_write(init_scores, step_idx)
# cell states (can be overwrited)
# cell states will be overwrited at each step.
# caches contains states of history steps to reduce redundant
# computation in decoder.
caches = [{
"k": layers.fill_constant_batch_size_like(
input=start_tokens,
......@@ -668,9 +650,10 @@ def fast_decode(
with while_op.block():
pre_ids = layers.array_read(array=ids, i=step_idx)
pre_scores = layers.array_read(array=scores, i=step_idx)
# sequence_expand can gather sequences according to lod thus can be
# used in beam search to sift states corresponding to selected ids.
pre_src_attn_bias = layers.sequence_expand(
x=trg_src_attn_bias, y=pre_scores)
# layers.Print(pre_src_attn_bias)
pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
pre_caches = [{
"k": layers.sequence_expand(
......@@ -687,13 +670,6 @@ def fast_decode(
y=layers.increment(
x=step_idx, value=1.0, in_place=False),
axis=0)
# layers.Print(pre_ids, summarize=10)
# layers.Print(pre_pos, summarize=10)
# layers.Print(pre_enc_output, summarize=10)
# layers.Print(pre_src_attn_bias, summarize=10)
# layers.Print(pre_caches[0]["k"], summarize=10)
# layers.Print(pre_caches[0]["v"], summarize=10)
# layers.Print(slf_attn_post_softmax_shape)
logits = wrap_decoder(
trg_vocab_size,
max_in_len,
......@@ -704,19 +680,16 @@ def fast_decode(
d_model,
d_inner_hid,
dropout_rate,
weight_sharing,
dec_inputs=(
pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
enc_output=pre_enc_output,
caches=pre_caches)
# layers.Print(logits)
topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size)
# layers.Print(topk_scores)
# layers.Print(topk_indices)
accu_scores = layers.elementwise_add(
# x=layers.log(x=layers.softmax(topk_scores)),
x=layers.log(topk_scores),
y=layers.reshape(
pre_scores, shape=[-1]),
......@@ -739,9 +712,6 @@ def fast_decode(
for i in range(n_layer):
layers.assign(pre_caches[i]["k"], caches[i]["k"])
layers.assign(pre_caches[i]["v"], caches[i]["v"])
layers.Print(selected_ids)
layers.Print(selected_scores)
# layers.Print(caches[-1]["k"])
layers.assign(
layers.elementwise_add(
x=slf_attn_pre_softmax_shape,
......@@ -755,12 +725,8 @@ def fast_decode(
length_cond = layers.less_than(x=step_idx, y=max_len)
finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
# layers.Print(length_cond)
# layers.Print(finish_cond)
layers.logical_and(x=length_cond, y=finish_cond, out=cond)
layers.Print(step_idx)
# finished_ids, finished_scores = layers.beam_search_decode(ids, scores,
# eos_idx)
finished_ids, finished_scores = layers.beam_search_decode(
ids, scores, beam_size=beam_size, end_id=eos_idx)
return finished_ids, finished_scores
......
......@@ -198,7 +198,8 @@ class DataReader(object):
for line in f_obj:
fields = line.strip().split(self._delimiter)
if len(fields) != 2 or (self._only_src and len(fields) != 1):
if (not self._only_src and len(fields) != 2) or (self._only_src and
len(fields) != 1):
continue
sample_words = []
......@@ -275,7 +276,7 @@ class DataReader(object):
for sample_idx in self._sample_idxs:
if self._only_src:
yield (self._src_seq_ids[sample_idx])
yield (self._src_seq_ids[sample_idx], )
else:
yield (self._src_seq_ids[sample_idx],
self._trg_seq_ids[sample_idx][:-1],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册