提交 ee46d363 编写于 作者: R ranqiu

initial weight boot

上级 848bb8ab
......@@ -35,6 +35,8 @@ class ExternalMemory(object):
sequence layer has sequence length indicating the number
of memory slots, and size as memory slot size.
:type boot_layer: LayerOutput
:param initial_weight: Initializer for addressing weights.
:type initial_weight: LayerOutput
:param readonly: If true, the memory is read-only, and write function cannot
be called. Default is false.
:type readonly: bool
......@@ -49,6 +51,7 @@ class ExternalMemory(object):
name,
mem_slot_size,
boot_layer,
initial_weight,
readonly=False,
enable_interpolation=True):
self.name = name
......@@ -57,11 +60,7 @@ class ExternalMemory(object):
self.enable_interpolation = enable_interpolation
self.external_memory = paddle.layer.memory(
name=self.name, size=self.mem_slot_size, boot_layer=boot_layer)
# prepare a constant (zero) intializer for addressing weights
self.zero_addressing_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=boot_layer, size=1),
slope=0.0,
intercept=0.0)
self.initial_weight = initial_weight
# set memory to constant when readonly=True
if self.readonly:
self.updated_external_memory = paddle.layer.mixed(
......@@ -111,7 +110,7 @@ class ExternalMemory(object):
last_addressing_weight = paddle.layer.memory(
name=self.name + "_addressing_weight_" + head_name,
size=1,
boot_layer=self.zero_addressing_init)
boot_layer=self.initial_weight)
interpolated_weight = paddle.layer.interpolation(
name=self.name + "_addressing_weight_" + head_name,
input=[addressing_weight, addressing_weight],
......
......@@ -125,7 +125,15 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size,
bounded_memory_perturbation
],
act=paddle.activation.Linear())
bounded_memory_weight_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=bounded_memory_init, size=1),
slope=0.0,
intercept=0.0)
unbounded_memory_init = source_context
unbounded_memory_weight_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=unbounded_memory_init, size=1),
slope=0.0,
intercept=0.0)
# prepare step function for reccurent group
def recurrent_decoder_step(cur_embedding):
......@@ -136,12 +144,14 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size,
name="bounded_memory",
mem_slot_size=size,
boot_layer=bounded_memory_init,
initial_weight = bounded_memory_weight_init,
readonly=False,
enable_interpolation=True)
unbounded_memory = ExternalMemory(
name="unbounded_memory",
mem_slot_size=size * 2,
boot_layer=unbounded_memory_init,
initial_weight = unbounded_memory_weight_init,
readonly=True,
enable_interpolation=False)
# write bounded memory
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册