Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDocCN
d2l-zh
提交
343c0a13
D
d2l-zh
项目概览
OpenDocCN
/
d2l-zh
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
d2l-zh
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
343c0a13
编写于
2月 10, 2018
作者:
A
Aston Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix error
上级
e2f969de
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
10 addition
and
5 deletion
+10
-5
chapter_natural-language-processing/nmt.md
chapter_natural-language-processing/nmt.md
+10
-5
未找到文件。
chapter_natural-language-processing/nmt.md
浏览文件 @
343c0a13
...
@@ -102,6 +102,7 @@ for i in range(len(input_seqs)):
...
@@ -102,6 +102,7 @@ for i in range(len(input_seqs)):
Y[i] = nd.array(output_vocab.to_indices(output_seqs[i]), ctx=ctx)
Y[i] = nd.array(output_vocab.to_indices(output_seqs[i]), ctx=ctx)
dataset = gluon.data.ArrayDataset(X, Y)
dataset = gluon.data.ArrayDataset(X, Y)
```
```
### 编码器、含注意力机制的解码器和解码器初始状态
### 编码器、含注意力机制的解码器和解码器初始状态
...
@@ -166,7 +167,7 @@ class Decoder(Block):
...
@@ -166,7 +167,7 @@ class Decoder(Block):
single_layer_state = [state[0][-1].expand_dims(0)]
single_layer_state = [state[0][-1].expand_dims(0)]
encoder_outputs = encoder_outputs.reshape((self.max_seq_len, 1,
encoder_outputs = encoder_outputs.reshape((self.max_seq_len, 1,
self.encoder_hidden_dim))
self.encoder_hidden_dim))
#
hidden
尺寸: [(1, 1, decoder_hidden_dim)]
#
single_layer_state
尺寸: [(1, 1, decoder_hidden_dim)]
# hidden_broadcast尺寸: (max_seq_len, 1, decoder_hidden_dim)
# hidden_broadcast尺寸: (max_seq_len, 1, decoder_hidden_dim)
hidden_broadcast = nd.broadcast_axis(single_layer_state[0], axis=0,
hidden_broadcast = nd.broadcast_axis(single_layer_state[0], axis=0,
size=self.max_seq_len)
size=self.max_seq_len)
...
@@ -243,7 +244,7 @@ def translate(encoder, decoder, decoder_init_state, fr_ens, ctx, max_seq_len):
...
@@ -243,7 +244,7 @@ def translate(encoder, decoder, decoder_init_state, fr_ens, ctx, max_seq_len):
encoder_outputs, encoder_state = encoder(inputs.expand_dims(0),
encoder_outputs, encoder_state = encoder(inputs.expand_dims(0),
encoder_state)
encoder_state)
encoder_outputs = encoder_outputs.flatten()
encoder_outputs = encoder_outputs.flatten()
#
编
码器的第一个输入为BOS字符。
#
解
码器的第一个输入为BOS字符。
decoder_input = nd.array([output_vocab.token_to_idx[BOS]], ctx=ctx)
decoder_input = nd.array([output_vocab.token_to_idx[BOS]], ctx=ctx)
decoder_state = decoder_init_state(encoder_state[0])
decoder_state = decoder_init_state(encoder_state[0])
output_tokens = []
output_tokens = []
...
@@ -295,7 +296,7 @@ def train(encoder, decoder, decoder_init_state, max_seq_len, ctx, eval_fr_ens):
...
@@ -295,7 +296,7 @@ def train(encoder, decoder, decoder_init_state, max_seq_len, ctx, eval_fr_ens):
# encoder_outputs尺寸: (max_seq_len, encoder_hidden_dim)
# encoder_outputs尺寸: (max_seq_len, encoder_hidden_dim)
encoder_outputs = encoder_outputs.flatten()
encoder_outputs = encoder_outputs.flatten()
#
编
码器的第一个输入为BOS字符。
#
解
码器的第一个输入为BOS字符。
decoder_input = nd.array([output_vocab.token_to_idx[BOS]],
decoder_input = nd.array([output_vocab.token_to_idx[BOS]],
ctx=ctx)
ctx=ctx)
decoder_state = decoder_init_state(encoder_state[0])
decoder_state = decoder_init_state(encoder_state[0])
...
@@ -320,10 +321,14 @@ def train(encoder, decoder, decoder_init_state, max_seq_len, ctx, eval_fr_ens):
...
@@ -320,10 +321,14 @@ def train(encoder, decoder, decoder_init_state, max_seq_len, ctx, eval_fr_ens):
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
m, s = divmod(remainder, 60)
time_str = 'Time %02d:%02d:%02d' % (h, m, s)
time_str = 'Time %02d:%02d:%02d' % (h, m, s)
print_loss_avg = total_loss / epoch_period / len(data_iter)
if epoch == 1:
print_loss_avg = total_loss / len(data_iter)
else:
print_loss_avg = total_loss / epoch_period / len(data_iter)
loss_str = 'Epoch %d, Loss %f, ' % (epoch, print_loss_avg)
loss_str = 'Epoch %d, Loss %f, ' % (epoch, print_loss_avg)
print(loss_str + time_str)
print(loss_str + time_str)
total_loss = 0.0
if epoch != 1:
total_loss = 0.0
prev_time = cur_time
prev_time = cur_time
translate(encoder, decoder, decoder_init_state, eval_fr_ens, ctx,
translate(encoder, decoder, decoder_init_state, eval_fr_ens, ctx,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录