提交 fdbbd0d8 编写于 作者: W wizardforcel

2020-12-25 16:16:02

上级 14a1358f
......@@ -124,17 +124,17 @@
![](img/60110071-3fb1-471e-9b9c-8fc73797e836.png)
因此,我们的元学习者经过几次更新后即可学习![](img/e58f947b-d9ce-46b7-8704-0923b21e27f7.png)和![](img/d766f53c-e3b3-42ac-9852-4174a782fd96.png)的最佳值。
因此,我们的元学习者经过几次更新后即可学习`i[t]``f[t]`的最佳值。
但是,这如何工作?
假设我们有一个由![](img/46fb0ba6-818f-4315-926b-6c5502f897fa.png)参数化的基础网络![](img/96e5a664-2281-469b-bc8c-3a93da8a4d8a.png)和由![](img/6b477c6b-dca3-4a40-8c43-bf844d4acf51.png)参数化的 LSTM 元学习器![](img/781fbd22-771e-4aab-ae43-36cb517b8b74.png)。 假设我们有一个数据集![](img/3a868eff-e1e0-4800-b315-c3d89407a051.png)。 我们将数据集分为![](img/8ea53b48-98b0-4abc-ab06-ee1020fbbb7c.png)和![](img/7a42e349-bf47-41c9-a4e6-7295534d7473.png)分别进行训练和测试。 首先,我们随机初始化元学习器参数![](img/7f6407d1-42b8-4442-80e9-f09b2e884040.png)
假设我们有一个由`Θ`参数化的基础网络,和`M`参数化的 LSTM 元学习器`R`。 假设我们有一个数据集`D`。 我们将数据集分为`D_train``D_test`分别进行训练和测试。 首先,我们随机初始化元学习器参数`φ`
对于某些`T`迭代次数,我们从![](img/8ea53b48-98b0-4abc-ab06-ee1020fbbb7c.png)中随机采样数据点,计算损失,然后相对于模型参数![](img/e3edc475-656e-4d41-bf0b-1a51a6078d03.png)计算损失的梯度。 现在,我们将此梯度,损失和元学习器参数![](img/4593a096-0a79-40bf-9e2a-2116c6d24062.png)输入到我们的元学习器。 我们的元学习器![](img/2a1edc09-a4d0-4fe1-8ec8-ecc85fc2356a.png)将返回单元状态![](img/28a759a2-3729-4b33-871c-d3be31207596.png),然后我们将`t`的基础网络![](img/96e5a664-2281-469b-bc8c-3a93da8a4d8a.png)参数![](img/f17b6173-5f5b-467f-b3ed-5255f61ba7d4.png)更新为![](img/ade55668-74fb-4aaa-b834-8f66ee022488.png)。 我们重复`N`次,如下图所示:
对于某些`T`迭代次数,我们从`D_train`中随机采样数据点,计算损失,然后相对于模型参数`Θ`计算损失的梯度。 现在,我们将此梯度,损失和元学习器参数`φ`输入到我们的元学习器。 我们的元学习器`R`将返回单元状态`c[t]`,然后我们将时间`t`的基础网络`M`的参数`Θ[t]`更新为`c[t]`。 我们重复`N`次,如下图所示:
![](img/c85a3a23-876b-46aa-a6c1-6a2b664e0c2a.png)
因此,在`T`次迭代之后,我们将获得一个最佳参数![](img/a864b10a-01ce-4acf-aa4a-2fbc72a0441e.png)。 但是,我们如何检查![](img/1e7893f9-7981-46e3-aac3-625c5d712452.png)的性能以及如何更新元学习器参数? 我们采用测试集,并使用参数![](img/0b771654-d8fb-4b8f-8363-29c51d27d162.png)计算测试集的损耗。 然后,我们根据元学习器参数![](img/6b477c6b-dca3-4a40-8c43-bf844d4acf51.png)计算损耗的梯度,然后更新![](img/6b477c6b-dca3-4a40-8c43-bf844d4acf51.png),如下所示:
因此,在`T`次迭代之后,我们将获得一个最佳参数`θ[T]`。 但是,我们如何检查`θ[T]`的性能以及如何更新元学习器参数? 我们采用测试集,并使用参数`θ[T]`计算测试集的损耗。 然后,我们根据元学习器参数`φ`计算损耗的梯度,然后更新`φ`,如下所示:
![](img/7e4c2e2a-1fe8-47fe-ba23-096cf093cdc4.png)
......
......@@ -19,7 +19,7 @@
![](img/62f3894f-d397-4119-8a9c-5caa9f73651c.png)
因此,我们分为三类: `{Lion, Eleph, Dog}`。 现在,我们需要为这三个类中的每一个创建一个原型表示。 我们如何构建这三个类的原型? 首先,我们将使用嵌入函数来学习每个数据点的嵌入。 嵌入函数![](img/588d745f-a8f6-4cb3-961a-74339bd63c85.png)可以是可用于提取特征的任何函数。 由于我们的输入是图像,因此我们可以使用卷积网络作为嵌入函数,该函数将从输入图像中提取特征:
因此,我们分为三类: `{Lion, Eleph, Dog}`。 现在,我们需要为这三个类中的每一个创建一个原型表示。 我们如何构建这三个类的原型? 首先,我们将使用嵌入函数来学习每个数据点的嵌入。 嵌入函数`f[φ]()`可以是可用于提取特征的任何函数。 由于我们的输入是图像,因此我们可以使用卷积网络作为嵌入函数,该函数将从输入图像中提取特征:
![](img/68661c3a-0389-4e09-85bb-27da96071e60.png)
......@@ -488,17 +488,17 @@ elif inverse_transform_type == "other":
```
到目前为止,我们已经看到我们可以计算协方差矩阵以及输入的嵌入。 下一步是什么? 我们如何计算类原型? 类原型![](img/91015c2c-0142-4171-a1a8-512bf53077cd.png)可以如下计算:
到目前为止,我们已经看到我们可以计算协方差矩阵以及输入的嵌入。 下一步是什么? 我们如何计算类原型? 类原型`p[c]`可以如下计算:
![](img/adf16b20-81b8-4cf0-a203-699da2707fd4.png)
在该方程式中,![](img/46370dcd-e39e-4828-9304-9552d23ecaf6.png)是逆协方差矩阵的对角线,![](img/3d239ea2-586f-48aa-bb1f-d180b2231f60.png)表示嵌入,上标`c`表示类别。
在该方程式中,`s[i]^c`是逆协方差矩阵的对角线,`x[i]^c`表示嵌入,上标`c`表示类别。
在为每个类计算原型之后,我们学习了查询点的嵌入。 令![](img/6e6803c0-fc44-4fe0-ba29-bb7be3d92971.png)为查询点的嵌入。 然后,我们计算查询点嵌入和类原型之间的距离,如下所示:
在为每个类计算原型之后,我们学习了查询点的嵌入。 令`x'`为查询点的嵌入。 然后,我们计算查询点嵌入和类原型之间的距离,如下所示:
![](img/d4cee939-9955-4ee3-bcaf-ce9ef23c0fb5.png)
最后,我们预测查询集的类别(![](img/f395204d-209f-4cae-aef7-68e28e316c5a.png)),该类别与类别原型的距离最小:
最后,我们预测查询集的类别(`y_hat`),该类别与类别原型的距离最小:
![](img/4768a510-e42b-4c87-beef-08ed98ee76ee.png)
......@@ -514,7 +514,7 @@ elif inverse_transform_type == "other":
![](img/242c7e0b-a5d4-4bb0-849b-135dc2ce6e66.png)
在该等式中,![](img/5be43069-8bde-464f-8338-1f5d61e8f10c.png)是逆协方差矩阵的对角线,![](img/0d4938f7-97d6-479e-8158-a3b469a73000.png)表示支持集的嵌入,上标`c`表示类别。
在该等式中,`s[i]^c`是逆协方差矩阵的对角线,`x[i]^c`表示支持集的嵌入,上标`c`表示类别。
6. 在计算支持集中每个类的原型之后,我们学习了查询集`Q`的嵌入。 假设`x'`是查询点的嵌入。
7. 我们计算查询点嵌入与类原型的距离,如下所示:
......@@ -531,7 +531,7 @@ elif inverse_transform_type == "other":
考虑一下我们的数据集包含一些未标记数据点的情况:我们如何计算这些未标记数据点的类原型?
假设我们有一个支持集![](img/799afc48-e234-444c-9fd3-d0fddf7cef2d.png),其中`x`是要素,`y`是标签,还有一个查询集![](img/dcd39ef8-0799-42f6-8633-9060b071cd55.png)。 伴随着这些,我们还有另外一个称为未标记集`R`的集合,在这里,我们只有未标记的例子![](img/29cb497c-f7e5-47a6-99a2-2dbc4cf34aa4.png)
假设我们有一个支持集`S = (x1, y1), (x2, y2), ..., (xk, yk)`,其中`x`是要素,`y`是标签,还有一个查询集`Q = (x1', y1'), (x2', y2'), ..., (xk', yk')`。 伴随着这些,我们还有另外一个称为未标记集`R`的集合,在这里,我们只有未标记的例子`R = (x_tilde1, y_tilde1), (x_tilde2, y_tilde2), ..., (x_tildek, y_tildek)`
那么,我们该如何处理这个未标记的集呢?
......
......@@ -20,31 +20,31 @@
# 一次性学习中的关系网络
关系网络由两个重要功能组成:以![](img/f9704c53-02aa-4996-8ec8-bb5e3510e5df.png)表示的嵌入函数和以![](img/619a6f8b-c314-49a6-8846-9625d234030e.png)表示的关系功能。 嵌入函数用于从输入中提取特征。 如果输入是图像,则可以使用卷积网络作为嵌入函数,这将为我们提供图像的特征向量/嵌入。 如果我们的输入是文本,那么我们可以使用 LSTM 网络获取文本的嵌入。
关系网络由两个重要功能组成:以`f[φ]`表示的嵌入函数和以`g[φ]`表示的关系功能。 嵌入函数用于从输入中提取特征。 如果输入是图像,则可以使用卷积网络作为嵌入函数,这将为我们提供图像的特征向量/嵌入。 如果我们的输入是文本,那么我们可以使用 LSTM 网络获取文本的嵌入。
众所周知,在一次学习中,每个班级只有一个示例。 例如,假设我们的支持集包含三个类,每个类一个示例。 如下图所示,我们有一个包含三个类别的支持集,`{Lion, Eleph, Dog}`
![](img/030c5e04-cf05-4394-baad-3c43bc5f77fa.png)
假设我们有一个查询图像![](img/a69c2a9f-d4fb-413c-8047-bc6fc139a536.png),如下图所示,我们希望预测该查询图像的类:
假设我们有一个查询图像`x[j]`,如下图所示,我们希望预测该查询图像的类:
![](img/87783c42-eb05-4a41-aecf-ab7bdfa56cb5.png)
首先,我们从支持集中获取每个图像![](img/1078376b-32e4-4899-9606-74eeb970e0f3.png),并将其传递给嵌入函数![](img/61752837-2d9c-4f9a-a604-d86fd605e9bd.png),以提取特征。 由于我们的支持集包含图像,因此我们可以使用卷积网络作为我们的嵌入函数来学习嵌入。 嵌入函数将为我们提供支持集中每个数据点的特征向量。 类似地,我们将把查询图像![](img/b838f304-5607-4be2-b374-90474eb18833.png)传递给嵌入函数![](img/ad6baf80-8d13-4f1d-9ba2-9a90a9cb8b4d.png)来学习其嵌入。
首先,我们从支持集中获取每个图像`x[i]`,并将其传递给嵌入函数`f[φ](x[i])`,以提取特征。 由于我们的支持集包含图像,因此我们可以使用卷积网络作为我们的嵌入函数来学习嵌入。 嵌入函数将为我们提供支持集中每个数据点的特征向量。 类似地,我们将把查询图像`x[j]`传递给嵌入函数`f[φ](x[j])`来学习其嵌入。
因此,一旦有了支持集![](img/61752837-2d9c-4f9a-a604-d86fd605e9bd.png)和查询集![](img/ad6baf80-8d13-4f1d-9ba2-9a90a9cb8b4d.png)的特征向量,就可以使用运算符![](img/be540eca-a2f6-40f1-ae07-28c2fbfb136c.png)组合它们。 ![](img/b0dd8d2d-ae04-40ba-9ded-b799bc692200.png)可以是任何组合运算符; 我们使用串联作为运算符,以合并支持和查询集的特征向量,即![](img/c7f2ec55-cb79-45f5-b62c-70f8e77e3cf4.png)
因此,一旦有了支持集`f[φ](x[i])`和查询集`f[φ](x[j])`的特征向量,就可以使用运算符`Z`组合它们。 `Z`可以是任何组合运算符; 我们使用串联作为运算符,以合并支持和查询集的特征向量,即`Z(f[φ](x[i]), f[φ](x[j]))`
如下图所示,我们将合并支持集![](img/61752837-2d9c-4f9a-a604-d86fd605e9bd.png)和查询集![](img/ad6baf80-8d13-4f1d-9ba2-9a90a9cb8b4d.png)的特征向量。 但是这样的组合有什么用呢? 这将帮助我们理解支持集中图像的特征向量与查询图像的特征向量之间的关系。 在我们的示例中,它将帮助我们理解狮子,大象和狗的图像的特征向量与查询图像的特征向量之间的关系:
如下图所示,我们将合并支持集`f[φ](x[i])`和查询集`f[φ](x[j])`的特征向量。 但是这样的组合有什么用呢? 这将帮助我们理解支持集中图像的特征向量与查询图像的特征向量之间的关系。 在我们的示例中,它将帮助我们理解狮子,大象和狗的图像的特征向量与查询图像的特征向量之间的关系:
![](img/3cee7837-9e4c-469d-931b-c64c706de99b.png)
但是我们如何衡量这种关联性呢? 这就是为什么我们使用关系函数![](img/e53a6f8b-0846-43ff-b600-67b405404468.png)的原因。 我们将这些组合的特征向量传递给关系函数,该函数将生成从 0 到 1 的关系得分,代表支持集![](img/f3df92cc-3b9f-4296-97ac-a4caecbebf5d.png)中的样本与查询集![](img/2b82f087-237b-4c95-8083-de38d0232742.png)中的样本之间的相似性。
但是我们如何衡量这种关联性呢? 这就是为什么我们使用关系函数`g[φ]`的原因。 我们将这些组合的特征向量传递给关系函数,该函数将生成从 0 到 1 的关系得分,代表支持集`x[i]`中的样本与查询集`x[j]`中的样本之间的相似性。
以下等式说明了我们如何计算关系网络中的关系得分:
![](img/4d723f33-f865-4a5b-b387-b29268e8a070.png)
在该等式中,![](img/b6b80ad2-1186-4cca-8391-cf256c231a44.png)表示表示在支持集中的每个类别和查询图像之间的相似性的关系分数。 由于我们在支持集中有 3 个类别,在查询集中有 1 个图像,因此我们将获得 3 个分数,表明支持集中的所有 3 个类别与查询图像的相似程度。
在该等式中,`r[ij]`表示表示在支持集中的每个类别和查询图像之间的相似性的关系分数。 由于我们在支持集中有 3 个类别,在查询集中有 1 个图像,因此我们将获得 3 个分数,表明支持集中的所有 3 个类别与查询图像的相似程度。
下图显示了在一次学习设置中关系网络的整体表示:
......@@ -62,7 +62,7 @@
![](img/fdf9f217-9695-462c-be8b-e6efdf3832f2.png)
我们可以像往常一样使用嵌入函数来提取查询图像的特征向量。 接下来,我们使用连接运算符![](img/6a904b6e-8b8d-4991-8097-5d179786d4ed.png)组合支持和查询集的特征向量。 我们执行级联,然后将级联的特征向量输入到关系函数并获得关系得分,该关系得分表示支持集和查询集中每个类之间的相似性。
我们可以像往常一样使用嵌入函数来提取查询图像的特征向量。 接下来,我们使用连接运算符`Z`组合支持和查询集的特征向量。 我们执行级联,然后将级联的特征向量输入到关系函数并获得关系得分,该关系得分表示支持集和查询集中每个类之间的相似性。
下图显示了关系网络在几次学习设置中的整体表示:
......@@ -70,9 +70,9 @@
# 零镜头学习中的关系网络
既然我们已经了解了如何在单发和少发学习任务中使用关系网络,我们将看到如何在零发学习设置中使用关系网络,在这种情况下,每个类别下都没有任何数据点。 但是,在零射击学习中,我们将具有元信息,该元信息是有关每个类的属性的信息,并将被编码到语义向量![](img/9fadf51c-0b77-4e48-908e-5fcc1eb85df1.png)中,其中下标`c`表示类。
既然我们已经了解了如何在单发和少发学习任务中使用关系网络,我们将看到如何在零发学习设置中使用关系网络,在这种情况下,每个类别下都没有任何数据点。 但是,在零射击学习中,我们将具有元信息,该元信息是有关每个类的属性的信息,并将被编码到语义向量`v[c]`中,其中下标`c`表示类。
我们没有使用单个嵌入函数来学习支持和查询集的嵌入,而是分别使用了两个不同的嵌入函数![](img/86aecac3-d041-4699-8b22-138ea9800f42.png)和![](img/60564a32-a8fc-477b-ab73-2f4cd0406154.png)。 首先,我们将使用![](img/86aecac3-d041-4699-8b22-138ea9800f42.png)学习语义向量![](img/9eddfb58-cbd7-4956-9c7f-ab8cee401f2f.png)的嵌入,并使用![](img/c3eb60fd-af00-4e27-a9e6-099a893e3da2.png)学习查询集![](img/2ef03adb-44aa-47b0-a212-270cef406cef.png)的嵌入。 现在,我们将使用串联操作![](img/7f57c532-4859-4c27-be61-2e5f592489e9.png)来串联这些嵌入:
我们没有使用单个嵌入函数来学习支持和查询集的嵌入,而是分别使用了两个不同的嵌入函数`f[φ1]``f[φ2]`。 首先,我们将使用`f[φ1]`学习语义向量`v[c]`的嵌入,并使用`f[φ2]`学习查询集`x[j]`的嵌入。 现在,我们将使用串联操作`Z`来串联这些嵌入:
![](img/d30a1e84-e7f0-486f-ba63-07bd509357ca.png)
......@@ -88,7 +88,7 @@
![](img/aa46d3e5-b4a1-4780-8b9f-f095b87b18f5.png)
其中![](img/8c4472bf-e13e-4bd0-993f-f61692d3fc5a.png)分别是我们嵌入函数![](img/2cacbc1d-22e1-4224-8bf5-9e4ff2408d9e.png)和关联函数![](img/ce50f7c9-71d5-4379-a9a8-886386051514.png)的参数。
其中`φ, φ`分别是我们嵌入函数`f`和关联函数`g`的参数。
# 使用 TensorFlow 建立关系网络
......@@ -261,17 +261,17 @@ Episode 900: loss 0.250
匹配网络是 Google 的 DeepMind 团队发布的另一种简单高效的一次性学习算法。 它甚至可以为数据集中未观察到的类生成标签。
假设我们有一个支持集`S`,其中包含`K`示例作为![](img/7de9a941-8801-4363-ac4e-f9ae0c13a716.png)。 给定查询点(一个新的看不见的示例)![](img/9385cadf-442c-46da-b498-a613372648fc.png)时,匹配网络通过将其与支持集进行比较来预测![](img/f3d60daa-f8f6-4be2-b246-984bea7eb4e2.png)的类别。
假设我们有一个支持集`S`,其中包含`K`示例作为`(x1, y1), (x2, y2), ..., (xk, yk)`。 给定查询点(一个新的看不见的示例)`x_hat`时,匹配网络通过将其与支持集进行比较来预测`x_hat`的类别。
我们可以将其定义为![](img/d9ccc420-177f-4115-a8e1-7d65897c692a.png),其中![](img/318eca5d-7d4f-4d93-828d-f0444520e952.png)是参数化神经网络,![](img/ed201434-5cfd-4a61-bdb5-ec02a9901af6.png)是查询点的预测类,![](img/d3aaea41-a25b-424c-a540-b245800d4613.png)和![](img/3755650d-21aa-4377-b997-5d69dbfa2ec9.png)是支持集。 ![](img/d9ccc420-177f-4115-a8e1-7d65897c692a.png)将返回![](img/d31b51a8-0ee8-494b-b91f-81873448c0c4.png)属于数据集中每个类别的概率。 然后,我们选择![](img/2af6be12-a8d4-4452-8df2-a5fec999a883.png)的类别作为可能性最高的类别。 但是,这到底如何工作? 如何计算此概率? 让我们现在看看。
我们可以将其定义为`P(y_hat | x_hat, S)`,其中`P`是参数化神经网络,`y_hat`是查询点的预测类,`x_hat``S`是支持集。 `P(y_hat | x_hat, S)`将返回`x_hat`属于数据集中每个类别的概率。 然后,我们选择`x_hat`的类别作为可能性最高的类别。 但是,这到底如何工作? 如何计算此概率? 让我们现在看看。
查询点![](img/00d00799-95bf-4a1f-9eeb-2a2036114945.png)的输出![](img/ed2b81c9-3cb5-4e32-9649-63bbac61f5ca.png)可以预测如下:
查询点`x_hat`的输出`y_hat`可以预测如下:
![](img/dae716c0-28ab-4ad4-b488-6e5b4915aead.png)
让我们破译这个方程式。 ![](img/1e6b17c3-d908-46c6-93bc-c0bb7f15fed2.png)和![](img/6d55955b-5e84-41f6-babe-707bb2368af4.png)是支持集的输入和标签。 ![](img/c99092bd-ca8f-4410-8c98-f3763277f841.png)是查询输入,即我们要预测标签的输入。 ![](img/cc8e4200-b6cf-48af-84c3-1b6037fe5fdc.png)是![](img/fc4e3de9-a4be-4fdc-bacc-8753f9ab8b5a.png)和![](img/baceddf4-d3f2-476d-a311-7c2c26b6abf9.png)之间的注意力机制。 但是,我们该如何进行关注呢? 在这里,我们使用一种简单的注意机制,即![](img/cf2bb81f-0cd2-41d9-928b-a5f04b45407c.png)和![](img/63c54228-3b78-4be1-9900-dcaea5aa67d8.png)之间(即![](img/fdf5cddd-d8de-4c1c-bc4a-06305163cedd.png))的余弦距离上的 softmax 函数
让我们破译这个方程式。 `x[i]``y[i]`是支持集的输入和标签。 `x_hat`是查询输入,即我们要预测标签的输入。 `a``x_hat``x[i]`之间的注意力机制。 但是,我们该如何进行关注呢? 在这里,我们使用一种简单的注意机制,即`x_hat``x[i]`之间的余弦距离上的 softmax 函数(即`a(·, ·) = softmax(cosine(·, ·))`
我们无法直接计算原始输入![](img/6e5b33d3-a4ef-4260-a63f-ba22cfb11ced.png)和![](img/dbd64ebc-d0a5-479d-aa27-697de19d518a.png)之间的余弦距离。 因此,首先,我们将学习它们的嵌入并计算嵌入之间的余弦距离。 我们使用两种不同的嵌入![](img/2e2245e2-5a20-4624-9c2a-0e581a88ab5d.png)和![](img/c794b2cc-cc8d-4917-ab1b-73f3134fc859.png)来分别学习查询输入![](img/f3f565ae-e636-4c48-b734-98717227b7b8.png)和支持集输入![](img/8a76ce9f-e0dc-4ea8-908e-0e1f0b4fe08f.png)的嵌入。 我们将在接下来的部分中详细了解![](img/3ca0ce74-1ec2-4adc-8763-e4764a603c8c.png)和![](img/b2e015e7-da70-4cb6-8e03-34c28cbc61ca.png)这两个嵌入函数。
我们无法直接计算原始输入`x_hat``x[i]`之间的余弦距离。 因此,首先,我们将学习它们的嵌入并计算嵌入之间的余弦距离。 我们使用两种不同的嵌入`f``g`来分别学习查询输入`x_hat`和支持集输入`x[i]`的嵌入。 我们将在接下来的部分中详细了解`f``g`这两个嵌入函数。
因此,我们可以如下重写注意力方程:
......@@ -281,21 +281,21 @@ Episode 900: loss 0.250
![](img/fd6aef37-dfe5-45f4-87a3-799cfd6e7c43.png)
因此,在计算注意力矩阵![](img/3d2735a2-46f8-4271-9ac1-b3b77cfa124f.png)之后,我们将注意力矩阵与支持集标签![](img/4eb03151-9951-4817-ae00-14c1d60f7229.png)相乘。 但是,如何将支持集标签与注意力矩阵相乘呢? 首先,我们将支持集标签转换为一个热编码值,然后将它们与我们的注意力矩阵相乘,结果,我们获得了![](img/685d4fc8-b1ca-41aa-a68c-60b5378803bf.png)属于支持集中每个类的概率。 然后,我们应用 argmax 并选择![](img/5d76bf68-0779-4160-8162-35ee38491efe.png)作为具有最大概率值的那个。
因此,在计算注意力矩阵`a(x_hat, x[i])`之后,我们将注意力矩阵与支持集标签`y[i]`相乘。 但是,如何将支持集标签与注意力矩阵相乘呢? 首先,我们将支持集标签转换为一个热编码值,然后将它们与我们的注意力矩阵相乘,结果,我们获得了`y_hat`属于支持集中每个类的概率。 然后,我们应用 argmax 并选择`y_hat`作为具有最大概率值的那个。
您是否还不清楚匹配网络? 看下图; 如您所见,我们的支持集中有 3 个类,即`{Lion, Eleph, Dog}`,还有一个新的查询图像![](img/6ed1cd67-f5a3-4556-a51f-293703675687.png)。 首先,将支持集提供给嵌入函数![](img/ea0c80de-3df9-49a0-8e3d-1ef56cf33ddd.png),将查询图像提供给嵌入函数![](img/c4130722-4028-4ab3-819d-53b515d033d4.png),然后学习它们的嵌入并计算它们之间的余弦距离; 然后,我们在这个余弦距离上施加 softmax 注意。 然后,将注意力矩阵与一键编码支持集标签相乘,得到概率,然后选择![](img/0c52e07c-051e-497e-85ef-db57678bb2b2.png)作为概率最高的那个。 如下图所示,查询集图像是一头大象,我们在索引 1 处的概率很高,因此我们将![](img/b9a8f80e-6f5f-4aee-a138-7b0d7af50c5c.png)的类别预测为 1(大象):
您是否还不清楚匹配网络? 看下图; 如您所见,我们的支持集中有 3 个类,即`{Lion, Eleph, Dog}`,还有一个新的查询图像`x_hat`。 首先,将支持集提供给嵌入函数`g`,将查询图像提供给嵌入函数`f`,然后学习它们的嵌入并计算它们之间的余弦距离; 然后,我们在这个余弦距离上施加 softmax 注意。 然后,将注意力矩阵与一键编码支持集标签相乘,得到概率,然后选择`y_hat`作为概率最高的那个。 如下图所示,查询集图像是一头大象,我们在索引 1 处的概率很高,因此我们将`y_hat`的类别预测为 1(大象):
![](img/92322e5f-a6ae-42d2-b3e4-6d4820b5dde8.png)
# 嵌入函数
我们了解到,我们使用两个嵌入函数![](img/722c4cab-8c9b-45aa-87b4-d2d171e23c57.png)和![](img/583e2b77-74e9-44aa-be8e-fa6b435d77c9.png)分别学习![](img/8298038a-6f99-42d3-af60-9278249d5b15.png)和![](img/aaf53641-2bb7-4441-a93c-e4599f57b944.png)的嵌入。 现在,我们将确切地看到这两个函数如何学习嵌入。
我们了解到,我们使用两个嵌入函数`f``g`分别学习`x_hat``y_hat`的嵌入。 现在,我们将确切地看到这两个函数如何学习嵌入。
# 支持集嵌入函数(`g`)
我们使用嵌入函数![](img/d9b6aae8-0813-4ee5-9c07-aac1e3e1ba7d.png)来学习支持集的嵌入。 我们使用双向 LSTM 作为我们的嵌入函数![](img/6e84bc5e-d86c-4b7b-9b3b-dca3232ae883.png)
我们使用嵌入函数`g`来学习支持集的嵌入。 我们使用双向 LSTM 作为我们的嵌入函数`g`
我们可以如下定义嵌入函数![](img/71b0b590-c21a-4b97-9a91-3884d1a9ce57.png)
我们可以如下定义嵌入函数`g`
```py
def g(X):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册