• Z
    ZeRO3, improved parameter all-gather operation (#1188) · c0eeb69d
    Zhen Zhang 提交于
    * remove norm(), avoid memcpy after allgather
    
    1) Removing the norm computation in debug printing
    2) Changing _all_gather to be sync op in fetch_sub_module
        Reason: the async version is not async at all, because each
        all_gather calls torch.cuda.synchronize() to guarantee previous
        communication op to be completed
    3) Adding new function _allgather_params_split_launch
        the existing _allgather_params has explicit memcpy after the
        all-gather op. We can avoid the explicit memory copy at
        python side, to improve the performance.
    
    Known issue:
        the `torch.distributed.all_gather` will do implicit memcpy
        at the end of each `ncclAllgather`.
    
    * WIP: wrapped ncclAllgather as customized op in DS
    
    micro benchmark shows the improvement of allgather a
    transformer layer with 9834560 elements in half precision is about
    1.1ms on aws-p4d instance.
    
    * WIP: integrated into partition_parameters
    
    Performance improvement of 5.1B bert on aws-p4d:
    fwd: 300ms -> 200ms
    bwd: 680ms -> 610ms
    
    * Fix format
    
    * cleaned dead code, modified unit test
    
    * removed customized c++ extension
    
    revert back to use torch distributed API
    
    * change torch.ones to torch empty
    
    * typo
    
    * warn if not cuda tensor for allgather
    
    * fix formatting
    
    * fix: move ds_tensor to cuda device
    
    but it is strange that the ds_tensor haven't been moved to cuda
    
    * remove try clause on the path for fetching params
    Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
    Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
    c0eeb69d
stage3.py 150.6 KB