config-json.md 50.4 KB
Newer Older
S
Shaden Smith 已提交
1 2 3
---
title: "DeepSpeed Configuration JSON"
---
4 5 6

### Batch Size Related Parameters

7
**Note:** <i>**train_batch_size**</i> must be equal to  <i>**train_micro_batch_size_per_gpu**</i> * <i>**gradient_accumulation**</i> * number of GPUs. For simplicty, you can choose to only specify two of the three parameters, the last one will be inferred automatically by DeepSpeed.
8
{: .notice--warning}
S
Shaden Smith 已提交
9

10
<i>**train_batch_size**</i>: [integer]
S
Shaden Smith 已提交
11

12 13
| Value                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        | Example |
| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
14
| The effective training batch size. This is the amount of data samples that leads to one step of model update. <i>**train_batch_size**</i> is aggregated by the batch size that a single GPU processes in one forward/backward pass (a.k.a., <i>**train_micro_batch_size_per_gpu**</i>),  the gradient accumulation steps (a.k.a., <i>**gradient_accumulation_steps**</i>), and the number of GPUs. Can be omitted if both <i>**train_micro_batch_size_per_gpu**</i> and <i>**gradient_accumulation_steps**</i> are provided. | `32`    |
S
Shaden Smith 已提交
15 16


17
<i>**train_micro_batch_size_per_gpu**</i>: [integer]
S
Shaden Smith 已提交
18

19 20
| Description                                                                                                                                                                                    | Default                           |
| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------- |
21
| Batch size to be processed by one GPU in one step (without gradient accumulation). Can be omitted if both <i>**train_batch_size**</i> and <i>**gradient_accumulation_steps**</i> are provided. | <i>**train_batch_size**</i> value |
S
Shaden Smith 已提交
22

23
<i>**gradient_accumulation_steps**</i>: [integer]
S
Shaden Smith 已提交
24

25 26
| Description                                                                                                                                                                                                                                                                                                                                                                                                                     | Default |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
27
| Number of training steps to accumulate gradients before averaging and applying them. This feature is sometimes useful to improve scalability since it results in less frequent communication of gradients between steps. Another impact of this feature is the ability to train with larger batch sizes per GPU. Can be omitted if both <i>**train_batch_size**</i> and <i>**train_micro_batch_size_per_gpu**</i> are provided. | `1`     |
S
Shaden Smith 已提交
28 29 30 31 32



### Optimizer Parameters

33
<i>**optimizer**</i>: [dictionary]
S
Shaden Smith 已提交
34

35 36
| Fields | Value                                                                                                                                                                                                                                                                                                        | Example                      |
| ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------- |
C
Conglong Li 已提交
37
| type   | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, **Lamb**, and **OneBitLamb** optimizers (See [here](https://deepspeed.readthedocs.io/en/latest/optimizers.html) for details) and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"`                     |
38
| params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)).                                                                                                       | `{"lr": 0.001, "eps": 1e-8}` |
S
Shaden Smith 已提交
39

40
  Example of <i>**optimizer**</i> with Adam
S
Shaden Smith 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

```json
"optimizer": {
    "type": "Adam",
    "params": {
      "lr": 0.001,
      "betas": [
        0.8,
        0.999
      ],
      "eps": 1e-8,
      "weight_decay": 3e-7
    }
  }
```
56
The Adam optimizer also supports the following two params keys/values in addition to the standard parameters from [torch.optim.Adam](https://pytorch.org/docs/stable/_modules/torch/optim/adam.html#Adam):
S
Stas Bekman 已提交
57

58
| "params" key  | Description                                                                 | Default |
C
Cheng Li 已提交
59
| ------------- | --------------------------------------------------------------------------- | ------- |
60 61 62
| torch\_adam   | Use torch's implementation of adam instead of our fused adam implementation | false   |
| adam\_w\_mode | Apply L2 regularization (also known as AdamW)                               | true    |

63
Another example of <i>**optimizer**</i> with 1-bit Adam specific parameters is as follows.
64 65 66 67 68 69 70 71 72 73 74 75 76

```json
"optimizer": {
    "type": "OneBitAdam",
    "params": {
      "lr": 0.001,
      "betas": [
        0.8,
        0.999
      ],
      "eps": 1e-8,
      "weight_decay": 3e-7,
      "freeze_step": 400,
C
Conglong Li 已提交
77 78
      "cuda_aware": false,
      "comm_backend_name": "nccl"
79 80 81
    }
  }
```
S
Shaden Smith 已提交
82

C
Conglong Li 已提交
83 84
The 1-bit Adam optimizer supports the following three params keys/values in addition to the standard Adam (learn more in our [tutorial](/tutorials/onebit-adam/)):

85 86 87 88 89
| "params" key        | Description                                                                        | Default |
| ------------------- | ---------------------------------------------------------------------------------- | ------- |
| freeze\_step        | Number of warm up steps before 1-bit compression gets applied to the communication | 100000  |
| cuda\_aware         | To indicate that the underlying MPI library supports CUDA-Aware communication      | false   |
| comm\_backend\_name | To indicate which backend implementation to use                                    | "nccl"  |
C
Conglong Li 已提交
90

C
Conglong Li 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
Another example of ***optimizer*** with 1-bit LAMB

```json
"optimizer": {
    "type": "OneBitLamb",
    "params": {
      "lr": 11e-3,
      "weight_decay": 0.01,
      "bias_correction": false,
      "max_coeff": 0.3,
      "min_coeff": 0.01,
      "freeze_step": 1000,
      "cuda_aware": false,
      "comm_backend_name": "nccl",
      "coeff_beta": 0.9,
      "factor_max": 4.0,
      "factor_min": 0.5,
      "factor_threshold": 0.1
    }
  }
```

The 1-bit LAMB optimizer supports the following params keys/values in addition to the standard LAMB (learn more in our [tutorial](/tutorials/onebit-lamb/)):

115 116 117 118 119 120 121 122 123 124 125
| "params" key        | Description                                                                               | Default |
| ------------------- | ----------------------------------------------------------------------------------------- | ------- |
| max\_coeff          | Scaling coefficient upper bound for original LAMB algorithm and 1-bit LAMB's warmup stage | 10.0    |
| min\_coeff          | Scaling coefficient lower bound for original LAMB algorithm and 1-bit LAMB's warmup stage | 0.01    |
| freeze\_step        | Number of warm up steps before 1-bit compression gets applied to the communication        | 100000  |
| cuda\_aware         | To indicate that the underlying MPI library supports CUDA-Aware communication             | false   |
| comm\_backend\_name | To indicate which backend implementation to use                                           | "nccl"  |
| coeff\_beta         | Coefficient used for computing running averages of lamb coefficient                       | 0.9     |
| factor\_max         | Maximum value of scaling factor to the frozen lamb coefficient during compression stage   | 4.0     |
| factor\_min         | Minimum value of scaling factor to the frozen lamb coefficient during compression stage   | 0.5     |
| factor\_threshold   | Threshold of how much the scaling factor can fluctuate between steps                      | 0.1     |
C
Conglong Li 已提交
126

S
Shaden Smith 已提交
127 128
### Scheduler Parameters

129

130 131
DeepSpeed calls the `step()` method of the scheduler at every training step when `model_engine.step()` is executed.

S
Shaden Smith 已提交
132 133
***scheduler***: [dictionary]

134 135 136 137
| Fields | Value                                                                                                                      | Example                                        |
| ------ | -------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------- |
| type   | The scheduler name. See [here](https://deepspeed.readthedocs.io/en/latest/schedulers.html) for list of support schedulers. | `"WarmupLR"`                                   |
| params | Dictionary of parameters to instantiate scheduler. The parameter names should match scheduler constructor signature.       | `{"warmup_min_lr": 0, "warmup_max_lr": 0.001}` |
S
Shaden Smith 已提交
138

139
Example of <i>**scheduler**</i>
S
Shaden Smith 已提交
140 141 142 143 144 145 146 147 148

```json
 "scheduler": {
      "type": "WarmupLR",
      "params": {
          "warmup_min_lr": 0,
          "warmup_max_lr": 0.001,
          "warmup_num_steps": 1000
      }
S
Stas Bekman 已提交
149
  }
S
Shaden Smith 已提交
150 151 152 153
```

### Communication options

154
<i>**fp32_allreduce**</i>: [boolean]
S
Shaden Smith 已提交
155

C
Cheng Li 已提交
156 157 158
| Description                                                    | Default |
| -------------------------------------------------------------- | ------- |
| During gradient averaging perform allreduce with 32 bit values | `false` |
S
Shaden Smith 已提交
159

160
<i>**prescale_gradients**</i>: [boolean]
S
Shaden Smith 已提交
161 162 163

| Description                            | Default |
| -------------------------------------- | ------- |
C
Cheng Li 已提交
164
| Scale gradients before doing allreduce | `false` |
S
Shaden Smith 已提交
165

166
<i>**gradient_predivide_factor**</i>: [float]
167

C
Cheng Li 已提交
168 169 170
| Description                                                                                                                                       | Default |
| ------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Before gradient averaging predivide gradients by a specified factor, can sometimes help with fp16 stability when scaling to large numbers of GPUs | `1.0`   |
171

172
<i>**sparse_gradients**</i>: [boolean]
S
Shaden Smith 已提交
173

C
Cheng Li 已提交
174 175 176
| Description                                                                                                              | Default |
| ------------------------------------------------------------------------------------------------------------------------ | ------- |
| Enable sparse compression of [torch.nn.Embedding](https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding) gradients. | `false` |
S
Shaden Smith 已提交
177 178 179

### FP16 training options

J
Jeff Rasley 已提交
180 181 182
**Note:** this mode cannot be combined with the `amp` mode described below.
{: .notice--warning}

183
<i>**fp16**</i>: [dictionary]
S
Shaden Smith 已提交
184

C
Cheng Li 已提交
185 186
| Description                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                | Default |
| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
J
Jeff Rasley 已提交
187
| Configuration for using mixed precision/FP16 training that leverages [NVIDIA's Apex package](https://nvidia.github.io/apex/). An example, including the available dictionary keys is illustrated below. NOTE: this does not use Apex's AMP mode that allows for more flexibility in mixed precision training modes, this mode is similar to AMP's O2 mode. Please see AMP support below if you want to use more complex mixed precision modes. If you want to use ZeRO (currently) you must use this mode. | None    |
S
Shaden Smith 已提交
188 189 190 191 192 193 194 195

```json
"fp16": {
    "enabled": true,
    "loss_scale": 0,
    "initial_scale_power": 32,
    "loss_scale_window": 1000,
    "hysteresis": 2,
J
Jeff Rasley 已提交
196
    "min_loss_scale": 1
S
Shaden Smith 已提交
197 198 199
}
```

200
<i>**fp16:enabled**</i>: [boolean]
S
Shaden Smith 已提交
201

202 203
| Description                                                                                 | Default |
| ------------------------------------------------------------------------------------------- | ------- |
204
| <i>**enabled**</i> is a **fp16** parameter indicating whether or not FP16 training enabled. | `false` |
S
Shaden Smith 已提交
205

206
<i>**fp16:loss_scale**</i>: [float]
S
Shaden Smith 已提交
207

208 209
| Description                                                                                                                                                                                                                           | Default |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
210
| <i>**loss_scale**</i> is a <i>**fp16**</i> parameter representing the loss scaling value for FP16 training. The default value of 0.0 results in dynamic loss scaling, otherwise the value will be used for static fixed loss scaling. | `0.0`   |
S
Shaden Smith 已提交
211

212
<i>**fp16:initial_scale_power**</i>: [integer]
S
Shaden Smith 已提交
213

214 215
| Description                                                                                                                                                                                             | Default |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
216
| <i>**initial_scale_power**</i> is a **fp16** parameter representing the power of the initial dynamic loss scale value. The actual loss scale is computed as 2<sup><i>**initial_scale_power**</i></sup>. | `32`    |
S
Shaden Smith 已提交
217

218
<i>**fp16:loss_scale_window**</i>: [integer]
S
Shaden Smith 已提交
219

220 221
| Description                                                                                                                          | Default |
| ------------------------------------------------------------------------------------------------------------------------------------ | ------- |
222
| <i>**loss_scale_window**</i> is a **fp16** parameter representing the window over which to raise/lower the dynamic loss scale value. | `1000`  |
S
Shaden Smith 已提交
223

224
<i>**fp16:hysteresis**</i>: [integer]
S
Shaden Smith 已提交
225

226 227
| Description                                                                                         | Default |
| --------------------------------------------------------------------------------------------------- | ------- |
228
| <i>**hysteresis**</i> is a **fp16** parameter representing the delay shift in dynamic loss scaling. | `2`     |
S
Shaden Smith 已提交
229

230
<i>**fp16:min_loss_scale**</i>: [integer]
S
Shaden Smith 已提交
231

232 233
| Description                                                                                           | Default |
| ----------------------------------------------------------------------------------------------------- | ------- |
234
| <i>**min_loss_scale**</i> is  a **fp16** parameter representing the minimum dynamic loss scale value. | `1000`  |
S
Shaden Smith 已提交
235

J
Jeff Rasley 已提交
236 237 238 239 240
### Automatic mixed precision (AMP) training options

**Note:** this mode cannot be combined with the `fp16` mode described above. In addition this mode is not currently compatible with ZeRO.
{: .notice--warning}

241
<i>**amp**</i>: [dictionary]
J
Jeff Rasley 已提交
242

C
Cheng Li 已提交
243 244
| Description                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     | Default |
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
J
Jeff Rasley 已提交
245 246 247 248 249 250 251 252 253 254 255
| Configuration for using automatic mixed precision (AMP) training that leverages [NVIDIA's Apex AMP package](https://nvidia.github.io/apex/). An example, including the available dictionary keys is illustrated below. Is not compatible with `fp16` mode above or ZeRO. Any parameters outside of "enabled" will be passed to AMP's initialize call, see the API and descriptions here at the [apex.amp.initialize documentation](https://nvidia.github.io/apex/amp.html#apex.amp.initialize). | None    |

```json
"amp": {
    "enabled": true,
    ...
    "opt_level": "O1",
    ...
}
```

256
<i>**amp:enabled**</i>: [boolean]
J
Jeff Rasley 已提交
257

258 259
| Description                                                                                   | Default |
| --------------------------------------------------------------------------------------------- | ------- |
260
| <i>**enabled**</i> is an **amp** parameter indicating whether or not AMP training is enabled. | `false` |
J
Jeff Rasley 已提交
261 262 263

***amp params***: [various]

C
Cheng Li 已提交
264 265
| Description                                                                                                                                                                                                            | Default |
| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
J
Jeff Rasley 已提交
266 267
| Any parameters outside of "enabled" will be passed to AMP's initialize call, see the API and descriptions here at the [apex.amp.initialize documentation](https://nvidia.github.io/apex/amp.html#apex.amp.initialize). | None    |

S
Shaden Smith 已提交
268 269
### Gradient Clipping

270
<i>**gradient_clipping**</i>: [float]
S
Shaden Smith 已提交
271 272 273

| Description                         | Default |
| ----------------------------------- | ------- |
274
| Enable gradient clipping with value | `1.0`   |
S
Shaden Smith 已提交
275

J
Jeff Rasley 已提交
276 277 278 279


### ZeRO Optimizations for FP16 Training

S
Stas Bekman 已提交
280
Enabling and configuring ZeRO memory optimizations
J
Jeff Rasley 已提交
281 282
```json
  "zero_optimization": {
S
Samyam Rajbhandari 已提交
283
    "stage": [0|1|2|3],
J
Jeff Rasley 已提交
284
    "allgather_partitions": [true|false],
S
Stas Bekman 已提交
285
    "allgather_bucket_size": 5e8,
286
    "overlap_comm": false,
J
Jeff Rasley 已提交
287
    "reduce_scatter": [true|false],
S
Stas Bekman 已提交
288
    "reduce_bucket_size": 5e8,
O
Olatunji Ruwase 已提交
289
    "contiguous_gradients" : [true|false],
J
Jeff Rasley 已提交
290 291 292 293 294 295
    "offload_param": {
      ...
    },
    "offload_optimizer": {
      ...
    },
S
Samyam Rajbhandari 已提交
296 297 298 299 300
    "stage3_max_live_parameters" : 1e9,
    "stage3_max_reuse_distance" : 1e9,
    "stage3_prefetch_bucket_size" : 5e8,
    "stage3_param_persistence_threshold" : 1e6,
    "sub_group_size" : 1e12,
S
Stas Bekman 已提交
301
    "elastic_checkpoint" : [true|false],
302
    "stage3_gather_fp16_weights_on_model_save": [true|false],
303
    "ignore_unused_parameters": [true|false]
J
Jeff Rasley 已提交
304 305 306
    }
```

307
<i>**zero_optimization**</i>: [dictionary]
J
Jeff Rasley 已提交
308

C
Cheng Li 已提交
309 310 311
| Description                                                                                               | Default |
| --------------------------------------------------------------------------------------------------------- | ------- |
| Enable ZeRO memory optimization wrapper for FP16 Training. Currently compatible only with Adam optimizer. | `false` |
J
Jeff Rasley 已提交
312

313
<i>**stage**</i>: [integer]
J
Jeff Rasley 已提交
314

315 316
| Description                                                                                                                                                                                                               | Default |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
S
Samyam Rajbhandari 已提交
317
| Chooses different stages of ZeRO Optimizer. Stage 0, 1, 2, and 3 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitioning, and optimizer+gradient+parameter partitioning, respectively. | `0`     |
J
Jeff Rasley 已提交
318

319
<i>**allgather_partitions**</i>: [boolean]
J
Jeff Rasley 已提交
320

C
Cheng Li 已提交
321 322 323
| Description                                                                                                                                      | Default |
| ------------------------------------------------------------------------------------------------------------------------------------------------ | ------- |
| Chooses between allgather collective or a series of broadcast collectives to gather updated parameters from all the GPUs at the end of each step | `true`  |
J
Jeff Rasley 已提交
324

J
Jeff Rasley 已提交
325
***allgather_bucket_size***: [integer]
J
Jeff Rasley 已提交
326

C
Cheng Li 已提交
327 328 329
| Description                                                                                                  | Default |
| ------------------------------------------------------------------------------------------------------------ | ------- |
| Number of elements allgathered at a time. Limits the memory required for the allgather for large model sizes | `5e8`   |
J
Jeff Rasley 已提交
330

331
<i>**overlap_comm**</i>: [boolean]
332

C
Cheng Li 已提交
333 334 335
| Description                                                                  | Default |
| ---------------------------------------------------------------------------- | ------- |
| Attempts to overlap the reduction of the gradients with backward computation | `false` |
336

337
<i>**reduce_scatter**</i>: [boolean]
J
Jeff Rasley 已提交
338

C
Cheng Li 已提交
339 340 341
| Description                                                             | Default |
| ----------------------------------------------------------------------- | ------- |
| Uses reduce or reduce scatter instead of allreduce to average gradients | `true`  |
J
Jeff Rasley 已提交
342

J
Jeff Rasley 已提交
343
***reduce_bucket_size***: [integer]
J
Jeff Rasley 已提交
344

C
Cheng Li 已提交
345 346 347
| Description                                                                                                         | Default |
| ------------------------------------------------------------------------------------------------------------------- | ------- |
| Number of elements reduced/allreduced at a time. Limits the memory required for the allgather for large model sizes | `5e8`   |
J
Jeff Rasley 已提交
348

349
<i>**contiguous_gradients**</i>: [boolean]
J
Jeff Rasley 已提交
350

C
Cheng Li 已提交
351 352 353
| Description                                                                                                                                                     | Default |
| --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Copies the gradients to a contiguous buffer as they are produced. Avoids memory fragmentation during backward pass. Only useful when running very large models. | `False` |
J
Jeff Rasley 已提交
354

J
Jeff Rasley 已提交
355 356 357 358 359
<i>**grad_hooks**</i>: [boolean]

| Description                                                                                                                                | Default |
| ------------------------------------------------------------------------------------------------------------------------------------------ | ------- |
| For use with ZeRO stage 1, enable backward hooks to reduce gradients during the backward pass or wait until the end of the backward pass.  | `True`  |  
J
Jeff Rasley 已提交
360

J
Jeff Rasley 已提交
361
***offload_param***: [dictionary]
S
Samyam Rajbhandari 已提交
362

363 364
| Description                                                                                                                                                                                   | Default |
| --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
J
Jeff Rasley 已提交
365
| Enable offloading of model parameters to CPU or NVMe. This frees up GPU memory for larger models or batch sizes. Valid only with stage 3. See [here](#parameter-offloading) for more details. | `False` |
S
Samyam Rajbhandari 已提交
366

J
Jeff Rasley 已提交
367
***offload_optimizer***: [dictionary]
S
Samyam Rajbhandari 已提交
368

369 370
| Description                                                                                                                                                                                                                    | Default |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------- |
J
Jeff Rasley 已提交
371
| Enable offloading of optimizer state to CPU or NVMe, and optimizer computation to CPU. This frees up GPU memory for larger models or batch sizes. Valid only with stage 3. See [here](#optimizer-offloading) for more details. | `False` |
S
Samyam Rajbhandari 已提交
372 373 374

***stage3_max_live_parameters***: [integer]

375 376
| Description                                                                                                                         | Default |
| ----------------------------------------------------------------------------------------------------------------------------------- | ------- |
S
Samyam Rajbhandari 已提交
377 378 379 380
| The maximum number of parameters resident per GPU before releasing. Smaller values use less memory, but perform more communication. | `1e9`   |

***stage3_max_reuse_distance***: [integer]

381 382
| Description                                                                                                                                          | Default |
| ---------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
S
Samyam Rajbhandari 已提交
383 384 385 386
| Do not release a parameter if it will be reused within this threshold of parameters. Smaller values use less memory, but perform more communication. | `1e9`   |

***stage3_prefetch_bucket_size***: [integer]

387 388
| Description                                                                                                                            | Default |
| -------------------------------------------------------------------------------------------------------------------------------------- | ------- |
S
Samyam Rajbhandari 已提交
389 390 391 392
| The size of the fixed buffer for prefetching parameters. Smaller values use less memory, but can increase stalls due to communication. | `5e8`   |


***stage3_param_persistence_threshold***: [integer]
J
Jeff Rasley 已提交
393

S
Samyam Rajbhandari 已提交
394 395 396 397
| Description                                                                                                                                                          | Default |
| -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages). | `1e6`   |

J
Jeff Rasley 已提交
398

S
Stas Bekman 已提交
399
***stage3_gather_fp16_weights_on_model_save***: [boolean]
J
Jeff Rasley 已提交
400

401 402
| Description                                                                                                                                                                                                                                                                   | Default |
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
S
Stas Bekman 已提交
403 404
| Consolidate the weights before saving the model by `save_fp16_model()`. Since the weights are partitioned across GPUs, they aren't part of `state_dict`, so this function automatically gather the weights when this option is enabled and then saves the fp16 model weights. | `False` |

405

J
Jeff Rasley 已提交
406 407 408 409 410
***cpu_offload***: [boolean]

**Deprecated:** **cpu_offload** is disabled and will be removed in future, please use `offload_optimizer` instead.
{: .notice--warning}

411 412 413
| Description                                                                                                                                       | Default |
| ------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Enable offloading of optimizer memory and computation to CPU. This frees up GPU memory for larger models or batch sizes. Valid only with stage 2. | `False` |
J
Jeff Rasley 已提交
414 415 416 417 418 419 420 421


### Parameter offloading
Enabling and configuring ZeRO optimization of parameter offloading to CPU/NVMe. Available only with ZeRO stage 3.
```json
  "offload_param": {
    "device": "[none|cpu|nvme]",
    "nvme_path": "/local_nvme",
422
    "pin_memory": [true|false],
J
Jeff Rasley 已提交
423 424 425 426 427 428 429
    "buffer_count": 5,
    "buffer_size": 1e8,
    "max_in_cpu": 1e9
  }
```
***device***: [string]

430 431
| Description                                                                        | Default |
| ---------------------------------------------------------------------------------- | ------- |
J
Jeff Rasley 已提交
432 433 434 435
| Device memory to offload model parameters. Supported options are `cpu` and `nvme`. | `cpu`   |

***nvme_path***: [string]

436 437 438
| Description                                               | Default       |
| --------------------------------------------------------- | ------------- |
| Filesystem path for NVMe device for parameter offloading. | `/local_nvme` |
J
Jeff Rasley 已提交
439

440 441
***pin_memory***: [boolean]

442 443 444
| Description                                                                                          | Default |
| ---------------------------------------------------------------------------------------------------- | ------- |
| Offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead. | `false` |
445

J
Jeff Rasley 已提交
446 447
***buffer_count***: [integer]

448 449 450
| Description                                                        | Default |
| ------------------------------------------------------------------ | ------- |
| Number of buffers in buffer pool for parameter offloading to NVMe. | 5       |
J
Jeff Rasley 已提交
451 452 453 454


***buffer_size***: [integer]

455 456 457
| Description                                                      | Default |
| ---------------------------------------------------------------- | ------- |
| Size of buffers in buffer pool for parameter offloading to NVMe. | 1e8     |
J
Jeff Rasley 已提交
458 459 460

***max_in_cpu***: [integer]

461 462 463
| Description                                                                                | Default |
| ------------------------------------------------------------------------------------------ | ------- |
| Number of parameter elements to maintain in CPU memory when offloading to NVMe is enabled. | 1e9     |
J
Jeff Rasley 已提交
464 465 466 467 468 469 470 471

### Optimizer offloading
Enabling and configuring ZeRO optimization of offloading optimizer computation to CPU and state to CPU/NVMe. CPU offloading is available with ZeRO stage 2 or 3. NVMe offloading is available only with ZeRO stage 3.
```json
  "offload_optimizer": {
    "device": "[none|cpu|nvme]",
    "nvme_path": "/local_nvme",
    "pin_memory": [true|false],
472
    "buffer_count": 4,
J
Jeff Rasley 已提交
473 474 475 476 477
    "fast_init": false
  }
```
***device***: [string]

478 479
| Description                                                                                                                                            | Default |
| ------------------------------------------------------------------------------------------------------------------------------------------------------ | ------- |
J
Jeff Rasley 已提交
480 481 482 483
| Device memory to offload optimizer state. Supported options are `cpu` and `nvme`. Optimizer computation is offload to CPU regardless of device option. | `cpu`   |

***nvme_path***: [string]

484 485 486
| Description                                                     | Default       |
| --------------------------------------------------------------- | ------------- |
| Filesystem path for NVMe device for optimizer state offloading. | `/local_nvme` |
J
Jeff Rasley 已提交
487

488
***pin_memory***: [boolean]
J
Jeff Rasley 已提交
489

490 491 492
| Description                                                                                          | Default |
| ---------------------------------------------------------------------------------------------------- | ------- |
| Offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead. | `false` |
J
Jeff Rasley 已提交
493

494
***buffer_count***: [integer]
J
Jeff Rasley 已提交
495

496 497 498
| Description                                                                                                                                                                                                                                              | Default |
| -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Number of buffers in buffer pool for optimizer state offloading to NVMe. This should be at least the number of states maintained per parameter by the optimizer. For example, Adam optimizer has 4 states (parameter, gradient, momentum, and variance). | 4       |
J
Jeff Rasley 已提交
499 500 501

***fast_init***: [boolean]

502 503 504
| Description                                                   | Default |
| ------------------------------------------------------------- | ------- |
| Enable fast optimizer initialization when offloading to NVMe. | `false` |
J
Jeff Rasley 已提交
505

O
Olatunji Ruwase 已提交
506 507 508 509 510 511 512 513 514 515 516 517 518 519

### Asynchronous I/O
Configuring the asynchronous I/O module for offloading parameter and optimizer states to persistent (NVMe) storage. This module uses Linux native asynchronous I/O (libaio).
```json
  "aio": {
    "block_size": 1048576,
    "queue_depth": 8,
    "thread_count": 1,
    "single_submit": false,
    "overlap_events": true
  }
```
***block_size***: [integer]

520 521 522
| Description              | Default |
| ------------------------ | ------- |
| I/O block size in bytes. | 1048576 |
O
Olatunji Ruwase 已提交
523 524 525

***queue_depth***: [integer]

526 527 528
| Description      | Default |
| ---------------- | ------- |
| I/O queue depth. | 8       |
O
Olatunji Ruwase 已提交
529 530 531

***thread_count***: [integer]

532 533 534
| Description                                                               | Default |
| ------------------------------------------------------------------------- | ------- |
| Intra-request parallelism for each read/write submitted by a user thread. | 1       |
O
Olatunji Ruwase 已提交
535 536 537

***single_submit***: [boolean]

538 539 540
| Description                                                                                            | Default |
| ------------------------------------------------------------------------------------------------------ | ------- |
| Submit requests to storage device as multiple individual requests as opposed to one block of requests. | `false` |
O
Olatunji Ruwase 已提交
541 542 543

***overlap_events***: [boolean]

544 545
| Description                                                                                                    | Default |
| -------------------------------------------------------------------------------------------------------------- | ------- |
O
Olatunji Ruwase 已提交
546 547
| Submit requests to storage device in an overlapped fashion without waiting for completion of earlier requests. | `true`  |

548 549 550 551 552
***ignore_unused_parameters***: [boolean]

| Description                                                                                                                            | Default |
| -------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Unused parameters in modules may be unexpected in static networks, but could be normal in dynamic networks. This controls whether or not training should terminate with an error message when unused parameters are detected. This is set to `False` by default, which means unused parameters are ignored and training continues. Now is just used in stage 2. | `True` |
O
Olatunji Ruwase 已提交
553

S
Shaden Smith 已提交
554 555
### Logging

556
<i>**steps_per_print**</i>: [integer]
S
Shaden Smith 已提交
557

C
Cheng Li 已提交
558 559 560
| Description                    | Default |
| ------------------------------ | ------- |
| Print train loss every N steps | `10`    |
S
Shaden Smith 已提交
561

562
<i>**wall_clock_breakdown**</i>: [boolean]
S
Shaden Smith 已提交
563

C
Cheng Li 已提交
564 565 566
| Description                                                             | Default |
| ----------------------------------------------------------------------- | ------- |
| Enable timing of the latency of forward/backward/update training phases | `false` |
S
Shaden Smith 已提交
567

568
<i>**dump_state**</i>: [boolean]
S
Shaden Smith 已提交
569

C
Cheng Li 已提交
570 571 572 573 574 575 576 577
| Description                                                          | Default |
| -------------------------------------------------------------------- | ------- |
| Print out state information of DeepSpeed object after initialization | `false` |

### Flops Profiler
```json
{
  "flops_profiler": {
578
    "enabled": false,
C
Cheng Li 已提交
579 580
    "profile_step": 1,
    "module_depth": -1,
581
    "top_modules": 1,
C
Cheng Li 已提交
582
    "detailed": true,
583
    "output_file": null,
C
Cheng Li 已提交
584 585 586
    }
}
```
587
<i>**enabled**</i>: [boolean]
C
Cheng Li 已提交
588

589 590 591
| Description                                                              | Default |
| ------------------------------------------------------------------------ | ------- |
| Enables the flops profiler. This would also enables wall_clock_breakdown | `false` |
C
Cheng Li 已提交
592

593
<i>**profile_step**</i>: [integer]
C
Cheng Li 已提交
594 595 596 597 598

| Description                                                                                                     | Default |
| --------------------------------------------------------------------------------------------------------------- | ------- |
| The global training step at which to profile. Note that warm up steps are needed for accurate time measurement. | `1`     |

599
<i>**module_depth**</i>: [integer]
C
Cheng Li 已提交
600

601 602 603
| Description                                                                                                                                                                           | Default |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| The depth of the model at which to print the aggregated module information. When set to `-1`, it prints information from the top module to the innermost modules (the maximum depth). | `-1`    |
C
Cheng Li 已提交
604

605
<i>**top_modules**</i>: [integer]
C
Cheng Li 已提交
606 607 608

| Description                                                                  | Default |
| ---------------------------------------------------------------------------- | ------- |
609
| Limits the aggregated profile output to the number of top modules specified. | `1`     |
C
Cheng Li 已提交
610

611
<i>**detailed**</i>: [boolean]
C
Cheng Li 已提交
612 613 614 615

| Description                                  | Default |
| -------------------------------------------- | ------- |
| Whether to print the detailed model profile. | `true`  |
J
Jeff Rasley 已提交
616

617 618 619 620 621 622 623
<i>**output_file**</i>: [string]

| Description                                                       | Default |
| ----------------------------------------------------------------- | ------- |
| Path to the output file. If None, the profiler prints to stdout.. | `null`  |


J
Jeff Rasley 已提交
624 625 626 627 628 629 630 631 632 633 634
### Activation Checkpointing
```json
  "activation_checkpointing": {
    "partition_activations": false,
    "cpu_checkpointing": false,
    "contiguous_memory_optimization": false,
    "number_checkpoints": null,
    "synchronize_checkpoint_boundary": false,
    "profile": false
    }
```
635
<i>**partition_activations**</i>: [boolean]
J
Jeff Rasley 已提交
636

C
Cheng Li 已提交
637 638 639
| Description                                                   | Default |
| ------------------------------------------------------------- | ------- |
| Enables partition activation when used with model parallelism | `false` |
J
Jeff Rasley 已提交
640

641
<i>**cpu_checkpointing**</i>: [boolean]
J
Jeff Rasley 已提交
642

C
Cheng Li 已提交
643 644 645
| Description                                                                 | Default |
| --------------------------------------------------------------------------- | ------- |
| Offloads partitioned activations to CPU if partition_activations is enabled | `false` |
J
Jeff Rasley 已提交
646 647


648
<i>**contiguous_memory_optimization**</i>: [boolean]
J
Jeff Rasley 已提交
649

C
Cheng Li 已提交
650 651 652
| Description                                                          | Default |
| -------------------------------------------------------------------- | ------- |
| Copies partitioned activations so that they are contiguous in memory | `false` |
J
Jeff Rasley 已提交
653

654
<i>**number_checkpoints**</i>: [integer]
J
Jeff Rasley 已提交
655

C
Cheng Li 已提交
656 657 658
| Description                                                                                              | Default |
| -------------------------------------------------------------------------------------------------------- | ------- |
| Total number of activation checkpoints used to allocate memory buffer for contiguous_memoty_optimization | `None`  |
J
Jeff Rasley 已提交
659

660
<i>**synchronize_checkpoint_boundary**</i>: [boolean]
J
Jeff Rasley 已提交
661

C
Cheng Li 已提交
662 663 664
| Description                                                   | Default |
| ------------------------------------------------------------- | ------- |
| Inserts torch.cuda.synchronize() at each checkpoint boundary. | `false` |
J
Jeff Rasley 已提交
665 666


667
<i>**profile**</i>: [boolean]
J
Jeff Rasley 已提交
668

C
Cheng Li 已提交
669 670 671
| Description                                                     | Default |
| --------------------------------------------------------------- | ------- |
| Logs the forward and backward time for each checkpoint function | `false` |
672 673 674

### Sparse Attention

675
<i>**sparse_attention**</i>: [dictionary]
676

C
Cheng Li 已提交
677 678 679 680 681 682 683 684 685 686 687 688 689 690 691
| Fields                           | Value                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          | Example           |
| -------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------- |
| mode                             | A string determining sparsity structure type. Deepspeed currently supports `"dense"`, `"fixed"`, `"bigbird"`, `"bslongformer"`, and `"variable"`.                                                                                                                                                                                                                                                                                                                                                              | `"fixed"`         |
| block                            | An integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.                                                                                                                                                                                                                                                                                                              | 16                |
| different\_layout\_per\_head     | A boolean determining if each head should be assigned a different sparsity layout; this will be satisfied based on availability.                                                                                                                                                                                                                                                                                                                                                                               | false             |
| num\_local\_blocks               | An integer determining the number of random blocks in each block row; only used in `"fixed"` mode.                                                                                                                                                                                                                                                                                                                                                                                                             | 4                 |
| num\_global\_blocks              | An integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention; used in `"fixed"` and `"bigbird"` modes.                                                                                                                                                                                                                                                                                                                                | 1                 |
| attention                        | A string determining attention type. Attention can be `"unidirectional"`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty. Or it can be `"bidirectional"`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular; used in `"fixed"` and `"variable"` modes. | `"bidirectional"` |
| horizontal\_global\_attention    | A boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `"bidirectional"`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks; used in `"fixed"` and `"variable"` modes.                                                                                                                                                        | false             |
| num\_different\_global\_patterns | An integer determining number of different global attentions layouts. While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative; used only in `"fixed"` mode.                                                                                                                                                                                                                              | 4                 |
| num\_random\_blocks              | An integer determining the number of random blocks in each block row; used in `"variable"` and `"bigbird"` modes.                                                                                                                                                                                                                                                                                                                                                                                              | 0                 |
| local\_window\_blocks            | A list of integers determining the number of blocks in each local attention window. It assumes first number determines # of blocks in the first local window, second the second window, ..., and the last number determines the number of blocks in the remaining local windows; only used in `"variable"` mode.                                                                                                                                                                                               | [4]               |
| global\_block\_indices           | A list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Notice that if global\_block\_end\_indices parameter is set, this parameter is used as starting index of each global window; used in `"variable"` and `"bslongformer"` modes.                                                                                                                             | [0]               |
| global\_block\_end\_indices      | A list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global\_block\_indices parameter, and combining this two parameters, for each index i, blocks from global\_block\_indices[i] to global\_block\_end\_indices[i], exclusive, are considered as global attention; used in `"variable"` and `"bslongformer"` modes.                                                                                               | None              |
| num\_sliding\_window\_blocks     | An integer determining the number of blocks in sliding local attention window; used in `"bigbird"` and `"bslongformer"` modes.                                                                                                                                                                                                                                                                                                                                                                                 | 3                 |
692

693
  Example of <i>**sparse_attention**</i>
694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711

```json
  "sparse_attention": {
    "mode": "fixed",
    "block": 16,
    "different_layout_per_head": true,
    "num_local_blocks": 4,
    "num_global_blocks": 1,
    "attention": "bidirectional",
    "horizontal_global_attention": false,
    "num_different_global_patterns": 4,
    "num_random_blocks": 0,
    "local_window_blocks": [4],
    "global_block_indices": [0],
    "global_block_end_indices": None,
    "num_sliding_window_blocks": 3
  }
```