diff --git a/doc/fluid/design/ir/draft.md b/doc/fluid/design/ir/draft.md index a33b5a9c9312c93247a1e1f3431061a5aad6c884..65bfaea6a1db6d8f9340ea0195ab6fa91d8bd91a 100644 --- a/doc/fluid/design/ir/draft.md +++ b/doc/fluid/design/ir/draft.md @@ -71,6 +71,44 @@ is a `Graph` and its output is also a `Graph`. For example, a `Pass` can simply print out the `Graph`. A `Pass` can also fuse some `Graph`'s `Node`s. +```cpp +class Pass { + public: + + virtual std::unique_ptr Apply(std::unique_ptr graph) const = 0; + + // Get a reference to the attributed previously set. + template + AttrType &Get(const std::string &attr_name) const; + + // Set a pointer to the attribute. Pass takes ownership of the attribute. + template + void Set(const std::string &attr_name, AttrType *attr) ; + + // Set a pointer to the attribute. Pass doesn't take ownership. Caller + // should delete the attribute. + template + void SetNotOwned(const std::string &attr_name, AttrType *attr); +}; + +// In my_pass.cc +class MyPass : public Pass { + public: + std::unique_ptr Apply(std::unique_ptr graph) const override { + // do something. + return graph; + } +} +REGISTER_PASS(my_pass, MyPass); + + +// To use the pass. +auto my_pass = ir::PassRegistry::Instance().Get("my_pass"); +graph = my_pass->Apply(std::move(graph)); +// Note: to force link my_pass.cc, in the code: +USE_PASS(my_pass); +``` + #### Optimize `Optimize` contains a series of `Pass` with defined order. diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index f3c1e7c5288b3b43c89983ae81b042dd64efdbe5..d822a1c9c4a491be670c15207d121aba7d941fe3 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -99,7 +99,7 @@ else() endif() -cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 9df7df1f42886d40210b16aa2ae5823e3310bfe7..5d652d37307d0a55ffee14930ae180dcd3e27841 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -31,9 +31,6 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) - -cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) - cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index d211f02689e93bec368ee29ae711799a5941bb85..ff90f31cdbcee53b4f57b2570831b3ba58a36f40 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -35,15 +35,15 @@ namespace framework { namespace details { void MultiDevSSAGraphBuilder::Init() const { - loss_var_name_ = Get("loss_var_name"); - places_ = Get>("places"); - local_scopes_ = Get>("local_scopes"); - strategy_ = Get("strategy"); + loss_var_name_ = Get("loss_var_name"); + places_ = Get>("places"); + local_scopes_ = Get>("local_scopes"); + strategy_ = Get("strategy"); #ifdef PADDLE_WITH_CUDA nccl_ctxs_ = &Get("nccl_ctxs"); #endif - for (auto &p : Get>("params")) { + for (auto &p : Get>("params")) { grad_names_.insert(GradVarName(p)); } balance_vars_.resize(places_.size(), 0); diff --git a/paddle/fluid/framework/details/ssa_graph_builder_factory.cc b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc deleted file mode 100644 index 2254a3b41eaf58e0badd72c71ee2c05ffc2599e0..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/ssa_graph_builder_factory.cc +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" -#include -#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" -#include "paddle/fluid/framework/details/ssa_graph_checker.h" -#include "paddle/fluid/framework/details/ssa_graph_printer.h" - -namespace paddle { -namespace framework { -namespace details { -std::unique_ptr ParallelExecutorPassManager::Create() { - std::unique_ptr res(new MultiDevSSAGraphBuilder); - res->SetNotOwned>("places", &places_); - res->SetNotOwned("loss_var_name", &loss_var_name_); - res->SetNotOwned>("params", ¶m_names_); - res->SetNotOwned>("local_scopes", &local_scopes_); - res->SetNotOwned("strategy", &strategy_); -#ifdef PADDLE_WITH_CUDA - res->SetNotOwned("nccl_ctxs", nccl_ctxs_); -#endif - - if (!strategy_.debug_graphviz_path_.empty()) { - ir::Pass *previous_pass = res.release(); - res.reset(new SSAGraghBuilderWithPrinter); - res->Set("previous_pass", previous_pass); - res->SetNotOwned("debug_graphviz_path", - &strategy_.debug_graphviz_path_); - res->Set("graph_printer", - new GraphvizSSAGraphPrinter); - } - - ir::Pass *previous_pass = res.release(); - res.reset(new SSAGraghBuilderWithChecker); - res->Set("previous_pass", previous_pass); - - return res; -} -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder_factory.h b/paddle/fluid/framework/details/ssa_graph_builder_factory.h deleted file mode 100644 index 1bfc3e71e8c993b554a50f3b7a82fa9d79639b4a..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/ssa_graph_builder_factory.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include -#include "paddle/fluid/framework/details/build_strategy.h" -#include "paddle/fluid/framework/details/ssa_graph_builder.h" -#include "paddle/fluid/platform/place.h" - -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/platform/nccl_helper.h" -#endif - -namespace paddle { -namespace framework { -class Scope; -namespace details { - -class ParallelExecutorPassManager { - public: - ParallelExecutorPassManager( - const std::vector& places, - const std::string& loss_var_name, - const std::unordered_set& param_names, - const std::vector& local_scopes, const BuildStrategy& strategy) - : places_(places), - loss_var_name_(loss_var_name), - param_names_(param_names), - local_scopes_(local_scopes), - strategy_(strategy) { -#ifdef PADDLE_WITH_CUDA - nccl_ctxs_ = nullptr; -#endif - } - -#ifdef PADDLE_WITH_CUDA - void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) { - nccl_ctxs_ = nccl_ctxs; - } -#endif - - std::unique_ptr Create(); - - private: - std::vector places_; - std::string loss_var_name_; - std::unordered_set param_names_; - std::vector local_scopes_; - BuildStrategy strategy_; - -#ifdef PADDLE_WITH_CUDA - platform::NCCLContextMap* nccl_ctxs_; -#endif -}; - -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index fb766fb41550782bb2f54795d2ca0d41864d9cf0..25891cf74da6c26c420cae6582b5e385a69a3d06 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -26,9 +26,8 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { public: std::unique_ptr Apply( std::unique_ptr graph) const override { - auto new_graph = Get("previous_pass").Apply(std::move(graph)); - PADDLE_ENFORCE(IsValidGraph(new_graph.get())); - return new_graph; + PADDLE_ENFORCE(IsValidGraph(graph.get())); + return graph; } bool IsValidGraph(const ir::Graph* graph) const; diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index b7d20aa9838eb4a5f3d7a53837a00614ddf0e907..bd4498c0612051cc90ba67289e9b0237b9d75802 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -39,13 +39,11 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { public: std::unique_ptr Apply( std::unique_ptr graph) const override { - auto new_graph = Get("previous_pass").Apply(std::move(graph)); - std::unique_ptr fout( - new std::ofstream(Get("debug_graphviz_path"))); + new std::ofstream(Get("debug_graphviz_path"))); PADDLE_ENFORCE(fout->good()); - Get("graph_printer").Print(*new_graph, *fout); - return new_graph; + Get("graph_printer").Print(*graph, *fout); + return graph; } }; diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 5ab7f9a1e2d85b17dc4a192b6600b29f58be5bcb..f254ef62df5ad04f201b049f28839fb0eda4ee33 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -42,6 +42,7 @@ class Pass { virtual std::unique_ptr Apply(std::unique_ptr graph) const = 0; + // Get a reference to the attributed previously set. template AttrType &Get(const std::string &attr_name) const { PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(), @@ -49,6 +50,7 @@ class Pass { return *boost::any_cast(attrs_.at(attr_name)); } + // Set a pointer to the attribute. Pass takes ownership of the attribute. template void Set(const std::string &attr_name, AttrType *attr) { PADDLE_ENFORCE(attrs_.count(attr_name) == 0); @@ -59,6 +61,8 @@ class Pass { }; } + // Set a pointer to the attribute. Pass doesn't take ownership. Caller + // should delete the attribute. template void SetNotOwned(const std::string &attr_name, AttrType *attr) { PADDLE_ENFORCE(attrs_.count(attr_name) == 0); @@ -127,6 +131,7 @@ struct PassRegistrar : public Registrar { __test_global_namespace_##uniq_name##__>::value, \ msg) +// Register a new pass that can be applied on the IR. #define REGISTER_PASS(pass_type, pass_class) \ STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ __reg_pass__##pass_type, \ diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index a23fd2a41a84419dc50f91eeeb79a591d92271fd..77bed5c9994167b98d42f49791682a8a55c29316 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -26,7 +26,8 @@ limitations under the License. */ #endif #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" -#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" +#include "paddle/fluid/framework/details/ssa_graph_checker.h" +#include "paddle/fluid/framework/details/ssa_graph_printer.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/platform/profiler.h" @@ -43,16 +44,6 @@ std::unique_ptr ApplyParallelExecutorPass( #else const BuildStrategy &strategy) { #endif - details::ParallelExecutorPassManager builder_factory( - places, loss_var_name, param_names, local_scopes, strategy); - if (use_cuda) { -#ifdef PADDLE_WITH_CUDA - builder_factory.SetNCCLContextMap(nccl_ctxs); -#else - PADDLE_THROW("Not compiled with CUDA."); -#endif - } - std::unique_ptr graph(new ir::Graph(main_program)); if (!strategy.debug_graphviz_path_.empty()) { auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); @@ -62,8 +53,37 @@ std::unique_ptr ApplyParallelExecutorPass( graph = viz_pass->Apply(std::move(graph)); } - auto builder = builder_factory.Create(); - graph = builder->Apply(std::move(graph)); + auto multi_device_pass = + ir::PassRegistry::Instance().Get("multi_device_pass"); + multi_device_pass->SetNotOwned>("places", + &places); + multi_device_pass->SetNotOwned("loss_var_name", + &loss_var_name); + multi_device_pass->SetNotOwned>( + "params", ¶m_names); + multi_device_pass->SetNotOwned>("local_scopes", + &local_scopes); + multi_device_pass->SetNotOwned("strategy", &strategy); + +#ifdef PADDLE_WITH_CUDA + platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; + multi_device_pass->SetNotOwned("nccl_ctxs", nctx); +#endif + graph = multi_device_pass->Apply(std::move(graph)); + + if (!strategy.debug_graphviz_path_.empty()) { + auto multi_device_print_pass = + ir::PassRegistry::Instance().Get("multi_device_print_pass"); + multi_device_print_pass->SetNotOwned( + "debug_graphviz_path", &strategy.debug_graphviz_path_); + multi_device_print_pass->Set( + "graph_printer", new details::GraphvizSSAGraphPrinter); + graph = multi_device_print_pass->Apply(std::move(graph)); + } + + auto multi_device_check_pass = + ir::PassRegistry::Instance().Get("multi_device_check_pass"); + graph = multi_device_check_pass->Apply(std::move(graph)); if (!strategy.debug_graphviz_path_.empty()) { auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");