提交 8b5c739a 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #249 from ranqiu92/mt_with_external_memory2

fix a bug of external_memory.
......@@ -35,6 +35,8 @@ class ExternalMemory(object):
sequence layer has sequence length indicating the number
of memory slots, and size as memory slot size.
:type boot_layer: LayerOutput
:param initial_weight: Initializer for addressing weights.
:type initial_weight: LayerOutput
:param readonly: If true, the memory is read-only, and write function cannot
be called. Default is false.
:type readonly: bool
......@@ -49,6 +51,7 @@ class ExternalMemory(object):
name,
mem_slot_size,
boot_layer,
initial_weight,
readonly=False,
enable_interpolation=True):
self.name = name
......@@ -57,11 +60,7 @@ class ExternalMemory(object):
self.enable_interpolation = enable_interpolation
self.external_memory = paddle.layer.memory(
name=self.name, size=self.mem_slot_size, boot_layer=boot_layer)
# prepare a constant (zero) intializer for addressing weights
self.zero_addressing_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=boot_layer, size=1),
slope=0.0,
intercept=0.0)
self.initial_weight = initial_weight
# set memory to constant when readonly=True
if self.readonly:
self.updated_external_memory = paddle.layer.mixed(
......@@ -111,7 +110,7 @@ class ExternalMemory(object):
last_addressing_weight = paddle.layer.memory(
name=self.name + "_addressing_weight_" + head_name,
size=1,
boot_layer=self.zero_addressing_init)
boot_layer=self.initial_weight)
interpolated_weight = paddle.layer.interpolation(
name=self.name + "_addressing_weight_" + head_name,
input=[addressing_weight, addressing_weight],
......
......@@ -125,7 +125,15 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size,
bounded_memory_perturbation
],
act=paddle.activation.Linear())
bounded_memory_weight_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=bounded_memory_init, size=1),
slope=0.0,
intercept=0.0)
unbounded_memory_init = source_context
unbounded_memory_weight_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=unbounded_memory_init, size=1),
slope=0.0,
intercept=0.0)
# prepare step function for reccurent group
def recurrent_decoder_step(cur_embedding):
......@@ -136,12 +144,14 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size,
name="bounded_memory",
mem_slot_size=size,
boot_layer=bounded_memory_init,
initial_weight=bounded_memory_weight_init,
readonly=False,
enable_interpolation=True)
unbounded_memory = ExternalMemory(
name="unbounded_memory",
mem_slot_size=size * 2,
boot_layer=unbounded_memory_init,
initial_weight=unbounded_memory_weight_init,
readonly=True,
enable_interpolation=False)
# write bounded memory
......@@ -154,7 +164,7 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size,
size=size,
act=paddle.activation.Tanh(),
bias_attr=False)
# read unbounded memory (i.e. attention mechanism)
# read unbounded memory (i.e. attention mechanism)
context = unbounded_memory.read(key_for_unbounded_memory)
# gated recurrent unit
gru_inputs = paddle.layer.fc(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册