diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index dbd375aa31bfbdcb109b6302acf23b3bb3b6befe..627370cd2df7317b4d32aa967565aaf9cf0c7a08 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method) -cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index c026e6c100a303b43650f08cd12d7260258c8f7e..c106761f72e689ff53867ecad8e36b6038173d0e 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -7,6 +7,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) +cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) @@ -28,6 +29,9 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) + +cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer) + cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index 629aa00cb817c4b1446e7b750ca62a7c6b1db670..8036f756b6d6506684c109ab881d546f38176a10 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -59,8 +59,8 @@ struct BroadcastOpHandle : public OpHandleBase { void RunImpl() override; private: - const std::vector<Scope *> &local_scopes_; - const std::vector<platform::Place> &places_; + std::vector<Scope *> local_scopes_; + std::vector<platform::Place> places_; #ifdef PADDLE_WITH_CUDA const platform::NCCLContextMap *nccl_ctxs_; #endif diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 91bdfe6134ffbd1404336c9d6d1222a505084b2b..64e83acb4dc1995800c4ca3caf81668b24a7c9fe 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -14,6 +14,8 @@ #pragma once +#include <string> + namespace paddle { namespace framework { namespace details { @@ -29,6 +31,8 @@ struct BuildStrategy { ReduceStrategy reduce_{ReduceStrategy::kAllReduce}; GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice}; + + std::string debug_graphviz_path_{""}; }; } // namespace details diff --git a/paddle/fluid/framework/details/graph_builder_factory.cc b/paddle/fluid/framework/details/graph_builder_factory.cc new file mode 100644 index 0000000000000000000000000000000000000000..a04b9bb63c06b40ff5c30c9792cdfad5d64d404c --- /dev/null +++ b/paddle/fluid/framework/details/graph_builder_factory.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/graph_builder_factory.h" +#include <fstream> +#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include "paddle/fluid/framework/details/ssa_graph_printer.h" + +namespace paddle { +namespace framework { +namespace details { +std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() { + std::unique_ptr<SSAGraphBuilder> res( +#ifdef PADDLE_WITH_CUDA + new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_, + local_scopes_, nccl_ctxs_, strategy_) +#else + new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_, + local_scopes_, strategy_) +#endif + ); // NOLINT + + if (!strategy_.debug_graphviz_path_.empty()) { + std::unique_ptr<std::ostream> fout( + new std::ofstream(strategy_.debug_graphviz_path_)); + PADDLE_ENFORCE(fout->good()); + std::unique_ptr<GraphvizSSAGraphPrinter> graphviz_printer( + new GraphvizSSAGraphPrinter()); + res.reset(new SSAGraghBuilderWithPrinter( + std::move(fout), std::move(graphviz_printer), std::move(res))); + } + return res; +} +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/graph_builder_factory.h b/paddle/fluid/framework/details/graph_builder_factory.h new file mode 100644 index 0000000000000000000000000000000000000000..857ab12d684e19788597e144fc0c46571d06aafc --- /dev/null +++ b/paddle/fluid/framework/details/graph_builder_factory.h @@ -0,0 +1,67 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include <memory> +#include <string> +#include <vector> +#include "paddle/fluid/framework/details/build_strategy.h" +#include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include "paddle/fluid/platform/place.h" + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace framework { +class Scope; +namespace details { + +class SSAGraphBuilderFactory { + public: + SSAGraphBuilderFactory(const std::vector<platform::Place>& places, + const std::string& loss_var_name, + const std::unordered_set<std::string>& param_names, + const std::vector<Scope*>& local_scopes, + const BuildStrategy& strategy) + : places_(places), + loss_var_name_(loss_var_name), + param_names_(param_names), + local_scopes_(local_scopes), + strategy_(strategy) {} + +#ifdef PADDLE_WITH_CUDA + void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) { + nccl_ctxs_ = nccl_ctxs; + } +#endif + + std::unique_ptr<SSAGraphBuilder> Create(); + + private: + std::vector<platform::Place> places_; + std::string loss_var_name_; + std::unordered_set<std::string> param_names_; + std::vector<Scope*> local_scopes_; + BuildStrategy strategy_; + +#ifdef PADDLE_WITH_CUDA + platform::NCCLContextMap* nccl_ctxs_; +#endif +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 17baacd13eecac8f410631fe9e94788da4fff848..0c4d369e889cf2cca7722dac14a5268fdacabeb4 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -30,10 +30,6 @@ #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #endif -DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot", - "the ssa graph path only print with GLOG_v=10," - "default /tmp/graph.dot"); - namespace paddle { namespace framework { namespace details { @@ -277,11 +273,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( */ AddOutputToLeafOps(&result); - if (VLOG_IS_ON(10)) { - std::ofstream fout(FLAGS_ssa_graph_path); - PrintGraphviz(*graph, fout); - } - return std::unique_ptr<SSAGraph>(graph); } diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h index a0c321843e3fc5abcbd1ef2ce2e153250269aa7d..8e98d894b828b4162059b30f5c6a74cfc06f402e 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h @@ -41,8 +41,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { void RunImpl() override; private: - const std::vector<Scope *> &local_scopes_; - const std::vector<platform::Place> &places_; + std::vector<Scope *> local_scopes_; + std::vector<platform::Place> places_; const platform::NCCLContextMap &nccl_ctxs_; }; diff --git a/paddle/fluid/framework/details/reduce_op_handle.h b/paddle/fluid/framework/details/reduce_op_handle.h index c652a2f4eb0f9b73cb19ebbd9d0809210b280ad3..4d14334cdfe06e2e805c2577458d6689e6324cc7 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.h +++ b/paddle/fluid/framework/details/reduce_op_handle.h @@ -32,8 +32,8 @@ namespace framework { namespace details { struct ReduceOpHandle : public OpHandleBase { - const std::vector<Scope *> &local_scopes_; - const std::vector<platform::Place> &places_; + std::vector<Scope *> local_scopes_; + std::vector<platform::Place> places_; #ifdef PADDLE_WITH_CUDA const platform::NCCLContextMap *nccl_ctxs_; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 6a567527550883add08031e50aa8de2b204cf13d..211113c7979ee95d896c0a57879f7b3ad13b36ef 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -73,64 +73,6 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, op_handle->AddOutput(var); } -template <typename Callback> -void IterAllVar(const SSAGraph &graph, Callback callback) { - for (auto &each : graph.vars_) { - for (auto &pair1 : each) { - for (auto &pair2 : pair1.second) { - callback(*pair2); - } - } - } - - for (auto &var : graph.dep_vars_) { - callback(*var); - } -} - -void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) { - size_t var_id = 0; - std::unordered_map<const VarHandleBase *, size_t> vars; - - sout << "digraph G {\n"; - - IterAllVar(graph, [&](const VarHandleBase &var) { - auto *var_ptr = &var; - auto *var_handle_ptr = dynamic_cast<const VarHandle *>(var_ptr); - auto *dummy_ptr = dynamic_cast<const DummyVarHandle *>(var_ptr); - - size_t cur_var_id = var_id++; - vars[var_ptr] = cur_var_id; - - if (var_handle_ptr) { - sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_ - << "\\n" - << var_handle_ptr->place_ << "\\n" - << var_handle_ptr->version_ << "\"]" << std::endl; - } else if (dummy_ptr) { - sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl; - } - }); - - size_t op_id = 0; - for (auto &op : graph.ops_) { - std::string op_name = "op_" + std::to_string(op_id++); - sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" - << std::endl; - for (auto in : op->Inputs()) { - std::string var_name = "var_" + std::to_string(vars[in]); - sout << var_name << " -> " << op_name << std::endl; - } - - for (auto out : op->Outputs()) { - std::string var_name = "var_" + std::to_string(vars[out]); - sout << op_name << " -> " << var_name << std::endl; - } - } - - sout << "}\n"; -} - void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { for (auto &op : graph->ops_) { if (!op->Outputs().empty()) { diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 64e5d93081eb76c56898bbeb530e37364619fdbb..5fc12a44b51fae26e5a8f5fdba952d3879e82d0f 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -55,8 +55,6 @@ class SSAGraphBuilder { const platform::Place &place, size_t place_offset); static void AddOutputToLeafOps(SSAGraph *graph); - - static void PrintGraphviz(const SSAGraph &graph, std::ostream &sout); }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/ssa_graph_printer.cc b/paddle/fluid/framework/details/ssa_graph_printer.cc new file mode 100644 index 0000000000000000000000000000000000000000..22a40ca4b25cdd8ed9856b6c71bffc79561edcac --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_printer.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/ssa_graph_printer.h" +#include <string> +#include "paddle/fluid/framework/details/ssa_graph.h" + +namespace paddle { +namespace framework { +namespace details { + +template <typename Callback> +static inline void IterAllVar(const SSAGraph &graph, Callback callback) { + for (auto &each : graph.vars_) { + for (auto &pair1 : each) { + for (auto &pair2 : pair1.second) { + callback(*pair2); + } + } + } + + for (auto &var : graph.dep_vars_) { + callback(*var); + } +} + +void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph, + std::ostream &sout) const { + size_t var_id = 0; + std::unordered_map<const VarHandleBase *, size_t> vars; + + sout << "digraph G {\n"; + + IterAllVar(graph, [&](const VarHandleBase &var) { + auto *var_ptr = &var; + auto *var_handle_ptr = dynamic_cast<const VarHandle *>(var_ptr); + auto *dummy_ptr = dynamic_cast<const DummyVarHandle *>(var_ptr); + + size_t cur_var_id = var_id++; + vars[var_ptr] = cur_var_id; + + if (var_handle_ptr) { + sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_ + << "\\n" + << var_handle_ptr->place_ << "\\n" + << var_handle_ptr->version_ << "\"]" << std::endl; + } else if (dummy_ptr) { + sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl; + } + }); + + size_t op_id = 0; + for (auto &op : graph.ops_) { + std::string op_name = "op_" + std::to_string(op_id++); + sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" + << std::endl; + for (auto in : op->Inputs()) { + std::string var_name = "var_" + std::to_string(vars[in]); + sout << var_name << " -> " << op_name << std::endl; + } + + for (auto out : op->Outputs()) { + std::string var_name = "var_" + std::to_string(vars[out]); + sout << op_name << " -> " << var_name << std::endl; + } + } + + sout << "}\n"; +} +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h new file mode 100644 index 0000000000000000000000000000000000000000..5287be3b6a05ec7067ca433ba976b0314d05fe02 --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -0,0 +1,67 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include <iosfwd> +#include "paddle/fluid/framework/details/ssa_graph_builder.h" + +namespace paddle { +namespace framework { +namespace details { +class SSAGraph; +class SSAGraphPrinter { + public: + virtual ~SSAGraphPrinter() {} + virtual void Print(const SSAGraph& graph, std::ostream& sout) const = 0; +}; + +class GraphvizSSAGraphPrinter : public SSAGraphPrinter { + public: + void Print(const SSAGraph& graph, std::ostream& sout) const override; +}; + +class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { + public: + SSAGraghBuilderWithPrinter(std::ostream& sout, + std::unique_ptr<SSAGraphPrinter>&& printer, + std::unique_ptr<SSAGraphBuilder>&& builder) + : printer_(std::move(printer)), + builder_(std::move(builder)), + stream_ref_(sout) {} + + SSAGraghBuilderWithPrinter(std::unique_ptr<std::ostream>&& sout, + std::unique_ptr<SSAGraphPrinter>&& printer, + std::unique_ptr<SSAGraphBuilder>&& builder) + : printer_(std::move(printer)), + builder_(std::move(builder)), + stream_ptr_(std::move(sout)), + stream_ref_(*stream_ptr_) {} + + std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override { + auto graph = builder_->Build(program); + printer_->Print(*graph, stream_ref_); + return graph; + } + + private: + std::unique_ptr<SSAGraphPrinter> printer_; + std::unique_ptr<SSAGraphBuilder> builder_; + std::unique_ptr<std::ostream> stream_ptr_; + std::ostream& stream_ref_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 003304b85af00165d54efbb199be01b2c5106768..ce56f55e4195a0625cd0754152285b80e4282183 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -22,7 +22,7 @@ limitations under the License. */ #include "paddle/fluid/platform/nccl_helper.h" #endif -#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include "paddle/fluid/framework/details/graph_builder_factory.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/platform/profiler.h" @@ -102,22 +102,19 @@ ParallelExecutor::ParallelExecutor( var_infos.back().persistable_ = var->Persistable(); } -// Step 3. Convert main_program to SSA form and dependency graph. Also, insert -// ncclOp -#ifdef PADDLE_WITH_CUDA - details::MultiDevSSAGraphBuilder builder( + // Step 3. Convert main_program to SSA form and dependency graph. Also, insert + // ncclOp + + details::SSAGraphBuilderFactory builder_factory( member_->places_, loss_var_name, params, member_->local_scopes_, - member_->nccl_ctxs_.get(), build_strategy); -#else - details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, - params, member_->local_scopes_, - build_strategy); + build_strategy); +#ifdef PADDLE_WITH_CUDA + builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get()); #endif - auto graph = builder.Build(main_program); - member_->executor_.reset(new details::ThreadedSSAGraphExecutor( - exec_strategy, member_->local_scopes_, places, std::move(graph))); + exec_strategy, member_->local_scopes_, places, + builder_factory.Create()->Build(main_program))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 03cf417b62f96fd6812b3eac497ffdf9a484f5eb..669d1bdaa3ec194be817cdc5e1f8484770c70c68 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -553,6 +553,12 @@ All parameter, weight, gradient are variables in Paddle. [](BuildStrategy &self, BuildStrategy::GradientScaleStrategy strategy) { self.gradient_scale_ = strategy; + }) + .def_property( + "debug_graphviz_path", + [](const BuildStrategy &self) { return self.debug_graphviz_path_; }, + [](BuildStrategy &self, const std::string &path) { + self.debug_graphviz_path_ = path; }); pe.def(py::init<const std::vector<platform::Place> &,