提交 833a0157 编写于 作者: G guosheng

Use create_global_var instead of fill_constant in __init__ to make it...

Use create_global_var instead of fill_constant in __init__ to make it compatible between dygraph and static-graph.
上级 bc039c59
...@@ -113,7 +113,7 @@ def do_predict(args): ...@@ -113,7 +113,7 @@ def do_predict(args):
for data in data_loader(): for data in data_loader():
finished_seq = model.test(inputs=flatten(data))[0] finished_seq = model.test(inputs=flatten(data))[0]
finished_seq = finished_seq[:, :, np.newaxis] if len( finished_seq = finished_seq[:, :, np.newaxis] if len(
finished_seq.shape == 2) else finished_seq finished_seq.shape) == 2 else finished_seq
finished_seq = np.transpose(finished_seq, [0, 2, 1]) finished_seq = np.transpose(finished_seq, [0, 2, 1])
for ins in finished_seq: for ins in finished_seq:
for beam_idx, beam in enumerate(ins): for beam_idx, beam in enumerate(ins):
......
...@@ -168,6 +168,7 @@ class SampleInfo(object): ...@@ -168,6 +168,7 @@ class SampleInfo(object):
def __init__(self, i, lens): def __init__(self, i, lens):
self.i = i self.i = i
self.lens = lens self.lens = lens
self.max_len = lens[0]
def get_ranges(self, min_length=None, max_length=None, truncate=False): def get_ranges(self, min_length=None, max_length=None, truncate=False):
ranges = [] ranges = []
......
...@@ -247,6 +247,8 @@ class GreedyEmbeddingHelper(fluid.layers.GreedyEmbeddingHelper): ...@@ -247,6 +247,8 @@ class GreedyEmbeddingHelper(fluid.layers.GreedyEmbeddingHelper):
self.start_token_value = start_tokens self.start_token_value = start_tokens
super(GreedyEmbeddingHelper, self).__init__(embedding_fn, start_tokens, super(GreedyEmbeddingHelper, self).__init__(embedding_fn, start_tokens,
end_token) end_token)
self.end_token = fluid.layers.create_global_var(
shape=[1], dtype="int64", value=end_token, persistable=True)
def initialize(self, batch_ref=None): def initialize(self, batch_ref=None):
if getattr(self, "need_convert_start_tokens", False): if getattr(self, "need_convert_start_tokens", False):
...@@ -319,7 +321,7 @@ class AttentionGreedyInferModel(AttentionModel): ...@@ -319,7 +321,7 @@ class AttentionGreedyInferModel(AttentionModel):
encoder_padding_mask = (src_mask - 1.0) * 1e9 encoder_padding_mask = (src_mask - 1.0) * 1e9
encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1]) encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1])
# dynamic decoding with beam search # dynamic decoding with greedy search
rs, _ = self.greedy_search_decoder( rs, _ = self.greedy_search_decoder(
inits=decoder_initial_states, inits=decoder_initial_states,
encoder_output=encoder_output, encoder_output=encoder_output,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册