From 5200c657a7899bde418afecf90f0536c1702e089 Mon Sep 17 00:00:00 2001 From: Dong Zhihong Date: Wed, 25 Oct 2017 09:05:03 -0700 Subject: [PATCH] "move Tensor to LoDTensor" --- paddle/operators/nccl_op.cc | 7 + paddle/operators/nccl_op.cu | 20 ++- paddle/operators/nccl_op.h | 50 -------- paddle/operators/nccl_op_test.cu | 214 +++++++++++++++++++++++-------- 4 files changed, 186 insertions(+), 105 deletions(-) delete mode 100644 paddle/operators/nccl_op.h diff --git a/paddle/operators/nccl_op.cc b/paddle/operators/nccl_op.cc index ec7a89d5ff4..85f589f4aae 100644 --- a/paddle/operators/nccl_op.cc +++ b/paddle/operators/nccl_op.cc @@ -74,8 +74,15 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel { // reduction == "ncclMin" || reduction == "ncclMax"), // "invalid reduction."); + // auto in_dim = x_dims[0]; ctx->SetOutputsDim("Out", x_dims); ctx->ShareLoD("X", /*->*/ "Out"); + size_t N = x_dims.size(); + auto out_dims = ctx->GetOutputsDim("Out"); + for (size_t i = 0; i < N; ++i) { + VLOG(1) << " inference (X) " << framework::product(x_dims[i]) << " (Out)" + << framework::product(out_dims[i]); + } } }; diff --git a/paddle/operators/nccl_op.cu b/paddle/operators/nccl_op.cu index 4fbdf1ce02d..c507d325f28 100644 --- a/paddle/operators/nccl_op.cu +++ b/paddle/operators/nccl_op.cu @@ -12,6 +12,7 @@ limitations under the License. */ #define EIGEN_USE_GPU #include +#include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/nccl/nccl_gpu_common.h" @@ -20,6 +21,7 @@ namespace operators { using framework::Tensor; using platform::Communicator; +using framework::LoDTensor; template class NCCLTypeWrapper; @@ -43,8 +45,8 @@ class NCCLAllReduceKernel : public framework::OpKernel { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - auto ins = ctx.MultiInput("X"); - auto outs = ctx.MultiOutput("Out"); + auto ins = ctx.MultiInput("X"); + auto outs = ctx.MultiOutput("Out"); auto* comm = ctx.Input("Communicator"); @@ -56,12 +58,24 @@ class NCCLAllReduceKernel : public framework::OpKernel { boost::get(ctx.GetPlace()).GetDeviceId(); int idx = comm->GetCommId(device_id); + size_t N = ins.size(); + for (size_t i = 0; i < N; ++i) { + VLOG(1) << " inference (X) " << framework::product(ins[i]->dims()) + << " (Out)" << framework::product(outs[i]->dims()); + } + for (size_t i = 0; i < ins.size(); ++i) { + VLOG(1) << " invoke allreduce. send " << ins[i]->numel() << " recv " + << outs[i]->numel(); + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( ins[i]->data(), outs[i]->mutable_data(ctx.GetPlace()), - outs[i]->numel() * sizeof(T), NCCLTypeWrapper::type, ncclSum, + outs[i]->numel(), NCCLTypeWrapper::type, ncclSum, comm->comms_[idx], stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + + VLOG(1) << " finished allreduce. send " << ins[i]->numel() << " recv " + << outs[i]->numel(); } } }; diff --git a/paddle/operators/nccl_op.h b/paddle/operators/nccl_op.h deleted file mode 100644 index a438e4eaa23..00000000000 --- a/paddle/operators/nccl_op.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once -#include "paddle/framework/op_registry.h" -#include "paddle/operators/nccl/nccl_gpu_common.h" - -#include - -namespace paddle { -namespace operators { - -using framework::Tensor; -using platform::Communicator; - -template -class NCCLTypeWrapper; - -template <> -class NCCLTypeWrapper { - public: - static const ncclDataType_t type = ncclFloat; -}; - -template <> -class NCCLTypeWrapper { - public: - static const ncclDataType_t type = ncclDouble; -}; - -template -class NCCLInitKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - std::vector gpus = ctx.Attr>("gpus"); - auto* comm = ctx.Output("Communicator"); - comm->InitAll(gpus); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/nccl_op_test.cu b/paddle/operators/nccl_op_test.cu index 334884d657a..0509e6ddabf 100644 --- a/paddle/operators/nccl_op_test.cu +++ b/paddle/operators/nccl_op_test.cu @@ -12,101 +12,211 @@ See the License for the specific language governing permissions and limitations under the License. */ +#define EIGEN_USE_GPU + #include #include #include -#include +#include +#include +#include #include #include "paddle/framework/block_desc.h" #include "paddle/framework/op_desc.h" -#include "paddle/framework/op_registry.h" #include "paddle/framework/program_desc.h" #include "paddle/framework/var_desc.h" +#include "paddle/operators/nccl/nccl_gpu_common.h" #include "paddle/platform/device_context.h" #include "paddle/platform/enforce.h" #include "paddle/platform/gpu_info.h" #include "paddle/platform/place.h" -USE_CPU_ONLY_OP(ncclInit); +#include "paddle/framework/op_registry.h" + +USE_NO_KERNEL_OP(ncclInit); USE_GPU_ONLY_OP(ncclAllReduce); USE_GPU_ONLY_OP(ncclReduce); USE_GPU_ONLY_OP(ncclBcastSend); USE_GPU_ONLY_OP(ncclBcastRecv); +namespace f = paddle::framework; +namespace p = paddle::platform; + static std::vector gpu_list; -namespace f = paddle::framework; -namespace ops = paddle::operators; - -void AddOp(const std::string &type, const f::VariableNameMap &inputs, - const f::VariableNameMap &outputs, f::AttributeMap attrs, - paddle::framework::BlockDescBind *block) { - for (auto kv : outputs) { - for (auto v : kv.second) { - auto var = block->Var(v); - var->SetDataType(paddle::framework::DataType::FP32); - } +// ncclInitOp with desc +// TEST(NCCL, ncclInitOp) { +// f::ProgramDescBind program; +// f::BlockDescBind *block = program.Block(0); +// f::OpDescBind *op_desc = block->AppendOp(); + +// op_desc->SetType("ncclInit"); +// op_desc->SetOutput("Communicator", {"x1"}); +// op_desc->SetAttr("gpus", {gpu_list}); +// f::Scope g_scope; +// p::DeviceContext *ctx = +// new p::CPUDeviceContext(p::CPUPlace()); + +// auto *var = g_scope.Var("x1"); +// var->GetMutable(); + +// auto op = f::OpRegistry::CreateOp(*op_desc); +// VLOG(1) << "invoke NCCLInitOp."; +// op->Run(g_scope, *ctx); +// VLOG(1) << "NCCLInitOp finished."; +// } + +// test data amount +static const f::DDim kDims = {100, 100}; +static std::vector dev_ctxs; + +void CreateContext() { + for (size_t i = 0; i < gpu_list.size(); ++i) { + p::GPUPlace place(i); + VLOG(1) << "create devicecontext : " << i; + dev_ctxs.emplace_back(new p::CUDADeviceContext(place)); } +} - auto op = block->AppendOp(); - op->SetType(type); - for (auto &kv : inputs) { - op->SetInput(kv.first, kv.second); - } - for (auto &kv : outputs) { - op->SetOutput(kv.first, kv.second); +void DestroyContext() { + for (size_t i = 0; i < gpu_list.size(); ++i) { + delete dev_ctxs[i]; } - op->SetAttrMap(attrs); } -// ncclInitOp with desc -TEST(NCCL, ncclInitOp) { +// global scope +static f::Scope g_scope; +std::mutex mu; + +template +void DeviceProgram(int gpu_id, const f::OpDescBind &op_desc, f::Scope *scope) { + std::unique_lock lk(mu); f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); - f::OpDescBind *op_desc = block->AppendOp(); - - op_desc->SetType("ncclInit"); - op_desc->SetOutput("Communicator", {"x1"}); - op_desc->SetAttr("gpus", {gpu_list}); - f::Scope g_scope; - paddle::platform::DeviceContext *ctx = - new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); - - auto *var = g_scope.Var("x1"); - var->GetMutable(); - - auto op = f::OpRegistry::CreateOp(*op_desc); - VLOG(1) << "invoke NCCLInitOp."; - op->Run(g_scope, *ctx); - VLOG(1) << "NCCLInitOp finished."; + f::OpDescBind *op1 = block->AppendOp(); + *op1 = op_desc; + + p::GPUPlace place(gpu_id); + // p::DeviceContext *ctx = + // new p::CUDADeviceContext(place); + p::DeviceContext *ctx = dev_ctxs.at(gpu_id); + VLOG(1) << "device context : " << dev_ctxs.size() << " gpu_id " << gpu_id; + + // f::Scope &local_scope = g_scope.NewScope(); + + auto *send_tensor = scope->Var("st")->GetMutable(); + auto *recv_tensor = scope->Var("rt")->GetMutable(); + send_tensor->Resize(kDims); + send_tensor->mutable_data(kDims, place); + // recv_tensor->mutable_data(kDims, place); + + std::vector send_vector(f::product(kDims), gpu_id); + send_tensor->CopyFromVector(send_vector, *ctx); + lk.unlock(); + PADDLE_ENFORCE(send_tensor->numel() == f::product(kDims), + "Tensor numel not match!"); + ctx->Wait(); + + VLOG(1) << send_tensor->numel() << " element in send tensor"; + + auto op = f::OpRegistry::CreateOp(*op1); + VLOG(1) << "Device : " << gpu_id << " invoke " << op_desc.Type(); + op->Run(*scope, *ctx); + VLOG(1) << "Device : " << gpu_id << " finished " << op_desc.Type(); } // ncclAllReduceOp with desc -TEST(NCCL, ncclInitOp) { +TEST(NCCL, ncclAllReduceOp) { f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); - f::OpDescBind *op_desc = block->AppendOp(); + f::OpDescBind *op1 = block->AppendOp(); - op_desc->SetType("ncclAllReduce"); + p::DeviceContext *ctx = new p::CPUDeviceContext(p::CPUPlace()); - op_desc->SetOutput("Communicator", {"x1"}); - op_desc->SetAttr("gpus", {gpu_list}); - f::Scope g_scope; - paddle::platform::DeviceContext *ctx = - new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); + CreateContext(); - auto *var = g_scope.Var("x1"); - var->GetMutable(); + op1->SetType("ncclInit"); + op1->SetOutput("Communicator", {"comm"}); + op1->SetAttr("gpus", {gpu_list}); - auto op = f::OpRegistry::CreateOp(*op_desc); + auto *var = g_scope.Var("comm"); + var->GetMutable(); + + auto op = f::OpRegistry::CreateOp(*op1); VLOG(1) << "invoke NCCLInitOp."; op->Run(g_scope, *ctx); VLOG(1) << "NCCLInitOp finished."; + delete ctx; + + f::OpDescBind *op2 = new f::OpDescBind; + op2->SetType("ncclAllReduce"); + op2->SetInput("X", {"st"}); + op2->SetInput("Communicator", {"comm"}); + op2->SetOutput("Out", {"rt"}); + + std::vector ths; + for (size_t i = 0; i < gpu_list.size(); ++i) { + std::thread th(DeviceProgram, gpu_list[i], *op2, + &g_scope.NewScope()); + // std::thread th([=](){ + // VLOG(1) << "thread id created : " << i; + // return 1;}); + ths.emplace_back(std::move(th)); + } + + for (size_t i = 0; i < gpu_list.size(); ++i) { + VLOG(1) << " thread joined! " << i; + ths[i].join(); + } + VLOG(1) << " main thread joined!"; + + delete op2; + g_scope.~Scope(); + DestroyContext(); + VLOG(1) << " destory contexts"; } +// ncclBcastOp with desc +// TEST(NCCL, ncclBcastOp) { +// f::ProgramDescBind program; +// f::BlockDescBind *block = program.Block(0); +// f::OpDescBind *op1= block->AppendOp(); + +// p::DeviceContext *ctx = +// new p::CPUDeviceContext(p::CPUPlace()); + +// op1->SetType("ncclInit"); +// op1->SetOutput("Communicator", {"comm"}); +// op1->SetAttr("gpus", {gpu_list}); + +// auto *var = g_scope.Var("comm"); +// var->GetMutable(); + +// auto op = f::OpRegistry::CreateOp(*op1); +// VLOG(1) << "invoke NCCLInitOp."; +// op->Run(g_scope, *ctx); +// VLOG(1) << "NCCLInitOp finished."; + +// f::OpDescBind *op2 = new f::OpDescBind; +// op2->SetType("ncclBcastSend"); +// op2->SetInput("X", {"st"}); +// op2->SetInput("Communicator", {"comm"}); +// op2->SetOutput("Out", {"rt"}); + +// std::vector ths; +// for (size_t i=0; i < gpu_list.size(); ++i) { +// std::thread th(DeviceProgram, gpu_list[i], *op2); +// ths.emplace_back(std::move(th)); +// } + +// for (size_t i=0; i < gpu_list.size(); ++i) { +// ths[i].join(); +// } +// } + int main(int argc, char **argv) { - static int dev_count = paddle::platform::GetCUDADeviceCount(); + const int dev_count = p::GetCUDADeviceCount(); if (dev_count <= 1) { LOG(WARNING) << "Cannot test multi-gpu nccl, because the CUDA device count is " -- GitLab