提交 eb25dcec 编写于 作者: Y Yu Yang

Polish Infer

上级 937cf5f0
...@@ -190,6 +190,6 @@ fast_decoder_data_input_fields = ( ...@@ -190,6 +190,6 @@ fast_decoder_data_input_fields = (
"trg_word", "trg_word",
"init_score", "init_score",
"trg_src_attn_bias", ) "trg_src_attn_bias", )
fast_decoder_util_input_fields = ( # fast_decoder_util_input_fields = (
"trg_slf_attn_pre_softmax_shape_delta", # "trg_slf_attn_pre_softmax_shape_delta",
"trg_slf_attn_post_softmax_shape_delta", ) # "trg_slf_attn_post_softmax_shape_delta", )
...@@ -424,8 +424,8 @@ def py_infer(test_data, trg_idx2word, use_wordpiece): ...@@ -424,8 +424,8 @@ def py_infer(test_data, trg_idx2word, use_wordpiece):
print(" ".join([trg_idx2word[idx] for idx in seq])) print(" ".join([trg_idx2word[idx] for idx in seq]))
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, def prepare_batch_input(insts, data_input_names, src_pad_idx, bos_idx, n_head,
bos_idx, n_head, d_model, place): d_model, place):
""" """
Put all padded data needed by beam search decoder into a dict. Put all padded data needed by beam search decoder into a dict.
""" """
...@@ -435,25 +435,9 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, ...@@ -435,25 +435,9 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64") trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32") [1, 1, 1, 1]).astype("float32")
trg_word = trg_word.reshape(-1, 1, 1)
# These shape tensors are used in reshape_op. src_word = src_word.reshape(-1, src_max_len, 1)
src_data_shape = np.array([-1, src_max_len, d_model], dtype="int32") src_pos = src_pos.reshape(-1, src_max_len, 1)
trg_data_shape = np.array([-1, 1, d_model], dtype="int32")
src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array(
[-1] + list(src_slf_attn_bias.shape[1:]), dtype="int32")
trg_slf_attn_pre_softmax_shape = np.array(
[-1, 1], dtype="int32") # only the first time step
trg_slf_attn_post_softmax_shape = np.array(
[-1, n_head, 1, 1], dtype="int32") # only the first time step
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
[-1] + list(trg_src_attn_bias.shape[1:]), dtype="int32")
# These inputs are used to change the shapes in the loop of while op.
attn_pre_softmax_shape_delta = np.array([0, 1], dtype="int32")
attn_post_softmax_shape_delta = np.array([0, 0, 0, 1], dtype="int32")
def to_lodtensor(data, place, lod=None): def to_lodtensor(data, place, lod=None):
data_tensor = fluid.LoDTensor() data_tensor = fluid.LoDTensor()
...@@ -465,7 +449,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, ...@@ -465,7 +449,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
# beamsearch_op must use tensors with lod # beamsearch_op must use tensors with lod
init_score = to_lodtensor( init_score = to_lodtensor(
np.zeros_like( np.zeros_like(
trg_word, dtype="float32"), trg_word, dtype="float32").reshape(-1, 1),
place, [range(trg_word.shape[0] + 1)] * 2) place, [range(trg_word.shape[0] + 1)] * 2)
trg_word = to_lodtensor(trg_word, place, [range(trg_word.shape[0] + 1)] * 2) trg_word = to_lodtensor(trg_word, place, [range(trg_word.shape[0] + 1)] * 2)
...@@ -474,16 +458,8 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, ...@@ -474,16 +458,8 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
src_word, src_pos, src_slf_attn_bias, trg_word, init_score, src_word, src_pos, src_slf_attn_bias, trg_word, init_score,
trg_src_attn_bias trg_src_attn_bias
])) ]))
util_input_dict = dict(
zip(util_input_names, [
src_data_shape, src_slf_attn_pre_softmax_shape,
src_slf_attn_post_softmax_shape, trg_data_shape,
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape,
attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta
]))
input_dict = dict(data_input_dict.items() + util_input_dict.items()) input_dict = dict(data_input_dict.items())
return input_dict return input_dict
...@@ -515,7 +491,6 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece): ...@@ -515,7 +491,6 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece):
for batch_id, data in enumerate(test_data.batch_generator()): for batch_id, data in enumerate(test_data.batch_generator()):
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_data_input_fields + fast_decoder_data_input_fields, data, encoder_data_input_fields + fast_decoder_data_input_fields,
encoder_util_input_fields + fast_decoder_util_input_fields,
ModelHyperParams.eos_idx, ModelHyperParams.bos_idx, ModelHyperParams.eos_idx, ModelHyperParams.bos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model, place) ModelHyperParams.n_head, ModelHyperParams.d_model, place)
seq_ids, seq_scores = exe.run(infer_program, seq_ids, seq_scores = exe.run(infer_program,
......
...@@ -197,6 +197,7 @@ def prepare_encoder(src_word, ...@@ -197,6 +197,7 @@ def prepare_encoder(src_word,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=word_emb_param_name, name=word_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5))) initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5) src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
src_pos_enc = layers.embedding( src_pos_enc = layers.embedding(
src_pos, src_pos,
...@@ -453,8 +454,7 @@ def wrap_encoder(src_vocab_size, ...@@ -453,8 +454,7 @@ def wrap_encoder(src_vocab_size,
if enc_inputs is None: if enc_inputs is None:
# This is used to implement independent encoder program in inference. # This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias = \ src_word, src_pos, src_slf_attn_bias = \
make_all_inputs(encoder_data_input_fields + make_all_inputs(encoder_data_input_fields)
encoder_util_input_fields)
else: else:
src_word, src_pos, src_slf_attn_bias = \ src_word, src_pos, src_slf_attn_bias = \
enc_inputs enc_inputs
...@@ -554,12 +554,8 @@ def fast_decode( ...@@ -554,12 +554,8 @@ def fast_decode(
enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head, enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid, d_key, d_value, d_model, d_inner_hid,
dropout_rate, weight_sharing) dropout_rate, weight_sharing)
start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \ start_tokens, init_scores, trg_src_attn_bias = \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \ make_all_inputs(fast_decoder_data_input_fields )
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta = \
make_all_inputs(fast_decoder_data_input_fields +
fast_decoder_util_input_fields)
def beam_search(): def beam_search():
max_len = layers.fill_constant( max_len = layers.fill_constant(
...@@ -570,6 +566,8 @@ def fast_decode( ...@@ -570,6 +566,8 @@ def fast_decode(
while_op = layers.While(cond) while_op = layers.While(cond)
# array states will be stored for each step. # array states will be stored for each step.
ids = layers.array_write(start_tokens, step_idx) ids = layers.array_write(start_tokens, step_idx)
ids_flatten = layers.array_write(
layers.reshape(start_tokens, (-1, 1)), step_idx)
scores = layers.array_write(init_scores, step_idx) scores = layers.array_write(init_scores, step_idx)
# cell states will be overwrited at each step. # cell states will be overwrited at each step.
# caches contains states of history steps to reduce redundant # caches contains states of history steps to reduce redundant
...@@ -604,7 +602,7 @@ def fast_decode( ...@@ -604,7 +602,7 @@ def fast_decode(
x=layers.fill_constant_batch_size_like( x=layers.fill_constant_batch_size_like(
input=pre_enc_output, # cann't use pre_ids here since it has lod input=pre_enc_output, # cann't use pre_ids here since it has lod
value=1, value=1,
shape=[-1, 1], shape=[-1, 1, 1],
dtype=pre_ids.dtype), dtype=pre_ids.dtype),
y=layers.increment( y=layers.increment(
x=step_idx, value=1.0, in_place=False), x=step_idx, value=1.0, in_place=False),
...@@ -620,12 +618,11 @@ def fast_decode( ...@@ -620,12 +618,11 @@ def fast_decode(
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
weight_sharing, weight_sharing,
dec_inputs=( dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
enc_output=pre_enc_output, enc_output=pre_enc_output,
caches=pre_caches) caches=pre_caches)
logits = layers.reshape(logits, (-1, trg_vocab_size))
topk_scores, topk_indices = layers.topk( topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size) input=layers.softmax(logits), k=beam_size)
accu_scores = layers.elementwise_add( accu_scores = layers.elementwise_add(
...@@ -642,8 +639,11 @@ def fast_decode( ...@@ -642,8 +639,11 @@ def fast_decode(
scores=accu_scores, scores=accu_scores,
beam_size=beam_size, beam_size=beam_size,
end_id=eos_idx) end_id=eos_idx)
layers.increment(x=step_idx, value=1.0, in_place=True) layers.increment(x=step_idx, value=1.0, in_place=True)
# update states # update states
layers.array_write(selected_ids, i=step_idx, array=ids_flatten)
selected_ids = layers.reshape(selected_ids, shape=(-1, 1, 1))
layers.array_write(selected_ids, i=step_idx, array=ids) layers.array_write(selected_ids, i=step_idx, array=ids)
layers.array_write(selected_scores, i=step_idx, array=scores) layers.array_write(selected_scores, i=step_idx, array=scores)
layers.assign(pre_src_attn_bias, trg_src_attn_bias) layers.assign(pre_src_attn_bias, trg_src_attn_bias)
...@@ -651,23 +651,12 @@ def fast_decode( ...@@ -651,23 +651,12 @@ def fast_decode(
for i in range(n_layer): for i in range(n_layer):
layers.assign(pre_caches[i]["k"], caches[i]["k"]) layers.assign(pre_caches[i]["k"], caches[i]["k"])
layers.assign(pre_caches[i]["v"], caches[i]["v"]) layers.assign(pre_caches[i]["v"], caches[i]["v"])
layers.assign(
layers.elementwise_add(
x=slf_attn_pre_softmax_shape,
y=attn_pre_softmax_shape_delta),
slf_attn_pre_softmax_shape)
layers.assign(
layers.elementwise_add(
x=slf_attn_post_softmax_shape,
y=attn_post_softmax_shape_delta),
slf_attn_post_softmax_shape)
length_cond = layers.less_than(x=step_idx, y=max_len) length_cond = layers.less_than(x=step_idx, y=max_len)
finish_cond = layers.logical_not(layers.is_empty(x=selected_ids)) finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
layers.logical_and(x=length_cond, y=finish_cond, out=cond) layers.logical_and(x=length_cond, y=finish_cond, out=cond)
finished_ids, finished_scores = layers.beam_search_decode( finished_ids, finished_scores = layers.beam_search_decode(
ids, scores, beam_size=beam_size, end_id=eos_idx) ids_flatten, scores, beam_size=beam_size, end_id=eos_idx)
return finished_ids, finished_scores return finished_ids, finished_scores
finished_ids, finished_scores = beam_search() finished_ids, finished_scores = beam_search()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册