diff --git a/train_rm.py b/train_rm.py index b0ef71aa64c5a5baac30265cf5b09bb1d9819e0a..31a4a47674b90020edfb66d93f778d4c76a759f9 100644 --- a/train_rm.py +++ b/train_rm.py @@ -255,8 +255,10 @@ if __name__ == "__main__": print(f"{str(shape[0]).ljust(5)} {n}") if "deepspeed" in args.strategy: - trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 - trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + del trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] + del trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] + # trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + # trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 # must set shuffle=True, persistent_workers=False (because worker is in another thread) data_loader = DataLoader(train_data, shuffle=True, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)