Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Pytorch Widedeep
提交
1690395b
P
Pytorch Widedeep
项目概览
Greenplum
/
Pytorch Widedeep
10 个月 前同步成功
通知
9
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Pytorch Widedeep
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
1690395b
编写于
5月 16, 2022
作者:
J
Javier Rodriguez Zaurin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Re-distributed the code. Added encoder and decoder self supervision for non-attention based models
上级
cdd674ed
变更
21
显示空白变更内容
内联
并排
Showing
21 changed file
with
1325 addition
and
295 deletion
+1325
-295
examples/scripts/adult_census_self_supervised.py
examples/scripts/adult_census_self_supervised.py
+97
-47
pytorch_widedeep/losses.py
pytorch_widedeep/losses.py
+26
-0
pytorch_widedeep/models/tabular/mlp/tab_mlp.py
pytorch_widedeep/models/tabular/mlp/tab_mlp.py
+44
-0
pytorch_widedeep/models/tabular/resnet/tab_resnet.py
pytorch_widedeep/models/tabular/resnet/tab_resnet.py
+64
-0
pytorch_widedeep/models/tabular/self_supervised/__init__.py
pytorch_widedeep/models/tabular/self_supervised/__init__.py
+0
-0
pytorch_widedeep/models/tabular/self_supervised/_augmentations.py
...widedeep/models/tabular/self_supervised/_augmentations.py
+0
-0
pytorch_widedeep/models/tabular/self_supervised/_denoise_mlps.py
..._widedeep/models/tabular/self_supervised/_denoise_mlps.py
+0
-0
pytorch_widedeep/models/tabular/self_supervised/_random_obfuscator.py
...deep/models/tabular/self_supervised/_random_obfuscator.py
+13
-0
pytorch_widedeep/models/tabular/self_supervised/contrastive_denoising_model.py
...ls/tabular/self_supervised/contrastive_denoising_model.py
+16
-13
pytorch_widedeep/models/tabular/self_supervised/encoder_decoder_model.py
...p/models/tabular/self_supervised/encoder_decoder_model.py
+57
-0
pytorch_widedeep/models/tabular/tabnet/_layers.py
pytorch_widedeep/models/tabular/tabnet/_layers.py
+6
-3
pytorch_widedeep/models/tabular/tabnet/tab_net.py
pytorch_widedeep/models/tabular/tabnet/tab_net.py
+80
-2
pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py
...upervised_training/_base_contrastive_denoising_trainer.py
+75
-24
pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py
...self_supervised_training/_base_encoder_decoder_trainer.py
+159
-0
pytorch_widedeep/self_supervised_training/contrastive_denoising_trainer.py
...self_supervised_training/contrastive_denoising_trainer.py
+210
-0
pytorch_widedeep/self_supervised_training/encoder_decoder_trainer.py
...edeep/self_supervised_training/encoder_decoder_trainer.py
+240
-0
pytorch_widedeep/training/_base_bayesian_trainer.py
pytorch_widedeep/training/_base_bayesian_trainer.py
+216
-0
pytorch_widedeep/training/_base_trainer.py
pytorch_widedeep/training/_base_trainer.py
+1
-201
pytorch_widedeep/training/bayesian_trainer.py
pytorch_widedeep/training/bayesian_trainer.py
+7
-3
pytorch_widedeep/training/trainer.py
pytorch_widedeep/training/trainer.py
+1
-1
pytorch_widedeep/wdtypes.py
pytorch_widedeep/wdtypes.py
+13
-1
未找到文件。
examples/scripts/adult_census_self_supervised.py
浏览文件 @
1690395b
from
itertools
import
product
import
numpy
as
np
import
torch
import
pandas
as
pd
from
pytorch_widedeep.models
import
(
SAINT
,
TabPerceiver
,
FTTransformer
,
TabFastFormer
,
TabTransformer
,
SelfAttentionMLP
,
ContextAttentionMLP
,
)
from
pytorch_widedeep.datasets
import
load_adult
from
pytorch_widedeep.preprocessing
import
TabPreprocessor
from
pytorch_widedeep.self_supervised_training.
self_supervised
_trainer
import
(
SelfSupervised
Trainer
,
from
pytorch_widedeep.self_supervised_training.
contrastive_denoising
_trainer
import
(
ContrastiveDenoising
Trainer
,
)
use_cuda
=
torch
.
cuda
.
is_available
()
...
...
@@ -55,43 +58,54 @@ if __name__ == "__main__":
target
=
"income_label"
target
=
df
[
target
].
values
tab_preprocessor
=
TabPreprocessor
(
transformer_models
=
[
"tab_transformer"
,
"saint"
,
"tab_fastformer"
,
"ft_transformer"
,
]
with_cls_token
=
[
True
,
False
]
for
w_cls_tok
,
transf_model
in
product
(
with_cls_token
,
transformer_models
):
processor
=
TabPreprocessor
(
cat_embed_cols
=
cat_embed_cols
,
continuous_cols
=
continuous_cols
,
with_attention
=
True
,
with_cls_token
=
True
,
with_cls_token
=
w_cls_tok
,
)
X_tab
=
tab_pre
processor
.
fit_transform
(
df
)
X_tab
=
processor
.
fit_transform
(
df
)
tab_transformer
=
TabTransformer
(
column_idx
=
tab_preprocessor
.
column_idx
,
cat_embed_input
=
tab_preprocessor
.
cat_embed_input
,
if
transf_model
==
"tab_transformer"
:
model
=
TabTransformer
(
column_idx
=
processor
.
column_idx
,
cat_embed_input
=
processor
.
cat_embed_input
,
continuous_cols
=
continuous_cols
,
embed_continuous
=
True
,
n_blocks
=
4
,
)
saint
=
SAINT
(
column_idx
=
tab_pre
processor
.
column_idx
,
cat_embed_input
=
tab_pre
processor
.
cat_embed_input
,
if
transf_model
==
"saint"
:
model
=
SAINT
(
column_idx
=
processor
.
column_idx
,
cat_embed_input
=
processor
.
cat_embed_input
,
continuous_cols
=
continuous_cols
,
cont_norm_layer
=
"batchnorm"
,
n_blocks
=
4
,
)
tab_fastformer
=
TabFastFormer
(
column_idx
=
tab_pre
processor
.
column_idx
,
cat_embed_input
=
tab_pre
processor
.
cat_embed_input
,
if
transf_model
==
"tab_fastformer"
:
model
=
TabFastFormer
(
column_idx
=
processor
.
column_idx
,
cat_embed_input
=
processor
.
cat_embed_input
,
continuous_cols
=
continuous_cols
,
n_blocks
=
4
,
n_heads
=
4
,
share_qv_weights
=
False
,
share_weights
=
False
,
)
ft_transformer
=
FTTransformer
(
column_idx
=
tab_pre
processor
.
column_idx
,
cat_embed_input
=
tab_pre
processor
.
cat_embed_input
,
if
transf_model
==
"ft_transformer"
:
model
=
FTTransformer
(
column_idx
=
processor
.
column_idx
,
cat_embed_input
=
processor
.
cat_embed_input
,
continuous_cols
=
continuous_cols
,
input_dim
=
32
,
kv_compression_factor
=
0.5
,
...
...
@@ -99,9 +113,45 @@ if __name__ == "__main__":
n_heads
=
4
,
)
for
transformer_model
in
[
tab_transformer
,
saint
,
tab_fastformer
,
ft_transformer
]:
ss_trainer
=
SelfSupervisedTrainer
(
model
=
transformer_model
,
preprocessor
=
tab_preprocessor
,
ss_trainer
=
ContrastiveDenoisingTrainer
(
base_model
=
model
,
preprocessor
=
processor
,
)
ss_trainer
.
pretrain
(
X_tab
,
n_epochs
=
1
,
batch_size
=
256
)
mlp_attn_model
=
[
"context_attention"
,
"self_attention"
]
for
w_cls_tok
,
attn_model
in
product
(
with_cls_token
,
mlp_attn_model
):
processor
=
TabPreprocessor
(
cat_embed_cols
=
cat_embed_cols
,
continuous_cols
=
continuous_cols
,
with_attention
=
True
,
with_cls_token
=
w_cls_tok
,
)
X_tab
=
processor
.
fit_transform
(
df
)
if
attn_model
==
"context_attention"
:
model
=
ContextAttentionMLP
(
column_idx
=
processor
.
column_idx
,
cat_embed_input
=
processor
.
cat_embed_input
,
continuous_cols
=
continuous_cols
,
input_dim
=
16
,
attn_dropout
=
0.2
,
n_blocks
=
3
,
)
if
attn_model
==
"self_attention"
:
model
=
SelfAttentionMLP
(
column_idx
=
processor
.
column_idx
,
cat_embed_input
=
processor
.
cat_embed_input
,
continuous_cols
=
continuous_cols
,
input_dim
=
16
,
attn_dropout
=
0.2
,
n_blocks
=
3
,
)
ss_trainer
=
ContrastiveDenoisingTrainer
(
base_model
=
model
,
preprocessor
=
processor
,
)
ss_trainer
.
pretrain
(
X_tab
,
n_epochs
=
1
,
batch_size
=
256
)
pytorch_widedeep/losses.py
浏览文件 @
1690395b
...
...
@@ -901,3 +901,29 @@ class DenoisingLoss(nn.Module):
loss_cont
+=
F
.
mse_loss
(
x_
,
x
,
reduction
=
self
.
reduction
)
return
loss_cont
class
EncoderDecoderLoss
(
object
):
def
__init__
(
self
,
eps
=
1e-9
):
super
(
EncoderDecoderLoss
,
self
).
__init__
()
self
.
eps
=
eps
def
forward
(
x_true
,
x_pred
,
mask
):
errors
=
x_pred
-
x_true
reconstruction_errors
=
torch
.
mul
(
errors
,
mask
)
**
2
x_true_means
=
torch
.
mean
(
x_true
,
dim
=
0
)
x_true_means
[
x_true_means
==
0
]
=
1
x_true_stds
=
torch
.
std
(
x_true
,
dim
=
0
)
**
2
x_true_stds
[
x_true_stds
==
0
]
=
x_true_means
[
x_true_stds
==
0
]
features_loss
=
torch
.
matmul
(
reconstruction_errors
,
1
/
x_true_stds
)
nb_reconstructed_variables
=
torch
.
sum
(
mask
,
dim
=
1
)
features_loss_norm
=
features_loss
/
(
nb_reconstructed_variables
+
eps
)
loss
=
torch
.
mean
(
features_loss_norm
)
return
loss
pytorch_widedeep/models/tabular/mlp/tab_mlp.py
浏览文件 @
1690395b
from
torch
import
nn
from
pytorch_widedeep.wdtypes
import
*
# noqa: F403
from
pytorch_widedeep.models.tabular.mlp._layers
import
MLP
from
pytorch_widedeep.models.tabular._base_tabular_model
import
(
...
...
@@ -152,3 +154,45 @@ class TabMlp(BaseTabularModelWithoutAttention):
@
property
def
output_dim
(
self
):
return
self
.
mlp_hidden_dims
[
-
1
]
# This is a companion Decoder for the TabMlp. We prefer not to refer to the
# 'TabMlp' model as 'TabMlpEncoder' (despite the fact that is indeed an
# encoder) for two reasons: 1. For convenience accross the library and 2.
# Because decoders are only going to be used in one of our implementations
# of Self Supervised pretraining, and we prefer to keep the names of
# the 'general' DL models as they are (e.g. TabMlp) as opposed as carry
# the 'Encoder' description (e.g. TabMlpEncoder) throughout the library
class
TabMlpDecoder
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
:
int
,
mlp_hidden_dims
:
List
[
int
]
=
[
100
,
200
],
mlp_activation
:
str
=
"relu"
,
mlp_dropout
:
Union
[
float
,
List
[
float
]]
=
0.1
,
mlp_batchnorm
:
bool
=
False
,
mlp_batchnorm_last
:
bool
=
False
,
mlp_linear_first
:
bool
=
False
,
):
super
(
TabMlpDecoder
,
self
).
__init__
()
self
.
embed_dim
=
embed_dim
self
.
mlp_hidden_dims
=
mlp_hidden_dims
self
.
mlp_activation
=
mlp_activation
self
.
mlp_dropout
=
mlp_dropout
self
.
mlp_batchnorm
=
mlp_batchnorm
self
.
mlp_batchnorm_last
=
mlp_batchnorm_last
self
.
mlp_linear_first
=
mlp_linear_first
self
.
decoder
=
MLP
(
mlp_hidden_dims
+
[
self
.
embed_dim
],
mlp_activation
,
mlp_dropout
,
mlp_batchnorm
,
mlp_batchnorm_last
,
mlp_linear_first
,
)
def
forward
(
self
,
X
:
Tensor
)
->
Tensor
:
return
self
.
decoder
(
X
)
pytorch_widedeep/models/tabular/resnet/tab_resnet.py
浏览文件 @
1690395b
from
torch
import
nn
from
pytorch_widedeep.wdtypes
import
*
# noqa: F403
from
pytorch_widedeep.models.tabular.mlp._layers
import
MLP
from
pytorch_widedeep.models.tabular.resnet._layers
import
DenseResnet
...
...
@@ -204,3 +206,65 @@ class TabResnet(BaseTabularModelWithoutAttention):
if
self
.
mlp_hidden_dims
is
not
None
else
self
.
blocks_dims
[
-
1
]
)
class
TabResnetDecoder
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
:
int
,
blocks_dims
:
List
[
int
]
=
[
100
,
100
,
200
],
blocks_dropout
:
float
=
0.1
,
simplify_blocks
:
bool
=
False
,
mlp_hidden_dims
:
Optional
[
List
[
int
]]
=
None
,
mlp_activation
:
str
=
"relu"
,
mlp_dropout
:
float
=
0.1
,
mlp_batchnorm
:
bool
=
False
,
mlp_batchnorm_last
:
bool
=
False
,
mlp_linear_first
:
bool
=
False
,
):
super
(
TabResnetDecoder
,
self
).
__init__
()
if
len
(
blocks_dims
)
<
2
:
raise
ValueError
(
"'blocks' must contain at least two elements, e.g. [256, 128]"
)
self
.
embed_dim
=
embed_dim
self
.
blocks_dims
=
blocks_dims
self
.
blocks_dropout
=
blocks_dropout
self
.
simplify_blocks
=
simplify_blocks
self
.
mlp_hidden_dims
=
mlp_hidden_dims
self
.
mlp_activation
=
mlp_activation
self
.
mlp_dropout
=
mlp_dropout
self
.
mlp_batchnorm
=
mlp_batchnorm
self
.
mlp_batchnorm_last
=
mlp_batchnorm_last
self
.
mlp_linear_first
=
mlp_linear_first
if
self
.
mlp_hidden_dims
is
not
None
:
self
.
mlp
=
MLP
(
mlp_hidden_dims
,
mlp_activation
,
mlp_dropout
,
mlp_batchnorm
,
mlp_batchnorm_last
,
mlp_linear_first
,
)
else
:
self
.
mlp
=
None
if
self
.
mlp
is
not
None
:
self
.
decoder
=
DenseResnet
(
mlp_hidden_dims
[
-
1
],
blocks_dims
,
blocks_dropout
,
self
.
simplify_blocks
)
else
:
self
.
decoder
=
DenseResnet
(
blocks_dims
[
0
],
blocks_dims
,
blocks_dropout
,
self
.
simplify_blocks
)
self
.
reconstruction_layer
=
nn
.
Linear
(
blocks_dims
[
-
1
],
embed_dim
,
bias
=
False
)
def
forward
(
self
,
X
:
Tensor
)
->
Tensor
:
x
=
self
.
mlp
(
X
)
if
self
.
mlp
is
not
None
else
X
return
self
.
reconstruction_layer
(
self
.
decoder
(
x
))
pytorch_widedeep/models/tabular/self_supervised/__init__.py
0 → 100644
浏览文件 @
1690395b
pytorch_widedeep/
self_supervised_training
/_augmentations.py
→
pytorch_widedeep/
models/tabular/self_supervised
/_augmentations.py
浏览文件 @
1690395b
文件已移动
pytorch_widedeep/
self_supervised_training
/_denoise_mlps.py
→
pytorch_widedeep/
models/tabular/self_supervised
/_denoise_mlps.py
浏览文件 @
1690395b
文件已移动
pytorch_widedeep/models/tabular/self_supervised/_random_obfuscator.py
0 → 100644
浏览文件 @
1690395b
import
torch
from
torch
import
nn
class
RandomObfuscator
(
nn
.
Module
):
def
__init__
(
self
,
p
):
super
(
RandomObfuscator
,
self
).
__init__
()
self
.
p
=
p
def
forward
(
self
,
x
):
mask
=
torch
.
bernoulli
(
self
.
p
*
torch
.
ones
(
x
.
shape
)).
to
(
x
.
device
)
masked_input
=
torch
.
mul
(
1
-
mask
,
x
)
return
masked_input
,
mask
pytorch_widedeep/
self_supervised_training/self_supervised
_model.py
→
pytorch_widedeep/
models/tabular/self_supervised/contrastive_denoising
_model.py
浏览文件 @
1690395b
...
...
@@ -2,22 +2,22 @@ from torch import Tensor, nn
from
pytorch_widedeep.wdtypes
import
*
# noqa: F403
from
pytorch_widedeep.models.tabular.mlp._layers
import
MLP
from
pytorch_widedeep.
self_supervised_training
._denoise_mlps
import
(
from
pytorch_widedeep.
models.tabular.self_supervised
._denoise_mlps
import
(
CatSingleMlp
,
ContSingleMlp
,
CatFeaturesMlp
,
ContFeaturesMlp
,
)
from
pytorch_widedeep.
self_supervised_training
._augmentations
import
(
from
pytorch_widedeep.
models.tabular.self_supervised
._augmentations
import
(
mix_up
,
cut_mix
,
)
class
SelfSupervised
Model
(
nn
.
Module
):
class
ContrastiveDenoising
Model
(
nn
.
Module
):
def
__init__
(
self
,
model
:
nn
.
Module
,
model
:
ModelWithAttention
,
encoding_dict
:
Dict
[
str
,
Dict
[
str
,
int
]],
loss_type
:
Literal
[
"contrastive"
,
"denoising"
,
"both"
],
projection_head1_dims
:
Optional
[
List
],
...
...
@@ -27,7 +27,7 @@ class SelfSupervisedModel(nn.Module):
cont_mlp_type
:
Literal
[
"single"
,
"multiple"
],
denoise_mlps_activation
:
str
,
):
super
(
SelfSupervised
Model
,
self
).
__init__
()
super
(
ContrastiveDenoising
Model
,
self
).
__init__
()
self
.
model
=
model
self
.
loss_type
=
loss_type
...
...
@@ -55,19 +55,22 @@ class SelfSupervisedModel(nn.Module):
Optional
[
Tuple
[
Tensor
,
Tensor
]],
]:
# "uncorrupted
branch"
# "uncorrupted
" branch
embed
=
self
.
model
.
_get_embeddings
(
X
)
if
self
.
model
.
with_cls_token
:
embed
[:,
0
,
:]
=
0.0
encoded
=
self
.
model
.
encoder
(
embed
)
# cut_mix and mix_up branch
# cut_mixed and mixed_up branch
if
self
.
training
:
cut_mixed
=
cut_mix
(
X
)
cut_mixed_embed
=
self
.
model
.
_get_embeddings
(
cut_mixed
)
if
self
.
model
.
with_cls_token
:
cut_mixed_embed
[:,
0
,
:]
=
0.0
cut_mixed_embed_mixed_up
=
mix_up
(
cut_mixed_embed
)
encoded_
=
self
.
model
.
encoder
(
cut_mixed_embed_mixed_up
)
else
:
encoded_
=
encoded
.
clone
()
# projections for constrastive loss
if
self
.
loss_type
in
[
"contrastive"
,
"both"
]:
...
...
pytorch_widedeep/models/tabular/self_supervised/encoder_decoder_model.py
0 → 100644
浏览文件 @
1690395b
from
torch
import
Tensor
,
nn
from
pytorch_widedeep.wdtypes
import
*
# noqa: F403
from
pytorch_widedeep.models.tabular.self_supervised._random_obfuscator
import
(
RandomObfuscator
,
)
class
EncoderDecoderModel
(
nn
.
Module
):
def
__init__
(
self
,
encoder
:
ModelWithoutAttention
,
decoder
:
nn
.
Module
,
masked_prob
:
float
,
):
super
(
EncoderDecoderModel
,
self
).
__init__
()
self
.
encoder
=
encoder
self
.
decoder
=
decoder
self
.
masker
=
RandomObfuscator
(
p
=
masked_prob
)
def
forward
(
self
,
X
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
if
self
.
encoder
.
is_tabnet
:
return
self
.
_forward_tabnet
(
X
)
else
:
return
self
.
_forward
(
X
)
def
_forward
(
self
,
X
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
x_embed
=
self
.
encoder
.
_get_embeddings
(
X
)
if
self
.
training
:
masked_x
,
mask
=
self
.
masker
(
x_embed
)
x_enc
=
self
.
encoder
(
X
)
x_embed_rec
=
self
.
decoder
(
x_enc
)
else
:
x_embed_rec
=
self
.
decoder
(
self
.
encoder
(
X
))
mask
=
torch
.
ones
(
x_embed
.
shape
).
to
(
X
.
device
)
return
x_embed
,
x_embed_rec
,
mask
def
_forward_tabnet
(
self
,
X
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
x_embed
=
self
.
encoder
.
_get_embeddings
(
X
)
if
self
.
training
:
masked_x
,
mask
=
self
.
masker
(
x_embed
)
prior
=
1
-
mask
steps_out
,
_
=
self
.
encoder
(
masked_x
,
prior
=
prior
)
x_embed_rec
=
self
.
decoder
(
steps_out
)
else
:
steps_out
,
_
=
self
.
encoder
(
x_embed
)
x_embed_rec
=
self
.
decoder
(
steps_out
)
mask
=
torch
.
ones
(
x_embed
.
shape
).
to
(
X
.
device
)
return
x_embed_rec
,
x_embed
,
mask
pytorch_widedeep/models/tabular/tabnet/_layers.py
浏览文件 @
1690395b
...
...
@@ -294,9 +294,12 @@ class TabNetEncoder(nn.Module):
self
.
feat_transformers
.
append
(
feat_transformer
)
self
.
attn_transformers
.
append
(
attn_transformer
)
def
forward
(
self
,
X
:
Tensor
)
->
Tuple
[
List
[
Tensor
],
Tensor
]:
def
forward
(
self
,
X
:
Tensor
,
prior
:
Optional
[
Tensor
]
=
None
)
->
Tuple
[
List
[
Tensor
],
Tensor
]:
x
=
self
.
initial_bn
(
X
)
if
prior
is
None
:
# P[n_step = 0] is initialized as all ones, 1^(B×D)
prior
=
torch
.
ones
(
x
.
shape
).
to
(
x
.
device
)
...
...
pytorch_widedeep/models/tabular/tabnet/tab_net.py
浏览文件 @
1690395b
...
...
@@ -4,6 +4,7 @@ from torch import nn
from
pytorch_widedeep.wdtypes
import
*
# noqa: F403
from
pytorch_widedeep.models.tabular.tabnet._layers
import
(
TabNetEncoder
,
FeatTransformer
,
initialize_non_glu
,
)
from
pytorch_widedeep.models.tabular._base_tabular_model
import
(
...
...
@@ -189,9 +190,11 @@ class TabNet(BaseTabularModelWithoutAttention):
mask_type
,
)
def
forward
(
self
,
X
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
]:
def
forward
(
self
,
X
:
Tensor
,
prior
:
Optional
[
Tensor
]
=
None
)
->
Tuple
[
Tensor
,
Tensor
]:
x
=
self
.
_get_embeddings
(
X
)
steps_output
,
M_loss
=
self
.
encoder
(
x
)
steps_output
,
M_loss
=
self
.
encoder
(
x
,
prior
)
res
=
torch
.
sum
(
torch
.
stack
(
steps_output
,
dim
=
0
),
dim
=
0
)
return
(
res
,
M_loss
)
...
...
@@ -223,3 +226,78 @@ class TabNetPredLayer(nn.Module):
def
forward
(
self
,
tabnet_tuple
:
Tuple
[
Tensor
,
Tensor
])
->
Tuple
[
Tensor
,
Tensor
]:
res
,
M_loss
=
tabnet_tuple
[
0
],
tabnet_tuple
[
1
]
return
self
.
pred_layer
(
res
),
M_loss
class
TabNetDecoder
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
:
int
,
n_steps
:
int
=
3
,
step_dim
:
int
=
8
,
attn_dim
:
int
=
8
,
dropout
:
float
=
0.0
,
n_glu_step_dependent
:
int
=
2
,
n_glu_shared
:
int
=
2
,
ghost_bn
:
bool
=
True
,
virtual_batch_size
:
int
=
128
,
momentum
:
float
=
0.02
,
gamma
:
float
=
1.3
,
epsilon
:
float
=
1e-15
,
mask_type
:
str
=
"sparsemax"
,
):
super
(
TabNetDecoder
,
self
).
__init__
()
self
.
n_steps
=
n_steps
self
.
step_dim
=
step_dim
self
.
attn_dim
=
attn_dim
self
.
dropout
=
dropout
self
.
n_glu_step_dependent
=
n_glu_step_dependent
self
.
n_glu_shared
=
n_glu_shared
self
.
ghost_bn
=
ghost_bn
self
.
virtual_batch_size
=
virtual_batch_size
self
.
momentum
=
momentum
self
.
gamma
=
gamma
self
.
epsilon
=
epsilon
self
.
mask_type
=
mask_type
shared_layers
=
nn
.
ModuleList
()
for
i
in
range
(
n_glu_shared
):
if
i
==
0
:
shared_layers
.
append
(
nn
.
Linear
(
embed_dim
,
2
*
(
step_dim
+
attn_dim
),
bias
=
False
)
)
else
:
shared_layers
.
append
(
nn
.
Linear
(
step_dim
+
attn_dim
,
2
*
(
step_dim
+
attn_dim
),
bias
=
False
)
)
self
.
feat_transformers
=
nn
.
ModuleList
()
for
step
in
range
(
n_steps
):
transformer
=
FeatTransformer
(
embed_dim
,
embed_dim
,
dropout
,
shared_layers
,
n_glu_step_dependent
,
ghost_bn
,
virtual_batch_size
,
momentum
=
momentum
,
)
self
.
feat_transformers
.
append
(
transformer
)
self
.
reconstruction_layer
=
nn
.
Linear
(
step_dim
,
embed_dim
,
bias
=
False
)
initialize_non_glu
(
self
.
reconstruction_layer
,
step_dim
,
embed_dim
)
def
forward
(
self
,
X
):
out
=
0.0
for
i
,
x
in
enumerate
(
X
):
x
=
self
.
feat_transformers
[
step_nb
](
x
)
out
=
torch
.
add
(
out
,
x
)
out
=
self
.
reconstruction_layer
(
out
)
return
out
def
forward_masks
(
self
,
X
:
Tensor
)
->
Tuple
[
Tensor
,
Dict
[
int
,
Tensor
]]:
x
=
self
.
_get_embeddings
(
X
)
return
self
.
encoder
.
forward_masks
(
x
)
pytorch_widedeep/self_supervised_training/_base_
self_supervised
_trainer.py
→
pytorch_widedeep/self_supervised_training/_base_
contrastive_denoising
_trainer.py
浏览文件 @
1690395b
...
...
@@ -2,10 +2,12 @@ import os
import
sys
from
abc
import
ABC
,
abstractmethod
import
numpy
as
np
import
torch
from
torch.optim.lr_scheduler
import
ReduceLROnPlateau
from
pytorch_widedeep.losses
import
InfoNCELoss
,
DenoisingLoss
from
pytorch_widedeep.wdtypes
import
*
# noqa: F403
from
pytorch_widedeep.wdtypes
import
*
# noqa: F403
; noqa: F403
from
pytorch_widedeep.callbacks
import
(
History
,
Callback
,
...
...
@@ -13,15 +15,15 @@ from pytorch_widedeep.callbacks import (
LRShedulerCallback
,
)
from
pytorch_widedeep.preprocessing.tab_preprocessor
import
TabPreprocessor
from
pytorch_widedeep.
self_supervised_training.self_supervised
_model
import
(
SelfSupervised
Model
,
from
pytorch_widedeep.
models.tabular.self_supervised.contrastive_denoising
_model
import
(
ContrastiveDenoising
Model
,
)
class
Base
SelfSupervised
Trainer
(
ABC
):
class
Base
ContrastiveDenoising
Trainer
(
ABC
):
def
__init__
(
self
,
model
,
base_model
:
ModelWithAttention
,
preprocessor
:
TabPreprocessor
,
optimizer
:
Optional
[
Optimizer
],
lr_scheduler
:
Optional
[
LRScheduler
],
...
...
@@ -38,8 +40,15 @@ class BaseSelfSupervisedTrainer(ABC):
**
kwargs
,
):
self
.
ss_model
=
SelfSupervisedModel
(
model
,
self
.
_check_model_is_supported
(
base_model
)
self
.
device
,
self
.
num_workers
=
self
.
_set_device_and_num_workers
(
**
kwargs
)
self
.
early_stop
=
False
self
.
verbose
=
verbose
self
.
seed
=
seed
self
.
model
=
ContrastiveDenoisingModel
(
base_model
,
preprocessor
.
label_encoder
.
encoding_dict
,
loss_type
,
projection_head1_dims
,
...
...
@@ -49,31 +58,38 @@ class BaseSelfSupervisedTrainer(ABC):
cont_mlp_type
,
denoise_mlps_activation
,
)
self
.
device
,
self
.
num_workers
=
self
.
_set_device_and_num_workers
(
**
kwargs
)
self
.
early_stop
=
False
self
.
ss_model
.
to
(
self
.
device
)
self
.
model
.
to
(
self
.
device
)
self
.
loss_type
=
loss_type
self
.
_set_loss_fn
(
**
kwargs
)
self
.
verbose
=
verbose
self
.
seed
=
seed
self
.
optimizer
=
(
optimizer
if
optimizer
is
not
None
else
torch
.
optim
.
AdamW
(
self
.
ss_model
.
parameters
())
)
self
.
lr_scheduler
=
self
.
_set_lr_scheduler_running_params
(
lr_scheduler
,
**
kwargs
else
torch
.
optim
.
AdamW
(
self
.
model
.
parameters
())
)
self
.
lr_scheduler
=
self
.
_set_lr_scheduler_running_params
(
lr_scheduler
)
self
.
_set_callbacks
(
callbacks
)
@
abstractmethod
def
pretrain
(
self
):
pass
def
pretrain
(
self
,
X_tab
:
np
.
ndarray
,
X_val
:
Optional
[
np
.
ndarray
],
val_split
:
Optional
[
float
],
validation_freq
:
int
,
n_epochs
:
int
,
batch_size
:
int
,
):
raise
NotImplementedError
(
"Trainer.pretrain method not implemented"
)
@
abstractmethod
def
save
(
self
,
path
:
str
,
save_state_dict
:
bool
,
model_filename
:
str
,
):
raise
NotImplementedError
(
"Trainer.save method not implemented"
)
def
_set_loss_fn
(
self
,
**
kwargs
):
...
...
@@ -102,7 +118,29 @@ class BaseSelfSupervisedTrainer(ABC):
return
contrastive_loss
+
denoising_loss
def
_set_reduce_on_plateau_criterion
(
self
,
lr_scheduler
,
reducelronplateau_criterion
):
self
.
reducelronplateau
=
False
if
isinstance
(
lr_scheduler
,
ReduceLROnPlateau
):
self
.
reducelronplateau
=
True
if
self
.
reducelronplateau
and
not
reducelronplateau_criterion
:
UserWarning
(
"The learning rate scheduler is of type ReduceLROnPlateau. The step method in this"
" scheduler requires a 'metrics' param that can be either the validation loss or the"
" validation metric. Please, when instantiating the Trainer, specify which quantity"
" will be tracked using reducelronplateau_criterion = 'loss' (default) or"
" reducelronplateau_criterion = 'metric'"
)
else
:
self
.
reducelronplateau_criterion
=
"loss"
def
_set_lr_scheduler_running_params
(
self
,
lr_scheduler
,
**
kwargs
):
reducelronplateau_criterion
=
kwargs
.
get
(
"reducelronplateau_criterion"
,
None
)
self
.
_set_reduce_on_plateau_criterion
(
lr_scheduler
,
reducelronplateau_criterion
)
if
lr_scheduler
is
not
None
:
self
.
cyclic_lr
=
"cycl"
in
lr_scheduler
.
__class__
.
__name__
.
lower
()
else
:
...
...
@@ -116,7 +154,7 @@ class BaseSelfSupervisedTrainer(ABC):
callback
=
callback
()
self
.
callbacks
.
append
(
callback
)
self
.
callback_container
=
CallbackContainer
(
self
.
callbacks
)
self
.
callback_container
.
set_model
(
self
.
ss_
model
)
self
.
callback_container
.
set_model
(
self
.
model
)
self
.
callback_container
.
set_trainer
(
self
)
def
_restore_best_weights
(
self
):
...
...
@@ -139,7 +177,7 @@ class BaseSelfSupervisedTrainer(ABC):
print
(
f
"Model weights restored to best epoch:
{
callback
.
best_epoch
+
1
}
"
)
self
.
ss_
model
.
load_state_dict
(
callback
.
best_state_dict
)
self
.
model
.
load_state_dict
(
callback
.
best_state_dict
)
else
:
if
self
.
verbose
:
print
(
...
...
@@ -160,3 +198,16 @@ class BaseSelfSupervisedTrainer(ABC):
device
=
kwargs
.
get
(
"device"
,
default_device
)
num_workers
=
kwargs
.
get
(
"num_workers"
,
default_num_workers
)
return
device
,
num_workers
@
staticmethod
def
_check_model_is_supported
(
model
:
ModelWithAttention
):
if
model
.
__class__
.
__name__
==
"TabPerceiver"
:
raise
ValueError
(
"Self-Supervised pretraining is not supported for the 'TabPerceiver'"
)
if
model
.
__class__
.
__name__
==
"TabTransformer"
and
not
model
.
embed_continuous
:
raise
ValueError
(
"Self-Supervised pretraining is only supported if both categorical and "
"continuum columns are embedded. Please set 'embed_continuous = True'"
)
pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py
0 → 100644
浏览文件 @
1690395b
import
os
import
sys
from
abc
import
ABC
,
abstractmethod
import
numpy
as
np
import
torch
from
torch
import
nn
from
torch.optim.lr_scheduler
import
ReduceLROnPlateau
from
pytorch_widedeep.losses
import
EncoderDecoderLoss
from
pytorch_widedeep.wdtypes
import
*
# noqa: F403; noqa: F403
from
pytorch_widedeep.callbacks
import
(
History
,
Callback
,
CallbackContainer
,
LRShedulerCallback
,
)
from
pytorch_widedeep.self_supervised_training.self_supervised_models
import
(
EncoderDecoderModel
,
)
class
BaseEncoderDecoderTrainer
(
ABC
):
def
__init__
(
self
,
encoder
:
ModelWithoutAttention
,
decoder
:
nn
.
Module
,
masked_prob
:
float
,
optimizer
:
Optional
[
Optimizer
],
lr_scheduler
:
Optional
[
LRScheduler
],
callbacks
:
Optional
[
List
[
Callback
]],
verbose
:
int
,
seed
:
int
,
**
kwargs
,
):
# self._check_model_is_supported(encoder)
self
.
device
,
self
.
num_workers
=
self
.
_set_device_and_num_workers
(
**
kwargs
)
self
.
early_stop
=
False
self
.
verbose
=
verbose
self
.
seed
=
seed
self
.
model
=
EncoderDecoderModel
(
encoder
,
decoder
,
masked_prob
,
)
self
.
model
.
to
(
self
.
device
)
self
.
loss_fn
=
EncoderDecoderLoss
()
self
.
optimizer
=
(
optimizer
if
optimizer
is
not
None
else
torch
.
optim
.
AdamW
(
self
.
model
.
parameters
())
)
self
.
lr_scheduler
=
self
.
_set_lr_scheduler_running_params
(
lr_scheduler
)
self
.
_set_callbacks
(
callbacks
)
@
abstractmethod
def
pretrain
(
self
,
X_tab
:
np
.
ndarray
,
X_val
:
Optional
[
np
.
ndarray
],
val_split
:
Optional
[
float
],
validation_freq
:
int
,
n_epochs
:
int
,
batch_size
:
int
,
):
raise
NotImplementedError
(
"Trainer.pretrain method not implemented"
)
@
abstractmethod
def
save
(
self
,
path
:
str
,
save_state_dict
:
bool
,
model_filename
:
str
,
):
raise
NotImplementedError
(
"Trainer.save method not implemented"
)
def
_set_reduce_on_plateau_criterion
(
self
,
lr_scheduler
,
reducelronplateau_criterion
):
self
.
reducelronplateau
=
False
if
isinstance
(
lr_scheduler
,
ReduceLROnPlateau
):
self
.
reducelronplateau
=
True
if
self
.
reducelronplateau
and
not
reducelronplateau_criterion
:
UserWarning
(
"The learning rate scheduler is of type ReduceLROnPlateau. The step method in this"
" scheduler requires a 'metrics' param that can be either the validation loss or the"
" validation metric. Please, when instantiating the Trainer, specify which quantity"
" will be tracked using reducelronplateau_criterion = 'loss' (default) or"
" reducelronplateau_criterion = 'metric'"
)
else
:
self
.
reducelronplateau_criterion
=
"loss"
def
_set_lr_scheduler_running_params
(
self
,
lr_scheduler
,
**
kwargs
):
reducelronplateau_criterion
=
kwargs
.
get
(
"reducelronplateau_criterion"
,
None
)
self
.
_set_reduce_on_plateau_criterion
(
lr_scheduler
,
reducelronplateau_criterion
)
if
lr_scheduler
is
not
None
:
self
.
cyclic_lr
=
"cycl"
in
lr_scheduler
.
__class__
.
__name__
.
lower
()
else
:
self
.
cyclic_lr
=
False
def
_set_callbacks
(
self
,
callbacks
):
self
.
callbacks
:
List
=
[
History
(),
LRShedulerCallback
()]
if
callbacks
is
not
None
:
for
callback
in
callbacks
:
if
isinstance
(
callback
,
type
):
callback
=
callback
()
self
.
callbacks
.
append
(
callback
)
self
.
callback_container
=
CallbackContainer
(
self
.
callbacks
)
self
.
callback_container
.
set_model
(
self
.
model
)
self
.
callback_container
.
set_trainer
(
self
)
def
_restore_best_weights
(
self
):
already_restored
=
any
(
[
(
callback
.
__class__
.
__name__
==
"EarlyStopping"
and
callback
.
restore_best_weights
)
for
callback
in
self
.
callback_container
.
callbacks
]
)
if
already_restored
:
pass
else
:
for
callback
in
self
.
callback_container
.
callbacks
:
if
callback
.
__class__
.
__name__
==
"ModelCheckpoint"
:
if
callback
.
save_best_only
:
if
self
.
verbose
:
print
(
f
"Model weights restored to best epoch:
{
callback
.
best_epoch
+
1
}
"
)
self
.
model
.
load_state_dict
(
callback
.
best_state_dict
)
else
:
if
self
.
verbose
:
print
(
"Model weights after training corresponds to the those of the "
"final epoch which might not be the best performing weights. Use "
"the 'ModelCheckpoint' Callback to restore the best epoch weights."
)
@
staticmethod
def
_set_device_and_num_workers
(
**
kwargs
):
default_num_workers
=
(
0
if
sys
.
platform
==
"darwin"
and
sys
.
version_info
.
minor
>
7
else
os
.
cpu_count
()
)
default_device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
kwargs
.
get
(
"device"
,
default_device
)
num_workers
=
kwargs
.
get
(
"num_workers"
,
default_num_workers
)
return
device
,
num_workers
pytorch_widedeep/self_supervised_training/
self_supervised
_trainer.py
→
pytorch_widedeep/self_supervised_training/
contrastive_denoising
_trainer.py
浏览文件 @
1690395b
import
json
from
pathlib
import
Path
import
numpy
as
np
import
torch
from
tqdm
import
trange
from
torch.utils.data
import
DataLoader
,
TensorDataset
from
sklearn.model_selection
import
train_test_split
from
pytorch_widedeep.wdtypes
import
*
# noqa: F403
from
pytorch_widedeep.callbacks
import
Callback
from
pytorch_widedeep.preprocessing
import
TabPreprocessor
from
pytorch_widedeep.training._trainer_utils
import
(
save_epoch_logs
,
print_loss_and_metric
,
)
from
pytorch_widedeep.self_supervised_training._base_
self_supervised
_trainer
import
(
Base
SelfSupervised
Trainer
,
from
pytorch_widedeep.self_supervised_training._base_
contrastive_denoising
_trainer
import
(
Base
ContrastiveDenoising
Trainer
,
)
class
SelfSupervisedTrainer
(
BaseSelfSupervised
Trainer
):
class
ContrastiveDenoisingTrainer
(
BaseContrastiveDenoising
Trainer
):
def
__init__
(
self
,
model
,
preprocessor
,
base_model
:
ModelWithAttention
,
preprocessor
:
TabPreprocessor
,
optimizer
:
Optional
[
Optimizer
]
=
None
,
lr_scheduler
:
Optional
[
LRScheduler
]
=
None
,
callbacks
:
Optional
[
List
[
Callback
]]
=
None
,
...
...
@@ -34,7 +39,7 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
**
kwargs
,
):
super
().
__init__
(
model
=
model
,
base_model
=
base_
model
,
preprocessor
=
preprocessor
,
loss_type
=
loss_type
,
optimizer
=
optimizer
,
...
...
@@ -54,18 +59,28 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
def
pretrain
(
self
,
X_tab
:
np
.
ndarray
,
X_val
:
Optional
[
np
.
ndarray
]
=
None
,
val_split
:
Optional
[
float
]
=
None
,
validation_freq
:
int
=
1
,
n_epochs
:
int
=
1
,
batch_size
:
int
=
32
,
):
self
.
batch_size
=
batch_size
pretrain_loader
=
DataLoader
(
dataset
=
TensorDataset
(
torch
.
from_numpy
(
X_tab
)),
train_set
,
eval_set
=
self
.
_train_eval_split
(
X_tab
,
X_val
,
val_split
)
train_loader
=
DataLoader
(
dataset
=
train_set
,
batch_size
=
batch_size
,
num_workers
=
self
.
num_workers
)
train_steps
=
len
(
train_loader
)
if
eval_set
is
not
None
:
eval_loader
=
DataLoader
(
dataset
=
eval_set
,
batch_size
=
batch_size
,
num_workers
=
self
.
num_workers
,
shuffle
=
False
,
)
train_steps
=
len
(
pretrain
_loader
)
eval_steps
=
len
(
eval
_loader
)
self
.
callback_container
.
on_train_begin
(
{
...
...
@@ -80,13 +95,31 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
self
.
train_running_loss
=
0.0
with
trange
(
train_steps
,
disable
=
self
.
verbose
!=
1
)
as
t
:
for
batch_idx
,
X
in
zip
(
t
,
pre
train_loader
):
for
batch_idx
,
X
in
zip
(
t
,
train_loader
):
t
.
set_description
(
"epoch %i"
%
(
epoch
+
1
))
train_loss
=
self
.
_
pre
train_step
(
X
[
0
],
batch_idx
)
train_loss
=
self
.
_train_step
(
X
[
0
],
batch_idx
)
self
.
callback_container
.
on_batch_end
(
batch
=
batch_idx
)
print_loss_and_metric
(
t
,
train_loss
)
epoch_logs
=
save_epoch_logs
(
epoch_logs
,
train_loss
,
None
,
"train"
)
if
eval_set
is
not
None
and
epoch
%
validation_freq
==
(
validation_freq
-
1
):
self
.
callback_container
.
on_eval_begin
()
self
.
valid_running_loss
=
0.0
with
trange
(
eval_steps
,
disable
=
self
.
verbose
!=
1
)
as
v
:
for
batch_idx
,
X
in
zip
(
v
,
eval_loader
):
v
.
set_description
(
"valid"
)
val_loss
=
self
.
_eval_step
(
X
[
0
],
batch_idx
)
print_loss_and_metric
(
v
,
val_loss
)
epoch_logs
=
save_epoch_logs
(
epoch_logs
,
val_loss
,
None
,
"val"
)
else
:
if
self
.
reducelronplateau
:
raise
NotImplementedError
(
"ReduceLROnPlateau scheduler can be used only with validation data."
)
self
.
callback_container
.
on_epoch_end
(
epoch
,
epoch_logs
)
if
self
.
early_stop
:
...
...
@@ -95,14 +128,41 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
self
.
callback_container
.
on_train_end
(
epoch_logs
)
self
.
_restore_best_weights
()
self
.
ss_model
.
train
()
self
.
model
.
train
()
def
save
(
self
,
path
:
str
,
save_state_dict
:
bool
=
False
,
model_filename
:
str
=
"cd_model.pt"
,
):
save_dir
=
Path
(
path
)
history_dir
=
save_dir
/
"history"
history_dir
.
mkdir
(
exist_ok
=
True
,
parents
=
True
)
def
_pretrain_step
(
self
,
X_tab
:
Tensor
,
batch_idx
:
int
):
# the trainer is run with the History Callback by default
with
open
(
history_dir
/
"train_eval_history.json"
,
"w"
)
as
teh
:
json
.
dump
(
self
.
history
,
teh
)
# type: ignore[attr-defined]
has_lr_history
=
any
(
[
clbk
.
__class__
.
__name__
==
"LRHistory"
for
clbk
in
self
.
callbacks
]
)
if
self
.
lr_scheduler
is
not
None
and
has_lr_history
:
with
open
(
history_dir
/
"lr_history.json"
,
"w"
)
as
lrh
:
json
.
dump
(
self
.
lr_history
,
lrh
)
# type: ignore[attr-defined]
model_path
=
save_dir
/
model_filename
if
save_state_dict
:
torch
.
save
(
self
.
model
.
state_dict
(),
model_path
)
else
:
torch
.
save
(
self
.
model
,
model_path
)
def
_train_step
(
self
,
X_tab
:
Tensor
,
batch_idx
:
int
):
X
=
X_tab
.
to
(
self
.
device
)
self
.
optimizer
.
zero_grad
()
g_projs
,
cat_x_and_x_
,
cont_x_and_x_
=
self
.
ss_
model
(
X
)
g_projs
,
cat_x_and_x_
,
cont_x_and_x_
=
self
.
model
(
X
)
loss
=
self
.
_compute_loss
(
g_projs
,
cat_x_and_x_
,
cont_x_and_x_
)
loss
.
backward
()
self
.
optimizer
.
step
()
...
...
@@ -111,3 +171,40 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
avg_loss
=
self
.
train_running_loss
/
(
batch_idx
+
1
)
return
avg_loss
def
_eval_step
(
self
,
X_tab
:
Tensor
,
batch_idx
:
int
):
self
.
model
.
eval
()
with
torch
.
no_grad
():
X
=
X_tab
.
to
(
self
.
device
)
g_projs
,
cat_x_and_x_
,
cont_x_and_x_
=
self
.
model
(
X
)
loss
=
self
.
_compute_loss
(
g_projs
,
cat_x_and_x_
,
cont_x_and_x_
)
self
.
valid_running_loss
+=
loss
.
item
()
avg_loss
=
self
.
valid_running_loss
/
(
batch_idx
+
1
)
return
avg_loss
def
_train_eval_split
(
self
,
X
:
np
.
ndarray
,
X_val
:
Optional
[
np
.
ndarray
]
=
None
,
val_split
:
Optional
[
float
]
=
None
,
):
if
X_val
is
not
None
:
train_set
=
TensorDataset
(
torch
.
from_numpy
(
X
))
eval_set
=
TensorDataset
(
torch
.
from_numpy
(
X_val
))
elif
val_split
is
not
None
:
X_tr
,
X_val
=
train_test_split
(
X
,
test_size
=
val_split
,
random_state
=
self
.
seed
)
train_set
=
TensorDataset
(
torch
.
from_numpy
(
X_tr
))
eval_set
=
TensorDataset
(
torch
.
from_numpy
(
X_val
))
else
:
train_set
=
TensorDataset
(
torch
.
from_numpy
(
X
))
eval_set
=
None
return
train_set
,
eval_set
pytorch_widedeep/self_supervised_training/encoder_decoder_trainer.py
0 → 100644
浏览文件 @
1690395b
import
json
from
pathlib
import
Path
import
numpy
as
np
import
torch
from
tqdm
import
trange
from
torch
import
nn
from
scipy.sparse
import
csc_matrix
from
torch.utils.data
import
DataLoader
,
TensorDataset
from
sklearn.model_selection
import
train_test_split
from
pytorch_widedeep.wdtypes
import
*
# noqa: F403
from
pytorch_widedeep.callbacks
import
Callback
from
pytorch_widedeep.training._trainer_utils
import
(
save_epoch_logs
,
print_loss_and_metric
,
)
from
pytorch_widedeep.self_supervised_training._base_encoder_decoder_trainer
import
(
BaseEncoderDecoderTrainer
,
)
class
EncoderDecoderTrainer
(
BaseEncoderDecoderTrainer
):
def
__init__
(
self
,
encoder
:
ModelWithoutAttention
,
decoder
:
nn
.
Module
,
masked_prob
:
float
,
optimizer
:
Optional
[
Optimizer
],
lr_scheduler
:
Optional
[
LRScheduler
],
callbacks
:
Optional
[
List
[
Callback
]],
verbose
:
int
,
seed
:
int
,
**
kwargs
,
):
super
().
__init__
(
encoder
=
encoder
,
decoder
=
decoder
,
masked_prob
=
masked_prob
,
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
callbacks
=
callbacks
,
verbose
=
verbose
,
seed
=
seed
,
**
kwargs
,
)
def
pretrain
(
self
,
X_tab
:
np
.
ndarray
,
X_val
:
Optional
[
np
.
ndarray
]
=
None
,
val_split
:
Optional
[
float
]
=
None
,
validation_freq
:
int
=
1
,
n_epochs
:
int
=
1
,
batch_size
:
int
=
32
,
):
self
.
batch_size
=
batch_size
train_set
,
eval_set
=
self
.
_train_eval_split
(
X_tab
,
X_val
,
val_split
)
train_loader
=
DataLoader
(
dataset
=
train_set
,
batch_size
=
batch_size
,
num_workers
=
self
.
num_workers
)
train_steps
=
len
(
train_loader
)
if
eval_set
is
not
None
:
eval_loader
=
DataLoader
(
dataset
=
eval_set
,
batch_size
=
batch_size
,
num_workers
=
self
.
num_workers
,
shuffle
=
False
,
)
eval_steps
=
len
(
eval_loader
)
self
.
callback_container
.
on_train_begin
(
{
"batch_size"
:
batch_size
,
"train_steps"
:
train_steps
,
"n_epochs"
:
n_epochs
,
}
)
for
epoch
in
range
(
n_epochs
):
epoch_logs
:
Dict
[
str
,
float
]
=
{}
self
.
callback_container
.
on_epoch_begin
(
epoch
,
logs
=
epoch_logs
)
self
.
train_running_loss
=
0.0
with
trange
(
train_steps
,
disable
=
self
.
verbose
!=
1
)
as
t
:
for
batch_idx
,
X
in
zip
(
t
,
train_loader
):
t
.
set_description
(
"epoch %i"
%
(
epoch
+
1
))
train_loss
=
self
.
_train_step
(
X
[
0
],
batch_idx
)
self
.
callback_container
.
on_batch_end
(
batch
=
batch_idx
)
print_loss_and_metric
(
t
,
train_loss
)
epoch_logs
=
save_epoch_logs
(
epoch_logs
,
train_loss
,
None
,
"train"
)
if
eval_set
is
not
None
and
epoch
%
validation_freq
==
(
validation_freq
-
1
):
self
.
callback_container
.
on_eval_begin
()
self
.
valid_running_loss
=
0.0
with
trange
(
eval_steps
,
disable
=
self
.
verbose
!=
1
)
as
v
:
for
batch_idx
,
X
in
zip
(
v
,
eval_loader
):
v
.
set_description
(
"valid"
)
val_loss
=
self
.
_eval_step
(
X
[
0
],
batch_idx
)
print_loss_and_metric
(
v
,
val_loss
)
epoch_logs
=
save_epoch_logs
(
epoch_logs
,
val_loss
,
None
,
"val"
)
else
:
if
self
.
reducelronplateau
:
raise
NotImplementedError
(
"ReduceLROnPlateau scheduler can be used only with validation data."
)
self
.
callback_container
.
on_epoch_end
(
epoch
,
epoch_logs
)
if
self
.
early_stop
:
self
.
callback_container
.
on_train_end
(
epoch_logs
)
break
self
.
callback_container
.
on_train_end
(
epoch_logs
)
self
.
_restore_best_weights
()
self
.
model
.
train
()
def
save
(
self
,
path
:
str
,
save_state_dict
:
bool
=
False
,
model_filename
:
str
=
"ed_model.pt"
,
):
save_dir
=
Path
(
path
)
history_dir
=
save_dir
/
"history"
history_dir
.
mkdir
(
exist_ok
=
True
,
parents
=
True
)
# the trainer is run with the History Callback by default
with
open
(
history_dir
/
"train_eval_history.json"
,
"w"
)
as
teh
:
json
.
dump
(
self
.
history
,
teh
)
# type: ignore[attr-defined]
has_lr_history
=
any
(
[
clbk
.
__class__
.
__name__
==
"LRHistory"
for
clbk
in
self
.
callbacks
]
)
if
self
.
lr_scheduler
is
not
None
and
has_lr_history
:
with
open
(
history_dir
/
"lr_history.json"
,
"w"
)
as
lrh
:
json
.
dump
(
self
.
lr_history
,
lrh
)
# type: ignore[attr-defined]
model_path
=
save_dir
/
model_filename
if
save_state_dict
:
torch
.
save
(
self
.
model
.
state_dict
(),
model_path
)
else
:
torch
.
save
(
self
.
model
,
model_path
)
def
explain
(
self
,
X_tab
:
np
.
ndarray
,
save_step_masks
:
bool
=
False
):
# TO DO: Adjust this to Self Supervised (e.g. no need of data["deeptabular"])
loader
=
DataLoader
(
dataset
=
self
.
_train_eval_split
(
X_tab
),
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
shuffle
=
False
,
)
self
.
model
.
eval
()
tabnet_backbone
=
list
(
self
.
encoder
.
children
())[
0
]
m_explain_l
=
[]
for
batch_nb
,
data
in
enumerate
(
loader
):
X
=
data
[
"deeptabular"
].
to
(
self
.
device
)
M_explain
,
masks
=
tabnet_backbone
.
forward_masks
(
X
)
m_explain_l
.
append
(
csc_matrix
.
dot
(
M_explain
.
cpu
().
detach
().
numpy
(),
self
.
reducing_matrix
)
)
if
save_step_masks
:
for
key
,
value
in
masks
.
items
():
masks
[
key
]
=
csc_matrix
.
dot
(
value
.
cpu
().
detach
().
numpy
(),
self
.
reducing_matrix
)
if
batch_nb
==
0
:
m_explain_step
=
masks
else
:
for
key
,
value
in
masks
.
items
():
m_explain_step
[
key
]
=
np
.
vstack
([
m_explain_step
[
key
],
value
])
m_explain_agg
=
np
.
vstack
(
m_explain_l
)
m_explain_agg_norm
=
m_explain_agg
/
m_explain_agg
.
sum
(
axis
=
1
)[:,
np
.
newaxis
]
res
=
(
(
m_explain_agg_norm
,
m_explain_step
)
if
save_step_masks
else
np
.
vstack
(
m_explain_agg_norm
)
)
return
res
def
_train_step
(
self
,
X_tab
:
Tensor
,
batch_idx
:
int
):
X
=
X_tab
.
to
(
self
.
device
)
self
.
optimizer
.
zero_grad
()
x_embed
,
x_embed_rec
,
mask
=
self
.
model
(
X
)
loss
=
self
.
loss_fn
(
x_embed
,
x_embed_rec
,
mask
)
loss
.
backward
()
self
.
optimizer
.
step
()
self
.
train_running_loss
+=
loss
.
item
()
avg_loss
=
self
.
train_running_loss
/
(
batch_idx
+
1
)
return
avg_loss
def
_eval_step
(
self
,
X_tab
:
Tensor
,
batch_idx
:
int
):
self
.
model
.
eval
()
with
torch
.
no_grad
():
X
=
X_tab
.
to
(
self
.
device
)
x_embed
,
x_embed_rec
,
mask
=
self
.
model
(
X
)
loss
=
self
.
loss_fn
(
x_embed
,
x_embed_rec
,
mask
)
self
.
valid_running_loss
+=
loss
.
item
()
avg_loss
=
self
.
valid_running_loss
/
(
batch_idx
+
1
)
return
avg_loss
def
_train_eval_split
(
self
,
X
:
np
.
ndarray
,
X_val
:
Optional
[
np
.
ndarray
]
=
None
,
val_split
:
Optional
[
float
]
=
None
,
):
if
X_val
is
not
None
:
train_set
=
TensorDataset
(
torch
.
from_numpy
(
X
))
eval_set
=
TensorDataset
(
torch
.
from_numpy
(
X_val
))
elif
val_split
is
not
None
:
X_tr
,
X_val
=
train_test_split
(
X
,
test_size
=
val_split
,
random_state
=
self
.
seed
)
train_set
=
TensorDataset
(
torch
.
from_numpy
(
X_tr
))
eval_set
=
TensorDataset
(
torch
.
from_numpy
(
X_val
))
else
:
train_set
=
TensorDataset
(
torch
.
from_numpy
(
X
))
eval_set
=
None
return
train_set
,
eval_set
pytorch_widedeep/training/_base_bayesian_trainer.py
0 → 100644
浏览文件 @
1690395b
import
os
import
sys
from
abc
import
ABC
,
abstractmethod
import
numpy
as
np
import
torch
from
torchmetrics
import
Metric
as
TorchMetric
from
torch.optim.lr_scheduler
import
ReduceLROnPlateau
from
pytorch_widedeep.metrics
import
Metric
,
MultipleMetrics
from
pytorch_widedeep.wdtypes
import
*
# noqa: F403
from
pytorch_widedeep.callbacks
import
(
History
,
Callback
,
MetricCallback
,
CallbackContainer
,
LRShedulerCallback
,
)
from
pytorch_widedeep.training._trainer_utils
import
bayesian_alias_to_loss
from
pytorch_widedeep.bayesian_models._base_bayesian_model
import
(
BaseBayesianModel
,
)
# There are some nuances in the Bayesian Trainer that make it hard to build an
# overall BaseTrainer. We could still perhaps code a very basic Trainer and
# then pass it to a BaseTrainer and BaseBayesianTrainer. However in this
# particular case we prefer code repetition as we believe is a simpler
# solution
class
BaseBayesianTrainer
(
ABC
):
def
__init__
(
self
,
model
:
BaseBayesianModel
,
objective
:
str
,
custom_loss_function
:
Optional
[
Module
],
optimizer
:
Optimizer
,
lr_scheduler
:
LRScheduler
,
callbacks
:
Optional
[
List
[
Callback
]],
metrics
:
Optional
[
Union
[
List
[
Metric
],
List
[
TorchMetric
]]],
verbose
:
int
,
seed
:
int
,
**
kwargs
,
):
if
objective
not
in
[
"binary"
,
"multiclass"
,
"regression"
]:
raise
ValueError
(
"If 'custom_loss_function' is not None, 'objective' must be 'binary' "
"'multiclass' or 'regression', consistent with the loss function"
)
self
.
device
,
self
.
num_workers
=
self
.
_set_device_and_num_workers
(
**
kwargs
)
self
.
early_stop
=
False
self
.
model
=
model
self
.
model
.
to
(
self
.
device
)
self
.
verbose
=
verbose
self
.
seed
=
seed
self
.
objective
=
objective
self
.
loss_fn
=
self
.
_set_loss_fn
(
objective
,
custom_loss_function
,
**
kwargs
)
self
.
optimizer
=
(
optimizer
if
optimizer
is
not
None
else
torch
.
optim
.
AdamW
(
self
.
model
.
parameters
())
)
self
.
lr_scheduler
=
lr_scheduler
self
.
_set_lr_scheduler_running_params
(
lr_scheduler
,
**
kwargs
)
self
.
_set_callbacks_and_metrics
(
callbacks
,
metrics
)
@
abstractmethod
def
fit
(
self
,
X_tab
:
np
.
ndarray
,
target
:
np
.
ndarray
,
X_tab_val
:
Optional
[
np
.
ndarray
],
target_val
:
Optional
[
np
.
ndarray
],
val_split
:
Optional
[
float
],
n_epochs
:
int
,
val_freq
:
int
,
batch_size
:
int
,
n_train_samples
:
int
,
n_val_samples
:
int
,
):
raise
NotImplementedError
(
"Trainer.fit method not implemented"
)
@
abstractmethod
def
predict
(
self
,
X_tab
:
np
.
ndarray
,
n_samples
:
int
,
return_samples
:
bool
,
batch_size
:
int
,
)
->
np
.
ndarray
:
raise
NotImplementedError
(
"Trainer.predict method not implemented"
)
@
abstractmethod
def
predict_proba
(
self
,
X_tab
:
np
.
ndarray
,
n_samples
:
int
,
return_samples
:
bool
,
batch_size
:
int
,
)
->
np
.
ndarray
:
raise
NotImplementedError
(
"Trainer.predict_proba method not implemented"
)
@
abstractmethod
def
save
(
self
,
path
:
str
,
save_state_dict
:
bool
,
model_filename
:
str
,
):
raise
NotImplementedError
(
"Trainer.save method not implemented"
)
def
_restore_best_weights
(
self
):
already_restored
=
any
(
[
(
callback
.
__class__
.
__name__
==
"EarlyStopping"
and
callback
.
restore_best_weights
)
for
callback
in
self
.
callback_container
.
callbacks
]
)
if
already_restored
:
pass
else
:
for
callback
in
self
.
callback_container
.
callbacks
:
if
callback
.
__class__
.
__name__
==
"ModelCheckpoint"
:
if
callback
.
save_best_only
:
if
self
.
verbose
:
print
(
f
"Model weights restored to best epoch:
{
callback
.
best_epoch
+
1
}
"
)
self
.
model
.
load_state_dict
(
callback
.
best_state_dict
)
else
:
if
self
.
verbose
:
print
(
"Model weights after training corresponds to the those of the "
"final epoch which might not be the best performing weights. Use "
"the 'ModelCheckpoint' Callback to restore the best epoch weights."
)
def
_set_loss_fn
(
self
,
objective
,
custom_loss_function
,
**
kwargs
):
if
custom_loss_function
is
not
None
:
return
custom_loss_function
class_weight
=
(
torch
.
tensor
(
kwargs
[
"class_weight"
]).
to
(
self
.
device
)
if
"class_weight"
in
kwargs
else
None
)
if
self
.
objective
!=
"regression"
:
return
bayesian_alias_to_loss
(
objective
,
weight
=
class_weight
)
else
:
return
bayesian_alias_to_loss
(
objective
)
def
_set_reduce_on_plateau_criterion
(
self
,
lr_scheduler
,
reducelronplateau_criterion
):
self
.
reducelronplateau
=
False
if
isinstance
(
lr_scheduler
,
ReduceLROnPlateau
):
self
.
reducelronplateau
=
True
if
self
.
reducelronplateau
and
not
reducelronplateau_criterion
:
UserWarning
(
"The learning rate scheduler is of type ReduceLROnPlateau. The step method in this"
" scheduler requires a 'metrics' param that can be either the validation loss or the"
" validation metric. Please, when instantiating the Trainer, specify which quantity"
" will be tracked using reducelronplateau_criterion = 'loss' (default) or"
" reducelronplateau_criterion = 'metric'"
)
else
:
self
.
reducelronplateau_criterion
=
"loss"
def
_set_lr_scheduler_running_params
(
self
,
lr_scheduler
,
**
kwargs
):
reducelronplateau_criterion
=
kwargs
.
get
(
"reducelronplateau_criterion"
,
None
)
self
.
_set_reduce_on_plateau_criterion
(
lr_scheduler
,
reducelronplateau_criterion
)
if
lr_scheduler
is
not
None
:
self
.
cyclic_lr
=
"cycl"
in
lr_scheduler
.
__class__
.
__name__
.
lower
()
else
:
self
.
cyclic_lr
=
False
def
_set_callbacks_and_metrics
(
self
,
callbacks
,
metrics
):
self
.
callbacks
:
List
=
[
History
(),
LRShedulerCallback
()]
if
callbacks
is
not
None
:
for
callback
in
callbacks
:
if
isinstance
(
callback
,
type
):
callback
=
callback
()
self
.
callbacks
.
append
(
callback
)
if
metrics
is
not
None
:
self
.
metric
=
MultipleMetrics
(
metrics
)
self
.
callbacks
+=
[
MetricCallback
(
self
.
metric
)]
else
:
self
.
metric
=
None
self
.
callback_container
=
CallbackContainer
(
self
.
callbacks
)
self
.
callback_container
.
set_model
(
self
.
model
)
self
.
callback_container
.
set_trainer
(
self
)
@
staticmethod
def
_set_device_and_num_workers
(
**
kwargs
):
default_num_workers
=
(
0
if
sys
.
platform
==
"darwin"
and
sys
.
version_info
.
minor
>
7
else
os
.
cpu_count
()
)
default_device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
kwargs
.
get
(
"device"
,
default_device
)
num_workers
=
kwargs
.
get
(
"num_workers"
,
default_num_workers
)
return
device
,
num_workers
pytorch_widedeep/training/_base_trainer
s
.py
→
pytorch_widedeep/training/_base_trainer.py
浏览文件 @
1690395b
...
...
@@ -17,10 +17,7 @@ from pytorch_widedeep.callbacks import (
LRShedulerCallback
,
)
from
pytorch_widedeep.initializers
import
Initializer
,
MultipleInitializer
from
pytorch_widedeep.training._trainer_utils
import
(
alias_to_loss
,
bayesian_alias_to_loss
,
)
from
pytorch_widedeep.training._trainer_utils
import
alias_to_loss
from
pytorch_widedeep.models.tabular.tabnet._utils
import
create_explain_matrix
from
pytorch_widedeep.training._multiple_optimizer
import
MultipleOptimizer
from
pytorch_widedeep.training._multiple_transforms
import
MultipleTransforms
...
...
@@ -28,9 +25,6 @@ from pytorch_widedeep.training._loss_and_obj_aliases import _ObjectiveToMethod
from
pytorch_widedeep.training._multiple_lr_scheduler
import
(
MultipleLRScheduler
,
)
from
pytorch_widedeep.bayesian_models._base_bayesian_model
import
(
BaseBayesianModel
,
)
class
BaseTrainer
(
ABC
):
...
...
@@ -346,197 +340,3 @@ class BaseTrainer(ABC):
device
=
kwargs
.
get
(
"device"
,
default_device
)
num_workers
=
kwargs
.
get
(
"num_workers"
,
default_num_workers
)
return
device
,
num_workers
# There are some nuances in the Bayesian Trainer that make it hard to build an
# overall BaseTrainer. We could still perhaps code a very basic Trainer and
# then pass it to a BaseTrainer and BaseBayesianTrainer. However in this
# particular case we prefer code repetition as we believe is a simpler
# solution
class
BaseBayesianTrainer
(
ABC
):
def
__init__
(
self
,
model
:
BaseBayesianModel
,
objective
:
str
,
custom_loss_function
:
Optional
[
Module
],
optimizer
:
Optimizer
,
lr_scheduler
:
LRScheduler
,
callbacks
:
Optional
[
List
[
Callback
]],
metrics
:
Optional
[
Union
[
List
[
Metric
],
List
[
TorchMetric
]]],
verbose
:
int
,
seed
:
int
,
**
kwargs
,
):
if
objective
not
in
[
"binary"
,
"multiclass"
,
"regression"
]:
raise
ValueError
(
"If 'custom_loss_function' is not None, 'objective' must be 'binary' "
"'multiclass' or 'regression', consistent with the loss function"
)
self
.
device
,
self
.
num_workers
=
self
.
_set_device_and_num_workers
(
**
kwargs
)
self
.
early_stop
=
False
self
.
model
=
model
self
.
model
.
to
(
self
.
device
)
self
.
verbose
=
verbose
self
.
seed
=
seed
self
.
objective
=
objective
self
.
loss_fn
=
self
.
_set_loss_fn
(
objective
,
custom_loss_function
,
**
kwargs
)
self
.
optimizer
=
(
optimizer
if
optimizer
is
not
None
else
torch
.
optim
.
AdamW
(
self
.
model
.
parameters
())
)
self
.
lr_scheduler
=
lr_scheduler
self
.
_set_lr_scheduler_running_params
(
lr_scheduler
,
**
kwargs
)
self
.
_set_callbacks_and_metrics
(
callbacks
,
metrics
)
@
abstractmethod
def
fit
(
self
,
X_tab
:
np
.
ndarray
,
target
:
np
.
ndarray
,
X_tab_val
:
Optional
[
np
.
ndarray
],
target_val
:
Optional
[
np
.
ndarray
],
val_split
:
Optional
[
float
],
n_epochs
:
int
,
val_freq
:
int
,
batch_size
:
int
,
n_train_samples
:
int
,
n_val_samples
:
int
,
):
raise
NotImplementedError
(
"Trainer.fit method not implemented"
)
@
abstractmethod
def
predict
(
self
,
X_tab
:
np
.
ndarray
,
n_samples
:
int
,
return_samples
:
bool
,
batch_size
:
int
,
)
->
np
.
ndarray
:
raise
NotImplementedError
(
"Trainer.predict method not implemented"
)
@
abstractmethod
def
predict_proba
(
self
,
X_tab
:
np
.
ndarray
,
n_samples
:
int
,
return_samples
:
bool
,
batch_size
:
int
,
)
->
np
.
ndarray
:
raise
NotImplementedError
(
"Trainer.predict_proba method not implemented"
)
@
abstractmethod
def
save
(
self
,
path
:
str
,
save_state_dict
:
bool
,
model_filename
:
str
,
):
raise
NotImplementedError
(
"Trainer.save method not implemented"
)
def
_restore_best_weights
(
self
):
already_restored
=
any
(
[
(
callback
.
__class__
.
__name__
==
"EarlyStopping"
and
callback
.
restore_best_weights
)
for
callback
in
self
.
callback_container
.
callbacks
]
)
if
already_restored
:
pass
else
:
for
callback
in
self
.
callback_container
.
callbacks
:
if
callback
.
__class__
.
__name__
==
"ModelCheckpoint"
:
if
callback
.
save_best_only
:
if
self
.
verbose
:
print
(
f
"Model weights restored to best epoch:
{
callback
.
best_epoch
+
1
}
"
)
self
.
model
.
load_state_dict
(
callback
.
best_state_dict
)
else
:
if
self
.
verbose
:
print
(
"Model weights after training corresponds to the those of the "
"final epoch which might not be the best performing weights. Use "
"the 'ModelCheckpoint' Callback to restore the best epoch weights."
)
def
_set_loss_fn
(
self
,
objective
,
custom_loss_function
,
**
kwargs
):
if
custom_loss_function
is
not
None
:
return
custom_loss_function
class_weight
=
(
torch
.
tensor
(
kwargs
[
"class_weight"
]).
to
(
self
.
device
)
if
"class_weight"
in
kwargs
else
None
)
if
self
.
objective
!=
"regression"
:
return
bayesian_alias_to_loss
(
objective
,
weight
=
class_weight
)
else
:
return
bayesian_alias_to_loss
(
objective
)
def
_set_reduce_on_plateau_criterion
(
self
,
lr_scheduler
,
reducelronplateau_criterion
):
self
.
reducelronplateau
=
False
if
isinstance
(
lr_scheduler
,
ReduceLROnPlateau
):
self
.
reducelronplateau
=
True
if
self
.
reducelronplateau
and
not
reducelronplateau_criterion
:
UserWarning
(
"The learning rate scheduler is of type ReduceLROnPlateau. The step method in this"
" scheduler requires a 'metrics' param that can be either the validation loss or the"
" validation metric. Please, when instantiating the Trainer, specify which quantity"
" will be tracked using reducelronplateau_criterion = 'loss' (default) or"
" reducelronplateau_criterion = 'metric'"
)
else
:
self
.
reducelronplateau_criterion
=
"loss"
def
_set_lr_scheduler_running_params
(
self
,
lr_scheduler
,
**
kwargs
):
reducelronplateau_criterion
=
kwargs
.
get
(
"reducelronplateau_criterion"
,
None
)
self
.
_set_reduce_on_plateau_criterion
(
lr_scheduler
,
reducelronplateau_criterion
)
if
lr_scheduler
is
not
None
:
self
.
cyclic_lr
=
"cycl"
in
lr_scheduler
.
__class__
.
__name__
.
lower
()
else
:
self
.
cyclic_lr
=
False
def
_set_callbacks_and_metrics
(
self
,
callbacks
,
metrics
):
self
.
callbacks
:
List
=
[
History
(),
LRShedulerCallback
()]
if
callbacks
is
not
None
:
for
callback
in
callbacks
:
if
isinstance
(
callback
,
type
):
callback
=
callback
()
self
.
callbacks
.
append
(
callback
)
if
metrics
is
not
None
:
self
.
metric
=
MultipleMetrics
(
metrics
)
self
.
callbacks
+=
[
MetricCallback
(
self
.
metric
)]
else
:
self
.
metric
=
None
self
.
callback_container
=
CallbackContainer
(
self
.
callbacks
)
self
.
callback_container
.
set_model
(
self
.
model
)
self
.
callback_container
.
set_trainer
(
self
)
@
staticmethod
def
_set_device_and_num_workers
(
**
kwargs
):
default_num_workers
=
(
0
if
sys
.
platform
==
"darwin"
and
sys
.
version_info
.
minor
>
7
else
os
.
cpu_count
()
)
default_device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
kwargs
.
get
(
"device"
,
default_device
)
num_workers
=
kwargs
.
get
(
"num_workers"
,
default_num_workers
)
return
device
,
num_workers
pytorch_widedeep/training/bayesian_trainer.py
浏览文件 @
1690395b
...
...
@@ -12,12 +12,14 @@ from pytorch_widedeep.metrics import Metric
from
pytorch_widedeep.wdtypes
import
*
# noqa: F403
from
pytorch_widedeep.callbacks
import
Callback
from
pytorch_widedeep.utils.general_utils
import
Alias
from
pytorch_widedeep.training._base_trainers
import
BaseBayesianTrainer
from
pytorch_widedeep.training._trainer_utils
import
(
save_epoch_logs
,
print_loss_and_metric
,
tabular_train_val_split
,
)
from
pytorch_widedeep.training._base_bayesian_trainer
import
(
BaseBayesianTrainer
,
)
from
pytorch_widedeep.bayesian_models._base_bayesian_model
import
(
BaseBayesianModel
,
)
...
...
@@ -128,7 +130,7 @@ class BayesianTrainer(BaseBayesianTrainer):
target_val
:
Optional
[
np
.
ndarray
]
=
None
,
val_split
:
Optional
[
float
]
=
None
,
n_epochs
:
int
=
1
,
val_freq
:
int
=
1
,
val
idation
_freq
:
int
=
1
,
batch_size
:
int
=
32
,
n_train_samples
:
int
=
2
,
n_val_samples
:
int
=
2
,
...
...
@@ -202,7 +204,9 @@ class BayesianTrainer(BaseBayesianTrainer):
epoch_logs
=
save_epoch_logs
(
epoch_logs
,
train_loss
,
train_score
,
"train"
)
on_epoch_end_metric
=
None
if
eval_set
is
not
None
and
epoch
%
val_freq
==
(
val_freq
-
1
):
if
eval_set
is
not
None
and
epoch
%
validation_freq
==
(
validation_freq
-
1
):
self
.
callback_container
.
on_eval_begin
()
self
.
valid_running_loss
=
0.0
with
trange
(
eval_steps
,
disable
=
self
.
verbose
!=
1
)
as
v
:
...
...
pytorch_widedeep/training/trainer.py
浏览文件 @
1690395b
...
...
@@ -21,7 +21,7 @@ from pytorch_widedeep.initializers import Initializer
from
pytorch_widedeep.training._finetune
import
FineTune
from
pytorch_widedeep.utils.general_utils
import
Alias
from
pytorch_widedeep.training._wd_dataset
import
WideDeepDataset
from
pytorch_widedeep.training._base_trainer
s
import
BaseTrainer
from
pytorch_widedeep.training._base_trainer
import
BaseTrainer
from
pytorch_widedeep.training._trainer_utils
import
(
save_epoch_logs
,
wd_train_val_split
,
...
...
pytorch_widedeep/wdtypes.py
浏览文件 @
1690395b
...
...
@@ -71,7 +71,7 @@ from torchvision.transforms import (
from
torch.optim.lr_scheduler
import
_LRScheduler
from
torch.utils.data.dataloader
import
DataLoader
from
pytorch_widedeep.models
import
WideDeep
from
pytorch_widedeep.models
import
*
from
pytorch_widedeep.models.tabular.tabnet.sparsemax
import
(
Entmax15
,
Sparsemax
,
...
...
@@ -122,3 +122,15 @@ Transforms = Union[
LRScheduler
=
_LRScheduler
ModelParams
=
Generator
[
Tensor
,
Tensor
,
Tensor
]
NormLayers
=
Union
[
torch
.
nn
.
Identity
,
torch
.
nn
.
LayerNorm
,
torch
.
nn
.
BatchNorm1d
]
ModelWithAttention
=
Union
[
TabTransformer
,
SAINT
,
FTTransformer
,
TabFastFormer
,
TabPerceiver
,
ContextAttentionMLP
,
SelfAttentionMLP
,
]
ModelWithoutAttention
=
Union
[
TabMlp
,
TabResnet
,
TabNet
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录