diff --git a/labml_nn/scaling/zero3/finetune_neox.py b/labml_nn/scaling/zero3/finetune_neox.py index d83312cdefbfbb71635824d550f12fb25ea318e5..acd7aeb5f35db31522b48421a1ff8ac0fe0d1ec4 100644 --- a/labml_nn/scaling/zero3/finetune_neox.py +++ b/labml_nn/scaling/zero3/finetune_neox.py @@ -80,8 +80,9 @@ def main(rank: int, world_size: int, init_method: str = 'tcp://localhost:23456') torch.cuda.set_device(device) # Create the experiment - experiment.create(name='zero3_neox', writers={'screen', 'labml'}) - experiment.distributed(rank, world_size) + experiment.create(name='zero3_neox', writers={'screen', 'labml'}, + distributed_world_size=world_size, + distributed_rank=rank) # Create configurations conf = Configs()