diff --git a/src/model.py b/src/model.py index cae259471164a6600e63adf4ae187a8f9d1dd13e..ee5169d9f4530b0f6ffd9ca13244616ed8a85384 100644 --- a/src/model.py +++ b/src/model.py @@ -385,7 +385,7 @@ class RWKV(pl.LightningModule): def load(self, path): path = Path(path) assert path.exists() - self.load_state_dict(torch.load(str(path)), map_location="cpu") + self.load_state_dict(torch.load(str(path), map_location="cpu")) def configure_optimizers(self): args = self.args