3.md 35.2 KB
Newer Older
W
wizardforcel 已提交
1 2
# 基于模型的方法

W
wizardforcel 已提交
3
在上一章中,我们讨论了两种基于优化的方法。 我们试图用元学习机制来训练模型,这与人类所见相似。 当然,除了学习新事物的能力外,人类在执行任何任务时还可以访问大量内存。 通过回忆过去的经历和经验,这使我们能够更轻松地学习新任务。 遵循相同的思想过程,设计了基于模型的架构,并添加了外部存储器以快速概括一次学习任务。 在这些方法中,使用存储在外部存储器中的信息,模型仅需几个训练步骤即可收敛。
W
wizardforcel 已提交
4 5 6 7 8 9 10 11 12 13

本章将涵盖以下主题:

*   了解神经图灵机
*   记忆增强神经网络
*   元网络
*   编码练习

# 技术要求

W
wizardforcel 已提交
14
您将需要 Python,Anaconda,Jupyter 笔记本,PyTorch 和 Matplotlib 库在本章中学习和执行项目。
W
wizardforcel 已提交
15

W
wizardforcel 已提交
16
您可以在本书的 [GitHub 存储库](https://github.com/PacktPublishing/Hands-On-One-shot-Learning-with-Python)中找到本章的代码文件。
W
wizardforcel 已提交
17 18 19 20 21 22 23 24

# 了解神经图灵机

在 AI 的早期,该领域主要由一种象征性的处理方法主导。 换句话说,它依靠使用符号和结构以及操纵它们的规则来处理信息。 直到 1980 年代,人工智能领域才采用了另一种方法-连接主义。 连接主义最有前景的建模技术是神经网络。 但是,他们经常遭到两个严厉的批评:

*   神经网络仅接受固定大小的输入,这在输入长度可变的现实生活中不会有太大帮助。
*   神经网络无法将值绑定到我们已知的两个信息系统(人脑和计算机)大量使用的数据结构中的特定位置。 简单来说,在神经网络中,我们无法将特定的权重设置到特定的位置。

W
wizardforcel 已提交
25
第一个问题可以通过在各种任务上实现最先进性能的 RNN 来解决。 通过查看**神经图灵机****NTM**)可以解决第二个问题。 在本节中,我们将讨论 NTM 的总体架构,这是理解**记忆增强神经网络****MANN**)的基础,这些神经网络修改了 NMT 的架构并使之适用于一个 镜头的学习任务。
W
wizardforcel 已提交
26

W
wizardforcel 已提交
27
# NTM 的架构
W
wizardforcel 已提交
28

W
wizardforcel 已提交
29
在过去的 50 年中,现代计算机发生了很大的变化。 但是,它们仍然由三个系统组成-内存,控制流和算术/逻辑运算。 来自生物学和计算神经科学领域的研究提供了广泛的证据,表明记忆对于快速有效地存储和检索信息至关重要。 从中汲取灵感,NTM 基本上由神经网络组成,该神经网络由控制器和称为存储库(或存储矩阵)的二维矩阵组成。 在每个时间步长,神经网络都会接收一些输入,并生成与该输入相对应的输出。 在这样做的过程中,它还访问内部存储库并对其执行读取和/或写入操作。 从传统的图灵机中汲取灵感,NMT 使用术语**头部**来指定内存位置。 下图显示了总体架构:
W
wizardforcel 已提交
30 31 32

![](img/29d7adbd-0103-4c43-b5b0-c2a65a121d48.png)

W
wizardforcel 已提交
33
整体架构看起来不错; 但是,这有一个问题。 如果通过在内存矩阵中指定行索引和列索引来访问内存位置,则无法获取该索引的梯度。 此操作不可反向传播,并且会使用标准的反向传播和基于梯度下降的优化技术来限制 NMT 的训练。 为了解决此问题,NTM 的控制器使用*模糊*读写操作与内存进行交互,这些操作与内存的所有元素进行不同程度的交互。 更准确地说,控制器以差分方式在所有存储位置上生成权重,这有助于使用基于标准梯度的优化方法从头到尾训练网络。
W
wizardforcel 已提交
34 35 36

在下一节中,我们将讨论如何产生这些权重以及如何执行读写操作。

W
wizardforcel 已提交
37
# 建模
W
wizardforcel 已提交
38

W
wizardforcel 已提交
39
在时间步`t``M[t]`)的存储矩阵具有`R`行和`C`列。 有一种注意力机制,用于指定注意头应该读取/写入的内存位置。 控制器生成的注意力向量是长度`R`的向量,称为**权重向量**`w[t]`),其中 向量`w[t](i)`的条目是存储库第`i`行的权重。 权重向量已标准化,这意味着它满足以下条件:*
W
wizardforcel 已提交
40 41 42

![](img/eed93f02-4abc-465e-b009-2912f63ccb39.png)

W
wizardforcel 已提交
43
# 读取
W
wizardforcel 已提交
44

W
wizardforcel 已提交
45
读取头将返回长度为`C`的向量, `r[t]`,它是存储器行`M[t](i)`由权重向量缩放:
W
wizardforcel 已提交
46 47 48

![](img/7b4dda93-17d9-4c36-ab06-7af57479b100.png)

W
wizardforcel 已提交
49
# 写入
W
wizardforcel 已提交
50

W
wizardforcel 已提交
51
写入是两个步骤的结合:擦除和添加。 为了擦除旧数据,写头使用附加长度`C`擦除向量`e[t]`以及权重向量。 以下方程式定义了擦除行的中间步骤:
W
wizardforcel 已提交
52 53 54

![](img/cb6720e7-9a40-4e30-b015-966dd1074fb0.png)

W
wizardforcel 已提交
55
最后,写入头使用长度`C`的向量, `a[t]`以及`M_erased[t]`根据前面的方程式和权重向量更新存储矩阵的行:
W
wizardforcel 已提交
56 57 58

![](img/6d9c4431-8618-4953-a442-b2320bd662fb.png)

W
wizardforcel 已提交
59
# 寻址
W
wizardforcel 已提交
60 61 62

读取和写入操作的关键是权重向量,该权重向量指示要从中读取/写入的行。 控制器分四个阶段生成此权重向量。 每个阶段都会产生一个中间向量,该向量将传递到下一个阶段:

W
wizardforcel 已提交
63
*   第一步是基于内容的寻址,其目的是基于每一行与给定的长度为`C`的给定关键字向量`k[t]`的相似度来生成权重向量。 更精确地说,控制器发出向量`k[t]`,并使用余弦相似性度量将其与`M[t]`的每一行进行比较。 如下:
W
wizardforcel 已提交
64 65 66 67 68 69 70

![](img/d338c2ca-2226-4543-abc4-5ab988fc0a28.png)

内容权重向量尚未规范化,因此可以通过以下操作进行规范化:

![](img/e9fa7fd5-b836-4dc6-a320-74200df1183a.png)

W
wizardforcel 已提交
71
*   第二阶段是基于位置的寻址,其重点是从特定存储位置读取/写入数据,而不是在阶段 1 中完成的位置值。其后,标量参数`g[t] ∈ (0, 1)`称为插值 门,将内容权重向量`w[t]^c`与前一个时间步的权重向量`w[t-1]`混合,以产生门控权重`w[t]^g`。 这使系统能够学习何时使用(或忽略)基于内容的寻址:
W
wizardforcel 已提交
72 73 74

![](img/906b96f2-1a9f-4b09-ad23-f551d40f82d5.png)

W
wizardforcel 已提交
75
*   在第三阶段,插值后,头部发出归一化的移位加权`s[t]`,以执行`R`模的移位运算(即,向上或向下移动行)。 这由以下操作定义:
W
wizardforcel 已提交
76 77 78

![](img/8ad09342-cba4-4cfe-8b3e-3536756a491d.png)

W
wizardforcel 已提交
79
*   第四个也是最后一个阶段,锐化,用于防止偏移的权重`w_tilde[t]`模糊。 这是使用标量`γ >= 1`并应用以下操作完成的:
W
wizardforcel 已提交
80 81 82

![](img/f6a0bbfd-bfe5-4d2d-8d0e-015fc55c1642.png)

W
wizardforcel 已提交
83
所有操作(包括读取,写入和寻址的四个阶段)都是差分的,因此可以使用反向传播和任何基于梯度下降的优化器从头到尾训练整个 NMT 模型。 控制器是一个神经网络,可以是前馈网络,也可以是循环神经网络,例如**长短期记忆****LSTM**)。 它已显示在各种算法任务(例如复制任务)上均具有良好的性能,这些任务将在本章的稍后部分实现。
W
wizardforcel 已提交
84

W
wizardforcel 已提交
85
既然我们了解了 NTM 的架构和工作原理,我们就可以开始研究 MANN,这是对 NMT 的修改,并且经过修改可以在一次学习中表现出色。
W
wizardforcel 已提交
86 87 88

# 记忆增强神经网络

W
wizardforcel 已提交
89
MANN 的目标是在一次学习任务中表现出色。 正如我们之前阅读的,NMT 控制器同时使用基于内容的寻址和基于位置的寻址。 另一方面,MANN 控制器仅使用基于内容的寻址。 有两个原因。 原因之一是一次学习任务不需要基于位置的寻址。 在此任务中,对于给定的输入,控制器可能只需要执行两个操作,并且这两个操作都与内容有关,而与位置无关。 当输入与先前看到的输入非常相似时,将采取一种措施,在这种情况下,我们可以更新内存的当前内容。 当当前输入与以前看到的输入不相似时,将采取另一种操作,在这种情况下,我们不想覆盖最近的信息。 相反,我们写到使用最少的内存位置。 在这种情况下,该存储模块称为**最久未使用的访问****LRUA**)模块。
W
wizardforcel 已提交
90

W
wizardforcel 已提交
91
# 读取
W
wizardforcel 已提交
92

W
wizardforcel 已提交
93
MANN 的读取操作与 NTM 的读取操作非常相似,唯一的区别是此处的权重向量仅使用基于内容的寻址(NMT 寻址的阶段 -1)。 更准确地说,控制器使用标准化的读取权重向量`w[t]^r`,将其与`M[t]`的行一起使用以生成读取向量,`r[t]`
W
wizardforcel 已提交
94 95 96

![](img/dc63f6b2-8a4b-4774-8ed6-0d720f31bb7b.png)

W
wizardforcel 已提交
97
读取权重向量`w[t]^r`由控制器产生,该控制器由以下操作定义:
W
wizardforcel 已提交
98 99 100

![](img/b947ff46-f23a-45ed-a252-2ceaa0ca82b6.png)

W
wizardforcel 已提交
101
在此,运算`K()`是余弦相似度,类似于为 NMT 定义的余弦相似度。
W
wizardforcel 已提交
102

W
wizardforcel 已提交
103
# 写入
W
wizardforcel 已提交
104 105 106 107 108

为了写入存储器,控制器在写入最近读取的存储器行和写入最近读取的存储器行之间进行插值。

![](img/940cbf9a-4794-43e5-917a-171886d6871a.png)

W
wizardforcel 已提交
109
通过对 Omniglot 数据集进行一次一次性分类任务,MANN 已显示出令人鼓舞的结果。 由于其基本的模型 NTM,它们表现良好。 NTM 能够快速编码,存储和检索数据。 它们还能够存储长期和短期权重。 可以使用 MANN 的方法添加 NTM,以跟踪*最久未使用的存储位置*,以执行基于内容的寻址,以读取和写入*最久未使用的*位置。 它使 MANN 成为少量学习的理想人选。
W
wizardforcel 已提交
110

W
wizardforcel 已提交
111
在下一部分中,我们将学习另一种基于模型的架构,该架构由四个架构的网络组成,并为一次学习领域做出了重大贡献。
W
wizardforcel 已提交
112 113 114

# 了解元网络

W
wizardforcel 已提交
115
顾名思义,元网络是基于模型的元学习方法的一种形式。 在通常的深度学习方法中,神经网络的权重是通过随机梯度下降来更新的,这需要大量的时间来训练。 众所周知,随机梯度下降法意味着我们将考虑每个训练数据点进行权重更新,因此,如果我们的批量大小为 1,这将导致模型优化非常缓慢-换句话说,**较慢**的权重更新。
W
wizardforcel 已提交
116

W
wizardforcel 已提交
117
元网络建议通过训练与原始神经网络并行的神经网络来预测目标任务的参数,从而解决权重缓慢的问题。 生成的权重称为**快速权重**。 如果您还记得的话,LSTM 元学习器(请参阅第 4 章,“基于优化的方法”)也是基于类似的基础来预测使用 LSTM 单元的任务的参数更新 。
W
wizardforcel 已提交
118 119 120

与其他元学习方法类似,元网络包含两个级别:

W
wizardforcel 已提交
121 122
*   **元学习器**:元学习器获得有关不同任务的一般知识。 在元网络的情况下,这是一个嵌入函数,用于比较两个不同数据点的特征。
*   **基础学习器**:基础学习器尝试学习目标任务(任务目标网络可以是简单的分类器)。
W
wizardforcel 已提交
123

W
wizardforcel 已提交
124
元级学习器的目标是获得有关不同任务的一般知识。 然后可以将知识转移到基础级学习器,以在单个任务的上下文中提供概括。
W
wizardforcel 已提交
125

W
wizardforcel 已提交
126
如所讨论的,元网络学习权重的两种形式:慢权重和快权重。 要为元学习器(嵌入函数)和基础学习器(分类模型)两者学习这些权重,我们需要两个不同的网络。 这使得元网络成为迄今为止我们在本书中讨论过的最复杂的网络之一。 简而言之,元网络由四种类型的神经网络组成,它们各自的参数要训练。 在下一节中,我们将遍历元网络中存在的每个网络,并了解其架构。
W
wizardforcel 已提交
127 128 129 130 131

# 元网络算法

要开始学习元网络,我们首先需要定义以下术语:

W
wizardforcel 已提交
132 133
*   **支持集**:训练集中的采样输入数据点(`x``y`)。
*   **测试集**:来自训练集的采样数据点(`x`*和*)。
W
wizardforcel 已提交
134 135 136 137 138 139
*   **嵌入函数**`f[θ]`):作为元学习器的一部分,*嵌入函数*与连体网络非常相似。 经过训练可以预测两个输入是否属于同一类。
*   **基本学习器模型**`g[φ]`):基本学习器模型尝试完成实际的学习任务(例如,分类模型)。
*   `θ⁺`:嵌入函数的快速权重,(`f[θ]`)。
*   `φ⁺`:基本学习器模型的快速权重(`g[φ]`)。
*   `F[w]`:一种 LSTM 架构,用于学习嵌入函数的快速权重`θ``f[θ]`)的慢速网络。
*   `G[v]`:通过`v`学习快速权重`φ`参数化的神经网络,用于基础学习器`g[φ]`,来自其损失梯度。
W
wizardforcel 已提交
140

W
wizardforcel 已提交
141
下图说明了元网络架构:
W
wizardforcel 已提交
142 143 144

![](img/7ee8a46f-a88d-4310-a261-da0df68aa4f7.png)

W
wizardforcel 已提交
145
如图所示,元学习器基础学习器由较慢的权重(`θ, φ`)组成。 为了学习快速权重(`θ⁺, φ⁺`),元网络使用两个不同的网络:
W
wizardforcel 已提交
146

W
wizardforcel 已提交
147 148
*   LSTM 网络(`F[w]`),学习嵌入函数的(元学习器)快速权重-即`θ⁺`
*   神经网络(`G[v]`),以学习基本学习器的快速权重,即`φ⁺`
W
wizardforcel 已提交
149

W
wizardforcel 已提交
150
现在我们已经了解了快速权重和慢速权重的概念和架构,让我们尝试观察整个元网络架构:
W
wizardforcel 已提交
151 152 153

![](img/d98e350d-6d91-4bff-b79d-6f33824b0e59.png)

W
wizardforcel 已提交
154
如上图所示,元网络由基本学习器和配备*外部存储器*的元学习器(嵌入函数)组成。 我们还可以看到快速参数化箭头同时进入元学习器和基础学习器。 这些是元权重的输出,其中包括用于学习快速权重的模型。
W
wizardforcel 已提交
155

W
wizardforcel 已提交
156
现在让我们简单介绍一下训练。 随着训练输入数据的到来,它会同时通过元学习器和基础学习器。 在元学习器中,它用于连续学习(更新参数),而在基础学习器中,对输入进行预处理之后,它将*元信息**梯度*)传递给元 -学习器。 之后,元学习器使用*元信息**梯度*)将参数化更新快速返回给基础学习器,以通过使用慢速权重和快速权重的集成来进行优化(如图所示) 在下图中)。 元网络的基本关键思想是通过处理*元信息*以快速的方式学习权重以进行快速概括。
W
wizardforcel 已提交
157

W
wizardforcel 已提交
158
在 MetaNet 中,学习器的损失梯度是任务的*元信息*。 MetaNet 还有一个重要的问题:它如何使用快速权重和慢速权重进行预测?
W
wizardforcel 已提交
159

W
wizardforcel 已提交
160
在 MetaNet 中,将慢速权重和快速权重组合在一起,以在神经网络中进行预测,如下图所示。 在这里,`⊕`表示元素方式的和:
W
wizardforcel 已提交
161 162 163 164 165 166 167

![](img/e3036523-f829-413a-85ad-fe9838ae6684.png)

在下一节中,我们将逐步介绍算法,提取元信息以及最终模型预测。

# 算法

W
wizardforcel 已提交
168
元网络也遵循与匹配网络相似的训练过程。 在此,训练数据分为两种类型:支持集`S = (x'[i], y'[i])`和测试集`U = (x[i], y[i])`
W
wizardforcel 已提交
169

W
wizardforcel 已提交
170
请记住,目前,我们有四个网络(`f[θ], g[φ], F[w], G[v])`)和四组要学习的模型参数`(θ, ϕ, w, v)`。 让我们看一下算法的所有步骤:
W
wizardforcel 已提交
171 172 173 174 175

![](img/6a8a41af-c6a4-471a-8998-920cf7d1d0c7.png)

以下是算法的步骤:

W
wizardforcel 已提交
176
1.  支持集`S`的样本的`K`个随机对。
W
wizardforcel 已提交
177

W
wizardforcel 已提交
178
对于`t = 1, ..., K`
W
wizardforcel 已提交
179

W
wizardforcel 已提交
180 181
*   通过嵌入函数`f[θ]`正向传播数据点。
*   计算交叉熵损失(`L_emb`)。
W
wizardforcel 已提交
182

W
wizardforcel 已提交
183
2.  通过 LSTM 网络向前传递数据以计算`θ⁺``θ⁺ = F[w](▽L_emb)`
W
wizardforcel 已提交
184
3.  接下来,遍历支持集`S`中的示例,并为每个示例计算快速权重。 同时,使用获得的嵌入内容更新外部存储器。
W
wizardforcel 已提交
185

W
wizardforcel 已提交
186
对于`i = 1, ..., K`
W
wizardforcel 已提交
187

W
wizardforcel 已提交
188 189 190 191 192 193
*   将基本学习器`g[φ](x[i])`(例如,分类模型)向前传递,并计算损失函数`L_task[i]`(例如,交叉熵)。
*   计算基本学习器梯度`▽L_task[i]`,并使用它们来计算示例级快速权重`φ⁺[i] = G[v](▽L_task[i])`
*   将基础学习器`φ⁺[i]`的计算得出的快速权重存储在存储器的`M`部分的`i`位置处。
*   在嵌入网络中使用`⊕`合并快速权重和缓慢权重。
*   将支持样本转发通过嵌入网络并获得嵌入`r'[i] = f[θ,θ⁺](x[i])`
*`r'[i]`存储在内存`R``k`部分的`i`位置。
W
wizardforcel 已提交
194

W
wizardforcel 已提交
195
4.  最后,是时候使用测试集`U = (x[i], y[i])`构建训练损失了。 从`L_train = 0`开始。
W
wizardforcel 已提交
196

W
wizardforcel 已提交
197
对于`j = 1, ..., L`
W
wizardforcel 已提交
198

W
wizardforcel 已提交
199 200 201 202 203
*   将测试样本转发通过嵌入网络,并获得测试嵌入`r'[i] = f[θ,θ⁺](x[j])`
*   计算支持集的存储嵌入`R`和获得的嵌入`r[j]`之间的相似度。 您可以使用`a[j] = cos(R, r[j])`来执行此操作。 在此,`R`是指存储在外部存储器中的数据。
*   现在,通过使用支持集样本(`M`)的快速权重来计算基础学习器的快速权重(`φ⁺`)。 您可以使用`φ⁺[j] = softmax(a[j])^T M`来执行此操作。 在此,`M`是指存储在外部存储器中的数据。
*   使用最新的`φ⁺`将测试样本向前传递通过基本学习器,并计算损失函数`L_task[i]`
*   使用公式更新训练损失:
W
wizardforcel 已提交
204

W
wizardforcel 已提交
205 206 207
    ![](img/14423e78-c644-4cec-b7dc-f5879e000a7a.png)

5.  使用`L_train`更新所有参数`(θ, ϕ, w, v)`
W
wizardforcel 已提交
208

W
wizardforcel 已提交
209
在选择嵌入网络时,元网络使用 LSTM 架构。 如我们所见,匹配网络和 LSTM 元学习器也遵循相同的策略,分别用于提取数据和元信息的上下文嵌入。 这是因为 LSTM 架构倾向于记住历史,这使得元学习器的目标能够跨任务提取重要信息。
W
wizardforcel 已提交
210

W
wizardforcel 已提交
211
例如,假设我们正在训练我们的网络以完成多种任务,例如猫的品种分类和狗的品种分类。 当我们使用 LSTM 元学习器进行训练时,它会学习例如狗品种分类中体重更新的策略,并使用这些学习到的信息以较少的步骤和较少的数据来优化其用于猫品种分类的操作。 使用元网络在 Omniglot 数据集上达到了 95.92% 的准确率,而人类的准确率仅为 95.5%,因此,元网络被认为是最新模型之一。
W
wizardforcel 已提交
212 213 214 215 216

# 编码练习

在本节中,我们将首先介绍 NTM 的实现,然后再使用 Omniglot 数据集介绍 MAAN。 所以,让我们开始吧!

W
wizardforcel 已提交
217
本练习不包括代码的某些部分。 如果希望获得可运行的代码,请在[这个页面](https://github.com/PacktPublishing/Hands-On-One-shot-Learning-with-Python)中查看本书的 GitHub 存储库。
W
wizardforcel 已提交
218

W
wizardforcel 已提交
219
# NTM 的实现
W
wizardforcel 已提交
220 221 222 223 224 225 226 227 228 229

如上所述,NTM 由两个重要组成部分组成:

*   神经网络,也称为控制器
*   称为记忆的二维矩阵

在本教程中,我们将实现两者的简化版本,并尝试展示复制任务。

任务目标如下:

W
wizardforcel 已提交
230 231
*   NTM 模型显示为`T`时间步长的随机`k`维向量。
*   网络的工作是在每个时间步从零向量输出这些`T[k]`维随机向量。
W
wizardforcel 已提交
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251

执行以下步骤来实现 NTM:

1.  首先,导入所有必需的库:

```py
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from time import time
import torchvision.utils as vutils
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
%matplotlib inline
```

2.  然后,实现`Controller`。 作为控制器的一部分,我们将实现以下三个组件:

W
wizardforcel 已提交
252
    *   两层前馈网络
W
wizardforcel 已提交
253
    *   使用 Xavier 方法进行权重初始化
W
wizardforcel 已提交
254
    *   Sigmoid 非线性
W
wizardforcel 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280

```py
class Controller(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super(Controller, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size)
        self.layer2 = nn.Linear(hidden_size, output_size)
        self.intialize_parameters()

    def intialize_parameters(self):
        # Initialize the weights of linear layers
        nn.init.xavier_uniform_(self.layer1.weight, gain=1.4) 
        nn.init.normal_(self.layer1.bias, std=0.01)
        nn.init.xavier_uniform_(self.layer2.weight, gain=1.4)
        nn.init.normal_(self.layer2.bias, std=0.01)

    def forward(self, x, last_read):
        # Forward pass operation, depending on last_read operation
        x = torch.cat((x, last_read), dim=1)
        x = torch.sigmoid(self.layer1(x))
        x = torch.sigmoid(self.layer2(x))
        return x
```

我们也可以有一个 LSTM 控制器,但是由于简单起见,我们构建了一个两层的全连接控制器。

W
wizardforcel 已提交
281
3.  接下来,实现`Memory`模块。 `Memory`是一个二维矩阵,具有`M`行和`N`列:
W
wizardforcel 已提交
282 283 284

![](img/3a4d2e1c-a7df-4316-b1cd-5a67640f0bdf.png)

W
wizardforcel 已提交
285
`address()`函数执行存储器寻址,它由四个功能组成:
W
wizardforcel 已提交
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400

*   `similarity`
*   `interpolate`
*   `shift`
*   `sharpen`

```py
class Memory(nn.Module):
    def __init__(self, M, N, controller_out):
        super(Memory, self).__init__()
        self.N = N
        self.M = M
        self.read_lengths = self.N + 1 + 1 + 3 + 1
        self.write_lengths = self.N + 1 + 1 + 3 + 1 + self.N + 
            self.N
        self.w_last = [] # define to keep track of weight_vector 
        at each time step.
        self.reset_memory()

    def address(self, k, beta, g, s, gamma, memory, w_last):
        # Content focus
        wc = self._similarity(k, beta, memory) # CB1 to CB3 
        equations
        # Location focus
        wg = self._interpolate(wc, g, w_last) # CS1 equation
        w_hat = self._shift(wg, s) # CS2 and CS3 equation
        w = self._sharpen(w_hat, gamma) # S1 equation
        return w

    # Implementing Similarity on basis of CB1 followed by CB2 
    and CB3 Equation
    def _similarity(self, k, beta, memory):
        w = F.cosine_similarity(memory, k, -1, 1e-16) # CB1 
        Equation
        w = F.softmax(beta * w, dim=-1) # CB2 and CB3 Equation
        return w # return CB3 equation obtained weights

    # Implementing CS1 Equation. It decides whether to use 
    the weights we obtained
    # at the previous time step w_last or use the weight 
    obtained through similarity(content focus)
    def _interpolate(self, wc, g, w_last):
        return g * wc + (1 - g) * w_last
# .... Rest Code is available at Github......
```

4.  接下来,执行`read`操作。 在这里,我们将定义`ReadHead`,它可以根据`read`操作访问和更新内存:

![](img/006a56d0-58c8-429e-a160-c9f33966e585.png)

```py
class ReadHead(Memory):
    # Reading based on R2 equation
    def read(self, memory, w):
        return torch.matmul(w, memory)

    # Use Memory class we formed above to create a ReadHead 
    operation
    def forward(self, x, memory):
        param = self.fc_read(x) # gather parameters
        # initialize necessary parameters k, beta, g, shift, 
        and gamma
        k, beta, g, s, gamma = torch.split(param, 
            [self.N, 1, 1, 3, 1], dim=1)
        k = torch.tanh(k)
        beta = F.softplus(beta)
        g = torch.sigmoid(g)
        s = F.softmax(s, dim=1)
        gamma = 1 + F.softplus(gamma) 
        # obtain current weight address vectors from Memory
        w = self.address(k, beta, g, s, gamma, memory, 
            self.w_last[-1])
        # append in w_last function
        self.w_last.append(w)
        mem = self.read(memory, w) 
        return mem, w
```

5.`read`操作类似,在这里我们将执行`write`操作:

```py
class WriteHead(Memory):
    def write(self, memory, w, e, a):
        # Implement write function based on E1 and A1 Equation
        w, e, a = torch.squeeze(w), torch.squeeze(e), 
            torch.squeeze(a)
        erase = torch.ger(w, e)
        m_tilde = memory * (1 - erase) # E1 equation
        add = torch.ger(w, a)
        memory_update = m_tilde + add # A1 equation
        return memory_update

    def forward(self, x, memory):
        param = self.fc_write(x) # gather parameters
        # initialize necessary parameters k, beta, g, shift, 
        and gamma
        k, beta, g, s, gamma, a, e = torch.split(param, 
            [self.N, 1, 1, 3, 1, self.N, self.N], dim=1)
        k = torch.tanh(k)
        beta = F.softplus(beta)
        g = torch.sigmoid(g)
        s = F.softmax(s, dim=-1)
        gamma = 1 + F.softplus(gamma)
        a = torch.tanh(a)
        e = torch.sigmoid(e)
        # obtain current weight address vectors from Memory
        w = self.address(k, beta, g, s, gamma, memory, 
            self.w_last[-1])
        # append in w_last function
        self.w_last.append(w)
        # obtain current mem location based on R2 equation
        mem = self.write(memory, w, e, a)
        return mem, w
```

W
wizardforcel 已提交
401
`ReadHead``WriteHead`都使用全连接层来生成用于内容编址的参数(`k``beta``g``s``gamma`)。
W
wizardforcel 已提交
402 403 404

6.  实现神经图灵机结构,其中包括:

W
wizardforcel 已提交
405
*   全连接控制器
W
wizardforcel 已提交
406
*   读写头部
W
wizardforcel 已提交
407
*   记忆体参数
W
wizardforcel 已提交
408
*   在无法训练的记忆上运行的工具函数
W
wizardforcel 已提交
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431

```py
class NTM(nn.Module):
    def forward(self, X=None):
        if X is None:
            X = torch.zeros(1, self.num_inputs)
        controller_out = self.controller(X, self.last_read) 
        self._read_write(controller_out)
        # use updated last_read to get sequence
        out = torch.cat((X, self.last_read), -1)
        out = torch.sigmoid(self.fc_out(out))

        return out

    def _read_write(self, controller_out):
        # Read Operation
        read, w = self.read_head(controller_out, self.memory)
        self.last_read = read
        # Write Operation
        mem, w = self.write_head(controller_out, self.memory)
        self.memory = mem
```

W
wizardforcel 已提交
432
`forward`函数中,`X`可以是`None`。 这是因为,在复制任务中,针对一个特定序列,训练分两步进行:
W
wizardforcel 已提交
433

W
wizardforcel 已提交
434
1.  在第一步中,网络显示为`t`时间步长的`k`维输入。
W
wizardforcel 已提交
435
2.  在第二步(预测步骤)中,网络采用`k`维零向量来产生预测,该预测对每个时间步执行输入的复制。
W
wizardforcel 已提交
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491

7.  在这里,我们正在为复制任务生成向量的随机序列。 它由 NTM 模型复制:

```py
class BinaySeqDataset(Dataset):

    def __init__(self, sequence_length, token_size, 
    training_samples):
        self.seq_len = sequence_length
        self.seq_width = token_size
        self.dataset_dim = training_samples

    def _generate_seq(self):
        # A special token is appened at beginning and end of each
        # sequence which marks sequence boundaries.
        seq = np.random.binomial(1, 0.5, (self.seq_len, self.seq_width))
        seq = torch.from_numpy(seq)
        # Add start and end token
        inp = torch.zeros(self.seq_len + 2, self.seq_width)
        inp[1:self.seq_len + 1, :self.seq_width] = seq.clone()
        inp[0, 0] = 1.0
        inp[self.seq_len + 1, self.seq_width - 1] = 1.0
        outp = seq.data.clone()

        return inp.float(), outp.float()

    def __len__(self):
        return self.dataset_dim

    def __getitem__(self, idx):
        inp, out = self._generate_seq()
        return inp, out
```

8.  我们还将实现梯度剪切,因为剪切梯度通常是一个好主意,以使网络在数值上稳定:

```py
def clip_grads(net, min_grad=-10,max_grad=10):
    parameters = list(filter(lambda p: p.grad is not None, net.parameters()))
    for p in parameters:
        p.grad.data.clamp_(min_grad,max_grad)
```

9.  初始化训练参数:

```py
memory_capacity=64
memory_vector_size=128
controller_output_dim=256
controller_hidden_dim=512
learning_rate=1e-2

sequence_length, token_size, training_samples = 2, 10, 99
min_grad, max_grad = -10, 10
```

W
wizardforcel 已提交
492
0.  然后,初始化训练模型:
W
wizardforcel 已提交
493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538

```py
# Initialize the dataset
dataset = BinaySeqDataset(sequence_length, token_size, training_samples)
dataloader = DataLoader(dataset, batch_size=1,shuffle=True, num_workers=4)
model = NTM() # Initialize NTM
criterion = torch.nn.BCELoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
losses = []
# Train the Model
for e, (X, Y) in enumerate(dataloader):
    tmp = time()
    model.initalize_state()
    optimizer.zero_grad()
    inp_seq_len = sequence_length + 2
    out_seq_len = sequence_length
    X.requires_grad = True
    # Forward Pass: Feed the Sequence
    for t in range(0, inp_seq_len):
        model(X[:, t])
    # Predictions: Obtain the already feeded sequence
    y_pred = torch.zeros(Y.size())
    for i in range(0, out_seq_len):
        y_pred[:, i] = model() # Here, X is passed as None
    loss = criterion(y_pred, Y)
    loss.backward()
    clip_grads(model)
    optimizer.step()
    losses += [loss.item()]
    if (e%10==0):
        print("iteration: {}, Loss:{} ".format(e, loss.item()))
    if e == 5000:
        break
```

运行此单元格后,您将看到以下输出:

```py
iteration: 0, Loss:0.7466866970062256 
iteration: 10, Loss:0.7099956274032593 
iteration: 20, Loss:0.6183871626853943 
iteration: 30, Loss:0.6750341653823853 
iteration: 40, Loss:0.7050653696060181 
iteration: 50, Loss:0.7188648581504822
```

W
wizardforcel 已提交
539
1.  定义一个`plot_signal`函数并绘制训练损失`losses`
W
wizardforcel 已提交
540 541 542 543 544 545 546 547 548 549

```py
def plot_signal(grid_image, fig_size=(500,100)):
    plt.figure(figsize=fig_size)
    plt.imshow(grid_image.data.permute(2, 1, 0))

plt.plot(losses)
plt.show()
```

W
wizardforcel 已提交
550
2.  测试 NTM 模型的复制任务:
W
wizardforcel 已提交
551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580

```py
X, Y = dataset._generate_seq()
X, Y = X.unsqueeze(0), Y.unsqueeze(0)# Add the batch dimension

model.initalize_state()

for t in range(0, inp_seq_len):
    model(X[:, t])

y_pred = torch.zeros(Y.size())
for i in range(0, out_seq_len):
    y_pred[:, i] = model()

grid_img_truth = vutils.make_grid(Y, normalize=True, scale_each=True)
grid_img_pred = vutils.make_grid(y_pred, normalize=True, scale_each=True)

plt.figure(figsize=(200,200))
plt.imshow(grid_img_truth.data.permute(2, 1, 0))

plt.figure(figsize=(200,200))
plt.imshow(grid_img_pred.data.permute(2, 1, 0))
```

运行前面的代码将给出以下输出:

![](img/0dc925d0-48d7-4b5c-958f-fe2f911fdb8d.png)

在这里,我们创建了一个 300 个时间步长的随机信号,并观察了模型复制该信号的程度。 在此步骤中,您观察了复制任务输出。 这两个信号应该非常接近。 如果不是,我们建议您更多地训练模型。

W
wizardforcel 已提交
581
# MAAN 的实现
W
wizardforcel 已提交
582 583 584

正如我们在上一节中展示的那样,NTM 的控制器能够使用基于内容的寻址,基于位置的寻址或同时使用这两种,而 MANN 则使用纯基于内容的内存写入器来工作。

W
wizardforcel 已提交
585
MANN 还使用一种称为**最久未访问**的新寻址方案。 该模式背后的思想是,最久未使用的内存位置由读取操作确定,而读取操作由基于内容的寻址执行。 因此,我们基本上执行基于内容的寻址,以读取和写入最久未使用的位置。
W
wizardforcel 已提交
586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629

在本教程中,我们将实现`read``write`操作。

1.  首先,导入所需的所有库:

```py
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import copy
```

2.  实现类似于 NTM 的`Memory`模块,并对 MANN 进行一些更改:

```py
class Memory(nn.Module):

 def __init__(self, M, N, controller_out):
     super(Memory, self).__init__()
     self.N = N
     self.M = M
     self.read_lengths = self.N + 1 + 1 + 3 + 1
     self.write_lengths = self.N + 1 + 1 + 3 + 1 + self.N + self.N
     self.w_last = [] # define to keep track of weight_vector at 
     each time step
     self.reset_memory()

 def address(self, k, beta, g, s, gamma, memory, w_last):
     # Content focus
     w_r = self._similarity(k, beta, memory)
     return w_r

 # Implementing Similarity
 def _similarity(self, k, beta, memory):
     w = F.cosine_similarity(memory, k, -1, 1e-16) 
     w = F.softmax(w, dim=-1)
     return w # return w_r^t for reading purpose
```

3.  定义`ReadHead`,使其可以根据`read`操作访问和更新内存:

![](img/295304f9-1168-4b66-8274-4550618d7e6e.png)

W
wizardforcel 已提交
630
`ReadHead`函数定义如下:
W
wizardforcel 已提交
631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714

```py
class ReadHead(Memory):
    def read(self, memory, w):
        # Calculate Memory Update
        return torch.matmul(w, memory)

    def forward(self, x, memory):
        param = self.fc_read(x) # gather parameters
        # initialize necessary parameters k, beta, g, shift, and 
        gamma
        k, g, s, gamma = torch.split(param, [self.N, 1, 1, 3, 1], 
            dim=1)
        k = torch.tanh(k)
        g = F.sigmoid(g)
        s = F.softmax(s, dim=1)
        gamma = 1 + F.softplus(gamma)
        # obtain current weight address vectors from Memory
        w_r = self.address(k, g, s, gamma, memory, self.w_last[-1])
        # append in w_last function to keep track content based 
        locations
        self.w_last.append(w_r)
        # obtain current mem location based on above equations
        mem = self.read(memory, w_r)
        w_read = copy.deepcopy(w_r)
        return mem, w_r
```

4.`read`操作类似,在这里我们将执行`write`操作:

![](img/1560e10f-2d4c-4532-a273-9d87c961a89e.png)

`write`操作的实现如下:

```py
class WriteHead(Memory):

    def usage_weight_vector(self, prev_w_u, w_read, w_write, 
    gamma):
        w_u = gamma * prev_w_u + torch.sum(w_read, dim=1) + 
            torch.sum(w_write, dim=1)
        return w_u # Equation F2

    def least_used(self, w_u, memory_size=3, n_reads=4):
        _, indices = torch.topk(-1*w_u,k=n_reads)
        wlu_t = torch.sum(F.one_hot(indices, 
            memory_size).type(torch.FloatTensor),dim=1,
            keepdim=True)
        return indices, wlu_t

    def mann_write(self, memory, w_write, a, gamma, prev_w_u, 
    w_read, k):
        w_u = self.usage_weight_vector(prev_w_u, w_read, w_write, 
            gamma)
        w_least_used_weight_t = self.least_used(w_u)
        # Implement write step as per (F3) Equation
        w_write = torch.sigmoid(a)*w_read + 
            (1-torch.sigmoid(a))*w_least_used_weight_t
        memory_update = memory + w_write*k # Memory Update 
        as per Equation (F4)

    def forward(self, x, memory):
        param = self.fc_write(x) # gather parameters
        k, beta, g, s, gamma, a, e = torch.split(param, 
            [self.N, 1, 1, 3, 1, self.N, self.N], dim=1)
        k = F.tanh(k)
        beta = F.softplus(beta)
        g = F.sigmoid(g)
        s = F.softmax(s, dim=-1)
        gamma = 1 + F.softplus(gamma)
        a = F.tanh(a)
        # obtain current weight address vectors from Memory
        w_write = self.address(k, beta, g, s, gamma, memory, 
            self.w_last[-1])
        # append in w_last function to keep track content 
        based locations
        self.w_last.append(w_write)
        # obtain current mem location based on F2-F4 equations
        mem = self.write(memory, w_write, a, gamma, prev_w_u, 
            w_read, k)
        w_write = copy.deepcopy(w)
        return mem, w
```

W
wizardforcel 已提交
715
`ReadHead``WriteHead`都使用全连接层来生成用于内容编址的参数(`k``beta``g``s``gamma`)。
W
wizardforcel 已提交
716

W
wizardforcel 已提交
717
请注意,此练习只是为了展示 MANN 如何受到 NTM 的启发。 如果您想在真实的数据集上探索前面的练习,请参考[GitHub 存储库](https://github.com/PacktPublishing/Hands-On-One-shot-Learning-with-Python/tree/master/Chapter03)
W
wizardforcel 已提交
718

W
wizardforcel 已提交
719
# 总结
W
wizardforcel 已提交
720

W
wizardforcel 已提交
721
在本章中,我们探索了用于单次学习的不同形式的基于模型的架构。 我们观察到的最常见的事情是使用外部存储器,这对学习神经网络不同阶段的表示形式有什么帮助。 NMT 方法在一次学习任务中表现良好,但是由于手工设计的内存寻址功能,它们的能力仍然有限,因为它们必须具有差异性。 探索更复杂的功能来处理内存可能很有趣。 在元网络中,我们看到了如何定义一个新网络以实现对原始网络的快速学习,以及如何在元学习器级别上存储有关表示的信息如何在基础级别上微调参数。 尽管基于模型的架构是实现一次学习的好方法,但它们具有外部存储器的先决条件,因此与其他方法相比,实现基于模型的架构的成本昂贵。
W
wizardforcel 已提交
722 723 724 725 726 727 728

在下一章中,我们将介绍基于优化的方法,例如与模型无关的元学习和 LSTM 元学习。 内存为我们提供了一种方式来存储我们所学到的信息,因此优化策略使我们能够更快地学习事物。 在后面的章节中,我们将探索一些不同形式的优化策略,这些策略可以用来学习目标。

# 问题

1.  什么是神经图灵机,它们如何帮助学习?
2.  记忆矩阵如何帮助模型更快地学习?
W
wizardforcel 已提交
729
3.  元学习器和基础学习器之间的分裂如何帮助架构学习一次学习?
W
wizardforcel 已提交
730 731 732 733 734

# 进一步阅读

基于模型的方法是您需要学习的更复杂的主题之一,因此,如果您想更深入地研究所涉及的概念,则可以阅读以下论文:

W
wizardforcel 已提交
735 736 737
*   [《神经图灵机》](https://arxiv.org/pdf/1410.5401.pdf)
*   [《记忆增强神经网络》](http://proceedings.mlr.press/v48/santoro16.pdf)
*   [《元网络》](https://arxiv.org/pdf/1703.00837.pdf)