Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Pytorch Widedeep
提交
42cfe5cc
P
Pytorch Widedeep
项目概览
Greenplum
/
Pytorch Widedeep
大约 1 年 前同步成功
通知
9
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Pytorch Widedeep
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
42cfe5cc
编写于
7月 14, 2023
作者:
J
Javier
提交者:
GitHub
7月 14, 2023
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #177 from jrzaurin/fix_restore_best_weights
Fix #175 early stopping and model checkpoint restoring weights.
上级
26e19859
8406813c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
188 addition
and
39 deletion
+188
-39
examples/scripts/bio_imbalanced_loader.py
examples/scripts/bio_imbalanced_loader.py
+1
-1
pytorch_widedeep/callbacks.py
pytorch_widedeep/callbacks.py
+19
-21
pytorch_widedeep/dataloaders.py
pytorch_widedeep/dataloaders.py
+1
-0
pytorch_widedeep/training/_base_trainer.py
pytorch_widedeep/training/_base_trainer.py
+28
-10
pytorch_widedeep/training/_finetune.py
pytorch_widedeep/training/_finetune.py
+2
-1
tests/test_model_functioning/test_callbacks.py
tests/test_model_functioning/test_callbacks.py
+137
-6
未找到文件。
examples/scripts/bio_imbalanced_loader.py
浏览文件 @
42cfe5cc
...
...
@@ -86,7 +86,7 @@ trainer.fit(
n_epochs
=
1
,
batch_size
=
32
,
custom_dataloader
=
DataLoaderImbalanced
,
oversample_mul
=
5
,
**
{
"oversample_mul"
:
5
}
,
)
print
(
"Training time[s]: {}"
.
format
(
...
...
pytorch_widedeep/callbacks.py
浏览文件 @
42cfe5cc
...
...
@@ -4,9 +4,9 @@ Code here is mostly based on the code from the torchsample and Keras packages
CREDIT TO THE TORCHSAMPLE AND KERAS TEAMS
"""
import
os
import
copy
import
datetime
import
warnings
import
copy
import
numpy
as
np
import
torch
...
...
@@ -349,6 +349,10 @@ class ModelCheckpoint(Callback):
monitor: str, default="loss"
quantity to monitor. Typically _'val_loss'_ or metric name
(e.g. _'val_acc'_)
min_delta: float, default=0.
minimum change in the monitored quantity to qualify as an
improvement, i.e. an absolute change of less than min_delta, will
count as no improvement.
verbose:int, default=0
verbosity mode
save_best_only: bool, default=False,
...
...
@@ -397,6 +401,7 @@ class ModelCheckpoint(Callback):
self
,
filepath
:
Optional
[
str
]
=
None
,
monitor
:
str
=
"val_loss"
,
min_delta
:
float
=
0.0
,
verbose
:
int
=
0
,
save_best_only
:
bool
=
False
,
mode
:
str
=
"auto"
,
...
...
@@ -407,6 +412,7 @@ class ModelCheckpoint(Callback):
self
.
filepath
=
filepath
self
.
monitor
=
monitor
self
.
min_delta
=
min_delta
self
.
verbose
=
verbose
self
.
save_best_only
=
save_best_only
self
.
mode
=
mode
...
...
@@ -450,6 +456,11 @@ class ModelCheckpoint(Callback):
self
.
monitor_op
=
np
.
less
self
.
best
=
np
.
Inf
if
self
.
monitor_op
==
np
.
greater
:
self
.
min_delta
*=
1
else
:
self
.
min_delta
*=
-
1
def
on_epoch_end
(
# noqa: C901
self
,
epoch
:
int
,
logs
:
Optional
[
Dict
]
=
None
,
metric
:
Optional
[
float
]
=
None
):
...
...
@@ -468,33 +479,20 @@ class ModelCheckpoint(Callback):
RuntimeWarning
,
)
else
:
if
self
.
monitor_op
(
current
,
self
.
best
):
if
self
.
monitor_op
(
current
-
self
.
min_delta
,
self
.
best
):
if
self
.
verbose
>
0
:
if
self
.
filepath
:
print
(
"
\n
Epoch %05d: %s improved from %0.5f to %0.5f,"
" saving model to %s"
%
(
epoch
+
1
,
self
.
monitor
,
self
.
best
,
current
,
filepath
,
)
f
"
\n
Epoch
{
epoch
+
1
}
:
{
self
.
monitor
}
improved from
{
self
.
best
:.
5
f
}
to
{
current
:.
5
f
}
"
f
"Saving model to
{
filepath
}
"
)
else
:
print
(
"
\n
Epoch %05d: %s improved from %0.5f to %0.5f"
%
(
epoch
+
1
,
self
.
monitor
,
self
.
best
,
current
,
)
f
"
\n
Epoch
{
epoch
+
1
}
:
{
self
.
monitor
}
improved from
{
self
.
best
:.
5
f
}
to
{
current
:.
5
f
}
"
)
self
.
best
=
current
self
.
best_epoch
=
epoch
self
.
best_state_dict
=
self
.
model
.
state_dict
(
)
self
.
best_state_dict
=
copy
.
deepcopy
(
self
.
model
.
state_dict
()
)
if
self
.
filepath
:
torch
.
save
(
self
.
best_state_dict
,
filepath
)
if
self
.
max_save
>
0
:
...
...
@@ -508,8 +506,8 @@ class ModelCheckpoint(Callback):
else
:
if
self
.
verbose
>
0
:
print
(
"
\n
Epoch %05d: %s did not improve from %0.5f
"
%
(
epoch
+
1
,
self
.
monitor
,
self
.
best
)
f
"
\n
Epoch
{
epoch
+
1
}
:
{
self
.
monitor
}
did not improve from
{
self
.
best
:.
5
f
}
"
f
" considering a 'min_delta' improvement of
{
self
.
min_delta
:.
5
f
}
"
)
if
not
self
.
save_best_only
and
self
.
filepath
:
if
self
.
verbose
>
0
:
...
...
pytorch_widedeep/dataloaders.py
浏览文件 @
42cfe5cc
...
...
@@ -85,6 +85,7 @@ class DataLoaderImbalanced(DataLoader):
self
.
with_lds
=
dataset
.
with_lds
if
"oversample_mul"
in
kwargs
:
oversample_mul
=
kwargs
[
"oversample_mul"
]
del
kwargs
[
"oversample_mul"
]
else
:
oversample_mul
=
1
weights
,
minor_cls_cnt
,
num_clss
=
get_class_weights
(
dataset
)
...
...
pytorch_widedeep/training/_base_trainer.py
浏览文件 @
42cfe5cc
import
os
import
sys
import
warnings
from
abc
import
ABC
,
abstractmethod
import
numpy
as
np
...
...
@@ -130,17 +131,34 @@ class BaseTrainer(ABC):
):
raise
NotImplementedError
(
"Trainer.save method not implemented"
)
def
_restore_best_weights
(
self
):
already_restored
=
any
(
[
(
callback
.
__class__
.
__name__
==
"EarlyStopping"
and
callback
.
restore_best_weights
)
for
callback
in
self
.
callback_container
.
callbacks
]
)
def
_restore_best_weights
(
self
):
# noqa: C901
early_stopping_min_delta
=
None
model_checkpoint_min_delta
=
None
already_restored
=
False
for
callback
in
self
.
callback_container
.
callbacks
:
if
(
callback
.
__class__
.
__name__
==
"EarlyStopping"
and
callback
.
restore_best_weights
):
early_stopping_min_delta
=
callback
.
min_delta
already_restored
=
True
if
callback
.
__class__
.
__name__
==
"ModelCheckpoint"
:
model_checkpoint_min_delta
=
callback
.
min_delta
if
(
early_stopping_min_delta
is
not
None
and
model_checkpoint_min_delta
is
not
None
)
and
(
early_stopping_min_delta
!=
model_checkpoint_min_delta
):
warnings
.
warn
(
"'min_delta' is different in the 'EarlyStopping' and 'ModelCheckpoint' callbacks. "
"This implies a different definition of 'improvement' for these two callbacks"
,
UserWarning
,
)
if
already_restored
:
# already restored via EarlyStopping
pass
else
:
for
callback
in
self
.
callback_container
.
callbacks
:
...
...
pytorch_widedeep/training/_finetune.py
浏览文件 @
42cfe5cc
...
...
@@ -317,6 +317,7 @@ class FineTune:
up, down: Tuple, int
number of steps increasing/decreasing the learning rate during the cycle
"""
up
=
round
((
steps
*
n_epochs
)
*
0.1
)
# up = round((steps * n_epochs) * 0.1)
up
=
max
([
round
((
steps
*
n_epochs
)
*
0.1
),
1
])
down
=
(
steps
*
n_epochs
)
-
up
return
up
,
down
tests/test_model_functioning/test_callbacks.py
浏览文件 @
42cfe5cc
...
...
@@ -481,18 +481,21 @@ def test_early_stopping_get_state():
assert
no_trainer
and
no_model
def
test_early_stopping_restore_state
():
# min_delta is large, so the early stopping condition will never be met except for the first epoch.
# ##############################################################################
# Test the restore weights functionalities after bug fixed
# ##############################################################################
def
test_early_stopping_restore_weights_with_metric
():
# min_delta is large, so the early stopping condition will be met in the first epoch.
early_stopping
=
EarlyStopping
(
restore_best_weights
=
True
,
min_delta
=
1000
,
patience
=
1000
)
trainer
_tt
=
Trainer
(
trainer
=
Trainer
(
model
,
objective
=
"regression"
,
callbacks
=
[
early_stopping
],
verbose
=
0
,
)
trainer
_tt
.
fit
(
trainer
.
fit
(
X_train
=
{
"X_wide"
:
X_wide
,
"X_tab"
:
X_tab
,
"target"
:
target
},
X_val
=
{
"X_wide"
:
X_wide_val
,
"X_tab"
:
X_tab_val
,
"target"
:
target_val
},
target
=
target
,
...
...
@@ -501,8 +504,136 @@ def test_early_stopping_restore_state():
)
assert
early_stopping
.
wait
>
0
# so early stopping is not triggered, but is over-fitting.
pred_val
=
trainer
_tt
.
predict
(
X_test
=
{
"X_wide"
:
X_wide_val
,
"X_tab"
:
X_tab_val
})
restored_metric
=
trainer
_tt
.
loss_fn
(
pred_val
=
trainer
.
predict
(
X_test
=
{
"X_wide"
:
X_wide_val
,
"X_tab"
:
X_tab_val
})
restored_metric
=
trainer
.
loss_fn
(
torch
.
tensor
(
pred_val
),
torch
.
tensor
(
target_val
)
).
item
()
assert
np
.
allclose
(
restored_metric
,
early_stopping
.
best
)
def
test_early_stopping_restore_weights_with_state
():
# Long, perhaps too long, test to check early_stopping restore weights
# functionality
# this is repetitive, but for now I want this unit test "self-contained"
# We first define a model and train it, with early stopping that should
# set the weights back to those after the 1st epoch. We also use
# ModelCheckpoint and save all iterations
wide
=
Wide
(
np
.
unique
(
X_wide
).
shape
[
0
],
1
)
deeptabular
=
TabMlp
(
column_idx
=
column_idx
,
cat_embed_input
=
embed_input
,
continuous_cols
=
colnames
[
-
5
:],
mlp_hidden_dims
=
[
16
,
8
],
)
model
=
WideDeep
(
wide
=
wide
,
deeptabular
=
deeptabular
)
fpath
=
"tests/test_model_functioning/modelcheckpoint/weights_out"
model_checkpoint
=
ModelCheckpoint
(
filepath
=
fpath
,
save_best_only
=
False
,
max_save
=
10
,
min_delta
=
1000
,
# irrelevant here
)
early_stopping
=
EarlyStopping
(
patience
=
3
,
min_delta
=
1000
,
restore_best_weights
=
True
)
trainer
=
Trainer
(
model
,
objective
=
"binary"
,
callbacks
=
[
early_stopping
,
model_checkpoint
],
verbose
=
0
,
)
trainer
.
fit
(
X_train
=
{
"X_wide"
:
X_wide
,
"X_tab"
:
X_tab
,
"target"
:
target
},
X_val
=
{
"X_wide"
:
X_wide_val
,
"X_tab"
:
X_tab_val
,
"target"
:
target_val
},
target
=
target
,
n_epochs
=
5
,
batch_size
=
16
,
)
# We now define a brand new model
new_wide
=
Wide
(
np
.
unique
(
X_wide
).
shape
[
0
],
1
)
new_deeptabular
=
TabMlp
(
column_idx
=
column_idx
,
cat_embed_input
=
embed_input
,
continuous_cols
=
colnames
[
-
5
:],
mlp_hidden_dims
=
[
16
,
8
],
)
new_model
=
WideDeep
(
wide
=
new_wide
,
deeptabular
=
new_deeptabular
)
# In general, the best epoch is equal to the (stopped_epoch - patience) + 1
full_best_epoch_path
=
"_"
.
join
(
[
model_checkpoint
.
filepath
,
str
((
early_stopping
.
stopped_epoch
-
early_stopping
.
patience
)
+
1
)
+
".p"
,
]
)
# we load the weights for the best epoch and these should match those of
# the original model if early_stopping worked
new_model
.
load_state_dict
(
torch
.
load
(
full_best_epoch_path
))
new_model
.
to
(
next
(
model
.
parameters
()).
device
)
shutil
.
rmtree
(
"tests/test_model_functioning/modelcheckpoint/"
)
assert
torch
.
allclose
(
new_model
.
state_dict
()[
"deeptabular.0.encoder.mlp.dense_layer_1.1.weight"
],
model
.
state_dict
()[
"deeptabular.0.encoder.mlp.dense_layer_1.1.weight"
],
)
def
test_model_checkpoint_restore_weights
():
wide
=
Wide
(
np
.
unique
(
X_wide
).
shape
[
0
],
1
)
deeptabular
=
TabMlp
(
column_idx
=
column_idx
,
cat_embed_input
=
embed_input
,
continuous_cols
=
colnames
[
-
5
:],
mlp_hidden_dims
=
[
16
,
8
],
)
model
=
WideDeep
(
wide
=
wide
,
deeptabular
=
deeptabular
)
fpath
=
"tests/test_model_functioning/modelcheckpoint/weights_out"
model_checkpoint
=
ModelCheckpoint
(
filepath
=
fpath
,
save_best_only
=
True
,
min_delta
=
1000
,
# irrelevant here
)
trainer
=
Trainer
(
model
,
objective
=
"binary"
,
callbacks
=
[
model_checkpoint
],
verbose
=
0
,
)
trainer
.
fit
(
X_train
=
{
"X_wide"
:
X_wide
,
"X_tab"
:
X_tab
,
"target"
:
target
},
X_val
=
{
"X_wide"
:
X_wide_val
,
"X_tab"
:
X_tab_val
,
"target"
:
target_val
},
target
=
target
,
n_epochs
=
5
,
batch_size
=
16
,
)
new_wide
=
Wide
(
np
.
unique
(
X_wide
).
shape
[
0
],
1
)
new_deeptabular
=
TabMlp
(
column_idx
=
column_idx
,
cat_embed_input
=
embed_input
,
continuous_cols
=
colnames
[
-
5
:],
mlp_hidden_dims
=
[
16
,
8
],
)
new_model
=
WideDeep
(
wide
=
new_wide
,
deeptabular
=
new_deeptabular
)
full_best_epoch_path
=
"_"
.
join
(
[
model_checkpoint
.
filepath
,
str
(
model_checkpoint
.
best_epoch
+
1
)
+
".p"
]
)
new_model
.
load_state_dict
(
torch
.
load
(
full_best_epoch_path
))
new_model
.
to
(
next
(
model
.
parameters
()).
device
)
shutil
.
rmtree
(
"tests/test_model_functioning/modelcheckpoint/"
)
assert
torch
.
allclose
(
new_model
.
state_dict
()[
"deeptabular.0.encoder.mlp.dense_layer_1.1.weight"
],
model
.
state_dict
()[
"deeptabular.0.encoder.mlp.dense_layer_1.1.weight"
],
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录