From 6e8dbe4b204369786530d8a6331b2e825cf4b88c Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 10 Mar 2023 16:49:57 +0800 Subject: [PATCH] opt reward model --- src/rlhf/reward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 32835fb..2a68d9e 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -136,7 +136,7 @@ class RewardModel(pl.LightningModule): strategy = self.trainer.strategy if isinstance(strategy, DeepSpeedStrategy): cfg = strategy.config["zero_optimization"] - return cfg.get("offload_optimizer") or cfg.get("offload_param") + return bool(cfg.get("offload_optimizer") or cfg.get("offload_param")) return False def single_forward( -- GitLab