Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Annotated Deep Learning Paper Implementations
提交
21b61874
A
Annotated Deep Learning Paper Implementations
项目概览
Greenplum
/
Annotated Deep Learning Paper Implementations
10 个月 前同步成功
通知
6
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
Annotated Deep Learning Paper Implementations
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
21b61874
编写于
8月 21, 2021
作者:
V
Varuna Jayasiri
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
docs fix
上级
b6607524
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
63 addition
and
63 deletion
+63
-63
docs/sitemap.xml
docs/sitemap.xml
+1
-1
docs/uncertainty/evidence/index.html
docs/uncertainty/evidence/index.html
+31
-31
labml_nn/uncertainty/evidence/__init__.py
labml_nn/uncertainty/evidence/__init__.py
+31
-31
未找到文件。
docs/sitemap.xml
浏览文件 @
21b61874
...
...
@@ -204,7 +204,7 @@
<url>
<loc>
https://nn.labml.ai/normalization/batch_norm/mnist.html
</loc>
<lastmod>
2021-08-2
0
T16:30:00+00:00
</lastmod>
<lastmod>
2021-08-2
1
T16:30:00+00:00
</lastmod>
<priority>
1.00
</priority>
</url>
...
...
docs/uncertainty/evidence/index.html
浏览文件 @
21b61874
...
...
@@ -85,14 +85,14 @@ and $u = \frac{K}{S}$ where $S = \sum_{k=1}^K (e_k + 1)$.
Paper uses term evidence as a measure of the amount of support
collected from data in favor of a sample to be classified into a certain class.
</p>
<p>
This corresponds to a
<a
href=
"https://en.wikipedia.org/wiki/Dirichlet_distribution"
>
Dirichlet distribution
</a>
with parameters $\color{
cyan
}{\alpha_k} = e_k + 1$, and
$\color{
cyan}{\alpha_0} = S = \sum_{k=1}^K \color{cyan
}{\alpha_k}$ is known as the Dirichlet strength.
Dirichlet distribution $D(\mathbf{p} \vert \color{
cyan
}{\mathbf{\alpha}})$
with parameters $\color{
orange
}{\alpha_k} = e_k + 1$, and
$\color{
orange}{\alpha_0} = S = \sum_{k=1}^K \color{orange
}{\alpha_k}$ is known as the Dirichlet strength.
Dirichlet distribution $D(\mathbf{p} \vert \color{
orange
}{\mathbf{\alpha}})$
is a distribution over categorical distribution; i.e. you can sample class probabilities
from a Dirichlet distribution.
The expected probability for class $k$ is $\hat{p}_k = \frac{\color{
cyan
}{\alpha_k}}{S}$.
</p>
The expected probability for class $k$ is $\hat{p}_k = \frac{\color{
orange
}{\alpha_k}}{S}$.
</p>
<p>
We get the model to output evidences
<script
type=
"math/tex; mode=display"
>
\
mathbf
{
e
}
=
\
color
{
cyan
}{
\
mathbf
{
\
alpha
}}
-
1
=
f
(
\
mathbf
{
x
}
|
\
Theta
)
</script>
<script
type=
"math/tex; mode=display"
>
\
mathbf
{
e
}
=
\
color
{
orange
}{
\
mathbf
{
\
alpha
}}
-
1
=
f
(
\
mathbf
{
x
}
|
\
Theta
)
</script>
for a given input $\mathbf{x}$.
We use a function such as
<a
href=
"https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html"
>
ReLU
</a>
or a
...
...
@@ -116,7 +116,7 @@ We use a function such as
</div>
<p><a
id=
"MaximumLikelihoodLoss"
></a></p>
<h2>
Type II Maximum Likelihood Loss
</h2>
<p>
The distribution
D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}})
is a prior on the likelihood
<p>
The distribution
$D(\mathbf{p} \vert \color{orange}{\mathbf{\alpha}})$
is a prior on the likelihood
$Multi(\mathbf{y} \vert p)$,
and the negative log marginal likelihood is calculated by integrating over class probabilities
$\mathbf{p}$.
</p>
...
...
@@ -127,11 +127,11 @@ $Multi(\mathbf{y} \vert p)$,
&=
-
\
log
\
Bigg
(
\
int
\
prod_
{
k
=
1
}
^
K
p_k
^
{
y_k
}
\
frac
{
1
}{
B
(
\
color
{
cyan
}{
\
mathbf
{
\
alpha
}})}
\
prod_
{
k
=
1
}
^
K
p_k
^
{
\
color
{
cyan
}{
\
alpha_k
}
-
1
}
\
frac
{
1
}{
B
(
\
color
{
orange
}{
\
mathbf
{
\
alpha
}})}
\
prod_
{
k
=
1
}
^
K
p_k
^
{
\
color
{
orange
}{
\
alpha_k
}
-
1
}
d
\
mathbf
{
p
}
\
Bigg
)
\
\
&=
\
sum_
{
k
=
1
}
^
K
y_k
\
bigg
(
\
log
S
-
\
log
\
color
{
cyan
}{
\
alpha_k
}
\
bigg
)
&=
\
sum_
{
k
=
1
}
^
K
y_k
\
bigg
(
\
log
S
-
\
log
\
color
{
orange
}{
\
alpha_k
}
\
bigg
)
\
end
{
align
}
</script>
</p>
</div>
...
...
@@ -158,7 +158,7 @@ $Multi(\mathbf{y} \vert p)$,
<div
class=
'section-link'
>
<a
href=
'#section-3'
>
#
</a>
</div>
<p>
$\color{
cyan
}{\alpha_k} = e_k + 1$
</p>
<p>
$\color{
orange
}{\alpha_k} = e_k + 1$
</p>
</div>
<div
class=
'code'
>
<div
class=
"highlight"
><pre><span
class=
"lineno"
>
90
</span>
<span
class=
"n"
>
alpha
</span>
<span
class=
"o"
>
=
</span>
<span
class=
"n"
>
evidence
</span>
<span
class=
"o"
>
+
</span>
<span
class=
"mf"
>
1.
</span></pre></div>
...
...
@@ -169,7 +169,7 @@ $Multi(\mathbf{y} \vert p)$,
<div
class=
'section-link'
>
<a
href=
'#section-4'
>
#
</a>
</div>
<p>
$S = \sum_{k=1}^K \color{
cyan
}{\alpha_k}$
</p>
<p>
$S = \sum_{k=1}^K \color{
orange
}{\alpha_k}$
</p>
</div>
<div
class=
'code'
>
<div
class=
"highlight"
><pre><span
class=
"lineno"
>
92
</span>
<span
class=
"n"
>
strength
</span>
<span
class=
"o"
>
=
</span>
<span
class=
"n"
>
alpha
</span><span
class=
"o"
>
.
</span><span
class=
"n"
>
sum
</span><span
class=
"p"
>
(
</span><span
class=
"n"
>
dim
</span><span
class=
"o"
>
=-
</span><span
class=
"mi"
>
1
</span><span
class=
"p"
>
)
</span></pre></div>
...
...
@@ -180,7 +180,7 @@ $Multi(\mathbf{y} \vert p)$,
<div
class=
'section-link'
>
<a
href=
'#section-5'
>
#
</a>
</div>
<p>
Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \log S - \log \color{
cyan
}{\alpha_k} \bigg)$
</p>
<p>
Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \log S - \log \color{
orange
}{\alpha_k} \bigg)$
</p>
</div>
<div
class=
'code'
>
<div
class=
"highlight"
><pre><span
class=
"lineno"
>
95
</span>
<span
class=
"n"
>
loss
</span>
<span
class=
"o"
>
=
</span>
<span
class=
"p"
>
(
</span><span
class=
"n"
>
target
</span>
<span
class=
"o"
>
*
</span>
<span
class=
"p"
>
(
</span><span
class=
"n"
>
strength
</span><span
class=
"o"
>
.
</span><span
class=
"n"
>
log
</span><span
class=
"p"
>
()[:,
</span>
<span
class=
"kc"
>
None
</span><span
class=
"p"
>
]
</span>
<span
class=
"o"
>
-
</span>
<span
class=
"n"
>
alpha
</span><span
class=
"o"
>
.
</span><span
class=
"n"
>
log
</span><span
class=
"p"
>
()))
</span><span
class=
"o"
>
.
</span><span
class=
"n"
>
sum
</span><span
class=
"p"
>
(
</span><span
class=
"n"
>
dim
</span><span
class=
"o"
>
=-
</span><span
class=
"mi"
>
1
</span><span
class=
"p"
>
)
</span></pre></div>
...
...
@@ -217,11 +217,11 @@ and sums it over all possible outcomes based on probability distribution.</p>
&=
-
\
log
\
Bigg
(
\
int
\
Big
[
\
sum_
{
k
=
1
}
^
K
-
y_k
\
log
p_k
\
Big
]
\
frac
{
1
}{
B
(
\
color
{
cyan
}{
\
mathbf
{
\
alpha
}})}
\
prod_
{
k
=
1
}
^
K
p_k
^
{
\
color
{
cyan
}{
\
alpha_k
}
-
1
}
\
frac
{
1
}{
B
(
\
color
{
orange
}{
\
mathbf
{
\
alpha
}})}
\
prod_
{
k
=
1
}
^
K
p_k
^
{
\
color
{
orange
}{
\
alpha_k
}
-
1
}
d
\
mathbf
{
p
}
\
Bigg
)
\
\
&=
\
sum_
{
k
=
1
}
^
K
y_k
\
bigg
(
\
psi
(
S
)
-
\
psi
(
\
color
{
cyan
}{
\
alpha_k
}
)
\
bigg
)
&=
\
sum_
{
k
=
1
}
^
K
y_k
\
bigg
(
\
psi
(
S
)
-
\
psi
(
\
color
{
orange
}{
\
alpha_k
}
)
\
bigg
)
\
end
{
align
}
</script>
</p>
<p>
where $\psi(\cdot)$ is the $digamma$ function.
</p>
...
...
@@ -249,7 +249,7 @@ and sums it over all possible outcomes based on probability distribution.</p>
<div
class=
'section-link'
>
<a
href=
'#section-9'
>
#
</a>
</div>
<p>
$\color{
cyan
}{\alpha_k} = e_k + 1$
</p>
<p>
$\color{
orange
}{\alpha_k} = e_k + 1$
</p>
</div>
<div
class=
'code'
>
<div
class=
"highlight"
><pre><span
class=
"lineno"
>
136
</span>
<span
class=
"n"
>
alpha
</span>
<span
class=
"o"
>
=
</span>
<span
class=
"n"
>
evidence
</span>
<span
class=
"o"
>
+
</span>
<span
class=
"mf"
>
1.
</span></pre></div>
...
...
@@ -260,7 +260,7 @@ and sums it over all possible outcomes based on probability distribution.</p>
<div
class=
'section-link'
>
<a
href=
'#section-10'
>
#
</a>
</div>
<p>
$S = \sum_{k=1}^K \color{
cyan
}{\alpha_k}$
</p>
<p>
$S = \sum_{k=1}^K \color{
orange
}{\alpha_k}$
</p>
</div>
<div
class=
'code'
>
<div
class=
"highlight"
><pre><span
class=
"lineno"
>
138
</span>
<span
class=
"n"
>
strength
</span>
<span
class=
"o"
>
=
</span>
<span
class=
"n"
>
alpha
</span><span
class=
"o"
>
.
</span><span
class=
"n"
>
sum
</span><span
class=
"p"
>
(
</span><span
class=
"n"
>
dim
</span><span
class=
"o"
>
=-
</span><span
class=
"mi"
>
1
</span><span
class=
"p"
>
)
</span></pre></div>
...
...
@@ -271,7 +271,7 @@ and sums it over all possible outcomes based on probability distribution.</p>
<div
class=
'section-link'
>
<a
href=
'#section-11'
>
#
</a>
</div>
<p>
Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{
cyan
}{\alpha_k} ) \bigg)$
</p>
<p>
Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{
orange
}{\alpha_k} ) \bigg)$
</p>
</div>
<div
class=
'code'
>
<div
class=
"highlight"
><pre><span
class=
"lineno"
>
141
</span>
<span
class=
"n"
>
loss
</span>
<span
class=
"o"
>
=
</span>
<span
class=
"p"
>
(
</span><span
class=
"n"
>
target
</span>
<span
class=
"o"
>
*
</span>
<span
class=
"p"
>
(
</span><span
class=
"n"
>
torch
</span><span
class=
"o"
>
.
</span><span
class=
"n"
>
digamma
</span><span
class=
"p"
>
(
</span><span
class=
"n"
>
strength
</span><span
class=
"p"
>
)[:,
</span>
<span
class=
"kc"
>
None
</span><span
class=
"p"
>
]
</span>
<span
class=
"o"
>
-
</span>
<span
class=
"n"
>
torch
</span><span
class=
"o"
>
.
</span><span
class=
"n"
>
digamma
</span><span
class=
"p"
>
(
</span><span
class=
"n"
>
alpha
</span><span
class=
"p"
>
)))
</span><span
class=
"o"
>
.
</span><span
class=
"n"
>
sum
</span><span
class=
"p"
>
(
</span><span
class=
"n"
>
dim
</span><span
class=
"o"
>
=-
</span><span
class=
"mi"
>
1
</span><span
class=
"p"
>
)
</span></pre></div>
...
...
@@ -305,19 +305,19 @@ and sums it over all possible outcomes based on probability distribution.</p>
&=
-
\
log
\
Bigg
(
\
int
\
Big
[
\
sum_
{
k
=
1
}
^
K
(
y_k
-
p_k
)
^
2
\
Big
]
\
frac
{
1
}{
B
(
\
color
{
cyan
}{
\
mathbf
{
\
alpha
}})}
\
prod_
{
k
=
1
}
^
K
p_k
^
{
\
color
{
cyan
}{
\
alpha_k
}
-
1
}
\
frac
{
1
}{
B
(
\
color
{
orange
}{
\
mathbf
{
\
alpha
}})}
\
prod_
{
k
=
1
}
^
K
p_k
^
{
\
color
{
orange
}{
\
alpha_k
}
-
1
}
d
\
mathbf
{
p
}
\
Bigg
)
\
\
&=
\
sum_
{
k
=
1
}
^
K
\
mathbb
{
E
}
\
Big
[
y_k
^
2
-
2
y_k
p_k
+
p_k
^
2
\
Big
]
\
\
&=
\
sum_
{
k
=
1
}
^
K
\
Big
(
y_k
^
2
-
2
y_k
\
mathbb
{
E
}[
p_k
]
+
\
mathbb
{
E
}[
p_k
^
2
]
\
Big
)
\
end
{
align
}
</script>
</p>
<p>
Where
<script
type=
"math/tex; mode=display"
>
\
mathbb
{
E
}[
p_k
]
=
\
hat
{
p
}
_k
=
\
frac
{
\
color
{
cyan
}{
\
alpha_k
}}{
S
}
</script>
<p>
Where
<script
type=
"math/tex; mode=display"
>
\
mathbb
{
E
}[
p_k
]
=
\
hat
{
p
}
_k
=
\
frac
{
\
color
{
orange
}{
\
alpha_k
}}{
S
}
</script>
is the expected probability when sampled from the Dirichlet distribution
and
<script
type=
"math/tex; mode=display"
>
\
mathbb
{
E
}[
p_k
^
2
]
=
\
mathbb
{
E
}[
p_k
]
^
2
+
\
text
{
Var
}(
p_k
)
</script>
where
<script
type=
"math/tex; mode=display"
>
\
text
{
Var
}(
p_k
)
=
\
frac
{
\
color
{
cyan
}{
\
alpha_k
}(
S
-
\
color
{
cyan
}{
\
alpha_k
})}{
S
^
2
(
S
+
1
)}
<script
type=
"math/tex; mode=display"
>
\
text
{
Var
}(
p_k
)
=
\
frac
{
\
color
{
orange
}{
\
alpha_k
}(
S
-
\
color
{
orange
}{
\
alpha_k
})}{
S
^
2
(
S
+
1
)}
=
\
frac
{
\
hat
{
p
}
_k
(
1
-
\
hat
{
p
}
_k
)}{
S
+
1
}
</script>
is the variance.
</p>
<p>
This gives,
...
...
@@ -355,7 +355,7 @@ the second part is the variance.</p>
<div
class=
'section-link'
>
<a
href=
'#section-15'
>
#
</a>
</div>
<p>
$\color{
cyan
}{\alpha_k} = e_k + 1$
</p>
<p>
$\color{
orange
}{\alpha_k} = e_k + 1$
</p>
</div>
<div
class=
'code'
>
<div
class=
"highlight"
><pre><span
class=
"lineno"
>
197
</span>
<span
class=
"n"
>
alpha
</span>
<span
class=
"o"
>
=
</span>
<span
class=
"n"
>
evidence
</span>
<span
class=
"o"
>
+
</span>
<span
class=
"mf"
>
1.
</span></pre></div>
...
...
@@ -366,7 +366,7 @@ the second part is the variance.</p>
<div
class=
'section-link'
>
<a
href=
'#section-16'
>
#
</a>
</div>
<p>
$S = \sum_{k=1}^K \color{
cyan
}{\alpha_k}$
</p>
<p>
$S = \sum_{k=1}^K \color{
orange
}{\alpha_k}$
</p>
</div>
<div
class=
'code'
>
<div
class=
"highlight"
><pre><span
class=
"lineno"
>
199
</span>
<span
class=
"n"
>
strength
</span>
<span
class=
"o"
>
=
</span>
<span
class=
"n"
>
alpha
</span><span
class=
"o"
>
.
</span><span
class=
"n"
>
sum
</span><span
class=
"p"
>
(
</span><span
class=
"n"
>
dim
</span><span
class=
"o"
>
=-
</span><span
class=
"mi"
>
1
</span><span
class=
"p"
>
)
</span></pre></div>
...
...
@@ -377,7 +377,7 @@ the second part is the variance.</p>
<div
class=
'section-link'
>
<a
href=
'#section-17'
>
#
</a>
</div>
<p>
$\hat{p}_k = \frac{\color{
cyan
}{\alpha_k}}{S}$
</p>
<p>
$\hat{p}_k = \frac{\color{
orange
}{\alpha_k}}{S}$
</p>
</div>
<div
class=
'code'
>
<div
class=
"highlight"
><pre><span
class=
"lineno"
>
201
</span>
<span
class=
"n"
>
p
</span>
<span
class=
"o"
>
=
</span>
<span
class=
"n"
>
alpha
</span>
<span
class=
"o"
>
/
</span>
<span
class=
"n"
>
strength
</span><span
class=
"p"
>
[:,
</span>
<span
class=
"kc"
>
None
</span><span
class=
"p"
>
]
</span></pre></div>
...
...
@@ -435,7 +435,7 @@ the second part is the variance.</p>
<p><a
id=
"KLDivergenceLoss"
></a></p>
<h2>
KL Divergence Regularization Loss
</h2>
<p>
This tries to shrink the total evidence to zero if the sample cannot be correctly classified.
</p>
<p>
First we calculate $\tilde{\alpha}_k = y_k + (1 - y_k) \color{
cyan
}{\alpha_k}$ the
<p>
First we calculate $\tilde{\alpha}_k = y_k + (1 - y_k) \color{
orange
}{\alpha_k}$ the
Dirichlet parameters after remove the correct evidence.
</p>
<p>
<script
type=
"math/tex; mode=display"
>
\
begin
{
align
}
...
...
@@ -474,7 +474,7 @@ $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$</p>
<div
class=
'section-link'
>
<a
href=
'#section-24'
>
#
</a>
</div>
<p>
$\color{
cyan
}{\alpha_k} = e_k + 1$
</p>
<p>
$\color{
orange
}{\alpha_k} = e_k + 1$
</p>
</div>
<div
class=
'code'
>
<div
class=
"highlight"
><pre><span
class=
"lineno"
>
244
</span>
<span
class=
"n"
>
alpha
</span>
<span
class=
"o"
>
=
</span>
<span
class=
"n"
>
evidence
</span>
<span
class=
"o"
>
+
</span>
<span
class=
"mf"
>
1.
</span></pre></div>
...
...
@@ -497,7 +497,7 @@ $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$</p>
<a
href=
'#section-26'
>
#
</a>
</div>
<p>
Remove non-misleading evidence
<script
type=
"math/tex; mode=display"
>
\
tilde
{
\
alpha
}
_k
=
y_k
+
(
1
-
y_k
)
\
color
{
cyan
}{
\
alpha_k
}
</script>
<script
type=
"math/tex; mode=display"
>
\
tilde
{
\
alpha
}
_k
=
y_k
+
(
1
-
y_k
)
\
color
{
orange
}{
\
alpha_k
}
</script>
</p>
</div>
<div
class=
'code'
>
...
...
@@ -637,7 +637,7 @@ $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$</p>
<div
class=
'section-link'
>
<a
href=
'#section-37'
>
#
</a>
</div>
<p>
$\color{
cyan
}{\alpha_k} = e_k + 1$
</p>
<p>
$\color{
orange
}{\alpha_k} = e_k + 1$
</p>
</div>
<div
class=
'code'
>
<div
class=
"highlight"
><pre><span
class=
"lineno"
>
296
</span>
<span
class=
"n"
>
alpha
</span>
<span
class=
"o"
>
=
</span>
<span
class=
"n"
>
evidence
</span>
<span
class=
"o"
>
+
</span>
<span
class=
"mf"
>
1.
</span></pre></div>
...
...
@@ -648,7 +648,7 @@ $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$</p>
<div
class=
'section-link'
>
<a
href=
'#section-38'
>
#
</a>
</div>
<p>
$S = \sum_{k=1}^K \color{
cyan
}{\alpha_k}$
</p>
<p>
$S = \sum_{k=1}^K \color{
orange
}{\alpha_k}$
</p>
</div>
<div
class=
'code'
>
<div
class=
"highlight"
><pre><span
class=
"lineno"
>
298
</span>
<span
class=
"n"
>
strength
</span>
<span
class=
"o"
>
=
</span>
<span
class=
"n"
>
alpha
</span><span
class=
"o"
>
.
</span><span
class=
"n"
>
sum
</span><span
class=
"p"
>
(
</span><span
class=
"n"
>
dim
</span><span
class=
"o"
>
=-
</span><span
class=
"mi"
>
1
</span><span
class=
"p"
>
)
</span></pre></div>
...
...
@@ -659,7 +659,7 @@ $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$</p>
<div
class=
'section-link'
>
<a
href=
'#section-39'
>
#
</a>
</div>
<p>
$\hat{p}_k = \frac{\color{
cyan
}{\alpha_k}}{S}$
</p>
<p>
$\hat{p}_k = \frac{\color{
orange
}{\alpha_k}}{S}$
</p>
</div>
<div
class=
'code'
>
<div
class=
"highlight"
><pre><span
class=
"lineno"
>
301
</span>
<span
class=
"n"
>
expected_probability
</span>
<span
class=
"o"
>
=
</span>
<span
class=
"n"
>
alpha
</span>
<span
class=
"o"
>
/
</span>
<span
class=
"n"
>
strength
</span><span
class=
"p"
>
[:,
</span>
<span
class=
"kc"
>
None
</span><span
class=
"p"
>
]
</span></pre></div>
...
...
labml_nn/uncertainty/evidence/__init__.py
浏览文件 @
21b61874
...
...
@@ -29,15 +29,15 @@ Paper uses term evidence as a measure of the amount of support
collected from data in favor of a sample to be classified into a certain class.
This corresponds to a [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution)
with parameters $\color{
cyan
}{
\a
lpha_k} = e_k + 1$, and
$\color{
cyan}{
\a
lpha_0} = S = \sum_{k=1}^K \color{cyan
}{
\a
lpha_k}$ is known as the Dirichlet strength.
Dirichlet distribution $D(\mathbf{p}
\v
ert \color{
cyan
}{\mathbf{
\a
lpha}})$
with parameters $\color{
orange
}{
\a
lpha_k} = e_k + 1$, and
$\color{
orange}{
\a
lpha_0} = S = \sum_{k=1}^K \color{orange
}{
\a
lpha_k}$ is known as the Dirichlet strength.
Dirichlet distribution $D(\mathbf{p}
\v
ert \color{
orange
}{\mathbf{
\a
lpha}})$
is a distribution over categorical distribution; i.e. you can sample class probabilities
from a Dirichlet distribution.
The expected probability for class $k$ is $\hat{p}_k =
\f
rac{\color{
cyan
}{
\a
lpha_k}}{S}$.
The expected probability for class $k$ is $\hat{p}_k =
\f
rac{\color{
orange
}{
\a
lpha_k}}{S}$.
We get the model to output evidences
$$\mathbf{e} = \color{
cyan
}{\mathbf{
\a
lpha}} - 1 = f(\mathbf{x} | \Theta)$$
$$\mathbf{e} = \color{
orange
}{\mathbf{
\a
lpha}} - 1 = f(\mathbf{x} | \Theta)$$
for a given input $\mathbf{x}$.
We use a function such as
[ReLU](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html) or a
...
...
@@ -62,7 +62,7 @@ class MaximumLikelihoodLoss(Module):
<a id="MaximumLikelihoodLoss"></a>
## Type II Maximum Likelihood Loss
The distribution
D(\mathbf{p}
\v
ert \color{cyan}{\mathbf{
\a
lpha}})
is a prior on the likelihood
The distribution
$D(\mathbf{p}
\v
ert \color{orange}{\mathbf{
\a
lpha}})$
is a prior on the likelihood
$Multi(\mathbf{y}
\v
ert p)$,
and the negative log marginal likelihood is calculated by integrating over class probabilities
$\mathbf{p}$.
...
...
@@ -74,11 +74,11 @@ class MaximumLikelihoodLoss(Module):
&= -\log \Bigg(
\int
\prod_{k=1}^K p_k^{y_k}
\f
rac{1}{B(\color{
cyan
}{\mathbf{
\a
lpha}})}
\prod_{k=1}^K p_k^{\color{
cyan
}{
\a
lpha_k} - 1}
\f
rac{1}{B(\color{
orange
}{\mathbf{
\a
lpha}})}
\prod_{k=1}^K p_k^{\color{
orange
}{
\a
lpha_k} - 1}
d\mathbf{p}
\Bigg )
\\
&= \sum_{k=1}^K y_k
\b
igg( \log S - \log \color{
cyan
}{
\a
lpha_k}
\b
igg)
&= \sum_{k=1}^K y_k
\b
igg( \log S - \log \color{
orange
}{
\a
lpha_k}
\b
igg)
\end{align}
"""
def
forward
(
self
,
evidence
:
torch
.
Tensor
,
target
:
torch
.
Tensor
):
...
...
@@ -86,12 +86,12 @@ class MaximumLikelihoodLoss(Module):
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
"""
# $\color{
cyan
}{\alpha_k} = e_k + 1$
# $\color{
orange
}{\alpha_k} = e_k + 1$
alpha
=
evidence
+
1.
# $S = \sum_{k=1}^K \color{
cyan
}{\alpha_k}$
# $S = \sum_{k=1}^K \color{
orange
}{\alpha_k}$
strength
=
alpha
.
sum
(
dim
=-
1
)
# Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \log S - \log \color{
cyan
}{\alpha_k} \bigg)$
# Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \log S - \log \color{
orange
}{\alpha_k} \bigg)$
loss
=
(
target
*
(
strength
.
log
()[:,
None
]
-
alpha
.
log
())).
sum
(
dim
=-
1
)
# Mean loss over the batch
...
...
@@ -117,11 +117,11 @@ class CrossEntropyBayesRisk(Module):
&= -\log \Bigg(
\int
\Big[ \sum_{k=1}^K -y_k \log p_k \Big]
\f
rac{1}{B(\color{
cyan
}{\mathbf{
\a
lpha}})}
\prod_{k=1}^K p_k^{\color{
cyan
}{
\a
lpha_k} - 1}
\f
rac{1}{B(\color{
orange
}{\mathbf{
\a
lpha}})}
\prod_{k=1}^K p_k^{\color{
orange
}{
\a
lpha_k} - 1}
d\mathbf{p}
\Bigg )
\\
&= \sum_{k=1}^K y_k
\b
igg( \psi(S) - \psi( \color{
cyan
}{
\a
lpha_k} )
\b
igg)
&= \sum_{k=1}^K y_k
\b
igg( \psi(S) - \psi( \color{
orange
}{
\a
lpha_k} )
\b
igg)
\end{align}
where $\psi(\cdot)$ is the $digamma$ function.
...
...
@@ -132,12 +132,12 @@ class CrossEntropyBayesRisk(Module):
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
"""
# $\color{
cyan
}{\alpha_k} = e_k + 1$
# $\color{
orange
}{\alpha_k} = e_k + 1$
alpha
=
evidence
+
1.
# $S = \sum_{k=1}^K \color{
cyan
}{\alpha_k}$
# $S = \sum_{k=1}^K \color{
orange
}{\alpha_k}$
strength
=
alpha
.
sum
(
dim
=-
1
)
# Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{
cyan
}{\alpha_k} ) \bigg)$
# Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{
orange
}{\alpha_k} ) \bigg)$
loss
=
(
target
*
(
torch
.
digamma
(
strength
)[:,
None
]
-
torch
.
digamma
(
alpha
))).
sum
(
dim
=-
1
)
# Mean loss over the batch
...
...
@@ -159,19 +159,19 @@ class SquaredErrorBayesRisk(Module):
&= -\log \Bigg(
\int
\Big[ \sum_{k=1}^K (y_k - p_k)^2 \Big]
\f
rac{1}{B(\color{
cyan
}{\mathbf{
\a
lpha}})}
\prod_{k=1}^K p_k^{\color{
cyan
}{
\a
lpha_k} - 1}
\f
rac{1}{B(\color{
orange
}{\mathbf{
\a
lpha}})}
\prod_{k=1}^K p_k^{\color{
orange
}{
\a
lpha_k} - 1}
d\mathbf{p}
\Bigg )
\\
&= \sum_{k=1}^K \mathbb{E} \Big[ y_k^2 -2 y_k p_k + p_k^2 \Big]
\\
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k^2] \Big)
\end{align}
Where $$\mathbb{E}[p_k] = \hat{p}_k =
\f
rac{\color{
cyan
}{
\a
lpha_k}}{S}$$
Where $$\mathbb{E}[p_k] = \hat{p}_k =
\f
rac{\color{
orange
}{
\a
lpha_k}}{S}$$
is the expected probability when sampled from the Dirichlet distribution
and $$\mathbb{E}[p_k^2] = \mathbb{E}[p_k]^2 +
\t
ext{Var}(p_k)$$
where
$$
\t
ext{Var}(p_k) =
\f
rac{\color{
cyan}{
\a
lpha_k}(S - \color{cyan
}{
\a
lpha_k})}{S^2 (S + 1)}
$$
\t
ext{Var}(p_k) =
\f
rac{\color{
orange}{
\a
lpha_k}(S - \color{orange
}{
\a
lpha_k})}{S^2 (S + 1)}
=
\f
rac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$$
is the variance.
...
...
@@ -193,11 +193,11 @@ class SquaredErrorBayesRisk(Module):
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
"""
# $\color{
cyan
}{\alpha_k} = e_k + 1$
# $\color{
orange
}{\alpha_k} = e_k + 1$
alpha
=
evidence
+
1.
# $S = \sum_{k=1}^K \color{
cyan
}{\alpha_k}$
# $S = \sum_{k=1}^K \color{
orange
}{\alpha_k}$
strength
=
alpha
.
sum
(
dim
=-
1
)
# $\hat{p}_k = \frac{\color{
cyan
}{\alpha_k}}{S}$
# $\hat{p}_k = \frac{\color{
orange
}{\alpha_k}}{S}$
p
=
alpha
/
strength
[:,
None
]
# Error $(y_k -\hat{p}_k)^2$
...
...
@@ -219,7 +219,7 @@ class KLDivergenceLoss(Module):
This tries to shrink the total evidence to zero if the sample cannot be correctly classified.
First we calculate $
\t
ilde{
\a
lpha}_k = y_k + (1 - y_k) \color{
cyan
}{
\a
lpha_k}$ the
First we calculate $
\t
ilde{
\a
lpha}_k = y_k + (1 - y_k) \color{
orange
}{
\a
lpha_k}$ the
Dirichlet parameters after remove the correct evidence.
\b
egin{align}
...
...
@@ -240,12 +240,12 @@ class KLDivergenceLoss(Module):
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
"""
# $\color{
cyan
}{\alpha_k} = e_k + 1$
# $\color{
orange
}{\alpha_k} = e_k + 1$
alpha
=
evidence
+
1.
# Number of classes
n_classes
=
evidence
.
shape
[
-
1
]
# Remove non-misleading evidence
# $$\tilde{\alpha}_k = y_k + (1 - y_k) \color{
cyan
}{\alpha_k}$$
# $$\tilde{\alpha}_k = y_k + (1 - y_k) \color{
orange
}{\alpha_k}$$
alpha_tilde
=
target
+
(
1
-
target
)
*
alpha
# $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$
strength_tilde
=
alpha_tilde
.
sum
(
dim
=-
1
)
...
...
@@ -292,12 +292,12 @@ class TrackStatistics(Module):
# Track accuracy
tracker
.
add
(
'accuracy.'
,
match
.
sum
()
/
match
.
shape
[
0
])
# $\color{
cyan
}{\alpha_k} = e_k + 1$
# $\color{
orange
}{\alpha_k} = e_k + 1$
alpha
=
evidence
+
1.
# $S = \sum_{k=1}^K \color{
cyan
}{\alpha_k}$
# $S = \sum_{k=1}^K \color{
orange
}{\alpha_k}$
strength
=
alpha
.
sum
(
dim
=-
1
)
# $\hat{p}_k = \frac{\color{
cyan
}{\alpha_k}}{S}$
# $\hat{p}_k = \frac{\color{
orange
}{\alpha_k}}{S}$
expected_probability
=
alpha
/
strength
[:,
None
]
# Expected probability of the selected (greedy highset probability) class
expected_probability
,
_
=
expected_probability
.
max
(
dim
=-
1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录