From e2dfe0d17be21ce08e759dab4712bc37debd44c4 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Wed, 10 Feb 2021 18:03:55 -0800 Subject: [PATCH] Add flops profiler tutorial (#682) * work on flops profiler tutorial * update flops profiler tutorial * add flops profiler tutorial and fix names * work on flops profiler tutorial * update flops profiler tutorial * add flops profiler tutorial and fix names * fix tailing ws * fix names * remove multistep profiling and update docs * fix cases where functionals and submodules coexist in a parent module, update readme * fix typo * always invoke post hook function * fix module flops sum and update tests * update tutorial --- deepspeed/out2 | 7 + deepspeed/profiling/config.py | 17 +- deepspeed/profiling/constants.py | 16 +- deepspeed/profiling/flops_profiler/README.md | 559 ++++++++++++------ .../profiling/flops_profiler/profiler.py | 289 +++++---- deepspeed/runtime/engine.py | 41 +- docs/_config.yml | 1 + docs/_data/navigation.yml | 4 + docs/_pages/config-json.md | 290 +++++---- docs/_pages/features.md | 46 +- docs/_tutorials/flops-profiler.md | 447 ++++++++++++++ tests/unit/test_flops_profiler.py | 16 +- 12 files changed, 1248 insertions(+), 485 deletions(-) create mode 100644 deepspeed/out2 create mode 100644 docs/_tutorials/flops-profiler.md diff --git a/deepspeed/out2 b/deepspeed/out2 new file mode 100644 index 00000000..15ca670d --- /dev/null +++ b/deepspeed/out2 @@ -0,0 +1,7 @@ +============================= test session starts ============================== +platform linux -- Python 3.6.9, pytest-6.0.1, py-1.9.0, pluggy-0.13.1 +rootdir: /home/chengli1/projects/DeepSpeed +plugins: forked-1.3.0, hypothesis-5.41.3, xdist-2.1.0, cov-2.10.1 +collected 0 items + +============================ no tests ran in 0.01s ============================= diff --git a/deepspeed/profiling/config.py b/deepspeed/profiling/config.py index 017f9ec9..3302d616 100644 --- a/deepspeed/profiling/config.py +++ b/deepspeed/profiling/config.py @@ -15,8 +15,7 @@ class DeepSpeedFlopsProfilerConfig(object): super(DeepSpeedFlopsProfilerConfig, self).__init__() self.enabled = None - self.start_step = None - self.end_step = None + self.profile_step = None self.module_depth = None self.top_modules = None @@ -35,13 +34,9 @@ class DeepSpeedFlopsProfilerConfig(object): FLOPS_PROFILER_ENABLED, FLOPS_PROFILER_ENABLED_DEFAULT) - self.start_step = get_scalar_param(flops_profiler_dict, - FLOPS_PROFILER_START_STEP, - FLOPS_PROFILER_START_STEP_DEFAULT) - - self.end_step = get_scalar_param(flops_profiler_dict, - FLOPS_PROFILER_END_STEP, - FLOPS_PROFILER_END_STEP_DEFAULT) + self.profile_step = get_scalar_param(flops_profiler_dict, + FLOPS_PROFILER_PROFILE_STEP, + FLOPS_PROFILER_PROFILE_STEP_DEFAULT) self.module_depth = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_MODULE_DEPTH, @@ -50,3 +45,7 @@ class DeepSpeedFlopsProfilerConfig(object): self.top_modules = get_scalar_param(flops_profiler_dict, FLOPS_PROFILER_TOP_MODULES, FLOPS_PROFILER_TOP_MODULES_DEFAULT) + + self.detailed = get_scalar_param(flops_profiler_dict, + FLOPS_PROFILER_DETAILED, + FLOPS_PROFILER_DETAILED_DEFAULT) diff --git a/deepspeed/profiling/constants.py b/deepspeed/profiling/constants.py index f4812d32..964e528c 100644 --- a/deepspeed/profiling/constants.py +++ b/deepspeed/profiling/constants.py @@ -12,11 +12,11 @@ FLOPS_PROFILER_FORMAT = ''' flops profiler should be enabled as: "session_params": { "flops_profiler": { - "enalbe": [true|false], - "start_step": 5, - "end_step": 6, + "enabled": true, + "profile_step": 1, "module_depth": -1, "top_modules": 3, + "detailed": true, } } ''' @@ -26,14 +26,14 @@ FLOPS_PROFILER = "flops_profiler" FLOPS_PROFILER_ENABLED = "enabled" FLOPS_PROFILER_ENABLED_DEFAULT = False -FLOPS_PROFILER_START_STEP = "start_step" -FLOPS_PROFILER_START_STEP_DEFAULT = 5 - -FLOPS_PROFILER_END_STEP = "end_step" -FLOPS_PROFILER_END_STEP_DEFAULT = FLOPS_PROFILER_START_STEP_DEFAULT + 1 +FLOPS_PROFILER_PROFILE_STEP = "profile_step" +FLOPS_PROFILER_PROFILE_STEP_DEFAULT = 1 FLOPS_PROFILER_MODULE_DEPTH = "module_depth" FLOPS_PROFILER_MODULE_DEPTH_DEFAULT = -1 FLOPS_PROFILER_TOP_MODULES = "top_modules" FLOPS_PROFILER_TOP_MODULES_DEFAULT = 3 + +FLOPS_PROFILER_DETAILED = "detailed" +FLOPS_PROFILER_DETAILED_DEFAULT = True diff --git a/deepspeed/profiling/flops_profiler/README.md b/deepspeed/profiling/flops_profiler/README.md index f4584d21..179a0b13 100644 --- a/deepspeed/profiling/flops_profiler/README.md +++ b/deepspeed/profiling/flops_profiler/README.md @@ -1,250 +1,445 @@ -# flops-profiler +# DeepSpeed Flops Profiler -> Measures the time, number of estimated flops and parameters of each module in a PyTorch Model. +> Measures the parameters, latency, and floating point operations of your model. -The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how time, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated time, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input. If multiple forward passes are specified by the user to caputre (in the case where the model have different paths or for more accurate timing), the average profile of the multiple batches is taken. + - [Overview](#overview) + - [Supported Models](#supported-models) + - [Multi-GPU, Multi-node Runs](#multi-gpu-multi-node-runs) + - [Usage](#usage) -The flops estimation is partly inspired by [ptflops](https://github.com/sovrasov/flops-counter.pytorch) with the major difference being that flops-profiler captures `torch.nn.functional` invoked in a module to estimate the flops, thus allowing customized modules in the model (e.g. `ParallelTransformerLayerworks, ParallelSelfAttention, RowParallelLinear, etc.` in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)). The flops-profiler also supports flops computation at module level (for RNNs). +## Overview -For models running on multi-node or multi-gpu, only the model parallelism affects the number of flops and parameters (e.g. `--model-parallel-size` in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)), i.e., model_parallel_size _ flops = total_flops, model_parallel_size _ parameters = total_parameters. The number of gpus or nodes does not affect the output profile. +The DeepSpeed flops profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. +It shows the parameters, latency, and number of floating point operations of the modules within the model to identify potential bottlenecks. +It also outputs the names of the top `k` modules in terms of aggregated time, flops, and number of parameters at depth `l` with `k` and `l` specified by the user. +The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a standalone package. -Below is an example output for LeNet5 with batch size 1024 on a V100 GPU: +The output profile is computed for each batch of input and printed to the `stdout`. For each module, the measured profile is annotated after the name and is listed in the order of `number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency of the module, percentage of the total latency, floating point operations per second (FLOPS)`. Note that the number of floating point operations is estimated as `2 * MACs` in the profiler (each MAC operation is counted as 2 floating point operations). + +Below is an example output for LeNet5 with batch size 1024: + +```shell +-------------------------- DeepSpeed Flops Profiler -------------------------- +Summary of forward pass: +Profile step: 1 +Number of parameters: 61.71 k +Number of multiply-accumulate operations (MACs): 439.56 M +Number of floating point operations ( = 2 * MACs): 879.12 M +Latency: 25.7 ms +Floating point operations per second(FLOPS): 34.2 GFLOPS + +----------------------------- Aggregated Profile ----------------------------- +Top 3 modules in MACs at depth 2 are {'Conv2d': '421.91 MMACs', 'Linear': '11.18 MMACs', 'AvgPool2d': '6.46 MMACs'} +Top 3 modules in params at depth 2 are {'Conv2d': '50.69 k', 'Linear': '11.01 k', 'Tanh': '0'} +Top 3 modules in latency at depth 2 are {'Conv2d': '11.37 ms', 'Linear': '5.27 ms', 'AvgPool2d': '5.02 ms'} + +------------------------------ Detailed Profile ------------------------------ +Each module profile is listed after its name in the follwing order: +number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency). +Note: +1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'. +2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught. -``` LeNet5( - 61.71 k, 100.00% Params, 439.55 MMACs, 100.00% MACs, 25.62 ms, 100.00% time, 0.034 TFLOPS, + 61.71 k, 100.00% Params, 439.56 MMACs, 100.00% MACs, 25.7 ms, 100.00% latency, 34.2 GFLOPS, (feature_extractor): Sequential( - 50.69 k, 82.15% Params, 428.37 MMACs, 97.46% MACs, 18.41 ms, 71.85% time, 0.047 TFLOPS, - (0): Conv2d(156, 0.25% Params, 125.24 MMACs, 28.49% MACs, 10.56 ms, 41.21% time, 0.024 TFLOPS, 1, 6, kernel_size=(5, 5), stride=(1, 1)) - (1): Tanh(0, 0.00% Params, 0.0 MACs, 0.00% MACs, 2.25 ms, 8.79% time, 0.0 TFLOPS, ) - (2): AvgPool2d(0, 0.00% Params, 4.82 MMACs, 1.10% MACs, 2.47 ms, 9.63% time, 0.0039 TFLOPS, kernel_size=2, stride=2, padding=0) - (3): Conv2d(2.42 k, 3.92% Params, 247.4 MMACs, 56.28% MACs, 1.08 ms, 4.23% time, 0.46 TFLOPS, 6, 16, kernel_size=(5, 5), stride=(1, 1)) - (4): Tanh(0, 0.00% Params, 0.0 MACs, 0.00% MACs, 497.39 us, 1.94% time, 0.0 TFLOPS, ) - (5): AvgPool2d(0, 0.00% Params, 1.64 MMACs, 0.37% MACs, 758.24 us, 2.96% time, 0.0043 TFLOPS, kernel_size=2, stride=2, padding=0) - (6): Conv2d(48.12 k, 77.98% Params, 49.27 MMACs, 11.21% MACs, 606.35 us, 2.37% time, 0.16 TFLOPS, 16, 120, kernel_size=(5, 5), stride=(1, 1)) - (7): Tanh(0, 0.00% Params, 0.0 MACs, 0.00% MACs, 68.86 us, 0.27% time, 0.0 TFLOPS, ) + 50.69 k, 82.15% Params, 428.37 MMACs, 97.45% MACs, 20.12 ms, 78.27% latency, 42.59 GFLOPS, + (0): Conv2d(156, 0.25% Params, 125.24 MMACs, 28.49% MACs, 9.8 ms, 38.12% latency, 25.56 GFLOPS, 1, 6, kernel_size=(5, 5), stride=(1, 1)) + (1): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 2.85 ms, 11.08% latency, 0.0 FLOPS, ) + (2): AvgPool2d(0, 0.00% Params, 4.82 MMACs, 1.10% MACs, 4.01 ms, 15.59% latency, 2.4 GFLOPS, kernel_size=2, stride=2, padding=0) + (3): Conv2d(2.42 k, 3.92% Params, 247.4 MMACs, 56.28% MACs, 924.83 us, 3.60% latency, 535.02 GFLOPS, 6, 16, kernel_size=(5, 5), stride=(1, 1)) + (4): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 672.1 us, 2.62% latency, 0.0 FLOPS, ) + (5): AvgPool2d(0, 0.00% Params, 1.64 MMACs, 0.37% MACs, 1.01 ms, 3.95% latency, 3.23 GFLOPS, kernel_size=2, stride=2, padding=0) + (6): Conv2d(48.12 k, 77.98% Params, 49.27 MMACs, 11.21% MACs, 647.31 us, 2.52% latency, 152.25 GFLOPS, 16, 120, kernel_size=(5, 5), stride=(1, 1)) + (7): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 82.02 us, 0.32% latency, 0.0 FLOPS, ) ) (classifier): Sequential( - 11.01 k, 17.85% Params, 11.18 MMACs, 2.54% MACs, 7.03 ms, 27.43% time, 0.0032 TFLOPS, - (0): Linear(10.16 k, 16.47% Params, 10.32 MMACs, 2.35% MACs, 2.71 ms, 10.57% time, 0.0076 TFLOPS, in_features=120, out_features=84, bias=True) - (1): Tanh(0, 0.00% Params, 0.0 MACs, 0.00% MACs, 78.77 us, 0.31% time, 0.0 TFLOPS, ) - (2): Linear(850, 1.38% Params, 860.16 KMACs, 0.20% MACs, 4.17 ms, 16.27% time, 0.00041 TFLOPS, in_features=84, out_features=10, bias=True) + 11.01 k, 17.85% Params, 11.18 MMACs, 2.54% MACs, 5.41 ms, 21.06% latency, 4.13 GFLOPS, + (0): Linear(10.16 k, 16.47% Params, 10.32 MMACs, 2.35% MACs, 2.47 ms, 9.60% latency, 8.37 GFLOPS, in_features=120, out_features=84, bias=True) + (1): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 90.12 us, 0.35% latency, 0.0 FLOPS, ) + (2): Linear(850, 1.38% Params, 860.16 KMACs, 0.20% MACs, 2.8 ms, 10.91% latency, 613.62 MFLOPS, in_features=84, out_features=10, bias=True) ) ) -Top 3 modules in flops at depth 2 are {'Conv2d': '421.91 MMACs', 'Linear': '11.18 MMACs', 'AvgPool2d': '6.46 MMACs'} -Top 3 modules in params at depth 2 are {'Conv2d': '50.69 k', 'Linear': '11.01 k', 'Tanh': '0'} -Top 3 modules in time at depth 2 are {'Conv2d': '12.25 ms', 'Linear': '6.88 ms', 'AvgPool2d': '3.23 ms'} -Batch size: 1024 -Number of multiply-adds: 439.55 MMACs -Number of parameters: 61.71 k -Number of steps profiled: 10 -``` - -## Installation - -The profiler is an integral part of DeepSpeed and can be installed by - -``` -pip install deepspeed +------------------------------------------------------------------------------ ``` -Refer to the [installaiton of DeepSpeed](https://www.deepspeed.ai/getting-started/#installation) for more information. +## Supported Models -## Usage +The flops estimation is partly inspired by [ptflops](https://github.com/sovrasov/flops-counter.pytorch) with the major difference being that the DeepSpeed flops profiler captures ```torch.nn.functional``` invoked in a module to estimate the flops. Thus the DeepSpeed flops profiler allows for customized modules in the model, e.g., ```ParallelTransformerLayerworks, ParallelSelfAttention, RowParallelLinear, etc.``` in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM). This is in contrast to tools that profile at ```torch.nn.module``` level, such as ptflops, which require users to write customized flops calculation functions for each customized module. Finally, the DeepSpeed flops profiler also supports flops computation at module level (for RNNs). -### With the DeepSpeed runtime +## Multi-GPU, Multi-node Runs -If using DeepSpeed for model training, no explict API calls are needed to use the flops-profiler. +For models running on multi-GPU or multi-node, only the model parallelism (e.g. ```--model-parallel-size``` in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)) affects the number of flops and parameters profiled, i.e., +`model_parallel_size * flops = total_flops` and `model_parallel_size * parameters = total_parameters`. The number of GPUs or nodes does not affect the output profile. -In DeepSpeed config file, specify: -```python - ds_config = { - ...# other deepspeed configs - "flops_profiler": { - "enabled": True, - "start_step": 2, - "end_step": 3, - "module_depth": -1, - "top_modules": 3, - }, - } -``` -- `"enabled": true` to enable the flops-profiler. -- `"start_step": 5` to start the profiler at step 5. Note that warm-up is necessary for getting accurate timing information. -- `"end_step": 6` to end the profiler at step 6. Note that `end_step > start_step`. -- `"module_depth": -1` to print aggregated module information at the maximum depth (innermost modules). Can be set to any positive number, caped by the maximum depth of the model. -- `"top_modules": 3`to set the number of top modules to print aggregated profile +## Usage -An example is given in [test_flops_profiler](tests/unit/test_flops_profiler.py). -### Without the DeepSpeed runtime +The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a standalone package. When using DeepSpeed for model training, the flops profiler can be configured in the deepspeed_config file without user code changes. To use the flops profiler outside of the DeepSpeed runtime, one can simply install DeepSpeed and import the flops_profiler package to use the APIs directly. Examples of each usage are given below. -The flops-profiler can be used as a standalone package outside of the deepspeed runtime. + - [Usage With the DeepSpeed Runtime](#usage-with-the-deepspeed-runtime) + - [Example: Megatron-LM](#example-megatron-lm) + - [Usage Outside the DeepSpeed Runtime](#usage-outside-the-deepspeed-runtime) + - [In Model Inference](#in-model-inference) + - [Example: AlexNet](#example-alexnet) + - [Example: Bert](#example-bert) + - [In Model Training Workflow](#in-model-training-workflow) + - [Example Training Workflow](#example-training-workflow) +### Usage With the DeepSpeed Runtime -#### Use the low-level APIs to profile the forward pass in the existing model training workflow +When using DeepSpeed for model training, the flops profiler can be configured in the `deepspeed_config` file. No explict API calls are needed to use the profiler. Refer to [flops profiler](https://www.deepspeed.ai/docs/config-json/#flops-profiler) for details. -- `start_profile` - starts profiling -- `get_total_flops` - returns the total number of flops -- `get_total_params` - returns the total number of params -- `get_total_duration` - returns the total duration of the model forward pass -- `get_total_steps` - returns the total number of steps (or input batches) profiled. -- `print_model_profile` - prints the profile annotated -- `print_model_aggregated_profile` - prints the aggregated profile for the top modules -- `end_profile` - ends profiling and cleans up, invoked at the end of the profiling and before any printing method. -`flops_to_string`, `params_to_string`, `duration_to_string` are utility functions to convert the metric number to string. +#### Example: Megatron-LM -Below is an example of this usage in a typical training workflow. +For information on running Megatron-LM with DeepSpeed, please refer to our tutorial [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM) -```python -from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler +The flops profiler can be enabled by adding the following field to the `deepspeed_config` file. -model = Model() -profiler = FlopsProfiler(model) - -start_step = 5 -end_step = 10 -assert (end_step > start_step), "should end profiling after start profiling" -print_profile = True -pring_aggregated_profile = True +```json +{ + "flops_profiler": { + "enabled": true, + "profile_step": 1, + "module_depth": -1, + "top_modules": 3, + "detailed": true, + } +} +``` -for step, batch in enumerate(data_loader): - # start profiling at training step "profile_step" - if step == start_step: - profiler.start_profile() - - # end profiling and print output at training step "profile_step" - if model == end_step: # if using multi nodes, check global_rank == 0 as well - flops = profiler.get_total_flops() - params = profiler.get_total_flops() - duration = profiler.get_total_duration() - steps = profiler.get_total_steps() - if print_profile: - profiler.print_model_profile() - if print_aggregated_profile: - profiler.print_model_aggregated_profile(module_depth=-1, top_modules=3) - profiler.end_profile() - print(flops, params, duration, step) +An example output of 4-layer Megatron-LM model (`hidden_size = 512, num_attention_heads = 16, batch_size = 8, seq_length = 1024`) is shown below. + +```shell +-------------------------- DeepSpeed Flops Profiler -------------------------- +Summary of forward pass: +Profile step: 1 +Number of parameters: 38.89 M +Number of multiply-accumulate operations (MACs): 314.61 G +Number of floating point operations ( = 2 * MACs): 629.21 G +Latency: 33.81 ms +Floating point operations per second(FLOPS): 18.61 TFLOPS + +----------------------------- Aggregated Profile ----------------------------- +Top 3 modules in MACs at depth 8 are {'ColumnParallelLinear': '60.13 GMACs', 'RowParallelLinear': '42.95 GMACs', 'FusedScaleMaskSoftmax': '536.87 MMACs'} +Top 3 modules in params at depth 8 are {'ColumnParallelLinear': '7.35 M', 'RowParallelLinear': '5.25 M', 'FusedScaleMaskSoftmax': '0'} +Top 3 modules in latency at depth 8 are {'ColumnParallelLinear': '659.23 us', 'RowParallelLinear': '587.94 us', 'FusedScaleMaskSoftmax': '370.98 us'} + +------------------------------ Detailed Profile ------------------------------ +Each module profile is listed after its name in the follwing order: +number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency). +Note: +1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'. +2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught. + +DistributedDataParallel( + 38.89 M, 100.00% Params, 314.61 GMACs, 100.00% MACs, 33.81 ms, 100.00% latency, 18.61 TFLOPS, + (module): FP16_Module( + 38.89 M, 100.00% Params, 314.61 GMACs, 100.00% MACs, 33.77 ms, 99.89% latency, 18.63 TFLOPS, + (module): GPT2Model( + 38.89 M, 100.00% Params, 314.61 GMACs, 100.00% MACs, 33.69 ms, 99.66% latency, 18.67 TFLOPS, + (language_model): TransformerLanguageModel( + 38.89 M, 100.00% Params, 103.62 GMACs, 32.94% MACs, 5.58 ms, 16.51% latency, 37.13 TFLOPS, + (embedding): Embedding( + 26.28 M, 67.57% Params, 0 MACs, 0.00% MACs, 545.98 us, 1.61% latency, 0.0 FLOPS, + (word_embeddings): VocabParallelEmbedding(25.76 M, 66.23% Params, 0 MACs, 0.00% MACs, 223.88 us, 0.66% latency, 0.0 FLOPS, ) + (position_embeddings): Embedding(524.29 k, 1.35% Params, 0 MACs, 0.00% MACs, 147.1 us, 0.44% latency, 0.0 FLOPS, 1024, 512) + (embedding_dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 79.39 us, 0.23% latency, 0.0 FLOPS, p=0.1, inplace=False) + ) + (transformer): ParallelTransformer( + 12.61 M, 32.43% Params, 103.62 GMACs, 32.94% MACs, 5.0 ms, 14.78% latency, 41.49 TFLOPS, + (layers): ModuleList( + 12.61 M, 32.42% Params, 103.62 GMACs, 32.94% MACs, 4.4 ms, 13.01% latency, 47.13 TFLOPS, + (0): ParallelTransformerLayer( + 3.15 M, 8.11% Params, 25.9 GMACs, 8.23% MACs, 1.36 ms, 4.02% latency, 38.09 TFLOPS, + (input_layernorm): FusedLayerNorm(1.02 k, 0.00% Params, 0 MACs, 0.00% MACs, 92.51 us, 0.27% latency, 0.0 FLOPS, torch.Size([512]), eps=1e-05, elementwise_affine=True) + (attention): ParallelSelfAttention( + 1.05 M, 2.70% Params, 8.72 GMACs, 2.77% MACs, 754.59 us, 2.23% latency, 23.12 TFLOPS, + (query_key_value): ColumnParallelLinear(787.97 k, 2.03% Params, 6.44 GMACs, 2.05% MACs, 182.87 us, 0.54% latency, 70.46 TFLOPS, ) + (scale_mask_softmax): FusedScaleMaskSoftmax(0, 0.00% Params, 134.22 MMACs, 0.04% MACs, 120.4 us, 0.36% latency, 2.23 TFLOPS, ) + (attention_dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 47.45 us, 0.14% latency, 0.0 FLOPS, p=0.1, inplace=False) + (dense): RowParallelLinear(262.66 k, 0.68% Params, 2.15 GMACs, 0.68% MACs, 81.78 us, 0.24% latency, 52.52 TFLOPS, ) + ) + (post_attention_layernorm): FusedLayerNorm(1.02 k, 0.00% Params, 0 MACs, 0.00% MACs, 57.22 us, 0.17% latency, 0.0 FLOPS, torch.Size([512]), eps=1e-05, elementwise_affine=True) + (mlp): ParallelMLP( + 2.1 M, 5.40% Params, 17.18 GMACs, 5.46% MACs, 224.83 us, 0.67% latency, 152.83 TFLOPS, + (dense_h_to_4h): ColumnParallelLinear(1.05 M, 2.70% Params, 8.59 GMACs, 2.73% MACs, 64.13 us, 0.19% latency, 267.87 TFLOPS, ) + (dense_4h_to_h): RowParallelLinear(1.05 M, 2.70% Params, 8.59 GMACs, 2.73% MACs, 90.36 us, 0.27% latency, 190.13 TFLOPS, ) + ) + ) + ... + (3): ParallelTransformerLayer(...) + (final_layernorm): FusedLayerNorm(1.02 k, 0.00% Params, 0 MACs, 0.00% MACs, 52.69 us, 0.16% latency, 0.0 TFLOPS, torch.Size([512]), eps=1e-05, elementwise_affine=True) + ) + ) + ) + ) +) +``` - # forward() method - loss = model(batch) +### Usage Outside the DeepSpeed Runtime - # runs backpropagation - loss.backward() +The flops profiler can be used as a standalone package outside of the DeepSpeed runtime. +One can simply install DeepSpeed and import the `flops_profiler` package to use the APIs directly. +Refer to [installation of DeepSpeed](https://www.deepspeed.ai/getting-started/#installation) for installing DeepSpeed. - # weight update - optimizer.step() -``` +#### In Model Inference -#### Use the high level-API and run the model inference for profiling purpose +To profile a trained model in inference, use the `get_model_profile` function. +Examples are given below. -Examples of this usage are given below. +##### Example: AlexNet -##### Classification model example: +The following example shows how to profile AlexNet using the DeepSpeed flops profiler. ```python -import argparse -import sys -import torch import torchvision.models as models +import torch from deepspeed.profiling.flops_profiler import get_model_profile -pt_models = { - 'resnet18': models.resnet18, - 'resnet50': models.resnet50, - 'alexnet': models.alexnet, - 'vgg16': models.vgg16, - 'squeezenet': models.squeezenet1_0, - 'densenet': models.densenet161, - 'inception': models.inception_v3 -} - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='flops-profiler example script') - parser.add_argument('--device', - type=int, - default=0, - help='Device to store the model.') - parser.add_argument('--model', - choices=list(pt_models.keys()), - type=str, - default='resnet18') - args = parser.parse_args() - - model = pt_models[args.model]() - - if torch.cuda.is_available(): - model.cuda(device=args.device) - +with torch.cuda.device(0): + model = models.alexnet() batch_size = 256 - macs, params, steps = get_model_profile(model, # the PyTorch model to be profiled + macs, params = get_model_profile(model=model, # model input_res=(batch_size, 3, 224, 224), # input shape or input to the input_constructor - input_constructor=None, # If specified, the constructor is applied to input_res and the constructor output is used as the input to the model - print_profile=True, # whether to print the model graph with the profile annotated. Defaults to True - print_aggregated_profile=True, # whether to print the aggregated profile for top modules. Defaults to True - module_depth=-1, # the depth into the nested modules. Defaults to -1 (the inner most modules) + input_constructor=None, # if specified, a constructor taking input_res is used as input to the model + print_profile=True, # prints the model graph with the measured profile attached to each module + detailed=True, # print the detailed profile + module_depth=-1, # depth into the nested modules with -1 being the inner most modules top_modules=3, # the number of top modules to print aggregated profile - warm_up=10, # the number of warm-up steps before measuring the time of each module. Defaults to 5 - num_steps=10, # the number of steps to profile. Defaults to 10 - as_strings=True, # whether to print the output as strings (e.g. 1k). Defaults to True - ignore_modules=None) # the list of modules to ignore during profiling. Defaults to None - - print("{:<30} {:<8}".format("Batch size: ", batch_size)) - print('{:<30} {:<8}'.format('Number of MACs: ', macs)) - print('{:<30} {:<8}'.format('Number of parameters: ', params)) - print('{:<30} {:<8}'.format('Number of steps profiled: ', steps)) - -# Output: -# Number of MACs: 466.48 GMACs -# Number of parameters: 11.69 M + warm_up=10, # the number of warm-ups before measuring the time of each module + as_string=True, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k) + ignore_modules=None) # the list of modules to ignore in the profiling +``` +An example output: + +```shell +-------------------------- DeepSpeed Flops Profiler -------------------------- +Summary of forward pass: +Profile step: 10 +Number of parameters: 61.1 M +Number of multiply-accumulate operations (MACs): 183.18 G +Number of floating point operations ( = 2 * MACs): 366.36 G +Latency: 22.13 ms +Floating point operations per second(FLOPS): 16.56 TFLOPS + +----------------------------- Aggregated Profile ----------------------------- +Top 3 modules in MACs at depth 2 are {'Conv2d': '167.95 GMACs', 'Linear': '15.01 GMACs', 'ReLU': '126.26 MMACs'} +Top 3 modules in params at depth 2 are {'Linear': '58.63 M', 'Conv2d': '2.47 M', 'ReLU': '0'} +Top 3 modules in latency at depth 2 are {'Conv2d': '13.96 ms', 'Linear': '6.23 ms', 'ReLU': '730.75 us'} + +------------------------------ Detailed Profile ------------------------------ +Each module profile is listed after its name in the follwing order: +number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency). +Note: +1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'. +2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught. + +AlexNet( + 61.1 M, 100.00% Params, 183.18 GMACs, 100.00% MACs, 22.13 ms, 100.00% latency, 16.56 TFLOPS, + (features): Sequential( + 2.47 M, 4.04% Params, 168.17 GMACs, 91.81% MACs, 15.17 ms, 68.57% latency, 22.17 TFLOPS, + (0): Conv2d(23.3 k, 0.04% Params, 18.04 GMACs, 9.85% MACs, 633.0 us, 2.86% latency, 57.0 TFLOPS, 3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)) + (1): ReLU(0, 0.00% Params, 49.56 MMACs, 0.03% MACs, 163.79 us, 0.74% latency, 605.17 GFLOPS, inplace=True) + (2): MaxPool2d(0, 0.00% Params, 49.56 MMACs, 0.03% MACs, 159.26 us, 0.72% latency, 622.38 GFLOPS, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) + (3): Conv2d(307.39 k, 0.50% Params, 57.37 GMACs, 31.32% MACs, 6.15 ms, 27.81% latency, 18.64 TFLOPS, 64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) + (4): ReLU(0, 0.00% Params, 35.83 MMACs, 0.02% MACs, 185.01 us, 0.84% latency, 387.34 GFLOPS, inplace=True) + (5): MaxPool2d(0, 0.00% Params, 35.83 MMACs, 0.02% MACs, 134.23 us, 0.61% latency, 533.89 GFLOPS, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) + (6): Conv2d(663.94 k, 1.09% Params, 28.72 GMACs, 15.68% MACs, 389.58 us, 1.76% latency, 147.47 TFLOPS, 192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (7): ReLU(0, 0.00% Params, 16.61 MMACs, 0.01% MACs, 76.53 us, 0.35% latency, 434.15 GFLOPS, inplace=True) + (8): Conv2d(884.99 k, 1.45% Params, 38.29 GMACs, 20.90% MACs, 6.38 ms, 28.82% latency, 12.01 TFLOPS, 384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (9): ReLU(0, 0.00% Params, 11.08 MMACs, 0.01% MACs, 104.43 us, 0.47% latency, 212.12 GFLOPS, inplace=True) + (10): Conv2d(590.08 k, 0.97% Params, 25.53 GMACs, 13.94% MACs, 405.79 us, 1.83% latency, 125.83 TFLOPS, 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (11): ReLU(0, 0.00% Params, 11.08 MMACs, 0.01% MACs, 65.57 us, 0.30% latency, 337.85 GFLOPS, inplace=True) + (12): MaxPool2d(0, 0.00% Params, 11.08 MMACs, 0.01% MACs, 122.07 us, 0.55% latency, 181.46 GFLOPS, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) + ) + (avgpool): AdaptiveAvgPool2d(0, 0.00% Params, 2.36 MMACs, 0.00% MACs, 259.4 us, 1.17% latency, 18.19 GFLOPS, output_size=(6, 6)) + (classifier): Sequential( + 58.63 M, 95.96% Params, 15.01 GMACs, 8.19% MACs, 6.54 ms, 29.54% latency, 4.59 TFLOPS, + (0): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 42.68 us, 0.19% latency, 0.0 FLOPS, p=0.5, inplace=False) + (1): Linear(37.75 M, 61.79% Params, 9.66 GMACs, 5.28% MACs, 301.36 us, 1.36% latency, 64.13 TFLOPS, in_features=9216, out_features=4096, bias=True) + (2): ReLU(0, 0.00% Params, 1.05 MMACs, 0.00% MACs, 79.39 us, 0.36% latency, 26.41 GFLOPS, inplace=True) + (3): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 39.58 us, 0.18% latency, 0.0 FLOPS, p=0.5, inplace=False) + (4): Linear(16.78 M, 27.46% Params, 4.29 GMACs, 2.34% MACs, 234.37 us, 1.06% latency, 36.65 TFLOPS, in_features=4096, out_features=4096, bias=True) + (5): ReLU(0, 0.00% Params, 1.05 MMACs, 0.00% MACs, 56.03 us, 0.25% latency, 37.43 GFLOPS, inplace=True) + (6): Linear(4.1 M, 6.71% Params, 1.05 GMACs, 0.57% MACs, 5.69 ms, 25.72% latency, 368.42 GFLOPS, in_features=4096, out_features=1000, bias=True) + ) +) +------------------------------------------------------------------------------ ``` -##### Bert model example: +##### Example: Bert ```python from functools import partial - import torch from transformers import BertForSequenceClassification, BertTokenizer - from deepspeed.profiling.flops_profiler import get_model_profile def bert_input_constructor(input_shape, tokenizer): - inp_seq = "" - for _ in range(input_shape[1] - 2): # there are two special tokens [CLS] and [SEP] - inp_seq += tokenizer.pad_token # let's use pad token to form a fake - # sequence for subsequent flops calculation - - inputs = tokenizer([inp_seq] * input_shape[0], + fake_seq = "" + for _ in range(input_shape[1] - 2): # ignore the two special tokens [CLS] and [SEP] + fake_seq += tokenizer.pad_token + inputs = tokenizer([fake_seq] * input_shape[0], padding=True, truncation=True, return_tensors="pt") labels = torch.tensor([1] * input_shape[0]) - # Batch size input_shape[0], sequence length input_shape[128] inputs = dict(inputs) inputs.update({"labels": labels}) return inputs -if __name__ == '__main__': - bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') +with torch.cuda.device(0): + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertForSequenceClassification.from_pretrained('bert-base-uncased') - macs, params, steps = get_model_profile( - model, - (2, 128), - input_constructor=partial(bert_input_constructor, - tokenizer=bert_tokenizer), - print_profile=True, - print_aggregated_profile=True, + batch_size = 4 + seq_len = 128 + enable_profile = True + if enable_profile: + macs, params = get_model_profile( + model, + (batch_size, seq_len), + input_constructor=partial(bert_input_constructor, + tokenizer=tokenizer), + print_profile=True, + detailed=True, + ) + else: + inputs = bert_input_constructor((batch_size, seq_len), tokenizer) + outputs = model(inputs) +``` + +An example output: + +``` +-------------------------- DeepSpeed Flops Profiler -------------------------- +Summary of forward pass: +Profile step: 1 +Number of parameters: 109.48 M +Number of multiply-accumulate operations (MACs): 43.5 G +Number of floating point operations ( = 2 * MACs): 87.0 G +Latency: 393.7 ms +Floating point operations per second(FLOPS): 220.97 GFLOPS + +----------------------------- Aggregated Profile ----------------------------- +Top 3 modules in MACs at depth 7 are {'Linear': '14.5 GMACs', 'Dropout': '0 MACs', 'LayerNorm': '0 MACs'} +Top 3 modules in params at depth 7 are {'Linear': '28.35 M', 'LayerNorm': '18.43 k', 'Dropout': '0'} +Top 3 modules in latency at depth 7 are {'Linear': '153.7 ms', 'LayerNorm': '4.74 ms', 'Dropout': '597.95 us'} + +------------------------------ Detailed Profile ------------------------------ +Each module profile is listed after its name in the follwing order: +number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency). +Note: +1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'. +2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught. + +BertForSequenceClassification( + 109.48 M, 100.00% Params, 43.5 GMACs, 100.00% MACs, 393.7 ms, 100.00% latency, 220.97 GFLOPS, + (bert): BertModel( + 109.48 M, 100.00% Params, 43.5 GMACs, 100.00% MACs, 393.38 ms, 99.92% latency, 221.15 GFLOPS, + (embeddings): BertEmbeddings( + 23.84 M, 21.77% Params, 0 MACs, 0.00% MACs, 1.79 ms, 0.45% latency, 0.0 FLOPS, + (word_embeddings): Embedding(23.44 M, 21.41% Params, 0 MACs, 0.00% MACs, 485.18 us, 0.12% latency, 0.0 FLOPS, 30522, 768, padding_idx=0) + (position_embeddings): Embedding(393.22 k, 0.36% Params, 0 MACs, 0.00% MACs, 111.1 us, 0.03% latency, 0.0 FLOPS, 512, 768) + (token_type_embeddings): Embedding(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 215.53 us, 0.05% latency, 0.0 FLOPS, 2, 768) + (LayerNorm): LayerNorm(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 386.95 us, 0.10% latency, 0.0 FLOPS, (768,), eps=1e-12, elementwise_affine=True) + (dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 20.27 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False) + ) + (encoder): BertEncoder( + 85.05 M, 77.69% Params, 43.5 GMACs, 99.99% MACs, 391.03 ms, 99.32% latency, 222.47 GFLOPS, + (layer): ModuleList( + 85.05 M, 77.69% Params, 43.5 GMACs, 99.99% MACs, 390.82 ms, 99.27% latency, 222.59 GFLOPS, + (0): BertLayer( + 7.09 M, 6.47% Params, 3.62 GMACs, 8.33% MACs, 31.91 ms, 8.10% latency, 227.21 GFLOPS, + (attention): BertAttention( + 2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 16.39 ms, 4.16% latency, 147.47 GFLOPS, + (self): BertSelfAttention( + 1.77 M, 1.62% Params, 906.76 MMACs, 2.08% MACs, 15.07 ms, 3.83% latency, 120.36 GFLOPS, + (query): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 3.66 ms, 0.93% latency, 164.91 GFLOPS, in_features=768, out_features=768, bias=True) + (key): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 3.72 ms, 0.94% latency, 162.36 GFLOPS, in_features=768, out_features=768, bias=True) + (value): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 4.52 ms, 1.15% latency, 133.65 GFLOPS, in_features=768, out_features=768, bias=True) + (dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 24.08 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False) + ) + (output): BertSelfOutput( + 592.13 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 1.29 ms, 0.33% latency, 469.21 GFLOPS, + (dense): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 504.26 us, 0.13% latency, 1.2 TFLOPS, in_features=768, out_features=768, bias=True) + (LayerNorm): LayerNorm(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 437.97 us, 0.11% latency, 0.0 FLOPS, (768,), eps=1e-12, elementwise_affine=True) + (dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 21.93 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False) + ) + ) + (intermediate): BertIntermediate( + 2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 9.57 ms, 2.43% latency, 252.35 GFLOPS, + (dense): Linear(2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 8.75 ms, 2.22% latency, 276.11 GFLOPS, in_features=768, out_features=3072, bias=True) + ) + (output): BertOutput( + 2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 5.77 ms, 1.47% latency, 418.39 GFLOPS, + (dense): Linear(2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 5.13 ms, 1.30% latency, 471.15 GFLOPS, in_features=3072, out_features=768, bias=True) + (LayerNorm): LayerNorm(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 310.9 us, 0.08% latency, 0.0 FLOPS, (768,), eps=1e-12, elementwise_affine=True) + (dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 29.8 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False) + ) + ) + ... + (11): BertLayer(...) + ) ) - print("{:<30} {:<8}".format("Number of multiply-adds: ", macs)) - print("{:<30} {:<8}".format("Number of parameters: ", params)) - print("{:<30} {:<8}".format("Number of steps profiled: ", steps)) + (pooler): BertPooler( + 590.59 k, 0.54% Params, 2.36 MMACs, 0.01% MACs, 337.12 us, 0.09% latency, 14.0 GFLOPS, + (dense): Linear(590.59 k, 0.54% Params, 2.36 MMACs, 0.01% MACs, 173.57 us, 0.04% latency, 27.19 GFLOPS, in_features=768, out_features=768, bias=True) + (activation): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 46.01 us, 0.01% latency, 0.0 FLOPS, ) + ) + ) + (dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 19.55 us, 0.00% latency, 0.0 FLOPS, p=0.1, inplace=False) + (classifier): Linear(1.54 k, 0.00% Params, 6.14 KMACs, 0.00% MACs, 56.51 us, 0.01% latency, 217.47 MFLOPS, in_features=768, out_features=2, bias=True) +) +------------------------------------------------------------------------------ +``` + +#### In Model Training Workflow -# Output: -# Number of multiply-adds: 21.74 GMACs -# Number of parameters: 109.48 M +To profile model forward in a training workflow, use the `FlopsProfiler`class. +The `FlopsProfiler`class provides the follwing methods: + * `start_profile()` - starts profiling + * `get_total_flops(as_string=False)` - returns the total number of MACs in the model + * `get_total_params(as_string=False)` - returns the total number of parameters in the model + * `print_model_profile(profile_step=1, module_depth=-1, top_modules=3, detailed=True)` - prints the model profile + * `end_profile()` - ends profiling and cleans up. This should be invoked at the end of the profiling and AFTER `get_total_flops`, `get_total_params` or `print_model_profile`. + +##### Example Training Workflow + +Below is an example of this usage in a typical training workflow. Note that the flops profiler only captures the forward pass in a training step. The flops of a backward pass can be roughly estimated from that of the forward pass (~2x). + +```python +from deepspeed.profiling.flops_profiler import FlopsProfiler + +model = Model() +prof = FlopsProfiler(model) + +profile_step = 5 +print_profile= True + +for step, batch in enumerate(data_loader): + # start profiling at training step "profile_step" + if step == profile_step: + prof.start_profile() + + # forward() method + loss = model(batch) + + # end profiling and print output + if step == profile_step: # if using multi nodes, check global_rank == 0 as well + flops = prof.get_total_flops(as_string=True) + params = prof.get_total_params(as_string=True) + if print_profile: + prof.print_model_profile(profile_step=profile_step) + prof.end_profile() + + # runs backpropagation + loss.backward() + + # weight update + optimizer.step() ``` diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index 855057a4..ca10d76c 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -9,9 +9,9 @@ old_functions = {} class FlopsProfiler(object): - """Measures the time, number of estimated flops and parameters of each module in a PyTorch model. + """Measures the latency, number of estimated floating point operations and parameters of each module in a PyTorch model. - The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how time, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated time, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input. If multiple forward passes are specified by the user to caputre (in the case where the model have different paths or for more accurate timing), the average profile of the multiple batches is taken. + The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how latency, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated latency, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input. Args: object (torch.nn.Module): The PyTorch model to profile. @@ -42,21 +42,16 @@ class FlopsProfiler(object): # if computing the flops of the functionals in a module def pre_hook(module, input): - module_flop_count.clear() - if len(input) > 0: - # Can have multiple inputs, getting the first one - input = input[0] - module.__steps__ += 1 + module_flop_count.append([]) module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook) def post_hook(module, input, output): - module.__flops__ += sum([elem[1] for elem in module_flop_count]) - module_flop_count.clear() + if module_flop_count: + module.__flops__ += sum([elem[1] for elem in module_flop_count[-1]]) + module_flop_count.pop() - has_children = len(module._modules.items()) != 0 - if not has_children: - module.__post_hook_handle__ = module.register_forward_hook(post_hook) + module.__post_hook_handle__ = module.register_forward_hook(post_hook) def start_time_hook(module, input): module.__start_time__ = time.time() @@ -77,8 +72,6 @@ class FlopsProfiler(object): Added attributes and handles are removed recursively on all the modules and the torch.nn.functionals are restored. """ def remove_profile_attrs(module): - if hasattr(module, "__steps__"): - del module.__steps__ if hasattr(module, "__flops__"): del module.__flops__ if hasattr(module, "__params__"): @@ -117,100 +110,91 @@ class FlopsProfiler(object): if p.requires_grad) module.__start_time__ = 0 module.__duration__ = 0 - module.__steps__ = 0 self.model.apply(add_or_reset_attrs) - def get_total_flops(self, in_str=False): + def get_total_flops(self, as_string=False): """Returns the total flops of the model. Args: - in_str (bool, optional): whether to output the flops in string. Defaults to False. + as_string (bool, optional): whether to output the flops as string. Defaults to False. """ - if self.get_total_steps() == 0: - return 0 - sum = 0 - for module in self.model.modules(): - sum += module.__flops__ - total_flops = sum / self.get_total_steps() - return flops_to_string(total_flops) if in_str else total_flops - - def get_total_duration(self, in_str=False): + total_flops = get_module_flops(self.model) + return macs_to_string(total_flops) if as_string else total_flops + + def get_total_duration(self, as_string=False): """Returns the total duration of the model forward pass. Args: - in_str (bool, optional): whether to output the duration in string. Defaults to False. + as_string (bool, optional): whether to output the duration as string. Defaults to False. """ - if self.get_total_steps() == 0: - return 0 - total_duration = self.model.__duration__ / self.get_total_steps() - return duration_to_string(total_duration) if in_str else total_duration + total_duration = self.model.__duration__ + return duration_to_string(total_duration) if as_string else total_duration - def get_total_params(self, in_str=False): + def get_total_params(self, as_string=False): """Returns the total parameters of the model. Args: - in_str (bool, optional): whether to output the parameters in string. Defaults to False. + as_string (bool, optional): whether to output the parameters as string. Defaults to False. """ return params_to_string( - self.model.__params__) if in_str else self.model.__params__ - - def get_total_steps(self): - """Returns the total number of steps (or input batches) profiled. - """ - def get_steps(module): - if module.__steps__ == 0: - sum = 0 - for m in module.children(): - sum += get_steps(m) - module.__steps__ = sum - return module.__steps__ - - total_steps = get_steps(self.model) - if total_steps == 0: - print("no step is profiled") - return total_steps + self.model.__params__) if as_string else self.model.__params__ - def print_model_profile(self): + def print_model_profile(self, + profile_step=1, + module_depth=-1, + top_modules=3, + detailed=True): """Prints the model graph with the measured profile attached to each module. """ + total_flops = self.get_total_flops() total_duration = self.get_total_duration() total_params = self.get_total_params() - total_steps = self.get_total_steps() - def accumulate_flops(module): - has_children = len(module._modules.items()) != 0 - if not has_children: - return module.__flops__ - else: - sum = 0 - for m in module.children(): - sum += m.accumulate_flops() - return sum + self.flops = total_flops + self.params = total_params + + print( + "\n-------------------------- DeepSpeed Flops Profiler --------------------------" + ) + print("Summary of forward pass:") + print('{:<30} {:<8}'.format('Profile step: ', profile_step)) + print('{:<30} {:<8}'.format('Number of parameters: ', + params_to_string(total_params))) + print('{:<30} {:<8}'.format('Number of multiply-accumulate operations (MACs): ', + num_to_string(total_flops))) + print('{:<30} {:<8}'.format( + 'Number of floating point operations ( = 2 * MACs): ', + num_to_string(2 * total_flops))) + print('{:<30} {:<8}'.format('Latency: ', duration_to_string(total_duration))) + print('{:<30} {:<8}'.format('Floating point operations per second(FLOPS): ', + flops_to_string(2 * total_flops / total_duration))) def flops_repr(module): params = module.__params__ - flops = 0 if total_steps == 0 else module.accumulate_flops() / total_steps + flops = get_module_flops(module) items = [ params_to_string(params), "{:.2%} Params".format(params / total_params), - flops_to_string(flops), + macs_to_string(flops), "{:.2%} MACs".format(0.0 if total_flops == 0 else flops / total_flops), ] - duration = 0 if total_steps == 0 else module.__duration__ / total_steps + duration = module.__duration__ + if duration == 0: # e.g. ModuleList + for m in module.children(): + duration += m.__duration__ + items.append(duration_to_string(duration)) - items.append("{:.2%} time".format(0.0 if total_duration == 0 else duration / - total_duration)) + items.append( + "{:.2%} latency".format(0.0 if total_duration == 0 else duration / + total_duration)) # flops = 2 * MACs - items.append(("{:.2} TFLOPS".format(0.0 if duration == 0 else 2 * flops / - duration / 10**12))) - items.append(str(module.__steps__)) + items.append(flops_to_string(0.0 if duration == 0 else 2 * flops / duration)) items.append(module.original_extra_repr()) return ", ".join(items) def add_extra_repr(module): - module.accumulate_flops = accumulate_flops.__get__(module) flops_extra_repr = flops_repr.__get__(module) if module.extra_repr != flops_extra_repr: module.original_extra_repr = module.extra_repr @@ -221,13 +205,33 @@ class FlopsProfiler(object): if hasattr(module, "original_extra_repr"): module.extra_repr = module.original_extra_repr del module.original_extra_repr - if hasattr(module, "accumulate_flops"): - del module.accumulate_flops self.model.apply(add_extra_repr) - print(self.model) + + print( + "\n----------------------------- Aggregated Profile -----------------------------" + ) + self.print_model_aggregated_profile(module_depth=module_depth, + top_modules=top_modules) + + if detailed: + print( + "\n------------------------------ Detailed Profile ------------------------------" + ) + print( + "Each module profile is listed after its name in the follwing order: \nnumber of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency)." + ) + print( + "Note: \n1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.\n2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.\n" + ) + print(self.model) + self.model.apply(del_extra_repr) + print( + "------------------------------------------------------------------------------" + ) + def print_model_aggregated_profile(self, module_depth=-1, top_modules=3): """Prints the names of the top top_modules modules in terms of aggregated time, flops, and parameters at depth module_depth. @@ -236,9 +240,6 @@ class FlopsProfiler(object): top_modules (int, optional): the number of top modules to show. Defaults to 3. """ info = {} - total_steps = self.get_total_steps() - if total_steps == 0: - return if not hasattr(self.model, "__flops__"): print( "no __flops__ attribute in the model, call this function after start_profile and before end_profile" @@ -271,7 +272,7 @@ class FlopsProfiler(object): num_items = min(top_modules, len(info[depth])) sort_flops = { - k: flops_to_string(v[0] / total_steps) + k: macs_to_string(v[0]) for k, v in sorted(info[depth].items(), key=lambda item: item[1][0], @@ -285,15 +286,15 @@ class FlopsProfiler(object): reverse=True)[:num_items] } sort_time = { - k: duration_to_string(v[2] / total_steps) + k: duration_to_string(v[2]) for k, v in sorted(info[depth].items(), key=lambda item: item[1][2], reverse=True)[:num_items] } - print(f"Top {num_items} modules in flops at depth {depth} are {sort_flops}") + print(f"Top {num_items} modules in MACs at depth {depth} are {sort_flops}") print(f"Top {num_items} modules in params at depth {depth} are {sort_params}") - print(f"Top {num_items} modules in time at depth {depth} are {sort_time}") + print(f"Top {num_items} modules in latency at depth {depth} are {sort_time}") def _prod(dims): @@ -461,7 +462,8 @@ def wrapFunc(func, funcFlopCompute): def newFunc(*args, **kwds): flops = funcFlopCompute(*args, **kwds) - module_flop_count.append((name, flops)) + if module_flop_count: + module_flop_count[-1].append((name, flops)) return oldFunc(*args, **kwds) return newFunc @@ -630,25 +632,61 @@ MODULE_HOOK_MAPPING = { } +def num_to_string(num, precision=2): + if num // 10**9 > 0: + return str(round(num / 10.0**9, precision)) + " G" + elif num // 10**6 > 0: + return str(round(num / 10.0**6, precision)) + " M" + elif num // 10**3 > 0: + return str(round(num / 10.0**3, precision)) + " K" + else: + return str(num) + + +def macs_to_string(macs, units=None, precision=2): + if units is None: + if macs // 10**9 > 0: + return str(round(macs / 10.0**9, precision)) + " GMACs" + elif macs // 10**6 > 0: + return str(round(macs / 10.0**6, precision)) + " MMACs" + elif macs // 10**3 > 0: + return str(round(macs / 10.0**3, precision)) + " KMACs" + else: + return str(macs) + " MACs" + else: + if units == "GMACs": + return str(round(macs / 10.0**9, precision)) + " " + units + elif units == "MMACs": + return str(round(macs / 10.0**6, precision)) + " " + units + elif units == "KMACs": + return str(round(macs / 10.0**3, precision)) + " " + units + else: + return str(macs) + " MACs" + + def flops_to_string(flops, units=None, precision=2): if units is None: + if flops // 10**12 > 0: + return str(round(flops / 10.0**12, precision)) + " TFLOPS" if flops // 10**9 > 0: - return str(round(flops / 10.0**9, precision)) + " GMACs" + return str(round(flops / 10.0**9, precision)) + " GFLOPS" elif flops // 10**6 > 0: - return str(round(flops / 10.0**6, precision)) + " MMACs" + return str(round(flops / 10.0**6, precision)) + " MFLOPS" elif flops // 10**3 > 0: - return str(round(flops / 10.0**3, precision)) + " KMACs" + return str(round(flops / 10.0**3, precision)) + " KFLOPS" else: - return str(flops) + " MACs" + return str(flops) + " FLOPS" else: - if units == "GMACs": + if units == "TFLOPS": + return str(round(flops / 10.0**12, precision)) + " " + units + if units == "GFLOPS": return str(round(flops / 10.0**9, precision)) + " " + units - elif units == "MMACs": + elif units == "MFLOPS": return str(round(flops / 10.0**6, precision)) + " " + units - elif units == "KMACs": + elif units == "KFLOPS": return str(round(flops / 10.0**3, precision)) + " " + units else: - return str(flops) + " MACs" + return str(flops) + " FLOPS" def params_to_string(params_num, units=None, precision=2): @@ -687,32 +725,40 @@ def duration_to_string(duration, units=None, precision=2): return str(round(duration, precision)) + " s" + # can not iterate over all submodules using self.model.modules() + # since modules() returns duplicate modules only once +def get_module_flops(module): + sum = module.__flops__ + # iterate over immediate children modules + for child in module.children(): + sum += get_module_flops(child) + return sum + + def get_model_profile( model, input_res, input_constructor=None, print_profile=True, - print_aggregated_profile=True, + detailed=True, module_depth=-1, top_modules=3, - warm_up=5, - num_steps=10, - as_strings=True, + warm_up=1, + as_string=True, ignore_modules=None, ): - """Returns the total flops, parameters, and profiled steps of a model. + """Returns the total MACs and parameters of a model. Args: model ([torch.nn.Module]): the PyTorch model to be profiled. input_res (list): input shape or input to the input_constructor input_constructor (func, optional): input constructor. If specified, the constructor is applied to input_res and the constructor output is used as the input to the model. Defaults to None. - print_profile (bool, optional): whether to print the model graph with the profile annotated. Defaults to True. - print_aggregated_profile (bool, optional): whether to print the aggregated profile for top modules. Defaults to True. + print_profile (bool, optional): whether to print the model profile. Defaults to True. + detailed (bool, optional): whether to print the detailed model profile. Defaults to True. module_depth (int, optional): the depth into the nested modules. Defaults to -1 (the inner most modules). top_modules (int, optional): the number of top modules to print in the aggregated profile. Defaults to 3. - warm_up (int, optional): the number of warm-up steps before measuring the time of each module. Defaults to 5. - num_steps (int, optional): the number of steps to profile. Defaults to 10. - as_strings (bool, optional): whether to print the output as strings. Defaults to True. + warm_up (int, optional): the number of warm-up steps before measuring the latency of each module. Defaults to 1. + as_string (bool, optional): whether to print the output as string. Defaults to True. ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None. """ assert type(input_res) is tuple @@ -738,32 +784,31 @@ def get_model_profile( prof.start_profile(ignore_list=ignore_modules) - for _ in range(num_steps): - if input_constructor: - input = input_constructor(input_res) - _ = model(**input) - else: - try: - batch = torch.ones(()).new_empty( - (*input_res, - ), - dtype=next(model.parameters()).dtype, - device=next(model.parameters()).device, - ) - except StopIteration: - batch = torch.ones(()).new_empty((*input_res, )) - _ = model(batch) + if input_constructor: + input = input_constructor(input_res) + _ = model(**input) + else: + try: + batch = torch.ones(()).new_empty( + (*input_res, + ), + dtype=next(model.parameters()).dtype, + device=next(model.parameters()).device, + ) + except StopIteration: + batch = torch.ones(()).new_empty((*input_res, )) + _ = model(batch) flops = prof.get_total_flops() params = prof.get_total_params() - steps = prof.get_total_steps() if print_profile: - prof.print_model_profile() - if print_aggregated_profile: - prof.print_model_aggregated_profile(module_depth=module_depth, - top_modules=top_modules) + prof.print_model_profile(profile_step=warm_up, + module_depth=module_depth, + top_modules=top_modules, + detailed=detailed) + prof.end_profile() - if as_strings: - return flops_to_string(flops), params_to_string(params), steps + if as_string: + return macs_to_string(flops), params_to_string(params) - return flops, params, steps + return flops, params diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 78c385d4..80b6013c 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -277,11 +277,8 @@ class DeepSpeedEngine(Module): def flops_profiler_enabled(self): return self._config.flops_profiler_config.enabled - def flops_profiler_start_step(self): - return self._config.flops_profiler_config.start_step - - def flops_profiler_end_step(self): - return self._config.flops_profiler_config.end_step + def flops_profiler_profile_step(self): + return self._config.flops_profiler_config.profile_step def flops_profiler_module_depth(self): return self._config.flops_profiler_config.module_depth @@ -289,6 +286,9 @@ class DeepSpeedEngine(Module): def flops_profiler_top_modules(self): return self._config.flops_profiler_config.top_modules + def flops_profiler_detailed(self): + return self._config.flops_profiler_config.detailed + def memory_breakdown(self): return self._config.memory_breakdown @@ -799,30 +799,11 @@ class DeepSpeedEngine(Module): **kwargs: variable length keyword arguments """ if self.flops_profiler_enabled( - ) and self.global_steps == self.flops_profiler_start_step( + ) and self.global_steps == self.flops_profiler_profile_step( ) and self.global_rank == 0: self.flops_profiler = FlopsProfiler(self.module) self.flops_profiler.start_profile(ignore_list=None) - if self.flops_profiler_enabled( - ) and self.global_steps == self.flops_profiler_end_step( - ) and self.global_rank == 0: - print('{:<30} {:<8}'.format( - 'Number of multiply-adds: ', - self.flops_profiler.get_total_flops(in_str=False))) - print('{:<30} {:<8}'.format( - 'Number of parameters: ', - self.flops_profiler.get_total_params(in_str=False))) - print('{:<30} {:<8}'.format('Number of steps profiled: ', - self.flops_profiler.get_total_steps())) - self.flops_profiler.print_model_profile() - self.flops_profiler.print_model_aggregated_profile( - module_depth=self.flops_profiler_module_depth(), - top_modules=self.flops_profiler_top_modules()) - self.flops_profiler.flops = self.flops_profiler.get_total_flops() - self.flops_profiler.params = self.flops_profiler.get_total_params() - self.flops_profiler.end_profile() - if self.module.training and self.progressive_layer_drop: kwargs.update(self.progressive_layer_drop.get_state()) @@ -838,6 +819,16 @@ class DeepSpeedEngine(Module): self.timers('forward').stop() self.timers('forward_microstep').stop() + if self.flops_profiler_enabled( + ) and self.global_steps == self.flops_profiler_profile_step( + ) and self.global_rank == 0: + self.flops_profiler.print_model_profile( + profile_step=self.global_steps, + module_depth=self.flops_profiler_module_depth(), + top_modules=self.flops_profiler_top_modules(), + detailed=self.flops_profiler_detailed()) + self.flops_profiler.end_profile() + return loss def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): diff --git a/docs/_config.yml b/docs/_config.yml index 4d64e8ca..20f0f73c 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -41,6 +41,7 @@ collections: - 1Cycle.md - lrrt.md - zero.md + - flops-profiler.md defaults: - scope: diff --git a/docs/_data/navigation.yml b/docs/_data/navigation.yml index 5cfd3d2a..084d0560 100755 --- a/docs/_data/navigation.yml +++ b/docs/_data/navigation.yml @@ -45,6 +45,8 @@ lnav: url: /docs/config-json/#zero-optimizations-for-fp16-training - title: "Logging" url: /docs/config-json/#logging + - title: "Flops Profiler" + url: /docs/config-json/#flops-profiler - title: "Activation checkpointing" url: /docs/config-json/#activation-checkpointing - title: "Sparse Attention" @@ -84,5 +86,7 @@ lnav: url: /tutorials/pipeline/ - title: "Progressive Layer Dropping" url: /tutorials/progressive_layer_dropping/ + - title: "Flops Profiler" + url: /tutorials/flops-profiler/ - title: "Contributing" url: /contributing/ diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index f067ec94..5c100ab6 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -9,22 +9,22 @@ title: "DeepSpeed Configuration JSON" ***train\_batch\_size***: [integer] -| Value | Example | -| ------------------------------------------------------------ | ------- | -| The effective training batch size. This is the amount of data samples that leads to one step of model update. ***train\_batch\_size*** is aggregated by the batch size that a single GPU processes in one forward/backward pass (a.k.a., ***train\_step\_batch\_size***), the gradient accumulation steps (a.k.a., ***gradient\_accumulation\_steps***), and the number of GPUs. | `32` | +| Value | Example | +| --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | +| The effective training batch size. This is the amount of data samples that leads to one step of model update. ***train\_batch\_size*** is aggregated by the batch size that a single GPU processes in one forward/backward pass (a.k.a., ***train\_step\_batch\_size***), the gradient accumulation steps (a.k.a., ***gradient\_accumulation\_steps***), and the number of GPUs. | `32` | ***train\_micro\_batch\_size\_per\_gpu***: [integer] -| Description | Default | -| ------------------------------------------------------------ | ---------------------------- | +| Description | Default | +| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------ | | Batch size to be processed by one GPU in one step (without gradient accumulation). When specified, ***gradient\_accumulation\_steps*** is automatically calculated using ***train\_batch\_size*** and number of GPUs. Should not be concurrently specified with ***gradient\_accumulation\_steps*** in the configuration JSON. | ***train\_batch\_size*** value | ***gradient\_accumulation\_steps***: [integer] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Number of training steps to accumulate gradients before averaging and applying them. This feature is sometimes useful to improve scalability since it results in less frequent communication of gradients between steps. Another impact of this feature is the ability to train with larger batch sizes per GPU. When specified, ***train\_step\_batch\_size*** is automatically calculated using ***train\_batch\_size*** and number of GPUs. Should not be concurrently specified with ***train\_step\_batch\_size*** in the configuration JSON. | `1` | +| Description | Default | +| -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | +| Number of training steps to accumulate gradients before averaging and applying them. This feature is sometimes useful to improve scalability since it results in less frequent communication of gradients between steps. Another impact of this feature is the ability to train with larger batch sizes per GPU. When specified, ***train\_step\_batch\_size*** is automatically calculated using ***train\_batch\_size*** and number of GPUs. Should not be concurrently specified with ***train\_step\_batch\_size*** in the configuration JSON. | `1` | @@ -32,10 +32,10 @@ title: "DeepSpeed Configuration JSON" ***optimizer***: [dictionary] -| Fields | Value | Example | -| ------ | ------------------------------------------------------------ | ------------------------------ | -| type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, and **Lamb** optimizers and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` | -| params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). | `{"lr": 0.001, "eps": 1e-8}` | +| Fields | Value | Example | +| ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------- | +| type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, and **Lamb** optimizers and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` | +| params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). | `{"lr": 0.001, "eps": 1e-8}` | Example of ***optimizer*** with Adam @@ -56,7 +56,7 @@ title: "DeepSpeed Configuration JSON" The Adam optimizer also supports the following two params keys/values in addition to the standard parameters from [torch.optim.Adam](https://pytorch.org/docs/stable/_modules/torch/optim/adam.html#Adam): | "params" key | Description | Default | -| ------------- | --------------------------------------------------------------------------- | --------| +| ------------- | --------------------------------------------------------------------------- | ------- | | torch\_adam | Use torch's implementation of adam instead of our fused adam implementation | false | | adam\_w\_mode | Apply L2 regularization (also known as AdamW) | true | @@ -83,10 +83,10 @@ The Adam optimizer also supports the following two params keys/values in additio ***scheduler***: [dictionary] -| Fields | Value | Example | -| ------ | ------------------------------------------------------------ | ------------------------------ | -| type | The scheduler name. See [here](https://deepspeed.readthedocs.io/en/latest/deepspeed.pt.html) for list of support schedulers. | `"WarmupLR"` | -| params | Dictionary of parameters to instantiate scheduler. The parameter names should match scheduler constructor signature. | `{"warmup_min_lr": 0, "warmup_max_lr": 0.001}` | +| Fields | Value | Example | +| ------ | ---------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------- | +| type | The scheduler name. See [here](https://deepspeed.readthedocs.io/en/latest/deepspeed.pt.html) for list of support schedulers. | `"WarmupLR"` | +| params | Dictionary of parameters to instantiate scheduler. The parameter names should match scheduler constructor signature. | `{"warmup_min_lr": 0, "warmup_max_lr": 0.001}` | Example of ***scheduler*** @@ -105,27 +105,27 @@ Example of ***scheduler*** ***fp32\_allreduce***: [boolean] -| Description | Default | -| ------------------------------------ | ------- | -| During gradient averaging perform allreduce with 32 bit values | `false` | +| Description | Default | +| -------------------------------------------------------------- | ------- | +| During gradient averaging perform allreduce with 32 bit values | `false` | ***prescale\_gradients***: [boolean] | Description | Default | | -------------------------------------- | ------- | -| Scale gradients before doing allreduce | `false` | +| Scale gradients before doing allreduce | `false` | ***gradient_predivide_factor***: [float] -| Description | Default | -| ---------------------------- | ------- | -| Before gradient averaging predivide gradients by a specified factor, can sometimes help with fp16 stability when scaling to large numbers of GPUs | `1.0` +| Description | Default | +| ------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | +| Before gradient averaging predivide gradients by a specified factor, can sometimes help with fp16 stability when scaling to large numbers of GPUs | `1.0` | ***sparse\_gradients***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Enable sparse compression of [torch.nn.Embedding](https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding) gradients. | `false` | +| Description | Default | +| ------------------------------------------------------------------------------------------------------------------------ | ------- | +| Enable sparse compression of [torch.nn.Embedding](https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding) gradients. | `false` | ### FP16 training options @@ -134,8 +134,8 @@ Example of ***scheduler*** ***fp16***: [dictionary] -| Description | Default | -| ------------------------------------------------------------ | ------- | +| Description | Default | +| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | | Configuration for using mixed precision/FP16 training that leverages [NVIDIA's Apex package](https://nvidia.github.io/apex/). An example, including the available dictionary keys is illustrated below. NOTE: this does not use Apex's AMP mode that allows for more flexibility in mixed precision training modes, this mode is similar to AMP's O2 mode. Please see AMP support below if you want to use more complex mixed precision modes. If you want to use ZeRO (currently) you must use this mode. | None | ```json @@ -151,39 +151,39 @@ Example of ***scheduler*** ***fp16:enabled***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| ***enabled*** is a **fp16** parameter indicating whether or not FP16 training enabled. | `false` | +| Description | Default | +| -------------------------------------------------------------------------------------- | ------- | +| ***enabled*** is a **fp16** parameter indicating whether or not FP16 training enabled. | `false` | ***fp16:loss\_scale***: [float] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| ***loss\_scale*** is a ***fp16*** parameter representing the loss scaling value for FP16 training. The default value of 0.0 results in dynamic loss scaling, otherwise the value will be used for static fixed loss scaling. | `0.0` | +| Description | Default | +| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | +| ***loss\_scale*** is a ***fp16*** parameter representing the loss scaling value for FP16 training. The default value of 0.0 results in dynamic loss scaling, otherwise the value will be used for static fixed loss scaling. | `0.0` | ***fp16:initial\_scale\_power***: [integer] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| ***initial\_loss\_scale\_power*** is a **fp16** parameter representing the power of the initial dynamic loss scale value. The actual loss scale is computed as 2***initial\_loss\_scale\_power***. | `32` | +| Description | Default | +| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | +| ***initial\_loss\_scale\_power*** is a **fp16** parameter representing the power of the initial dynamic loss scale value. The actual loss scale is computed as 2***initial\_loss\_scale\_power***. | `32` | ***fp16:loss\_scale\_window***: [integer] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| ***loss\_scale\_window*** is a **fp16** parameter representing the window over which to raise/lower the dynamic loss scale value. | `1000` | +| Description | Default | +| --------------------------------------------------------------------------------------------------------------------------------- | ------- | +| ***loss\_scale\_window*** is a **fp16** parameter representing the window over which to raise/lower the dynamic loss scale value. | `1000` | ***fp16:hysteresis***: [integer] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| ***hysteresis*** is a **fp16** parameter representing the delay shift in dynamic loss scaling. | `2` | +| Description | Default | +| ---------------------------------------------------------------------------------------------- | ------- | +| ***hysteresis*** is a **fp16** parameter representing the delay shift in dynamic loss scaling. | `2` | ***fp16:min\_loss\_scale***: [integer] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| ***min\_loss\_scale*** is a **fp16** parameter representing the minimum dynamic loss scale value. | `1000` | +| Description | Default | +| -------------------------------------------------------------------------------------------------- | ------- | +| ***min\_loss\_scale*** is a **fp16** parameter representing the minimum dynamic loss scale value. | `1000` | ### Automatic mixed precision (AMP) training options @@ -192,8 +192,8 @@ Example of ***scheduler*** ***amp***: [dictionary] -| Description | Default | -| ------------------------------------------------------------ | ------- | +| Description | Default | +| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | | Configuration for using automatic mixed precision (AMP) training that leverages [NVIDIA's Apex AMP package](https://nvidia.github.io/apex/). An example, including the available dictionary keys is illustrated below. Is not compatible with `fp16` mode above or ZeRO. Any parameters outside of "enabled" will be passed to AMP's initialize call, see the API and descriptions here at the [apex.amp.initialize documentation](https://nvidia.github.io/apex/amp.html#apex.amp.initialize). | None | ```json @@ -207,14 +207,14 @@ Example of ***scheduler*** ***amp:enabled***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| ***enabled*** is an **amp** parameter indicating whether or not AMP training is enabled. | `false` | +| Description | Default | +| ---------------------------------------------------------------------------------------- | ------- | +| ***enabled*** is an **amp** parameter indicating whether or not AMP training is enabled. | `false` | ***amp params***: [various] -| Description | Default | -| ----------------------------------- | ------- | +| Description | Default | +| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | | Any parameters outside of "enabled" will be passed to AMP's initialize call, see the API and descriptions here at the [apex.amp.initialize documentation](https://nvidia.github.io/apex/amp.html#apex.amp.initialize). | None | ### Gradient Clipping @@ -223,7 +223,7 @@ Example of ***scheduler*** | Description | Default | | ----------------------------------- | ------- | -| Enable gradient clipping with value | `0` | +| Enable gradient clipping with value | `0` | @@ -245,78 +245,120 @@ Enabling and configuring ZeRO memory optimizations ***zero\_optimization***: [dictionary] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Enable ZeRO memory optimization wrapper for FP16 Training. Currently compatible only with Adam optimizer. | `false` | +| Description | Default | +| --------------------------------------------------------------------------------------------------------- | ------- | +| Enable ZeRO memory optimization wrapper for FP16 Training. Currently compatible only with Adam optimizer. | `false` | ***stage***: [integer] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Chooses different stages of ZeRO Optimizer. Stage 0, 1, and 2 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitiong, respectively. | `0` | +| Description | Default | +| --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | +| Chooses different stages of ZeRO Optimizer. Stage 0, 1, and 2 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitiong, respectively. | `0` | ***allgather_partitions***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Chooses between allgather collective or a series of broadcast collectives to gather updated parameters from all the GPUs at the end of each step | `true` | +| Description | Default | +| ------------------------------------------------------------------------------------------------------------------------------------------------ | ------- | +| Chooses between allgather collective or a series of broadcast collectives to gather updated parameters from all the GPUs at the end of each step | `true` | ***allgather_bucket_size***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Number of elements allgathered at a time. Limits the memory required for the allgather for large model sizes | `5e8` | +| Description | Default | +| ------------------------------------------------------------------------------------------------------------ | ------- | +| Number of elements allgathered at a time. Limits the memory required for the allgather for large model sizes | `5e8` | ***overlap_comm***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Attempts to overlap the reduction of the gradients with backward computation | `false` | +| Description | Default | +| ---------------------------------------------------------------------------- | ------- | +| Attempts to overlap the reduction of the gradients with backward computation | `false` | ***reduce_scatter***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Uses reduce or reduce scatter instead of allreduce to average gradients | `true` | +| Description | Default | +| ----------------------------------------------------------------------- | ------- | +| Uses reduce or reduce scatter instead of allreduce to average gradients | `true` | ***reduce_bucket_size***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Number of elements reduced/allreduced at a time. Limits the memory required for the allgather for large model sizes | `5e8` | +| Description | Default | +| ------------------------------------------------------------------------------------------------------------------- | ------- | +| Number of elements reduced/allreduced at a time. Limits the memory required for the allgather for large model sizes | `5e8` | ***contiguous_gradients***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Copies the gradients to a contiguous buffer as they are produced. Avoids memory fragmentation during backward pass. Only useful when running very large models. | `False` | +| Description | Default | +| --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | +| Copies the gradients to a contiguous buffer as they are produced. Avoids memory fragmentation during backward pass. Only useful when running very large models. | `False` | ***cpu_offload***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Enable offloading of optimizer memory and computation to CPU. This frees up GPU memory for larger models or batch sizes. | `False` | +| Description | Default | +| ------------------------------------------------------------------------------------------------------------------------ | ------- | +| Enable offloading of optimizer memory and computation to CPU. This frees up GPU memory for larger models or batch sizes. | `False` | ### Logging ***steps\_per\_print***: [integer] -| Description | Default | -| ----------- | ------- | -| Print train loss every N steps | `10` | +| Description | Default | +| ------------------------------ | ------- | +| Print train loss every N steps | `10` | ***wall\_clock\_breakdown***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Enable timing of the latency of forward/backward/update training phases | `false` | +| Description | Default | +| ----------------------------------------------------------------------- | ------- | +| Enable timing of the latency of forward/backward/update training phases | `false` | ***dump_state***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Print out state information of DeepSpeed object after initialization | `false` | +| Description | Default | +| -------------------------------------------------------------------- | ------- | +| Print out state information of DeepSpeed object after initialization | `false` | + +### Flops Profiler +```json +{ + "flops_profiler": { + "enabled": true, + "profile_step": 1, + "module_depth": -1, + "top_modules": 3, + "detailed": true, + } +} +``` +***enabled***: [boolean] + +| Description | Default | +| --------------------------- | ------- | +| Enables the flops profiler. | `false` | + +***profile\_step***: [integer] + +| Description | Default | +| --------------------------------------------------------------------------------------------------------------- | ------- | +| The global training step at which to profile. Note that warm up steps are needed for accurate time measurement. | `1` | + +***module\_depth***: [integer] + +| Description | Default | +| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | +| The depth of the model at which to print the aggregated module information. When set to `-1`, it prints information on the innermost modules (with the maximum depth). | `-1` | + +***top\_modules***: [integer] + +| Description | Default | +| ---------------------------------------------------------------------------- | ------- | +| Limits the aggregated profile output to the number of top modules specified. | `3` | + +***detailed***: [boolean] + +| Description | Default | +| -------------------------------------------- | ------- | +| Whether to print the detailed model profile. | `true` | ### Activation Checkpointing ```json @@ -331,61 +373,61 @@ Enabling and configuring ZeRO memory optimizations ``` ***partition\_activations***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Enables partition activation when used with model parallelism | `false` | +| Description | Default | +| ------------------------------------------------------------- | ------- | +| Enables partition activation when used with model parallelism | `false` | ***cpu\_checkpointing***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Offloads partitioned activations to CPU if partition_activations is enabled| `false` | +| Description | Default | +| --------------------------------------------------------------------------- | ------- | +| Offloads partitioned activations to CPU if partition_activations is enabled | `false` | ***contiguous\_memory\_optimization***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Copies partitioned activations so that they are contiguous in memory | `false` | +| Description | Default | +| -------------------------------------------------------------------- | ------- | +| Copies partitioned activations so that they are contiguous in memory | `false` | ***number_checkpoints***: [integer] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Total number of activation checkpoints used to allocate memory buffer for contiguous_memoty_optimization | `None` | +| Description | Default | +| -------------------------------------------------------------------------------------------------------- | ------- | +| Total number of activation checkpoints used to allocate memory buffer for contiguous_memoty_optimization | `None` | ***synchronize\_checkpoint\_boundary***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Inserts torch.cuda.synchronize() at each checkpoint boundary. | `false` | +| Description | Default | +| ------------------------------------------------------------- | ------- | +| Inserts torch.cuda.synchronize() at each checkpoint boundary. | `false` | ***profile***: [boolean] -| Description | Default | -| ------------------------------------------------------------ | ------- | -| Logs the forward and backward time for each checkpoint function | `false` | +| Description | Default | +| --------------------------------------------------------------- | ------- | +| Logs the forward and backward time for each checkpoint function | `false` | ### Sparse Attention ***sparse\_attention***: [dictionary] -| Fields | Value | Example | -| ------ | ------------------------------------------------------------ | ------------------------------ | -| mode | A string determining sparsity structure type. Deepspeed currently supports `"dense"`, `"fixed"`, `"bigbird"`, `"bslongformer"`, and `"variable"`. | `"fixed"` | -| block | An integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. | 16 | -| different\_layout\_per\_head | A boolean determining if each head should be assigned a different sparsity layout; this will be satisfied based on availability. | false | -| num\_local\_blocks | An integer determining the number of random blocks in each block row; only used in `"fixed"` mode. | 4 | -| num\_global\_blocks | An integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention; used in `"fixed"` and `"bigbird"` modes. | 1 | -| attention | A string determining attention type. Attention can be `"unidirectional"`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty. Or it can be `"bidirectional"`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular; used in `"fixed"` and `"variable"` modes. | `"bidirectional"` | -| horizontal\_global\_attention | A boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `"bidirectional"`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks; used in `"fixed"` and `"variable"` modes. | false | -| num\_different\_global\_patterns | An integer determining number of different global attentions layouts. While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative; used only in `"fixed"` mode. | 4 | -| num\_random\_blocks | An integer determining the number of random blocks in each block row; used in `"variable"` and `"bigbird"` modes. | 0 | -| local\_window\_blocks | A list of integers determining the number of blocks in each local attention window. It assumes first number determines # of blocks in the first local window, second the second window, ..., and the last number determines the number of blocks in the remaining local windows; only used in `"variable"` mode. | [4] | -| global\_block\_indices | A list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Notice that if global\_block\_end\_indices parameter is set, this parameter is used as starting index of each global window; used in `"variable"` and `"bslongformer"` modes. | [0] | -| global\_block\_end\_indices | A list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global\_block\_indices parameter, and combining this two parameters, for each index i, blocks from global\_block\_indices[i] to global\_block\_end\_indices[i], exclusive, are considered as global attention; used in `"variable"` and `"bslongformer"` modes. | None | -| num\_sliding\_window\_blocks | An integer determining the number of blocks in sliding local attention window; used in `"bigbird"` and `"bslongformer"` modes. | 3 | +| Fields | Value | Example | +| -------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------- | +| mode | A string determining sparsity structure type. Deepspeed currently supports `"dense"`, `"fixed"`, `"bigbird"`, `"bslongformer"`, and `"variable"`. | `"fixed"` | +| block | An integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. | 16 | +| different\_layout\_per\_head | A boolean determining if each head should be assigned a different sparsity layout; this will be satisfied based on availability. | false | +| num\_local\_blocks | An integer determining the number of random blocks in each block row; only used in `"fixed"` mode. | 4 | +| num\_global\_blocks | An integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention; used in `"fixed"` and `"bigbird"` modes. | 1 | +| attention | A string determining attention type. Attention can be `"unidirectional"`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty. Or it can be `"bidirectional"`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular; used in `"fixed"` and `"variable"` modes. | `"bidirectional"` | +| horizontal\_global\_attention | A boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `"bidirectional"`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks; used in `"fixed"` and `"variable"` modes. | false | +| num\_different\_global\_patterns | An integer determining number of different global attentions layouts. While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative; used only in `"fixed"` mode. | 4 | +| num\_random\_blocks | An integer determining the number of random blocks in each block row; used in `"variable"` and `"bigbird"` modes. | 0 | +| local\_window\_blocks | A list of integers determining the number of blocks in each local attention window. It assumes first number determines # of blocks in the first local window, second the second window, ..., and the last number determines the number of blocks in the remaining local windows; only used in `"variable"` mode. | [4] | +| global\_block\_indices | A list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Notice that if global\_block\_end\_indices parameter is set, this parameter is used as starting index of each global window; used in `"variable"` and `"bslongformer"` modes. | [0] | +| global\_block\_end\_indices | A list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global\_block\_indices parameter, and combining this two parameters, for each index i, blocks from global\_block\_indices[i] to global\_block\_end\_indices[i], exclusive, are considered as global attention; used in `"variable"` and `"bslongformer"` modes. | None | +| num\_sliding\_window\_blocks | An integer determining the number of blocks in sliding local attention window; used in `"bigbird"` and `"bslongformer"` modes. | 3 | Example of ***sparse\_attention*** diff --git a/docs/_pages/features.md b/docs/_pages/features.md index 3ad1c8e9..08f2bf22 100755 --- a/docs/_pages/features.md +++ b/docs/_pages/features.md @@ -113,7 +113,7 @@ to contiguous buffers preventing memory fragmentation. ## ZeRO-Offload -ZeRO-Offload pushes the boundary of the maximum model size that can be trained efficiently using minimal GPU resources, by exploiting computational and memory resources on both GPUs and their host CPUs. It allows training up to 13-billion-parameter models on a single NVIDIA V100 GPU, 10x larger than the state-of-the-art, while retaining high training throughput of over 30 teraflops per GPU. +ZeRO-Offload pushes the boundary of the maximum model size that can be trained efficiently using minimal GPU resources, by exploiting computational and memory resources on both GPUs and their host CPUs. It allows training up to 13-billion-parameter models on a single NVIDIA V100 GPU, 10x larger than the state-of-the-art, while retaining high training throughput of over 30 teraflops per GPU. For more details see the [ZeRO-Offload release blog]( https://www.microsoft.com/en-us/research/?p=689370&secret=iSlooB), and [tutorial](/tutorials/zero-offload/) on integration with DeepSpeed. @@ -133,7 +133,7 @@ micro-batch, specially when the number of micro-batches per effective batch is l During back propagation, DeepSpeed can overlap the communication required for averaging parameter gradients that have already been computed with the ongoing gradient computation. This computation-communication overlap allows DeepSpeed to achieve higher throughput even -at modest batch sizes. +at modest batch sizes. ## Training Features @@ -240,19 +240,53 @@ comes to data loading. Users simply provide a PyTorch dataset, and DeepSpeed dat can automatically handle batch creation appropriately. ## Performance Analysis and Debugging -For performance debugging, DeepSpeed can give you a detailed breakdown of the time spent -in different parts of the training by simply enabling it in the `deepspeed_config` -file. -Please see the [core API doc](https://deepspeed.readthedocs.io/) for more details. + +DeepSpeed provides a set of tools for performance analysis and debugging. + +### Wall Clock Breakdown + +DeepSpeed provides a detailed breakdown of the time spent +in different parts of the training. +This can be enabled by setting the following in the `deepspeed_config` file. + ```json { "wall_clock_breakdown": true, +} +``` + +### Timing Activiation Checkpoint Functions + +When activiation checkpoingint is enabled, profiling the forward and backward time of each checkpoint function can be enabled in the `deepspeed_config` file. + +```json +{ "activation_checkpointing": { "profile": true } } + ``` + +### Flops Profiler + +The DeepSpeed flops profiler measures the time, flops and parameters of a PyTorch model and shows which modules or layers are the bottleneck. When used with the DeepSpeed runtime, the flops profiler can be configured in the `deepspeed_config` file as follows: + +```json +{ + "flops_profiler": { + "enabled": true, + "profile_step": 1, + "module_depth": -1, + "top_modules": 3, + "detailed": true, + } +} + +``` +The flops profiler can also be used as a standalone package. Please refer to the [Flops Profiler](/tutorials/flops-profiler) tutorial for more details. + ## Sparse Attention DeepSpeed offers sparse attention to support long sequences. Please refer to the [Sparse Attention](/tutorials/sparse-attention/) tutorial. diff --git a/docs/_tutorials/flops-profiler.md b/docs/_tutorials/flops-profiler.md new file mode 100644 index 00000000..00bc07c0 --- /dev/null +++ b/docs/_tutorials/flops-profiler.md @@ -0,0 +1,447 @@ +--- +title: "Flops Profiler" +excerpt: "Measure the parameters, latency, and floating point operations of your model" +--- + +In this tutorial, we introduce the DeepSpeed flops profiler and provide examples of its usage. + + - [Overview](#overview) + - [Supported Models](#supported-models) + - [Multi-GPU, Multi-node Runs](#multi-gpu-multi-node-runs) + - [Usage](#usage) + +## Overview + +The DeepSpeed flops profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. +It shows the parameters, latency, and number of floating point operations of the modules within the model to identify potential bottlenecks. +It also outputs the names of the top `k` modules in terms of aggregated time, flops, and number of parameters at depth `l` with `k` and `l` specified by the user. +The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a standalone package. + +The output profile is computed for each batch of input and printed to the `stdout`. For each module, the measured profile is annotated after the name and is listed in the order of `number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency of the module, percentage of the total latency, floating point operations per second (FLOPS)`. Note that the number of floating point operations is estimated as `2 * MACs` in the profiler (each MAC operation is counted as 2 floating point operations). + +Below is an example output for LeNet5 with batch size 1024: + +```shell +-------------------------- DeepSpeed Flops Profiler -------------------------- +Summary of forward pass: +Profile step: 1 +Number of parameters: 61.71 k +Number of multiply-accumulate operations (MACs): 439.56 M +Number of floating point operations ( = 2 * MACs): 879.12 M +Latency: 25.7 ms +Floating point operations per second(FLOPS): 34.2 GFLOPS + +----------------------------- Aggregated Profile ----------------------------- +Top 3 modules in MACs at depth 2 are {'Conv2d': '421.91 MMACs', 'Linear': '11.18 MMACs', 'AvgPool2d': '6.46 MMACs'} +Top 3 modules in params at depth 2 are {'Conv2d': '50.69 k', 'Linear': '11.01 k', 'Tanh': '0'} +Top 3 modules in latency at depth 2 are {'Conv2d': '11.37 ms', 'Linear': '5.27 ms', 'AvgPool2d': '5.02 ms'} + +------------------------------ Detailed Profile ------------------------------ +Each module profile is listed after its name in the follwing order: +number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency). +Note: +1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'. +2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught. + +LeNet5( + 61.71 k, 100.00% Params, 439.56 MMACs, 100.00% MACs, 25.7 ms, 100.00% latency, 34.2 GFLOPS, + (feature_extractor): Sequential( + 50.69 k, 82.15% Params, 428.37 MMACs, 97.45% MACs, 20.12 ms, 78.27% latency, 42.59 GFLOPS, + (0): Conv2d(156, 0.25% Params, 125.24 MMACs, 28.49% MACs, 9.8 ms, 38.12% latency, 25.56 GFLOPS, 1, 6, kernel_size=(5, 5), stride=(1, 1)) + (1): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 2.85 ms, 11.08% latency, 0.0 FLOPS, ) + (2): AvgPool2d(0, 0.00% Params, 4.82 MMACs, 1.10% MACs, 4.01 ms, 15.59% latency, 2.4 GFLOPS, kernel_size=2, stride=2, padding=0) + (3): Conv2d(2.42 k, 3.92% Params, 247.4 MMACs, 56.28% MACs, 924.83 us, 3.60% latency, 535.02 GFLOPS, 6, 16, kernel_size=(5, 5), stride=(1, 1)) + (4): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 672.1 us, 2.62% latency, 0.0 FLOPS, ) + (5): AvgPool2d(0, 0.00% Params, 1.64 MMACs, 0.37% MACs, 1.01 ms, 3.95% latency, 3.23 GFLOPS, kernel_size=2, stride=2, padding=0) + (6): Conv2d(48.12 k, 77.98% Params, 49.27 MMACs, 11.21% MACs, 647.31 us, 2.52% latency, 152.25 GFLOPS, 16, 120, kernel_size=(5, 5), stride=(1, 1)) + (7): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 82.02 us, 0.32% latency, 0.0 FLOPS, ) + ) + (classifier): Sequential( + 11.01 k, 17.85% Params, 11.18 MMACs, 2.54% MACs, 5.41 ms, 21.06% latency, 4.13 GFLOPS, + (0): Linear(10.16 k, 16.47% Params, 10.32 MMACs, 2.35% MACs, 2.47 ms, 9.60% latency, 8.37 GFLOPS, in_features=120, out_features=84, bias=True) + (1): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 90.12 us, 0.35% latency, 0.0 FLOPS, ) + (2): Linear(850, 1.38% Params, 860.16 KMACs, 0.20% MACs, 2.8 ms, 10.91% latency, 613.62 MFLOPS, in_features=84, out_features=10, bias=True) + ) +) +------------------------------------------------------------------------------ +``` + +## Supported Models + +The flops estimation is partly inspired by [ptflops](https://github.com/sovrasov/flops-counter.pytorch) with the major difference being that the DeepSpeed flops profiler captures ```torch.nn.functional``` invoked in a module to estimate the flops. Thus the DeepSpeed flops profiler allows for customized modules in the model, e.g., ```ParallelTransformerLayerworks, ParallelSelfAttention, RowParallelLinear, etc.``` in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM). This is in contrast to tools that profile at ```torch.nn.module``` level, such as ptflops, which require users to write customized flops calculation functions for each customized module. Finally, the DeepSpeed flops profiler also supports flops computation at module level (for RNNs). + +## Multi-GPU, Multi-node Runs + +For models running on multi-GPU or multi-node, only the model parallelism (e.g. ```--model-parallel-size``` in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)) affects the number of flops and parameters profiled, i.e., +`model_parallel_size * flops = total_flops` and `model_parallel_size * parameters = total_parameters`. The number of GPUs or nodes does not affect the output profile. + + +## Usage + +The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a standalone package. When using DeepSpeed for model training, the flops profiler can be configured in the deepspeed_config file without user code changes. To use the flops profiler outside of the DeepSpeed runtime, one can simply install DeepSpeed and import the flops_profiler package to use the APIs directly. Examples of each usage are given below. + + - [Usage With the DeepSpeed Runtime](#usage-with-the-deepspeed-runtime) + - [Example: Megatron-LM](#example-megatron-lm) + - [Usage Outside the DeepSpeed Runtime](#usage-outside-the-deepspeed-runtime) + - [In Model Inference](#in-model-inference) + - [Example: AlexNet](#example-alexnet) + - [Example: Bert](#example-bert) + - [In Model Training Workflow](#in-model-training-workflow) + - [Example Training Workflow](#example-training-workflow) +### Usage With the DeepSpeed Runtime + +When using DeepSpeed for model training, the flops profiler can be configured in the `deepspeed_config` file. No explict API calls are needed to use the profiler. Refer to [flops profiler](https://www.deepspeed.ai/docs/config-json/#flops-profiler) for details. + + +#### Example: Megatron-LM + +For information on running Megatron-LM with DeepSpeed, please refer to our tutorial [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM) + +The flops profiler can be enabled by adding the following field to the `deepspeed_config` file. + +```json +{ + "flops_profiler": { + "enabled": true, + "profile_step": 1, + "module_depth": -1, + "top_modules": 3, + "detailed": true, + } +} +``` + +An example output of 4-layer Megatron-LM model (`hidden_size = 512, num_attention_heads = 16, batch_size = 8, seq_length = 1024`) is shown below. + +```shell +-------------------------- DeepSpeed Flops Profiler -------------------------- +Summary of forward pass: +Profile step: 1 +Number of parameters: 38.89 M +Number of multiply-accumulate operations (MACs): 314.61 G +Number of floating point operations ( = 2 * MACs): 629.21 G +Latency: 33.81 ms +Floating point operations per second(FLOPS): 18.61 TFLOPS + +----------------------------- Aggregated Profile ----------------------------- +Top 3 modules in MACs at depth 8 are {'ColumnParallelLinear': '60.13 GMACs', 'RowParallelLinear': '42.95 GMACs', 'FusedScaleMaskSoftmax': '536.87 MMACs'} +Top 3 modules in params at depth 8 are {'ColumnParallelLinear': '7.35 M', 'RowParallelLinear': '5.25 M', 'FusedScaleMaskSoftmax': '0'} +Top 3 modules in latency at depth 8 are {'ColumnParallelLinear': '659.23 us', 'RowParallelLinear': '587.94 us', 'FusedScaleMaskSoftmax': '370.98 us'} + +------------------------------ Detailed Profile ------------------------------ +Each module profile is listed after its name in the follwing order: +number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency). +Note: +1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'. +2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught. + +DistributedDataParallel( + 38.89 M, 100.00% Params, 314.61 GMACs, 100.00% MACs, 33.81 ms, 100.00% latency, 18.61 TFLOPS, + (module): FP16_Module( + 38.89 M, 100.00% Params, 314.61 GMACs, 100.00% MACs, 33.77 ms, 99.89% latency, 18.63 TFLOPS, + (module): GPT2Model( + 38.89 M, 100.00% Params, 314.61 GMACs, 100.00% MACs, 33.69 ms, 99.66% latency, 18.67 TFLOPS, + (language_model): TransformerLanguageModel( + 38.89 M, 100.00% Params, 103.62 GMACs, 32.94% MACs, 5.58 ms, 16.51% latency, 37.13 TFLOPS, + (embedding): Embedding( + 26.28 M, 67.57% Params, 0 MACs, 0.00% MACs, 545.98 us, 1.61% latency, 0.0 FLOPS, + (word_embeddings): VocabParallelEmbedding(25.76 M, 66.23% Params, 0 MACs, 0.00% MACs, 223.88 us, 0.66% latency, 0.0 FLOPS, ) + (position_embeddings): Embedding(524.29 k, 1.35% Params, 0 MACs, 0.00% MACs, 147.1 us, 0.44% latency, 0.0 FLOPS, 1024, 512) + (embedding_dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 79.39 us, 0.23% latency, 0.0 FLOPS, p=0.1, inplace=False) + ) + (transformer): ParallelTransformer( + 12.61 M, 32.43% Params, 103.62 GMACs, 32.94% MACs, 5.0 ms, 14.78% latency, 41.49 TFLOPS, + (layers): ModuleList( + 12.61 M, 32.42% Params, 103.62 GMACs, 32.94% MACs, 4.4 ms, 13.01% latency, 47.13 TFLOPS, + (0): ParallelTransformerLayer( + 3.15 M, 8.11% Params, 25.9 GMACs, 8.23% MACs, 1.36 ms, 4.02% latency, 38.09 TFLOPS, + (input_layernorm): FusedLayerNorm(1.02 k, 0.00% Params, 0 MACs, 0.00% MACs, 92.51 us, 0.27% latency, 0.0 FLOPS, torch.Size([512]), eps=1e-05, elementwise_affine=True) + (attention): ParallelSelfAttention( + 1.05 M, 2.70% Params, 8.72 GMACs, 2.77% MACs, 754.59 us, 2.23% latency, 23.12 TFLOPS, + (query_key_value): ColumnParallelLinear(787.97 k, 2.03% Params, 6.44 GMACs, 2.05% MACs, 182.87 us, 0.54% latency, 70.46 TFLOPS, ) + (scale_mask_softmax): FusedScaleMaskSoftmax(0, 0.00% Params, 134.22 MMACs, 0.04% MACs, 120.4 us, 0.36% latency, 2.23 TFLOPS, ) + (attention_dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 47.45 us, 0.14% latency, 0.0 FLOPS, p=0.1, inplace=False) + (dense): RowParallelLinear(262.66 k, 0.68% Params, 2.15 GMACs, 0.68% MACs, 81.78 us, 0.24% latency, 52.52 TFLOPS, ) + ) + (post_attention_layernorm): FusedLayerNorm(1.02 k, 0.00% Params, 0 MACs, 0.00% MACs, 57.22 us, 0.17% latency, 0.0 FLOPS, torch.Size([512]), eps=1e-05, elementwise_affine=True) + (mlp): ParallelMLP( + 2.1 M, 5.40% Params, 17.18 GMACs, 5.46% MACs, 224.83 us, 0.67% latency, 152.83 TFLOPS, + (dense_h_to_4h): ColumnParallelLinear(1.05 M, 2.70% Params, 8.59 GMACs, 2.73% MACs, 64.13 us, 0.19% latency, 267.87 TFLOPS, ) + (dense_4h_to_h): RowParallelLinear(1.05 M, 2.70% Params, 8.59 GMACs, 2.73% MACs, 90.36 us, 0.27% latency, 190.13 TFLOPS, ) + ) + ) + ... + (3): ParallelTransformerLayer(...) + (final_layernorm): FusedLayerNorm(1.02 k, 0.00% Params, 0 MACs, 0.00% MACs, 52.69 us, 0.16% latency, 0.0 TFLOPS, torch.Size([512]), eps=1e-05, elementwise_affine=True) + ) + ) + ) + ) +) +``` + +### Usage Outside the DeepSpeed Runtime + +The flops profiler can be used as a standalone package outside of the DeepSpeed runtime. +One can simply install DeepSpeed and import the `flops_profiler` package to use the APIs directly. +Refer to [installation of DeepSpeed](https://www.deepspeed.ai/getting-started/#installation) for installing DeepSpeed. + +#### In Model Inference + +To profile a trained model in inference, use the `get_model_profile` function. +Examples are given below. + +##### Example: AlexNet + +The following example shows how to profile AlexNet using the DeepSpeed flops profiler. + +```python +import torchvision.models as models +import torch +from deepspeed.profiling.flops_profiler import get_model_profile + +with torch.cuda.device(0): + model = models.alexnet() + batch_size = 256 + macs, params = get_model_profile(model=model, # model + input_res=(batch_size, 3, 224, 224), # input shape or input to the input_constructor + input_constructor=None, # if specified, a constructor taking input_res is used as input to the model + print_profile=True, # prints the model graph with the measured profile attached to each module + detailed=True, # print the detailed profile + module_depth=-1, # depth into the nested modules with -1 being the inner most modules + top_modules=3, # the number of top modules to print aggregated profile + warm_up=10, # the number of warm-ups before measuring the time of each module + as_string=True, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k) + ignore_modules=None) # the list of modules to ignore in the profiling +``` + +An example output: + +```shell +-------------------------- DeepSpeed Flops Profiler -------------------------- +Summary of forward pass: +Profile step: 10 +Number of parameters: 61.1 M +Number of multiply-accumulate operations (MACs): 183.18 G +Number of floating point operations ( = 2 * MACs): 366.36 G +Latency: 22.13 ms +Floating point operations per second(FLOPS): 16.56 TFLOPS + +----------------------------- Aggregated Profile ----------------------------- +Top 3 modules in MACs at depth 2 are {'Conv2d': '167.95 GMACs', 'Linear': '15.01 GMACs', 'ReLU': '126.26 MMACs'} +Top 3 modules in params at depth 2 are {'Linear': '58.63 M', 'Conv2d': '2.47 M', 'ReLU': '0'} +Top 3 modules in latency at depth 2 are {'Conv2d': '13.96 ms', 'Linear': '6.23 ms', 'ReLU': '730.75 us'} + +------------------------------ Detailed Profile ------------------------------ +Each module profile is listed after its name in the follwing order: +number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency). +Note: +1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'. +2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught. + +AlexNet( + 61.1 M, 100.00% Params, 183.18 GMACs, 100.00% MACs, 22.13 ms, 100.00% latency, 16.56 TFLOPS, + (features): Sequential( + 2.47 M, 4.04% Params, 168.17 GMACs, 91.81% MACs, 15.17 ms, 68.57% latency, 22.17 TFLOPS, + (0): Conv2d(23.3 k, 0.04% Params, 18.04 GMACs, 9.85% MACs, 633.0 us, 2.86% latency, 57.0 TFLOPS, 3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)) + (1): ReLU(0, 0.00% Params, 49.56 MMACs, 0.03% MACs, 163.79 us, 0.74% latency, 605.17 GFLOPS, inplace=True) + (2): MaxPool2d(0, 0.00% Params, 49.56 MMACs, 0.03% MACs, 159.26 us, 0.72% latency, 622.38 GFLOPS, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) + (3): Conv2d(307.39 k, 0.50% Params, 57.37 GMACs, 31.32% MACs, 6.15 ms, 27.81% latency, 18.64 TFLOPS, 64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) + (4): ReLU(0, 0.00% Params, 35.83 MMACs, 0.02% MACs, 185.01 us, 0.84% latency, 387.34 GFLOPS, inplace=True) + (5): MaxPool2d(0, 0.00% Params, 35.83 MMACs, 0.02% MACs, 134.23 us, 0.61% latency, 533.89 GFLOPS, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) + (6): Conv2d(663.94 k, 1.09% Params, 28.72 GMACs, 15.68% MACs, 389.58 us, 1.76% latency, 147.47 TFLOPS, 192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (7): ReLU(0, 0.00% Params, 16.61 MMACs, 0.01% MACs, 76.53 us, 0.35% latency, 434.15 GFLOPS, inplace=True) + (8): Conv2d(884.99 k, 1.45% Params, 38.29 GMACs, 20.90% MACs, 6.38 ms, 28.82% latency, 12.01 TFLOPS, 384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (9): ReLU(0, 0.00% Params, 11.08 MMACs, 0.01% MACs, 104.43 us, 0.47% latency, 212.12 GFLOPS, inplace=True) + (10): Conv2d(590.08 k, 0.97% Params, 25.53 GMACs, 13.94% MACs, 405.79 us, 1.83% latency, 125.83 TFLOPS, 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (11): ReLU(0, 0.00% Params, 11.08 MMACs, 0.01% MACs, 65.57 us, 0.30% latency, 337.85 GFLOPS, inplace=True) + (12): MaxPool2d(0, 0.00% Params, 11.08 MMACs, 0.01% MACs, 122.07 us, 0.55% latency, 181.46 GFLOPS, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) + ) + (avgpool): AdaptiveAvgPool2d(0, 0.00% Params, 2.36 MMACs, 0.00% MACs, 259.4 us, 1.17% latency, 18.19 GFLOPS, output_size=(6, 6)) + (classifier): Sequential( + 58.63 M, 95.96% Params, 15.01 GMACs, 8.19% MACs, 6.54 ms, 29.54% latency, 4.59 TFLOPS, + (0): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 42.68 us, 0.19% latency, 0.0 FLOPS, p=0.5, inplace=False) + (1): Linear(37.75 M, 61.79% Params, 9.66 GMACs, 5.28% MACs, 301.36 us, 1.36% latency, 64.13 TFLOPS, in_features=9216, out_features=4096, bias=True) + (2): ReLU(0, 0.00% Params, 1.05 MMACs, 0.00% MACs, 79.39 us, 0.36% latency, 26.41 GFLOPS, inplace=True) + (3): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 39.58 us, 0.18% latency, 0.0 FLOPS, p=0.5, inplace=False) + (4): Linear(16.78 M, 27.46% Params, 4.29 GMACs, 2.34% MACs, 234.37 us, 1.06% latency, 36.65 TFLOPS, in_features=4096, out_features=4096, bias=True) + (5): ReLU(0, 0.00% Params, 1.05 MMACs, 0.00% MACs, 56.03 us, 0.25% latency, 37.43 GFLOPS, inplace=True) + (6): Linear(4.1 M, 6.71% Params, 1.05 GMACs, 0.57% MACs, 5.69 ms, 25.72% latency, 368.42 GFLOPS, in_features=4096, out_features=1000, bias=True) + ) +) +------------------------------------------------------------------------------ +``` + +##### Example: Bert + +```python +from functools import partial +import torch +from transformers import BertForSequenceClassification, BertTokenizer +from deepspeed.profiling.flops_profiler import get_model_profile + + +def bert_input_constructor(input_shape, tokenizer): + fake_seq = "" + for _ in range(input_shape[1] - 2): # ignore the two special tokens [CLS] and [SEP] + fake_seq += tokenizer.pad_token + inputs = tokenizer([fake_seq] * input_shape[0], + padding=True, + truncation=True, + return_tensors="pt") + labels = torch.tensor([1] * input_shape[0]) + inputs = dict(inputs) + inputs.update({"labels": labels}) + return inputs + + +with torch.cuda.device(0): + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = BertForSequenceClassification.from_pretrained('bert-base-uncased') + batch_size = 4 + seq_len = 128 + enable_profile = True + if enable_profile: + macs, params = get_model_profile( + model, + (batch_size, seq_len), + input_constructor=partial(bert_input_constructor, + tokenizer=tokenizer), + print_profile=True, + detailed=True, + ) + else: + inputs = bert_input_constructor((batch_size, seq_len), tokenizer) + outputs = model(inputs) +``` + +An example output: + +``` +-------------------------- DeepSpeed Flops Profiler -------------------------- +Summary of forward pass: +Profile step: 1 +Number of parameters: 109.48 M +Number of multiply-accumulate operations (MACs): 43.5 G +Number of floating point operations ( = 2 * MACs): 87.0 G +Latency: 393.7 ms +Floating point operations per second(FLOPS): 220.97 GFLOPS + +----------------------------- Aggregated Profile ----------------------------- +Top 3 modules in MACs at depth 7 are {'Linear': '14.5 GMACs', 'Dropout': '0 MACs', 'LayerNorm': '0 MACs'} +Top 3 modules in params at depth 7 are {'Linear': '28.35 M', 'LayerNorm': '18.43 k', 'Dropout': '0'} +Top 3 modules in latency at depth 7 are {'Linear': '153.7 ms', 'LayerNorm': '4.74 ms', 'Dropout': '597.95 us'} + +------------------------------ Detailed Profile ------------------------------ +Each module profile is listed after its name in the follwing order: +number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency). +Note: +1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'. +2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught. + +BertForSequenceClassification( + 109.48 M, 100.00% Params, 43.5 GMACs, 100.00% MACs, 393.7 ms, 100.00% latency, 220.97 GFLOPS, + (bert): BertModel( + 109.48 M, 100.00% Params, 43.5 GMACs, 100.00% MACs, 393.38 ms, 99.92% latency, 221.15 GFLOPS, + (embeddings): BertEmbeddings( + 23.84 M, 21.77% Params, 0 MACs, 0.00% MACs, 1.79 ms, 0.45% latency, 0.0 FLOPS, + (word_embeddings): Embedding(23.44 M, 21.41% Params, 0 MACs, 0.00% MACs, 485.18 us, 0.12% latency, 0.0 FLOPS, 30522, 768, padding_idx=0) + (position_embeddings): Embedding(393.22 k, 0.36% Params, 0 MACs, 0.00% MACs, 111.1 us, 0.03% latency, 0.0 FLOPS, 512, 768) + (token_type_embeddings): Embedding(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 215.53 us, 0.05% latency, 0.0 FLOPS, 2, 768) + (LayerNorm): LayerNorm(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 386.95 us, 0.10% latency, 0.0 FLOPS, (768,), eps=1e-12, elementwise_affine=True) + (dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 20.27 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False) + ) + (encoder): BertEncoder( + 85.05 M, 77.69% Params, 43.5 GMACs, 99.99% MACs, 391.03 ms, 99.32% latency, 222.47 GFLOPS, + (layer): ModuleList( + 85.05 M, 77.69% Params, 43.5 GMACs, 99.99% MACs, 390.82 ms, 99.27% latency, 222.59 GFLOPS, + (0): BertLayer( + 7.09 M, 6.47% Params, 3.62 GMACs, 8.33% MACs, 31.91 ms, 8.10% latency, 227.21 GFLOPS, + (attention): BertAttention( + 2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 16.39 ms, 4.16% latency, 147.47 GFLOPS, + (self): BertSelfAttention( + 1.77 M, 1.62% Params, 906.76 MMACs, 2.08% MACs, 15.07 ms, 3.83% latency, 120.36 GFLOPS, + (query): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 3.66 ms, 0.93% latency, 164.91 GFLOPS, in_features=768, out_features=768, bias=True) + (key): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 3.72 ms, 0.94% latency, 162.36 GFLOPS, in_features=768, out_features=768, bias=True) + (value): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 4.52 ms, 1.15% latency, 133.65 GFLOPS, in_features=768, out_features=768, bias=True) + (dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 24.08 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False) + ) + (output): BertSelfOutput( + 592.13 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 1.29 ms, 0.33% latency, 469.21 GFLOPS, + (dense): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 504.26 us, 0.13% latency, 1.2 TFLOPS, in_features=768, out_features=768, bias=True) + (LayerNorm): LayerNorm(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 437.97 us, 0.11% latency, 0.0 FLOPS, (768,), eps=1e-12, elementwise_affine=True) + (dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 21.93 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False) + ) + ) + (intermediate): BertIntermediate( + 2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 9.57 ms, 2.43% latency, 252.35 GFLOPS, + (dense): Linear(2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 8.75 ms, 2.22% latency, 276.11 GFLOPS, in_features=768, out_features=3072, bias=True) + ) + (output): BertOutput( + 2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 5.77 ms, 1.47% latency, 418.39 GFLOPS, + (dense): Linear(2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 5.13 ms, 1.30% latency, 471.15 GFLOPS, in_features=3072, out_features=768, bias=True) + (LayerNorm): LayerNorm(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 310.9 us, 0.08% latency, 0.0 FLOPS, (768,), eps=1e-12, elementwise_affine=True) + (dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 29.8 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False) + ) + ) + ... + (11): BertLayer(...) + ) + ) + (pooler): BertPooler( + 590.59 k, 0.54% Params, 2.36 MMACs, 0.01% MACs, 337.12 us, 0.09% latency, 14.0 GFLOPS, + (dense): Linear(590.59 k, 0.54% Params, 2.36 MMACs, 0.01% MACs, 173.57 us, 0.04% latency, 27.19 GFLOPS, in_features=768, out_features=768, bias=True) + (activation): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 46.01 us, 0.01% latency, 0.0 FLOPS, ) + ) + ) + (dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 19.55 us, 0.00% latency, 0.0 FLOPS, p=0.1, inplace=False) + (classifier): Linear(1.54 k, 0.00% Params, 6.14 KMACs, 0.00% MACs, 56.51 us, 0.01% latency, 217.47 MFLOPS, in_features=768, out_features=2, bias=True) +) +------------------------------------------------------------------------------ +``` + +#### In Model Training Workflow + +To profile model forward in a training workflow, use the `FlopsProfiler`class. +The `FlopsProfiler`class provides the follwing methods: + * `start_profile()` - starts profiling + * `get_total_flops(as_string=False)` - returns the total number of MACs in the model + * `get_total_params(as_string=False)` - returns the total number of parameters in the model + * `print_model_profile(profile_step=1, module_depth=-1, top_modules=3, detailed=True)` - prints the model profile + * `end_profile()` - ends profiling and cleans up. This should be invoked at the end of the profiling and AFTER `get_total_flops`, `get_total_params` or `print_model_profile`. + +##### Example Training Workflow + +Below is an example of this usage in a typical training workflow. Note that the flops profiler only captures the forward pass in a training step. The flops of a backward pass can be roughly estimated from that of the forward pass (~2x). + +```python +from deepspeed.profiling.flops_profiler import FlopsProfiler + +model = Model() +prof = FlopsProfiler(model) + +profile_step = 5 +print_profile= True + +for step, batch in enumerate(data_loader): + # start profiling at training step "profile_step" + if step == profile_step: + prof.start_profile() + + # forward() method + loss = model(batch) + + # end profiling and print output + if step == profile_step: # if using multi nodes, check global_rank == 0 as well + flops = prof.get_total_flops(as_string=True) + params = prof.get_total_params(as_string=True) + if print_profile: + prof.print_model_profile(profile_step=profile_step) + prof.end_profile() + + # runs backpropagation + loss.backward() + + # weight update + optimizer.step() +``` diff --git a/tests/unit/test_flops_profiler.py b/tests/unit/test_flops_profiler.py index fc741707..133610d0 100644 --- a/tests/unit/test_flops_profiler.py +++ b/tests/unit/test_flops_profiler.py @@ -24,8 +24,7 @@ def test_flops_profiler_in_ds_trainning(tmpdir): }, "flops_profiler": { "enabled": True, - "start_step": 2, - "end_step": 3, + "step": 1, "module_depth": -1, "top_modules": 3, }, @@ -100,18 +99,17 @@ def test_flops_profiler_in_inference(): mod = LeNet5(10) batch_size = 1024 input = torch.randn(batch_size, 1, 32, 32) - macs, params, steps = get_model_profile( + macs, params = get_model_profile( mod, tuple(input.shape), print_profile=True, - print_aggregated_profile=True, + detailed=True, module_depth=-1, top_modules=3, - warm_up=5, - num_steps=10, - as_strings=True, + warm_up=1, + as_string=True, ignore_modules=None, ) - print(macs, params, steps) - assert macs == "439.55 MMACs" + print(macs, params) + assert macs == "439.56 MMACs" assert params == "61.71 k" -- GitLab