From ddbbe006c4845739658989be8ae070eae5ed44f3 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 10 Mar 2023 15:12:25 +0800 Subject: [PATCH] opt reward model --- src/rlhf/reward.py | 6 ------ src/trainer.py | 4 +++- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index df77244..665188b 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -77,12 +77,6 @@ class RewardModel(pl.LightningModule): path = Path(path) assert path.exists() self.load_state_dict(torch.load(str(path))) - - def finetune_parameters(self): - return [ - *self.to_pred.parameters(), - *self.rwkv.parameters() - ] def configure_optimizers(self): # 论文中的参数: diff --git a/src/trainer.py b/src/trainer.py index e514d65..41cacbf 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -117,7 +117,9 @@ class train_callback(pl.Callback): def on_train_epoch_start(self, trainer, pl_module): args = self.args dataset = trainer.train_dataloader.dataset.datasets - assert "MyDataset" in str(dataset) or "S2SDataset" in str(dataset) + assert "MyDataset" in str(dataset) \ + or "S2SDataset" in str(dataset) \ + or "RMDataset" in str(dataset) dataset.global_rank = trainer.global_rank dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) dataset.world_size = trainer.world_size -- GitLab