提交 48dea84b 编写于 作者: D dongzhihong

"nccl multigpu init"

上级 f2f839af
#pragma once
#include <nccl.h>
#include "paddle/platform/device_context.h"
namespace paddle {
namespace platform {
class NCCLManager {
public:
static NCCLManager* Get() {
static NCCLManager m;
return &m;
}
NCCLManager() { _comms.resize(_gpu_worlds.size()); }
~NCCLManager() {}
private:
// clang-format off
std::vector<ncclComm_t> _comms;
std::vector<int> _gpu_worlds;
// clang-format on
};
class NCCLContext : public DeviceContext {
public:
explicit NCCLContext(GPUPlace place);
virtual ~NCCLContext();
private:
// clang-format off
std::vector<int> _gpu_ids;
std::vector<cudaStream_t> _streams;
int root_gpu;
// clang-format on
};
}
}
#include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"
namespace paddle {
namespace operators {
// AllreduceOp
class NCCLAllreduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
// allreduce do nothing in infershape
void InferShape(const framework::InferShapeContext &ctx) const override {}
};
template <typename T>
class NCCLAllreduceOp : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *ctx = static_cast<NCCLContext *>(context.device_context());
// auto *comm = ;
// auto *src = ;
// ncclAllReduce(src, dest, )
}
};
// BcastSendOp
template <typename T>
class NCCLBroadcastSendOp final : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
};
// BcastRecvOp
template <typename T>
class NCCLBroadcastRecvOp final : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
};
}
}
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"
namespace paddle {
namespace operators {}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册