Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Pytorch Widedeep
提交
fd3e4e34
P
Pytorch Widedeep
项目概览
Greenplum
/
Pytorch Widedeep
11 个月 前同步成功
通知
9
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Pytorch Widedeep
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
fd3e4e34
编写于
1月 03, 2022
作者:
J
jrzaurin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added tests for two minor changes in the TabPreprocessor and the Trainer
上级
cee8e1c7
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
39 addition
and
1 deletion
+39
-1
pytorch_widedeep/preprocessing/tab_preprocessor.py
pytorch_widedeep/preprocessing/tab_preprocessor.py
+3
-1
tests/test_data_utils/test_du_tabular.py
tests/test_data_utils/test_du_tabular.py
+16
-0
tests/test_model_functioning/test_fit_methods.py
tests/test_model_functioning/test_fit_methods.py
+20
-0
未找到文件。
pytorch_widedeep/preprocessing/tab_preprocessor.py
浏览文件 @
fd3e4e34
...
...
@@ -295,7 +295,9 @@ class TabPreprocessor(BasePreprocessor):
and
self
.
continuous_cols
is
not
None
and
len
(
np
.
intersect1d
(
self
.
cat_embed_cols
,
self
.
continuous_cols
))
>
0
):
overlapping_cols
=
list
(
np
.
intersect1d
(
cat_embed_cols
,
continuous_cols
))
overlapping_cols
=
list
(
np
.
intersect1d
(
self
.
cat_embed_cols
,
self
.
continuous_cols
)
)
raise
ValueError
(
"Currently passing columns as both categorical and continuum is not supported."
" Please, choose one or the other for the following columns: {}"
.
format
(
...
...
tests/test_data_utils/test_du_tabular.py
浏览文件 @
fd3e4e34
...
...
@@ -294,3 +294,19 @@ def test_embed_sz_rule_of_thumb(rule):
tab_preprocessor
.
embed_dim
[
col
]
==
embed_szs
[
col
]
for
col
in
embed_szs
.
keys
()
]
assert
all
(
out
)
###############################################################################
# Test Valuerror for repeated cols
###############################################################################
def
test_overlapping_cols_valueerror
():
embed_cols
=
[
"col1"
,
"col2"
]
cont_cols
=
[
"col1"
,
"col2"
]
with
pytest
.
raises
(
ValueError
):
tab_preprocessor
=
TabPreprocessor
(
# noqa: F841
cat_embed_cols
=
embed_cols
,
continuous_cols
=
cont_cols
)
tests/test_model_functioning/test_fit_methods.py
浏览文件 @
fd3e4e34
...
...
@@ -300,3 +300,23 @@ def test_custom_dataloader():
)
# simply checking that runs with DataLoaderImbalanced
assert
"train_loss"
in
trainer
.
history
.
keys
()
##############################################################################
# Test raise warning for multiclass classification
##############################################################################
def
test_multiclass_warning
():
wide
=
Wide
(
np
.
unique
(
X_wide
).
shape
[
0
],
1
)
deeptabular
=
TabMlp
(
column_idx
=
column_idx
,
cat_embed_input
=
embed_input
,
continuous_cols
=
colnames
[
-
5
:],
mlp_hidden_dims
=
[
32
,
16
],
mlp_dropout
=
[
0.5
,
0.5
],
)
model
=
WideDeep
(
wide
=
wide
,
deeptabular
=
deeptabular
)
with
pytest
.
raises
(
ValueError
):
trainer
=
Trainer
(
model
,
loss
=
"multiclass"
,
verbose
=
0
)
# noqa: F841
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录