Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Pytorch Widedeep
提交
c36b38c7
P
Pytorch Widedeep
项目概览
Greenplum
/
Pytorch Widedeep
10 个月 前同步成功
通知
9
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Pytorch Widedeep
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
c36b38c7
编写于
3月 05, 2021
作者:
J
jrzaurin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
lr scheduler step as a Callback and adjusted a few tests
上级
6def0a0b
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
77 addition
and
57 deletion
+77
-57
examples/adult_census.py
examples/adult_census.py
+3
-3
pytorch_widedeep/callbacks.py
pytorch_widedeep/callbacks.py
+47
-0
pytorch_widedeep/models/tabnet/tab_net.py
pytorch_widedeep/models/tabnet/tab_net.py
+5
-5
pytorch_widedeep/models/wide_deep.py
pytorch_widedeep/models/wide_deep.py
+19
-2
pytorch_widedeep/training/trainer.py
pytorch_widedeep/training/trainer.py
+2
-46
tests/test_model_functioning/test_miscellaneous.py
tests/test_model_functioning/test_miscellaneous.py
+1
-1
未找到文件。
examples/adult_census.py
浏览文件 @
c36b38c7
...
...
@@ -81,8 +81,8 @@ if __name__ == "__main__":
wide_opt
=
torch
.
optim
.
Adam
(
model
.
wide
.
parameters
(),
lr
=
0.01
)
deep_opt
=
RAdam
(
model
.
deeptabular
.
parameters
())
wide_sch
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
wide_opt
,
step_size
=
3
)
deep_sch
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
deep_opt
,
step_size
=
5
)
wide_sch
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
wide_opt
,
step_size
=
2
)
deep_sch
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
deep_opt
,
step_size
=
3
)
optimizers
=
{
"wide"
:
wide_opt
,
"deeptabular"
:
deep_opt
}
schedulers
=
{
"wide"
:
wide_sch
,
"deeptabular"
:
deep_sch
}
...
...
@@ -108,7 +108,7 @@ if __name__ == "__main__":
X_wide
=
X_wide
,
X_tab
=
X_tab
,
target
=
target
,
n_epochs
=
4
,
n_epochs
=
10
,
batch_size
=
64
,
val_split
=
0.2
,
)
...
...
pytorch_widedeep/callbacks.py
浏览文件 @
c36b38c7
...
...
@@ -134,6 +134,53 @@ class History(Callback):
self
.
trainer
.
history
.
setdefault
(
k
,
[]).
append
(
v
)
class
LRShedulerCallback
(
Callback
):
r
"""Callback for the learning rate schedulers to take a step
This callback runs by default within :obj:`Trainer`, therefore, should not
be passed to the ``Trainer``. Is included here just for completion.
"""
def
on_batch_end
(
self
,
batch
:
int
,
logs
:
Optional
[
Dict
]
=
None
):
if
self
.
trainer
.
lr_scheduler
is
not
None
:
if
self
.
_multiple_scheduler
():
for
(
model_name
,
scheduler
,
)
in
self
.
trainer
.
lr_scheduler
.
_schedulers
.
items
():
if
self
.
_is_cyclic
(
model_name
):
scheduler
.
step
()
elif
self
.
trainer
.
cyclic_lr
:
self
.
trainer
.
lr_scheduler
.
step
()
def
on_epoch_end
(
self
,
epoch
:
int
,
logs
:
Optional
[
Dict
]
=
None
):
if
self
.
trainer
.
lr_scheduler
is
not
None
:
if
self
.
_multiple_scheduler
():
for
(
model_name
,
scheduler
,
)
in
self
.
trainer
.
lr_scheduler
.
_schedulers
.
items
():
if
not
self
.
_is_cyclic
(
model_name
):
scheduler
.
step
()
elif
not
self
.
trainer
.
cyclic_lr
:
self
.
trainer
.
lr_scheduler
.
step
()
def
_multiple_scheduler
(
self
):
return
self
.
trainer
.
lr_scheduler
.
__class__
.
__name__
==
"MultipleLRScheduler"
def
_is_cyclic
(
self
,
model_name
:
str
):
return
(
self
.
_has_scheduler
(
model_name
)
and
"cycl"
in
self
.
trainer
.
lr_scheduler
.
_schedulers
[
model_name
].
__class__
.
__name__
.
lower
()
)
def
_has_scheduler
(
self
,
model_name
:
str
):
return
model_name
in
self
.
trainer
.
lr_scheduler
.
_schedulers
class
LRHistory
(
Callback
):
def
__init__
(
self
,
n_epochs
:
int
):
r
"""Saves the learning rates during training to a ``lr_history`` attribute.
...
...
pytorch_widedeep/models/tabnet/tab_net.py
浏览文件 @
c36b38c7
...
...
@@ -367,17 +367,17 @@ class EmbeddingsAndContinuous(nn.Module):
}
)
self
.
embedding_dropout
=
nn
.
Dropout
(
embed_dropout
)
emb_
inp
_dim
=
np
.
sum
([
embed
[
2
]
for
embed
in
self
.
embed_input
])
emb_
out
_dim
=
np
.
sum
([
embed
[
2
]
for
embed
in
self
.
embed_input
])
# Continuous
if
self
.
continuous_cols
is
not
None
:
cont_
inp
_dim
=
len
(
self
.
continuous_cols
)
cont_
out
_dim
=
len
(
self
.
continuous_cols
)
if
self
.
batchnorm_cont
:
self
.
norm
=
nn
.
BatchNorm1d
(
cont_
inp
_dim
)
self
.
norm
=
nn
.
BatchNorm1d
(
cont_
out
_dim
)
else
:
cont_
inp
_dim
=
0
cont_
out
_dim
=
0
self
.
output_dim
=
emb_
inp_dim
+
cont_inp
_dim
self
.
output_dim
=
emb_
out_dim
+
cont_out
_dim
def
forward
(
self
,
X
):
embed
=
[
...
...
pytorch_widedeep/models/wide_deep.py
浏览文件 @
c36b38c7
...
...
@@ -164,10 +164,19 @@ class WideDeep(nn.Module):
if
self
.
deeptabular
is
not
None
:
self
.
is_tabnet
=
deeptabular
.
__class__
.
__name__
==
"TabNet"
else
:
self
.
is_tabnet
=
False
if
self
.
deephead
is
None
:
if
head_hidden_dims
is
not
None
:
self
.
_build_deephead
()
self
.
_build_deephead
(
head_hidden_dims
,
head_activation
,
head_dropout
,
head_batchnorm
,
head_batchnorm_last
,
head_linear_first
,
)
else
:
self
.
_add_pred_layer
()
...
...
@@ -178,7 +187,15 @@ class WideDeep(nn.Module):
else
:
return
self
.
_forward_deep
(
X
,
wide_out
)
def
_build_deephead
(
self
):
def
_build_deephead
(
self
,
head_hidden_dims
,
head_activation
,
head_dropout
,
head_batchnorm
,
head_batchnorm_last
,
head_linear_first
,
):
deep_dim
=
0
if
self
.
deeptabular
is
not
None
:
deep_dim
+=
self
.
deeptabular
.
output_dim
...
...
pytorch_widedeep/training/trainer.py
浏览文件 @
c36b38c7
...
...
@@ -12,7 +12,7 @@ from pytorch_widedeep.losses import MSLELoss, RMSELoss, FocalLoss, RMSLELoss
from
pytorch_widedeep.models
import
WideDeep
from
pytorch_widedeep.metrics
import
Metric
,
MetricCallback
,
MultipleMetrics
from
pytorch_widedeep.wdtypes
import
*
# noqa: F403
from
pytorch_widedeep.callbacks
import
History
,
Callback
,
CallbackContainer
from
pytorch_widedeep.callbacks
import
History
,
LRShedulerCallback
,
Callback
,
CallbackContainer
from
pytorch_widedeep.initializers
import
Initializer
,
MultipleInitializer
from
pytorch_widedeep.training._finetune
import
FineTune
from
pytorch_widedeep.utils.general_utils
import
Alias
...
...
@@ -554,8 +554,6 @@ class Trainer:
)
else
:
t
.
set_postfix
(
loss
=
train_loss
)
if
self
.
lr_scheduler
:
self
.
_lr_scheduler_step
(
step_location
=
"on_batch_end"
)
self
.
callback_container
.
on_batch_end
(
batch
=
batch_idx
)
epoch_logs
[
"train_loss"
]
=
train_loss
if
score
is
not
None
:
...
...
@@ -582,8 +580,6 @@ class Trainer:
for
k
,
v
in
score
.
items
():
log_k
=
"_"
.
join
([
"val"
,
k
])
epoch_logs
[
log_k
]
=
v
if
self
.
lr_scheduler
:
self
.
_lr_scheduler_step
(
step_location
=
"on_epoch_end"
)
self
.
callback_container
.
on_epoch_end
(
epoch
,
epoch_logs
)
if
self
.
early_stop
:
self
.
callback_container
.
on_train_end
(
epoch_logs
)
...
...
@@ -936,46 +932,6 @@ class Trainer:
self
.
model
.
deepimage
,
"deepimage"
,
loader
,
n_epochs
,
max_lr
)
def
_lr_scheduler_step
(
self
,
step_location
:
str
):
# noqa: C901
r
"""
Function to execute the learning rate schedulers steps.
If the lr_scheduler is Cyclic (i.e. CyclicLR or OneCycleLR), the step
must happen after training each bach durig training. On the other
hand, if the scheduler is not Cyclic, is expected to be called after
validation. (Consider coding this function as callback)
Parameters
----------
step_location: Str
Indicates where to run the lr_scheduler step
"""
if
(
self
.
lr_scheduler
.
__class__
.
__name__
==
"MultipleLRScheduler"
and
self
.
cyclic_lr
):
if
step_location
==
"on_batch_end"
:
for
model_name
,
scheduler
in
self
.
lr_scheduler
.
_schedulers
.
items
():
# type: ignore
if
"cycl"
in
scheduler
.
__class__
.
__name__
.
lower
():
scheduler
.
step
()
# type: ignore
elif
step_location
==
"on_epoch_end"
:
for
scheduler_name
,
scheduler
in
self
.
lr_scheduler
.
_schedulers
.
items
():
# type: ignore
if
"cycl"
not
in
scheduler
.
__class__
.
__name__
.
lower
():
scheduler
.
step
()
# type: ignore
elif
self
.
cyclic_lr
:
if
step_location
==
"on_batch_end"
:
self
.
lr_scheduler
.
step
()
# type: ignore
else
:
pass
elif
self
.
lr_scheduler
.
__class__
.
__name__
==
"MultipleLRScheduler"
:
if
step_location
==
"on_epoch_end"
:
self
.
lr_scheduler
.
step
()
# type: ignore
else
:
pass
elif
step_location
==
"on_epoch_end"
:
self
.
lr_scheduler
.
step
()
# type: ignore
else
:
pass
def
_training_step
(
self
,
data
:
Dict
[
str
,
Tensor
],
target
:
Tensor
,
batch_idx
:
int
):
self
.
model
.
train
()
X
=
{
k
:
v
.
cuda
()
for
k
,
v
in
data
.
items
()}
if
use_cuda
else
data
...
...
@@ -1192,7 +1148,7 @@ class Trainer:
return
None
def
_set_callbacks_and_metrics
(
self
,
callbacks
,
metrics
):
self
.
callbacks
:
List
=
[
History
()]
self
.
callbacks
:
List
=
[
History
()
,
LRShedulerCallback
()
]
if
callbacks
is
not
None
:
for
callback
in
callbacks
:
if
isinstance
(
callback
,
type
):
...
...
tests/test_model_functioning/test_miscellaneous.py
浏览文件 @
c36b38c7
...
...
@@ -110,7 +110,7 @@ def test_non_instantiated_callbacks():
model
=
WideDeep
(
wide
=
wide
,
deeptabular
=
tabmlp
)
callbacks
=
[
EarlyStopping
]
trainer
=
Trainer
(
model
,
objective
=
"binary"
,
callbacks
=
callbacks
)
assert
trainer
.
callbacks
[
1
].
__class__
.
__name__
==
"EarlyStopping"
assert
trainer
.
callbacks
[
2
].
__class__
.
__name__
==
"EarlyStopping"
###############################################################################
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录