From 142305ef661e6b2c1b3da9943cf68702b331e069 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 10 Mar 2023 18:21:45 +0800 Subject: [PATCH] bug fixed --- train_rm.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/train_rm.py b/train_rm.py index 31a4a47..b0ef71a 100644 --- a/train_rm.py +++ b/train_rm.py @@ -255,10 +255,8 @@ if __name__ == "__main__": print(f"{str(shape[0]).ljust(5)} {n}") if "deepspeed" in args.strategy: - del trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] - del trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] - # trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 - # trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 # must set shuffle=True, persistent_workers=False (because worker is in another thread) data_loader = DataLoader(train_data, shuffle=True, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) -- GitLab