Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
83a2fb1f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
83a2fb1f
编写于
3月 10, 2021
作者:
W
WangXi
提交者:
GitHub
3月 10, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add collective async wait op (#31463)
上级
0205e9f8
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
370 addition
and
1 deletion
+370
-1
paddle/fluid/operators/collective/c_wait_comm_op.cc
paddle/fluid/operators/collective/c_wait_comm_op.cc
+91
-0
paddle/fluid/operators/collective/c_wait_compute_op.cc
paddle/fluid/operators/collective/c_wait_compute_op.cc
+95
-0
paddle/fluid/platform/collective_helper.cc
paddle/fluid/platform/collective_helper.cc
+28
-0
paddle/fluid/platform/collective_helper.h
paddle/fluid/platform/collective_helper.h
+2
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+2
-1
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/collective_allreduce_op_wait.py
...dle/fluid/tests/unittests/collective_allreduce_op_wait.py
+114
-0
python/paddle/fluid/tests/unittests/test_collective_wait.py
python/paddle/fluid/tests/unittests/test_collective_wait.py
+37
-0
未找到文件。
paddle/fluid/operators/collective/c_wait_comm_op.cc
0 → 100644
浏览文件 @
83a2fb1f
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
class
CWaitCommOp
:
public
framework
::
OperatorBase
{
public:
CWaitCommOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"wait_comm op can run on gpu place only for now."
));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int
ring_id
=
Attr
<
int
>
(
"ring_id"
);
auto
compute_stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
))
->
stream
();
auto
comm_stream
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
ring_id
,
place
)
->
stream
();
auto
event
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
ring_id
,
place
)
->
comm_event
();
// comm_stream-->event-->compute_stream
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipEventRecord
(
event
,
comm_stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamWaitEvent
(
compute_stream
,
event
,
0
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaEventRecord
(
event
,
comm_stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamWaitEvent
(
compute_stream
,
event
,
0
));
#endif
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should compile with GPU."
));
#endif
}
};
class
CWaitCommOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"(Tensor) Dependency of the variable need to sync"
)
.
AsDuplicable
();
AddOutput
(
"Out"
,
"(Tensor) Dependency of the variable need to sync"
)
.
AsDuplicable
();
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) ring id."
).
SetDefault
(
0
);
AddComment
(
R"DOC(
CWaitComm Operator
Compute stream wait Comm Stream with async event.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
c_wait_comm
,
ops
::
CWaitCommOp
,
ops
::
CWaitCommOpMaker
);
paddle/fluid/operators/collective/c_wait_compute_op.cc
0 → 100644
浏览文件 @
83a2fb1f
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
class
CWaitComputeOp
:
public
framework
::
OperatorBase
{
public:
CWaitComputeOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"wait_compute op can run on gpu place only for now."
));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int
ring_id
=
Attr
<
int
>
(
"ring_id"
);
auto
compute_stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
))
->
stream
();
auto
comm_stream
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
ring_id
,
place
)
->
stream
();
auto
event
=
platform
::
NCCLCommContext
::
Instance
()
.
Get
(
ring_id
,
place
)
->
compute_event
();
// compute_stream-->event-->comm_stream
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipEventRecord
(
event
,
compute_stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamWaitEvent
(
comm_stream
,
event
,
0
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaEventRecord
(
event
,
compute_stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamWaitEvent
(
comm_stream
,
event
,
0
));
#endif
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should compile with GPU."
));
#endif
}
};
class
CWaitComputeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"(Tensor) Dependency of the variable need to sync"
)
.
AsDuplicable
();
AddOutput
(
"Out"
,
"(Tensor) Dependency of the variable need to sync"
)
.
AsDuplicable
();
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) ring id."
).
SetDefault
(
0
);
AddComment
(
R"DOC(
CWaitCompute Operator
Comm stream wait Compute Stream with async event.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
c_wait_compute
,
ops
::
CWaitComputeOp
,
ops
::
CWaitComputeOpMaker
);
paddle/fluid/platform/collective_helper.cc
浏览文件 @
83a2fb1f
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/collective_helper.h"
#include <utility>
#include <utility>
#include "paddle/fluid/platform/cuda_resource_pool.h"
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
...
@@ -43,12 +45,31 @@ class NCCLCommImpl : public NCCLComm {
...
@@ -43,12 +45,31 @@ class NCCLCommImpl : public NCCLComm {
}
}
CUDADeviceContext
*
dev_context
()
const
override
{
return
dev_ctx_
.
get
();
}
CUDADeviceContext
*
dev_context
()
const
override
{
return
dev_ctx_
.
get
();
}
gpuEvent_t
compute_event
()
const
override
{
return
compute_event_
.
get
();
}
gpuEvent_t
comm_event
()
const
override
{
return
comm_event_
.
get
();
}
void
set_compute_event
(
std
::
shared_ptr
<
platform
::
CudaEventObject
>&&
compute_event
)
{
compute_event_
=
std
::
move
(
compute_event
);
}
void
set_comm_event
(
std
::
shared_ptr
<
platform
::
CudaEventObject
>&&
comm_event
)
{
comm_event_
=
std
::
move
(
comm_event
);
}
private:
private:
int
ring_id_
;
int
ring_id_
;
int
nranks_
;
int
nranks_
;
int
rank_
;
int
rank_
;
ncclComm_t
comm_
;
ncclComm_t
comm_
;
std
::
unique_ptr
<
CUDADeviceContext
>
dev_ctx_
;
std
::
unique_ptr
<
CUDADeviceContext
>
dev_ctx_
;
// used for comm wait compute, compute_stream-->event-->comm_stream
std
::
shared_ptr
<
platform
::
CudaEventObject
>
compute_event_
;
// used for compute wait comm, comm_stream-->event-->compute_stream
std
::
shared_ptr
<
platform
::
CudaEventObject
>
comm_event_
;
};
};
NCCLComm
*
NCCLCommContext
::
CreateNCCLComm
(
ncclUniqueId
*
nccl_id
,
int
nranks
,
NCCLComm
*
NCCLCommContext
::
CreateNCCLComm
(
ncclUniqueId
*
nccl_id
,
int
nranks
,
...
@@ -124,12 +145,19 @@ NCCLComm* NCCLCommContext::AssignNCCLComm(ncclComm_t comm, int nranks, int rank,
...
@@ -124,12 +145,19 @@ NCCLComm* NCCLCommContext::AssignNCCLComm(ncclComm_t comm, int nranks, int rank,
std
::
unique_ptr
<
CUDADeviceContext
>
dev_ctx
(
std
::
unique_ptr
<
CUDADeviceContext
>
dev_ctx
(
new
CUDADeviceContext
(
CUDAPlace
(
dev_id
)));
new
CUDADeviceContext
(
CUDAPlace
(
dev_id
)));
std
::
shared_ptr
<
platform
::
CudaEventObject
>
compute_event
(
platform
::
CudaEventResourcePool
::
Instance
().
New
(
dev_id
));
std
::
shared_ptr
<
platform
::
CudaEventObject
>
comm_event
(
platform
::
CudaEventResourcePool
::
Instance
().
New
(
dev_id
));
NCCLCommImpl
*
c
=
new
NCCLCommImpl
;
NCCLCommImpl
*
c
=
new
NCCLCommImpl
;
c
->
set_ring_id
(
ring_id
);
c
->
set_ring_id
(
ring_id
);
c
->
set_nranks
(
nranks
);
c
->
set_nranks
(
nranks
);
c
->
set_rank
(
rank
);
c
->
set_rank
(
rank
);
c
->
set_comm
(
comm
);
c
->
set_comm
(
comm
);
c
->
set_dev_ctx
(
std
::
move
(
dev_ctx
));
c
->
set_dev_ctx
(
std
::
move
(
dev_ctx
));
c
->
set_compute_event
(
std
::
move
(
compute_event
));
c
->
set_comm_event
(
std
::
move
(
comm_event
));
comm_map_mutex_
.
lock
();
comm_map_mutex_
.
lock
();
if
(
comm_map_
.
count
(
ring_id
)
==
0
)
{
if
(
comm_map_
.
count
(
ring_id
)
==
0
)
{
...
...
paddle/fluid/platform/collective_helper.h
浏览文件 @
83a2fb1f
...
@@ -57,6 +57,8 @@ class NCCLComm {
...
@@ -57,6 +57,8 @@ class NCCLComm {
virtual
int
device_id
()
const
=
0
;
virtual
int
device_id
()
const
=
0
;
virtual
ncclComm_t
comm
()
const
=
0
;
virtual
ncclComm_t
comm
()
const
=
0
;
virtual
gpuStream_t
stream
()
const
=
0
;
virtual
gpuStream_t
stream
()
const
=
0
;
virtual
gpuEvent_t
compute_event
()
const
=
0
;
virtual
gpuEvent_t
comm_event
()
const
=
0
;
virtual
CUDADeviceContext
*
dev_context
()
const
=
0
;
virtual
CUDADeviceContext
*
dev_context
()
const
=
0
;
virtual
~
NCCLComm
()
=
default
;
virtual
~
NCCLComm
()
=
default
;
};
};
...
...
python/paddle/fluid/framework.py
浏览文件 @
83a2fb1f
...
@@ -2121,7 +2121,8 @@ class Operator(object):
...
@@ -2121,7 +2121,8 @@ class Operator(object):
'fl_listen_and_serv'
,
'ncclInit'
,
'select'
,
'checkpoint_notify'
,
'fl_listen_and_serv'
,
'ncclInit'
,
'select'
,
'checkpoint_notify'
,
'gen_bkcl_id'
,
'c_gen_bkcl_id'
,
'gen_nccl_id'
,
'c_gen_nccl_id'
,
'gen_bkcl_id'
,
'c_gen_bkcl_id'
,
'gen_nccl_id'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_sync_calc_stream'
,
'c_sync_comm_stream'
,
'c_comm_init'
,
'c_sync_calc_stream'
,
'c_sync_comm_stream'
,
'queue_generator'
,
'dequeue'
,
'enqueue'
,
'heter_listen_and_serv'
'queue_generator'
,
'dequeue'
,
'enqueue'
,
'heter_listen_and_serv'
,
'c_wait_comm'
,
'c_wait_compute'
}
}
def
__init__
(
self
,
def
__init__
(
self
,
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
83a2fb1f
...
@@ -84,6 +84,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
...
@@ -84,6 +84,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST
(
REMOVE_ITEM TEST_OPS test_collective_allreduce_api
)
LIST
(
REMOVE_ITEM TEST_OPS test_collective_allreduce_api
)
LIST
(
REMOVE_ITEM TEST_OPS test_collective_broadcast_api
)
LIST
(
REMOVE_ITEM TEST_OPS test_collective_broadcast_api
)
LIST
(
REMOVE_ITEM TEST_OPS test_collective_allgather_api
)
LIST
(
REMOVE_ITEM TEST_OPS test_collective_allgather_api
)
LIST
(
REMOVE_ITEM TEST_OPS test_collective_wait
)
LIST
(
REMOVE_ITEM TEST_OPS test_memcpy_op
)
LIST
(
REMOVE_ITEM TEST_OPS test_memcpy_op
)
endif
()
endif
()
...
...
python/paddle/fluid/tests/unittests/collective_allreduce_op_wait.py
0 → 100644
浏览文件 @
83a2fb1f
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
numpy
as
np
import
argparse
import
os
import
sys
import
signal
import
time
import
socket
from
contextlib
import
closing
from
six
import
string_types
import
math
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.profiler
as
profiler
import
paddle.fluid.unique_name
as
nameGen
from
paddle.fluid
import
core
import
unittest
from
multiprocessing
import
Process
import
paddle.fluid.layers
as
layers
from
functools
import
reduce
from
test_collective_base
import
TestCollectiveRunnerBase
,
runtime_main
paddle
.
enable_static
()
class
TestCollectiveAllreduce
(
TestCollectiveRunnerBase
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
):
ring_id
=
0
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
layers
.
data
(
name
=
"tindata"
,
shape
=
[
10
,
1000
],
dtype
=
'float32'
)
toutdata
=
main_prog
.
current_block
().
create_var
(
name
=
"outofallreduce"
,
dtype
=
'float32'
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
stop_gradient
=
False
)
# tout = tin + tin - tin = tin
if
True
:
main_prog
.
global_block
().
append_op
(
type
=
"elementwise_add"
,
inputs
=
{
'X'
:
tindata
,
'Y'
:
tindata
,
},
outputs
=
{
'Out'
:
toutdata
},
)
main_prog
.
global_block
().
append_op
(
type
=
"elementwise_sub"
,
inputs
=
{
'X'
:
toutdata
,
'Y'
:
tindata
,
},
outputs
=
{
'Out'
:
toutdata
},
)
main_prog
.
global_block
().
append_op
(
type
=
'c_wait_compute'
,
inputs
=
{
'X'
:
toutdata
},
outputs
=
{
'Out'
:
toutdata
},
attrs
=
{
'ring_id'
:
ring_id
})
main_prog
.
global_block
().
append_op
(
type
=
"c_allreduce_sum"
,
inputs
=
{
'X'
:
toutdata
},
attrs
=
{
'ring_id'
:
ring_id
},
outputs
=
{
'Out'
:
toutdata
},
attr
=
{
'use_calc_stream'
:
False
})
main_prog
.
global_block
().
append_op
(
type
=
"c_wait_comm"
,
inputs
=
{
'X'
:
toutdata
},
outputs
=
{
'Out'
:
toutdata
},
attrs
=
{
'ring_id'
:
ring_id
})
# tout = tin + tout - tin = tout
if
True
:
main_prog
.
global_block
().
append_op
(
type
=
"elementwise_add"
,
inputs
=
{
'X'
:
tindata
,
'Y'
:
toutdata
,
},
outputs
=
{
'Out'
:
toutdata
},
)
main_prog
.
global_block
().
append_op
(
type
=
"elementwise_sub"
,
inputs
=
{
'X'
:
toutdata
,
'Y'
:
tindata
,
},
outputs
=
{
'Out'
:
toutdata
},
)
return
toutdata
if
__name__
==
"__main__"
:
runtime_main
(
TestCollectiveAllreduce
,
"allreduce"
,
0
)
python/paddle/fluid/tests/unittests/test_collective_wait.py
0 → 100644
浏览文件 @
83a2fb1f
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle
from
test_collective_base
import
TestDistBase
paddle
.
enable_static
()
class
TestCWaitOp
(
TestDistBase
):
def
_setup_config
(
self
):
pass
def
test_allreduce_wait
(
self
):
self
.
check_with_place
(
"collective_allreduce_op_wait.py"
,
"allreduce"
,
check_error_log
=
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录