未验证 提交 17b4dd70 编写于 作者: 李季 提交者: GitHub

Fix global gather and global scatter operators (#36517)

* fix global gather and global scatter operators
上级 6a572a19
...@@ -47,8 +47,8 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -47,8 +47,8 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
if (platform::is_cpu_place(local_count->place())) { if (platform::is_cpu_place(local_count->place())) {
cpu_local_count_data = local_count->data<int64_t>(); cpu_local_count_data = local_count->data<int64_t>();
} else { } else {
framework::TensorCopy(*local_count, platform::CPUPlace(), framework::TensorCopySync(*local_count, platform::CPUPlace(),
&cpu_local_count); &cpu_local_count);
cpu_local_count_data = cpu_local_count.data<int64_t>(); cpu_local_count_data = cpu_local_count.data<int64_t>();
} }
auto global_count_len = 0; auto global_count_len = 0;
...@@ -57,8 +57,8 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -57,8 +57,8 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
cpu_global_count_data = global_count->data<int64_t>(); cpu_global_count_data = global_count->data<int64_t>();
global_count_len = global_count->numel(); global_count_len = global_count->numel();
} else { } else {
framework::TensorCopy(*global_count, platform::CPUPlace(), framework::TensorCopySync(*global_count, platform::CPUPlace(),
&cpu_global_count); &cpu_global_count);
cpu_global_count_data = cpu_global_count.data<int64_t>(); cpu_global_count_data = cpu_global_count.data<int64_t>();
global_count_len = cpu_global_count.numel(); global_count_len = cpu_global_count.numel();
} }
......
...@@ -65,14 +65,11 @@ def global_scatter(x, ...@@ -65,14 +65,11 @@ def global_scatter(x,
to global_count. to global_count.
Args: Args:
x (Tensor): Tensor. Every element in the list must be a Tensor whose data type x (Tensor): Tensor. The tensor data type should be float16, float32, float64, int32 or int64.
should be float16, float32, float64, int32 or int64.
local_count (Tensor): Tensor which have n_expert * world_size elements that indicates local_count (Tensor): Tensor which have n_expert * world_size elements that indicates
how many data needed to be sent. Every element in the list must be a Tensor whose how many data needed to be sent. The tensor data type should be int64.
data type should be int64.
global_count (Tensor): Tensor which have n_expert * world_size elements that indicates global_count (Tensor): Tensor which have n_expert * world_size elements that indicates
how many data needed to be received. Every element in the list must be a Tensor whose how many data needed to be received. The tensor data type should be int64.
data type should be int64.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None. group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True. use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
...@@ -161,19 +158,16 @@ def global_gather(x, ...@@ -161,19 +158,16 @@ def global_gather(x,
to global_count. to global_count.
Args: Args:
x (Tensor): Tensor. Every element in the list must be a Tensor whose data type x (Tensor): Tensor. Tensor whose data type should be float16, float32, float64, int32 or int64.
should be float16, float32, float64, int32 or int64.
local_count (Tensor): Tensor which have n_expert * world_size elements that indicates local_count (Tensor): Tensor which have n_expert * world_size elements that indicates
how many data needed to be received. Every element in the list must be a Tensor whose how many data needed to be received. Tensor data type should be int64.
data type should be int64.
global_count (Tensor): Tensor which have n_expert * world_size elements that indicates global_count (Tensor): Tensor which have n_expert * world_size elements that indicates
how many data needed to be sent. Every element in the list must be a Tensor whose how many data needed to be sent. Tensor data type should be int64.
data type should be int64.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None. group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True. use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
Returns: Returns:
None. out (Tensor): The data received from all experts.
Examples: Examples:
.. code-block:: python .. code-block:: python
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册