未验证 提交 17b4dd70 编写于 作者: 李季 提交者: GitHub

Fix global gather and global scatter operators (#36517)

* fix global gather and global scatter operators
上级 6a572a19
......@@ -47,7 +47,7 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
if (platform::is_cpu_place(local_count->place())) {
cpu_local_count_data = local_count->data<int64_t>();
} else {
framework::TensorCopy(*local_count, platform::CPUPlace(),
framework::TensorCopySync(*local_count, platform::CPUPlace(),
&cpu_local_count);
cpu_local_count_data = cpu_local_count.data<int64_t>();
}
......@@ -57,7 +57,7 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
cpu_global_count_data = global_count->data<int64_t>();
global_count_len = global_count->numel();
} else {
framework::TensorCopy(*global_count, platform::CPUPlace(),
framework::TensorCopySync(*global_count, platform::CPUPlace(),
&cpu_global_count);
cpu_global_count_data = cpu_global_count.data<int64_t>();
global_count_len = cpu_global_count.numel();
......
......@@ -65,14 +65,11 @@ def global_scatter(x,
to global_count.
Args:
x (Tensor): Tensor. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64.
x (Tensor): Tensor. The tensor data type should be float16, float32, float64, int32 or int64.
local_count (Tensor): Tensor which have n_expert * world_size elements that indicates
how many data needed to be sent. Every element in the list must be a Tensor whose
data type should be int64.
how many data needed to be sent. The tensor data type should be int64.
global_count (Tensor): Tensor which have n_expert * world_size elements that indicates
how many data needed to be received. Every element in the list must be a Tensor whose
data type should be int64.
how many data needed to be received. The tensor data type should be int64.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
......@@ -161,19 +158,16 @@ def global_gather(x,
to global_count.
Args:
x (Tensor): Tensor. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64.
x (Tensor): Tensor. Tensor whose data type should be float16, float32, float64, int32 or int64.
local_count (Tensor): Tensor which have n_expert * world_size elements that indicates
how many data needed to be received. Every element in the list must be a Tensor whose
data type should be int64.
how many data needed to be received. Tensor data type should be int64.
global_count (Tensor): Tensor which have n_expert * world_size elements that indicates
how many data needed to be sent. Every element in the list must be a Tensor whose
data type should be int64.
how many data needed to be sent. Tensor data type should be int64.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
Returns:
None.
out (Tensor): The data received from all experts.
Examples:
.. code-block:: python
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册