From 5cf0092825a9625018e8856931cbdb8ff15b71a5 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Thu, 7 Feb 2019 14:19:21 +0800 Subject: [PATCH] add more log and fix test_dist_base in multi_batch_merge_pass --- paddle/fluid/framework/details/build_strategy.cc | 2 ++ paddle/fluid/framework/ir/pass.cc | 1 + python/paddle/fluid/tests/unittests/test_dist_base.py | 3 +-- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 51ce9732722..ca9843057d6 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -177,11 +177,13 @@ std::unique_ptr BuildStrategy::Apply( #else const bool use_cuda) const { #endif + VLOG(3) << "apply all passes"; // Create a default one if not finalized by user. CreatePassesFromStrategy(false); std::unique_ptr graph(new ir::Graph(main_program)); for (std::shared_ptr &pass : pass_builder_->AllPasses()) { + VLOG(3) << "apply " << pass->Type(); if (IsMultiDevPass(pass->Type())) { pass->Erase(kPlaces); pass->SetNotOwned>(kPlaces, &places); diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 33ccee6aa0a..823697495ed 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -19,6 +19,7 @@ namespace paddle { namespace framework { namespace ir { std::unique_ptr Pass::Apply(std::unique_ptr graph) const { + VLOG(3) << "apply pass -> " << Type(); PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty."); for (const std::string& attr : required_pass_attrs_) { PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(), diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 758c510dc75..98e6923c111 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -128,8 +128,7 @@ class TestDistRunnerBase(object): if args.batch_merge_repeat > 1: pass_builder = build_stra._finalize_strategy_and_create_passes() - mypass = pass_builder.insert_pass( - len(pass_builder.all_passes()) - 3, "multi_batch_merge_pass") + mypass = pass_builder.insert_pass(0, "multi_batch_merge_pass") mypass.set("num_repeats", args.batch_merge_repeat) if args.update_method == "nccl2": -- GitLab