提交 343c0a13 编写于 作者: A Aston Zhang

fix error

上级 e2f969de
...@@ -102,6 +102,7 @@ for i in range(len(input_seqs)): ...@@ -102,6 +102,7 @@ for i in range(len(input_seqs)):
Y[i] = nd.array(output_vocab.to_indices(output_seqs[i]), ctx=ctx) Y[i] = nd.array(output_vocab.to_indices(output_seqs[i]), ctx=ctx)
dataset = gluon.data.ArrayDataset(X, Y) dataset = gluon.data.ArrayDataset(X, Y)
``` ```
### 编码器、含注意力机制的解码器和解码器初始状态 ### 编码器、含注意力机制的解码器和解码器初始状态
...@@ -166,7 +167,7 @@ class Decoder(Block): ...@@ -166,7 +167,7 @@ class Decoder(Block):
single_layer_state = [state[0][-1].expand_dims(0)] single_layer_state = [state[0][-1].expand_dims(0)]
encoder_outputs = encoder_outputs.reshape((self.max_seq_len, 1, encoder_outputs = encoder_outputs.reshape((self.max_seq_len, 1,
self.encoder_hidden_dim)) self.encoder_hidden_dim))
# hidden尺寸: [(1, 1, decoder_hidden_dim)] # single_layer_state尺寸: [(1, 1, decoder_hidden_dim)]
# hidden_broadcast尺寸: (max_seq_len, 1, decoder_hidden_dim) # hidden_broadcast尺寸: (max_seq_len, 1, decoder_hidden_dim)
hidden_broadcast = nd.broadcast_axis(single_layer_state[0], axis=0, hidden_broadcast = nd.broadcast_axis(single_layer_state[0], axis=0,
size=self.max_seq_len) size=self.max_seq_len)
...@@ -243,7 +244,7 @@ def translate(encoder, decoder, decoder_init_state, fr_ens, ctx, max_seq_len): ...@@ -243,7 +244,7 @@ def translate(encoder, decoder, decoder_init_state, fr_ens, ctx, max_seq_len):
encoder_outputs, encoder_state = encoder(inputs.expand_dims(0), encoder_outputs, encoder_state = encoder(inputs.expand_dims(0),
encoder_state) encoder_state)
encoder_outputs = encoder_outputs.flatten() encoder_outputs = encoder_outputs.flatten()
# 码器的第一个输入为BOS字符。 # 码器的第一个输入为BOS字符。
decoder_input = nd.array([output_vocab.token_to_idx[BOS]], ctx=ctx) decoder_input = nd.array([output_vocab.token_to_idx[BOS]], ctx=ctx)
decoder_state = decoder_init_state(encoder_state[0]) decoder_state = decoder_init_state(encoder_state[0])
output_tokens = [] output_tokens = []
...@@ -295,7 +296,7 @@ def train(encoder, decoder, decoder_init_state, max_seq_len, ctx, eval_fr_ens): ...@@ -295,7 +296,7 @@ def train(encoder, decoder, decoder_init_state, max_seq_len, ctx, eval_fr_ens):
# encoder_outputs尺寸: (max_seq_len, encoder_hidden_dim) # encoder_outputs尺寸: (max_seq_len, encoder_hidden_dim)
encoder_outputs = encoder_outputs.flatten() encoder_outputs = encoder_outputs.flatten()
# 码器的第一个输入为BOS字符。 # 码器的第一个输入为BOS字符。
decoder_input = nd.array([output_vocab.token_to_idx[BOS]], decoder_input = nd.array([output_vocab.token_to_idx[BOS]],
ctx=ctx) ctx=ctx)
decoder_state = decoder_init_state(encoder_state[0]) decoder_state = decoder_init_state(encoder_state[0])
...@@ -320,10 +321,14 @@ def train(encoder, decoder, decoder_init_state, max_seq_len, ctx, eval_fr_ens): ...@@ -320,10 +321,14 @@ def train(encoder, decoder, decoder_init_state, max_seq_len, ctx, eval_fr_ens):
h, remainder = divmod((cur_time - prev_time).seconds, 3600) h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60) m, s = divmod(remainder, 60)
time_str = 'Time %02d:%02d:%02d' % (h, m, s) time_str = 'Time %02d:%02d:%02d' % (h, m, s)
print_loss_avg = total_loss / epoch_period / len(data_iter) if epoch == 1:
print_loss_avg = total_loss / len(data_iter)
else:
print_loss_avg = total_loss / epoch_period / len(data_iter)
loss_str = 'Epoch %d, Loss %f, ' % (epoch, print_loss_avg) loss_str = 'Epoch %d, Loss %f, ' % (epoch, print_loss_avg)
print(loss_str + time_str) print(loss_str + time_str)
total_loss = 0.0 if epoch != 1:
total_loss = 0.0
prev_time = cur_time prev_time = cur_time
translate(encoder, decoder, decoder_init_state, eval_fr_ens, ctx, translate(encoder, decoder, decoder_init_state, eval_fr_ens, ctx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册