diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 51ce9732722efa44d2489f5b77694094e58c8775..ca9843057d628635d9a50edf8deba658f3574c3f 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -177,11 +177,13 @@ std::unique_ptr BuildStrategy::Apply( #else const bool use_cuda) const { #endif + VLOG(3) << "apply all passes"; // Create a default one if not finalized by user. CreatePassesFromStrategy(false); std::unique_ptr graph(new ir::Graph(main_program)); for (std::shared_ptr &pass : pass_builder_->AllPasses()) { + VLOG(3) << "apply " << pass->Type(); if (IsMultiDevPass(pass->Type())) { pass->Erase(kPlaces); pass->SetNotOwned>(kPlaces, &places); diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 33ccee6aa0a94b8fd8308214d6144ae832d40bab..823697495edf32c9f2d6339f44f84be551f6474c 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -19,6 +19,7 @@ namespace paddle { namespace framework { namespace ir { std::unique_ptr Pass::Apply(std::unique_ptr graph) const { + VLOG(3) << "apply pass -> " << Type(); PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty."); for (const std::string& attr : required_pass_attrs_) { PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(), diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 758c510dc757bd4106d6258a6c2de021bb4aefbc..98e6923c1115142d3a9c90073898cd03db716398 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -128,8 +128,7 @@ class TestDistRunnerBase(object): if args.batch_merge_repeat > 1: pass_builder = build_stra._finalize_strategy_and_create_passes() - mypass = pass_builder.insert_pass( - len(pass_builder.all_passes()) - 3, "multi_batch_merge_pass") + mypass = pass_builder.insert_pass(0, "multi_batch_merge_pass") mypass.set("num_repeats", args.batch_merge_repeat) if args.update_method == "nccl2":