提交 a4b99e57 编写于 作者: J jrzaurin

tmp file to check that things work

上级 b667b36a
......@@ -9,6 +9,14 @@ from pytorch_widedeep.utils.deep_utils import DeepProcessor
from pytorch_widedeep.models.wide import Wide
from pytorch_widedeep.models.deep_dense import DeepDense
from pytorch_widedeep.models.wide_deep import WideDeep
from pytorch_widedeep.initializers import *
from pytorch_widedeep.optimizers import *
from pytorch_widedeep.lr_schedulers import *
from pytorch_widedeep.callbacks import *
from pytorch_widedeep.metrics import *
# use_cuda = torch.cuda.is_available()
import pdb
......@@ -29,61 +37,46 @@ if __name__ == '__main__':
cat_embed_cols = [('education',10), ('relationship',8), ('workclass',10),
('occupation',10),('native_country',10)]
continuous_cols = ["age","hours_per_week"]
target = 'income_label'
target = df[target].values
prepare_wide = WideProcessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
X_wide = prepare_wide.fit_transform(df)
prepare_deep = DeepProcessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)
X_deep = prepare_deep.fit_transform(df)
wide = Wide(X_wide.shape[1], 1)
pred_wide = wide(torch.tensor(X_wide[:10]))
deep = DeepDense(
wide = Wide(
wide_dim=X_wide.shape[1],
output_dim=1)
deepdense = DeepDense(
hidden_layers=[32,16],
dropout=[0.5],
deep_column_idx=prepare_deep.deep_column_idx,
embed_input=prepare_deep.embeddings_input,
continuous_cols=continuous_cols,
batchnorm=True,
output_dim=1)
pred_deep = deep(torch.tensor(X_deep[:10]))
model = WideDeep(wide=wide, deepdense=deepdense)
initializers = {'wide': Normal, 'deepdense':Normal}
optimizers = {'wide': Adam, 'deepdense':RAdam(lr=0.001)}
schedulers = {'wide': StepLR(step_size=5), 'deepdense':StepLR(step_size=5)}
callbacks = [EarlyStopping, ModelCheckpoint(filepath='../model_weights/wd_out')]
metrics = [BinaryAccuracy]
model.compile(
method='logistic',
initializers=initializers,
optimizers=optimizers,
lr_schedulers=schedulers,
callbacks=callbacks,
metrics=metrics)
model.fit(
X_wide=X_wide,
X_deep=X_deep,
target=target,
n_epochs=10,
batch_size=256,
val_split=0.2)
pdb.set_trace()
# wd_dataset = prepare_data(df,
# target=target,
# wide_cols=wide_cols,
# crossed_cols=crossed_cols,
# cat_embed_cols=cat_embed_cols,
# continuous_cols=continuous_cols)
# model = WideDeep(
# output_dim=1,
# wide_dim=wd_dataset.wide.shape[1],
# cat_embed_input = wd_dataset.cat_embed_input,
# continuous_cols=wd_dataset.continuous_cols,
# deep_column_idx=wd_dataset.deep_column_idx)
# initializers = {'wide': Normal, 'deepdense':Normal}
# optimizers = {'wide': Adam, 'deepdense':RAdam(lr=0.001)}
# schedulers = {'wide': StepLR(step_size=5), 'deepdense':StepLR(step_size=5)}
# callbacks = [EarlyStopping, ModelCheckpoint(filepath='../model_weights/wd_out.pt')]
# metrics = [BinaryAccuracy]
# model.compile(
# method='logistic',
# initializers=initializers,
# optimizers=optimizers,
# lr_schedulers=schedulers,
# callbacks=callbacks,
# metrics=metrics)
# model.fit(
# X_wide=wd_dataset.wide,
# X_deep=wd_dataset.deepdense,
# target=wd_dataset.target,
# n_epochs=5,
# batch_size=256,
# val_split=0.2)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册