Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Pytorch Widedeep
提交
a4b99e57
P
Pytorch Widedeep
项目概览
Greenplum
/
Pytorch Widedeep
11 个月 前同步成功
通知
9
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Pytorch Widedeep
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a4b99e57
编写于
10月 13, 2019
作者:
J
jrzaurin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
tmp file to check that things work
上级
b667b36a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
38 addition
and
45 deletion
+38
-45
examples/main_adult.py
examples/main_adult.py
+38
-45
未找到文件。
examples/main_adult.py
浏览文件 @
a4b99e57
...
...
@@ -9,6 +9,14 @@ from pytorch_widedeep.utils.deep_utils import DeepProcessor
from
pytorch_widedeep.models.wide
import
Wide
from
pytorch_widedeep.models.deep_dense
import
DeepDense
from
pytorch_widedeep.models.wide_deep
import
WideDeep
from
pytorch_widedeep.initializers
import
*
from
pytorch_widedeep.optimizers
import
*
from
pytorch_widedeep.lr_schedulers
import
*
from
pytorch_widedeep.callbacks
import
*
from
pytorch_widedeep.metrics
import
*
# use_cuda = torch.cuda.is_available()
import
pdb
...
...
@@ -29,61 +37,46 @@ if __name__ == '__main__':
cat_embed_cols
=
[(
'education'
,
10
),
(
'relationship'
,
8
),
(
'workclass'
,
10
),
(
'occupation'
,
10
),(
'native_country'
,
10
)]
continuous_cols
=
[
"age"
,
"hours_per_week"
]
target
=
'income_label'
target
=
df
[
target
].
values
prepare_wide
=
WideProcessor
(
wide_cols
=
wide_cols
,
crossed_cols
=
crossed_cols
)
X_wide
=
prepare_wide
.
fit_transform
(
df
)
prepare_deep
=
DeepProcessor
(
embed_cols
=
cat_embed_cols
,
continuous_cols
=
continuous_cols
)
X_deep
=
prepare_deep
.
fit_transform
(
df
)
wide
=
Wide
(
X_wide
.
shape
[
1
],
1
)
pred_wide
=
wide
(
torch
.
tensor
(
X_wide
[:
10
]))
deep
=
DeepDense
(
wide
=
Wide
(
wide_dim
=
X_wide
.
shape
[
1
],
output_dim
=
1
)
deep
dense
=
DeepDense
(
hidden_layers
=
[
32
,
16
],
dropout
=
[
0.5
],
deep_column_idx
=
prepare_deep
.
deep_column_idx
,
embed_input
=
prepare_deep
.
embeddings_input
,
continuous_cols
=
continuous_cols
,
batchnorm
=
True
,
output_dim
=
1
)
pred_deep
=
deep
(
torch
.
tensor
(
X_deep
[:
10
]))
model
=
WideDeep
(
wide
=
wide
,
deepdense
=
deepdense
)
initializers
=
{
'wide'
:
Normal
,
'deepdense'
:
Normal
}
optimizers
=
{
'wide'
:
Adam
,
'deepdense'
:
RAdam
(
lr
=
0.001
)}
schedulers
=
{
'wide'
:
StepLR
(
step_size
=
5
),
'deepdense'
:
StepLR
(
step_size
=
5
)}
callbacks
=
[
EarlyStopping
,
ModelCheckpoint
(
filepath
=
'../model_weights/wd_out'
)]
metrics
=
[
BinaryAccuracy
]
model
.
compile
(
method
=
'logistic'
,
initializers
=
initializers
,
optimizers
=
optimizers
,
lr_schedulers
=
schedulers
,
callbacks
=
callbacks
,
metrics
=
metrics
)
model
.
fit
(
X_wide
=
X_wide
,
X_deep
=
X_deep
,
target
=
target
,
n_epochs
=
10
,
batch_size
=
256
,
val_split
=
0.2
)
pdb
.
set_trace
()
# wd_dataset = prepare_data(df,
# target=target,
# wide_cols=wide_cols,
# crossed_cols=crossed_cols,
# cat_embed_cols=cat_embed_cols,
# continuous_cols=continuous_cols)
# model = WideDeep(
# output_dim=1,
# wide_dim=wd_dataset.wide.shape[1],
# cat_embed_input = wd_dataset.cat_embed_input,
# continuous_cols=wd_dataset.continuous_cols,
# deep_column_idx=wd_dataset.deep_column_idx)
# initializers = {'wide': Normal, 'deepdense':Normal}
# optimizers = {'wide': Adam, 'deepdense':RAdam(lr=0.001)}
# schedulers = {'wide': StepLR(step_size=5), 'deepdense':StepLR(step_size=5)}
# callbacks = [EarlyStopping, ModelCheckpoint(filepath='../model_weights/wd_out.pt')]
# metrics = [BinaryAccuracy]
# model.compile(
# method='logistic',
# initializers=initializers,
# optimizers=optimizers,
# lr_schedulers=schedulers,
# callbacks=callbacks,
# metrics=metrics)
# model.fit(
# X_wide=wd_dataset.wide,
# X_deep=wd_dataset.deepdense,
# target=wd_dataset.target,
# n_epochs=5,
# batch_size=256,
# val_split=0.2)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录