From 99dffb91d668d70b7c110f76de70d9666c5dc7d4 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Thu, 8 Nov 2018 20:20:33 +0800 Subject: [PATCH] allow to repeatedly share and update BuildStrategy test=develop --- paddle/fluid/framework/details/build_strategy.cc | 16 ++++++++++------ paddle/fluid/framework/details/build_strategy.h | 4 +++- paddle/fluid/pybind/pybind.cc | 9 ++++++--- .../fluid/tests/unittests/test_pass_builder.py | 2 +- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 48f94a1f056..132725fa7e8 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -79,9 +79,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { BuildStrategy strategy_; }; -std::shared_ptr BuildStrategy::CreatePassesFromStrategy() - const { +std::shared_ptr BuildStrategy::CreatePassesFromStrategy( + bool from_user) const { + if (finalized_by_user_) { + return pass_builder_; + } pass_builder_.reset(new ParallelExecutorPassBuilder(*this)); + if (from_user) { + finalized_by_user_ = true; + } return pass_builder_; } @@ -95,10 +101,8 @@ std::unique_ptr BuildStrategy::Apply( #else const bool use_cuda) const { #endif - // Create a default one if not initialized by user. - if (!pass_builder_) { - CreatePassesFromStrategy(); - } + // Create a default one if not finalized by user. + CreatePassesFromStrategy(false); std::unique_ptr graph(new ir::Graph(main_program)); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 6c7b54db8f6..e9deebd504e 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -80,7 +80,8 @@ struct BuildStrategy { // from python side. // A new PassBuilder is created based on configs defined above and // passes are owned by the PassBuilder. - std::shared_ptr CreatePassesFromStrategy() const; + std::shared_ptr CreatePassesFromStrategy( + bool from_user) const; // Apply the passes built by the pass_builder_. The passes will be // applied to the Program and output an ir::Graph. @@ -97,6 +98,7 @@ struct BuildStrategy { #endif private: + mutable bool finalized_by_user_ = false; mutable std::shared_ptr pass_builder_; }; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 238cc19189c..b7776df9042 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -855,10 +855,13 @@ All parameter, weight, gradient are variables in Paddle. R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether to fuse elementwise_add_op and activation_op, it may make the execution faster. Default False)DOC") - .def("_create_passes_from_strategy", + .def("_finalize_strategy_and_create_passes", [](BuildStrategy &self) -> std::shared_ptr { - return self.CreatePassesFromStrategy(); - }); + return self.CreatePassesFromStrategy(true); + }, + R"DOC(Allow user to customized passes. Normally model-specific + optimization passes should be defined in this way. BuildStrategy + cannot be updated after being finalized.)DOC"); pe.def(py::init &, const std::unordered_set &, diff --git a/python/paddle/fluid/tests/unittests/test_pass_builder.py b/python/paddle/fluid/tests/unittests/test_pass_builder.py index 288c5f6a1f6..65ad63dc013 100644 --- a/python/paddle/fluid/tests/unittests/test_pass_builder.py +++ b/python/paddle/fluid/tests/unittests/test_pass_builder.py @@ -94,7 +94,7 @@ class TestPassBuilder(unittest.TestCase): def test_parallel_testing_with_new_strategy(self): build_strategy = fluid.BuildStrategy() - pass_builder = build_strategy._create_passes_from_strategy() + pass_builder = build_strategy._finalize_strategy_and_create_passes() origin_len = len(pass_builder.all_passes()) viz_pass = pass_builder.append_pass("graph_viz_pass") -- GitLab