提交 5388e807 编写于 作者: V Varuna Jayasiri

layer norm

上级 d3790d70
......@@ -75,7 +75,9 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">1</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">2</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span></pre></div>
<span class="lineno">2</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">3</span>
<span class="lineno">4</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
......@@ -86,7 +88,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">5</span><span class="k">class</span> <span class="nc">Swish</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">7</span><span class="k">class</span> <span class="nc">Swish</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
......@@ -97,9 +99,9 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">6</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">7</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">8</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigmoid</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sigmoid</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">8</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">9</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">10</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigmoid</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sigmoid</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
......@@ -110,8 +112,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">10</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
<span class="lineno">11</span> <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">12</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
<span class="lineno">13</span> <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
</div>
......
......@@ -118,7 +118,7 @@ We decided to write a simpler implementation to make it easier readers who are n
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">43</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<div class="highlight"><pre><span class="lineno">43</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">Module</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">Module</span><span class="p">):</span>
<span class="lineno">44</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
......
import torch
from torch import nn
from labml_helpers.module import Module
class Swish(nn.Module):
class Swish(Module):
def __init__(self):
super().__init__()
self.sigmoid = nn.Sigmoid()
......
......@@ -98,8 +98,10 @@ a CNN classifier that use batch normalization for MNIST dataset.
import torch
from torch import nn
from labml_helpers.module import Module
class BatchNorm(nn.Module):
class BatchNorm(Module):
r"""
## Batch Normalization Layer
......@@ -157,7 +159,7 @@ class BatchNorm(nn.Module):
def forward(self, x: torch.Tensor):
"""
`x` is a tensor of shape `[batch_size, channels, *]`.
`*` could be any (even *) dimensions.
`*` could be any number of (even 0) dimensions.
For example, in an image (2D) convolution this will be
`[batch_size, channels, height, width]`
"""
......@@ -200,3 +202,19 @@ class BatchNorm(nn.Module):
# Reshape to original and return
return x_norm.view(x_shape)
def _test():
from labml.logger import inspect
x = torch.zeros([2, 3, 2, 4])
inspect(x.shape)
bn = BatchNorm(3)
x = bn(x)
inspect(x.shape)
inspect(bn.exp_var.shape)
if __name__ == '__main__':
_test()
......@@ -32,89 +32,86 @@ Layer normalization is generally used for NLP tasks.
We have used layer normalization in most of the
[transformer implementations](../../transformers/gpt/index.html).
"""
from typing import Union, List
import torch
from torch import nn
from torch import nn, Size
from labml_helpers.module import Module
class LayerNorm(nn.Module):
class LayerNorm(Module):
"""
## Layer Normalization
"""
def __init__(self, channels: int, *,
eps: float = 1e-5, momentum: float = 0.1,
affine: bool = True, track_running_stats: bool = True):
def __init__(self, normalized_shape: Union[int, List[int], Size], *,
eps: float = 1e-5,
elementwise_affine: bool = True):
"""
* `channels` is the number of features in the input
* `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
* `momentum` is the momentum in taking the exponential moving average
* `affine` is whether to scale and shift the normalized value
* `track_running_stats` is whether to calculate the moving averages or mean and variance
* `normalized_shape` $S$ is shape of the elements (except the batch).
The input should then be
$X \in \mathbb{R}^{* \times S[0] \times S[1] \times ... \times S[n]}$
* `eps` is $\epsilon$, used in $\sqrt{Var[X}] + \epsilon}$ for numerical stability
* `elementwise_affine` is whether to scale and shift the normalized value
We've tried to use the same names for arguments as PyTorch `BatchNorm` implementation.
We've tried to use the same names for arguments as PyTorch `LayerNorm` implementation.
"""
super().__init__()
self.channels = channels
self.normalized_shape = normalized_shape
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
# Create parameters for $\gamma$ and $\beta$ for scale and shift
if self.affine:
self.scale = nn.Parameter(torch.ones(channels))
self.shift = nn.Parameter(torch.zeros(channels))
# Create buffers to store exponential moving averages of
# mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$
if self.track_running_stats:
self.register_buffer('exp_mean', torch.zeros(channels))
self.register_buffer('exp_var', torch.ones(channels))
self.elementwise_affine = elementwise_affine
# Create parameters for $\gamma$ and $\beta$ for gain and bias
if self.elementwise_affine:
self.gain = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
def forward(self, x: torch.Tensor):
"""
`x` is a tensor of shape `[batch_size, channels, *]`.
`*` could be any (even *) dimensions.
For example, in an image (2D) convolution this will be
`[batch_size, channels, height, width]`
`x` is a tensor of shape `[*, S[0], S[1], ..., S[n]]`.
`*` could be any number of dimensions.
For example, in an NLP task this will be
`[seq_len, batch_size, features]`
"""
# Keep the original shape
x_shape = x.shape
# Get the batch size
batch_size = x_shape[0]
# Sanity check to make sure the number of features is same
assert self.channels == x.shape[1]
# Reshape into `[batch_size, channels, n]`
x = x.view(batch_size, self.channels, -1)
# We will calculate the mini-batch mean and variance
# if we are in training mode or if we have not tracked exponential moving averages
if self.training or not self.track_running_stats:
# Calculate the mean across first and last dimension;
# i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
mean = x.mean(dim=[0, 2])
# Calculate the squared mean across first and last dimension;
# i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
mean_x2 = (x ** 2).mean(dim=[0, 2])
# Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
var = mean_x2 - mean ** 2
# Update exponential moving averages
if self.training and self.track_running_stats:
self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
# Use exponential moving averages as estimates
else:
mean = self.exp_mean
var = self.exp_var
# Normalize $$\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}] + \epsilon}}$$
x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)
# Scale and shift $$y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}$$
if self.affine:
x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
# Sanity check to make sure the shapes match
assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
# Reshape into `[M, S[0], S[1], ..., S[n]]`
x = x.view(-1, *self.normalized_shape)
# Calculate the mean across first dimension;
# i.e. the means for each element $\mathbb{E}[X}]$
mean = x.mean(dim=0)
# Calculate the squared mean across first dimension;
# i.e. the means for each element $\mathbb{E}[X^2]$
mean_x2 = (x ** 2).mean(dim=0)
# Variance for each element $Var[X] = \mathbb{E}[X^2] - \mathbb{E}[X]^2$
var = mean_x2 - mean ** 2
# Normalize $$\hat{X} = \frac{X} - \mathbb{E}[X]}{\sqrt{Var[X] + \epsilon}}$$
x_norm = (x - mean) / torch.sqrt(var + self.eps)
# Scale and shift $$\text{LN}(x) = \gamma \hat{X} + \beta$$
if self.elementwise_affine:
x_norm = self.gain * x_norm + self.bias
# Reshape to original and return
return x_norm.view(x_shape)
def _test():
from labml.logger import inspect
x = torch.zeros([2, 3, 2, 4])
inspect(x.shape)
ln = LayerNorm(x.shape[2:])
x = ln(x)
inspect(x.shape)
inspect(ln.gain.shape)
if __name__ == '__main__':
_test()
......@@ -8,6 +8,7 @@ summary: This is a simple MNIST example with a CNN model to test the optimizers.
"""
import torch.nn as nn
import torch.utils.data
from labml_helpers.module import Module
from labml import experiment, tracker
from labml.configs import option
......@@ -19,7 +20,7 @@ from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_
from labml_nn.optimizers.configs import OptimizerConfigs
class Model(nn.Module):
class Model(Module):
"""
## The model
"""
......
......@@ -40,7 +40,7 @@ class AutoregressiveModel(Module):
## Auto regressive model
"""
def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
def __init__(self, src_embed: Module, encoder: Encoder, generator: Module):
super().__init__()
# Token embedding module
self.src_embed = src_embed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册