提交 1853d687 编写于 作者: X Xing Wu 提交者: Guo Sheng

dygraph_seq2seq_fix_infer (#4191)

上级 9ab3fa06
...@@ -269,9 +269,8 @@ class AttentionModel(fluid.dygraph.Layer): ...@@ -269,9 +269,8 @@ class AttentionModel(fluid.dygraph.Layer):
enc_outputs = self.tile_beam_merge_with_batch(enc_outputs) enc_outputs = self.tile_beam_merge_with_batch(enc_outputs)
enc_padding_mask = self.tile_beam_merge_with_batch(enc_padding_mask) enc_padding_mask = self.tile_beam_merge_with_batch(enc_padding_mask)
batch_beam_shape = (self.batch_size, self.beam_size) batch_beam_shape = (self.batch_size, self.beam_size)
batch_beam_shape_1 = (self.batch_size, self.beam_size, 1)
vocab_size_tensor = to_variable(np.full((1), self.tar_vocab_size)) vocab_size_tensor = to_variable(np.full((1), self.tar_vocab_size))
start_token_tensor = to_variable(np.full(batch_beam_shape_1, self.beam_start_token, dtype='int64')) # remove last dim 1 in v1.7 start_token_tensor = to_variable(np.full(batch_beam_shape, self.beam_start_token, dtype='int64'))
end_token_tensor = to_variable(np.full(batch_beam_shape, self.beam_end_token, dtype='int64')) end_token_tensor = to_variable(np.full(batch_beam_shape, self.beam_end_token, dtype='int64'))
step_input = self.tar_embeder(start_token_tensor) step_input = self.tar_embeder(start_token_tensor)
input_feed = to_variable(np.zeros((self.batch_size, self.hidden_size), dtype='float32')) input_feed = to_variable(np.zeros((self.batch_size, self.hidden_size), dtype='float32'))
...@@ -348,7 +347,7 @@ class AttentionModel(fluid.dygraph.Layer): ...@@ -348,7 +347,7 @@ class AttentionModel(fluid.dygraph.Layer):
dec_hidden, dec_cell = new_dec_hidden, new_dec_cell dec_hidden, dec_cell = new_dec_hidden, new_dec_cell
beam_finished = next_finished beam_finished = next_finished
beam_state_log_probs = next_log_probs beam_state_log_probs = next_log_probs
step_input = self.tar_embeder(fluid.layers.unsqueeze(token_indices, 2)) # remove unsqueeze in v1.7 step_input = self.tar_embeder(token_indices)
predicted_ids.append(token_indices) predicted_ids.append(token_indices)
parent_ids.append(beam_indices) parent_ids.append(beam_indices)
...@@ -359,4 +358,4 @@ class AttentionModel(fluid.dygraph.Layer): ...@@ -359,4 +358,4 @@ class AttentionModel(fluid.dygraph.Layer):
return predicted_ids return predicted_ids
else: else:
print("not support mode ", self.mode) print("not support mode ", self.mode)
raise Exception("not support mode: " + self.mode) raise Exception("not support mode: " + self.mode)
\ No newline at end of file
...@@ -202,7 +202,7 @@ class BaseModel(fluid.dygraph.Layer): ...@@ -202,7 +202,7 @@ class BaseModel(fluid.dygraph.Layer):
batch_beam_shape = (self.batch_size, self.beam_size) batch_beam_shape = (self.batch_size, self.beam_size)
#batch_beam_shape_1 = (self.batch_size, self.beam_size, 1) #batch_beam_shape_1 = (self.batch_size, self.beam_size, 1)
vocab_size_tensor = to_variable(np.full((1), self.tar_vocab_size)) vocab_size_tensor = to_variable(np.full((1), self.tar_vocab_size))
start_token_tensor = to_variable(np.full(batch_beam_shape, self.beam_start_token, dtype='int64')) # remove last dim 1 in v1.7 start_token_tensor = to_variable(np.full(batch_beam_shape, self.beam_start_token, dtype='int64'))
end_token_tensor = to_variable(np.full(batch_beam_shape, self.beam_end_token, dtype='int64')) end_token_tensor = to_variable(np.full(batch_beam_shape, self.beam_end_token, dtype='int64'))
step_input = self.tar_embeder(start_token_tensor) step_input = self.tar_embeder(start_token_tensor)
beam_finished = to_variable(np.full(batch_beam_shape, 0, dtype='float32')) beam_finished = to_variable(np.full(batch_beam_shape, 0, dtype='float32'))
...@@ -271,7 +271,7 @@ class BaseModel(fluid.dygraph.Layer): ...@@ -271,7 +271,7 @@ class BaseModel(fluid.dygraph.Layer):
dec_hidden, dec_cell = new_dec_hidden, new_dec_cell dec_hidden, dec_cell = new_dec_hidden, new_dec_cell
beam_finished = next_finished beam_finished = next_finished
beam_state_log_probs = next_log_probs beam_state_log_probs = next_log_probs
step_input = self.tar_embeder(fluid.layers.unsqueeze(token_indices, 2)) # remove unsqueeze in v1.7 step_input = self.tar_embeder(token_indices) # remove unsqueeze in v1.7
predicted_ids.append(token_indices) predicted_ids.append(token_indices)
parent_ids.append(beam_indices) parent_ids.append(beam_indices)
...@@ -282,4 +282,4 @@ class BaseModel(fluid.dygraph.Layer): ...@@ -282,4 +282,4 @@ class BaseModel(fluid.dygraph.Layer):
return predicted_ids return predicted_ids
else: else:
print("not support mode ", self.mode) print("not support mode ", self.mode)
raise Exception("not support mode: " + self.mode) raise Exception("not support mode: " + self.mode)
\ No newline at end of file
...@@ -70,7 +70,6 @@ def infer(): ...@@ -70,7 +70,6 @@ def infer():
# So we can set dropout to 0 # So we can set dropout to 0
if args.attention: if args.attention:
model = AttentionModel( model = AttentionModel(
"attention_model",
hidden_size, hidden_size,
src_vocab_size, src_vocab_size,
tar_vocab_size, tar_vocab_size,
...@@ -82,7 +81,6 @@ def infer(): ...@@ -82,7 +81,6 @@ def infer():
mode='beam_search') mode='beam_search')
else: else:
model = BaseModel( model = BaseModel(
"base_model",
hidden_size, hidden_size,
src_vocab_size, src_vocab_size,
tar_vocab_size, tar_vocab_size,
...@@ -134,11 +132,11 @@ def infer(): ...@@ -134,11 +132,11 @@ def infer():
for batch_id, batch in enumerate(train_data_iter): for batch_id, batch in enumerate(train_data_iter):
input_data_feed, word_num = prepare_input(batch, epoch_id=0) input_data_feed, word_num = prepare_input(batch, epoch_id=0)
# import ipdb; ipdb.set_trace()
outputs = model(input_data_feed) outputs = model(input_data_feed)
for i in range(outputs.shape[0]): for i in range(outputs.shape[0]):
ins = fluid.Variable.numpy(outputs[i]) ins = outputs[i].numpy()
res = [tar_id2vocab[e] for e in ins[:, 0].reshape(-1)] res = [tar_id2vocab[int(e)] for e in ins[:, 0].reshape(-1)]
new_res = [] new_res = []
for ele in res: for ele in res:
if ele == "</s>": if ele == "</s>":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册