diff --git a/examples/main_adult.py b/examples/main_adult.py index 8dc438b92bdba078f50b47efd1c86ca7f654d995..49fd950c7ff313031bba9017d99405a521d8e54f 100644 --- a/examples/main_adult.py +++ b/examples/main_adult.py @@ -9,6 +9,14 @@ from pytorch_widedeep.utils.deep_utils import DeepProcessor from pytorch_widedeep.models.wide import Wide from pytorch_widedeep.models.deep_dense import DeepDense +from pytorch_widedeep.models.wide_deep import WideDeep + +from pytorch_widedeep.initializers import * +from pytorch_widedeep.optimizers import * +from pytorch_widedeep.lr_schedulers import * +from pytorch_widedeep.callbacks import * +from pytorch_widedeep.metrics import * + # use_cuda = torch.cuda.is_available() import pdb @@ -29,61 +37,46 @@ if __name__ == '__main__': cat_embed_cols = [('education',10), ('relationship',8), ('workclass',10), ('occupation',10),('native_country',10)] continuous_cols = ["age","hours_per_week"] + target = 'income_label' + target = df[target].values prepare_wide = WideProcessor(wide_cols=wide_cols, crossed_cols=crossed_cols) X_wide = prepare_wide.fit_transform(df) - prepare_deep = DeepProcessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols) X_deep = prepare_deep.fit_transform(df) - wide = Wide(X_wide.shape[1], 1) - pred_wide = wide(torch.tensor(X_wide[:10])) - - deep = DeepDense( + wide = Wide( + wide_dim=X_wide.shape[1], + output_dim=1) + deepdense = DeepDense( hidden_layers=[32,16], dropout=[0.5], deep_column_idx=prepare_deep.deep_column_idx, embed_input=prepare_deep.embeddings_input, continuous_cols=continuous_cols, - batchnorm=True, output_dim=1) - pred_deep = deep(torch.tensor(X_deep[:10])) + model = WideDeep(wide=wide, deepdense=deepdense) + + initializers = {'wide': Normal, 'deepdense':Normal} + optimizers = {'wide': Adam, 'deepdense':RAdam(lr=0.001)} + schedulers = {'wide': StepLR(step_size=5), 'deepdense':StepLR(step_size=5)} + + callbacks = [EarlyStopping, ModelCheckpoint(filepath='../model_weights/wd_out')] + metrics = [BinaryAccuracy] + + model.compile( + method='logistic', + initializers=initializers, + optimizers=optimizers, + lr_schedulers=schedulers, + callbacks=callbacks, + metrics=metrics) + + model.fit( + X_wide=X_wide, + X_deep=X_deep, + target=target, + n_epochs=10, + batch_size=256, + val_split=0.2) pdb.set_trace() - - - # wd_dataset = prepare_data(df, - # target=target, - # wide_cols=wide_cols, - # crossed_cols=crossed_cols, - # cat_embed_cols=cat_embed_cols, - # continuous_cols=continuous_cols) - - # model = WideDeep( - # output_dim=1, - # wide_dim=wd_dataset.wide.shape[1], - # cat_embed_input = wd_dataset.cat_embed_input, - # continuous_cols=wd_dataset.continuous_cols, - # deep_column_idx=wd_dataset.deep_column_idx) - - # initializers = {'wide': Normal, 'deepdense':Normal} - # optimizers = {'wide': Adam, 'deepdense':RAdam(lr=0.001)} - # schedulers = {'wide': StepLR(step_size=5), 'deepdense':StepLR(step_size=5)} - - # callbacks = [EarlyStopping, ModelCheckpoint(filepath='../model_weights/wd_out.pt')] - # metrics = [BinaryAccuracy] - - # model.compile( - # method='logistic', - # initializers=initializers, - # optimizers=optimizers, - # lr_schedulers=schedulers, - # callbacks=callbacks, - # metrics=metrics) - - # model.fit( - # X_wide=wd_dataset.wide, - # X_deep=wd_dataset.deepdense, - # target=wd_dataset.target, - # n_epochs=5, - # batch_size=256, - # val_split=0.2) \ No newline at end of file