Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Pytorch Widedeep
提交
75c37e12
P
Pytorch Widedeep
项目概览
Greenplum
/
Pytorch Widedeep
11 个月 前同步成功
通知
9
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Pytorch Widedeep
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
75c37e12
编写于
11月 23, 2021
作者:
J
jrzaurin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fixing style and small adjustment of the CategoricalEmbeddings class
上级
3532a980
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
23 addition
and
17 deletion
+23
-17
pytorch_widedeep/callbacks.py
pytorch_widedeep/callbacks.py
+1
-1
pytorch_widedeep/datasets/__init__.py
pytorch_widedeep/datasets/__init__.py
+1
-1
pytorch_widedeep/datasets/_base.py
pytorch_widedeep/datasets/_base.py
+1
-0
pytorch_widedeep/models/transformers/_embeddings_layers.py
pytorch_widedeep/models/transformers/_embeddings_layers.py
+17
-13
tests/test_datasets/test_datasets.py
tests/test_datasets/test_datasets.py
+3
-2
未找到文件。
pytorch_widedeep/callbacks.py
浏览文件 @
75c37e12
...
...
@@ -498,7 +498,7 @@ class ModelCheckpoint(Callback):
)
)
if
self
.
wb
is
not
None
:
self
.
wb
.
run
.
summary
[
"best"
]
=
current
self
.
wb
.
run
.
summary
[
"best"
]
=
current
# type: ignore[attr-defined]
self
.
best
=
current
self
.
best_epoch
=
epoch
self
.
best_state_dict
=
self
.
model
.
state_dict
()
...
...
pytorch_widedeep/datasets/__init__.py
浏览文件 @
75c37e12
from
._base
import
load_
bio_kdd04
,
load_adult
from
._base
import
load_
adult
,
load_bio_kdd04
__all__
=
[
"load_bio_kdd04"
,
"load_adult"
]
pytorch_widedeep/datasets/_base.py
浏览文件 @
75c37e12
from
importlib
import
resources
import
pandas
as
pd
...
...
pytorch_widedeep/models/transformers/_embeddings_layers.py
浏览文件 @
75c37e12
...
...
@@ -4,6 +4,7 @@ https://github.com/awslabs/autogluon/tree/master/tabular/src/autogluon/tabular/m
"""
import
math
import
warnings
import
torch
from
torch
import
nn
...
...
@@ -19,9 +20,9 @@ class FullEmbeddingDropout(nn.Module):
def
forward
(
self
,
X
:
Tensor
)
->
Tensor
:
if
self
.
training
:
mask
=
X
.
new
().
resize_
((
X
.
size
(
1
),
1
)).
bernoulli_
(
1
-
self
.
dropout
).
expand_as
(
X
)
/
(
1
-
self
.
dropout
)
mask
=
X
.
new
().
resize_
((
X
.
size
(
1
),
1
)).
bernoulli_
(
1
-
self
.
dropout
)
.
expand_as
(
X
)
/
(
1
-
self
.
dropout
)
return
mask
*
X
else
:
return
X
...
...
@@ -128,13 +129,16 @@ class CategoricalEmbeddings(nn.Module):
self
.
categorical_cols
=
[
ei
[
0
]
for
ei
in
embed_input
]
self
.
cat_idx
=
[
self
.
column_idx
[
col
]
for
col
in
self
.
categorical_cols
]
self
.
bias
=
(
nn
.
Parameter
(
torch
.
Tensor
(
len
(
self
.
categorical_cols
),
embed_dim
))
if
use_bias
else
None
)
if
self
.
bias
is
not
None
:
if
use_bias
is
not
None
:
self
.
bias
=
nn
.
Parameter
(
torch
.
Tensor
(
len
(
self
.
categorical_cols
),
embed_dim
)
)
nn
.
init
.
kaiming_uniform_
(
self
.
bias
,
a
=
math
.
sqrt
(
5
))
if
shared_embed
:
warnings
.
warn
(
"The current implementation of 'SharedEmbeddings' does not use bias"
,
UserWarning
,
)
# Categorical: val + 1 because 0 is reserved for padding/unseen cateogories.
if
self
.
shared_embed
:
...
...
@@ -170,11 +174,11 @@ class CategoricalEmbeddings(nn.Module):
x
=
torch
.
cat
(
cat_embed
,
1
)
else
:
x
=
self
.
embed
(
X
[:,
self
.
cat_idx
].
long
())
if
self
.
bias
is
not
None
:
x
=
x
+
self
.
bias
.
unsqueeze
(
0
)
x
=
self
.
dropout
(
x
)
if
self
.
bias
is
not
None
:
x
=
x
+
self
.
bias
.
unsqueeze
(
0
)
return
self
.
dropout
(
x
)
return
x
class
CatAndContEmbeddings
(
nn
.
Module
):
...
...
tests/test_datasets/test_datasets.py
浏览文件 @
75c37e12
from
pytorch_widedeep.datasets
import
load_bio_kdd04
,
load_adult
import
pandas
as
pd
import
numpy
as
np
import
pandas
as
pd
import
pytest
from
pytorch_widedeep.datasets
import
load_adult
,
load_bio_kdd04
@
pytest
.
mark
.
parametrize
(
"as_frame"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录