• J
    Various ZeRO Stage3 Optimizations + Improvements (including bfloat16 support) (#1453) · 4912e0ad
    Justin Chiu 提交于
    * Changes for bfloat16 Zero2
    
    * ZeRO stage3 optimizations, with some bug fixes
    
    optimizations for stage3:
    - prefetching improvements
    - batching allgather calls to amortize fixed overhead and improve
      bandwidth utilization
    - batching reduce_scatter calls to amortize fixed overhead and
      improve bandwidth utilization
    - using *_base variants of allgather and reduce scatter to reduce memory
      allocations and data movement
    - more fine grained synchronization for communication that allows
      blocking on less work
    - precomputation of fetching code - using a fetch queue rather than
      deciding what to (pre)fetch at each iteration
    - limiting queued coalesced communication ops to reduce memory pressure
      on pytorch cuda caching allocator (not elegant solution)
    
    optimizations for stage3-offload:
    - made some host-device tensor copies async to improve performance
    
    bug fixes and qol improvements:
    - fix init context method when parent modules modify child weights
    - speed up model initialization by moving model to GPU before weight
      initialization
    - fixed unit test imports so that unit tests can be run from any
      directory
    - change performance logging to include memory consumption
    - add logging w/ model size when done partitioning model
    
    new features
    - bfloat16 support for ZeRO 3
    
    * fix import in ut
    
    * ran yapf
    
    * improvements to cache flush warn log
    
    * backwards compatibility with older versions of pytorch
    
    * handle edge case where reduced tensor smaller than world size
    
    * moved event synchronization to allgather handle wait() call
    
    * removed unnecessary barrier call
    
    * formatting fix after resolving merge conflict
    
    * skip nvme prefetch when trace not complete
    
    * opportunistically avoid memory allocation in allgather coalesced where possible
    
    * fix indentation after merge
    
    * fixes to account for parameter offload
    
    * accounting for torch.cuda.memory_stats not being available
    
    * moved partition_all_params to optimizer step
    
    * allgathering on params before item gets called
    
    * fix param status checks
    
    needed after moving partition_all_parameters call to optimizer step
    
    * fix grad accumulation with optimizer offload
    
    * grad norm computation fix for optimizer offload
    
    * change post divide in reduce-scatter to pre divide
    
    * fix gradient race condition w/ optimizer offload
    
    * improve inf/nan gradient tracking
    
    * don't prefetch when not in training mode
    
    * format fix after merging
    
    * fix prefetching issue when using NVME offload
    
    * improved defragmentation for fp16 parameters
    
    * relative imports for bf16 tests
    
    * changes for bwd compatibility with pytorch 1.2
    
    * remove buffered_reduce_fallback
    
    * removed unused parameter offset bookkeeping
    
    * fixed tracking for multiple param groups
    
    * unbroke bfloat16 config after merge conflict
    
    * using base allgather params when only 1 param
    
    * cleanup/fixes for fp16 partition defragmentation
    
    * switch to CRLF
    
    * convert to same new-line style as master
    
    * align new line with master
    
    * Fix merge issues
    
    * switch to CRLF
    
    * fix to LF line endings
    
    * minor merge fixes
    
    * remove extra bfloat16_enabled definition
    
    * asserting params inflight for AllGatherHandle
    
    * remove get_cuda_mem_allocated_str
    
    * Format fixes
    
    * fix bfloat16 zero stage check (broken after merge commit)
    
    * +self.communication_data_type, -self.allreduce_always_fp32; delete dead code
    
    * Add self.reduce_scatter
    
    * Format fix
    
    * Fix merge issues
    
    * iterate over params_to_fetch rather than make another iterator
    
    * add some TODOs
    
    * remove unnecessary division by micro_step_id
    
    * rename config keys "bfloat16" -> "bf16"
    
    * rename stage3_gather_fp16_weights_on_model_save -> stage3_gather_16bit_weights_on_model_save
    
    * add unit test to check backwards compatibility for gather_16bit_weights
    
    * added test to confirm bf16 key bwd compatibility
    
    * Format fixes
    Co-authored-by: NRana Ali Amjad <raamjad@amazon.com>
    Co-authored-by: NJustin Chiu <justchiu@amazon.com>
    Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
    Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
    4912e0ad
test_pipe_module.py 2.8 KB