提交 ddbbe006 编写于 作者: U u010280923

opt reward model

上级 e1170613
......@@ -77,12 +77,6 @@ class RewardModel(pl.LightningModule):
path = Path(path)
assert path.exists()
self.load_state_dict(torch.load(str(path)))
def finetune_parameters(self):
return [
*self.to_pred.parameters(),
*self.rwkv.parameters()
]
def configure_optimizers(self):
# 论文中的参数:
......
......@@ -117,7 +117,9 @@ class train_callback(pl.Callback):
def on_train_epoch_start(self, trainer, pl_module):
args = self.args
dataset = trainer.train_dataloader.dataset.datasets
assert "MyDataset" in str(dataset) or "S2SDataset" in str(dataset)
assert "MyDataset" in str(dataset) \
or "S2SDataset" in str(dataset) \
or "RMDataset" in str(dataset)
dataset.global_rank = trainer.global_rank
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
dataset.world_size = trainer.world_size
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册