external_memory.py 8.3 KB
Newer Older
1 2 3 4 5 6 7
"""
    External neural memory class.
"""
import paddle.v2 as paddle


class ExternalMemory(object):
8
    """External neural memory class.
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25

    A simplified Neural Turing Machines (NTM) with only content-based
    addressing (including content addressing and interpolation, but excluding
    convolutional shift and sharpening). It serves as an external differential
    memory bank, with differential write/read head controllers to store
    and read information dynamically. Simple feedforward networks are
    used as the write/read head controllers.

    The ExternalMemory class could be utilized by many neural network structures
    to easily expand their memory bandwidth and accomplish a long-term memory
    handling. Besides, some existing mechanism can be realized directly with
    the ExternalMemory class, e.g. the attention mechanism in Seq2Seq (i.e. an
    unbounded external memory).

    Besides, the ExternalMemory class must be used together with
    paddle.layer.recurrent_group (within its step function). It can never be
    used in a standalone manner.
C
caoying03 已提交
26

27 28 29 30 31 32 33 34 35 36 37
    For more details, please refer to
    `Neural Turing Machines <https://arxiv.org/abs/1410.5401>`_.

    :param name: Memory name.
    :type name: basestring
    :param mem_slot_size: Size of memory slot/vector.
    :type mem_slot_size: int
    :param boot_layer: Boot layer for initializing the external memory. The
                       sequence layer has sequence length indicating the number
                       of memory slots, and size as memory slot size.
    :type boot_layer: LayerOutput
R
ranqiu 已提交
38 39
    :param initial_weight: Initializer for addressing weights.
    :type initial_weight: LayerOutput
40 41 42 43 44 45 46 47 48 49 50 51 52 53
    :param readonly: If true, the memory is read-only, and write function cannot
                     be called. Default is false.
    :type readonly: bool
    :param enable_interpolation: If set true, the read/write addressing weights
                                 will be interpolated with the weights in the
                                 last step, with the affine coefficients being
                                 a learnable gate function.
    :type enable_interpolation: bool
    """

    def __init__(self,
                 name,
                 mem_slot_size,
                 boot_layer,
R
ranqiu 已提交
54
                 initial_weight,
55 56 57 58 59 60 61
                 readonly=False,
                 enable_interpolation=True):
        self.name = name
        self.mem_slot_size = mem_slot_size
        self.readonly = readonly
        self.enable_interpolation = enable_interpolation
        self.external_memory = paddle.layer.memory(
62
            name=self.name, size=self.mem_slot_size, boot_layer=boot_layer)
R
ranqiu 已提交
63
        self.initial_weight = initial_weight
64 65 66 67 68 69 70 71 72 73
        # set memory to constant when readonly=True
        if self.readonly:
            self.updated_external_memory = paddle.layer.mixed(
                name=self.name,
                input=[
                    paddle.layer.identity_projection(input=self.external_memory)
                ],
                size=self.mem_slot_size)

    def _content_addressing(self, key_vector):
74
        """Get write/read head's addressing weights via content-based addressing.
75 76
        """
        # content-based addressing: a=tanh(W*M + U*key)
77 78 79 80
        key_projection = paddle.layer.fc(input=key_vector,
                                         size=self.mem_slot_size,
                                         act=paddle.activation.Linear(),
                                         bias_attr=False)
81 82
        key_proj_expanded = paddle.layer.expand(
            input=key_projection, expand_as=self.external_memory)
83 84 85 86
        memory_projection = paddle.layer.fc(input=self.external_memory,
                                            size=self.mem_slot_size,
                                            act=paddle.activation.Linear(),
                                            bias_attr=False)
87 88 89 90 91 92 93 94 95 96 97 98
        merged_projection = paddle.layer.addto(
            input=[key_proj_expanded, memory_projection],
            act=paddle.activation.Tanh())
        # softmax addressing weight: w=softmax(v^T a)
        addressing_weight = paddle.layer.fc(
            input=merged_projection,
            size=1,
            act=paddle.activation.SequenceSoftmax(),
            bias_attr=False)
        return addressing_weight

    def _interpolation(self, head_name, key_vector, addressing_weight):
99
        """Interpolate between previous and current addressing weights.
100 101
        """
        # prepare interpolation scalar gate: g=sigmoid(W*key)
102 103 104 105
        gate = paddle.layer.fc(input=key_vector,
                               size=1,
                               act=paddle.activation.Sigmoid(),
                               bias_attr=False)
106 107 108 109
        # interpolation: w_t = g*w_t+(1-g)*w_{t-1}
        last_addressing_weight = paddle.layer.memory(
            name=self.name + "_addressing_weight_" + head_name,
            size=1,
R
ranqiu 已提交
110
            boot_layer=self.initial_weight)
111 112
        interpolated_weight = paddle.layer.interpolation(
            name=self.name + "_addressing_weight_" + head_name,
113
            input=[last_addressing_weight, addressing_weight],
114 115
            weight=paddle.layer.expand(
                input=gate, expand_as=addressing_weight))
116 117 118
        return interpolated_weight

    def _get_addressing_weight(self, head_name, key_vector):
119
        """Get final addressing weights for read/write heads, including content
120 121 122 123 124 125 126 127 128 129 130
        addressing and interpolation.
        """
        # current content-based addressing
        addressing_weight = self._content_addressing(key_vector)
        # interpolation with previous addresing weight
        if self.enable_interpolation:
            return self._interpolation(head_name, key_vector, addressing_weight)
        else:
            return addressing_weight

    def write(self, write_key):
131
        """Write onto the external memory.
132 133 134 135 136 137 138 139 140 141 142 143
        It cannot be called if "readonly" set True.

        :param write_key: Key vector for write heads to generate writing
                          content and addressing signals.
        :type write_key: LayerOutput
        """
        # check readonly
        if self.readonly:
            raise ValueError("ExternalMemory with readonly=True cannot write.")
        # get addressing weight for write head
        write_weight = self._get_addressing_weight("write_head", write_key)
        # prepare add_vector and erase_vector
144 145 146 147 148 149 150 151
        erase_vector = paddle.layer.fc(input=write_key,
                                       size=self.mem_slot_size,
                                       act=paddle.activation.Sigmoid(),
                                       bias_attr=False)
        add_vector = paddle.layer.fc(input=write_key,
                                     size=self.mem_slot_size,
                                     act=paddle.activation.Sigmoid(),
                                     bias_attr=False)
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
        erase_vector_expand = paddle.layer.expand(
            input=erase_vector, expand_as=self.external_memory)
        add_vector_expand = paddle.layer.expand(
            input=add_vector, expand_as=self.external_memory)
        # prepare scaled add part and erase part
        scaled_erase_vector_expand = paddle.layer.scaling(
            weight=write_weight, input=erase_vector_expand)
        erase_memory_part = paddle.layer.mixed(
            input=paddle.layer.dotmul_operator(
                a=self.external_memory,
                b=scaled_erase_vector_expand,
                scale=-1.0))
        add_memory_part = paddle.layer.scaling(
            weight=write_weight, input=add_vector_expand)
        # update external memory
        self.updated_external_memory = paddle.layer.addto(
            input=[self.external_memory, add_memory_part, erase_memory_part],
            name=self.name)

    def read(self, read_key):
172
        """Read from the external memory.
173 174 175 176 177 178 179 180 181 182 183 184 185 186

        :param write_key: Key vector for read head to generate addressing
                          signals.
        :type write_key: LayerOutput
        :return: Content (vector) read from external memory.
        :rtype: LayerOutput
        """
        # get addressing weight for write head
        read_weight = self._get_addressing_weight("read_head", read_key)
        # read content from external memory
        scaled = paddle.layer.scaling(
            weight=read_weight, input=self.updated_external_memory)
        return paddle.layer.pooling(
            input=scaled, pooling_type=paddle.pooling.Sum())