Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Annotated Deep Learning Paper Implementations
提交
302785a4
A
Annotated Deep Learning Paper Implementations
项目概览
Greenplum
/
Annotated Deep Learning Paper Implementations
11 个月 前同步成功
通知
6
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
Annotated Deep Learning Paper Implementations
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
302785a4
编写于
9月 14, 2020
作者:
V
Varuna Jayasiri
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
capsnet annotations
上级
010f0c56
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
212 addition
and
111 deletion
+212
-111
.gitignore
.gitignore
+2
-0
.labml.yaml
.labml.yaml
+1
-0
Makefile
Makefile
+1
-1
labml_nn/__init__.py
labml_nn/__init__.py
+1
-0
labml_nn/capsule_networks/__init__.py
labml_nn/capsule_networks/__init__.py
+101
-109
labml_nn/capsule_networks/mnist.py
labml_nn/capsule_networks/mnist.py
+105
-0
setup.py
setup.py
+1
-1
未找到文件。
.gitignore
浏览文件 @
302785a4
...
...
@@ -8,3 +8,5 @@ build/
.idea/*
!.idea/dictionaries
html/
labml
labml_helpers
.labml.yaml
0 → 100644
浏览文件 @
302785a4
web_api
:
https://api.lab-ml.com/api/v1/track?labml_token=903c84fba8ca49ca9f215922833e08cf&channel=app-updates-test
Makefile
浏览文件 @
302785a4
...
...
@@ -22,7 +22,7 @@ uninstall: ## Uninstall
pip uninstall labml_nn
docs
:
##
Render annotated HTML
python ../../pylit/pylit.py
-t
../../pylit/template_docs.html
-d
html
-w
labml_nn
python ../../pylit/pylit.py
-
-remove_empty_sections
-s
../../pylit/pylit_docs.css
-
t
../../pylit/template_docs.html
-d
html
-w
labml_nn
pages
:
##
Copy to lab-ml site
@
cd
../lab-ml.github.io
;
git pull
...
...
labml_nn/__init__.py
浏览文件 @
302785a4
...
...
@@ -4,6 +4,7 @@
* [Transformers](transformers/index.html)
* [Recurrent Highway Networks](recurrent_highway_networks/index.html)
* [LSTM](lstm/index.html)
* [Capsule Networks](capsule_networks/index.html)
If you have any suggestions for other new implementations,
please create a [Github Issue](https://github.com/lab-ml/labml_nn/issues).
...
...
labml_nn/capsule_networks/__init__.py
浏览文件 @
302785a4
"""
This is an implementation of paper
This is an implementation of [Dynamic Routing Between Capsules](https://arxiv.org/abs/1710.09829).
Unlike in other implementations of models, we've included a sample, because
it is difficult to understand some of the concepts with just the modules.
[This is the annotated code for a model that use capsules to classify MNIST dataset](mnist.html)
This file holds the implementations of the core modules of Capsule Networks.
"""
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.utils.data
from
labml
import
experiment
,
tracker
from
labml.configs
import
option
from
labml.utils.pytorch
import
get_device
from
labml_helpers.datasets.mnist
import
MNISTConfigs
from
labml_helpers.device
import
DeviceConfigs
from
labml_helpers.module
import
Module
from
labml_helpers.train_valid
import
TrainValidConfigs
,
BatchStep
class
Squash
(
Module
):
"""
This is **squashing** function from paper.
## Squash
This is **squashing** function from paper, given by equation $(1)$.
$$\mathbf{v}_j =
\f
rac{{\lVert \mathbf{s}_j
\r
Vert}^2}{1 + {\lVert \mathbf{s}_j
\r
Vert}^2}
\f
rac{\mathbf{s}_j}{\lVert \mathbf{s}_j
\r
Vert}$$
$
\f
rac{\mathbf{s}_j}{\lVert \mathbf{s}_j
\r
Vert}$
normalizes the length of all the capsules, whilst
$
\f
rac{{\lVert \mathbf{s}_j
\r
Vert}^2}{1 + {\lVert \mathbf{s}_j
\r
Vert}^2}$
shrinks the capsules that have a length smaller than one .
"""
def
__init__
(
self
,
epsilon
=
1e-8
):
...
...
@@ -25,42 +35,103 @@ class Squash(Module):
self
.
epsilon
=
epsilon
def
__call__
(
self
,
s
:
torch
.
Tensor
):
# shape: batch, caps, features
"""
The shape of `s` is `[batch_size, n_capsules, n_features]`
"""
# ${\lVert \mathbf{s}_j \rVert}^2$
s2
=
(
s
**
2
).
sum
(
dim
=-
1
,
keepdims
=
True
)
# We add an epsilon when calculating $\lVert \mathbf{s}_j \rVert$ to make sure it doesn't become zero.
# If this becomes zero it starts giving out `nan` values and training fails.
# $$\mathbf{v}_j = \frac{{\lVert \mathbf{s}_j \rVert}^2}{1 + {\lVert \mathbf{s}_j \rVert}^2}
# \frac{\mathbf{s}_j}{\sqrt{{\lVert \mathbf{s}_j \rVert}^2 + \epsilon}}$$
return
(
s2
/
(
1
+
s2
))
*
(
s
/
torch
.
sqrt
(
s2
+
self
.
epsilon
))
class
Router
(
Module
):
"""
The routing mechanism
## Routing Algorithm
This is the routing mechanism described in the paper.
You can use multiple routing layers in your models.
This combines calculating $\mathbf{s}_j$ for this layer and
the routing algorithm described in *Procedure 1*.
"""
def
__init__
(
self
,
in_caps
:
int
,
out_caps
:
int
,
in_d
:
int
,
out_d
:
int
,
iterations
:
int
):
def
__init__
(
self
,
in_caps
:
int
,
out_caps
:
int
,
in_d
:
int
,
out_d
:
int
,
iterations
:
int
):
"""
`in_caps` is the number of capsules, and `in_d` is the number of features per capsule from the layer below.
`out_caps` and `out_d` are the same for this layer.
`iterations` is the number of routing iterations, symbolized by $r$ in the paper.
"""
super
().
__init__
()
self
.
in_caps
=
in_caps
self
.
out_caps
=
out_caps
self
.
iterations
=
iterations
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
in_caps
,
out_caps
,
in_d
,
out_d
))
self
.
softmax
=
nn
.
Softmax
(
dim
=
1
)
self
.
squash
=
Squash
()
# This is the weight matrix $\mathbf{W}_{ij}$. It maps each capsule in the
# lower layer to each capsule in this layer
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
in_caps
,
out_caps
,
in_d
,
out_d
),
requires_grad
=
True
)
def
__call__
(
self
,
u
:
torch
.
Tensor
):
# batch, in_caps, in_d
"""
The shape of `u` is `[batch_size, n_capsules, n_features]`.
These are the capsules from the lower layer.
"""
# $$\hat{\mathbf{u}}_{j|i} = \mathbf{W}_{ij} \mathbf{u}_i$$
# Here $j$ is used to index capsules in this layer, whilst $i$ is
# used to index capsules in the layer below (previous).
u_hat
=
torch
.
einsum
(
'ijnm,bin->bijm'
,
self
.
weight
,
u
)
# Initial logits $b_{ij}$ are the log prior probabilities that capsule $i$
# should be coupled with $j$.
# We initialize these at zero
b
=
u
.
new_zeros
(
u
.
shape
[
0
],
self
.
in_caps
,
self
.
out_caps
)
v
=
None
# Iterate
for
i
in
range
(
self
.
iterations
):
# routing softmax $$c_{ij} = \frac{\exp({b_{ij}})}{\sum_k\exp({b_{ik}})}$$
c
=
self
.
softmax
(
b
)
# $$\mathbf{s}_j = \sum_i{c_{ij} \hat{\mathbf{u}}_{j|i}}$$
s
=
torch
.
einsum
(
'bij,bijm->bjm'
,
c
,
u_hat
)
# $$\mathbf{v}_j = squash(\mathbf{s}_j)$$
v
=
self
.
squash
(
s
)
# $$a_{ij} = \mathbf{v}_j \cdot \hat{\mathbf{u}}_{j|i}$$
a
=
torch
.
einsum
(
'bjm,bijm->bij'
,
v
,
u_hat
)
# $$b_{ij} \gets b_{ij} + \mathbf{v}_j \cdot \hat{\mathbf{u}}_{j|i}$$
b
=
b
+
a
return
v
class
MarginLoss
(
Module
):
"""
## Margin loss for class existence
A separate margin loss is used for each output capsule and the total loss is the sum of them.
The length of each output capsule is the probability that class is present in the input.
Loss for each output capsule or class $k$ is,
$$L_k = T_k \max(0, m^{+} - \lVert\mathbf{v}_k
\r
Vert)^2 +
\lambda (1 - T_k) \max(0, \lVert\mathbf{v}_k
\r
Vert - m^{-})^2$$
$T_k$ is $1$ if the class $k$ is present and $0$ otherwise.
The first component of the loss is $0$ when if the class is not present,
and the second component is $0$ is the class is present.
The $\max(0, x)$ is used to avoid predictions going to extremes.
$m^{+}$ is set to be $0.9$ and $m^{-}$ to be $0.1$ in the paper.
The $\lambda$ down-weighting is used to stop the length of all capsules from
fallind during the initial phase of training.
"""
def
__init__
(
self
,
*
,
n_labels
:
int
,
lambda_
:
float
=
0.5
,
m_positive
:
float
=
0.9
,
m_negative
:
float
=
0.1
):
super
().
__init__
()
...
...
@@ -70,104 +141,25 @@ class MarginLoss(Module):
self
.
n_labels
=
n_labels
def
__call__
(
self
,
v
:
torch
.
Tensor
,
labels
:
torch
.
Tensor
):
"""
`v`, $\mathbf{v}_j$ are the squashed output capsules.
This has shape `[batch_size, n_labels, n_features]`; that is, there is a capsule for each label.
`labels` are the labels, and has shape `[batch_size]`.
"""
# $$\lVert \mathbf{v}_j \rVert$$
v_norm
=
torch
.
sqrt
((
v
**
2
).
sum
(
dim
=-
1
))
# $$L$$
# `labels` is one-hot encoded labels of shape `[batch_size, n_labels]`
labels
=
torch
.
eye
(
self
.
n_labels
,
device
=
labels
.
device
)[
labels
]
# $$L_k = T_k \max(0, m^{+} - \lVert\mathbf{v}_k\rVert)^2 +
# \lambda (1 - T_k) \max(0, \lVert\mathbf{v}_k\rVert - m^{-})^2$$
# `loss` has shape `[batch_size, n_labels]`. We have parallelized the computation
# of $L_k$ for for all $k$.
loss
=
labels
*
F
.
relu
(
self
.
m_positive
-
v_norm
)
+
\
self
.
lambda_
*
(
1.0
-
labels
)
*
F
.
relu
(
v_norm
-
self
.
m_negative
)
loss
=
loss
.
sum
(
dim
=-
1
).
mean
()
return
loss
class
MNISTCapsuleNetworkModel
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
in_channels
=
1
,
out_channels
=
256
,
kernel_size
=
9
,
stride
=
1
)
self
.
conv2
=
nn
.
Conv2d
(
in_channels
=
256
,
out_channels
=
32
*
8
,
kernel_size
=
9
,
stride
=
2
,
padding
=
0
)
self
.
squash
=
Squash
()
# self.digit_capsules = DigitCaps()
self
.
digit_capsules
=
Router
(
32
*
6
*
6
,
10
,
8
,
16
,
3
)
self
.
reconstruct
=
nn
.
Sequential
(
nn
.
Linear
(
16
*
10
,
512
),
nn
.
ReLU
(),
nn
.
Linear
(
512
,
1024
),
nn
.
ReLU
(),
nn
.
Linear
(
1024
,
784
),
nn
.
Sigmoid
()
)
self
.
mse_loss
=
nn
.
MSELoss
()
def
forward
(
self
,
data
):
x
=
F
.
relu
(
self
.
conv1
(
data
))
caps
=
self
.
conv2
(
x
).
view
(
x
.
shape
[
0
],
8
,
32
*
6
*
6
).
permute
(
0
,
2
,
1
)
caps
=
self
.
squash
(
caps
)
caps
=
self
.
digit_capsules
(
caps
)
with
torch
.
no_grad
():
pred
=
(
caps
**
2
).
sum
(
-
1
).
argmax
(
-
1
)
masked
=
torch
.
eye
(
10
,
device
=
x
.
device
)[
pred
]
reconstructions
=
self
.
reconstruct
((
caps
*
masked
[:,
:,
None
]).
view
(
x
.
shape
[
0
],
-
1
))
reconstructions
=
reconstructions
.
view
(
-
1
,
1
,
28
,
28
)
return
caps
,
reconstructions
,
pred
class
CapsuleNetworkBatchStep
(
BatchStep
):
def
__init__
(
self
,
*
,
model
,
optimizer
):
super
().
__init__
(
model
=
model
,
optimizer
=
optimizer
,
loss_func
=
None
,
accuracy_func
=
None
)
self
.
reconstruction_loss
=
nn
.
MSELoss
()
self
.
margin_loss
=
MarginLoss
(
n_labels
=
10
)
def
calculate_loss
(
self
,
batch
:
any
,
state
:
any
):
device
=
get_device
(
self
.
model
)
data
,
target
=
batch
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
stats
=
{
'samples'
:
len
(
data
)}
caps
,
reconstructions
,
pred
=
self
.
model
(
data
)
loss
=
self
.
margin_loss
(
caps
,
target
)
+
0.0005
*
self
.
reconstruction_loss
(
reconstructions
,
data
)
stats
[
'correct'
]
=
pred
.
eq
(
target
).
sum
().
item
()
stats
[
'loss'
]
=
loss
.
detach
().
item
()
*
stats
[
'samples'
]
tracker
.
add
(
"loss."
,
loss
)
return
loss
,
stats
,
None
class
Configs
(
MNISTConfigs
,
TrainValidConfigs
):
batch_step
=
'capsule_network_batch_step'
device
:
torch
.
device
=
DeviceConfigs
()
epochs
:
int
=
10
loss_func
=
None
accuracy_func
=
None
@
option
(
Configs
.
model
)
def
model
(
c
:
Configs
):
return
MNISTCapsuleNetworkModel
().
to
(
c
.
device
)
@
option
(
Configs
.
batch_step
)
def
capsule_network_batch_step
(
c
:
TrainValidConfigs
):
return
CapsuleNetworkBatchStep
(
model
=
c
.
model
,
optimizer
=
c
.
optimizer
)
def
main
():
conf
=
Configs
()
experiment
.
create
(
name
=
'mnist_latest'
,
writers
=
{})
experiment
.
configs
(
conf
,
{
'optimizer.optimizer'
:
'Adam'
,
'device.cuda_device'
:
1
},
'run'
)
experiment
.
add_pytorch_models
(
dict
(
model
=
conf
.
model
))
with
experiment
.
start
():
conf
.
run
()
if
__name__
==
'__main__'
:
mai
n
()
# $$\sum_k L_k$$
return
loss
.
sum
(
dim
=-
1
).
mea
n
()
labml_nn/capsule_networks/mnist.py
0 → 100644
浏览文件 @
302785a4
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.utils.data
from
labml
import
experiment
,
tracker
from
labml.configs
import
option
from
labml.utils.pytorch
import
get_device
from
labml_helpers.datasets.mnist
import
MNISTConfigs
from
labml_helpers.device
import
DeviceConfigs
from
labml_helpers.module
import
Module
from
labml_helpers.train_valid
import
TrainValidConfigs
,
BatchStep
from
labml_nn.capsule_networks
import
Squash
,
Router
,
MarginLoss
class
MNISTCapsuleNetworkModel
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
in_channels
=
1
,
out_channels
=
256
,
kernel_size
=
9
,
stride
=
1
)
self
.
conv2
=
nn
.
Conv2d
(
in_channels
=
256
,
out_channels
=
32
*
8
,
kernel_size
=
9
,
stride
=
2
,
padding
=
0
)
self
.
squash
=
Squash
()
# self.digit_capsules = DigitCaps()
self
.
digit_capsules
=
Router
(
32
*
6
*
6
,
10
,
8
,
16
,
3
)
self
.
reconstruct
=
nn
.
Sequential
(
nn
.
Linear
(
16
*
10
,
512
),
nn
.
ReLU
(),
nn
.
Linear
(
512
,
1024
),
nn
.
ReLU
(),
nn
.
Linear
(
1024
,
784
),
nn
.
Sigmoid
()
)
self
.
mse_loss
=
nn
.
MSELoss
()
def
forward
(
self
,
data
):
x
=
F
.
relu
(
self
.
conv1
(
data
))
caps
=
self
.
conv2
(
x
).
view
(
x
.
shape
[
0
],
8
,
32
*
6
*
6
).
permute
(
0
,
2
,
1
)
caps
=
self
.
squash
(
caps
)
caps
=
self
.
digit_capsules
(
caps
)
with
torch
.
no_grad
():
pred
=
(
caps
**
2
).
sum
(
-
1
).
argmax
(
-
1
)
masked
=
torch
.
eye
(
10
,
device
=
x
.
device
)[
pred
]
reconstructions
=
self
.
reconstruct
((
caps
*
masked
[:,
:,
None
]).
view
(
x
.
shape
[
0
],
-
1
))
reconstructions
=
reconstructions
.
view
(
-
1
,
1
,
28
,
28
)
return
caps
,
reconstructions
,
pred
class
CapsuleNetworkBatchStep
(
BatchStep
):
def
__init__
(
self
,
*
,
model
,
optimizer
):
super
().
__init__
(
model
=
model
,
optimizer
=
optimizer
,
loss_func
=
None
,
accuracy_func
=
None
)
self
.
reconstruction_loss
=
nn
.
MSELoss
()
self
.
margin_loss
=
MarginLoss
(
n_labels
=
10
)
def
calculate_loss
(
self
,
batch
:
any
,
state
:
any
):
device
=
get_device
(
self
.
model
)
data
,
target
=
batch
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
stats
=
{
'samples'
:
len
(
data
)}
caps
,
reconstructions
,
pred
=
self
.
model
(
data
)
loss
=
self
.
margin_loss
(
caps
,
target
)
+
0.0005
*
self
.
reconstruction_loss
(
reconstructions
,
data
)
stats
[
'correct'
]
=
pred
.
eq
(
target
).
sum
().
item
()
stats
[
'loss'
]
=
loss
.
detach
().
item
()
*
stats
[
'samples'
]
tracker
.
add
(
"loss."
,
loss
)
return
loss
,
stats
,
None
class
Configs
(
MNISTConfigs
,
TrainValidConfigs
):
batch_step
=
'capsule_network_batch_step'
device
:
torch
.
device
=
DeviceConfigs
()
epochs
:
int
=
10
model
=
'capsule_network_model'
loss_func
=
None
accuracy_func
=
None
@
option
(
Configs
.
model
)
def
capsule_network_model
(
c
:
Configs
):
return
MNISTCapsuleNetworkModel
().
to
(
c
.
device
)
@
option
(
Configs
.
batch_step
)
def
capsule_network_batch_step
(
c
:
TrainValidConfigs
):
return
CapsuleNetworkBatchStep
(
model
=
c
.
model
,
optimizer
=
c
.
optimizer
)
def
main
():
conf
=
Configs
()
experiment
.
create
(
name
=
'mnist_latest'
)
experiment
.
configs
(
conf
,
{
'optimizer.optimizer'
:
'Adam'
,
'device.cuda_device'
:
1
},
'run'
)
with
experiment
.
start
():
conf
.
run
()
if
__name__
==
'__main__'
:
main
()
setup.py
浏览文件 @
302785a4
...
...
@@ -5,7 +5,7 @@ with open("readme.rst", "r") as f:
setuptools
.
setup
(
name
=
'labml_nn'
,
version
=
'0.4.
1
'
,
version
=
'0.4.
2
'
,
author
=
"Varuna Jayasiri, Nipun Wijerathne"
,
author_email
=
"vpjayasiri@gmail.com, hnipun@gmail.com"
,
description
=
"A collection of PyTorch implementations of neural network architectures and layers."
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录