diff --git a/docs/index.html b/docs/index.html index afc37122dd6e5611e262fb59b990efafe60594d8..8e17ac32124598e6efcbac03fd0e32acba97a6c4 100644 --- a/docs/index.html +++ b/docs/index.html @@ -125,6 +125,7 @@ and

Normalization Layers

Installation

pip install labml_nn
diff --git a/docs/normalization/batch_norm/index.html b/docs/normalization/batch_norm/index.html
index 2fed9a7e2c04abb568e56978a2ae0cf796734fc8..1888ef0f33212b11c7a3640bbe4cffec3f53c86d 100644
--- a/docs/normalization/batch_norm/index.html
+++ b/docs/normalization/batch_norm/index.html
@@ -156,18 +156,21 @@ a CNN classifier that use batch normalization for MNIST dataset.

Batch normalization layer $\text{BN}$ normalizes the input $X$ as follows:

When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations, where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width. +$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.

-

When input $X \in \mathbb{R}^{B \times C}$ is a batch of vector embeddings, +

When input $X \in \mathbb{R}^{B \times C}$ is a batch of embeddings, where $B$ is the batch size and $C$ is the number of features. +$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.

-

When input $X \in \mathbb{R}^{B \times C \times L}$ is a batch of sequence embeddings, +

When input $X \in \mathbb{R}^{B \times C \times L}$ is a batch of a sequence embeddings, where $B$ is the batch size, $C$ is the number of features, and $L$ is the length of the sequence. +$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$. @@ -192,9 +195,9 @@ where $B$ is the batch size, $C$ is the number of features, and $L$ is the lengt

We’ve tried to use the same names for arguments as PyTorch BatchNorm implementation.

-
129    def __init__(self, channels: int, *,
-130                 eps: float = 1e-5, momentum: float = 0.1,
-131                 affine: bool = True, track_running_stats: bool = True):
+
132    def __init__(self, channels: int, *,
+133                 eps: float = 1e-5, momentum: float = 0.1,
+134                 affine: bool = True, track_running_stats: bool = True):
@@ -205,14 +208,14 @@ where $B$ is the batch size, $C$ is the number of features, and $L$ is the lengt
-
141        super().__init__()
-142
-143        self.channels = channels
-144
-145        self.eps = eps
-146        self.momentum = momentum
-147        self.affine = affine
-148        self.track_running_stats = track_running_stats
+
144        super().__init__()
+145
+146        self.channels = channels
+147
+148        self.eps = eps
+149        self.momentum = momentum
+150        self.affine = affine
+151        self.track_running_stats = track_running_stats
@@ -223,9 +226,9 @@ where $B$ is the batch size, $C$ is the number of features, and $L$ is the lengt

Create parameters for $\gamma$ and $\beta$ for scale and shift

-
150        if self.affine:
-151            self.scale = nn.Parameter(torch.ones(channels))
-152            self.shift = nn.Parameter(torch.zeros(channels))
+
153        if self.affine:
+154            self.scale = nn.Parameter(torch.ones(channels))
+155            self.shift = nn.Parameter(torch.zeros(channels))
@@ -237,9 +240,9 @@ where $B$ is the batch size, $C$ is the number of features, and $L$ is the lengt mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$

-
155        if self.track_running_stats:
-156            self.register_buffer('exp_mean', torch.zeros(channels))
-157            self.register_buffer('exp_var', torch.ones(channels))
+
158        if self.track_running_stats:
+159            self.register_buffer('exp_mean', torch.zeros(channels))
+160            self.register_buffer('exp_var', torch.ones(channels))
@@ -253,7 +256,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$

[batch_size, channels, height, width]

-
159    def forward(self, x: torch.Tensor):
+
162    def forward(self, x: torch.Tensor):
@@ -264,7 +267,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$

Keep the original shape

-
167        x_shape = x.shape
+
170        x_shape = x.shape
@@ -275,7 +278,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$

Get the batch size

-
169        batch_size = x_shape[0]
+
172        batch_size = x_shape[0]
@@ -286,7 +289,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$

Sanity check to make sure the number of features is same

-
171        assert self.channels == x.shape[1]
+
174        assert self.channels == x.shape[1]
@@ -297,7 +300,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$

Reshape into [batch_size, channels, n]

-
174        x = x.view(batch_size, self.channels, -1)
+
177        x = x.view(batch_size, self.channels, -1)
@@ -309,7 +312,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$

if we are in training mode or if we have not tracked exponential moving averages

-
178        if self.training or not self.track_running_stats:
+
181        if self.training or not self.track_running_stats:
@@ -321,7 +324,7 @@ if we are in training mode or if we have not tracked exponential moving averages i.e. the means for each feature $\mathbb{E}[x^{(k)}]$

-
181            mean = x.mean(dim=[0, 2])
+
184            mean = x.mean(dim=[0, 2])
@@ -333,7 +336,7 @@ i.e. the means for each feature $\mathbb{E}[x^{(k)}]$

i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$

-
184            mean_x2 = (x ** 2).mean(dim=[0, 2])
+
187            mean_x2 = (x ** 2).mean(dim=[0, 2])
@@ -344,7 +347,7 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$

Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$

-
186            var = mean_x2 - mean ** 2
+
189            var = mean_x2 - mean ** 2
@@ -355,9 +358,9 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$

Update exponential moving averages

-
189            if self.training and self.track_running_stats:
-190                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
-191                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
+
192            if self.training and self.track_running_stats:
+193                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
+194                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
@@ -368,9 +371,9 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$

Use exponential moving averages as estimates

-
193        else:
-194            mean = self.exp_mean
-195            var = self.exp_var
+
196        else:
+197            mean = self.exp_mean
+198            var = self.exp_var
@@ -382,7 +385,7 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$

-
198        x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)
+
201        x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)
@@ -394,8 +397,8 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$

-
200        if self.affine:
-201            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
+
203        if self.affine:
+204            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
@@ -406,31 +409,49 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$

Reshape to original and return

-
204        return x_norm.view(x_shape)
+
207        return x_norm.view(x_shape)
-
+
+

Simple test

+
+
+
210def _test():
+
+
+
+
+ + +
+
+
214    from labml.logger import inspect
+215
+216    x = torch.zeros([2, 3, 2, 4])
+217    inspect(x.shape)
+218    bn = BatchNorm(3)
+219
+220    x = bn(x)
+221    inspect(x.shape)
+222    inspect(bn.exp_var.shape)
+
+
+
+
+
-
207def _test():
-208    from labml.logger import inspect
-209
-210    x = torch.zeros([2, 3, 2, 4])
-211    inspect(x.shape)
-212    bn = BatchNorm(3)
-213
-214    x = bn(x)
-215    inspect(x.shape)
-216    inspect(bn.exp_var.shape)
-217
-218
-219if __name__ == '__main__':
-220    _test()
+
226if __name__ == '__main__':
+227    _test()
diff --git a/docs/normalization/batch_norm/readme.html b/docs/normalization/batch_norm/readme.html new file mode 100644 index 0000000000000000000000000000000000000000..de16cd155c0ad36c8515292fdb6ff37e8df45b92 --- /dev/null +++ b/docs/normalization/batch_norm/readme.html @@ -0,0 +1,180 @@ + + + + + + + + + + + + + + + + + + + + + + + Batch Normalization + + + + + + + + +
+
+
+
+

+ home + normalization + batch_norm +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ +

Batch Normalization

+

This is a PyTorch implementation of Batch Normalization from paper + Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.

+

Internal Covariate Shift

+

The paper defines Internal Covariate Shift as the change in the +distribution of network activations due to the change in +network parameters during training. +For example, let’s say there are two layers $l_1$ and $l_2$. +During the beginning of the training $l_1$ outputs (inputs to $l_2$) +could be in distribution $\mathcal{N}(0.5, 1)$. +Then, after some training steps, it could move to $\mathcal{N}(0.5, 1)$. +This is internal covariate shift.

+

Internal covariate shift will adversely affect training speed because the later layers +($l_2$ in the above example) has to adapt to this shifted distribution.

+

By stabilizing the distribution batch normalization minimizes the internal covariate shift.

+

Normalization

+

It is known that whitening improves training speed and convergence. +Whitening is linearly transforming inputs to have zero mean, unit variance, +and be uncorrelated.

+

Normalizing outside gradient computation doesn’t work

+

Normalizing outside the gradient computation using pre-computed (detached) +means and variances doesn’t work. For instance. (ignoring variance), let + +where $x = u + b$ and $b$ is a trained bias. +and $\mathbb{E}[x]$ is outside gradient computation (pre-computed constant).

+

Note that $\hat{x}$ has no effect of $b$. +Therefore, +$b$ will increase or decrease based +$\frac{\partial{\mathcal{L}}}{\partial x}$, +and keep on growing indefinitely in each training update. +The paper notes that similar explosions happen with variances.

+

Batch Normalization

+

Whitening is computationally expensive because you need to de-correlate and +the gradients must flow through the full whitening calculation.

+

The paper introduces simplified version which they call Batch Normalization. +First simplification is that it normalizes each feature independently to have +zero mean and unit variance: + +where $x = (x^{(1)} … x^{(d)})$ is the $d$-dimensional input.

+

The second simplification is to use estimates of mean $\mathbb{E}[x^{(k)}]$ +and variance $Var[x^{(k)}]$ from the mini-batch +for normalization; instead of calculating the mean and variance across whole dataset.

+

Normalizing each feature to zero mean and unit variance could affect what the layer +can represent. +As an example paper illustrates that, if the inputs to a sigmoid are normalized +most of it will be within $[-1, 1]$ range where the sigmoid is linear. +To overcome this each feature is scaled and shifted by two trained parameters +$\gamma^{(k)}$ and $\beta^{(k)}$. + +where $y^{(k)}$ is the output of the batch normalization layer.

+

Note that when applying batch normalization after a linear transform +like $Wu + b$ the bias parameter $b$ gets cancelled due to normalization. +So you can and should omit bias parameter in linear transforms right before the +batch normalization.

+

Batch normalization also makes the back propagation invariant to the scale of the weights. +And empirically it improves generalization, so it has regularization effects too.

+

Inference

+

We need to know $\mathbb{E}[x^{(k)}]$ and $Var[x^{(k)}]$ in order to +perform the normalization. +So during inference, you either need to go through the whole (or part of) dataset +and find the mean and variance, or you can use an estimate calculated during training. +The usual practice is to calculate an exponential moving average of +mean and variance during the training phase and use that for inference.

+

Here’s the training code and a notebook for training +a CNN classifier that use batch normalization for MNIST dataset.

+

Open In Colab +View Run

+
+
+ +
+
+
+ + + + + + \ No newline at end of file diff --git a/docs/normalization/index.html b/docs/normalization/index.html index a528757d731cc0e931df163002ce1cb7d29c6394..6f4b3143294ce68d669abcd412c26e6236c68618 100644 --- a/docs/normalization/index.html +++ b/docs/normalization/index.html @@ -74,10 +74,10 @@

Normalization Layers

TODO

diff --git a/docs/normalization/layer_norm/index.html b/docs/normalization/layer_norm/index.html index 49c0b303186e09cc0b86f73776eed61f3bd29ab9..b174a6f90b5ade6d84a406f399a937882e9f88e5 100644 --- a/docs/normalization/layer_norm/index.html +++ b/docs/normalization/layer_norm/index.html @@ -88,7 +88,7 @@ large NLP models are usually trained with small batch sizes. on a wider range of settings. Layer normalization transformers the inputs to have zero mean and unit variance across the features. -*Note that batch normalization, fixes the zero mean and unit variance for each vector. +Note that batch normalization fixes the zero mean and unit variance for each element. Layer normalization does it for each batch across all elements.

Layer normalization is generally used for NLP tasks.

We have used layer normalization in most of the @@ -109,6 +109,29 @@ Layer normalization does it for each batch across all elements.

#

Layer Normalization

+

Layer normalization $\text{LN}$ normalizes the input $X$ as follows:

+

When input $X \in \mathbb{R}^{B \times C}$ is a batch of embeddings, +where $B$ is the batch size and $C$ is the number of features. +$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$. + +

+

When input $X \in \mathbb{R}^{L \times B \times C}$ is a batch of a sequence of embeddings, +where $B$ is the batch size, $C$ is the number of channels, $L$ is the length of the sequence. +$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$. + +

+

When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations, +where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width. +This is not a widely used scenario. +$\gamma \in \mathbb{R}^{C \times H \times W}$ and $\beta \in \mathbb{R}^{C \times H \times W}$. + +

43class LayerNorm(Module):
@@ -120,18 +143,18 @@ Layer normalization does it for each batch across all elements.

#

We’ve tried to use the same names for arguments as PyTorch LayerNorm implementation.

-
48    def __init__(self, normalized_shape: Union[int, List[int], Size], *,
-49                 eps: float = 1e-5,
-50                 elementwise_affine: bool = True):
+
72    def __init__(self, normalized_shape: Union[int, List[int], Size], *,
+73                 eps: float = 1e-5,
+74                 elementwise_affine: bool = True):
@@ -142,11 +165,11 @@ Layer normalization does it for each batch across all elements.

-
60        super().__init__()
-61
-62        self.normalized_shape = normalized_shape
-63        self.eps = eps
-64        self.elementwise_affine = elementwise_affine
+
84        super().__init__()
+85
+86        self.normalized_shape = normalized_shape
+87        self.eps = eps
+88        self.elementwise_affine = elementwise_affine
@@ -157,9 +180,9 @@ Layer normalization does it for each batch across all elements.

Create parameters for $\gamma$ and $\beta$ for gain and bias

-
66        if self.elementwise_affine:
-67            self.gain = nn.Parameter(torch.ones(normalized_shape))
-68            self.bias = nn.Parameter(torch.zeros(normalized_shape))
+
90        if self.elementwise_affine:
+91            self.gain = nn.Parameter(torch.ones(normalized_shape))
+92            self.bias = nn.Parameter(torch.zeros(normalized_shape))
@@ -173,7 +196,7 @@ Layer normalization does it for each batch across all elements.

[seq_len, batch_size, features]

-
70    def forward(self, x: torch.Tensor):
+
94    def forward(self, x: torch.Tensor):
@@ -181,10 +204,10 @@ Layer normalization does it for each batch across all elements.

-

Keep the original shape

+

Sanity check to make sure the shapes match

-
78        x_shape = x.shape
+
102        assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
@@ -192,10 +215,10 @@ Layer normalization does it for each batch across all elements.

-

Sanity check to make sure the shapes match

+

The dimensions to calculate the mean and variance on

-
80        assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
+
105        dims = [-(i + 1) for i in range(len(self.normalized_shape))]
@@ -203,10 +226,11 @@ Layer normalization does it for each batch across all elements.

-

Reshape into [M, S[0], S[1], ..., S[n]]

+

Calculate the mean of all elements; +i.e. the means for each element $\mathbb{E}[X]$

-
83        x = x.view(-1, *self.normalized_shape)
+
109        mean = x.mean(dim=dims, keepdims=True)
@@ -214,11 +238,11 @@ Layer normalization does it for each batch across all elements.

-

Calculate the mean across first dimension; -i.e. the means for each element $\mathbb{E}[X}]$

+

Calculate the squared mean of all elements; +i.e. the means for each element $\mathbb{E}[X^2]$

-
87        mean = x.mean(dim=0)
+
112        mean_x2 = (x ** 2).mean(dim=dims, keepdims=True)
@@ -226,11 +250,10 @@ i.e. the means for each element $\mathbb{E}[X}]$

-

Calculate the squared mean across first dimension; -i.e. the means for each element $\mathbb{E}[X^2]$

+

Variance of all element $Var[X] = \mathbb{E}[X^2] - \mathbb{E}[X]^2$

-
90        mean_x2 = (x ** 2).mean(dim=0)
+
114        var = mean_x2 - mean ** 2
@@ -238,10 +261,11 @@ i.e. the means for each element $\mathbb{E}[X^2]$

-

Variance for each element $Var[X] = \mathbb{E}[X^2] - \mathbb{E}[X]^2$

+

Normalize +

-
92        var = mean_x2 - mean ** 2
+
117        x_norm = (x - mean) / torch.sqrt(var + self.eps)
@@ -249,11 +273,12 @@ i.e. the means for each element $\mathbb{E}[X^2]$

-

Normalize +

Scale and shift

-
95        x_norm = (x - mean) / torch.sqrt(var + self.eps)
+
119        if self.elementwise_affine:
+120            x_norm = self.gain * x_norm + self.bias
@@ -261,23 +286,21 @@ i.e. the means for each element $\mathbb{E}[X^2]$

-

Scale and shift -

+
-
97        if self.elementwise_affine:
-98            x_norm = self.gain * x_norm + self.bias
+
123        return x_norm
-
+
-

Reshape to original and return

+

Simple test

-
101        return x_norm.view(x_shape)
+
126def _test():
@@ -288,20 +311,27 @@ i.e. the means for each element $\mathbb{E}[X^2]$

-
104def _test():
-105    from labml.logger import inspect
-106
-107    x = torch.zeros([2, 3, 2, 4])
-108    inspect(x.shape)
-109    ln = LayerNorm(x.shape[2:])
-110
-111    x = ln(x)
-112    inspect(x.shape)
-113    inspect(ln.gain.shape)
-114
-115
-116if __name__ == '__main__':
-117    _test()
+
130    from labml.logger import inspect
+131
+132    x = torch.zeros([2, 3, 2, 4])
+133    inspect(x.shape)
+134    ln = LayerNorm(x.shape[2:])
+135
+136    x = ln(x)
+137    inspect(x.shape)
+138    inspect(ln.gain.shape)
+
+
+
+
+ + +
+
+
142if __name__ == '__main__':
+143    _test()
diff --git a/docs/normalization/layer_norm/readme.html b/docs/normalization/layer_norm/readme.html new file mode 100644 index 0000000000000000000000000000000000000000..d279c77f20d61c3da3e47faa50c192950f287b8e --- /dev/null +++ b/docs/normalization/layer_norm/readme.html @@ -0,0 +1,134 @@ + + + + + + + + + + + + + + + + + + + + + + + Layer Normalization + + + + + + + + +
+
+
+
+

+ home + normalization + layer_norm +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ +

Layer Normalization

+

This is a PyTorch implementation of +Layer Normalization.

+

Limitations of Batch Normalization

+
    +
  • You need to maintain running means.
  • +
  • Tricky for RNNs. Do you need different normalizations for each step?
  • +
  • Doesn’t work with small batch sizes; +large NLP models are usually trained with small batch sizes.
  • +
  • Need to compute means and variances across devices in distributed training
  • +
+

Layer Normalization

+

Layer normalization is a simpler normalization method that works +on a wider range of settings. +Layer normalization transformers the inputs to have zero mean and unit variance +across the features. +Note that batch normalization fixes the zero mean and unit variance for each element. +Layer normalization does it for each batch across all elements.

+

Layer normalization is generally used for NLP tasks.

+

We have used layer normalization in most of the +transformer implementations.

+
+
+ +
+
+
+ + + + + + \ No newline at end of file diff --git a/docs/sitemap.xml b/docs/sitemap.xml index 048381dab00e43ab9d3e2337c154c4c14b9172ab..56e8e87cce17ccbda65d304d2918b74fc5fb0466 100644 --- a/docs/sitemap.xml +++ b/docs/sitemap.xml @@ -50,7 +50,7 @@ https://nn.labml.ai/activations/swish.html - 2021-01-25T16:30:00+00:00 + 2021-02-02T16:30:00+00:00 1.00 @@ -83,6 +83,13 @@ + + https://nn.labml.ai/normalization/layer_norm/index.html + 2021-02-02T16:30:00+00:00 + 1.00 + + + https://nn.labml.ai/normalization/index.html 2021-02-01T16:30:00+00:00 @@ -183,7 +190,7 @@ https://nn.labml.ai/optimizers/mnist_experiment.html - 2020-12-10T16:30:00+00:00 + 2021-02-02T16:30:00+00:00 1.00 @@ -225,7 +232,7 @@ https://nn.labml.ai/transformers/knn/train_model.html - 2021-01-25T16:30:00+00:00 + 2021-02-02T16:30:00+00:00 1.00 @@ -253,7 +260,7 @@ https://nn.labml.ai/transformers/models.html - 2021-02-01T16:30:00+00:00 + 2021-02-02T16:30:00+00:00 1.00 @@ -267,14 +274,14 @@ https://nn.labml.ai/transformers/gpt/index.html - 2021-02-01T16:30:00+00:00 + 2021-02-02T16:30:00+00:00 1.00 https://nn.labml.ai/transformers/feed_forward.html - 2021-01-30T16:30:00+00:00 + 2021-02-02T16:30:00+00:00 1.00 @@ -295,7 +302,7 @@ https://nn.labml.ai/transformers/feedback/index.html - 2021-02-01T16:30:00+00:00 + 2021-02-02T16:30:00+00:00 1.00 @@ -309,7 +316,7 @@ https://nn.labml.ai/transformers/feedback/experiment.html - 2021-01-29T16:30:00+00:00 + 2021-02-02T16:30:00+00:00 1.00 @@ -330,14 +337,14 @@ https://nn.labml.ai/transformers/glu_variants/experiment.html - 2021-01-26T16:30:00+00:00 + 2021-02-02T16:30:00+00:00 1.00 https://nn.labml.ai/transformers/glu_variants/simple.html - 2021-01-26T16:30:00+00:00 + 2021-02-02T16:30:00+00:00 1.00 @@ -358,7 +365,7 @@ https://nn.labml.ai/transformers/switch/index.html - 2021-02-01T16:30:00+00:00 + 2021-02-02T16:30:00+00:00 1.00 @@ -372,28 +379,28 @@ https://nn.labml.ai/transformers/switch/experiment.html - 2021-01-25T16:30:00+00:00 + 2021-02-02T16:30:00+00:00 1.00 https://nn.labml.ai/transformers/positional_encoding.html - 2021-01-07T16:30:00+00:00 + 2021-02-02T16:30:00+00:00 1.00 https://nn.labml.ai/transformers/label_smoothing_loss.html - 2020-12-10T16:30:00+00:00 + 2021-02-02T16:30:00+00:00 1.00 https://nn.labml.ai/transformers/mha.html - 2021-02-01T16:30:00+00:00 + 2021-02-02T16:30:00+00:00 1.00 diff --git a/labml_nn/__init__.py b/labml_nn/__init__.py index c6e44cfd1077eb2338f129678c091357cd589c66..f46c8fc08371304e445f4eda0ce602369cdab6c6 100644 --- a/labml_nn/__init__.py +++ b/labml_nn/__init__.py @@ -60,6 +60,7 @@ and #### ✨ [Normalization Layers](https://nn.labml.ai/normalization/index.html) * [Batch Normalization](https://nn.labml.ai/normalization/batch_norm/index.html) +* [Layer Normalization](https://nn.labml.ai/normalization/layer_norm/index.html) ### Installation diff --git a/labml_nn/normalization/__init__.py b/labml_nn/normalization/__init__.py index 986d9c754d29aa8673481929e85c302226409040..ac254aacd92712774823fa01f2c92ec00908526e 100644 --- a/labml_nn/normalization/__init__.py +++ b/labml_nn/normalization/__init__.py @@ -8,10 +8,10 @@ summary: > # Normalization Layers * [Batch Normalization](batch_norm/index.html) +* [Layer Normalization](layer_norm/index.html) *TODO* -* Layer Normalization * Instance Normalization * Group Normalization """ \ No newline at end of file diff --git a/labml_nn/normalization/batch_norm/__init__.py b/labml_nn/normalization/batch_norm/__init__.py index f64a1475659a8055d74e2d18acd610e790afd224..0e914a0c9c4589e65f9fbdf0ce2aa52e87899646 100644 --- a/labml_nn/normalization/batch_norm/__init__.py +++ b/labml_nn/normalization/batch_norm/__init__.py @@ -109,18 +109,21 @@ class BatchNorm(Module): When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations, where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width. + $\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$. $$\text{BN}(X) = \gamma \frac{X - \underset{B, H, W}{\mathbb{E}}[X]}{\sqrt{\underset{B, H, W}{Var}[X] + \epsilon}} + \beta$$ - When input $X \in \mathbb{R}^{B \times C}$ is a batch of vector embeddings, + When input $X \in \mathbb{R}^{B \times C}$ is a batch of embeddings, where $B$ is the batch size and $C$ is the number of features. + $\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$. $$\text{BN}(X) = \gamma \frac{X - \underset{B}{\mathbb{E}}[X]}{\sqrt{\underset{B}{Var}[X] + \epsilon}} + \beta$$ - When input $X \in \mathbb{R}^{B \times C \times L}$ is a batch of sequence embeddings, + When input $X \in \mathbb{R}^{B \times C \times L}$ is a batch of a sequence embeddings, where $B$ is the batch size, $C$ is the number of features, and $L$ is the length of the sequence. + $\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$. $$\text{BN}(X) = \gamma \frac{X - \underset{B, L}{\mathbb{E}}[X]}{\sqrt{\underset{B, L}{Var}[X] + \epsilon}} + \beta$$ @@ -205,6 +208,9 @@ class BatchNorm(Module): def _test(): + """ + Simple test + """ from labml.logger import inspect x = torch.zeros([2, 3, 2, 4]) @@ -216,5 +222,6 @@ def _test(): inspect(bn.exp_var.shape) +# if __name__ == '__main__': _test() diff --git a/labml_nn/normalization/batch_norm/readme.md b/labml_nn/normalization/batch_norm/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..573c079b0003138c1f0226adf0fcbfbc88cf4e5d --- /dev/null +++ b/labml_nn/normalization/batch_norm/readme.md @@ -0,0 +1,88 @@ +# [Batch Normalization](https://nn.labml.ai/normalization/batch_norm/index.html) + +This is a [PyTorch](https://pytorch.org) implementation of Batch Normalization from paper + [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167). + +### Internal Covariate Shift + +The paper defines *Internal Covariate Shift* as the change in the +distribution of network activations due to the change in +network parameters during training. +For example, let's say there are two layers $l_1$ and $l_2$. +During the beginning of the training $l_1$ outputs (inputs to $l_2$) +could be in distribution $\mathcal{N}(0.5, 1)$. +Then, after some training steps, it could move to $\mathcal{N}(0.5, 1)$. +This is *internal covariate shift*. + +Internal covariate shift will adversely affect training speed because the later layers +($l_2$ in the above example) has to adapt to this shifted distribution. + +By stabilizing the distribution batch normalization minimizes the internal covariate shift. + +## Normalization + +It is known that whitening improves training speed and convergence. +*Whitening* is linearly transforming inputs to have zero mean, unit variance, +and be uncorrelated. + +### Normalizing outside gradient computation doesn't work + +Normalizing outside the gradient computation using pre-computed (detached) +means and variances doesn't work. For instance. (ignoring variance), let +$$\hat{x} = x - \mathbb{E}[x]$$ +where $x = u + b$ and $b$ is a trained bias. +and $\mathbb{E}[x]$ is outside gradient computation (pre-computed constant). + +Note that $\hat{x}$ has no effect of $b$. +Therefore, +$b$ will increase or decrease based +$\frac{\partial{\mathcal{L}}}{\partial x}$, +and keep on growing indefinitely in each training update. +The paper notes that similar explosions happen with variances. + +### Batch Normalization + +Whitening is computationally expensive because you need to de-correlate and +the gradients must flow through the full whitening calculation. + +The paper introduces simplified version which they call *Batch Normalization*. +First simplification is that it normalizes each feature independently to have +zero mean and unit variance: +$$\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}]}}$$ +where $x = (x^{(1)} ... x^{(d)})$ is the $d$-dimensional input. + +The second simplification is to use estimates of mean $\mathbb{E}[x^{(k)}]$ +and variance $Var[x^{(k)}]$ from the mini-batch +for normalization; instead of calculating the mean and variance across whole dataset. + +Normalizing each feature to zero mean and unit variance could affect what the layer +can represent. +As an example paper illustrates that, if the inputs to a sigmoid are normalized +most of it will be within $[-1, 1]$ range where the sigmoid is linear. +To overcome this each feature is scaled and shifted by two trained parameters +$\gamma^{(k)}$ and $\beta^{(k)}$. +$$y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}$$ +where $y^{(k)}$ is the output of the batch normalization layer. + +Note that when applying batch normalization after a linear transform +like $Wu + b$ the bias parameter $b$ gets cancelled due to normalization. +So you can and should omit bias parameter in linear transforms right before the +batch normalization. + +Batch normalization also makes the back propagation invariant to the scale of the weights. +And empirically it improves generalization, so it has regularization effects too. + +## Inference + +We need to know $\mathbb{E}[x^{(k)}]$ and $Var[x^{(k)}]$ in order to +perform the normalization. +So during inference, you either need to go through the whole (or part of) dataset +and find the mean and variance, or you can use an estimate calculated during training. +The usual practice is to calculate an exponential moving average of +mean and variance during the training phase and use that for inference. + +Here's [the training code](https://nn.labml.ai/normalization/layer_norm/mnist.html) and a notebook for training +a CNN classifier that use batch normalization for MNIST dataset. + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/batch_norm/mnist.ipynb) +[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=011254fe647011ebbb8e0242ac1c0002) diff --git a/labml_nn/normalization/layer_norm/__init__.py b/labml_nn/normalization/layer_norm/__init__.py index b711e81add018854fca88a131192dafbe47f69b0..6d913197ec125c7a139ebf9f5888be11349cc266 100644 --- a/labml_nn/normalization/layer_norm/__init__.py +++ b/labml_nn/normalization/layer_norm/__init__.py @@ -24,7 +24,7 @@ Layer normalization is a simpler normalization method that works on a wider range of settings. Layer normalization transformers the inputs to have zero mean and unit variance across the features. -*Note that batch normalization, fixes the zero mean and unit variance for each vector. +*Note that batch normalization fixes the zero mean and unit variance for each element.* Layer normalization does it for each batch across all elements. Layer normalization is generally used for NLP tasks. @@ -41,18 +41,42 @@ from labml_helpers.module import Module class LayerNorm(Module): - """ + r""" ## Layer Normalization + + Layer normalization $\text{LN}$ normalizes the input $X$ as follows: + + When input $X \in \mathbb{R}^{B \times C}$ is a batch of embeddings, + where $B$ is the batch size and $C$ is the number of features. + $\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$. + $$\text{LN}(X) = \gamma + \frac{X - \underset{C}{\mathbb{E}}[X]}{\sqrt{\underset{C}{Var}[X] + \epsilon}} + + \beta$$ + + When input $X \in \mathbb{R}^{L \times B \times C}$ is a batch of a sequence of embeddings, + where $B$ is the batch size, $C$ is the number of channels, $L$ is the length of the sequence. + $\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$. + $$\text{LN}(X) = \gamma + \frac{X - \underset{C}{\mathbb{E}}[X]}{\sqrt{\underset{C}{Var}[X] + \epsilon}} + + \beta$$ + + When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations, + where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width. + This is not a widely used scenario. + $\gamma \in \mathbb{R}^{C \times H \times W}$ and $\beta \in \mathbb{R}^{C \times H \times W}$. + $$\text{LN}(X) = \gamma + \frac{X - \underset{C, H, W}{\mathbb{E}}[X]}{\sqrt{\underset{C, H, W}{Var}[X] + \epsilon}} + + \beta$$ """ def __init__(self, normalized_shape: Union[int, List[int], Size], *, eps: float = 1e-5, elementwise_affine: bool = True): """ - * `normalized_shape` $S$ is shape of the elements (except the batch). + * `normalized_shape` $S$ is the shape of the elements (except the batch). The input should then be $X \in \mathbb{R}^{* \times S[0] \times S[1] \times ... \times S[n]}$ - * `eps` is $\epsilon$, used in $\sqrt{Var[X}] + \epsilon}$ for numerical stability + * `eps` is $\epsilon$, used in $\sqrt{Var[X] + \epsilon}$ for numerical stability * `elementwise_affine` is whether to scale and shift the normalized value We've tried to use the same names for arguments as PyTorch `LayerNorm` implementation. @@ -74,34 +98,35 @@ class LayerNorm(Module): For example, in an NLP task this will be `[seq_len, batch_size, features]` """ - # Keep the original shape - x_shape = x.shape # Sanity check to make sure the shapes match assert self.normalized_shape == x.shape[-len(self.normalized_shape):] - # Reshape into `[M, S[0], S[1], ..., S[n]]` - x = x.view(-1, *self.normalized_shape) + # The dimensions to calculate the mean and variance on + dims = [-(i + 1) for i in range(len(self.normalized_shape))] - # Calculate the mean across first dimension; - # i.e. the means for each element $\mathbb{E}[X}]$ - mean = x.mean(dim=0) - # Calculate the squared mean across first dimension; + # Calculate the mean of all elements; + # i.e. the means for each element $\mathbb{E}[X]$ + mean = x.mean(dim=dims, keepdims=True) + # Calculate the squared mean of all elements; # i.e. the means for each element $\mathbb{E}[X^2]$ - mean_x2 = (x ** 2).mean(dim=0) - # Variance for each element $Var[X] = \mathbb{E}[X^2] - \mathbb{E}[X]^2$ + mean_x2 = (x ** 2).mean(dim=dims, keepdims=True) + # Variance of all element $Var[X] = \mathbb{E}[X^2] - \mathbb{E}[X]^2$ var = mean_x2 - mean ** 2 - # Normalize $$\hat{X} = \frac{X} - \mathbb{E}[X]}{\sqrt{Var[X] + \epsilon}}$$ + # Normalize $$\hat{X} = \frac{X - \mathbb{E}[X]}{\sqrt{Var[X] + \epsilon}}$$ x_norm = (x - mean) / torch.sqrt(var + self.eps) # Scale and shift $$\text{LN}(x) = \gamma \hat{X} + \beta$$ if self.elementwise_affine: x_norm = self.gain * x_norm + self.bias - # Reshape to original and return - return x_norm.view(x_shape) + # + return x_norm def _test(): + """ + Simple test + """ from labml.logger import inspect x = torch.zeros([2, 3, 2, 4]) @@ -113,5 +138,6 @@ def _test(): inspect(ln.gain.shape) +# if __name__ == '__main__': _test() diff --git a/labml_nn/normalization/layer_norm/readme.md b/labml_nn/normalization/layer_norm/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..2be7ecd58a30ed73100030f81a6b6064c6a3d924 --- /dev/null +++ b/labml_nn/normalization/layer_norm/readme.md @@ -0,0 +1,26 @@ +# [Layer Normalization](https://nn.labml.ai/normalization/layer_norm/index.html) + +This is a [PyTorch](https://pytorch.org) implementation of +[Layer Normalization](https://arxiv.org/abs/1607.06450). + +### Limitations of [Batch Normalization](https://nn.labml.ai/normalization/batch_norm/index.html) + +* You need to maintain running means. +* Tricky for RNNs. Do you need different normalizations for each step? +* Doesn't work with small batch sizes; +large NLP models are usually trained with small batch sizes. +* Need to compute means and variances across devices in distributed training + +## Layer Normalization + +Layer normalization is a simpler normalization method that works +on a wider range of settings. +Layer normalization transformers the inputs to have zero mean and unit variance +across the features. +*Note that batch normalization fixes the zero mean and unit variance for each element.* +Layer normalization does it for each batch across all elements. + +Layer normalization is generally used for NLP tasks. + +We have used layer normalization in most of the +[transformer implementations](https://nn.labml.ai/transformers/gpt/index.html). \ No newline at end of file diff --git a/readme.md b/readme.md index 6fa6ad7d7ebf4673f0369c4341abf27e1091ba12..7bd188a401f719a3c9e97251ab168dfeacc74a7f 100644 --- a/readme.md +++ b/readme.md @@ -66,6 +66,7 @@ and #### ✨ [Normalization Layers](https://nn.labml.ai/normalization/index.html) * [Batch Normalization](https://nn.labml.ai/normalization/batch_norm/index.html) +* [Layer Normalization](https://nn.labml.ai/normalization/layer_norm/index.html) ### Installation