Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
17b4dd70
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
17b4dd70
编写于
10月 20, 2021
作者:
李
李季
提交者:
GitHub
10月 20, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix global gather and global scatter operators (#36517)
* fix global gather and global scatter operators
上级
6a572a19
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
11 addition
and
17 deletion
+11
-17
paddle/fluid/operators/collective/global_scatter_op.cu.cc
paddle/fluid/operators/collective/global_scatter_op.cu.cc
+4
-4
python/paddle/distributed/utils.py
python/paddle/distributed/utils.py
+7
-13
未找到文件。
paddle/fluid/operators/collective/global_scatter_op.cu.cc
浏览文件 @
17b4dd70
...
@@ -47,7 +47,7 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -47,7 +47,7 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
if
(
platform
::
is_cpu_place
(
local_count
->
place
()))
{
if
(
platform
::
is_cpu_place
(
local_count
->
place
()))
{
cpu_local_count_data
=
local_count
->
data
<
int64_t
>
();
cpu_local_count_data
=
local_count
->
data
<
int64_t
>
();
}
else
{
}
else
{
framework
::
TensorCopy
(
*
local_count
,
platform
::
CPUPlace
(),
framework
::
TensorCopy
Sync
(
*
local_count
,
platform
::
CPUPlace
(),
&
cpu_local_count
);
&
cpu_local_count
);
cpu_local_count_data
=
cpu_local_count
.
data
<
int64_t
>
();
cpu_local_count_data
=
cpu_local_count
.
data
<
int64_t
>
();
}
}
...
@@ -57,7 +57,7 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -57,7 +57,7 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
cpu_global_count_data
=
global_count
->
data
<
int64_t
>
();
cpu_global_count_data
=
global_count
->
data
<
int64_t
>
();
global_count_len
=
global_count
->
numel
();
global_count_len
=
global_count
->
numel
();
}
else
{
}
else
{
framework
::
TensorCopy
(
*
global_count
,
platform
::
CPUPlace
(),
framework
::
TensorCopy
Sync
(
*
global_count
,
platform
::
CPUPlace
(),
&
cpu_global_count
);
&
cpu_global_count
);
cpu_global_count_data
=
cpu_global_count
.
data
<
int64_t
>
();
cpu_global_count_data
=
cpu_global_count
.
data
<
int64_t
>
();
global_count_len
=
cpu_global_count
.
numel
();
global_count_len
=
cpu_global_count
.
numel
();
...
...
python/paddle/distributed/utils.py
浏览文件 @
17b4dd70
...
@@ -65,14 +65,11 @@ def global_scatter(x,
...
@@ -65,14 +65,11 @@ def global_scatter(x,
to global_count.
to global_count.
Args:
Args:
x (Tensor): Tensor. Every element in the list must be a Tensor whose data type
x (Tensor): Tensor. The tensor data type should be float16, float32, float64, int32 or int64.
should be float16, float32, float64, int32 or int64.
local_count (Tensor): Tensor which have n_expert * world_size elements that indicates
local_count (Tensor): Tensor which have n_expert * world_size elements that indicates
how many data needed to be sent. Every element in the list must be a Tensor whose
how many data needed to be sent. The tensor data type should be int64.
data type should be int64.
global_count (Tensor): Tensor which have n_expert * world_size elements that indicates
global_count (Tensor): Tensor which have n_expert * world_size elements that indicates
how many data needed to be received. Every element in the list must be a Tensor whose
how many data needed to be received. The tensor data type should be int64.
data type should be int64.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
...
@@ -161,19 +158,16 @@ def global_gather(x,
...
@@ -161,19 +158,16 @@ def global_gather(x,
to global_count.
to global_count.
Args:
Args:
x (Tensor): Tensor. Every element in the list must be a Tensor whose data type
x (Tensor): Tensor. Tensor whose data type should be float16, float32, float64, int32 or int64.
should be float16, float32, float64, int32 or int64.
local_count (Tensor): Tensor which have n_expert * world_size elements that indicates
local_count (Tensor): Tensor which have n_expert * world_size elements that indicates
how many data needed to be received. Every element in the list must be a Tensor whose
how many data needed to be received. Tensor data type should be int64.
data type should be int64.
global_count (Tensor): Tensor which have n_expert * world_size elements that indicates
global_count (Tensor): Tensor which have n_expert * world_size elements that indicates
how many data needed to be sent. Every element in the list must be a Tensor whose
how many data needed to be sent. Tensor data type should be int64.
data type should be int64.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
Returns:
Returns:
None.
out (Tensor): The data received from all experts.
Examples:
Examples:
.. code-block:: python
.. code-block:: python
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录