From 01c6618de904e1d49660486cd65f8810cc9665a3 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Sun, 8 Apr 2018 09:38:26 +0800 Subject: [PATCH] first wip commit --- .../fluid/framework/details/send_op_handle.cc | 78 +++++++++++++++++++ .../fluid/framework/details/send_op_handle.h | 50 ++++++++++++ paddle/fluid/operators/detail/grpc_client.cc | 3 +- 3 files changed, 129 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/framework/details/send_op_handle.cc create mode 100644 paddle/fluid/framework/details/send_op_handle.h diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc new file mode 100644 index 00000000000..bd2a0a9c298 --- /dev/null +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/send_op_handle.h" + +namespace paddle { +namespace framework { +namespace details { + +SendOpHandle::SendOpHandle(const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap &ctxs) + : local_scopes_(local_scopes), places_(places) {} + +void SendOpHandle::RunImpl() { + if (inputs_.size() == 1) { + return; // No need to all reduce when GPU count = 1; + } else { + // Wait input done + for (auto *in : inputs_) { + auto &p = static_cast(in)->place_; + in->generated_op_->Wait(dev_ctxes_[p]); + } + + auto &var_name = static_cast(this->inputs_[0])->name_; + int dtype = -1; + size_t numel = 0; + + std::vector> all_reduce_calls; + + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto &p = places_[i]; + auto *s = local_scopes_[i]; + int dev_id = boost::get(p).device; + + auto &lod_tensor = s->FindVar(var_name)->Get(); + void *buffer = const_cast(lod_tensor.data()); + + if (dtype == -1) { + dtype = platform::ToNCCLDataType(lod_tensor.type()); + } + + if (numel == 0) { + numel = static_cast(lod_tensor.numel()); + } + + auto &nccl_ctx = nccl_ctxs_.at(dev_id); + auto stream = nccl_ctx.stream(); + auto comm = nccl_ctx.comm_; + all_reduce_calls.emplace_back([=] { + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + buffer, buffer, numel, static_cast(dtype), ncclSum, + comm, stream)); + }); + } + + platform::NCCLGroupGuard guard; + for (auto &call : all_reduce_calls) { + call(); + } + } +} + +std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; } +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/send_op_handle.h b/paddle/fluid/framework/details/send_op_handle.h new file mode 100644 index 00000000000..515f1a10a8d --- /dev/null +++ b/paddle/fluid/framework/details/send_op_handle.h @@ -0,0 +1,50 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/nccl_helper.h" + +namespace paddle { +namespace framework { +namespace details { + +struct SendOpHandle : public OpHandleBase { + const std::vector &local_scopes_; + const std::vector &places_; + const platform::NCCLContextMap &nccl_ctxs_; + + SendOpHandle(const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap &ctxs); + + std::string Name() const override; + + // Delay and buffer nccl_all_reduce together can significantly increase + // performance. Disable this feature by returning false. + bool IsMultiDeviceTransfer() override { return true; }; + + protected: + void RunImpl() override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ef987d07f08..3cf286575e4 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -65,9 +65,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, } void ProcGetResponse(const VarHandle& var_h, - // const sendrecv::VariableMessage& ret_msg) { const ::grpc::ByteBuffer& ret_msg) { - framework::Variable* outvar = NULL; + framework::Variable* outvar = nullptr; DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar); } -- GitLab