提交 d963d69e 编写于 作者: G guosheng

Make fast decoder run smoothly in Transformer

上级 dead21e4
......@@ -33,12 +33,12 @@ class TrainTaskConfig(object):
class InferTaskConfig(object):
use_gpu = True
use_gpu = False
# the number of examples in one run for sequence generation.
batch_size = 10
batch_size = 2
# the parameters for beam search.
beam_size = 5
max_length = 30
max_out_len = 30
# the number of decoded sentences to output.
n_best = 1
# the flags indicating whether to output the special tokens.
......@@ -103,24 +103,28 @@ def merge_cfg_from_list(cfg_list, g_cfgs):
break
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
# The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1]
"src_word": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"],
"src_word": [(batch_size * seq_len, 1L), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1]
"src_pos": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"],
"src_pos": [(batch_size * seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias":
[(batch_size, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"],
"src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This shape input is used to reshape the output of embedding layer.
"src_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
......@@ -129,24 +133,23 @@ input_descs = {
"src_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_word": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"],
"trg_word": [(batch_size * seq_len, 1L), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_pos": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"],
"trg_pos": [(batch_size * seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head,
(ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"],
"trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head,
(ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"],
"trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This shape input is used to reshape the output of embedding layer.
"trg_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
......@@ -162,17 +165,18 @@ input_descs = {
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(1, (ModelHyperParams.max_length + 1),
ModelHyperParams.d_model), "float32"],
"enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
"lbl_word": [(batch_size * seq_len, 1L), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"],
# These two inputs are used for beam search decoder.
# "start_token": [(1 * 1, 1L), "int64"],
"lbl_weight": [(batch_size * seq_len, 1L), "float32"],
# These inputs are used to change the shape tensor in beam-search decoder.
"trg_slf_attn_pre_softmax_shape_delta": [(2L, ), "int32"],
"trg_slf_attn_post_softmax_shape_delta": [(4L, ), "int32"],
"init_score": [(batch_size, 1L), "float32"],
}
# Names of position encoding table which will be initialized externally.
......@@ -203,7 +207,12 @@ decoder_util_input_fields = (
label_data_input_fields = (
"lbl_word",
"lbl_weight", )
fast_decoder_data_fields = (
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
"trg_word",
# "start_token",
"init_score",
"trg_src_attn_bias", )
fast_decoder_util_input_fields = decoder_util_input_fields + (
"trg_slf_attn_pre_softmax_shape_delta",
"trg_slf_attn_post_softmax_shape_delta", )
......@@ -397,7 +397,7 @@ def infer(args):
(decoder_data_input_fields[-1], ),
[predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_length,
InferTaskConfig.max_out_len,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
......@@ -416,16 +416,135 @@ def infer(args):
print(" ".join([trg_idx2word[idx] for idx in seq]))
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
bos_idx, n_head, d_model, place):
"""
Put all padded data needed by inference into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
# start tokens
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
# These shape tensors are used in reshape_op.
src_data_shape = np.array([-1, src_max_len, d_model], dtype="int32")
trg_data_shape = np.array([-1, 1, d_model], dtype="int32")
src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array(
[-1] + list(src_slf_attn_bias.shape[1:]), dtype="int32")
trg_slf_attn_pre_softmax_shape = np.array(
[-1, 1], dtype="int32") # only the first time step
trg_slf_attn_post_softmax_shape = np.array(
[-1, n_head, 1, 1], dtype="int32") # only the first time step
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
[-1] + list(trg_src_attn_bias.shape[1:]), dtype="int32")
# These inputs are used to change the shapes in the loop of while op.
attn_pre_softmax_shape_delta = np.array([0, 1], dtype="int32")
attn_post_softmax_shape_delta = np.array([0, 0, 0, 1], dtype="int32")
def to_lodtensor(data, place, lod=None):
data_tensor = fluid.LoDTensor()
data_tensor.set(data, place)
if lod is not None:
data_tensor.set_lod(lod)
return data_tensor
# beamsearch_op must use tensors with lod
init_score = to_lodtensor(
np.zeros_like(
trg_word, dtype="float32"),
place, [range(trg_word.shape[0] + 1)] * 2)
trg_word = to_lodtensor(trg_word, place, [range(trg_word.shape[0] + 1)] * 2)
data_input_dict = dict(
zip(data_input_names, [
src_word, src_pos, src_slf_attn_bias, trg_word, init_score,
trg_src_attn_bias
]))
util_input_dict = dict(
zip(util_input_names, [
src_data_shape, src_slf_attn_pre_softmax_shape,
src_slf_attn_post_softmax_shape, trg_data_shape,
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape,
attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta
]))
input_dict = dict(data_input_dict.items() + util_input_dict.items())
return input_dict
def fast_infer(args):
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
ids, scores = fast_decoder(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
InferTaskConfig.beam_size, InferTaskConfig.max_out_len,
ModelHyperParams.eos_idx)
fluid.io.load_vars(
exe,
InferTaskConfig.model_path,
vars=filter(lambda var: isinstance(var, fluid.framework.Parameter),
fluid.default_main_program().list_vars()))
# This is used here to set dropout to the test mode.
infer_program = fluid.default_main_program().inference_optimize()
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
batch_size=args.batch_size,
use_token_batch=False,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
for batch_id, data in enumerate(test_data.batch_generator()):
data_input = prepare_batch_input(
data, encoder_data_input_fields + fast_decoder_data_input_fields,
encoder_util_input_fields + fast_decoder_util_input_fields,
ModelHyperParams.eos_idx, ModelHyperParams.bos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model, place)
seq_ids, seq_scores = exe.run(infer_program,
feed=data_input,
fetch_list=[ids, scores],
return_numpy=False)
# print np.array(seq_ids)#, np.array(seq_scores)
# print seq_ids.lod()#, seq_scores.lod()
hyps = [[] for i in range(len(data))]
for i in range(len(seq_ids.lod()[0]) - 1): # for each source sentence
start = seq_ids.lod()[0][i]
end = seq_ids.lod()[0][i + 1]
for j in range(end - start): # for each candidate
sub_start = seq_ids.lod()[1][start + j]
sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(" ".join([
trg_idx2word[idx]
for idx in np.array(seq_ids)[sub_start:sub_end]
]))
print hyps[i]
if __name__ == "__main__":
fast_decoder(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
InferTaskConfig.beam_size, InferTaskConfig.max_length,
ModelHyperParams.eos_idx)
print(fluid.default_main_program())
exit(0)
args = parse_args()
infer(args)
fast_infer(args)
......@@ -85,7 +85,7 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, -1, n_head, hidden_size // n_head])
x=x, shape=[0, 0, n_head, hidden_size // n_head])
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
......@@ -105,7 +105,7 @@ def multi_head_attention(queries,
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x,
shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]]))
shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]]))
def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
"""
......@@ -124,17 +124,15 @@ def multi_head_attention(queries,
if dropout_rate:
weights = layers.dropout(
weights, dropout_prob=dropout_rate, is_test=False)
out = layers.matmul(weights, v)
return out
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
print cache["k"].shape
print k.shape
k = cache["k"] = layers.concat([cache["k"], k], axis=1)
v = cache["v"] = layers.concat([cache["v"], v], axis=1)
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
v = __split_heads(v, n_head)
......@@ -143,7 +141,6 @@ def multi_head_attention(queries,
dropout_rate)
out = __combine_heads(ctx_multiheads)
# Project back to the model size.
proj_out = layers.fc(input=out,
size=d_model,
......@@ -225,7 +222,7 @@ def prepare_encoder(src_word,
enc_input = src_word_emb + src_pos_enc
enc_input = layers.reshape(
x=enc_input,
shape=[-1, src_max_len, src_emb_dim],
shape=[batch_size, seq_len, src_emb_dim],
actual_shape=src_data_shape)
return layers.dropout(
enc_input, dropout_prob=dropout_rate,
......@@ -400,6 +397,8 @@ def make_all_inputs(input_fields):
name=input_field,
shape=input_descs[input_field][0],
dtype=input_descs[input_field][1],
lod_level=input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0,
append_batch_size=False)
inputs.append(input_var)
return inputs
......@@ -460,7 +459,6 @@ def transformer(
logits=predict,
label=label,
soft_label=True if label_smooth_eps else False)
# cost = layers.softmax_with_cross_entropy(logits=predict, label=gold)
weighted_cost = cost * weights
sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights)
......@@ -595,19 +593,24 @@ def fast_decode(
enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
dropout_rate)
start_tokens, trg_src_attn_bias, trg_data_shape, \
start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape = \
make_all_inputs(fast_decoder_data_fields + decoder_util_input_fields)
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta = \
make_all_inputs(fast_decoder_data_input_fields +
fast_decoder_util_input_fields)
def beam_search():
cond = layers.create_tensor(dtype='bool')
while_op = layers.While(cond)
max_len = layers.fill_constant(
shape=[1], dtype='int32', value=max_out_len)
step_idx = layers.fill_constant(shape=[1], dtype='int32', value=0)
init_scores = layers.fill_constant_batch_size_like(
input=start_tokens, shape=[-1, 1], dtype="float32", value=0)
shape=[1], dtype=start_tokens.dtype, value=max_out_len)
step_idx = layers.fill_constant(
shape=[1], dtype=start_tokens.dtype, value=0)
# cond = layers.fill_constant(
# shape=[1], dtype='bool', value=1, force_cpu=True)
cond = layers.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond)
# init_scores = layers.fill_constant_batch_size_like(
# input=start_tokens, shape=[-1, 1], dtype="float32", value=0)
# array states
ids = layers.array_write(start_tokens, step_idx)
scores = layers.array_write(init_scores, step_idx)
......@@ -616,34 +619,38 @@ def fast_decode(
"k": layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, 0, d_model],
dtype="float32",
dtype=enc_output.dtype,
value=0),
"v": layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, 0, d_model],
dtype="float32",
dtype=enc_output.dtype,
value=0)
} for i in range(n_layer)]
with while_op.block():
pre_ids = layers.array_read(array=ids, i=step_idx)
pre_scores = layers.array_read(array=scores, i=step_idx)
pre_pos = layers.elementwise_mul(
x=layers.fill_constant_batch_size_like(
input=pre_ids, value=1, shape=[-1, 1], dtype='int32'),
input=pre_ids, value=1, shape=[-1, 1], dtype=pre_ids.dtype),
y=layers.increment(
x=step_idx, value=1.0, in_place=False))
x=step_idx, value=1.0, in_place=False),
axis=0)
pre_src_attn_bias = layers.sequence_expand(
x=trg_src_attn_bias, y=pre_ids)
pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_ids)
print caches[0]["k"].shape
x=trg_src_attn_bias, y=pre_scores)
pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
pre_caches = [{
"k": layers.sequence_expand(
x=cache["k"], y=pre_ids),
x=cache["k"], y=pre_scores),
"v": layers.sequence_expand(
x=cache["v"], y=pre_ids),
x=cache["v"], y=pre_scores),
} for cache in caches]
print pre_caches[0]["k"].shape
layers.Print(pre_ids)
# layers.Print(pre_enc_output)
# layers.Print(pre_src_attn_bias)
# layers.Print(pre_caches[0]["k"])
# layers.Print(pre_caches[0]["v"])
# layers.Print(slf_attn_post_softmax_shape)
logits = wrap_decoder(
trg_vocab_size,
max_in_len,
......@@ -662,26 +669,42 @@ def fast_decode(
caches=pre_caches)
topk_scores, topk_indices = layers.topk(logits, k=beam_size)
accu_scores = layers.elementwise_add(
x=pre_scores, y=layers.log(x=layers.softmax(topk_scores)))
x=layers.log(x=layers.softmax(topk_scores)),
y=layers.reshape(
pre_scores, shape=[-1]),
axis=0)
# beam_search op uses lod to distinguish branches.
topk_indices = layers.lod_reset(topk_indices, pre_ids)
selected_ids, selected_scores = layers.beam_search(
pre_ids=pre_ids,
ids=topk_indices,
scores=accu_scores,
beam_size=beam_size,
end_id=eos_idx)
layers.increment(x=step_idx, value=1.0, in_place=True)
# update states
layers.array_write(selected_ids, i=step_idx)
layers.array_write(selected_scores, i=step_idx)
layers.array_write(selected_ids, i=step_idx, array=ids)
layers.array_write(selected_scores, i=step_idx, array=scores)
layers.assign(pre_src_attn_bias, trg_src_attn_bias)
layers.assign(pre_enc_output, enc_output)
for i in range(n_layer):
layers.assign(pre_caches[i]["k"], caches[i]["k"])
layers.assign(pre_caches[i]["v"], caches[i]["v"])
layers.assign(
slf_attn_pre_softmax_shape + attn_pre_softmax_shape_delta,
slf_attn_pre_softmax_shape)
layers.assign(
layers.elementwise_add(
x=slf_attn_post_softmax_shape,
y=attn_post_softmax_shape_delta),
slf_attn_post_softmax_shape)
max_len_cond = layers.less_than(x=step_idx, y=max_len)
all_finish_cond = layers.less_than(x=step_idx, y=max_len)
layers.logical_or(x=max_len_cond, y=all_finish_cond, out=cond)
beam_search()
finished_ids, finished_scores = layers.beam_search_decode(ids, scores)
return finished_ids, finished_scores
finished_ids, finished_scores = beam_search()
return finished_ids, finished_scores
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册