From 17b4dd70a95b9eeec52237c8aa1c6b122b5e93a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AD=A3?= <2042519524@qq.com> Date: Wed, 20 Oct 2021 16:13:22 +0800 Subject: [PATCH] Fix global gather and global scatter operators (#36517) * fix global gather and global scatter operators --- .../collective/global_scatter_op.cu.cc | 8 ++++---- python/paddle/distributed/utils.py | 20 +++++++------------ 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/collective/global_scatter_op.cu.cc b/paddle/fluid/operators/collective/global_scatter_op.cu.cc index 64765b549e..bec984c6b5 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cu.cc @@ -47,8 +47,8 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel { if (platform::is_cpu_place(local_count->place())) { cpu_local_count_data = local_count->data(); } else { - framework::TensorCopy(*local_count, platform::CPUPlace(), - &cpu_local_count); + framework::TensorCopySync(*local_count, platform::CPUPlace(), + &cpu_local_count); cpu_local_count_data = cpu_local_count.data(); } auto global_count_len = 0; @@ -57,8 +57,8 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel { cpu_global_count_data = global_count->data(); global_count_len = global_count->numel(); } else { - framework::TensorCopy(*global_count, platform::CPUPlace(), - &cpu_global_count); + framework::TensorCopySync(*global_count, platform::CPUPlace(), + &cpu_global_count); cpu_global_count_data = cpu_global_count.data(); global_count_len = cpu_global_count.numel(); } diff --git a/python/paddle/distributed/utils.py b/python/paddle/distributed/utils.py index 63585e167e..31d5748ce3 100644 --- a/python/paddle/distributed/utils.py +++ b/python/paddle/distributed/utils.py @@ -65,14 +65,11 @@ def global_scatter(x, to global_count. Args: - x (Tensor): Tensor. Every element in the list must be a Tensor whose data type - should be float16, float32, float64, int32 or int64. + x (Tensor): Tensor. The tensor data type should be float16, float32, float64, int32 or int64. local_count (Tensor): Tensor which have n_expert * world_size elements that indicates - how many data needed to be sent. Every element in the list must be a Tensor whose - data type should be int64. + how many data needed to be sent. The tensor data type should be int64. global_count (Tensor): Tensor which have n_expert * world_size elements that indicates - how many data needed to be received. Every element in the list must be a Tensor whose - data type should be int64. + how many data needed to be received. The tensor data type should be int64. group (Group, optional): The group instance return by new_group or None for global default group. Default: None. use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True. @@ -161,19 +158,16 @@ def global_gather(x, to global_count. Args: - x (Tensor): Tensor. Every element in the list must be a Tensor whose data type - should be float16, float32, float64, int32 or int64. + x (Tensor): Tensor. Tensor whose data type should be float16, float32, float64, int32 or int64. local_count (Tensor): Tensor which have n_expert * world_size elements that indicates - how many data needed to be received. Every element in the list must be a Tensor whose - data type should be int64. + how many data needed to be received. Tensor data type should be int64. global_count (Tensor): Tensor which have n_expert * world_size elements that indicates - how many data needed to be sent. Every element in the list must be a Tensor whose - data type should be int64. + how many data needed to be sent. Tensor data type should be int64. group (Group, optional): The group instance return by new_group or None for global default group. Default: None. use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True. Returns: - None. + out (Tensor): The data received from all experts. Examples: .. code-block:: python -- GitLab