提交 142305ef 编写于 作者: U u010280923

bug fixed

上级 c0450690
......@@ -255,10 +255,8 @@ if __name__ == "__main__":
print(f"{str(shape[0]).ljust(5)} {n}")
if "deepspeed" in args.strategy:
del trainer.strategy.config["zero_optimization"]["allgather_bucket_size"]
del trainer.strategy.config["zero_optimization"]["reduce_bucket_size"]
# trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
# trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
# must set shuffle=True, persistent_workers=False (because worker is in another thread)
data_loader = DataLoader(train_data, shuffle=True, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册