未验证 提交 e26e51ba 编写于 作者: X xiayanming 提交者: GitHub

[fix bug] communication op suppport rccl (#41763)

上级 419d8eb2
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/alltoall_op.h" #include "paddle/fluid/operators/collective/alltoall_op.h"
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
...@@ -26,7 +26,7 @@ template <typename T> ...@@ -26,7 +26,7 @@ template <typename T>
class AllToAllOpCUDAKernel : public framework::OpKernel<T> { class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703 #if NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X"); auto x = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out"); auto out = ctx.Output<framework::LoDTensor>("Out");
...@@ -43,7 +43,7 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> { ...@@ -43,7 +43,7 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks(); int nranks = comm->nranks();
cudaStream_t stream = nullptr; gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream(); stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
......
...@@ -14,6 +14,9 @@ limitations under the License. */ ...@@ -14,6 +14,9 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
#include <nccl.h> #include <nccl.h>
#endif #endif
#if defined(PADDLE_WITH_RCCL)
#include <rccl.h>
#endif
#include <stdint.h> #include <stdint.h>
#include <ostream> #include <ostream>
#include <string> #include <string>
...@@ -24,7 +27,7 @@ limitations under the License. */ ...@@ -24,7 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
// #include "paddle/fluid/operators/distributed/distributed.h" // #include "paddle/fluid/operators/distributed/distributed.h"
// #include "paddle/fluid/operators/distributed/request_handler_impl.h" // #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
...@@ -51,7 +54,7 @@ class CCommInitMultiTrainerOp : public framework::OperatorBase { ...@@ -51,7 +54,7 @@ class CCommInitMultiTrainerOp : public framework::OperatorBase {
auto var = scope.FindVar(Input("X")); auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input X must be provided.")); var, platform::errors::InvalidArgument("Input X must be provided."));
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
ncclUniqueId* nccl_id = var->GetMutable<ncclUniqueId>(); ncclUniqueId* nccl_id = var->GetMutable<ncclUniqueId>();
int ntrainers = Attr<int>("ntrainers"); int ntrainers = Attr<int>("ntrainers");
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/global_gather_op.h" #include "paddle/fluid/operators/collective/global_gather_op.h"
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
...@@ -26,7 +26,7 @@ template <typename T> ...@@ -26,7 +26,7 @@ template <typename T>
class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> { class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703 #if NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X"); auto x = ctx.Input<framework::LoDTensor>("X");
auto local_count = ctx.Input<framework::LoDTensor>("local_count"); auto local_count = ctx.Input<framework::LoDTensor>("local_count");
...@@ -79,7 +79,7 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -79,7 +79,7 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
ring_id)); ring_id));
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
cudaStream_t stream = nullptr; gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream(); stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/global_scatter_op.h" #include "paddle/fluid/operators/collective/global_scatter_op.h"
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
...@@ -26,7 +26,7 @@ template <typename T> ...@@ -26,7 +26,7 @@ template <typename T>
class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> { class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703 #if NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X"); auto x = ctx.Input<framework::LoDTensor>("X");
auto local_count = ctx.Input<framework::LoDTensor>("local_count"); auto local_count = ctx.Input<framework::LoDTensor>("local_count");
...@@ -78,7 +78,7 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -78,7 +78,7 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
cudaStream_t stream = nullptr; gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream(); stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册