diff --git a/examples/adult_script.py b/examples/adult_script.py index cdcb77501956c16cb94069d299c9df1b5342a782..46f72d577bfd0d7377e32ae2a451a40d6bb772f2 100644 --- a/examples/adult_script.py +++ b/examples/adult_script.py @@ -95,3 +95,13 @@ if __name__ == "__main__": batch_size=64, val_split=0.2, ) + # # to save/load the model + # torch.save(model, "model_weights/model.t") + # model = torch.load("model_weights/model.t") + + # # or via state dictionaries + # torch.save(model.state_dict(), "model_weights/model_dict.t") + # model = WideDeep(wide=wide, deepdense=deepdense) + # model.load_state_dict(torch.load("model_weights/model_dict.t")) + # # + import pdb; pdb.set_trace() # breakpoint dde47114 //