diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 32835fb18ece8742180368889ad7bde3d5f21734..2a68d9e6525790fe574b82d4db2ae682635c3a8d 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -136,7 +136,7 @@ class RewardModel(pl.LightningModule): strategy = self.trainer.strategy if isinstance(strategy, DeepSpeedStrategy): cfg = strategy.config["zero_optimization"] - return cfg.get("offload_optimizer") or cfg.get("offload_param") + return bool(cfg.get("offload_optimizer") or cfg.get("offload_param")) return False def single_forward(