diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index b3292422522b64a38a50f39f04e6f0d2e15492dd..bc3e373488c9899e6e6d46d090b083332ff40666 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -29,9 +29,7 @@ namespace framework { namespace details { struct BroadcastOpHandle : public OpHandleBase { - const std::vector &local_scopes_; - const std::vector &places_; - + public: BroadcastOpHandle(const std::vector &local_scopes, const std::vector &places); @@ -41,6 +39,10 @@ struct BroadcastOpHandle : public OpHandleBase { protected: void RunImpl() override; + + private: + const std::vector &local_scopes_; + const std::vector &places_; }; } // namespace details diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc index bcd61335be0f7fe64563ee65daaf9de0760c9b1a..efc70515820d18fe61696fd697b0af0a0fef3834 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc @@ -90,7 +90,7 @@ struct TestBroadcastOpHandle { op_handle_->AddInput(dummy_var_handle); for (size_t j = 0; j < gpu_list_.size(); ++j) { - op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get(); + op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index ff6d91c1dafb0ab4cabb1646cc333e19a89eb812..7ff0efe09387b7e5d7cfe0dfe5e129ca9914d90b 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -28,8 +28,8 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope, void ComputationOpHandle::RunImpl() { auto *cur_ctx = dev_ctxes_[place_]; for (auto *in : inputs_) { - bool need_wait = - in->generated_op_ && in->generated_op_->dev_ctxes_[place_] != cur_ctx; + bool need_wait = in->generated_op_ && + in->generated_op_->DeviceContext(place_) != cur_ctx; if (need_wait) { in->generated_op_->Wait(cur_ctx); } diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index d6d2d731ca80a0fbc0a2a34027b5b7c3c1977c07..c363b973d9abbae6bea76c2458fbe82a37a342ca 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -14,6 +14,9 @@ #pragma once +#include +#include + #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" @@ -24,10 +27,7 @@ namespace paddle { namespace framework { namespace details { struct ComputationOpHandle : public OpHandleBase { - std::unique_ptr op_; - Scope *scope_; - platform::Place place_; - + public: ComputationOpHandle(const OpDesc &op_desc, Scope *scope, platform::Place place); @@ -35,6 +35,11 @@ struct ComputationOpHandle : public OpHandleBase { protected: void RunImpl() override; + + private: + std::unique_ptr op_; + Scope *scope_; + platform::Place place_; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/fetch_op_handle.h b/paddle/fluid/framework/details/fetch_op_handle.h index 904b2d669f8b156b99197afb0155380d1170a68b..b49f3df338dc11310a4a0c27c8aaae3602373fcc 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.h +++ b/paddle/fluid/framework/details/fetch_op_handle.h @@ -14,6 +14,9 @@ #pragma once +#include +#include + #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/scope.h" @@ -24,11 +27,7 @@ namespace framework { namespace details { struct FetchOpHandle : public OpHandleBase { - FeedFetchList *data_; - size_t offset_; - std::vector *local_scopes_; - std::vector tensors_; - + public: FetchOpHandle(FeedFetchList *data, size_t offset, std::vector *local_scopes); @@ -42,6 +41,12 @@ struct FetchOpHandle : public OpHandleBase { protected: void RunImpl() override; + + private: + FeedFetchList *data_; + size_t offset_; + std::vector *local_scopes_; + std::vector tensors_; }; } // namespace details diff --git a/paddle/fluid/framework/details/gather_op_handle.h b/paddle/fluid/framework/details/gather_op_handle.h index 6c0231f642c05e6b558b7e2518a15e08c816fe4b..d11ef8556aa8840949ca8dc7aa176413f70b9f22 100644 --- a/paddle/fluid/framework/details/gather_op_handle.h +++ b/paddle/fluid/framework/details/gather_op_handle.h @@ -29,9 +29,7 @@ namespace framework { namespace details { struct GatherOpHandle : public OpHandleBase { - const std::vector &local_scopes_; - const std::vector &places_; - + public: GatherOpHandle(const std::vector &local_scopes, const std::vector &places); @@ -41,6 +39,10 @@ struct GatherOpHandle : public OpHandleBase { protected: void RunImpl() override; + + private: + const std::vector &local_scopes_; + const std::vector &places_; }; } // namespace details diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index 2da8c89d2df73215b748f102d9bbfc5b742cf97f..9481579f6c6f8272ab7b78a15d57c09a4d3245a4 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -78,7 +78,7 @@ struct TestGatherOpHandle { op_handle_.reset(new GatherOpHandle(local_scopes_, gpu_list_)); // add input for (size_t j = 0; j < gpu_list_.size(); ++j) { - op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get(); + op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]); vars_.emplace_back(in_var_handle); op_handle_->AddInput(in_var_handle); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index d2b6a35a5d5c260b023c68ec4684da95a5b79e81..002952436e58eecfcecf5c9fa40c01b795170681 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -60,7 +60,8 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, const platform::Place &p, const size_t &i) const { auto *op_handle = result->ops_.back().get(); - op_handle->dev_ctxes_[p] = platform::DeviceContextPool::Instance().Get(p); + op_handle->SetDeviceContext(p, + platform::DeviceContextPool::Instance().Get(p)); auto var_names = op.InputArgumentNames(); diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h index ad14a3c5cb4625fa121cad2daed389c441e78771..a0c321843e3fc5abcbd1ef2ce2e153250269aa7d 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h @@ -27,10 +27,6 @@ namespace framework { namespace details { struct NCCLAllReduceOpHandle : public OpHandleBase { - const std::vector &local_scopes_; - const std::vector &places_; - const platform::NCCLContextMap &nccl_ctxs_; - NCCLAllReduceOpHandle(const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap &ctxs); @@ -43,6 +39,11 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { protected: void RunImpl() override; + + private: + const std::vector &local_scopes_; + const std::vector &places_; + const platform::NCCLContextMap &nccl_ctxs_; }; } // namespace details diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index a9a6c8d39cf8741f7d9c91579a650ad742cec381..00f213f3ed294adcce7c540e3ff346de8e2be7fb 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -27,28 +27,15 @@ namespace details { constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; class OpHandleBase { - private: - DISABLE_COPY_AND_ASSIGN(OpHandleBase); - public: - std::vector inputs_; - std::vector outputs_; - std::unordered_map - dev_ctxes_; - -#ifdef PADDLE_WITH_CUDA - std::unordered_map events_; -#endif - OpHandleBase() {} + virtual ~OpHandleBase(); + std::string DebugString() const; virtual std::string Name() const = 0; - virtual ~OpHandleBase(); - void Run(bool use_event); virtual void Wait(platform::DeviceContext *waited_dev); @@ -61,6 +48,18 @@ class OpHandleBase { // will likely block other computations. virtual bool IsMultiDeviceTransfer() { return false; } + const platform::DeviceContext *DeviceContext(platform::Place place) { + return dev_ctxes_[place]; + } + + void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) { + dev_ctxes_[place] = ctx_; + } + + const std::vector &Inputs() const { return inputs_; } + + const std::vector &Outputs() const { return outputs_; } + protected: void RunAndRecordEvent(const std::function &callback); @@ -68,6 +67,18 @@ class OpHandleBase { const std::function &callback); virtual void RunImpl() = 0; + + std::vector inputs_; + std::vector outputs_; + std::unordered_map + dev_ctxes_; + +#ifdef PADDLE_WITH_CUDA + std::unordered_map events_; +#endif + + DISABLE_COPY_AND_ASSIGN(OpHandleBase); }; } // namespace details diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h index ab7353a4fc56bebfe04696efd838dc4559218058..d93d599d46f130cf98f39f15697ce994a31e20c3 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h @@ -14,6 +14,8 @@ #pragma once +#include + #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" @@ -23,10 +25,6 @@ namespace framework { namespace details { struct ScaleLossGradOpHandle : public OpHandleBase { - float coeff_; - Scope *scope_; - platform::Place place_; - ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place, platform::DeviceContext *context); @@ -36,6 +34,11 @@ struct ScaleLossGradOpHandle : public OpHandleBase { protected: void RunImpl() override; + + private: + float coeff_; + Scope *scope_; + platform::Place place_; }; } // namespace details diff --git a/paddle/fluid/framework/details/send_op_handle.h b/paddle/fluid/framework/details/send_op_handle.h index 173f9d726145aeb9e85cc0fb9056eb57bf484098..2f78811fad50642b5e45776c41910df6f4cc48f6 100644 --- a/paddle/fluid/framework/details/send_op_handle.h +++ b/paddle/fluid/framework/details/send_op_handle.h @@ -28,10 +28,6 @@ namespace framework { namespace details { struct SendOpHandle : public OpHandleBase { - std::unique_ptr op_; - const Scope* local_scope_; - const platform::Place& place_; - SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, const platform::Place& place); @@ -43,6 +39,11 @@ struct SendOpHandle : public OpHandleBase { protected: void RunImpl() override; + + private: + std::unique_ptr op_; + const Scope* local_scope_; + const platform::Place& place_; }; } // namespace details diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 25e8c77bb489546092b2a93e052da7dd0dd5edf4..6a567527550883add08031e50aa8de2b204cf13d 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -117,12 +117,12 @@ void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) { std::string op_name = "op_" + std::to_string(op_id++); sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" << std::endl; - for (auto in : op->inputs_) { + for (auto in : op->Inputs()) { std::string var_name = "var_" + std::to_string(vars[in]); sout << var_name << " -> " << op_name << std::endl; } - for (auto out : op->outputs_) { + for (auto out : op->Outputs()) { std::string var_name = "var_" + std::to_string(vars[out]); sout << op_name << " -> " << var_name << std::endl; } @@ -133,7 +133,7 @@ void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) { void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { for (auto &op : graph->ops_) { - if (!op->outputs_.empty()) { + if (!op->Outputs().empty()) { continue; } auto *dummy_leaf = new DummyVarHandle(); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 3d2bd633afff1d453d00faeca3b3dcf77f8dd5d7..14e75e7b7b582d994b83d6c74ad9947135f6c449 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -53,7 +53,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( }; auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) { - pending_ops.insert({&op_instance, op_instance.inputs_.size()}); + pending_ops.insert({&op_instance, op_instance.Inputs().size()}); }; // Transform SSAGraph to pending_ops & pending_vars @@ -69,7 +69,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( } for (auto &op : graph_->ops_) { - if (op->inputs_.empty()) { // Special case, Op has no input. + if (op->Inputs().empty()) { // Special case, Op has no input. ready_ops.insert(op.get()); } else { InsertPendingOp(*op); @@ -99,7 +99,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( fetch_ops.emplace_back(op); for (auto &p : places_) { - op->dev_ctxes_[p] = fetch_ctxs_.Get(p); + op->SetDeviceContext(p, fetch_ctxs_.Get(p)); } for (auto *var : vars) { @@ -180,7 +180,7 @@ void ThreadedSSAGraphExecutor::RunOp( op->Run(use_event_); VLOG(10) << op << " " << op->Name() << " Done "; running_ops_--; - ready_var_q->Extend(op->outputs_); + ready_var_q->Extend(op->Outputs()); VLOG(10) << op << " " << op->Name() << "Signal posted"; } catch (platform::EnforceNotMet ex) { exception_.reset(new platform::EnforceNotMet(ex));