Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Pytorch Widedeep
提交
9e9c4d28
P
Pytorch Widedeep
项目概览
Greenplum
/
Pytorch Widedeep
10 个月 前同步成功
通知
9
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Pytorch Widedeep
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
9e9c4d28
编写于
4月 06, 2022
作者:
A
Alexander Shirkov
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[bugfix] unable to handle categorical column names with dots
上级
efa8793e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
61 addition
and
5 deletion
+61
-5
pytorch_widedeep/models/tabular/embeddings_layers.py
pytorch_widedeep/models/tabular/embeddings_layers.py
+24
-5
tests/test_model_functioning/test_miscellaneous.py
tests/test_model_functioning/test_miscellaneous.py
+37
-0
未找到文件。
pytorch_widedeep/models/tabular/embeddings_layers.py
浏览文件 @
9e9c4d28
...
@@ -128,10 +128,19 @@ class DiffSizeCatEmbeddings(nn.Module):
...
@@ -128,10 +128,19 @@ class DiffSizeCatEmbeddings(nn.Module):
self
.
embed_input
=
embed_input
self
.
embed_input
=
embed_input
self
.
use_bias
=
use_bias
self
.
use_bias
=
use_bias
self
.
embed_layers_names
=
None
if
self
.
embed_input
is
not
None
:
self
.
embed_layers_names
=
{
e
[
0
]:
e
[
0
].
replace
(
"."
,
"_"
)
for
e
in
self
.
embed_input
}
# Categorical: val + 1 because 0 is reserved for padding/unseen cateogories.
# Categorical: val + 1 because 0 is reserved for padding/unseen cateogories.
self
.
embed_layers
=
nn
.
ModuleDict
(
self
.
embed_layers
=
nn
.
ModuleDict
(
{
{
"emb_layer_"
+
col
:
nn
.
Embedding
(
val
+
1
,
dim
,
padding_idx
=
0
)
"emb_layer_"
+
self
.
embed_layers_names
[
col
]:
nn
.
Embedding
(
val
+
1
,
dim
,
padding_idx
=
0
)
for
col
,
val
,
dim
in
self
.
embed_input
for
col
,
val
,
dim
in
self
.
embed_input
}
}
)
)
...
@@ -152,7 +161,9 @@ class DiffSizeCatEmbeddings(nn.Module):
...
@@ -152,7 +161,9 @@ class DiffSizeCatEmbeddings(nn.Module):
def
forward
(
self
,
X
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
X
:
Tensor
)
->
Tensor
:
embed
=
[
embed
=
[
self
.
embed_layers
[
"emb_layer_"
+
col
](
X
[:,
self
.
column_idx
[
col
]].
long
())
self
.
embed_layers
[
"emb_layer_"
+
self
.
embed_layers_names
[
col
]](
X
[:,
self
.
column_idx
[
col
]].
long
()
)
+
(
+
(
self
.
biases
[
"bias_"
+
col
].
unsqueeze
(
0
)
self
.
biases
[
"bias_"
+
col
].
unsqueeze
(
0
)
if
self
.
use_bias
if
self
.
use_bias
...
@@ -186,6 +197,12 @@ class SameSizeCatEmbeddings(nn.Module):
...
@@ -186,6 +197,12 @@ class SameSizeCatEmbeddings(nn.Module):
self
.
shared_embed
=
shared_embed
self
.
shared_embed
=
shared_embed
self
.
with_cls_token
=
"cls_token"
in
column_idx
self
.
with_cls_token
=
"cls_token"
in
column_idx
self
.
embed_layers_names
=
None
if
self
.
embed_input
is
not
None
:
self
.
embed_layers_names
=
{
e
[
0
]:
e
[
0
].
replace
(
"."
,
"_"
)
for
e
in
self
.
embed_input
}
categorical_cols
=
[
ei
[
0
]
for
ei
in
embed_input
]
categorical_cols
=
[
ei
[
0
]
for
ei
in
embed_input
]
self
.
cat_idx
=
[
self
.
column_idx
[
col
]
for
col
in
categorical_cols
]
self
.
cat_idx
=
[
self
.
column_idx
[
col
]
for
col
in
categorical_cols
]
...
@@ -211,7 +228,7 @@ class SameSizeCatEmbeddings(nn.Module):
...
@@ -211,7 +228,7 @@ class SameSizeCatEmbeddings(nn.Module):
self
.
embed
:
Union
[
nn
.
ModuleDict
,
nn
.
Embedding
]
=
nn
.
ModuleDict
(
self
.
embed
:
Union
[
nn
.
ModuleDict
,
nn
.
Embedding
]
=
nn
.
ModuleDict
(
{
{
"emb_layer_"
"emb_layer_"
+
col
:
SharedEmbeddings
(
+
self
.
embed_layers_names
[
col
]
:
SharedEmbeddings
(
val
if
col
==
"cls_token"
else
val
+
1
,
val
if
col
==
"cls_token"
else
val
+
1
,
embed_dim
,
embed_dim
,
embed_dropout
,
embed_dropout
,
...
@@ -233,9 +250,11 @@ class SameSizeCatEmbeddings(nn.Module):
...
@@ -233,9 +250,11 @@ class SameSizeCatEmbeddings(nn.Module):
def
forward
(
self
,
X
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
X
:
Tensor
)
->
Tensor
:
if
self
.
shared_embed
:
if
self
.
shared_embed
:
cat_embed
=
[
cat_embed
=
[
self
.
embed
[
"emb_layer_"
+
col
](
# type: ignore[index]
self
.
embed
[
"emb_layer_"
+
self
.
embed_layers_names
[
col
]
](
# type: ignore[index]
X
[:,
self
.
column_idx
[
col
]].
long
()
X
[:,
self
.
column_idx
[
col
]].
long
()
).
unsqueeze
(
1
)
).
unsqueeze
(
1
)
for
col
,
_
in
self
.
embed_input
for
col
,
_
in
self
.
embed_input
]
]
x
=
torch
.
cat
(
cat_embed
,
1
)
x
=
torch
.
cat
(
cat_embed
,
1
)
...
...
tests/test_model_functioning/test_miscellaneous.py
浏览文件 @
9e9c4d28
...
@@ -405,6 +405,43 @@ def test_get_embeddings_deprecation_warning():
...
@@ -405,6 +405,43 @@ def test_get_embeddings_deprecation_warning():
)
)
###############################################################################
# test test_handle_columns_with_dots
###############################################################################
def
test_handle_columns_with_dots
():
data
=
df
.
copy
()
data
=
data
.
rename
(
columns
=
{
"col1"
:
"col.1"
,
"a"
:
"a.1"
})
embed_cols
=
[(
"col.1"
,
5
),
(
"col2"
,
5
)]
continuous_cols
=
[
"col3"
,
"col4"
]
tab_preprocessor
=
TabPreprocessor
(
cat_embed_cols
=
embed_cols
,
continuous_cols
=
continuous_cols
)
X_tab
=
tab_preprocessor
.
fit_transform
(
data
)
target
=
data
.
target
.
values
tabmlp
=
TabMlp
(
mlp_hidden_dims
=
[
32
,
16
],
mlp_dropout
=
[
0.5
,
0.5
],
column_idx
=
{
k
:
v
for
v
,
k
in
enumerate
(
data
.
columns
)},
cat_embed_input
=
tab_preprocessor
.
cat_embed_input
,
continuous_cols
=
tab_preprocessor
.
continuous_cols
,
)
model
=
WideDeep
(
deeptabular
=
tabmlp
)
trainer
=
Trainer
(
model
,
objective
=
"binary"
,
verbose
=
0
)
trainer
.
fit
(
X_tab
=
X_tab
,
target
=
target
,
batch_size
=
16
,
)
preds
=
trainer
.
predict
(
X_tab
=
X_tab
,
batch_size
=
16
)
assert
preds
.
shape
[
0
]
==
32
and
"train_loss"
in
trainer
.
history
###############################################################################
###############################################################################
# test Label Distribution Smoothing
# test Label Distribution Smoothing
###############################################################################
###############################################################################
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录