From b43e49fa3152077e9e487c95cedbce7b4aa1119c Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 17 Sep 2018 15:49:13 +0800 Subject: [PATCH] fix --- paddle/fluid/API.spec | 2 +- .../fluid/framework/details/build_strategy.cc | 74 +++++++------------ .../fluid/framework/details/build_strategy.h | 6 +- paddle/fluid/framework/ir/pass.h | 17 +++++ paddle/fluid/framework/ir/pass_test.cc | 10 +-- paddle/fluid/pybind/pybind.cc | 10 ++- .../tests/unittests/test_pass_builder.py | 19 ++++- 7 files changed, 70 insertions(+), 68 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 8dcc1358b85..7cda403f7fe 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -63,7 +63,7 @@ paddle.fluid.ExecutionStrategy.__init__ __init__(self: paddle.fluid.core.Executi paddle.fluid.BuildStrategy.GradientScaleStrategy.__init__ __init__(self: paddle.fluid.core.GradientScaleStrategy, arg0: int) -> None paddle.fluid.BuildStrategy.ReduceStrategy.__init__ __init__(self: paddle.fluid.core.ReduceStrategy, arg0: int) -> None paddle.fluid.BuildStrategy.__init__ __init__(self: paddle.fluid.core.BuildStrategy) -> None -paddle.fluid.BuildStrategy.create_pass_builder create_pass_builder(self: paddle.fluid.core.BuildStrategy) -> paddle.fluid.core.PassBuilder +paddle.fluid.BuildStrategy.create_passes_from_srategy create_passes_from_srategy(self: paddle.fluid.core.BuildStrategy) -> paddle.fluid.core.PassBuilder paddle.fluid.create_lod_tensor ArgSpec(args=['data', 'recursive_seq_lens', 'place'], varargs=None, keywords=None, defaults=None) paddle.fluid.create_random_int_lodtensor ArgSpec(args=['recursive_seq_lens', 'base_shape', 'place', 'low', 'high'], varargs=None, keywords=None, defaults=None) paddle.fluid.io.save_vars ArgSpec(args=['executor', 'dirname', 'main_program', 'vars', 'predicate', 'filename'], varargs=None, keywords=None, defaults=(None, None, None, None)) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 2a3bc85ff79..deeb18656b4 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -14,9 +14,6 @@ limitations under the License. */ #include "paddle/fluid/framework/details/build_strategy.h" -#include -#include - #include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/ir/graph.h" @@ -71,46 +68,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { AppendPass("multi_devices_check_pass"); } - std::unique_ptr Build( - const ProgramDesc &main_program, - const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶m_names, - const std::vector &local_scopes, -#ifdef PADDLE_WITH_CUDA - const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const { -#else - const bool use_cuda) const { -#endif - // Convert the program to graph. - std::unique_ptr graph(new ir::Graph(main_program)); - - for (std::shared_ptr &pass : AllPasses()) { - if (pass->Type() == "multi_devices_pass") { - pass->SetNotOwned>("places", - &places); - pass->SetNotOwned("loss_var_name", &loss_var_name); - pass->SetNotOwned>("params", - ¶m_names); - pass->SetNotOwned>("local_scopes", - &local_scopes); -#ifdef PADDLE_WITH_CUDA - platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; - pass->SetNotOwned("nccl_ctxs", nctx); -#endif - } - graph = pass->Apply(std::move(graph)); - } - return graph; - } - private: BuildStrategy strategy_; }; -ir::PassBuilder *BuildStrategy::CreatePassBuilder() const { +std::shared_ptr BuildStrategy::CreatePassesFromStrategy() + const { pass_builder_.reset(new ParallelExecutorPassBuilder(*this)); - return pass_builder_.get(); + return pass_builder_; } std::unique_ptr BuildStrategy::Apply( @@ -123,20 +88,33 @@ std::unique_ptr BuildStrategy::Apply( #else const bool use_cuda) const { #endif + // Create a default one if not intialized by user. if (!pass_builder_) { - CreatePassBuilder(); + CreatePassesFromStrategy(); } - // std::unique_ptr graph; - ParallelExecutorPassBuilder *builder = - reinterpret_cast(pass_builder_.get()); + + std::unique_ptr graph(new ir::Graph(main_program)); + + for (std::shared_ptr &pass : pass_builder_->AllPasses()) { + if (pass->Type() == "multi_devices_pass") { + pass->Erase("places"); + pass->SetNotOwned>("places", &places); + pass->Erase("loss_var_name"); + pass->SetNotOwned("loss_var_name", &loss_var_name); + pass->Erase("params"); + pass->SetNotOwned>("params", + ¶m_names); + pass->Erase("local_scopes"); + pass->SetNotOwned>("local_scopes", + &local_scopes); #ifdef PADDLE_WITH_CUDA - std::unique_ptr graph = - builder->Build(main_program, places, loss_var_name, param_names, - local_scopes, use_cuda, nccl_ctxs); -#else - std::unique_ptr graph = builder->Build( - main_program, places, loss_var_name, param_names, local_scopes, use_cuda); + platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; + pass->Erase("nccl_ctxs"); + pass->SetNotOwned("nccl_ctxs", nctx); #endif + } + graph = pass->Apply(std::move(graph)); + } return graph; } } // namespace details diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 4468708d09f..f75a1913b78 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -31,9 +31,6 @@ namespace paddle { namespace framework { namespace details { -class ParallelExecutorPassBuilder; -struct BuildStrategy; - struct BuildStrategy { // ParallelExecutor supports two modes of ReduceStrategy, kAllReduce and // kReduce, for CPU and GPU. If you use kAllReduce, different threads @@ -72,7 +69,7 @@ struct BuildStrategy { bool enable_data_balance_{false}; - ir::PassBuilder *CreatePassBuilder() const; + std::shared_ptr CreatePassesFromStrategy() const; std::unique_ptr Apply( const ProgramDesc &main_program, @@ -87,7 +84,6 @@ struct BuildStrategy { #endif private: - // TODO(panyx0718): This should probably be unique_ptr. mutable std::shared_ptr pass_builder_; }; diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 042a7461b42..9570c59cff2 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -54,6 +54,21 @@ class Pass { return *boost::any_cast(attrs_.at(attr_name)); } + bool Has(const std::string &attr_name) const { + return attrs_.find(attr_name) != attrs_.end(); + } + + void Erase(const std::string &attr_name) { + if (!Has(attr_name)) { + return; + } + if (attr_dels_.find(attr_name) != attr_dels_.end()) { + attr_dels_[attr_name](); + attr_dels_.erase(attr_name); + } + attrs_.erase(attr_name); + } + // Set a pointer to the attribute. Pass takes ownership of the attribute. template void Set(const std::string &attr_name, AttrType *attr) { @@ -70,6 +85,8 @@ class Pass { // should delete the attribute. template void SetNotOwned(const std::string &attr_name, AttrType *attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass", + attr_name); attrs_[attr_name] = attr; } diff --git a/paddle/fluid/framework/ir/pass_test.cc b/paddle/fluid/framework/ir/pass_test.cc index 5b5011412ed..6ad7d1df8bd 100644 --- a/paddle/fluid/framework/ir/pass_test.cc +++ b/paddle/fluid/framework/ir/pass_test.cc @@ -82,12 +82,10 @@ TEST(PassTest, TestPassAttrCheck) { ASSERT_EQ(graph->Get("copy_test_pass_attr"), 2); ASSERT_EQ(graph->Get("copy_test_graph_attr"), 2); - try { - graph = pass->Apply(std::move(graph)); - } catch (paddle::platform::EnforceNotMet e) { - exception = std::string(e.what()); - } - ASSERT_TRUE(exception.find("Pass can only Apply() once") != exception.npos); + // Allow apply more than once. + graph.reset(new Graph(prog)); + graph->Set("test_graph_attr", new int); + graph = pass->Apply(std::move(graph)); pass = PassRegistry::Instance().Get("test_pass"); pass->SetNotOwned("test_pass_attr", &val); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c14b893fa40..f4ccadccca0 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -603,7 +603,8 @@ All parameter, weight, gradient are variables in Paddle. self.Set(name, new std::string(attr)); }); - py::class_ pb(m, "PassBuilder"); + py::class_> pb( + m, "PassBuilder"); pb.def(py::init()) .def("append_pass", [](ir::PassBuilder &self, @@ -701,9 +702,10 @@ All parameter, weight, gradient are variables in Paddle. [](BuildStrategy &self, bool b) { self.fuse_elewise_add_act_ops_ = b; }) - .def("create_pass_builder", - [](BuildStrategy &self) { return *self.CreatePassBuilder(); }, - py::return_value_policy::reference); + .def("create_passes_from_srategy", + [](BuildStrategy &self) -> std::shared_ptr { + return self.CreatePassesFromStrategy(); + }); pe.def(py::init &, const std::unordered_set &, diff --git a/python/paddle/fluid/tests/unittests/test_pass_builder.py b/python/paddle/fluid/tests/unittests/test_pass_builder.py index 2da4c097d92..0abd6fe494e 100644 --- a/python/paddle/fluid/tests/unittests/test_pass_builder.py +++ b/python/paddle/fluid/tests/unittests/test_pass_builder.py @@ -94,16 +94,27 @@ class TestPassBuilder(unittest.TestCase): def test_parallel_testing_with_new_strategy(self): build_strategy = fluid.BuildStrategy() - pass_builder = build_strategy.create_pass_builder() + pass_builder = build_strategy.create_passes_from_srategy() + origin_len = len(pass_builder.all_passes()) + viz_pass = pass_builder.append_pass("graph_viz_pass") - all_passes = pass_builder.all_passes() - pass_builder.insert_pass(len(all_passes), "graph_viz_pass") + self.assertEqual(origin_len + 1, len(pass_builder.all_passes())) + + pass_builder.insert_pass( + len(pass_builder.all_passes()), "graph_viz_pass") + self.assertEqual(origin_len + 2, len(pass_builder.all_passes())) + pass_builder.remove_pass(len(pass_builder.all_passes()) - 1) - viz_pass.set_str("graph_viz_path", "/tmp/viz_pass") + self.assertEqual(origin_len + 1, len(pass_builder.all_passes())) + viz_pass.set_str("graph_viz_path", "/tmp/test_viz_pass") self.check_network_convergence( use_cuda=core.is_compiled_with_cuda(), build_strategy=build_strategy) + try: + os.stat("/tmp/test_viz_pass") + except os.error: + self.assertFalse(True) if __name__ == '__main__': -- GitLab