diff --git a/train_rm.py b/train_rm.py index 31a4a47674b90020edfb66d93f778d4c76a759f9..b0ef71aa64c5a5baac30265cf5b09bb1d9819e0a 100644 --- a/train_rm.py +++ b/train_rm.py @@ -255,10 +255,8 @@ if __name__ == "__main__": print(f"{str(shape[0]).ljust(5)} {n}") if "deepspeed" in args.strategy: - del trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] - del trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] - # trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 - # trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 # must set shuffle=True, persistent_workers=False (because worker is in another thread) data_loader = DataLoader(train_data, shuffle=True, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)