提交 d1443104 编写于 作者: D Dong Zhihong

"nccl add interface"

上级 51abb6c3
......@@ -106,6 +106,7 @@ function(op_library TARGET)
endfunction()
add_subdirectory(math)
add_subdirectory(nccl)
set(DEPS_OPS
recurrent_op
......
if(WITH_GPU)
nv_library(nccl_common SRCS nccl_gpu_common DEPS device_context operator)
nv_library(nccl_op SRCS nccl_ops.cc DEPS nccl_common)
else()
cc_library(nccl_common SRCS nccl_gpu_common DEPS device_context operator)
endif()
cc_test(nccl_gpu_common_test SRCS nccl_gpu_common_test.cc DEPS nccl_common)
#include "paddle/operators/nccl/nccl_gpu_common.h"
#include "paddle/platform/gpu_info.h"
namespace paddle {
namespace platform {
NCCLManager::NCCLManager() {}
NCCLManager::~NCCLManager() {
for (auto& p : comm_table) {
auto* comm = p.second;
auto& gpus_ = comm->gpus_;
for (int i = 0; i < gpus_.size(); ++i) {
int gid = gpus_[i];
platform::SetDeviceId(gid);
// mapping gid to idx
int idx = gid % gpus_.size();
// wait finish
NCCL_CHECK(
cudaStreamWaitEvent(*comm->streams_[idx], comm->events_[idx], 0));
NCCL_CHECK(cudaEventDestroy(comm->events_[idx]));
NCCL_CHECK(ncclCommDestroy(comm->comms_[idx]));
}
delete comm;
}
}
Communicator* NCCLManager::GetCommunicator(const std::vector<int>& gpus) const {
std::string key;
for (auto& id : gpus) {
key += std::to_string(id);
}
std::sort(key.begin(), key.end());
std::mutex mu;
std::lock_guard<std::mutex> lk(mu);
auto* comm = comm_table[key];
if (comm == nullptr) {
comm = new Communicator(gpus.size());
NCCL_CHECK(ncclCommInitAll(comm->comms_.data(), gpus.size(), gpus.data()));
for (size_t i = 0; i < gpus.size(); ++i) {
platform::SetDeviceId(gpus[i]);
// block wait
NCCL_CHECK(cudaEventCreateWithFlags(
&events_[i], cudaEventBlockingSync | cudaEventDisableTiming));
}
comm_table[key] = comm;
}
return comm;
}
} // namespace operators
} // namespace paddle
#pragma once
#include <nccl.h>
#include <algorithm>
#include <condition_variable>
#include <memory>
#include <mutex>
#include <condition_variable>
#include <vector>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/platform/device_context.h"
namespace paddle {
namespace platform {
#define NCCL_CHECK(condition) \
do { \
ncclResult_t ret = (condition); \
PADDLE_ENFORCE(ret == ncclSuccess, "Error invoking NCCL: ", __FILE__, \
__LINE__, ncclGetErrorString(ret)); \
} while (0)
class WaitGroup {
public:
inline void Add(int n) {
std::unique_lock<std::mutex> lk(mu_);
PADDLE_ENFORCE(n >= 0, "add wait must >=0.");
counter_ += n;
}
inline void Done(int n) {
std::unique_lock<std::mutex> lk(mu_);
PADDLE_ENFORCE(n <= counter_, " wait group done unmatch to add.");
counter_ -= n;
if (counter_ == 0) {
cv_.notify_all();
}
}
inline void Add() { Add(1); }
inline void Done() { Done(1); }
inline void Wait() {
std::unique_lock<std::mutex> lk(mu_);
cv_.wait(lk, [&] { return counter_ == 0; });
}
inline int GetCount() {
std::unique_lock<std::mutex> lk(mu_);
return counter_;
}
private:
int counter_ = 0;
std::mutex mu_;
std::condition_variable cv_;
};
// class NCCLContext : public DeviceContext {
// public:
......@@ -23,8 +68,26 @@ namespace platform {
// std::vector<cudaStream_t> streams_;
// };
// TODO(dzh) : make resources managed unified with framework
struct Communicator {
std::vector<ncclComm_t> comms_;
std::vector<cudaStream_t*> streams_;
std::vector<cudaEvent_t> events_;
std::vector<int> gpus_;
WaitGroup wg_;
int root_gpu = -1;
// cudaEvent_t root_monitor;
explicit Communicator(const std::vector<int>& gpus) : gpus_(gpus) {
comms_.resize(gpus.size());
streams_.resize(gpus.size());
events_.resize(gpus.size());
}
// Communicator(int num_device): comms_.resize(num_device) {}
inline int get_root_gpu() const { return root_gpu; }
class Communicator;
inline void set_root_gpu(int id) { root_gpu = id; }
};
class NCCLManager {
public:
......@@ -33,27 +96,20 @@ class NCCLManager {
return &m;
}
NCCLManager() {
}
~NCCLManager() {}
NCCLManager();
~NCCLManager();
// for each card only have one communicator
Communicator* GetCommunicator() const;
Communicator* GetCommunicator(const std::vector<int>& gpus) const;
private:
struct Communicator {
std::vector<ncclComm_t> comms_;
std::vector<cudaStream_t*> streams_; // do not own
std::vector<cudaEvent_t> events_;
int root_gpu;
};
// the gpu id list available. Note that only support
// whole world communication.
std::vector<int> _gpu_worlds;
// // the gpu id list available. Note that only support
// // whole world communication.
// std::vector<int> _gpu_worlds;
// communicator list
std::unordered_map<std::string /* key*/, Communicator*> comms_;
std::unordered_map<std::string /* key*/, Communicator*> comm_table;
};
} // namespace operators
......
#include "paddle/operators/nccl/nccl_gpu_common.h"
#include <gtest/gtest.h>
#include <chrono>
#include <thread>
#include <vector>
TEST(WaitGroup, wait) {
WaitGroup wg;
auto run_thread = [](int idx) {
wg.Add(1);
std::this_thread::sleep_for(std::chrono::seconds(1));
wg.Done();
};
std::vector<std::thread> ths;
constexpr const int TNUM = 5;
for (int i = 0; i < TNUM; ++i) {
ths.emplace_back(std::thread(run_thread, i));
}
wg.Wait();
}
......@@ -11,25 +11,20 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
protected:
// allreduce do nothing in infershape
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
" Input(X) of AllReduce op input should not be NULL");
PADDLE_ENFORCE_NOT_NULL(
ctx.InputVar("X"),
" Input(X) of AllReduce op input should not be NULL");
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
PADDLE_ENFORCE(ins.size() == outs.size(), "Input(X) and Output(Out) must have same size");
for(size_t i=0; i < ins.size(); ++i) {
PADDLE_ENFORCE(ins.size() == outs.size(),
"Input(X) and Output(Out) must have same size");
for (size_t i = 0; i < ins.size(); ++i) {
outs[i]->Resize(ins[i]->dims());
}
std::string reduction = ctx.Attr<std::string>("reduction");
PADDLE_ENFORCE( (reduction == "ncclSum" || reduction == "ncclProd" ||
reduction == "ncclMin" || reduction == "ncclMax"), "invalid reduction!");
}
};
template <typename T>
class NCCLAllreduceOp : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *ctx = static_cast<NCCLContext *>(context.device_context());
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
reduction == "ncclMin" || reduction == "ncclMax"),
"invalid reduction!");
}
};
......@@ -41,8 +36,9 @@ class NCCLBcastSendOp final : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
" Input(X) of BcastSend op input should not be NULL");
PADDLE_ENFORCE_NOT_NULL(
ctx.InputVar("X"),
" Input(X) of BcastSend op input should not be NULL");
}
};
......@@ -54,18 +50,21 @@ class NCCLBcastRecvOp final : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
" Input(X) of BcastRecv op input should not be NULL");
PADDLE_ENFORCE_NOT_NULL(
ctx.OutputVar("Out"),
" Input(X) of BcastRecv op input should not be NULL");
}
};
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
NCCLAllReduceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of AllReduce op");
AddOutput("Out", "The output of AllReduce op");
AddAttr<std::string>("reduction: {'min', 'max', 'prod', 'sum'}.");
AddAttr<std::string>("reduction",
"{'ncclmin', 'ncclmax', 'ncclprod', 'ncclsum'}.");
AddAttr<std::vector<int>>("gpus", "gpu id lists");
AddComment(R"DOC(
AllReduce the input tensors.
)DOC");
......@@ -73,8 +72,9 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
};
class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
NCCLAllReduceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of BcastSend op");
AddComment(R"DOC(
BcastSend the tensors.
......@@ -83,8 +83,9 @@ class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
};
class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
NCCLAllReduceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "The output of BcastRecv op");
AddComment(R"DOC(
BcastRecv the tensors.
......@@ -92,5 +93,5 @@ class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
}
}
} // operators
} // paddle
......@@ -7,29 +7,27 @@
namespace paddle {
namespace operators {
template<typename Type>
template <typename Type>
class NCCLTypeWrapper;
template<>
template <>
class NCCLTypeWrapper<float> {
static const ncclDataType_t type = ncclFloat;
};
template<>
template <>
class NCCLTypeWrapper<double> {
static const ncclDataType_t type = ncclDouble;
};
template<typename T>
template <typename T>
class NCCLAllReduceKernel : public framework::OpKernel {
public:
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<Tensor>("X");
auto outs = ctx.MultiOutput<Tensor>("Out");
std::string reduction = ctx.Attr<std::string>("reduction");
std::vector<int> gpus = ctx.Attr<std::vector<int>>("gpus");
ncclRedOp_t op_type;
if (reduction == "ncclSum") {
op_type = ncclSum;
......@@ -37,24 +35,40 @@ public:
op_type = ncclProd;
} else if (reduction == "ncclMin") {
op_type = ncclMin;
} else (reduction == "ncclMax") {
op_type = ncclMax;
}
} else
(reduction == "ncclMax") { op_type = ncclMax; }
auto dev_ctx =
static_cast<const platform::CUDADeviceContext>(ctx.device_context());
NCCLManager* m = NCCLManager::Get();
auto* comm = m->GetCommunicator(gpus);
comm->wg_.Add(1);
auto dev_ctx = ctx.device_context();
auto* stream = &dev_ctx.stream();
for( size_t i=0; i < ins.size(); ++i) {
ncclAllReduce(ins[i]->data<T>(),
outs[i]->mutable_data<T>(),
outs[i]->numel() * sizeof(T),
NCCLTypeWrapper<T>::type,
op_type,
comm,
stream);
// device id
int gid = ctx.GetPlace().GetDeviceId();
int idx = gid % gpus.size();
comm->streams_[idx] = stream;
for (size_t i = 0; i < ins.size(); ++i) {
NCCL_CHECK(ncclAllReduce(ins[i]->data<T>(), outs[i]->mutable_data<T>(),
outs[i]->numel() * sizeof(T),
NCCLTypeWrapper<T>::type, op_type,
&comm->comms_[idx], comm->streams_[idx]));
NCCL_CHECK(cudaEventRecord(comm->events_[idx], *comms_->streams_[idx]));
// wait finish
NCCL_CHECK(
cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0));
}
}
};
comm->wg_.Done();
wg.Wait();
}
};
}
}
......@@ -35,6 +35,7 @@ struct GPUPlace {
GPUPlace() : GPUPlace(0) {}
explicit GPUPlace(int d) : device(d) {}
inline int GetDeviceId() const { return device; }
// needed for variant equality comparison
inline bool operator==(const GPUPlace &o) const { return device == o.device; }
inline bool operator!=(const GPUPlace &o) const { return !(*this == o); }
......
......@@ -3,7 +3,7 @@ import numpy as np
import paddle.v2 as paddle
from paddle.v2.framework.op import Operator
import paddle.v2.framework.core as core
from op_test import OpTest, create_op
from op_test import OpTest, create_op, set_input
gpu_list = os.environ["NV_LIST"]
......@@ -11,7 +11,63 @@ if not core.is_compile_gpu() or not gpu_list:
exit(0)
def allreduce(tensors, num_device):
assert (len(tensors) == num_device), "not match of tensor and device"
Out = tensors
for i in range(1, len(tensors)):
Out[0] += Out[i]
for i in range(1, len(tensors)):
Out[i] = Out[0]
return Out
class TestNCCLAllReduce(unittest.TestCase):
def __init__(self):
self.op_type = "nnclAllReduce"
self.scope = core.Scope()
self.gpus = [int(g) for g in gpu_list]
self.scopes = []
self.ops = []
self.places = []
self.input_data = []
for i in range(len(self.gpus)):
input_data.append(np.random.random((32, 32)))
self.output_data = allreduce(input_data)
for i in range(len(self.gpus)):
scope = core.Scope()
place = core.GPUPlace(self.gpus[i])
inputs = {"X": self.input_data[i]}
outputs = {"Out": self.output_data[i]}
attrs = {"gpus": self.gpus}
op = create_op(scope, self.op_type, inputs, outputs, attrs)
set_input(scope, op, inputs, place)
self.scopes.append(scope)
self.ops.append(op)
self.places.append(place)
def test_output(self):
idx = 0
for scope, place, op in zip(self.scopes, self.places, self.ops):
ctx = core.DeviceContext.create(place)
op.run(scope, ctx)
for out_name, out_dup in Operator.get_op_outputs(self.op.type()):
actual = np.array(scope.find_var(out_name).get_tensor())
expect = self.output_data[idx]
idx += 1
self.assertTrue(actual, expect), "has diff"
if __name__ == "__main__":
# usage : export NV_LIST=0,1,2,3 python *.py
os.environ["NV_LIST"] = ["0,1,2,3"]
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册