Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
2c0e4781
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2c0e4781
编写于
9月 16, 2017
作者:
C
Cao Ying
提交者:
GitHub
9月 16, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #252 from ranqiu92/mt_with_external_memory
update README.md and train.py
上级
c754dbec
f456031b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
8 addition
and
3 deletion
+8
-3
mt_with_external_memory/README.md
mt_with_external_memory/README.md
+7
-2
mt_with_external_memory/train.py
mt_with_external_memory/train.py
+1
-1
未找到文件。
mt_with_external_memory/README.md
浏览文件 @
2c0e4781
...
...
@@ -142,6 +142,7 @@ class ExternalMemory(object):
name
,
mem_slot_size
,
boot_layer
,
initial_weight
,
readonly
=
False
,
enable_interpolation
=
True
):
""" Initialization.
...
...
@@ -154,6 +155,8 @@ class ExternalMemory(object):
sequence layer has sequence length indicating the number
of memory slots, and size as memory slot size.
:type boot_layer: LayerOutput
:param initial_weight: Initializer for addressing weights.
:type initial_weight: LayerOutput
:param readonly: If true, the memory is read-only, and write function cannot
be called. Default is false.
:type readonly: bool
...
...
@@ -205,7 +208,7 @@ class ExternalMemory(object):
-
`_content_addressing`
: 通过基于内容的寻址,计算得到读写操作的寻址强度。
-
`_interpolation`
: 通过插值寻址(当前寻址强度和上一时间步寻址强度的线性加权),更新当前寻址强度。
-
`_get_addressing_weight`
: 调用上述两个寻址操作,获得对存储
导员
的读写操作的最终寻址强度。
-
`_get_addressing_weight`
: 调用上述两个寻址操作,获得对存储
单元
的读写操作的最终寻址强度。
对外接口包含:
...
...
@@ -214,6 +217,7 @@ class ExternalMemory(object):
-
输入参数
`name`
: 外部记忆单元名,不同实例的相同命名将共享同一外部记忆单元。
-
输入参数
`mem_slot_size`
: 单个记忆槽(向量)的维度。
-
输入参数
`boot_layer`
: 用于内存槽初始化的层。需为序列类型,序列长度表明记忆槽的数量。
-
输入参数
`initial_weight`
: 用于初始化寻址强度。
-
输入参数
`readonly`
: 是否打开只读模式(例如打开只读模式,该实例可用于注意力机制)。打开只读模式,
`write`
方法不可被调用。
-
输入参数
`enable_interpolation`
: 是否允许插值寻址(例如当用于注意力机制时,需要关闭插值寻址)。
-
`write`
: 写操作。
...
...
@@ -230,7 +234,6 @@ class ExternalMemory(object):
self
.
external_memory
=
paddle
.
layer
.
memory
(
name
=
self
.
name
,
size
=
self
.
mem_slot_size
,
is_seq
=
True
,
boot_layer
=
boot_layer
)
```
-
`ExternalMemory`
类的寻址逻辑通过
`_content_addressing`
和
`_interpolation`
两个私有方法实现。读和写操作通过
`read`
和
`write`
两个函数实现,包括上述的寻址操作。并且读和写的寻址独立进行,不同于
\[
[
2
](
#参考文献
)
\]
中的二者共享同一个寻址强度,目的是为了使得该类更通用。
...
...
@@ -349,6 +352,7 @@ def memory_enhanced_seq2seq(encoder_input, decoder_input, decoder_target,
name="unbounded_memory",
mem_slot_size=size * 2,
boot_layer=unbounded_memory_init,
initial_weight=unbounded_memory_weight_init,
readonly=True,
enable_interpolation=False)
```
...
...
@@ -359,6 +363,7 @@ def memory_enhanced_seq2seq(encoder_input, decoder_input, decoder_target,
name="bounded_memory",
mem_slot_size=size,
boot_layer=bounded_memory_init,
initial_weight=bounded_memory_weight_init,
readonly=False,
enable_interpolation=True)
```
...
...
mt_with_external_memory/train.py
浏览文件 @
2c0e4781
...
...
@@ -135,7 +135,7 @@ def train():
sys
.
stdout
.
flush
()
if
isinstance
(
event
,
paddle
.
event
.
EndPass
):
result
=
trainer
.
test
(
reader
=
test_batch_reader
,
feeding
=
feeding
)
print
"Pass: %d, TestCost: %f, %s"
%
(
event
.
pass_id
,
even
t
.
cost
,
print
"Pass: %d, TestCost: %f, %s"
%
(
event
.
pass_id
,
resul
t
.
cost
,
result
.
metrics
)
with
gzip
.
open
(
"checkpoints/params.pass-%d.tar.gz"
%
event
.
pass_id
,
'w'
)
as
f
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录