-
由 Zhen Zhang 提交于
* remove norm(), avoid memcpy after allgather 1) Removing the norm computation in debug printing 2) Changing _all_gather to be sync op in fetch_sub_module Reason: the async version is not async at all, because each all_gather calls torch.cuda.synchronize() to guarantee previous communication op to be completed 3) Adding new function _allgather_params_split_launch the existing _allgather_params has explicit memcpy after the all-gather op. We can avoid the explicit memory copy at python side, to improve the performance. Known issue: the `torch.distributed.all_gather` will do implicit memcpy at the end of each `ncclAllgather`. * WIP: wrapped ncclAllgather as customized op in DS micro benchmark shows the improvement of allgather a transformer layer with 9834560 elements in half precision is about 1.1ms on aws-p4d instance. * WIP: integrated into partition_parameters Performance improvement of 5.1B bert on aws-p4d: fwd: 300ms -> 200ms bwd: 680ms -> 610ms * Fix format * cleaned dead code, modified unit test * removed customized c++ extension revert back to use torch distributed API * change torch.ones to torch empty * typo * warn if not cuda tensor for allgather * fix formatting * fix: move ds_tensor to cuda device but it is strange that the ds_tensor haven't been moved to cuda * remove try clause on the path for fetching params Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com> Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
c0eeb69d