diff --git a/forward_demo.py b/forward_demo.py index 9de47db04df4bea05d8b19028dd9e05c6a9c4d32..28571d88ac4455d3004118539199b68f46c668ff 100644 --- a/forward_demo.py +++ b/forward_demo.py @@ -263,6 +263,8 @@ if __name__ == "__main__": trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + model.deepspeed_offload() + seq = torch.randint(0, 50277, (1, 100)) model(seq)