提交 50f04dca 编写于 作者: D Dong Zhihong

"add init allreduce test"

上级 f6106ffa
......@@ -80,8 +80,8 @@ function(op_library TARGET)
if ("${TARGET}" STREQUAL "nccl_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_GPU_ONLY_OP(ncclInit);\n")
file(APPEND ${pybind_file} "USE_GPU_ONLY_OP(ncclAllReduce);\n")
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(ncclInit);\n")
endif()
# reduce_op contains several operators
......@@ -148,7 +148,6 @@ foreach(src ${GENERAL_OPS})
endforeach()
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
message(STATUS "operators_list: ${OP_LIBRARY}")
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
......
......@@ -23,48 +23,12 @@
#include <vector>
#include "paddle/platform/device_context.h"
#include "paddle/platform/dynload/nccl.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace platform {
class WaitGroup {
public:
inline void Add(int n) {
std::unique_lock<std::mutex> lk(mu_);
PADDLE_ENFORCE(n >= 0, "add wait must >=0.");
counter_ += n;
}
inline void Done(int n) {
std::unique_lock<std::mutex> lk(mu_);
PADDLE_ENFORCE(n <= counter_, " wait group done unmatch to add.");
counter_ -= n;
if (counter_ == 0) {
cv_.notify_all();
}
}
inline void Add() { Add(1); }
inline void Done() { Done(1); }
inline void Wait() {
std::unique_lock<std::mutex> lk(mu_);
cv_.wait(lk, [&] { return counter_ == 0; });
}
inline int GetCount() {
std::unique_lock<std::mutex> lk(mu_);
return counter_;
}
private:
int counter_ = 0;
std::mutex mu_;
std::condition_variable cv_;
};
struct Communicator {
std::vector<ncclComm_t> comms_;
std::unordered_map<int, int> comm_id_map_;
......@@ -76,12 +40,13 @@ struct Communicator {
for (size_t i = 0; i < gpus.size(); ++i) {
comm_id_map_[gpus[i]] = i;
}
PADDLE_ENFORCE(ncclCommInitAll(comms_.data(), gpus.size(), gpus.data()));
PADDLE_ENFORCE(
dynload::ncclCommInitAll(comms_.data(), gpus.size(), gpus.data()));
}
~Communicator() {
for (size_t i = 0; i < comms_.size(); ++i) {
PADDLE_ENFORCE(ncclCommDestroy(comms_[i]));
PADDLE_ENFORCE(dynload::ncclCommDestroy(comms_[i]));
}
}
......
......@@ -21,8 +21,9 @@ class NCCLInitOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Communicator"),
" Input(X) of AllReduce op input should not be NULL");
PADDLE_ENFORCE(
ctx->HasOutput("Communicator"),
" Output(Communicator) of ncclInit op input should not be NULL");
}
};
......@@ -123,7 +124,7 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "The output of AllReduce op");
AddAttr<std::string>("reduction",
"{'ncclmin', 'ncclmax', 'ncclprod', 'ncclsum'}.");
AddAttr<std::vector<int>>("gpus", "gpu id lists");
// AddAttr<std::vector<int>>("gpus", "gpu id lists");
AddComment(R"DOC(
AllReduce the input tensors.
)DOC");
......
......@@ -39,7 +39,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
int idx = comm->GetCommId(device_id);
for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE(ncclAllReduce(
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel() * sizeof(T), NCCLTypeWrapper<T>::type, ncclSum,
comm->comms_[idx], stream));
......@@ -76,9 +76,9 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
if (root == device_id) {
recvbuffer = outs[i]->mutable_data<T>(ctx.GetPlace());
}
PADDLE_ENFORCE(ncclReduce(ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
NCCLTypeWrapper<T>::type, ncclSum, root,
comm->comms_[idx], stream));
PADDLE_ENFORCE(platform::dynload::ncclReduce(
ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
NCCLTypeWrapper<T>::type, ncclSum, root, comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
}
}
......@@ -105,17 +105,17 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
if (idx == root) {
auto ins = ctx.MultiInput<Tensor>("X");
for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE(ncclBcast((void*)ins[i]->data<T>(), ins[i]->numel(),
NCCLTypeWrapper<T>::type, root,
comm->comms_[idx], stream));
PADDLE_ENFORCE(platform::dynload::ncclBcast(
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type,
root, comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
}
} else {
auto outs = ctx.MultiOutput<Tensor>("Out");
for (size_t i = 0; i < outs.size(); ++i) {
PADDLE_ENFORCE(ncclBcast(outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel(), NCCLTypeWrapper<T>::type,
root, comm->comms_[idx], stream));
PADDLE_ENFORCE(platform::dynload::ncclBcast(
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(),
NCCLTypeWrapper<T>::type, root, comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
}
}
......
import unittest, os
import numpy as np
import paddle.v2 as paddle
from paddle.v2.framework.op import Operator
import paddle.v2.framework.core as core
from op_test import OpTest, create_op, set_input
# gpu_list = os.environ["NV_LIST"]
gpu_list = "0,1,2,3"
if not core.is_compile_gpu() or not gpu_list:
exit(0)
g_scope = core.Scope()
g_ctx = core.DeviceContext.create(core.CPUPlace())
class TestNCCLInit(OpTest):
def setUp(self):
self.op_type = "ncclInit"
self.gpus = [int(g) for g in gpu_list.split(",")]
self.attrs = {"gpus": self.gpus}
self.scope = g_scope.var("Communicator")
self.outputs = {"Communicator": self.scope.var("Communicator")}
def test_check_output(self):
self.check_output()
class TestNCCLAllReduce(unittest.TestCase):
def setUp(self):
# cpu allreduce for check
def allreduce(tensors, gpus):
num_device = len(gpus)
assert (
len(tensors) == num_device), "not match of tensor and device"
Out = tensors
for i in range(1, len(tensors)):
Out[0] += Out[i]
for i in range(1, len(tensors)):
Out[i] = Out[0]
return Out
self.op_type = "ncclAllReduce"
self.gpus = [int(g) for g in gpu_list.split(",")]
self.g_scope = core.Scope()
self.g_ctx = core.DeviceContext.create(core.CPUPlace())
self.scopes = []
self.ops = []
self.places = []
self.input_data = []
for i in range(len(self.gpus)):
self.input_data.append(np.random.random((32, 32)))
self.output_data = allreduce(self.input_data, self.gpus)
nccl_init = Operator("ncclInit", Out="Communicator", gpus=self.gpus)
nccl_init.run(self.g_scope, self.g_ctx)
for i in range(len(self.gpus)):
# insert kid scope
scope = self.g_scope.new_scope()
place = core.GPUPlace(self.gpus[i])
inputs = {
"X": self.input_data[i],
"Communicator": scope.find_var("Communicator")
}
outputs = {"Out": self.output_data[i]}
# attrs = {"gpus": self.gpus}
op = create_op(scope, self.op_type, inputs, outputs, attrs)
set_input(scope, op, inputs, place)
self.scopes.append(scope)
self.ops.append(op)
self.places.append(place)
def test_output(self):
idx = 0
for scope, place, op in zip(self.scopes, self.places, self.ops):
ctx = core.DeviceContext.create(place)
op.run(scope, ctx)
for out_name, out_dup in Operator.get_op_outputs(self.op.type()):
actual = np.array(scope.find_var(out_name).get_tensor())
expect = self.output_data[idx]
idx += 1
self.assertTrue(actual, expect), "has diff"
# if __name__ == "__main__":
# unittest.main()
# usage : export NV_LIST=0,1,2,3 python *.py
# os.environ["NV_LIST"] = ["0,1,2,3"]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册