提交 4eda2803 编写于 作者: G guosheng

Code clean for fast_decoder of Transformer

上级 3cbf0f73
...@@ -33,9 +33,9 @@ class TrainTaskConfig(object): ...@@ -33,9 +33,9 @@ class TrainTaskConfig(object):
class InferTaskConfig(object): class InferTaskConfig(object):
use_gpu = False use_gpu = True
# the number of examples in one run for sequence generation. # the number of examples in one run for sequence generation.
batch_size = 2 batch_size = 10
# the parameters for beam search. # the parameters for beam search.
beam_size = 5 beam_size = 5
max_out_len = 256 max_out_len = 256
...@@ -108,7 +108,7 @@ def merge_cfg_from_list(cfg_list, g_cfgs): ...@@ -108,7 +108,7 @@ def merge_cfg_from_list(cfg_list, g_cfgs):
# consistent with some ops' infer-shape output in compile time, such as the # consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder. # sequence_expand op used in beamsearch decoder.
batch_size = -1 batch_size = -1
# The placeholder for squence length in compile time. # The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs. # Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in # The shapes here act as placeholder and are set to pass the infer-shape in
......
...@@ -88,7 +88,8 @@ def translate_batch(exe, ...@@ -88,7 +88,8 @@ def translate_batch(exe,
output_unk=True): output_unk=True):
""" """
Run the encoder program once and run the decoder program multiple times to Run the encoder program once and run the decoder program multiple times to
implement beam search externally. implement beam search externally. This is deprecated since a faster beam
search decoder based solely on Fluid operators has been added.
""" """
# Prepare data for encoder and run the encoder. # Prepare data for encoder and run the encoder.
enc_in_data = pad_batch_data( enc_in_data = pad_batch_data(
...@@ -255,8 +256,6 @@ def translate_batch(exe, ...@@ -255,8 +256,6 @@ def translate_batch(exe,
predict_all = exe.run(decoder, predict_all = exe.run(decoder,
feed=dict(zip(dec_in_names, dec_in_data)), feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0] fetch_list=dec_out_names)[0]
print predict_all.reshape(
[len(beam_inst_map) * beam_size, i + 1, -1])[:, -1, :]
predict_all = np.log( predict_all = np.log(
predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1]) predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1])
[:, -1, :]) [:, -1, :])
...@@ -275,19 +274,11 @@ def translate_batch(exe, ...@@ -275,19 +274,11 @@ def translate_batch(exe,
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:] top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[:: top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[::
-1]] -1]]
# top_scores_ids = np.asarray(
# sorted(
# top_scores_ids,
# lambda x, y: x / predict_all.shape[-1] - y / predict_all.shape[-1]
# )) # sort by pre_branch and score to compare with fast_infer
top_scores = predict[top_scores_ids] top_scores = predict[top_scores_ids]
scores[beam_idx] = top_scores scores[beam_idx] = top_scores
prev_branchs[beam_idx].append(top_scores_ids / prev_branchs[beam_idx].append(top_scores_ids /
predict_all.shape[-1]) predict_all.shape[-1])
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1]) next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1])
print prev_branchs[beam_idx][-1]
print next_ids[beam_idx][-1]
print top_scores
if next_ids[beam_idx][-1][0] != eos_idx: if next_ids[beam_idx][-1][0] != eos_idx:
active_beams.append(beam_idx) active_beams.append(beam_idx)
if len(active_beams) == 0: if len(active_beams) == 0:
...@@ -308,7 +299,32 @@ def translate_batch(exe, ...@@ -308,7 +299,32 @@ def translate_batch(exe,
return seqs, scores[:, :n_best].tolist() return seqs, scores[:, :n_best].tolist()
def infer(args): def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)
def py_infer(test_data, trg_idx2word):
"""
Inference by beam search implented by python, while the calculations from
symbols to probilities execute by Fluid operators.
"""
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -355,48 +371,7 @@ def infer(args): ...@@ -355,48 +371,7 @@ def infer(args):
encoder_program = encoder_program.inference_optimize() encoder_program = encoder_program.inference_optimize()
decoder_program = decoder_program.inference_optimize() decoder_program = decoder_program.inference_optimize()
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
batch_size=args.batch_size,
use_token_batch=False,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)
for batch_id, data in enumerate(test_data.batch_generator()): for batch_id, data in enumerate(test_data.batch_generator()):
if batch_id != 0:
continue
batch_seqs, batch_scores = translate_batch( batch_seqs, batch_scores = translate_batch(
exe, exe,
[item[0] for item in data], [item[0] for item in data],
...@@ -425,14 +400,12 @@ def infer(args): ...@@ -425,14 +400,12 @@ def infer(args):
scores = batch_scores[i] scores = batch_scores[i]
for seq in seqs: for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq])) print(" ".join([trg_idx2word[idx] for idx in seq]))
print scores
exit(0)
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
bos_idx, n_head, d_model, place): bos_idx, n_head, d_model, place):
""" """
Put all padded data needed by inference into a dict. Put all padded data needed by beam search decoder into a dict.
""" """
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
...@@ -492,18 +465,21 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, ...@@ -492,18 +465,21 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
return input_dict return input_dict
def fast_infer(args): def fast_infer(test_data, trg_idx2word):
"""
Inference by beam search decoder based solely on Fluid operators.
"""
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
ids, scores = fast_decoder( out_ids, out_scores = fast_decoder(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size, ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
InferTaskConfig.beam_size, InferTaskConfig.max_out_len, ModelHyperParams.weight_sharing, InferTaskConfig.beam_size,
ModelHyperParams.eos_idx) InferTaskConfig.max_out_len, ModelHyperParams.eos_idx)
fluid.io.load_vars( fluid.io.load_vars(
exe, exe,
...@@ -514,28 +490,7 @@ def fast_infer(args): ...@@ -514,28 +490,7 @@ def fast_infer(args):
# This is used here to set dropout to the test mode. # This is used here to set dropout to the test mode.
infer_program = fluid.default_main_program().inference_optimize() infer_program = fluid.default_main_program().inference_optimize()
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
batch_size=args.batch_size,
use_token_batch=False,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
for batch_id, data in enumerate(test_data.batch_generator()): for batch_id, data in enumerate(test_data.batch_generator()):
if batch_id != 0:
continue
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_data_input_fields + fast_decoder_data_input_fields, data, encoder_data_input_fields + fast_decoder_data_input_fields,
encoder_util_input_fields + fast_decoder_util_input_fields, encoder_util_input_fields + fast_decoder_util_input_fields,
...@@ -543,10 +498,16 @@ def fast_infer(args): ...@@ -543,10 +498,16 @@ def fast_infer(args):
ModelHyperParams.n_head, ModelHyperParams.d_model, place) ModelHyperParams.n_head, ModelHyperParams.d_model, place)
seq_ids, seq_scores = exe.run(infer_program, seq_ids, seq_scores = exe.run(infer_program,
feed=data_input, feed=data_input,
fetch_list=[ids, scores], fetch_list=[out_ids, out_scores],
return_numpy=False) return_numpy=False)
# print np.array(seq_ids)#, np.array(seq_scores) # How to parse the results:
# print seq_ids.lod()#, seq_scores.lod() # Suppose the lod of seq_ids is:
# [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
# then from lod[0]:
# there are 2 source sentences, beam width is 3.
# from lod[1]:
# the first source sentence has 3 hyps; the lengths are 12, 12, 16
# the second source sentence has 3 hyps; the lengths are 14, 13, 15
hyps = [[] for i in range(len(data))] hyps = [[] for i in range(len(data))]
scores = [[] for i in range(len(data))] scores = [[] for i in range(len(data))]
for i in range(len(seq_ids.lod()[0]) - 1): # for each source sentence for i in range(len(seq_ids.lod()[0]) - 1): # for each source sentence
...@@ -557,16 +518,38 @@ def fast_infer(args): ...@@ -557,16 +518,38 @@ def fast_infer(args):
sub_end = seq_ids.lod()[1][start + j + 1] sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(" ".join([ hyps[i].append(" ".join([
trg_idx2word[idx] trg_idx2word[idx]
for idx in np.array(seq_ids)[sub_start:sub_end] for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end])
])) ]))
scores[i].append(np.array(seq_scores)[sub_end - 1]) scores[i].append(np.array(seq_scores)[sub_end - 1])
print hyps[i] print hyps[i][-1]
print scores[i] if len(hyps[i]) >= InferTaskConfig.n_best:
print len(hyps[i]), [len(hyp.split()) for hyp in hyps[i]] break
exit(0)
def infer(args, inferencer=fast_infer):
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
batch_size=args.batch_size,
use_token_batch=False,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
inferencer(test_data, trg_idx2word)
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
fast_infer(args) infer(args)
# infer(args)
...@@ -6,8 +6,6 @@ import paddle.fluid.layers as layers ...@@ -6,8 +6,6 @@ import paddle.fluid.layers as layers
from config import * from config import *
FLAG = False
def position_encoding_init(n_position, d_pos_vec): def position_encoding_init(n_position, d_pos_vec):
""" """
...@@ -103,12 +101,6 @@ def multi_head_attention(queries, ...@@ -103,12 +101,6 @@ def multi_head_attention(queries,
""" """
scaled_q = layers.scale(x=q, scale=d_model**-0.5) scaled_q = layers.scale(x=q, scale=d_model**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True) product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
# global FLAG
# if FLAG and attn_bias:
# print "hehehehehe"
# layers.Print(product, message="product")
# layers.Print(attn_bias, message="bias")
# FLAG = False
weights = layers.reshape( weights = layers.reshape(
x=layers.elementwise_add( x=layers.elementwise_add(
x=product, y=attn_bias) if attn_bias else product, x=product, y=attn_bias) if attn_bias else product,
...@@ -117,19 +109,9 @@ def multi_head_attention(queries, ...@@ -117,19 +109,9 @@ def multi_head_attention(queries,
act="softmax") act="softmax")
weights = layers.reshape( weights = layers.reshape(
x=weights, shape=product.shape, actual_shape=post_softmax_shape) x=weights, shape=product.shape, actual_shape=post_softmax_shape)
# global FLAG
# if FLAG:
# print "hehehehehe"
# layers.Print(scaled_q)
# layers.Print(k)
# layers.Print(v)
# layers.Print(product)
# layers.Print(weights)
# FLAG = False
if dropout_rate: if dropout_rate:
weights = layers.dropout( weights = layers.dropout(
weights, dropout_prob=dropout_rate, is_test=False) weights, dropout_prob=dropout_rate, is_test=False)
out = layers.matmul(weights, v) out = layers.matmul(weights, v)
return out return out
...@@ -138,13 +120,7 @@ def multi_head_attention(queries, ...@@ -138,13 +120,7 @@ def multi_head_attention(queries,
if cache is not None: # use cache and concat time steps if cache is not None: # use cache and concat time steps
k = cache["k"] = layers.concat([cache["k"], k], axis=1) k = cache["k"] = layers.concat([cache["k"], k], axis=1)
v = cache["v"] = layers.concat([cache["v"], v], axis=1) v = cache["v"] = layers.concat([cache["v"], v], axis=1)
# global FLAG
# if FLAG:
# print "hehehehehe"
# layers.Print(q)
# layers.Print(k)
# layers.Print(v)
# FLAG = False
q = __split_heads(q, n_head) q = __split_heads(q, n_head)
k = __split_heads(k, n_head) k = __split_heads(k, n_head)
v = __split_heads(v, n_head) v = __split_heads(v, n_head)
...@@ -153,16 +129,12 @@ def multi_head_attention(queries, ...@@ -153,16 +129,12 @@ def multi_head_attention(queries,
dropout_rate) dropout_rate)
out = __combine_heads(ctx_multiheads) out = __combine_heads(ctx_multiheads)
# Project back to the model size. # Project back to the model size.
proj_out = layers.fc(input=out, proj_out = layers.fc(input=out,
size=d_model, size=d_model,
bias_attr=False, bias_attr=False,
num_flatten_dims=2) num_flatten_dims=2)
# global FLAG
# if FLAG:
# print "hehehehehe"
# layers.Print(proj_out)
# FLAG = False
return proj_out return proj_out
...@@ -391,15 +363,22 @@ def decoder(dec_input, ...@@ -391,15 +363,22 @@ def decoder(dec_input,
The decoder is composed of a stack of identical decoder_layer layers. The decoder is composed of a stack of identical decoder_layer layers.
""" """
for i in range(n_layer): for i in range(n_layer):
if i == 0: #n_layer-1:
global FLAG
FLAG = True
dec_output = decoder_layer( dec_output = decoder_layer(
dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias, n_head, dec_input,
d_key, d_value, d_model, d_inner_hid, dropout_rate, enc_output,
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, dec_slf_attn_bias,
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, None dec_enc_attn_bias,
if caches is None else caches[i]) n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape,
None if caches is None else caches[i], )
dec_input = dec_output dec_input = dec_output
return dec_output return dec_output
...@@ -625,12 +604,17 @@ def fast_decode( ...@@ -625,12 +604,17 @@ def fast_decode(
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
weight_sharing,
beam_size, beam_size,
max_out_len, max_out_len,
eos_idx, ): eos_idx, ):
"""
Use beam search to decode. Caches will be used to store states of history
steps which can make the decoding faster.
"""
enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head, enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid, d_key, d_value, d_model, d_inner_hid,
dropout_rate) dropout_rate, weight_sharing)
start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \ start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \ slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \ src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
...@@ -643,16 +627,14 @@ def fast_decode( ...@@ -643,16 +627,14 @@ def fast_decode(
shape=[1], dtype=start_tokens.dtype, value=max_out_len) shape=[1], dtype=start_tokens.dtype, value=max_out_len)
step_idx = layers.fill_constant( step_idx = layers.fill_constant(
shape=[1], dtype=start_tokens.dtype, value=0) shape=[1], dtype=start_tokens.dtype, value=0)
# cond = layers.fill_constant(
# shape=[1], dtype='bool', value=1, force_cpu=True)
cond = layers.less_than(x=step_idx, y=max_len) cond = layers.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond) while_op = layers.While(cond)
# init_scores = layers.fill_constant_batch_size_like( # array states will be stored for each step.
# input=start_tokens, shape=[-1, 1], dtype="float32", value=0)
# array states
ids = layers.array_write(start_tokens, step_idx) ids = layers.array_write(start_tokens, step_idx)
scores = layers.array_write(init_scores, step_idx) scores = layers.array_write(init_scores, step_idx)
# cell states (can be overwrited) # cell states will be overwrited at each step.
# caches contains states of history steps to reduce redundant
# computation in decoder.
caches = [{ caches = [{
"k": layers.fill_constant_batch_size_like( "k": layers.fill_constant_batch_size_like(
input=start_tokens, input=start_tokens,
...@@ -668,9 +650,10 @@ def fast_decode( ...@@ -668,9 +650,10 @@ def fast_decode(
with while_op.block(): with while_op.block():
pre_ids = layers.array_read(array=ids, i=step_idx) pre_ids = layers.array_read(array=ids, i=step_idx)
pre_scores = layers.array_read(array=scores, i=step_idx) pre_scores = layers.array_read(array=scores, i=step_idx)
# sequence_expand can gather sequences according to lod thus can be
# used in beam search to sift states corresponding to selected ids.
pre_src_attn_bias = layers.sequence_expand( pre_src_attn_bias = layers.sequence_expand(
x=trg_src_attn_bias, y=pre_scores) x=trg_src_attn_bias, y=pre_scores)
# layers.Print(pre_src_attn_bias)
pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores) pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
pre_caches = [{ pre_caches = [{
"k": layers.sequence_expand( "k": layers.sequence_expand(
...@@ -687,13 +670,6 @@ def fast_decode( ...@@ -687,13 +670,6 @@ def fast_decode(
y=layers.increment( y=layers.increment(
x=step_idx, value=1.0, in_place=False), x=step_idx, value=1.0, in_place=False),
axis=0) axis=0)
# layers.Print(pre_ids, summarize=10)
# layers.Print(pre_pos, summarize=10)
# layers.Print(pre_enc_output, summarize=10)
# layers.Print(pre_src_attn_bias, summarize=10)
# layers.Print(pre_caches[0]["k"], summarize=10)
# layers.Print(pre_caches[0]["v"], summarize=10)
# layers.Print(slf_attn_post_softmax_shape)
logits = wrap_decoder( logits = wrap_decoder(
trg_vocab_size, trg_vocab_size,
max_in_len, max_in_len,
...@@ -704,19 +680,16 @@ def fast_decode( ...@@ -704,19 +680,16 @@ def fast_decode(
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
weight_sharing,
dec_inputs=( dec_inputs=(
pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape, pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape, src_attn_post_softmax_shape), src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
enc_output=pre_enc_output, enc_output=pre_enc_output,
caches=pre_caches) caches=pre_caches)
# layers.Print(logits)
topk_scores, topk_indices = layers.topk( topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size) input=layers.softmax(logits), k=beam_size)
# layers.Print(topk_scores)
# layers.Print(topk_indices)
accu_scores = layers.elementwise_add( accu_scores = layers.elementwise_add(
# x=layers.log(x=layers.softmax(topk_scores)),
x=layers.log(topk_scores), x=layers.log(topk_scores),
y=layers.reshape( y=layers.reshape(
pre_scores, shape=[-1]), pre_scores, shape=[-1]),
...@@ -739,9 +712,6 @@ def fast_decode( ...@@ -739,9 +712,6 @@ def fast_decode(
for i in range(n_layer): for i in range(n_layer):
layers.assign(pre_caches[i]["k"], caches[i]["k"]) layers.assign(pre_caches[i]["k"], caches[i]["k"])
layers.assign(pre_caches[i]["v"], caches[i]["v"]) layers.assign(pre_caches[i]["v"], caches[i]["v"])
layers.Print(selected_ids)
layers.Print(selected_scores)
# layers.Print(caches[-1]["k"])
layers.assign( layers.assign(
layers.elementwise_add( layers.elementwise_add(
x=slf_attn_pre_softmax_shape, x=slf_attn_pre_softmax_shape,
...@@ -755,12 +725,8 @@ def fast_decode( ...@@ -755,12 +725,8 @@ def fast_decode(
length_cond = layers.less_than(x=step_idx, y=max_len) length_cond = layers.less_than(x=step_idx, y=max_len)
finish_cond = layers.logical_not(layers.is_empty(x=selected_ids)) finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
# layers.Print(length_cond)
# layers.Print(finish_cond)
layers.logical_and(x=length_cond, y=finish_cond, out=cond) layers.logical_and(x=length_cond, y=finish_cond, out=cond)
layers.Print(step_idx)
# finished_ids, finished_scores = layers.beam_search_decode(ids, scores,
# eos_idx)
finished_ids, finished_scores = layers.beam_search_decode( finished_ids, finished_scores = layers.beam_search_decode(
ids, scores, beam_size=beam_size, end_id=eos_idx) ids, scores, beam_size=beam_size, end_id=eos_idx)
return finished_ids, finished_scores return finished_ids, finished_scores
......
...@@ -198,7 +198,8 @@ class DataReader(object): ...@@ -198,7 +198,8 @@ class DataReader(object):
for line in f_obj: for line in f_obj:
fields = line.strip().split(self._delimiter) fields = line.strip().split(self._delimiter)
if len(fields) != 2 or (self._only_src and len(fields) != 1): if (not self._only_src and len(fields) != 2) or (self._only_src and
len(fields) != 1):
continue continue
sample_words = [] sample_words = []
...@@ -275,7 +276,7 @@ class DataReader(object): ...@@ -275,7 +276,7 @@ class DataReader(object):
for sample_idx in self._sample_idxs: for sample_idx in self._sample_idxs:
if self._only_src: if self._only_src:
yield (self._src_seq_ids[sample_idx]) yield (self._src_seq_ids[sample_idx], )
else: else:
yield (self._src_seq_ids[sample_idx], yield (self._src_seq_ids[sample_idx],
self._trg_seq_ids[sample_idx][:-1], self._trg_seq_ids[sample_idx][:-1],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册