diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 6522a7a69f1652605e4acce49cfbc65fa1332ad5..9d2dc290282ecf341248b50efe6be447843ee040 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(details) # ddim lib proto_library(framework_proto SRCS framework.proto) @@ -87,7 +88,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog lod_rank_table feed_fetch_method) cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope - framework_proto backward glog lod_rank_table feed_fetch_method executor simple_threadpool concat) + framework_proto backward glog lod_rank_table feed_fetch_method executor simple_threadpool var_handle) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5074715e2ef4dcc7e3da57d216cc3057e4a388bd --- /dev/null +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(var_handle SRCS var_handle.cc DEPS place) diff --git a/paddle/fluid/framework/details/var_handle.cc b/paddle/fluid/framework/details/var_handle.cc new file mode 100644 index 0000000000000000000000000000000000000000..6f00abd9473a84a77ed1a39015e2ae079e00be79 --- /dev/null +++ b/paddle/fluid/framework/details/var_handle.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/var_handle.h" + +namespace paddle { +namespace framework { +namespace details { + +VarHandleBase::~VarHandleBase() {} + +std::string VarHandle::DebugString() const { + std::stringstream ss; + ss << name_ << ":" << place_; + return ss.str(); +} + +std::string DummyVarHandle::DebugString() const { return "dummy"; } +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..613ff901b151df058d88c81a0fa3afd31ebbdb10 --- /dev/null +++ b/paddle/fluid/framework/details/var_handle.h @@ -0,0 +1,66 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { + +struct OpHandleBase; + +namespace details { + +// VarHandleBase is the var node in the dependency graph. +// A variable can only be generated by a single operator. i.e. +// This is a single assignment graph. +struct VarHandleBase { + virtual ~VarHandleBase(); + virtual std::string DebugString() const = 0; + + // The operator who generate this variable. nullptr if the variable + // is a root node. + OpHandleBase *generated_op_; + + // Operators which depend on this variable ready. + std::unordered_set pending_ops_; +}; + +// VarHandle is actually a single version of Runtime Variable. +// Variable in Runtime mapped to many VarHandles in Graph. +// Each assignment will generate a new var handle with newer version. +// +// NOTE: runtime variables have place. +struct VarHandle : public VarHandleBase { + std::string DebugString() const override; + + // version field currently is not used, however, just store the version to + // debug easily. + size_t version_; + std::string name_; + platform::Place place_; +}; + +// Dummy Variable. It is used to represent dependencies between operators +struct DummyVarHandle : public VarHandleBase { + std::string DebugString() const override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index a5221d03d614096315f281102ee43f360035f426..2b094eba1e1a28828cf07dd6329b3d91a12e80f9 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "lod_tensor.h" #include "lod_tensor_array.h" #include "op_registry.h" +#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/platform/nccl_helper.h" @@ -25,35 +26,11 @@ limitations under the License. */ namespace paddle { namespace framework { -struct OpHandle; +using details::DummyVarHandle; +using details::VarHandle; +using details::VarHandleBase; -struct VarHandleBase { - virtual ~VarHandleBase() {} - virtual std::string DebugString() const = 0; - - OpHandle *generated_op_; - std::unordered_set pending_ops_; -}; - -struct VarHandle : public VarHandleBase { - std::string DebugString() const override { - std::stringstream ss; - ss << name_ << ":" << place_; - return ss.str(); - } - - // version field currently is not used, however, just store the version to - // debug easily. - size_t version_; - std::string name_; - platform::Place place_; -}; - -struct DummyVarHandle : public VarHandleBase { - std::string DebugString() const override { return "dummy"; } -}; - -struct OpHandle { +struct OpHandleBase { std::vector inputs_; std::vector outputs_; std::unordered_map *local_scopes_; @@ -216,51 +193,13 @@ class ParallelExecutorPrivate { std::vector local_scopes_; Scope *global_scope_; -#ifdef PADDLE_WITH_CUDA - struct NCCLContext { - std::unique_ptr ctx_; - ncclComm_t comm; - - explicit NCCLContext(int dev_id) { - ctx_.reset(new platform::CUDADeviceContext(platform::CUDAPlace(dev_id))); - } - - cudaStream_t stream() const { return ctx_->stream(); } - - int device_id() const { - return boost::get(ctx_->GetPlace()).device; - } - - static void InitNCCLContext(std::unordered_map &contexts, - const std::vector &places) { - std::vector comms; - std::vector devs; - comms.resize(contexts.size()); - devs.reserve(contexts.size()); - - for (auto &p : places) { - devs.push_back(boost::get(p).device); - } - - PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( - &comms[0], static_cast(contexts.size()), &devs[0])); - - int i = 0; - for (auto &dev_id : devs) { - contexts.at(dev_id).comm = comms[i++]; - } - } - }; - - std::unordered_map communication_streams_; + std::unordered_map communication_streams_; - NCCLContext &GetNCCLCtx(platform::Place p) { + platform::NCCLContext &GetNCCLCtx(platform::Place p) { int dev_id = boost::get(p).device; return communication_streams_.at(dev_id); } -#endif - platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) { if (platform::is_cpu_place(place) || local_scopes_.size() == 1) { return const_cast( @@ -282,27 +221,95 @@ class ParallelExecutorPrivate { vars_; std::unordered_set> dep_vars_; - std::vector> ops_; + std::vector> ops_; // Use a simpler thread pool, might be faster. std::unique_ptr pool_; std::unique_ptr exception_; -}; -struct NCCLAllReduceOpHandle : public OpHandle { - ParallelExecutorPrivate *member_; + VarHandle *GetVarHandle(const std::string &each_var_name, + const platform::Place &place) { + auto &var_holders = vars_[place]; + auto &var_holder = var_holders[each_var_name]; + VarHandle *var = nullptr; + if (var_holder.empty()) { + auto &init_var = var_holder[0]; + init_var.place_ = place; + init_var.name_ = each_var_name; + init_var.generated_op_ = nullptr; + init_var.version_ = 0; + var = &init_var; + } else { + var = &var_holder.rbegin()->second; + } + return var; + } - explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member) - : member_(member) {} + void RunOp( + bool use_event, + std::unordered_map> &pending_vars, + OpHandleBase *op) { + std::vector *> *ready_buffer = + new std::vector *>(); + for (auto *var : op->outputs_) { + ready_buffer->emplace_back(&pending_vars[var]); + } + + auto op_run = [ready_buffer, op, this, use_event] { + try { + VLOG(10) << op->DebugString(); + op->Run(use_event); + for (auto *ready : *ready_buffer) { + ready->store(true, std::memory_order_release); + } + delete ready_buffer; + } catch (platform::EnforceNotMet ex) { + exception_.reset(new platform::EnforceNotMet(ex)); + } catch (...) { + LOG(FATAL) << "Unknown exception catched"; + } + }; + if (pool_) { + pool_->enqueue(op_run); + } else { + op_run(); + } + } + + void GenerateVar(OpHandleBase *op_handle, const std::string &each_var_name, + const platform::Place &place) { + auto &vars = vars_[place][each_var_name]; + size_t version = vars.size(); + auto &var = vars[version]; + var.version_ = version; + var.generated_op_ = op_handle; + var.name_ = each_var_name; + var.place_ = place; + op_handle->outputs_.emplace_back(&var); + } +}; // namespace framework + +struct NCCLAllReduceOpHandle : public OpHandleBase { + const std::vector &local_scopes_; + const std::vector &places_; + const std::unordered_map &communication_ctxs_; + + explicit NCCLAllReduceOpHandle( + const std::vector &local_scopes, + const std::vector &places, + const std::unordered_map &ctxs) + : local_scopes_(local_scopes), + places_(places), + communication_ctxs_(ctxs) {} void Wait(platform::DeviceContext *waited_dev) override { - OpHandle::Wait(waited_dev); + OpHandleBase::Wait(waited_dev); } protected: void RunImpl() override { - if (this->inputs_.size() == 1) { + if (inputs_.size() == 1) { return; // No need to all reduce when GPU count = 1; } else { // Wait input done @@ -317,9 +324,9 @@ struct NCCLAllReduceOpHandle : public OpHandle { platform::NCCLGroupGuard guard; - for (size_t i = 0; i < member_->local_scopes_.size(); ++i) { - auto &p = member_->places_[i]; - auto *s = member_->local_scopes_[i]; + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto &p = places_[i]; + auto *s = local_scopes_[i]; int dev_id = boost::get(p).device; auto &lod_tensor = s->FindVar(var_name)->Get(); @@ -336,16 +343,16 @@ struct NCCLAllReduceOpHandle : public OpHandle { if (numel == 0) { numel = static_cast(lod_tensor.numel()); } - auto &nccl_ctx = member_->communication_streams_.at(dev_id); + auto &nccl_ctx = communication_ctxs_.at(dev_id); PADDLE_ENFORCE(platform::dynload::ncclAllReduce( buffer, buffer, numel, static_cast(dtype), ncclSum, - nccl_ctx.comm, nccl_ctx.stream())); + nccl_ctx.comm_, nccl_ctx.stream())); } } } }; -struct ComputationOpHandle : public OpHandle { +struct ComputationOpHandle : public OpHandleBase { std::unique_ptr op_; Scope *scope_; platform::Place place_; @@ -443,14 +450,14 @@ void ParallelExecutor::ConstructDependencyGraph( auto var_names = op->InputArgumentNames(); for (auto &each_var_name : var_names) { - VarHandle *var = GetVarHandle(each_var_name, p); + VarHandle *var = member_->GetVarHandle(each_var_name, p); op_handle->inputs_.emplace_back(var); var->pending_ops_.emplace(op_handle); } var_names = op->OutputArgumentNames(); for (auto &each_var_name : var_names) { - GenerateVar(op_handle, each_var_name, p); + member_->GenerateVar(op_handle, each_var_name, p); } if (is_forwarding) { @@ -468,7 +475,7 @@ void ParallelExecutor::ConstructDependencyGraph( // loss->pending_ops_.emplace_back(op_handle); // op_handle->inputs_.emplace_back(loss); - GenerateVar(op_handle, loss_var_name + "@GRAD", p); + member_->GenerateVar(op_handle, loss_var_name + "@GRAD", p); change_forward = true; } } @@ -483,7 +490,9 @@ void ParallelExecutor::ConstructDependencyGraph( for (auto &og : var_names) { if (grads.count(og) != 0) { // is param grad // Insert NCCL AllReduce Op - member_->ops_.emplace_back(new NCCLAllReduceOpHandle(member_)); + member_->ops_.emplace_back(new NCCLAllReduceOpHandle( + member_->local_scopes_, member_->places_, + member_->communication_streams_)); auto *op_handle = member_->ops_.back().get(); for (size_t i = 0; i < member_->places_.size(); ++i) { @@ -562,37 +571,6 @@ void ParallelExecutor::PolishGraphToSupportDataHazards() const { } } -void ParallelExecutor::GenerateVar(OpHandle *op_handle, - const std::string &each_var_name, - const platform::Place &place) const { - auto &vars = member_->vars_[place][each_var_name]; - size_t version = vars.size(); - auto &var = vars[version]; - var.version_ = version; - var.generated_op_ = op_handle; - var.name_ = each_var_name; - var.place_ = place; - op_handle->outputs_.emplace_back(&var); -} - -VarHandle *ParallelExecutor::GetVarHandle(const std::string &each_var_name, - const platform::Place &place) const { - auto &var_holders = member_->vars_[place]; - auto &var_holder = var_holders[each_var_name]; - VarHandle *var = nullptr; - if (var_holder.empty()) { - auto &init_var = var_holder[0]; - init_var.place_ = place; - init_var.name_ = each_var_name; - init_var.generated_op_ = nullptr; - init_var.version_ = 0; - var = &init_var; - } else { - var = &var_holder.rbegin()->second; - } - return var; -} - void ParallelExecutor::BCastParamsToGPUs( const ProgramDesc &startup_program) const { #ifdef PADDLE_WITH_CUDA @@ -621,8 +599,8 @@ void ParallelExecutor::BCastParamsToGPUs( } auto &nccl_ctx = member_->GetNCCLCtx(place); - platform::dynload::ncclBcast(buffer, numel, data_type, 0, nccl_ctx.comm, - nccl_ctx.stream()); + platform::dynload::ncclBcast(buffer, numel, data_type, 0, + nccl_ctx.comm_, nccl_ctx.stream()); } } @@ -640,12 +618,12 @@ void ParallelExecutor::BuildNCCLCommunicator() const { for (auto &place : member_->places_) { int dev_id = boost::get(place).device; - member_->communication_streams_.emplace( - dev_id, ParallelExecutorPrivate::NCCLContext(dev_id)); + member_->communication_streams_.emplace(dev_id, + platform::NCCLContext(dev_id)); } - ParallelExecutorPrivate::NCCLContext::InitNCCLContext( - member_->communication_streams_, member_->places_); + platform::NCCLContext::InitNCCLContext(member_->communication_streams_, + member_->places_); #endif } @@ -656,7 +634,7 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, // Version --> VarHandle member_->exception_.reset(); std::unordered_map> pending_vars; - std::unordered_map pending_ops; + std::unordered_map pending_ops; std::vector dummy_vars; for (auto &place_pair : member_->vars_) { @@ -672,7 +650,7 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, pending_vars[var.get()] = var->generated_op_ == nullptr; } - std::vector to_run; + std::vector to_run; for (auto &op : member_->ops_) { if (op->inputs_.empty()) { // Special case, Op has no input. @@ -722,7 +700,7 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, } for (auto *op : to_run) { - RunOp(use_event, pending_vars, op); + member_->RunOp(use_event, pending_vars, op); } while (!pending_vars.empty()) { @@ -750,7 +728,7 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, } for (auto *op : to_run) { pending_ops.erase(op); - RunOp(use_event, pending_vars, op); + member_->RunOp(use_event, pending_vars, op); } } @@ -762,35 +740,5 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, fetched_data; } -void ParallelExecutor::RunOp( - bool use_event, - std::unordered_map> &pending_vars, - OpHandle *op) const { - std::vector *> *ready_buffer = - new std::vector *>(); - for (auto *var : op->outputs_) { - ready_buffer->emplace_back(&pending_vars[var]); - } - - auto op_run = [ready_buffer, op, this, use_event] { - try { - VLOG(10) << op->DebugString(); - op->Run(use_event); - for (auto *ready : *ready_buffer) { - ready->store(true, std::memory_order_release); - } - delete ready_buffer; - } catch (platform::EnforceNotMet ex) { - member_->exception_.reset(new platform::EnforceNotMet(ex)); - } catch (...) { - LOG(FATAL) << "Unknown exception catched"; - } - }; - if (member_->pool_) { - member_->pool_->enqueue(op_run); - } else { - op_run(); - } -} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index c206e726a71d1c8729ee65213f83118d4cde7d1a..466b5f5f62d4d8ccbebb64b578911c67dfcd0709 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -29,9 +29,6 @@ namespace paddle { namespace framework { class ParallelExecutorPrivate; -class VarHandle; -class OpHandle; -class VarHandleBase; class ParallelExecutor { public: @@ -50,23 +47,12 @@ class ParallelExecutor { void BCastParamsToGPUs(const ProgramDesc& startup_program) const; - VarHandle* GetVarHandle(const std::string& each_var_name, - const platform::Place& place) const; - - void GenerateVar(OpHandle* op_handle, const std::string& each_var_name, - const platform::Place& place) const; - void ConstructDependencyGraph(const std::unordered_set& params, const ProgramDesc& main_program, const std::string& loss_var_name) const; void BuildNCCLCommunicator() const; - void RunOp( - bool use_event, - std::unordered_map>& pending_vars, - OpHandle* op) const; - void PolishGraphToSupportDataHazards() const; }; diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index cceceda8ad83824eba6b50e061f4cd9e03d1b354..3db846b0247bd68edc8041e9912ab34492d871ed 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -47,11 +47,45 @@ class NCCLGroupGuard { } private: - static std::mutex& mutex() { + static std::mutex &mutex() { static std::mutex mtx; return mtx; } }; +struct NCCLContext { + std::unique_ptr ctx_; + ncclComm_t comm_; + + explicit NCCLContext(int dev_id) + : ctx_(new CUDADeviceContext(CUDAPlace(dev_id))) {} + + cudaStream_t stream() const { return ctx_->stream(); } + + int device_id() const { + return boost::get(ctx_->GetPlace()).device; + } + + static void InitNCCLContext(std::unordered_map &contexts, + const std::vector &places) { + std::vector comms; + std::vector devs; + comms.resize(contexts.size()); + devs.reserve(contexts.size()); + + for (auto &p : places) { + devs.push_back(boost::get(p).device); + } + + PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( + &comms[0], static_cast(contexts.size()), &devs[0])); + + int i = 0; + for (auto &dev_id : devs) { + contexts.at(dev_id).comm_ = comms[i++]; + } + } +}; + } // namespace platform } // namespace paddle