diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 48f94a1f05614d4b797562ac67cdb9828fd0456e..132725fa7e89a9218b64470ea5fded62e8f26c9e 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -79,9 +79,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { BuildStrategy strategy_; }; -std::shared_ptr BuildStrategy::CreatePassesFromStrategy() - const { +std::shared_ptr BuildStrategy::CreatePassesFromStrategy( + bool from_user) const { + if (finalized_by_user_) { + return pass_builder_; + } pass_builder_.reset(new ParallelExecutorPassBuilder(*this)); + if (from_user) { + finalized_by_user_ = true; + } return pass_builder_; } @@ -95,10 +101,8 @@ std::unique_ptr BuildStrategy::Apply( #else const bool use_cuda) const { #endif - // Create a default one if not initialized by user. - if (!pass_builder_) { - CreatePassesFromStrategy(); - } + // Create a default one if not finalized by user. + CreatePassesFromStrategy(false); std::unique_ptr graph(new ir::Graph(main_program)); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 6c7b54db8f610aa34cd51dcbc13063290cae3ac0..e9deebd504eeb6632b32045ea16918a83f7c80f4 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -80,7 +80,8 @@ struct BuildStrategy { // from python side. // A new PassBuilder is created based on configs defined above and // passes are owned by the PassBuilder. - std::shared_ptr CreatePassesFromStrategy() const; + std::shared_ptr CreatePassesFromStrategy( + bool from_user) const; // Apply the passes built by the pass_builder_. The passes will be // applied to the Program and output an ir::Graph. @@ -97,6 +98,7 @@ struct BuildStrategy { #endif private: + mutable bool finalized_by_user_ = false; mutable std::shared_ptr pass_builder_; }; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 238cc19189cfd74afa38bdcb5f5c802f9521dfea..b7776df9042364e5006b5f497eb04fea9132f860 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -855,10 +855,13 @@ All parameter, weight, gradient are variables in Paddle. R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether to fuse elementwise_add_op and activation_op, it may make the execution faster. Default False)DOC") - .def("_create_passes_from_strategy", + .def("_finalize_strategy_and_create_passes", [](BuildStrategy &self) -> std::shared_ptr { - return self.CreatePassesFromStrategy(); - }); + return self.CreatePassesFromStrategy(true); + }, + R"DOC(Allow user to customized passes. Normally model-specific + optimization passes should be defined in this way. BuildStrategy + cannot be updated after being finalized.)DOC"); pe.def(py::init &, const std::unordered_set &, diff --git a/python/paddle/fluid/tests/unittests/test_pass_builder.py b/python/paddle/fluid/tests/unittests/test_pass_builder.py index 288c5f6a1f6b1760ca40c0c653e4c0726b799519..65ad63dc013dde2f2e4503890f6b986e91d03690 100644 --- a/python/paddle/fluid/tests/unittests/test_pass_builder.py +++ b/python/paddle/fluid/tests/unittests/test_pass_builder.py @@ -94,7 +94,7 @@ class TestPassBuilder(unittest.TestCase): def test_parallel_testing_with_new_strategy(self): build_strategy = fluid.BuildStrategy() - pass_builder = build_strategy._create_passes_from_strategy() + pass_builder = build_strategy._finalize_strategy_and_create_passes() origin_len = len(pass_builder.all_passes()) viz_pass = pass_builder.append_pass("graph_viz_pass")