From 77e4f4cbff0fcb9d5253b37ae57272367382205a Mon Sep 17 00:00:00 2001 From: u010280923 Date: Mon, 13 Mar 2023 14:16:38 +0800 Subject: [PATCH] debug --- src/rlhf/reward.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 2c74992..5eb8363 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -123,10 +123,21 @@ class RewardModel(pl.LightningModule): {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, ] + + optim_names = [ + {"params": lr_1x}, + {"params": lr_2x}, + {"params": lr_3x}, + ] + else: optim_groups = [ {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, ] + + optim_names = [ + {"params": [n for n, p in self.named_parameters()]}, + ] if self.deepspeed_offload: return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) -- GitLab