提交 8b5c739a 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #249 from ranqiu92/mt_with_external_memory2

fix a bug of external_memory.
...@@ -35,6 +35,8 @@ class ExternalMemory(object): ...@@ -35,6 +35,8 @@ class ExternalMemory(object):
sequence layer has sequence length indicating the number sequence layer has sequence length indicating the number
of memory slots, and size as memory slot size. of memory slots, and size as memory slot size.
:type boot_layer: LayerOutput :type boot_layer: LayerOutput
:param initial_weight: Initializer for addressing weights.
:type initial_weight: LayerOutput
:param readonly: If true, the memory is read-only, and write function cannot :param readonly: If true, the memory is read-only, and write function cannot
be called. Default is false. be called. Default is false.
:type readonly: bool :type readonly: bool
...@@ -49,6 +51,7 @@ class ExternalMemory(object): ...@@ -49,6 +51,7 @@ class ExternalMemory(object):
name, name,
mem_slot_size, mem_slot_size,
boot_layer, boot_layer,
initial_weight,
readonly=False, readonly=False,
enable_interpolation=True): enable_interpolation=True):
self.name = name self.name = name
...@@ -57,11 +60,7 @@ class ExternalMemory(object): ...@@ -57,11 +60,7 @@ class ExternalMemory(object):
self.enable_interpolation = enable_interpolation self.enable_interpolation = enable_interpolation
self.external_memory = paddle.layer.memory( self.external_memory = paddle.layer.memory(
name=self.name, size=self.mem_slot_size, boot_layer=boot_layer) name=self.name, size=self.mem_slot_size, boot_layer=boot_layer)
# prepare a constant (zero) intializer for addressing weights self.initial_weight = initial_weight
self.zero_addressing_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=boot_layer, size=1),
slope=0.0,
intercept=0.0)
# set memory to constant when readonly=True # set memory to constant when readonly=True
if self.readonly: if self.readonly:
self.updated_external_memory = paddle.layer.mixed( self.updated_external_memory = paddle.layer.mixed(
...@@ -111,7 +110,7 @@ class ExternalMemory(object): ...@@ -111,7 +110,7 @@ class ExternalMemory(object):
last_addressing_weight = paddle.layer.memory( last_addressing_weight = paddle.layer.memory(
name=self.name + "_addressing_weight_" + head_name, name=self.name + "_addressing_weight_" + head_name,
size=1, size=1,
boot_layer=self.zero_addressing_init) boot_layer=self.initial_weight)
interpolated_weight = paddle.layer.interpolation( interpolated_weight = paddle.layer.interpolation(
name=self.name + "_addressing_weight_" + head_name, name=self.name + "_addressing_weight_" + head_name,
input=[addressing_weight, addressing_weight], input=[addressing_weight, addressing_weight],
......
...@@ -125,7 +125,15 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size, ...@@ -125,7 +125,15 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size,
bounded_memory_perturbation bounded_memory_perturbation
], ],
act=paddle.activation.Linear()) act=paddle.activation.Linear())
bounded_memory_weight_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=bounded_memory_init, size=1),
slope=0.0,
intercept=0.0)
unbounded_memory_init = source_context unbounded_memory_init = source_context
unbounded_memory_weight_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=unbounded_memory_init, size=1),
slope=0.0,
intercept=0.0)
# prepare step function for reccurent group # prepare step function for reccurent group
def recurrent_decoder_step(cur_embedding): def recurrent_decoder_step(cur_embedding):
...@@ -136,12 +144,14 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size, ...@@ -136,12 +144,14 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size,
name="bounded_memory", name="bounded_memory",
mem_slot_size=size, mem_slot_size=size,
boot_layer=bounded_memory_init, boot_layer=bounded_memory_init,
initial_weight=bounded_memory_weight_init,
readonly=False, readonly=False,
enable_interpolation=True) enable_interpolation=True)
unbounded_memory = ExternalMemory( unbounded_memory = ExternalMemory(
name="unbounded_memory", name="unbounded_memory",
mem_slot_size=size * 2, mem_slot_size=size * 2,
boot_layer=unbounded_memory_init, boot_layer=unbounded_memory_init,
initial_weight=unbounded_memory_weight_init,
readonly=True, readonly=True,
enable_interpolation=False) enable_interpolation=False)
# write bounded memory # write bounded memory
...@@ -154,7 +164,7 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size, ...@@ -154,7 +164,7 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size,
size=size, size=size,
act=paddle.activation.Tanh(), act=paddle.activation.Tanh(),
bias_attr=False) bias_attr=False)
# read unbounded memory (i.e. attention mechanism) # read unbounded memory (i.e. attention mechanism)
context = unbounded_memory.read(key_for_unbounded_memory) context = unbounded_memory.read(key_for_unbounded_memory)
# gated recurrent unit # gated recurrent unit
gru_inputs = paddle.layer.fc( gru_inputs = paddle.layer.fc(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册