From c0450690a148b40cb5e757f5191814decf82fe6f Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 10 Mar 2023 18:12:23 +0800 Subject: [PATCH] bug fixed --- train_rm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train_rm.py b/train_rm.py index b0ef71a..31a4a47 100644 --- a/train_rm.py +++ b/train_rm.py @@ -255,8 +255,10 @@ if __name__ == "__main__": print(f"{str(shape[0]).ljust(5)} {n}") if "deepspeed" in args.strategy: - trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 - trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + del trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] + del trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] + # trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + # trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 # must set shuffle=True, persistent_workers=False (because worker is in another thread) data_loader = DataLoader(train_data, shuffle=True, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) -- GitLab