未验证 提交 e26e51ba 编写于 作者: X xiayanming 提交者: GitHub

[fix bug] communication op suppport rccl (#41763)

上级 419d8eb2
......@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/alltoall_op.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
......@@ -26,7 +26,7 @@ template <typename T>
class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out");
......@@ -43,7 +43,7 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks();
cudaStream_t stream = nullptr;
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
......
......@@ -14,6 +14,9 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL)
#include <nccl.h>
#endif
#if defined(PADDLE_WITH_RCCL)
#include <rccl.h>
#endif
#include <stdint.h>
#include <ostream>
#include <string>
......@@ -24,7 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
// #include "paddle/fluid/operators/distributed/distributed.h"
// #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
......@@ -51,7 +54,7 @@ class CCommInitMultiTrainerOp : public framework::OperatorBase {
auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input X must be provided."));
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
ncclUniqueId* nccl_id = var->GetMutable<ncclUniqueId>();
int ntrainers = Attr<int>("ntrainers");
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/global_gather_op.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
......@@ -26,7 +26,7 @@ template <typename T>
class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X");
auto local_count = ctx.Input<framework::LoDTensor>("local_count");
......@@ -79,7 +79,7 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
ring_id));
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
cudaStream_t stream = nullptr;
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/global_scatter_op.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
......@@ -26,7 +26,7 @@ template <typename T>
class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X");
auto local_count = ctx.Input<framework::LoDTensor>("local_count");
......@@ -78,7 +78,7 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
cudaStream_t stream = nullptr;
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册