diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index df7724448c03794cff15cd049850639bd667c881..665188b6171a899e13932363eb2a32154d9defba 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -77,12 +77,6 @@ class RewardModel(pl.LightningModule): path = Path(path) assert path.exists() self.load_state_dict(torch.load(str(path))) - - def finetune_parameters(self): - return [ - *self.to_pred.parameters(), - *self.rwkv.parameters() - ] def configure_optimizers(self): # 论文中的参数: diff --git a/src/trainer.py b/src/trainer.py index e514d6504beaf58c3dba686f23ffa4a0b95b6318..41cacbf01d21a86d164ed4249d7745269d8f1e4c 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -117,7 +117,9 @@ class train_callback(pl.Callback): def on_train_epoch_start(self, trainer, pl_module): args = self.args dataset = trainer.train_dataloader.dataset.datasets - assert "MyDataset" in str(dataset) or "S2SDataset" in str(dataset) + assert "MyDataset" in str(dataset) \ + or "S2SDataset" in str(dataset) \ + or "RMDataset" in str(dataset) dataset.global_rank = trainer.global_rank dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) dataset.world_size = trainer.world_size