diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 0b38f1f40f1406ba51ee034fe7d4dd8d4246e121..003ca23085c075bf26b05c71ef4c8c5d22963e61 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -21,12 +21,12 @@ limitations under the License. */ #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/graph_printer.h" #include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h" -#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h" DECLARE_bool(use_mkldnn); @@ -48,212 +48,195 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { public: explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy) : ir::PassBuilder(), strategy_(strategy) { - // Add a graph viz pass to record a graph. - if (!strategy_.debug_graphviz_path_.empty()) { - VLOG(1) << "Add graph_viz_pass"; - auto viz_pass = AppendPass("graph_viz_pass"); - const std::string graph_path = string::Sprintf( - "%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph"); - viz_pass->Set("graph_viz_path", new std::string(graph_path)); - } + ResolveOptionConfliction(); - // Note(zcd): record_skip_memory_opt_vars_pass should be the first pass. - VLOG(1) << "Add record_skip_memory_opt_vars_pass"; + AppendPrintGraphPass("graph_viz_pass", "_original_graph"); + // Note(zcd): record_skip_memory_opt_vars_pass should + // be the first pass. AppendPass("record_skip_memory_opt_vars_pass"); + AppendPassWithCheck(strategy_.enable_sequential_execution_, + "sequential_execution_pass"); + AppendPassWithCheck(strategy_.sync_batch_norm_, "sync_batch_norm_pass"); + + AppendOpFusePasses(); + AppendPrintGraphPass("graph_viz_pass", "_fused_graph"); + // TODO(dev-paddle): memory optimize pass should be placed last. + AppendMemoryOptimizePasses(); + AppendMultiDevPass(); + AppendMultiGraphOptPasses(); + + AppendPassToSetMkldnnAttr("mkldnn_placement_pass"); + // runtime_context_cache pass should be the last pass to enable the attr of + // all original and fused operators. But no operators can be enabled this + // attr if putting it after MultiDevPass. + AppendPassWithCheck(strategy_.cache_runtime_context_, + "runtime_context_cache_pass"); + AppendPassWithCheck(strategy_.remove_unnecessary_lock_, + "modify_op_lock_and_record_event_pass"); + // Note: This pass is used to check whether the multi_device_graph is right. + AppendPass("multi_devices_check_pass"); -#ifdef PADDLE_WITH_MKLDNN - if (FLAGS_use_mkldnn) { - VLOG(1) << "Add mkldnn_placement_pass"; - AppendPass("mkldnn_placement_pass"); - } else if (!strategy_.mkldnn_enabled_op_types_.empty()) { - LOG(WARNING) - << "mkldnn_enabled_op_types specify the operator type list to " - "use MKLDNN acceleration. It is null in default, means " - "that all the operators supported by MKLDNN will be " - "accelerated. And it should not be set when " - "FLAGS_use_mkldnn=false."; - } -#else - PADDLE_ENFORCE(!FLAGS_use_mkldnn, - "Please compile with MKLDNN first to use MKLDNN"); -#endif + SetCollectiveContext(); + } - if (strategy_.enable_sequential_execution_) { - VLOG(1) << "Add sequential_execution_pass"; - AppendPass("sequential_execution_pass"); + void ResolveOptionConfliction() { + // Specifies the restrictions between different pass. + if (strategy_.enable_parallel_graph_) { + VLOG_IF(3, strategy_.fuse_all_optimizer_ops_) + << "Currently, fuse_all_optimizer_ops doesn't works under " + "parallel_graph."; + strategy_.fuse_all_optimizer_ops_ = false; } - - // Add op fusion. - if (strategy.sync_batch_norm_) { - AppendPass("sync_batch_norm_pass"); + if (strategy_.is_distribution_) { + VLOG_IF(3, strategy_.fuse_all_optimizer_ops_) + << "Currently, fuse_all_optimizer_ops only works under " + "Non-distributed mode."; + strategy_.fuse_all_optimizer_ops_ = false; } - - // Add op fusion. - if (strategy.fuse_relu_depthwise_conv_) { - VLOG(1) << "Add fuse_relu_depthwise_conv_pass"; - AppendPass("fuse_relu_depthwise_conv_pass"); + if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { + VLOG_IF(3, strategy_.fuse_all_optimizer_ops_) + << "Currently, fuse_all_optimizer_ops only works under AllReduce " + "mode."; + strategy_.fuse_all_optimizer_ops_ = false; + VLOG_IF(3, strategy_.fuse_all_reduce_ops_) + << "fuse_all_optimizer_ops only work in Reducer mode."; + strategy_.fuse_all_reduce_ops_ = false; } + } - // TODO(zjl): refactor MemoryOptimizePass to fit - // new strategy, which does not need to set - // var.persistable = True - if (strategy_.use_legacy_memory_optimize_strategy_) { - if (strategy_.enable_inplace_) { - VLOG(5) << "Add inplace_pass"; - AppendPass("inplace_pass"); - } - } + void AppendMultiGraphOptPasses() { + // NOTE: fuse_all_reduce_ops will count the number of all_reduce operator + // first, if the number is zero, fuse_all_reduce_ops will do nothing. + AppendPassWithCheck(strategy_.fuse_all_reduce_ops_, + "fuse_all_reduce_op_pass"); + AppendPrintGraphPass("multi_devices_print_pass", "_multi_devices_graph"); - if (strategy_.fuse_elewise_add_act_ops_) { - VLOG(1) << "Add fuse_elewise_add_act_pass"; - AppendPass("fuse_elewise_add_act_pass"); - } + // experimental shows that the program will be faster if append + // all_reduce_deps_pass here. + bool append_all_reduce_deps_pass = + !strategy_.enable_parallel_graph_ && + (SeqOnlyAllReduceOps(strategy_) || + strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce); + AppendPassWithCheck(append_all_reduce_deps_pass, "all_reduce_deps_pass"); + bool append_backward_optimizer_op_deps_pass = + strategy_.num_trainers_ > 1 && !strategy_.async_mode_ && + !strategy_.is_distribution_ && + strategy_.enable_backward_optimizer_op_deps_; + AppendPassWithCheck(append_backward_optimizer_op_deps_pass, + "backward_optimizer_op_deps_pass"); + } + + void AppendOpFusePasses() { + AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_, + "fuse_relu_depthwise_conv_pass"); + AppendPassWithCheck(strategy_.fuse_elewise_add_act_ops_, + "fuse_elewise_add_act_pass"); // for single card training, fuse_all_reduce_ops is unnecessary. // coalesce_grad_tensor_pass should be before of MultiDevPass. - if (strategy_.fuse_all_reduce_ops_) { - VLOG(1) << "Add coalesce_grad_tensor_pass"; - AppendPass("coalesce_grad_tensor_pass"); - } - + AppendPassWithCheck(strategy_.fuse_all_reduce_ops_, + "coalesce_grad_tensor_pass"); // Fuse all the optimization operators. - if (strategy_.is_distribution_) { - VLOG(3) << "Currently, fuse_all_optimizer_ops only works under " - "Non-distributed mode."; - strategy_.fuse_all_optimizer_ops_ = false; - } - if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce || - strategy_.is_distribution_) { - VLOG(3) << "Currently, fuse_all_optimizer_ops only works under AllReduce " - "mode."; - strategy_.fuse_all_optimizer_ops_ = false; - } + // NOTE: fuse_all_xx_ops will count the number of xx operator first, + // if the number is zero, fuse_all_reduce_ops will do nothing. + // Currently, only one type of optimization algorithm can be fused. if (strategy_.fuse_all_optimizer_ops_) { - // NOTE: fuse_all_xx_ops will count the number of xx operator first, - // if the number is zero, fuse_all_reduce_ops will do nothing. - // Currently, only one type of optimization algorithm can be fused. - VLOG(1) << "Add fuse_adam_op_pass"; AppendPass("fuse_adam_op_pass"); - VLOG(1) << "Add fuse_sgd_op_pass"; AppendPass("fuse_sgd_op_pass"); - VLOG(1) << "Add fuse_momentum_op_pass"; AppendPass("fuse_momentum_op_pass"); } + } - // Add a graph viz pass to record a graph. - if (!strategy.debug_graphviz_path_.empty()) { - auto viz_pass = AppendPass("graph_viz_pass"); - const std::string graph_path = string::Sprintf( - "%s%s", strategy_.debug_graphviz_path_.c_str(), "_fused_graph"); - viz_pass->Set("graph_viz_path", new std::string(graph_path)); - } - - CollectiveContext *context = CollectiveContext::GetInstance(); - context->endpoints_ = strategy_.trainers_endpoints_; - context->trainer_id_ = strategy_.trainer_id_; - PADDLE_ENFORCE(strategy_.trainer_id_ >= 0, "trainer_id_ >= 0"); - if (strategy_.trainer_id_ > 0 && strategy_.trainers_endpoints_.size() > 0) { - PADDLE_ENFORCE((unsigned)(strategy_.trainer_id_) < - strategy_.trainers_endpoints_.size(), - "trainer_id_ < endpoints_ size"); + void AppendMemoryOptimizePasses() { // Append Memory Optimize Pass + // TODO(zjl): refactor MemoryOptimizePass to fit + // new strategy, which does not need to set + // var.persistable = True + if (strategy_.use_legacy_memory_optimize_strategy_) { + AppendPassWithCheck(strategy_.enable_inplace_, "inplace_pass"); } - VLOG(1) << "CollectiveContext:" << context->String(); - // NOTE(dzh): memory optimize should be a runtime pass. // However, after multi_devices_pass, VarHandle, OpHandle is // the de-fact IR, any reuse on Graph is meaningless. // A side-effect of that, memory optimize cannot forsee the fetched vars // , so fetchlist should be set persistable before call the Run interface. if (strategy_.use_legacy_memory_optimize_strategy_) { - if (strategy_.memory_optimize_) { - VLOG(5) << "Add memory_optimize_pass"; - AppendPass("memory_optimize_pass"); - } - } - - // runtime_context_cache pass should be the last pass to enable the attr of - // all original and fused operators. But no operators can be enabled this - // attr if putting it after MultiDevPass. - if (strategy_.cache_runtime_context_) { - VLOG(1) << "Add runtime_context_cache_pass"; - AppendPass("runtime_context_cache_pass"); - } - - AppendMultiDevPass(strategy_); - - if (strategy_.fuse_all_reduce_ops_) { - // NOTE: fuse_all_reduce_ops will count the number of all_reduce operator - // first, if the number is zero, fuse_all_reduce_ops will do nothing. - VLOG(1) << "Add fuse_all_reduce_op_pass"; - AppendPass("fuse_all_reduce_op_pass"); - } - - // Add a graph print pass to record a graph with device info. - if (!strategy_.debug_graphviz_path_.empty()) { - VLOG(1) << "Add multi_devices_print_pass"; - auto multi_devices_print_pass = AppendPass("multi_devices_print_pass"); - const std::string graph_path = - string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(), - "_multi_devices_graph"); - multi_devices_print_pass->Set(ir::kGraphvizPath, - new std::string(graph_path)); - multi_devices_print_pass->Set( - "graph_printer", new ir::GraphvizSSAGraphPrinter); - } - - // experimental shows that the program will be faster if append - // all_reduce_deps_pass here. - if (!strategy_.enable_parallel_graph_ && - (SeqOnlyAllReduceOps(strategy_) || - strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) { - VLOG(1) << "Add all_reduce_deps_pass"; - AppendPass("all_reduce_deps_pass"); - } - - if (strategy_.num_trainers_ > 1 && !strategy_.async_mode_ && - !strategy_.is_distribution_ && - strategy_.enable_backward_optimizer_op_deps_) { - VLOG(1) << "Add backward_op_deps_pass"; - AppendPass("backward_optimizer_op_deps_pass"); + AppendPassWithCheck(strategy_.memory_optimize_, "memory_optimize_pass"); } + } - if (strategy_.remove_unnecessary_lock_) { - VLOG(1) << "Add modify_op_lock_and_record_event_pass"; - AppendPass("modify_op_lock_and_record_event_pass"); + void SetCollectiveContext() const { + CollectiveContext *context = CollectiveContext::GetInstance(); + context->endpoints_ = strategy_.trainers_endpoints_; + context->trainer_id_ = strategy_.trainer_id_; + PADDLE_ENFORCE_GE(strategy_.trainer_id_, 0, "trainer_id_ >= 0"); + if (strategy_.trainer_id_ > 0 && strategy_.trainers_endpoints_.size() > 0) { + PADDLE_ENFORCE_LT(static_cast(strategy_.trainer_id_), + strategy_.trainers_endpoints_.size(), + "trainer_id_ < endpoints_ size"); } - - // Verify that the graph is correct for multi-device executor. - VLOG(1) << "Add multi_devices_check_pass"; - AppendPass("multi_devices_check_pass"); + VLOG(1) << "CollectiveContext:" << context->String(); } // Convert graph to run on multi-devices. - void AppendMultiDevPass(const BuildStrategy &strategy) { + void AppendMultiDevPass() { ir::Pass *multi_devices_pass = nullptr; - if (strategy_.async_mode_) { - VLOG(1) << "Add async_multi_devices_pass"; multi_devices_pass = AppendPass("async_multi_devices_pass").get(); } else if (strategy_.is_distribution_) { - VLOG(1) - << "Add dist_multi_devices_pass, multi device parameter server mode"; multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); } else { - if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { - VLOG(1) << "Add all_reduce_mode_multi_devices_pass"; - multi_devices_pass = - AppendPass("all_reduce_mode_multi_devices_pass").get(); - } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { - VLOG(1) << "Add reduce_mode_multi_devices_pass"; - multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get(); - } else { - PADDLE_THROW("Unknown reduce strategy."); + switch (strategy_.reduce_) { + case BuildStrategy::ReduceStrategy::kAllReduce: + multi_devices_pass = + AppendPass("all_reduce_mode_multi_devices_pass").get(); + break; + case BuildStrategy::ReduceStrategy::kReduce: + multi_devices_pass = + AppendPass("reduce_mode_multi_devices_pass").get(); + break; + default: + PADDLE_THROW("Unknown reduce strategy."); } } multi_devices_pass->SetNotOwned("strategy", &strategy_); } + void AppendPrintGraphPass(const std::string &pass_name, + const std::string &debug_file_suffix) { + if (!strategy_.debug_graphviz_path_.empty()) { + auto viz_pass = AppendPass(pass_name); + const std::string graph_path = string::Sprintf( + "%s%s", strategy_.debug_graphviz_path_.c_str(), debug_file_suffix); + viz_pass->Set(ir::kGraphvizPath, + new std::string(graph_path)); + } + } + + void AppendPassWithCheck(bool append_pass, const std::string &pass_name) { + if (append_pass) { + AppendPass(pass_name); + } + } + + void AppendPassToSetMkldnnAttr(const std::string &pass_name) { +#ifdef PADDLE_WITH_MKLDNN + if (FLAGS_use_mkldnn) { + AppendPass(pass_name); + } else if (!strategy_.mkldnn_enabled_op_types_.empty()) { + LOG(WARNING) + << "mkldnn_enabled_op_types specify the operator type list to " + "use MKLDNN acceleration. It is null in default, means " + "that all the operators supported by MKLDNN will be " + "accelerated. And it should not be set when " + "FLAGS_use_mkldnn=false."; + } +#else + PADDLE_ENFORCE(!FLAGS_use_mkldnn, + "Please compile with MKLDNN first to use MKLDNN"); +#endif + } + private: BuildStrategy strategy_; }; @@ -307,26 +290,20 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, pass->Erase(kNCCLCtxs); pass->SetNotOwned(kNCCLCtxs, nctx); #endif - } else if (pass->Type() == "coalesce_grad_tensor_pass" || - pass->Type() == "fuse_adam_op_pass" || - pass->Type() == "fuse_sgd_op_pass" || - pass->Type() == "fuse_momentum_op_pass" || - pass->Type() == "fuse_all_reduce_op_pass") { + } else if (pass->Type() == "fuse_all_reduce_op_pass") { pass->Erase(kPlaces); pass->SetNotOwned>(kPlaces, &places); pass->Erase(kLocalScopes); pass->SetNotOwned>(kLocalScopes, &local_scopes); - if (pass->Type() == "fuse_all_reduce_op_pass") { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr; - pass->Erase(kNCCLCtxs); - pass->SetNotOwned(kNCCLCtxs, nctx); - pass->Erase(kUseHierarchicalAllReduce); - pass->Set(kUseHierarchicalAllReduce, - new bool(use_hierarchical_allreduce_)); + platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr; + pass->Erase(kNCCLCtxs); + pass->SetNotOwned(kNCCLCtxs, nctx); + pass->Erase(kUseHierarchicalAllReduce); + pass->Set(kUseHierarchicalAllReduce, + new bool(use_hierarchical_allreduce_)); #endif - } } else if (pass->Type() == "coalesce_grad_tensor_pass") { pass->Erase(kPlaces); pass->SetNotOwned>(kPlaces, &places); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 547a9d72787507842cdc10d6dbc0cbcae4fe652b..8b767222324d525ee5b8f38e37c55fa2d653190d 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -88,7 +88,7 @@ struct BuildStrategy { bool fuse_elewise_add_act_ops_{false}; // Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients // should not be sparse types - bool fuse_all_optimizer_ops_{false}; + bool fuse_all_optimizer_ops_{true}; bool fuse_all_reduce_ops_{false}; // fuse_relu_depthwise_conv can fuse the `relu -> // depthwise_conv` diff --git a/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc b/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc index 3eb4ef9fb3afd88599bb63d1ab507dc9a8c39093..8acfc5ecf04de73c4c95479a526e8060e7f844e3 100644 --- a/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc +++ b/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc @@ -483,6 +483,4 @@ class CoalesceGradTensorPass : public ir::Pass { } // namespace paddle REGISTER_PASS(coalesce_grad_tensor_pass, - paddle::framework::ir::CoalesceGradTensorPass) - .RequirePassAttr(paddle::framework::details::kPlaces) - .RequirePassAttr(paddle::framework::details::kLocalScopes); + paddle::framework::ir::CoalesceGradTensorPass); diff --git a/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc b/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc index 504ff04cfed267ad4fba795672b2809042fe52a3..88366238d312ba5bff8abb789654146bc575ad6a 100644 --- a/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc +++ b/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc @@ -204,6 +204,4 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { } // namespace framework } // namespace paddle -REGISTER_PASS(fuse_adam_op_pass, paddle::framework::ir::FuseAdamOpPass) - .RequirePassAttr(paddle::framework::details::kPlaces) - .RequirePassAttr(paddle::framework::details::kLocalScopes); +REGISTER_PASS(fuse_adam_op_pass, paddle::framework::ir::FuseAdamOpPass); diff --git a/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_momentum_op_pass.cc b/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_momentum_op_pass.cc index 3ac92d176274461fd548b0f6b7b3e1c632cdaa76..b038bc92deffd697ca356f27992dc61ffa85b956 100644 --- a/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_momentum_op_pass.cc +++ b/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_momentum_op_pass.cc @@ -87,6 +87,4 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass { } // namespace framework } // namespace paddle -REGISTER_PASS(fuse_momentum_op_pass, paddle::framework::ir::FuseMomentumOpPass) - .RequirePassAttr(paddle::framework::details::kPlaces) - .RequirePassAttr(paddle::framework::details::kLocalScopes); +REGISTER_PASS(fuse_momentum_op_pass, paddle::framework::ir::FuseMomentumOpPass); diff --git a/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_sgd_op_pass.cc b/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_sgd_op_pass.cc index 077e393c105dadf0e87d64f520fe3a65b88c6972..3824ceec72b2b9fb4053fe52c8e34a7b8b02596b 100644 --- a/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_sgd_op_pass.cc +++ b/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_sgd_op_pass.cc @@ -65,6 +65,4 @@ class FuseSgdOpPass : public FuseOptimizerOpPass { } // namespace framework } // namespace paddle -REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::ir::FuseSgdOpPass) - .RequirePassAttr(paddle::framework::details::kPlaces) - .RequirePassAttr(paddle::framework::details::kLocalScopes); +REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::ir::FuseSgdOpPass); diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h b/paddle/fluid/framework/ir/graph_printer.h similarity index 95% rename from paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h rename to paddle/fluid/framework/ir/graph_printer.h index 8562856e3d5fc923d453c8c646269c3d7559b6ce..76b07f0d6530907e7b20253d6a2a744fd2e11362 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h +++ b/paddle/fluid/framework/ir/graph_printer.h @@ -26,7 +26,7 @@ namespace paddle { namespace framework { namespace ir { -constexpr char kGraphvizPath[] = "debug_graphviz_path"; +constexpr char kGraphvizPath[] = "graph_viz_path"; class SSAGraphPrinter { public: diff --git a/paddle/fluid/framework/ir/graph_viz_pass.cc b/paddle/fluid/framework/ir/graph_viz_pass.cc index f4df4cfeba66889f3bf547d989d27aa76587e6be..1da3c9fe69791c94de5d97bffb22145277f614df 100644 --- a/paddle/fluid/framework/ir/graph_viz_pass.cc +++ b/paddle/fluid/framework/ir/graph_viz_pass.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/ir/graph_printer.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/string/printf.h" @@ -25,8 +26,6 @@ namespace framework { namespace ir { using inference::analysis::Dot; namespace { -const char kGraphVizPath[] = "graph_viz_path"; - std::string FormatName(const Node* node) { if (!node->IsOp() || !node->Op() || !node->Op()->HasAttr(OpProtoAndCheckerMaker::OpNamescopeAttrName())) { @@ -39,7 +38,7 @@ std::string FormatName(const Node* node) { } // namespace void GraphVizPass::ApplyImpl(ir::Graph* graph) const { - const std::string graph_viz_path = Get(kGraphVizPath); + const std::string& graph_viz_path = Get(kGraphvizPath); VLOG(3) << "draw IR graph viz to " << graph_viz_path; std::unique_ptr fout(new std::ofstream(graph_viz_path)); PADDLE_ENFORCE(fout->good()); @@ -132,4 +131,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes( } // namespace paddle REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass) - .RequirePassAttr(paddle::framework::ir::kGraphVizPath); + .RequirePassAttr(paddle::framework::ir::kGraphvizPath); diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.cc index a6c2b28215affcfb30f66452a633eea266088906..efd549e79d0ef2ff31a3d1253201f1c2656adf84 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h" #include #include #include #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/graph_printer.h" namespace paddle { namespace framework { @@ -29,7 +29,12 @@ class SSAGraghBuilderWithPrinterPass : public ir::Pass { std::unique_ptr fout( new std::ofstream(Get(kGraphvizPath))); PADDLE_ENFORCE(fout->good()); - Get("graph_printer").Print(*graph, *fout); + if (Has("graph_printer")) { + Get("graph_printer").Print(*graph, *fout); + } else { + GraphvizSSAGraphPrinter printer; + printer.Print(*graph, *fout); + } } }; diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 5b5dee92d619af09078c78dccb92f15a71f84017..b4cfda919ce346c60ef9f4e24de705b51488e4dd 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -24,6 +24,7 @@ namespace framework { namespace ir { Graph* Pass::Apply(Graph* graph) const { + CheckPrevPass(); PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty."); for (const std::string& attr : required_pass_attrs_) { PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(), @@ -41,6 +42,10 @@ Graph* Pass::Apply(Graph* graph) const { PADDLE_ENFORCE(VarDescIsConsistency(*graph), "The VarDescs of persistable variable are not consistency."); applied_ = true; + if (!graph->Has(kPassRecorder)) { + graph->Set(kPassRecorder, new PassRecorder); + } + graph->Get(kPassRecorder).insert(Type()); return graph; } diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 6cbe9a8212775512431860591526b52665ec4037..cf6b8d1338e20a67d332c2ddec562f662d8ff0a9 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -20,6 +20,7 @@ limitations under the License. */ #include #include #include +#include #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/program_desc.h" @@ -31,6 +32,9 @@ namespace ir { template struct PassRegistrar; +typedef std::unordered_set PassRecorder; +constexpr char kPassRecorder[] = "pass_recorder"; + class Pass { public: Pass() = default; @@ -104,6 +108,10 @@ class Pass { LOG(FATAL) << "Calling virtual Pass not implemented."; } + // Some Pass must be placed before this Pass, and some + // Pass must be placed after this Pass. + virtual void CheckPrevPass() const {} + private: template friend struct PassRegistrar; diff --git a/paddle/fluid/framework/ir/pass_builder.cc b/paddle/fluid/framework/ir/pass_builder.cc index e0719867b34d13666672b22070ce14dbaf80d85d..457de41c8f6a84cf81798c71b2366fb1d989b9de 100644 --- a/paddle/fluid/framework/ir/pass_builder.cc +++ b/paddle/fluid/framework/ir/pass_builder.cc @@ -13,12 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/pass_builder.h" +#include +#include namespace paddle { namespace framework { namespace ir { std::shared_ptr PassBuilder::AppendPass(const std::string& pass_type) { + VLOG(3) << "Append " << pass_type; auto pass = ir::PassRegistry::Instance().Get(pass_type); passes_.emplace_back(pass.release()); return passes_.back(); diff --git a/paddle/fluid/framework/ir/sync_batch_norm_pass.cc b/paddle/fluid/framework/ir/sync_batch_norm_pass.cc index 25207ffc1e9f540131f5b7c0336d308831aec19f..2077304b9693b41448720a72cd47804b1fe2d60d 100644 --- a/paddle/fluid/framework/ir/sync_batch_norm_pass.cc +++ b/paddle/fluid/framework/ir/sync_batch_norm_pass.cc @@ -26,7 +26,7 @@ class SyncBatchNormPass : public Pass { void ApplyImpl(ir::Graph *graph) const override { VLOG(3) << "Use synchronous batch norm"; for (const Node *n : graph->Nodes()) { - if (n->IsOp()) { + if (n->IsOp() && n->Op()) { auto *op = n->Op(); if (op->Type() == "batch_norm") { op->SetType("sync_batch_norm"); diff --git a/python/paddle/fluid/tests/unittests/test_buffer_shared_memory_reuse_pass.py b/python/paddle/fluid/tests/unittests/test_buffer_shared_memory_reuse_pass.py index e5d1bdcf1aa41c5c99becfec73fa823ac5812d62..730dbe0aa4caf876567d8a6102096c42ce289950 100644 --- a/python/paddle/fluid/tests/unittests/test_buffer_shared_memory_reuse_pass.py +++ b/python/paddle/fluid/tests/unittests/test_buffer_shared_memory_reuse_pass.py @@ -32,6 +32,7 @@ feed_dict = { class InplaceTestBase(unittest.TestCase): def initParameter(self): self.use_cuda = True + self.fuse_all_optimizer_ops = False def setUp(self): self.initParameter() @@ -39,7 +40,6 @@ class InplaceTestBase(unittest.TestCase): self.device_count = fluid.core.get_cuda_device_count() else: self.device_count = 4 - assert batch_size % self.device_count == 0 def build_program_and_scope(self): @@ -90,6 +90,7 @@ class InplaceTestBase(unittest.TestCase): build_strategy = fluid.BuildStrategy() build_strategy.memory_optimize = memory_optimize build_strategy.enable_inplace = enable_inplace + build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops compiled_prog = fluid.CompiledProgram(prog).with_data_parallel( loss_name=loss.name, build_strategy=build_strategy, @@ -135,6 +136,7 @@ class InplaceTestBase(unittest.TestCase): build_strategy = fluid.BuildStrategy() build_strategy.memory_optimize = memory_optimize build_strategy.enable_inplace = enable_inplace + build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops compiled_program = fluid.CompiledProgram( prog).with_data_parallel( loss_name=loss.name, @@ -162,6 +164,19 @@ class InplaceTestBase(unittest.TestCase): class CPUInplaceTest(InplaceTestBase): def initParameter(self): self.use_cuda = False + self.fuse_all_optimizer_ops = False + + +class CUDAInplaceTestWithFuseOptimizationOps(InplaceTestBase): + def initParameter(self): + self.use_cuda = True + self.fuse_all_optimizer_ops = True + + +class CPUInplaceTestWithFuseOptimizationOps(InplaceTestBase): + def initParameter(self): + self.use_cuda = True + self.fuse_all_optimizer_ops = True if __name__ == '__main__':