Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Annotated Deep Learning Paper Implementations
提交
774a72e0
A
Annotated Deep Learning Paper Implementations
项目概览
Greenplum
/
Annotated Deep Learning Paper Implementations
11 个月 前同步成功
通知
6
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
Annotated Deep Learning Paper Implementations
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
774a72e0
编写于
1月 13, 2021
作者:
V
Varuna Jayasiri
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
gpt
上级
a11e3eff
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
324 addition
and
20 deletion
+324
-20
labml_nn/experiments/nlp_autoregression.py
labml_nn/experiments/nlp_autoregression.py
+5
-1
labml_nn/optimizers/__init__.py
labml_nn/optimizers/__init__.py
+2
-0
labml_nn/optimizers/adam.py
labml_nn/optimizers/adam.py
+6
-2
labml_nn/optimizers/adam_warmup.py
labml_nn/optimizers/adam_warmup.py
+2
-2
labml_nn/optimizers/adam_warmup_cosine_decay.py
labml_nn/optimizers/adam_warmup_cosine_decay.py
+95
-0
labml_nn/optimizers/configs.py
labml_nn/optimizers/configs.py
+10
-0
labml_nn/optimizers/radam.py
labml_nn/optimizers/radam.py
+6
-3
labml_nn/transformers/configs.py
labml_nn/transformers/configs.py
+18
-7
labml_nn/transformers/gpt/__init__.py
labml_nn/transformers/gpt/__init__.py
+167
-0
labml_nn/transformers/models.py
labml_nn/transformers/models.py
+13
-5
未找到文件。
labml_nn/experiments/nlp_autoregression.py
浏览文件 @
774a72e0
...
...
@@ -47,6 +47,10 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
loss_func
=
CrossEntropyLoss
()
accuracy
=
Accuracy
()
d_model
:
int
=
512
grad_norm_clip
:
float
=
1.0
train_loader
:
DataLoader
=
'shuffled_train_loader'
valid_loader
:
DataLoader
=
'shuffled_valid_loader'
def
init
(
self
):
tracker
.
set_scalar
(
"accuracy.*"
,
True
)
...
...
@@ -70,7 +74,7 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
if
self
.
mode
.
is_train
:
loss
.
backward
()
torch
.
nn
.
utils
.
clip_grad_norm_
(
self
.
model
.
parameters
(),
max_norm
=
1.
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
self
.
model
.
parameters
(),
max_norm
=
self
.
grad_norm_clip
)
self
.
optimizer
.
step
()
if
batch_idx
.
is_last
:
tracker
.
add
(
'model'
,
self
.
model
)
...
...
labml_nn/optimizers/__init__.py
浏览文件 @
774a72e0
...
...
@@ -209,3 +209,5 @@ class WeightDecay:
if
group
[
'weight_decay'
]
!=
0
:
# Add the weight decay to the gradient and return the modified gradient
return
grad
.
add
(
param
.
data
,
alpha
=
group
[
'weight_decay'
])
else
:
return
grad
labml_nn/optimizers/adam.py
浏览文件 @
774a72e0
...
...
@@ -41,6 +41,7 @@ import math
from
typing
import
Dict
,
Any
,
Tuple
,
Optional
import
torch
from
labml
import
tracker
from
torch
import
nn
from
labml_nn.optimizers
import
GenericAdaptiveOptimizer
,
WeightDecay
...
...
@@ -168,12 +169,15 @@ class Adam(GenericAdaptiveOptimizer):
# Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$
bias_correction2
=
1
-
beta2
**
state
[
'step'
]
# Get learning rate
lr
=
self
.
get_lr
(
state
,
group
)
# Whether to optimize the computation
if
self
.
optimized_update
:
# $\sqrt{v_t} + \hat{\epsilon}$
denominator
=
v
.
sqrt
().
add_
(
group
[
'eps'
])
# $\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
step_size
=
self
.
get_lr
(
state
,
group
)
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
step_size
=
lr
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
# $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
# \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
param
.
data
.
addcdiv_
(
m
,
denominator
,
value
=-
step_size
)
...
...
@@ -182,7 +186,7 @@ class Adam(GenericAdaptiveOptimizer):
# $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$
denominator
=
(
v
.
sqrt
()
/
math
.
sqrt
(
bias_correction2
)).
add_
(
group
[
'eps'
])
# $\frac{\alpha}{1-\beta_1^t}$
step_size
=
self
.
get_lr
(
state
,
group
)
/
bias_correction1
step_size
=
lr
/
bias_correction1
# $\theta_t \leftarrow \theta_{t-1} - \alpha \cdot
# \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$
param
.
data
.
addcdiv_
(
m
,
denominator
,
value
=-
step_size
)
...
...
labml_nn/optimizers/adam_warmup.py
浏览文件 @
774a72e0
...
...
@@ -6,7 +6,7 @@ summary: A simple PyTorch implementation/tutorial of Adam optimizer with warm-up
# Adam Optimizer with Warmup
This extends [AMSGrad optimizer](a
dam
.html) and adds a warmup stage.
This extends [AMSGrad optimizer](a
msgrad
.html) and adds a warmup stage.
"""
from
typing
import
Dict
...
...
@@ -19,7 +19,7 @@ class AdamWarmup(AMSGrad):
"""
## Adam Optimizer with Warmup
This class extends from A
dam optimizer defined in [`adam.py`](adam
.html).
This class extends from A
MSGrad optimizer defined in [`amsgrad.py`](amsgrad
.html).
"""
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-16
,
weight_decay
:
WeightDecay
=
WeightDecay
(),
...
...
labml_nn/optimizers/adam_warmup_cosine_decay.py
0 → 100644
浏览文件 @
774a72e0
"""
---
title: Adam optimizer with warm-up and cosine decay
summary: A PyTorch implementation/tutorial of Adam optimizer with warm-up and cosine decay for GPT.
---
# Adam Optimizer with Warmup and Cosine Decay
This extends [AMSGrad optimizer](adam.html) and adds a warmup stage.
"""
import
math
from
typing
import
Dict
from
labml_nn.optimizers
import
WeightDecay
from
labml_nn.optimizers.amsgrad
import
AMSGrad
class
AdamWarmupCosineDecay
(
AMSGrad
):
"""
## Adam Optimizer with Warmup
This class extends from AMSGrad optimizer defined in [`amsgrad.py`](amsgrad.html).
"""
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-16
,
weight_decay
:
WeightDecay
=
WeightDecay
(),
optimized_update
:
bool
=
True
,
amsgrad
=
False
,
warmup
=
0
,
total_steps
=
1e10
,
defaults
=
None
):
"""
### Initialize the optimizer
* `params` is the list of parameters
* `lr` is the learning rate $
\a
lpha$
* `betas` is a tuple of ($
\b
eta_1$, $
\b
eta_2$)
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
* 'optimized_update' is a flag whether to optimize the bias correction of the second moment
by doing it after adding $\epsilon$
* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
* `warmup` number of warmup steps
* `total_steps` total number of steps. Cosine decay reaches 0 at this,
but stays at 10% of `lr` because we take $
\a
lpha * \max(0.1, decay)$
* `defaults` is a dictionary of default for group values.
This is useful when you want to extend the class `AdamWarmup`.
"""
defaults
=
{}
if
defaults
is
None
else
defaults
defaults
.
update
(
dict
(
warmup
=
warmup
,
total_steps
=
total_steps
))
super
().
__init__
(
params
,
lr
,
betas
,
eps
,
weight_decay
,
optimized_update
,
amsgrad
,
defaults
)
def
get_lr
(
self
,
state
:
Dict
[
str
,
any
],
group
:
Dict
[
str
,
any
]):
"""
### Get learning-rate
$$
\a
lpha \min
\b
igg(1,
\f
rac{t}{w}
\b
igg)$$
where $w$ is the number of warmup steps.
"""
# If we are in warmup stage
if
group
[
'warmup'
]
>
state
[
'step'
]:
# A linearly increasing learning rate from $0$ to $\alpha$
return
1e-8
+
state
[
'step'
]
*
group
[
'lr'
]
/
group
[
'warmup'
]
else
:
# Constant learning rate $\alpha$
progress
=
(
state
[
'step'
]
-
group
[
'warmup'
])
/
max
(
1
,
group
[
'total_steps'
]
-
group
[
'warmup'
])
return
group
[
'lr'
]
*
max
(
0.1
,
0.5
*
(
1.0
+
math
.
cos
(
math
.
pi
*
progress
)))
def
_test_lr
():
"""
### Plot learning rate for different warmups and model sizes
![Plot of learning rate](noam_lr.png)
"""
import
matplotlib.pyplot
as
plt
import
numpy
as
np
from
torch
import
nn
model
=
nn
.
Linear
(
10
,
10
)
opt
=
AdamWarmupCosineDecay
(
model
.
parameters
(),
warmup
=
5000
,
lr
=
1e-4
,
total_steps
=
4e6
)
steps
=
20_000
plt
.
plot
(
np
.
arange
(
1
,
steps
),
[
opt
.
get_lr
({
'step'
:
i
},
opt
.
defaults
)
for
i
in
range
(
1
,
steps
)])
plt
.
legend
([
"5000:4e6"
,
"5000:2e6"
,
"5000:1e6"
])
plt
.
title
(
"Learning Rate"
)
plt
.
show
()
steps
=
int
(
6e6
)
step_size
=
1000
plt
.
plot
(
np
.
arange
(
1
,
steps
,
step_size
),
[
opt
.
get_lr
({
'step'
:
i
},
opt
.
defaults
)
for
i
in
range
(
1
,
steps
,
step_size
)])
plt
.
legend
([
"5000:4e6"
,
"5000:2e6"
,
"5000:1e6"
])
plt
.
title
(
"Learning Rate"
)
plt
.
show
()
if
__name__
==
'__main__'
:
_test_lr
()
labml_nn/optimizers/configs.py
浏览文件 @
774a72e0
...
...
@@ -31,6 +31,7 @@ class OptimizerConfigs(BaseConfigs):
momentum
:
float
=
0.5
amsgrad
:
bool
=
False
warmup
:
int
=
2_000
total_steps
:
int
=
int
(
1e10
)
degenerated_to_sgd
:
bool
=
True
rectify
:
bool
=
True
d_model
:
int
...
...
@@ -103,3 +104,12 @@ def _noam_optimizer(c: OptimizerConfigs):
lr
=
c
.
learning_rate
,
betas
=
c
.
betas
,
eps
=
c
.
eps
,
weight_decay
=
c
.
weight_decay_obj
,
amsgrad
=
c
.
amsgrad
,
warmup
=
c
.
warmup
,
d_model
=
c
.
d_model
)
@
option
(
OptimizerConfigs
.
optimizer
,
'AdamWarmupCosineDecay'
)
def
_noam_optimizer
(
c
:
OptimizerConfigs
):
from
labml_nn.optimizers.adam_warmup_cosine_decay
import
AdamWarmupCosineDecay
return
AdamWarmupCosineDecay
(
c
.
parameters
,
lr
=
c
.
learning_rate
,
betas
=
c
.
betas
,
eps
=
c
.
eps
,
weight_decay
=
c
.
weight_decay_obj
,
amsgrad
=
c
.
amsgrad
,
warmup
=
c
.
warmup
,
total_steps
=
c
.
total_steps
)
labml_nn/optimizers/radam.py
浏览文件 @
774a72e0
...
...
@@ -229,6 +229,9 @@ class RAdam(AMSGrad):
r
=
self
.
calc_rectification_term
(
beta2
,
state
[
'step'
])
# Get learning rate
lr
=
self
.
get_lr
(
state
,
group
)
# If $r_t$ is intractable
if
r
is
not
None
:
# Whether to optimize the computation by combining scalar computations
...
...
@@ -236,7 +239,7 @@ class RAdam(AMSGrad):
# Denominator $\sqrt{v_t} + \hat{\epsilon}$
denominator
=
v
.
sqrt
().
add_
(
group
[
'eps'
])
# Step size $\alpha \sqrt{r_t} * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
step_size
=
self
.
get_lr
(
state
,
group
)
*
math
.
sqrt
(
bias_correction2
)
*
r
/
bias_correction1
step_size
=
lr
*
math
.
sqrt
(
bias_correction2
)
*
r
/
bias_correction1
# Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \sqrt{r_t} \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
# \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
param
.
data
.
addcdiv_
(
m
,
denominator
,
value
=-
step_size
)
...
...
@@ -245,7 +248,7 @@ class RAdam(AMSGrad):
# Denominator $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$
denominator
=
(
v
.
sqrt
()
/
math
.
sqrt
(
bias_correction2
)).
add_
(
group
[
'eps'
])
# Step size $\frac{\alpha \sqrt{r_t}}{1-\beta_1^t}$
step_size
=
self
.
get_lr
(
state
,
group
)
*
r
/
bias_correction1
step_size
=
lr
*
r
/
bias_correction1
# Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \sqrt{r_t} \cdot
# \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$
param
.
data
.
addcdiv_
(
m
,
denominator
,
value
=-
step_size
)
...
...
@@ -253,7 +256,7 @@ class RAdam(AMSGrad):
# If $r_t$ is intractable do a SGD with momentum
elif
self
.
degenerated_to_sgd
:
# Step size $\frac{\alpha}{1-\beta_1^t}$
step_size
=
self
.
get_lr
(
state
,
group
)
/
bias_correction1
step_size
=
lr
/
bias_correction1
# Update parameters
# $\theta_t \leftarrow \theta_{t-1} - \alpha \cdot \hat{m}_t$
param
.
data
.
add_
(
m
,
alpha
=-
step_size
)
...
...
labml_nn/transformers/configs.py
浏览文件 @
774a72e0
...
...
@@ -9,9 +9,9 @@ summary: These are configurable components that can be re-used quite easily.
import
copy
import
torch.nn
as
nn
from
labml.configs
import
BaseConfigs
,
option
,
calculate
from
labml_helpers.module
import
Module
from
.mha
import
MultiHeadAttention
from
.models
import
EmbeddingsWithPositionalEncoding
,
EmbeddingsWithLearnedPositionalEncoding
,
FeedForward
,
\
TransformerLayer
,
Encoder
,
Decoder
,
Generator
,
EncoderDecoder
...
...
@@ -30,6 +30,7 @@ class TransformerConfigs(BaseConfigs):
decoder_attn
:
MultiHeadAttention
=
'mha'
decoder_mem_attn
:
MultiHeadAttention
=
'mha'
feed_forward
:
FeedForward
feed_forward_activation
:
nn
.
Module
=
'ReLU'
encoder_layer
:
TransformerLayer
=
'normal'
decoder_layer
:
TransformerLayer
=
'normal'
...
...
@@ -45,12 +46,22 @@ class TransformerConfigs(BaseConfigs):
encoder_decoder
:
EncoderDecoder
@
option
(
TransformerConfigs
.
feed_forward_activation
,
'ReLU'
)
def
_feed_forward_activation_relu
():
return
nn
.
ReLU
()
@
option
(
TransformerConfigs
.
feed_forward_activation
,
'GELU'
)
def
_feed_forward_activation_relu
():
return
nn
.
GELU
()
@
option
(
TransformerConfigs
.
feed_forward
,
'default'
)
def
_feed_forward
(
c
:
TransformerConfigs
):
return
FeedForward
(
c
.
d_model
,
c
.
d_ff
,
c
.
dropout
)
return
FeedForward
(
c
.
d_model
,
c
.
d_ff
,
c
.
dropout
,
c
.
feed_forward_activation
)
### MHA
#
## MHA
def
_mha
(
c
:
TransformerConfigs
):
return
MultiHeadAttention
(
c
.
n_heads
,
c
.
d_model
)
...
...
@@ -60,7 +71,7 @@ calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
calculate
(
TransformerConfigs
.
decoder_mem_attn
,
'mha'
,
_mha
)
### Relative MHA
#
## Relative MHA
def
_relative_mha
(
c
:
TransformerConfigs
):
from
.relative_mha
import
RelativeMultiHeadAttention
return
RelativeMultiHeadAttention
(
c
.
n_heads
,
c
.
d_model
)
...
...
@@ -100,7 +111,7 @@ def _generator(c: TransformerConfigs):
return
Generator
(
c
.
n_tgt_vocab
,
c
.
d_model
)
### Positional Embeddings
#
## Positional Embeddings
@
option
(
TransformerConfigs
.
src_embed
,
'fixed_pos'
)
def
_src_embed_with_positional
(
c
:
TransformerConfigs
):
return
EmbeddingsWithPositionalEncoding
(
c
.
d_model
,
c
.
n_src_vocab
)
...
...
@@ -111,7 +122,7 @@ def _tgt_embed_with_positional(c: TransformerConfigs):
return
EmbeddingsWithPositionalEncoding
(
c
.
d_model
,
c
.
n_tgt_vocab
)
### Learned Positional Embeddings
#
## Learned Positional Embeddings
@
option
(
TransformerConfigs
.
src_embed
,
'learned_pos'
)
def
_src_embed_with_learned_positional
(
c
:
TransformerConfigs
):
return
EmbeddingsWithLearnedPositionalEncoding
(
c
.
d_model
,
c
.
n_src_vocab
)
...
...
@@ -122,7 +133,7 @@ def _tgt_embed_with_learned_positional(c: TransformerConfigs):
return
EmbeddingsWithLearnedPositionalEncoding
(
c
.
d_model
,
c
.
n_tgt_vocab
)
### No Positional Embeddings
#
## No Positional Embeddings
@
option
(
TransformerConfigs
.
src_embed
,
'no_pos'
)
def
_src_embed_without_positional
(
c
:
TransformerConfigs
):
return
nn
.
Embedding
(
c
.
n_src_vocab
,
c
.
d_model
)
...
...
labml_nn/transformers/gpt/__init__.py
0 → 100644
浏览文件 @
774a72e0
import
torch
from
labml
import
experiment
from
labml.configs
import
option
from
labml_helpers.module
import
Module
from
torch
import
nn
from
labml_nn.experiments.nlp_autoregression
import
NLPAutoRegressionConfigs
from
labml_nn.optimizers.configs
import
OptimizerConfigs
from
labml_nn.transformers
import
TransformerConfigs
,
Encoder
from
labml_nn.transformers.utils
import
subsequent_mask
class
GPT
(
Module
):
def
__init__
(
self
,
encoder
:
Encoder
,
src_embed
:
Module
,
generator
:
Module
):
super
().
__init__
()
# Make copies of the transformer layer
self
.
src_embed
=
src_embed
self
.
encoder
=
encoder
self
.
generator
=
generator
# This will be initialized on the first call
self
.
mask
=
None
def
__call__
(
self
,
x
:
torch
.
Tensor
):
if
self
.
mask
is
None
or
self
.
mask
.
size
(
0
)
!=
len
(
x
):
self
.
mask
=
subsequent_mask
(
len
(
x
)).
to
(
x
.
device
)
# Run through each transformer layer
x
=
self
.
src_embed
(
x
)
x
=
self
.
encoder
(
x
,
self
.
mask
)
x
=
self
.
generator
(
x
)
return
x
,
None
class
Configs
(
NLPAutoRegressionConfigs
):
model
:
GPT
transformer
:
TransformerConfigs
weight_decay
:
float
=
0.1
warmup_steps
:
int
=
512
*
128
*
500
optimizer
=
'transformer_optimizer'
@
option
(
Configs
.
transformer
,
'GPT'
)
def
_transformer_configs
(
c
:
Configs
):
conf
=
TransformerConfigs
()
conf
.
n_src_vocab
=
c
.
n_tokens
conf
.
n_tgt_vocab
=
c
.
n_tokens
conf
.
feed_forward_activation
=
'GELU'
return
conf
def
_init_weights
(
module
):
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
0.02
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
@
option
(
Configs
.
model
)
def
_model
(
c
:
Configs
):
m
=
GPT
(
c
.
transformer
.
encoder
,
c
.
transformer
.
src_embed
,
c
.
transformer
.
generator
).
to
(
c
.
device
)
m
.
apply
(
_init_weights
)
return
m
@
option
(
NLPAutoRegressionConfigs
.
optimizer
)
def
transformer_optimizer
(
c
:
NLPAutoRegressionConfigs
):
optimizer
=
OptimizerConfigs
()
decay
=
set
()
no_decay
=
set
()
whitelist_weight_modules
=
(
nn
.
Linear
,)
blacklist_weight_modules
=
(
nn
.
LayerNorm
,
nn
.
Embedding
)
for
mn
,
m
in
c
.
model
.
named_modules
():
for
pn
,
p
in
m
.
named_parameters
():
fpn
=
f
'
{
mn
}
.
{
pn
}
'
if
mn
else
pn
# full param name
if
fpn
.
find
(
'positional_encodings'
)
!=
-
1
:
no_decay
.
add
(
fpn
)
elif
fpn
.
endswith
(
'bias'
):
# all biases will not be decayed
no_decay
.
add
(
fpn
)
elif
fpn
.
endswith
(
'weight'
):
if
isinstance
(
m
,
whitelist_weight_modules
):
# weights of whitelist modules will be weight decayed
decay
.
add
(
fpn
)
elif
isinstance
(
m
,
blacklist_weight_modules
):
# weights of blacklist modules will NOT be weight decayed
no_decay
.
add
(
fpn
)
# validate that we considered every parameter
param_dict
=
{
pn
:
p
for
pn
,
p
in
c
.
model
.
named_parameters
()}
inter_params
=
decay
&
no_decay
if
inter_params
:
raise
ValueError
(
"Repeated parameters"
,
inter_params
)
missing_params
=
set
(
param_dict
.
keys
())
-
(
decay
|
no_decay
)
if
missing_params
:
raise
ValueError
(
'Missing parameters'
,
missing_params
)
# create the pytorch optimizer object
opt_groups
=
[
{
"params"
:
[
param_dict
[
pn
]
for
pn
in
sorted
(
list
(
decay
))],
"weight_decay"
:
c
.
weight_decay
},
{
"params"
:
[
param_dict
[
pn
]
for
pn
in
sorted
(
list
(
no_decay
))],
"weight_decay"
:
0.0
},
]
optimizer
.
parameters
=
opt_groups
optimizer
.
optimizer
=
'AdamWarmupCosineDecay'
optimizer
.
d_model
=
c
.
d_model
optimizer
.
weight_decay
=
c
.
weight_decay
optimizer
.
learning_rate
=
6e-4
optimizer
.
betas
=
(
0.9
,
0.95
)
optimizer
.
eps
=
1e-8
optimizer
.
weight_decouple
=
False
optimizer
.
total_steps
=
c
.
epochs
*
len
(
c
.
text
.
train
)
optimizer
.
warmup
=
c
.
warmup_steps
//
(
c
.
batch_size
*
c
.
seq_len
)
return
optimizer
def
main
():
# Create experiment
experiment
.
create
(
name
=
"gpt"
)
# Create configs
conf
=
Configs
()
# Load configurations
experiment
.
configs
(
conf
,
# A dictionary of configurations to override
{
'tokenizer'
:
'character'
,
'prompt_separator'
:
''
,
'prompt'
:
'It is '
,
'text'
:
'tiny_shakespeare'
,
'seq_len'
:
128
,
'epochs'
:
32
,
'batch_size'
:
128
,
'inner_iterations'
:
10
,
# Transformer configurations
'transformer.d_model'
:
512
,
'transformer.d_ff'
:
2048
,
'transformer.n_heads'
:
8
,
'transformer.n_layers'
:
6
})
# This is needed to initialize models
conf
.
n_tokens
=
conf
.
text
.
n_tokens
# Set models for saving and loading
experiment
.
add_pytorch_models
({
'model'
:
conf
.
model
})
# Start the experiment
with
experiment
.
start
():
# `TrainValidConfigs.run`
conf
.
run
()
if
__name__
==
'__main__'
:
main
()
labml_nn/transformers/models.py
浏览文件 @
774a72e0
...
...
@@ -12,9 +12,8 @@ import math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
labml_helpers.module
import
Module
from
labml_nn.utils
import
clone_module_list
from
.mha
import
MultiHeadAttention
from
.positional_encoding
import
get_positional_encoding
...
...
@@ -24,6 +23,7 @@ class EmbeddingsWithPositionalEncoding(Module):
"""
## Embed tokens and add [fixed positional encoding](positional_encoding.html)
"""
def
__init__
(
self
,
d_model
:
int
,
n_vocab
:
int
,
max_len
:
int
=
5000
):
super
().
__init__
()
self
.
linear
=
nn
.
Embedding
(
n_vocab
,
d_model
)
...
...
@@ -39,6 +39,7 @@ class EmbeddingsWithLearnedPositionalEncoding(Module):
"""
## Embed tokens and add parameterized positional encodings
"""
def
__init__
(
self
,
d_model
:
int
,
n_vocab
:
int
,
max_len
:
int
=
5000
):
super
().
__init__
()
self
.
linear
=
nn
.
Embedding
(
n_vocab
,
d_model
)
...
...
@@ -54,15 +55,17 @@ class FeedForward(Module):
"""
## Position-wise feed-forward network with hidden layer
"""
def
__init__
(
self
,
d_model
:
int
,
d_ff
:
int
,
dropout
:
float
=
0.1
):
def
__init__
(
self
,
d_model
:
int
,
d_ff
:
int
,
dropout
:
float
=
0.1
,
activation
=
nn
.
ReLU
()):
super
().
__init__
()
self
.
layer1
=
nn
.
Linear
(
d_model
,
d_ff
)
self
.
layer2
=
nn
.
Linear
(
d_ff
,
d_model
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
activation
=
activation
def
__call__
(
self
,
x
:
torch
.
Tensor
):
x
=
self
.
layer1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
activation
(
x
)
x
=
self
.
dropout
(
x
)
return
self
.
layer2
(
x
)
...
...
@@ -82,6 +85,7 @@ class TransformerLayer(Module):
We found a detailed discussion about this in paper
[On Layer Normalization in the Transformer Architecture](https://arxiv.org/abs/2002.04745).
"""
def
__init__
(
self
,
*
,
d_model
:
int
,
self_attn
:
MultiHeadAttention
,
...
...
@@ -141,6 +145,7 @@ class Encoder(Module):
"""
## Transformer Encoder
"""
def
__init__
(
self
,
layer
:
TransformerLayer
,
n_layers
:
int
):
super
().
__init__
()
# Make copies of the transformer layer
...
...
@@ -159,6 +164,7 @@ class Decoder(Module):
"""
## Transformer Decoder
"""
def
__init__
(
self
,
layer
:
TransformerLayer
,
n_layers
:
int
):
super
().
__init__
()
# Make copies of the transformer layer
...
...
@@ -180,18 +186,20 @@ class Generator(Module):
This predicts the tokens and gives the lof softmax of those.
You don't need this if you are using `nn.CrossEntropyLoss`.
"""
def
__init__
(
self
,
n_vocab
:
int
,
d_model
:
int
):
super
().
__init__
()
self
.
projection
=
nn
.
Linear
(
d_model
,
n_vocab
)
def
__call__
(
self
,
x
):
return
F
.
log_softmax
(
self
.
projection
(
x
),
dim
=-
1
)
return
self
.
projection
(
x
)
class
EncoderDecoder
(
Module
):
"""
## Combined Encoder-Decoder
"""
def
__init__
(
self
,
encoder
:
Encoder
,
decoder
:
Decoder
,
src_embed
:
Module
,
tgt_embed
:
Module
,
generator
:
Module
):
super
().
__init__
()
self
.
encoder
=
encoder
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录