Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Pytorch Widedeep
提交
fadede26
P
Pytorch Widedeep
项目概览
Greenplum
/
Pytorch Widedeep
10 个月 前同步成功
通知
9
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Pytorch Widedeep
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
fadede26
编写于
10月 07, 2021
作者:
J
jrzaurin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fixed issue #53 related to the use of some transformer models without categorical columns
上级
6540cd3c
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
37 addition
and
19 deletion
+37
-19
README.md
README.md
+1
-0
VERSION
VERSION
+1
-1
pypi_README.md
pypi_README.md
+1
-0
pytorch_widedeep/models/transformers/ft_transformer.py
pytorch_widedeep/models/transformers/ft_transformer.py
+1
-6
pytorch_widedeep/models/transformers/saint.py
pytorch_widedeep/models/transformers/saint.py
+1
-6
pytorch_widedeep/models/transformers/tab_fastformer.py
pytorch_widedeep/models/transformers/tab_fastformer.py
+0
-5
pytorch_widedeep/version.py
pytorch_widedeep/version.py
+1
-1
tests/test_model_components/test_mc_transformers.py
tests/test_model_components/test_mc_transformers.py
+31
-0
未找到文件。
README.md
浏览文件 @
fadede26
...
...
@@ -11,6 +11,7 @@
[
![Code style: black
](
https://img.shields.io/badge/code%20style-black-000000.svg
)
](https://github.com/psf/black)
[
![Maintenance
](
https://img.shields.io/badge/Maintained%3F-yes-green.svg
)
](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
[
![contributions welcome
](
https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat
)
](https://github.com/jrzaurin/pytorch-widedeep/issues)
[
![Slack
](
https://img.shields.io/badge/slack-chat-green.svg?logo=slack
)
](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)
# pytorch-widedeep
...
...
VERSION
浏览文件 @
fadede26
1.0.9
\ No newline at end of file
1.0.10
\ No newline at end of file
pypi_README.md
浏览文件 @
fadede26
...
...
@@ -6,6 +6,7 @@
[
![Code style: black
](
https://img.shields.io/badge/code%20style-black-000000.svg
)
](https://github.com/psf/black)
[
![Maintenance
](
https://img.shields.io/badge/Maintained%3F-yes-green.svg
)
](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
[
![contributions welcome
](
https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat
)
](https://github.com/jrzaurin/pytorch-widedeep/issues)
[
![Slack
](
https://img.shields.io/badge/slack-chat-green.svg?logo=slack
)
](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)
# pytorch-widedeep
...
...
pytorch_widedeep/models/transformers/ft_transformer.py
浏览文件 @
fadede26
...
...
@@ -134,7 +134,7 @@ class FTTransformer(nn.Module):
def
__init__
(
self
,
column_idx
:
Dict
[
str
,
int
],
embed_input
:
List
[
Tuple
[
str
,
int
]]
,
embed_input
:
Optional
[
List
[
Tuple
[
str
,
int
]]]
=
None
,
embed_dropout
:
float
=
0.1
,
full_embed_dropout
:
bool
=
False
,
shared_embed
:
bool
=
False
,
...
...
@@ -194,11 +194,6 @@ class FTTransformer(nn.Module):
self
.
n_cont
=
len
(
continuous_cols
)
if
continuous_cols
is
not
None
else
0
self
.
n_feats
=
self
.
n_cat
+
self
.
n_cont
if
self
.
n_cont
and
not
self
.
n_cat
and
not
self
.
embed_continuous
:
raise
ValueError
(
"If only continuous features are used 'embed_continuous' must be set to 'True'"
)
self
.
cat_and_cont_embed
=
CatAndContEmbeddings
(
input_dim
,
column_idx
,
...
...
pytorch_widedeep/models/transformers/saint.py
浏览文件 @
fadede26
...
...
@@ -120,7 +120,7 @@ class SAINT(nn.Module):
def
__init__
(
self
,
column_idx
:
Dict
[
str
,
int
],
embed_input
:
List
[
Tuple
[
str
,
int
]]
,
embed_input
:
Optional
[
List
[
Tuple
[
str
,
int
]]]
=
None
,
embed_dropout
:
float
=
0.1
,
full_embed_dropout
:
bool
=
False
,
shared_embed
:
bool
=
False
,
...
...
@@ -173,11 +173,6 @@ class SAINT(nn.Module):
self
.
n_cont
=
len
(
continuous_cols
)
if
continuous_cols
is
not
None
else
0
self
.
n_feats
=
self
.
n_cat
+
self
.
n_cont
if
self
.
n_cont
and
not
self
.
n_cat
and
not
self
.
embed_continuous
:
raise
ValueError
(
"If only continuous features are used 'embed_continuous' must be set to 'True'"
)
self
.
cat_and_cont_embed
=
CatAndContEmbeddings
(
input_dim
,
column_idx
,
...
...
pytorch_widedeep/models/transformers/tab_fastformer.py
浏览文件 @
fadede26
...
...
@@ -182,11 +182,6 @@ class TabFastFormer(nn.Module):
self
.
n_cont
=
len
(
continuous_cols
)
if
continuous_cols
is
not
None
else
0
self
.
n_feats
=
self
.
n_cat
+
self
.
n_cont
if
self
.
n_cont
and
not
self
.
n_cat
and
not
self
.
embed_continuous
:
raise
ValueError
(
"If only continuous features are used 'embed_continuous' must be set to 'True'"
)
self
.
cat_and_cont_embed
=
CatAndContEmbeddings
(
input_dim
,
column_idx
,
...
...
pytorch_widedeep/version.py
浏览文件 @
fadede26
__version__
=
"1.0.
9
"
__version__
=
"1.0.
10
"
tests/test_model_components/test_mc_transformers.py
浏览文件 @
fadede26
...
...
@@ -449,3 +449,34 @@ def test_ft_transformer_mlp(mlp_first_h, shoud_work):
else
:
with
pytest
.
raises
(
AssertionError
):
model
=
_build_model
(
"fttransformer"
,
params
)
# noqa: F841
###############################################################################
# Test transformers with only continuous cols
###############################################################################
X_tab_only_cont
=
torch
.
from_numpy
(
np
.
vstack
([
np
.
random
.
rand
(
10
)
for
_
in
range
(
4
)]).
transpose
()
)
colnames_only_cont
=
list
(
string
.
ascii_lowercase
)[:
4
]
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"fttransformer"
,
"saint"
,
"tabfastformer"
,
],
)
def
test_transformers_only_cont
(
model_name
):
params
=
{
"column_idx"
:
{
k
:
v
for
v
,
k
in
enumerate
(
colnames_only_cont
)},
"continuous_cols"
:
colnames_only_cont
,
}
model
=
_build_model
(
model_name
,
params
)
out
=
model
(
X_tab_only_cont
)
assert
out
.
size
(
0
)
==
10
and
out
.
size
(
1
)
==
model
.
output_dim
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录