From ef257e6d96e5b99710a9d63e11a6642163f4e018 Mon Sep 17 00:00:00 2001 From: Dong Zhihong Date: Tue, 24 Oct 2017 11:11:22 -0700 Subject: [PATCH] write nccl c++ test case --- paddle/operators/CMakeLists.txt | 4 + paddle/operators/nccl/CMakeLists.txt | 1 - paddle/operators/nccl/nccl_gpu_common.h | 2 - paddle/operators/nccl/nccl_gpu_common_test.cc | 33 ----- paddle/operators/nccl_op.cc | 27 ++-- paddle/operators/nccl_op.cu | 1 - paddle/operators/nccl_op.h | 4 +- paddle/operators/nccl_op_test.cc | 71 ++++++++++ paddle/operators/nccl_op_test.cu | 71 ++++++++++ paddle/pybind/pybind.cc | 13 +- .../v2/framework/tests/test_multigpu.py | 8 ++ .../framework/tests/test_nccl_allreduce_op.py | 122 +++++++++--------- .../v2/framework/tests/test_nccl_init_op.py | 36 ++++++ .../v2/framework/tests/test_nccl_reduce_op.py | 19 +++ 14 files changed, 298 insertions(+), 114 deletions(-) delete mode 100644 paddle/operators/nccl/nccl_gpu_common_test.cc create mode 100644 paddle/operators/nccl_op_test.cc create mode 100644 paddle/operators/nccl_op_test.cu create mode 100644 python/paddle/v2/framework/tests/test_multigpu.py create mode 100644 python/paddle/v2/framework/tests/test_nccl_init_op.py diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 5da637dd7d..0f2122b4b0 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -154,3 +154,7 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory) cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array) + +if(WITH_GPU) + nv_test(nccl_op_test SRCS nccl_op_test.cu DEPS nccl_op gpu_info device_context) +endif() diff --git a/paddle/operators/nccl/CMakeLists.txt b/paddle/operators/nccl/CMakeLists.txt index 21cc1d9ee9..ce0ddd89bf 100644 --- a/paddle/operators/nccl/CMakeLists.txt +++ b/paddle/operators/nccl/CMakeLists.txt @@ -1,4 +1,3 @@ if(WITH_GPU) nv_library(nccl_common SRCS nccl_gpu_common.cc DEPS device_context operator ) - nv_test(nccl_gpu_common_test SRCS nccl_gpu_common_test.cc DEPS nccl_common) endif() diff --git a/paddle/operators/nccl/nccl_gpu_common.h b/paddle/operators/nccl/nccl_gpu_common.h index 648693508d..f492f96aa8 100644 --- a/paddle/operators/nccl/nccl_gpu_common.h +++ b/paddle/operators/nccl/nccl_gpu_common.h @@ -53,7 +53,5 @@ struct Communicator { // DISABLE_COPY_AND_ASSIGN(Communicator); }; -Communicator* NewCommunicator(const std::vector& gpus); - } // namespace platform } // namespace paddle diff --git a/paddle/operators/nccl/nccl_gpu_common_test.cc b/paddle/operators/nccl/nccl_gpu_common_test.cc deleted file mode 100644 index 6f6a4ac886..0000000000 --- a/paddle/operators/nccl/nccl_gpu_common_test.cc +++ /dev/null @@ -1,33 +0,0 @@ -#include "paddle/operators/nccl/nccl_gpu_common.h" - -#include - -#include -#include -#include - -namespace paddle { -namespace platform { - -TEST(WaitGroup, wait) { - WaitGroup wg; - auto run_thread = [&wg](int idx) { - wg.Add(1); - std::this_thread::sleep_for(std::chrono::seconds(1)); - wg.Done(); - }; - - std::vector ths; - constexpr const int TNUM = 5; - for (int i = 0; i < TNUM; ++i) { - ths.emplace_back(std::thread(run_thread, i)); - } - wg.Wait(); - - for (int i = 0; i < TNUM; ++i) { - ths[i].join(); - } -} - -} // namespace platform -} // namespace paddle diff --git a/paddle/operators/nccl_op.cc b/paddle/operators/nccl_op.cc index ee6ed0ae85..6213f23613 100644 --- a/paddle/operators/nccl_op.cc +++ b/paddle/operators/nccl_op.cc @@ -21,9 +21,14 @@ class NCCLInitOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE( - ctx->HasOutput("Communicator"), - " Output(Communicator) of ncclInit op input should not be NULL"); + PADDLE_ENFORCE(ctx->HasOutput("Communicator"), + " Output(Communicator) of ncclInitOp should not be NULL"); + } + + protected: + framework::DataType IndicateDataType( + const framework::ExecutionContext &ctx) const override { + return static_cast(ctx.Attr("data_type")); } }; @@ -32,9 +37,11 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker { NCCLInitOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddAttr>("gpus", "gpu id lists"); AddOutput("Communicator", "Create Communicator for communicating between gpus"); + AddAttr>("gpus", "gpu id lists"); + AddAttr("data_type", "output data type") + .SetDefault(framework::DataType::FP32); AddComment(R"DOC( create communicator. )DOC"); @@ -58,10 +65,10 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputsDim("X"); - std::string reduction = ctx->Attrs().Get("reduction"); - PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || - reduction == "ncclMin" || reduction == "ncclMax"), - "invalid reduction."); + // std::string reduction = ctx->Attrs().Get("reduction"); + // PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || + // reduction == "ncclMin" || reduction == "ncclMax"), + // "invalid reduction."); ctx->SetOutputsDim("Out", x_dims); ctx->ShareLoD("X", /*->*/ "Out"); @@ -122,8 +129,8 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The input of AllReduce op"); AddInput("Communicator", "Communicator for communicating between gpus"); AddOutput("Out", "The output of AllReduce op"); - AddAttr("reduction", - "{'ncclmin', 'ncclmax', 'ncclprod', 'ncclsum'}."); + // AddAttr("reduction", + // "{'ncclmin', 'ncclmax', 'ncclprod', 'ncclsum'}."); // AddAttr>("gpus", "gpu id lists"); AddComment(R"DOC( AllReduce the input tensors. diff --git a/paddle/operators/nccl_op.cu b/paddle/operators/nccl_op.cu index ee19a69afc..00a115feeb 100644 --- a/paddle/operators/nccl_op.cu +++ b/paddle/operators/nccl_op.cu @@ -26,7 +26,6 @@ class NCCLAllReduceKernel : public framework::OpKernel { auto ins = ctx.MultiInput("X"); auto outs = ctx.MultiOutput("Out"); - std::string reduction = ctx.Attr("reduction"); auto* comm = ctx.Input("Communicator"); diff --git a/paddle/operators/nccl_op.h b/paddle/operators/nccl_op.h index 09606c4acd..a438e4eaa2 100644 --- a/paddle/operators/nccl_op.h +++ b/paddle/operators/nccl_op.h @@ -40,9 +40,9 @@ template class NCCLInitKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* gpus = ctx.Input>("gpus"); + std::vector gpus = ctx.Attr>("gpus"); auto* comm = ctx.Output("Communicator"); - comm->InitAll(*gpus); + comm->InitAll(gpus); } }; diff --git a/paddle/operators/nccl_op_test.cc b/paddle/operators/nccl_op_test.cc new file mode 100644 index 0000000000..9c319a3387 --- /dev/null +++ b/paddle/operators/nccl_op_test.cc @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include "paddle/operators/nccl_op.h" + +#include "glog/logging.h" +#include "gtest/gtest.h" + +#include "paddle/platform/device_context.h" +#include "paddle/platform/enforce.h" +#include "paddle/platform/gpu_info.h" + +#include +#include +#include + +static std::vector gpu_list; + +using f = paddle::framework; +using ops = paddle::operators; + +void AddOp(const std::string &type, const f::VariableNameMap &inputs, + const f::VariableNameMap &outputs, f::AttributeMap attrs, + paddle::framework::BlockDescBind *block) { + for (auto kv : outputs) { + for (auto v : kv.second) { + auto var = block->Var(v); + var->SetDataType(paddle::framework::DataType::FP32); + } + } + + auto op = block->AppendOp(); + op->SetType(type); + for (auto &kv : inputs) { + op->SetInput(kv.first, kv.second); + } + for (auto &kv : outputs) { + op->SetOutput(kv.first, kv.second); + } + op->SetAttrMap(attrs); +} + +TEST(NCCL, ncclInitOp) { + f::ProgramDescBind program; + f::BlockDescBind *block = program.Block(0); +} + +int main(int argc, char **argv) { + static constexpr int gpu_count = paddle::platform::GetCUDADeviceCount(); + for (int i = 0; i < gpu_count; ++i) { + gpu_list.emplace_back(i); + } + if (dev_count <= 1) { + LOG(WARNING) + << "Cannot test multi-gpu nccl, because the CUDA device count is " + << dev_count; + return 0; + } + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/paddle/operators/nccl_op_test.cu b/paddle/operators/nccl_op_test.cu new file mode 100644 index 0000000000..9c319a3387 --- /dev/null +++ b/paddle/operators/nccl_op_test.cu @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include "paddle/operators/nccl_op.h" + +#include "glog/logging.h" +#include "gtest/gtest.h" + +#include "paddle/platform/device_context.h" +#include "paddle/platform/enforce.h" +#include "paddle/platform/gpu_info.h" + +#include +#include +#include + +static std::vector gpu_list; + +using f = paddle::framework; +using ops = paddle::operators; + +void AddOp(const std::string &type, const f::VariableNameMap &inputs, + const f::VariableNameMap &outputs, f::AttributeMap attrs, + paddle::framework::BlockDescBind *block) { + for (auto kv : outputs) { + for (auto v : kv.second) { + auto var = block->Var(v); + var->SetDataType(paddle::framework::DataType::FP32); + } + } + + auto op = block->AppendOp(); + op->SetType(type); + for (auto &kv : inputs) { + op->SetInput(kv.first, kv.second); + } + for (auto &kv : outputs) { + op->SetOutput(kv.first, kv.second); + } + op->SetAttrMap(attrs); +} + +TEST(NCCL, ncclInitOp) { + f::ProgramDescBind program; + f::BlockDescBind *block = program.Block(0); +} + +int main(int argc, char **argv) { + static constexpr int gpu_count = paddle::platform::GetCUDADeviceCount(); + for (int i = 0; i < gpu_count; ++i) { + gpu_list.emplace_back(i); + } + if (dev_count <= 1) { + LOG(WARNING) + << "Cannot test multi-gpu nccl, because the CUDA device count is " + << dev_count; + return 0; + } + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index b6e44fdbad..e1e382b2bb 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/framework/tensor_array.h" #include "paddle/operators/cond_op.h" #include "paddle/operators/dynamic_recurrent_op.h" +#include "paddle/operators/nccl/nccl_gpu_common.h" #include "paddle/operators/net_op.h" #include "paddle/operators/recurrent_op.h" #include "paddle/platform/enforce.h" @@ -203,6 +204,13 @@ All parameter, weight, gradient are variables in Paddle. return self.GetMutable(); }, py::return_value_policy::reference) +#ifdef PADDLE_WITH_CUDA + .def("get_communicator", + [](Variable &self) -> platform::Communicator * { + return self.GetMutable(); + }, + py::return_value_policy::reference) +#endif .def("get_net", [](Variable &self) -> operators::NetOp * { return self.GetMutable(); @@ -258,8 +266,11 @@ All parameter, weight, gradient are variables in Paddle. return new paddle::platform::CUDADeviceContext(place); #endif }); - // clang-format on +// clang-format on +#ifdef PADDLE_WITH_CUDA + py::class_(m, "Communicator").def(py::init<>()); +#endif py::class_(m, "GPUPlace") .def(py::init()) .def("__str__", string::to_string); diff --git a/python/paddle/v2/framework/tests/test_multigpu.py b/python/paddle/v2/framework/tests/test_multigpu.py new file mode 100644 index 0000000000..b75d274d88 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_multigpu.py @@ -0,0 +1,8 @@ +import unittest, os +import numpy as np +import paddle.v2 as paddle +from paddle.v2.framework.op import Operator +import paddle.v2.framework.core as core +from op_test import OpTest, create_op, set_input + +gpu_list = "0,1,2,3" diff --git a/python/paddle/v2/framework/tests/test_nccl_allreduce_op.py b/python/paddle/v2/framework/tests/test_nccl_allreduce_op.py index 0e6927a24d..06e079eda8 100644 --- a/python/paddle/v2/framework/tests/test_nccl_allreduce_op.py +++ b/python/paddle/v2/framework/tests/test_nccl_allreduce_op.py @@ -1,4 +1,5 @@ import unittest, os +from threading import Thread import numpy as np import paddle.v2 as paddle from paddle.v2.framework.op import Operator @@ -13,94 +14,87 @@ if not core.is_compile_gpu() or not gpu_list: g_scope = core.Scope() g_ctx = core.DeviceContext.create(core.CPUPlace()) +gpus = [int(g) for g in gpu_list.split(",")] -class TestNCCLInit(OpTest): - def setUp(self): - self.op_type = "ncclInit" - self.gpus = [int(g) for g in gpu_list.split(",")] - - self.attrs = {"gpus": self.gpus} - self.scope = g_scope.var("Communicator") - self.outputs = {"Communicator": self.scope.var("Communicator")} +# ground truth +def allreduce(tensors, gpus): + num_device = len(gpus) + assert (len(tensors) == num_device), "not match of tensor and device" + Out = tensors + for i in range(1, len(tensors)): + Out[0] += Out[i] - def test_check_output(self): - self.check_output() + for i in range(1, len(tensors)): + Out[i] = Out[0] + return Out -class TestNCCLAllReduce(unittest.TestCase): - def setUp(self): - # cpu allreduce for check - def allreduce(tensors, gpus): - num_device = len(gpus) - assert ( - len(tensors) == num_device), "not match of tensor and device" - Out = tensors - for i in range(1, len(tensors)): - Out[0] += Out[i] - for i in range(1, len(tensors)): - Out[i] = Out[0] - - return Out - - self.op_type = "ncclAllReduce" +input_data = [ + np.random.random((32, 32)).astype("float32") for i in range(len(gpus)) +] +output_data = allreduce(input_data, gpus) - self.gpus = [int(g) for g in gpu_list.split(",")] +# output_vars = [g_scope.var("Out_"+str(i)).get_tensor() +# for i in range(len(gpus))] - self.g_scope = core.Scope() - self.g_ctx = core.DeviceContext.create(core.CPUPlace()) - self.scopes = [] - self.ops = [] - self.places = [] - self.input_data = [] +def thread_allreduce_op(thread_id, gpu_id): + i = gpu_id + scope = g_scope.new_scope() + place = core.GPUPlace(gpus[i]) + inputs = { + "X": input_data[i], + "Communicator": scope.find_var("Communicator") + } + outputs = {"Out": output_data[i]} - for i in range(len(self.gpus)): - self.input_data.append(np.random.random((32, 32))) - self.output_data = allreduce(self.input_data, self.gpus) + op = create_op(scope, "ncclAllReduce", inputs, outputs, attrs={}) + place = core.GPUPlace(gpus[i]) + set_input(scope, op, inputs, place) - nccl_init = Operator("ncclInit", Out="Communicator", gpus=self.gpus) - nccl_init.run(self.g_scope, self.g_ctx) + ctx = core.DeviceContext.create(place) - for i in range(len(self.gpus)): - # insert kid scope - scope = self.g_scope.new_scope() - place = core.GPUPlace(self.gpus[i]) + print "thread_id : ", thread_id, "gpu_id : ", gpu_id, " invoke allreduce" + op.run(scope, ctx) + print "thread_id : ", thread_id, "gpu_id : ", gpu_id, " allreduce Done." - inputs = { - "X": self.input_data[i], - "Communicator": scope.find_var("Communicator") - } - outputs = {"Out": self.output_data[i]} - # attrs = {"gpus": self.gpus} - op = create_op(scope, self.op_type, inputs, outputs, attrs) - set_input(scope, op, inputs, place) +class TestNCCLAllReduce(unittest.TestCase): + def setUp(self): + self.op_type = "ncclAllReduce" - self.scopes.append(scope) - self.ops.append(op) - self.places.append(place) + nccl_init = create_op( + g_scope, + op_type="ncclInit", + inputs={}, + outputs={ + "Communicator": g_scope.var("Communicator").get_communicator() + }, + attrs={"gpus": gpus}) + nccl_init.run(g_scope, g_ctx) def test_output(self): - idx = 0 - for scope, place, op in zip(self.scopes, self.places, self.ops): - ctx = core.DeviceContext.create(place) - op.run(scope, ctx) + ops = [] + for i in range(len(gpus)): + th = Thread( + target=thread_allreduce_op, args=( + i, + gpus[i], )) + th.start() + ops.append(ops) + for th in ops: + th.join() + idx = 0 for out_name, out_dup in Operator.get_op_outputs(self.op.type()): actual = np.array(scope.find_var(out_name).get_tensor()) - expect = self.output_data[idx] + expect = output_data[idx] idx += 1 self.assertTrue(actual, expect), "has diff" -# if __name__ == "__main__": -# unittest.main() -# usage : export NV_LIST=0,1,2,3 python *.py - -# os.environ["NV_LIST"] = ["0,1,2,3"] - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/v2/framework/tests/test_nccl_init_op.py b/python/paddle/v2/framework/tests/test_nccl_init_op.py new file mode 100644 index 0000000000..8aed14c15d --- /dev/null +++ b/python/paddle/v2/framework/tests/test_nccl_init_op.py @@ -0,0 +1,36 @@ +import unittest, os +import numpy as np +import paddle.v2 as paddle +from paddle.v2.framework.op import Operator +import paddle.v2.framework.core as core +from op_test import OpTest, create_op, set_input + +gpu_list = "0,1,2,3" + +if not core.is_compile_gpu() or not gpu_list: + exit(0) + +g_scope = core.Scope() +g_ctx = core.DeviceContext.create(core.CPUPlace()) + + +class TestNCCLInit(unittest.TestCase): + def test_init(self): + self.op_type = "ncclInit" + self.gpus = [int(g) for g in gpu_list.split(",")] + + self.inputs = {} + self.attrs = {"gpus": self.gpus} + g_scope.var("Communicator").get_communicator() + self.outputs = {"Communicator": g_scope.find_var("Communicator")} + nccl_init = create_op( + g_scope, + op_type=self.op_type, + inputs=self.inputs, + outputs=self.outputs, + attrs=self.attrs) + nccl_init.run(g_scope, g_ctx) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_nccl_reduce_op.py b/python/paddle/v2/framework/tests/test_nccl_reduce_op.py index 675ad5766c..0cee1923a6 100644 --- a/python/paddle/v2/framework/tests/test_nccl_reduce_op.py +++ b/python/paddle/v2/framework/tests/test_nccl_reduce_op.py @@ -4,3 +4,22 @@ import paddle.v2 as paddle from paddle.v2.framework.op import Operator import paddle.v2.framework.core as core from op_test import OpTest, create_op, set_input + +gpu_list = "0,1,2,3" +g_scope = core.Scope() +g_ctx = core.DeviceContext.create(core.CPUPlace()) + +if not core.is_compile_gpu() or not gpu_list: + exit(0) + + +class TestNCCLReduce(OpTest): + def setUp(self): + self.op_type = "ncclReduce" + self.gpus = [int(g) for g in gpu_list.split(",")] + + self.scope = g_scope.var("Communicator").get_communicator() + self.outputs = {"Communicator": self.scope.var("Communicator")} + + def test_check_output(self): + self.check_output() -- GitLab