diff --git a/deepspeed/comm/utils.py b/deepspeed/comm/utils.py index 27a4d2c4a588fda1e88f9e82b16ae5315ec323b4..78bfdf4462e4f57458a1137cb24a271e09fc1b6f 100644 --- a/deepspeed/comm/utils.py +++ b/deepspeed/comm/utils.py @@ -68,6 +68,9 @@ def get_tensor_position(func): # most colls if 'tensor' in sig_params: arg = 'tensor' + # all_reduce_coalesced coll + elif 'tensors' in sig_params: + arg = 'tensors' # reduce scatter coll elif 'input_list' in sig_params: arg = 'input_list' @@ -87,6 +90,8 @@ def get_tensor_kwarg(func, kwargs): if 'tensor' in func_args: arg = func_args['tensor'] + elif 'tensors' in func_args: + arg = func_args['tensors'] elif 'input_list' in func_args: arg = func_args['input_list'] elif 'input_tensor_list' in func_args: diff --git a/deepspeed/utils/comms_logging.py b/deepspeed/utils/comms_logging.py index 8e6558cfb9dd9dcdde4d6a1947023e6975199368..69c80f78c25dbe0d1c955dd26950c97f0d213b3a 100644 --- a/deepspeed/utils/comms_logging.py +++ b/deepspeed/utils/comms_logging.py @@ -12,6 +12,12 @@ def get_caller_func(frame=3): return sys._getframe(frame).f_code.co_name +def print_rank_0(message): + import deepspeed.comm as dist + if dist.get_rank() == 0: + print(message) + + # Helper function to pretty-print message sizes def convert_size(size_bytes): if size_bytes == 0: @@ -38,7 +44,7 @@ def calc_bw_log(comm_op, size, duration): size *= n tput = (size / duration) busbw = (size / duration) * ((n - 1) / n) - elif comm_op == "all_reduce": + elif comm_op == "all_reduce" or comm_op == "all_reduce_coalesced": tput = (size * 2 / duration) busbw = (size / duration) * (2 * (n - 1) / n) elif comm_op == "send" or comm_op == "recv" or comm_op == "isend" or comm_op == "irecv" or comm_op == "broadcast" or comm_op == "reduce" or comm_op == "gather" or comm_op == "scatter" or comm_op == "barrier":