未验证 提交 6fc865e1 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #1013 from guoshengCS/add-transformer-BeamsearchDecoder-clean

Add transformer beamsearch decoder
......@@ -38,7 +38,7 @@ class InferTaskConfig(object):
batch_size = 10
# the parameters for beam search.
beam_size = 5
max_length = 256
max_out_len = 256
# the number of decoded sentences to output.
n_best = 1
# the flags indicating whether to output the special tokens.
......@@ -104,23 +104,28 @@ def merge_cfg_from_list(cfg_list, g_cfgs):
break
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
# The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1]
"src_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
"src_word": [(batch_size * seq_len, 1L), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1]
"src_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
"src_pos": [(batch_size * seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias":
[(1, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"],
"src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This shape input is used to reshape the output of embedding layer.
"src_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
......@@ -129,24 +134,23 @@ input_descs = {
"src_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
"trg_word": [(batch_size * seq_len, 1L), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
"trg_pos": [(batch_size * seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(1, ModelHyperParams.n_head,
(ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"],
"trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(1, ModelHyperParams.n_head,
(ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"],
"trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This shape input is used to reshape the output of embedding layer.
"trg_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
......@@ -162,15 +166,18 @@ input_descs = {
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(1, (ModelHyperParams.max_length + 1),
ModelHyperParams.d_model), "float32"],
"enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
"lbl_word": [(batch_size * seq_len, 1L), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"],
"lbl_weight": [(batch_size * seq_len, 1L), "float32"],
# These inputs are used to change the shape tensor in beam-search decoder.
"trg_slf_attn_pre_softmax_shape_delta": [(2L, ), "int32"],
"trg_slf_attn_post_softmax_shape_delta": [(4L, ), "int32"],
"init_score": [(batch_size, 1L), "float32"],
}
# Names of word embedding table which might be reused for weight sharing.
......@@ -205,3 +212,12 @@ decoder_util_input_fields = (
label_data_input_fields = (
"lbl_word",
"lbl_weight", )
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
"trg_word",
"init_score",
"trg_src_attn_bias", )
fast_decoder_util_input_fields = decoder_util_input_fields + (
"trg_slf_attn_pre_softmax_shape_delta",
"trg_slf_attn_post_softmax_shape_delta", )
......@@ -7,6 +7,7 @@ import paddle.fluid as fluid
import model
from model import wrap_encoder as encoder
from model import wrap_decoder as decoder
from model import fast_decode as fast_decoder
from config import *
from train import pad_batch_data
import reader
......@@ -87,7 +88,8 @@ def translate_batch(exe,
output_unk=True):
"""
Run the encoder program once and run the decoder program multiple times to
implement beam search externally.
implement beam search externally. This is deprecated since a faster beam
search decoder based solely on Fluid operators has been added.
"""
# Prepare data for encoder and run the encoder.
enc_in_data = pad_batch_data(
......@@ -297,7 +299,32 @@ def translate_batch(exe,
return seqs, scores[:, :n_best].tolist()
def infer(args):
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)
def py_infer(test_data, trg_idx2word):
"""
Inference by beam search implented by python, while the calculations from
symbols to probilities execute by Fluid operators.
"""
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -341,49 +368,8 @@ def infer(args):
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=decoder_params)
# This is used here to set dropout to the test mode.
encoder_program = fluid.io.get_inference_program(
target_vars=[enc_output], main_program=encoder_program)
decoder_program = fluid.io.get_inference_program(
target_vars=[predict], main_program=decoder_program)
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
batch_size=args.batch_size,
use_token_batch=False,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)
encoder_program = encoder_program.inference_optimize()
decoder_program = decoder_program.inference_optimize()
for batch_id, data in enumerate(test_data.batch_generator()):
batch_seqs, batch_scores = translate_batch(
......@@ -397,7 +383,7 @@ def infer(args):
(decoder_data_input_fields[-1], ),
[predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_length,
InferTaskConfig.max_out_len,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
......@@ -416,6 +402,154 @@ def infer(args):
print(" ".join([trg_idx2word[idx] for idx in seq]))
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
bos_idx, n_head, d_model, place):
"""
Put all padded data needed by beam search decoder into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
# start tokens
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
# These shape tensors are used in reshape_op.
src_data_shape = np.array([-1, src_max_len, d_model], dtype="int32")
trg_data_shape = np.array([-1, 1, d_model], dtype="int32")
src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array(
[-1] + list(src_slf_attn_bias.shape[1:]), dtype="int32")
trg_slf_attn_pre_softmax_shape = np.array(
[-1, 1], dtype="int32") # only the first time step
trg_slf_attn_post_softmax_shape = np.array(
[-1, n_head, 1, 1], dtype="int32") # only the first time step
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
[-1] + list(trg_src_attn_bias.shape[1:]), dtype="int32")
# These inputs are used to change the shapes in the loop of while op.
attn_pre_softmax_shape_delta = np.array([0, 1], dtype="int32")
attn_post_softmax_shape_delta = np.array([0, 0, 0, 1], dtype="int32")
def to_lodtensor(data, place, lod=None):
data_tensor = fluid.LoDTensor()
data_tensor.set(data, place)
if lod is not None:
data_tensor.set_lod(lod)
return data_tensor
# beamsearch_op must use tensors with lod
init_score = to_lodtensor(
np.zeros_like(
trg_word, dtype="float32"),
place, [range(trg_word.shape[0] + 1)] * 2)
trg_word = to_lodtensor(trg_word, place, [range(trg_word.shape[0] + 1)] * 2)
data_input_dict = dict(
zip(data_input_names, [
src_word, src_pos, src_slf_attn_bias, trg_word, init_score,
trg_src_attn_bias
]))
util_input_dict = dict(
zip(util_input_names, [
src_data_shape, src_slf_attn_pre_softmax_shape,
src_slf_attn_post_softmax_shape, trg_data_shape,
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape,
attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta
]))
input_dict = dict(data_input_dict.items() + util_input_dict.items())
return input_dict
def fast_infer(test_data, trg_idx2word):
"""
Inference by beam search decoder based solely on Fluid operators.
"""
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
out_ids, out_scores = fast_decoder(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.weight_sharing, InferTaskConfig.beam_size,
InferTaskConfig.max_out_len, ModelHyperParams.eos_idx)
fluid.io.load_vars(
exe,
InferTaskConfig.model_path,
vars=filter(lambda var: isinstance(var, fluid.framework.Parameter),
fluid.default_main_program().list_vars()))
# This is used here to set dropout to the test mode.
infer_program = fluid.default_main_program().inference_optimize()
for batch_id, data in enumerate(test_data.batch_generator()):
data_input = prepare_batch_input(
data, encoder_data_input_fields + fast_decoder_data_input_fields,
encoder_util_input_fields + fast_decoder_util_input_fields,
ModelHyperParams.eos_idx, ModelHyperParams.bos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model, place)
seq_ids, seq_scores = exe.run(infer_program,
feed=data_input,
fetch_list=[out_ids, out_scores],
return_numpy=False)
# How to parse the results:
# Suppose the lod of seq_ids is:
# [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
# then from lod[0]:
# there are 2 source sentences, beam width is 3.
# from lod[1]:
# the first source sentence has 3 hyps; the lengths are 12, 12, 16
# the second source sentence has 3 hyps; the lengths are 14, 13, 15
hyps = [[] for i in range(len(data))]
scores = [[] for i in range(len(data))]
for i in range(len(seq_ids.lod()[0]) - 1): # for each source sentence
start = seq_ids.lod()[0][i]
end = seq_ids.lod()[0][i + 1]
for j in range(end - start): # for each candidate
sub_start = seq_ids.lod()[1][start + j]
sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(" ".join([
trg_idx2word[idx]
for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end])
]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print hyps[i][-1]
if len(hyps[i]) >= InferTaskConfig.n_best:
break
def infer(args, inferencer=fast_infer):
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
batch_size=args.batch_size,
use_token_batch=False,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
inferencer(test_data, trg_idx2word)
if __name__ == "__main__":
args = parse_args()
infer(args)
......@@ -30,7 +30,8 @@ def multi_head_attention(queries,
n_head=1,
dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None):
post_softmax_shape=None,
cache=None):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
......@@ -116,6 +117,10 @@ def multi_head_attention(queries,
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
k = cache["k"] = layers.concat([cache["k"], k], axis=1)
v = cache["v"] = layers.concat([cache["v"], v], axis=1)
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
v = __split_heads(v, n_head)
......@@ -203,7 +208,7 @@ def prepare_encoder(src_word,
enc_input = src_word_emb + src_pos_enc
enc_input = layers.reshape(
x=enc_input,
shape=[-1, src_max_len, src_emb_dim],
shape=[batch_size, seq_len, src_emb_dim],
actual_shape=src_data_shape)
return layers.dropout(
enc_input, dropout_prob=dropout_rate,
......@@ -285,7 +290,8 @@ def decoder_layer(dec_input,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None):
src_attn_post_softmax_shape=None,
cache=None):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
......@@ -301,7 +307,8 @@ def decoder_layer(dec_input,
n_head,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, )
slf_attn_post_softmax_shape,
cache, )
slf_attn_output = post_process_layer(
dec_input,
slf_attn_output,
......@@ -350,7 +357,8 @@ def decoder(dec_input,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None):
src_attn_post_softmax_shape=None,
caches=None):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
......@@ -369,7 +377,8 @@ def decoder(dec_input,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
src_attn_post_softmax_shape,
None if caches is None else caches[i], )
dec_input = dec_output
return dec_output
......@@ -384,6 +393,8 @@ def make_all_inputs(input_fields):
name=input_field,
shape=input_descs[input_field][0],
dtype=input_descs[input_field][1],
lod_level=input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0,
append_batch_size=False)
inputs.append(input_var)
return inputs
......@@ -517,7 +528,8 @@ def wrap_decoder(trg_vocab_size,
dropout_rate,
weight_sharing,
dec_inputs=None,
enc_output=None):
enc_output=None,
caches=None):
"""
The wrapper assembles together all needed layers for the decoder.
"""
......@@ -559,7 +571,8 @@ def wrap_decoder(trg_vocab_size,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
src_attn_post_softmax_shape,
caches, )
# Return logits for training and probs for inference.
if weight_sharing:
predict = layers.reshape(
......@@ -578,3 +591,145 @@ def wrap_decoder(trg_vocab_size,
shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None)
return predict
def fast_decode(
src_vocab_size,
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
weight_sharing,
beam_size,
max_out_len,
eos_idx, ):
"""
Use beam search to decode. Caches will be used to store states of history
steps which can make the decoding faster.
"""
enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
dropout_rate, weight_sharing)
start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta = \
make_all_inputs(fast_decoder_data_input_fields +
fast_decoder_util_input_fields)
def beam_search():
max_len = layers.fill_constant(
shape=[1], dtype=start_tokens.dtype, value=max_out_len)
step_idx = layers.fill_constant(
shape=[1], dtype=start_tokens.dtype, value=0)
cond = layers.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond)
# array states will be stored for each step.
ids = layers.array_write(start_tokens, step_idx)
scores = layers.array_write(init_scores, step_idx)
# cell states will be overwrited at each step.
# caches contains states of history steps to reduce redundant
# computation in decoder.
caches = [{
"k": layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, 0, d_model],
dtype=enc_output.dtype,
value=0),
"v": layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, 0, d_model],
dtype=enc_output.dtype,
value=0)
} for i in range(n_layer)]
with while_op.block():
pre_ids = layers.array_read(array=ids, i=step_idx)
pre_scores = layers.array_read(array=scores, i=step_idx)
# sequence_expand can gather sequences according to lod thus can be
# used in beam search to sift states corresponding to selected ids.
pre_src_attn_bias = layers.sequence_expand(
x=trg_src_attn_bias, y=pre_scores)
pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
pre_caches = [{
"k": layers.sequence_expand(
x=cache["k"], y=pre_scores),
"v": layers.sequence_expand(
x=cache["v"], y=pre_scores),
} for cache in caches]
pre_pos = layers.elementwise_mul(
x=layers.fill_constant_batch_size_like(
input=pre_enc_output, # cann't use pre_ids here since it has lod
value=1,
shape=[-1, 1],
dtype=pre_ids.dtype),
y=layers.increment(
x=step_idx, value=1.0, in_place=False),
axis=0)
logits = wrap_decoder(
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
weight_sharing,
dec_inputs=(
pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
enc_output=pre_enc_output,
caches=pre_caches)
topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size)
accu_scores = layers.elementwise_add(
x=layers.log(topk_scores),
y=layers.reshape(
pre_scores, shape=[-1]),
axis=0)
# beam_search op uses lod to distinguish branches.
topk_indices = layers.lod_reset(topk_indices, pre_ids)
selected_ids, selected_scores = layers.beam_search(
pre_ids=pre_ids,
pre_scores=pre_scores,
ids=topk_indices,
scores=accu_scores,
beam_size=beam_size,
end_id=eos_idx)
layers.increment(x=step_idx, value=1.0, in_place=True)
# update states
layers.array_write(selected_ids, i=step_idx, array=ids)
layers.array_write(selected_scores, i=step_idx, array=scores)
layers.assign(pre_src_attn_bias, trg_src_attn_bias)
layers.assign(pre_enc_output, enc_output)
for i in range(n_layer):
layers.assign(pre_caches[i]["k"], caches[i]["k"])
layers.assign(pre_caches[i]["v"], caches[i]["v"])
layers.assign(
layers.elementwise_add(
x=slf_attn_pre_softmax_shape,
y=attn_pre_softmax_shape_delta),
slf_attn_pre_softmax_shape)
layers.assign(
layers.elementwise_add(
x=slf_attn_post_softmax_shape,
y=attn_post_softmax_shape_delta),
slf_attn_post_softmax_shape)
length_cond = layers.less_than(x=step_idx, y=max_len)
finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
layers.logical_and(x=length_cond, y=finish_cond, out=cond)
finished_ids, finished_scores = layers.beam_search_decode(
ids, scores, beam_size=beam_size, end_id=eos_idx)
return finished_ids, finished_scores
finished_ids, finished_scores = beam_search()
return finished_ids, finished_scores
......@@ -198,7 +198,8 @@ class DataReader(object):
for line in f_obj:
fields = line.strip().split(self._delimiter)
if len(fields) != 2 or (self._only_src and len(fields) != 1):
if (not self._only_src and len(fields) != 2) or (self._only_src and
len(fields) != 1):
continue
sample_words = []
......@@ -275,7 +276,7 @@ class DataReader(object):
for sample_idx in self._sample_idxs:
if self._only_src:
yield (self._src_seq_ids[sample_idx])
yield (self._src_seq_ids[sample_idx], )
else:
yield (self._src_seq_ids[sample_idx],
self._trg_seq_ids[sample_idx][:-1],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册