diff --git a/train.py b/train.py index ccaeab99610f1bea46fed4479461f424aed90798..690e145495a109dc301e4d1cd7d62ce12988c0bc 100644 --- a/train.py +++ b/train.py @@ -31,7 +31,7 @@ if __name__ == "__main__": pretrained_dict = torch.load("model_data/ssd_weights.pth") pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) - model.load_state_dict(pretrained_dict) + model.load_state_dict(model_dict) print('Finished!') net = model