提交 333045d7 编写于 作者: D Dong Zhihong

"move nccl to another directory"

上级 fdfc8f9b
......@@ -76,6 +76,14 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(sigmoid);\n")
endif()
# nccl_op contains several operators
if ("${TARGET}" STREQUAL "nccl_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(ncclInit);\n")
# file(APPEND ${pybind_file} "USE_OP(ncclInit);\n")
endif()
# reduce_op contains several operators
if ("${TARGET}" STREQUAL "reduce_op")
set(pybind_flag 1)
......@@ -116,7 +124,9 @@ set(DEPS_OPS
softmax_with_cross_entropy_op
sum_op
pool_op
pool_with_index_op)
pool_with_index_op
nccl_op
)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
......@@ -127,6 +137,9 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling)
if(WITH_GPU)
op_library(nccl_op DEPS nccl_common)
endif()
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
......@@ -134,6 +147,7 @@ foreach(src ${GENERAL_OPS})
endforeach()
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
message(STATUS "operators_list: ${OP_LIBRARY}")
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
......
if(WITH_GPU)
nv_library(nccl_common SRCS nccl_gpu_common DEPS device_context operator)
nv_library(nccl_op SRCS nccl_ops.cc DEPS nccl_common)
else()
cc_library(nccl_common SRCS nccl_gpu_common DEPS device_context operator)
nv_library(nccl_common SRCS nccl_gpu_common.cc DEPS device_context operator)
nv_test(nccl_gpu_common_test SRCS nccl_gpu_common_test.cc DEPS nccl_common)
endif()
cc_test(nccl_gpu_common_test SRCS nccl_gpu_common_test.cc DEPS nccl_common)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/nccl/nccl_gpu_common.h"
#include "paddle/platform/gpu_info.h"
namespace paddle {
namespace platform {
NCCLManager::NCCLManager() {}
NCCLManager::~NCCLManager() {
for (auto& p : comm_table) {
auto& comm = p.second;
auto& gpus_ = comm->gpus_;
for (size_t i = 0; i < gpus_.size(); ++i) {
int gid = gpus_[i];
platform::SetDeviceId(gid);
// mapping gid to idx
int idx = gid % gpus_.size();
// wait finish
PADDLE_ENFORCE(
cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0));
PADDLE_ENFORCE(cudaEventDestroy(comm->events_[idx]));
PADDLE_ENFORCE(ncclCommDestroy(comm->comms_[idx]));
}
comm.reset(nullptr);
}
}
Communicator* NCCLManager::GetCommunicator(const std::vector<int>& gpus) {
std::string key;
for (auto& id : gpus) {
key += std::to_string(id);
}
std::sort(key.begin(), key.end());
std::mutex mu;
std::lock_guard<std::mutex> lk(mu);
auto it = comm_table.find(key);
if (it->second == nullptr) {
auto* comm = new Communicator(gpus);
PADDLE_ENFORCE(
ncclCommInitAll(comm->comms_.data(), gpus.size(), gpus.data()));
for (size_t i = 0; i < gpus.size(); ++i) {
platform::SetDeviceId(gpus[i]);
// block wait
PADDLE_ENFORCE(cudaEventCreateWithFlags(
&comm->events_[i], cudaEventBlockingSync | cudaEventDisableTiming));
}
comm_table[key].reset(comm);
}
return comm_table[key].get();
}
} // namespace operators
namespace platform {} // namespace platform
} // namespace paddle
......@@ -65,65 +65,30 @@ class WaitGroup {
std::condition_variable cv_;
};
// TODO(dzh) : make resources managed unified with framework
struct Communicator {
std::vector<ncclComm_t> comms_;
std::vector<cudaStream_t> streams_;
std::vector<cudaEvent_t> events_;
std::vector<int> gpus_;
WaitGroup wg_;
int root_gpu = -1;
// cudaEvent_t root_monitor;
explicit Communicator(const std::vector<int>& gpus) : gpus_(gpus) {
std::unordered_map<int, int> comm_id_map_;
int GetCommId(int device_id) const { return comm_id_map_.at(device_id); }
void InitAll(const std::vector<int>& gpus) {
comms_.resize(gpus.size());
streams_.resize(gpus.size());
events_.resize(gpus.size());
for (size_t i = 0; i < gpus.size(); ++i) {
comm_id_map_[gpus[i]] = i;
}
PADDLE_ENFORCE(ncclCommInitAll(comms_.data(), gpus.size(), gpus.data()));
}
~Communicator() {
for (size_t i = 0; i < gpus_.size(); ++i) {
int gid = gpus_[i];
platform::SetDeviceId(gid);
int idx = gid % gpus_.size();
// wait finish
PADDLE_ENFORCE(
cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0));
PADDLE_ENFORCE(cudaEventDestroy(comm->events_[idx]));
PADDLE_ENFORCE(ncclCommDestroy(comm->comms_[idx]));
for (size_t i = 0; i < comms_.size(); ++i) {
PADDLE_ENFORCE(ncclCommDestroy(comms_[i]));
}
}
inline int get_root_gpu() const { return root_gpu; }
inline void set_root_gpu(int id) { root_gpu = id; }
// DISABLE_COPY_AND_ASSIGN(Communicator);
};
class NCCLManager {
public:
static NCCLManager* Get() {
static NCCLManager m;
return &m;
}
NCCLManager();
~NCCLManager();
// for each card only have one communicator
Communicator* GetCommunicator(const std::vector<int>& gpus);
private:
// // the gpu id list available. Note that only support
// // whole world communication.
// std::vector<int> _gpu_worlds;
// communicator list
std::unordered_map<std::string /* key*/, std::unique_ptr<Communicator>>
comm_table;
};
Communicator* NewCommunicator(const std::vector<int>& gpus);
} // namespace platform
} // namespace paddle
......@@ -9,7 +9,7 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/nccl/nccl_ops.h"
#include "paddle/operators/nccl_op.h"
namespace paddle {
namespace operators {
......@@ -85,31 +85,36 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
// // BcastSendOp
// class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
// public:
// NCCLAllReduceOpMaker(framework::OpProto *proto,
// framework::OpAttrChecker *op_checker)
// : OpProtoAndCheckerMaker(proto, op_checker) {
// AddInput("X", "The input of BcastSend op");
// AddComment(R"DOC(
// BcastSend the tensors.
// )DOC");
// }
// };
// BcastOp
class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllBcastOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of Bcast op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddInput("root", "root gpu of Bcast");
AddComment(R"DOC(
Bcast the tensors.
)DOC");
}
};
// // BcastRecvOp
// class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
// public:
// NCCLAllReduceOpMaker(framework::OpProto *proto,
// framework::OpAttrChecker *op_checker)
// : OpProtoAndCheckerMaker(proto, op_checker) {
// AddOutput("Out", "The output of BcastRecv op");
// AddComment(R"DOC(
// BcastRecv the tensors.
// )DOC");
// }
// };
// BcastRecvOp
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLReduceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of Reduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddInput("root", "root gpu of Reduce");
AddOutput("Out", "The output of Reduce op");
AddComment(R"DOC(
Reduce the tensors.
)DOC");
}
};
} // namespace operators
} // namespace paddle
......@@ -117,3 +122,5 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp,
ops::NCCLAllReduceOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclInit, ops::NCCLInitOp, ops::NCCLInitOpMaker);
REGISTER_OP_CPU_KERNEL(ncclInit, ops::NCCLInitKernel<float>);
......@@ -10,7 +10,57 @@ See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/nccl/nccl_ops.h"
#include "paddle/operators/nccl_op.h"
namespace paddle {
namespace operators {
template <typename T>
class NCCLAllReduceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto ins = ctx.MultiInput<Tensor>("X");
auto outs = ctx.MultiOutput<Tensor>("Out");
std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t op_type;
if (reduction == "ncclSum") {
op_type = ncclSum;
} else if (reduction == "ncclProd") {
op_type = ncclProd;
} else if (reduction == "ncclMin") {
op_type = ncclMin;
} else if (reduction == "ncclMax") {
op_type = ncclMax;
} else {
PADDLE_ENFORCE(false, "reduction error.");
}
auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
// device id
int device_id =
boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(device_id);
for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE(ncclAllReduce(
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel() * sizeof(T), NCCLTypeWrapper<T>::type, op_type,
comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(ncclAllReduce, ops::NCCLAllReduceKernel<float>);
\ No newline at end of file
REGISTER_OP_GPU_KERNEL(ncclAllReduce, ops::NCCLAllReduceKernel<float>);
......@@ -19,6 +19,7 @@ namespace paddle {
namespace operators {
using framework::Tensor;
using platform::Communicator;
template <typename Type>
class NCCLTypeWrapper;
......@@ -35,67 +36,13 @@ class NCCLTypeWrapper<double> {
static const ncclDataType_t type = ncclDouble;
};
class NCCLInitOp : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto gpus = ctx.Input<std::vector<int>>("gpus");
auto* comm = ctx.Output<Communicator>("Communicator");
comm->mutable_data<Communicator>(CPUPlace());
comm = NCCLManager::GetCommunicator(gpus);
}
};
template <typename T>
class NCCLAllReduceKernel : public framework::OpKernel<T> {
class NCCLInitKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<Tensor>("X");
auto outs = ctx.MultiOutput<Tensor>("Out");
std::string reduction = ctx.Attr<std::string>("reduction");
std::vector<int> gpus = ctx.Attr<std::vector<int>>("gpus");
ncclRedOp_t op_type;
if (reduction == "ncclSum") {
op_type = ncclSum;
} else if (reduction == "ncclProd") {
op_type = ncclProd;
} else if (reduction == "ncclMin") {
op_type = ncclMin;
} else if (reduction == "ncclMax") {
op_type = ncclMax;
}
auto* comm = ctx.Input<Communicator>("Communicator");
auto dev_ctx =
static_cast<const platform::CUDADeviceContext>(ctx.device_context());
// platform::NCCLManager* m = platform::NCCLManager::Get();
// auto* comm = m->GetCommunicator(gpus);
// comm->wg_.Add(1);
auto stream = dev_ctx.stream();
// device id
int gid = static_cast<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = gid % gpus.size();
comm->streams_[idx] = stream;
for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE(
ncclAllReduce(ins[i]->data<T>(), outs[i]->mutable_data<T>(),
outs[i]->numel() * sizeof(T), NCCLTypeWrapper<T>::type,
op_type, comm->comms_[idx], comm->streams_[idx]));
PADDLE_ENFORCE(cudaEventRecord(comm->events_[idx], comm->streams_[idx]));
// // wait finish
// PADDLE_ENFORCE(
// cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0));
}
// comm->wg_.Done();
// comm->wg_.Wait();
auto* gpus = ctx.Input<std::vector<int>>("gpus");
auto* comm = ctx.Output<Communicator>("Communicator");
comm->InitAll(*gpus);
}
};
......
......@@ -5,13 +5,15 @@ from paddle.v2.framework.op import Operator
import paddle.v2.framework.core as core
from op_test import OpTest, create_op, set_input
gpu_list = os.environ["NV_LIST"]
# gpu_list = os.environ["NV_LIST"]
gpu_list = "0,1,2,3"
if not core.is_compile_gpu() or not gpu_list:
exit(0)
def allreduce(tensors, num_device):
def allreduce(tensors, gpus):
num_device = len(gpus)
assert (len(tensors) == num_device), "not match of tensor and device"
Out = tensors
for i in range(1, len(tensors)):
......@@ -24,23 +26,32 @@ def allreduce(tensors, num_device):
class TestNCCLAllReduce(unittest.TestCase):
def __init__(self):
self.op_type = "nnclAllReduce"
def setUp(self):
self.gpus = [int(g) for g in gpu_list]
self.op_type = "ncclAllReduce"
self.gpus = [int(g) for g in gpu_list.split(",")]
self.g_scope = core.Scope()
self.g_ctx = core.DeviceContext.create(core.CPUPlace())
self.scopes = []
self.ops = []
self.places = []
self.input_data = []
for i in range(len(self.gpus)):
input_data.append(np.random.random((32, 32)))
self.output_data = allreduce(input_data)
self.input_data.append(np.random.random((32, 32)))
self.output_data = allreduce(self.input_data, self.gpus)
nccl_init = Operator("ncclInit", Out="Communicator", gpus=self.gpus)
op.run(self.g_scope, self.g_ctx)
for i in range(len(self.gpus)):
scope = core.Scope()
# insert kid scope
scope = self.g_scope.new_scope()
place = core.GPUPlace(self.gpus[i])
inputs = {"X": self.input_data[i]}
outputs = {"Out": self.output_data[i]}
attrs = {"gpus": self.gpus}
......@@ -66,8 +77,11 @@ class TestNCCLAllReduce(unittest.TestCase):
self.assertTrue(actual, expect), "has diff"
if __name__ == "__main__":
# usage : export NV_LIST=0,1,2,3 python *.py
# if __name__ == "__main__":
# unittest.main()
# usage : export NV_LIST=0,1,2,3 python *.py
# os.environ["NV_LIST"] = ["0,1,2,3"]
os.environ["NV_LIST"] = ["0,1,2,3"]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册