diff --git a/paddle/fluid/operators/collective/global_scatter_op.cu.cc b/paddle/fluid/operators/collective/global_scatter_op.cu.cc index 64765b549e5c1fb44f89e452204149ff6230c985..bec984c6b57e19dd890c0a8f3321d69242bd67e5 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cu.cc @@ -47,8 +47,8 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel { if (platform::is_cpu_place(local_count->place())) { cpu_local_count_data = local_count->data(); } else { - framework::TensorCopy(*local_count, platform::CPUPlace(), - &cpu_local_count); + framework::TensorCopySync(*local_count, platform::CPUPlace(), + &cpu_local_count); cpu_local_count_data = cpu_local_count.data(); } auto global_count_len = 0; @@ -57,8 +57,8 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel { cpu_global_count_data = global_count->data(); global_count_len = global_count->numel(); } else { - framework::TensorCopy(*global_count, platform::CPUPlace(), - &cpu_global_count); + framework::TensorCopySync(*global_count, platform::CPUPlace(), + &cpu_global_count); cpu_global_count_data = cpu_global_count.data(); global_count_len = cpu_global_count.numel(); } diff --git a/python/paddle/distributed/utils.py b/python/paddle/distributed/utils.py index 63585e167e8e3203fb1692e26702268e9edab3db..31d5748ce392e73792d1ebab6203c635954db79d 100644 --- a/python/paddle/distributed/utils.py +++ b/python/paddle/distributed/utils.py @@ -65,14 +65,11 @@ def global_scatter(x, to global_count. Args: - x (Tensor): Tensor. Every element in the list must be a Tensor whose data type - should be float16, float32, float64, int32 or int64. + x (Tensor): Tensor. The tensor data type should be float16, float32, float64, int32 or int64. local_count (Tensor): Tensor which have n_expert * world_size elements that indicates - how many data needed to be sent. Every element in the list must be a Tensor whose - data type should be int64. + how many data needed to be sent. The tensor data type should be int64. global_count (Tensor): Tensor which have n_expert * world_size elements that indicates - how many data needed to be received. Every element in the list must be a Tensor whose - data type should be int64. + how many data needed to be received. The tensor data type should be int64. group (Group, optional): The group instance return by new_group or None for global default group. Default: None. use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True. @@ -161,19 +158,16 @@ def global_gather(x, to global_count. Args: - x (Tensor): Tensor. Every element in the list must be a Tensor whose data type - should be float16, float32, float64, int32 or int64. + x (Tensor): Tensor. Tensor whose data type should be float16, float32, float64, int32 or int64. local_count (Tensor): Tensor which have n_expert * world_size elements that indicates - how many data needed to be received. Every element in the list must be a Tensor whose - data type should be int64. + how many data needed to be received. Tensor data type should be int64. global_count (Tensor): Tensor which have n_expert * world_size elements that indicates - how many data needed to be sent. Every element in the list must be a Tensor whose - data type should be int64. + how many data needed to be sent. Tensor data type should be int64. group (Group, optional): The group instance return by new_group or None for global default group. Default: None. use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True. Returns: - None. + out (Tensor): The data received from all experts. Examples: .. code-block:: python