提交 99dffb91 编写于 作者: X Xin Pan

allow to repeatedly share and update BuildStrategy

test=develop
上级 df826de7
...@@ -79,9 +79,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -79,9 +79,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
BuildStrategy strategy_; BuildStrategy strategy_;
}; };
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy() std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
const { bool from_user) const {
if (finalized_by_user_) {
return pass_builder_;
}
pass_builder_.reset(new ParallelExecutorPassBuilder(*this)); pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
if (from_user) {
finalized_by_user_ = true;
}
return pass_builder_; return pass_builder_;
} }
...@@ -95,10 +101,8 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -95,10 +101,8 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
#else #else
const bool use_cuda) const { const bool use_cuda) const {
#endif #endif
// Create a default one if not initialized by user. // Create a default one if not finalized by user.
if (!pass_builder_) { CreatePassesFromStrategy(false);
CreatePassesFromStrategy();
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program)); std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
......
...@@ -80,7 +80,8 @@ struct BuildStrategy { ...@@ -80,7 +80,8 @@ struct BuildStrategy {
// from python side. // from python side.
// A new PassBuilder is created based on configs defined above and // A new PassBuilder is created based on configs defined above and
// passes are owned by the PassBuilder. // passes are owned by the PassBuilder.
std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy() const; std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy(
bool from_user) const;
// Apply the passes built by the pass_builder_. The passes will be // Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph. // applied to the Program and output an ir::Graph.
...@@ -97,6 +98,7 @@ struct BuildStrategy { ...@@ -97,6 +98,7 @@ struct BuildStrategy {
#endif #endif
private: private:
mutable bool finalized_by_user_ = false;
mutable std::shared_ptr<ir::PassBuilder> pass_builder_; mutable std::shared_ptr<ir::PassBuilder> pass_builder_;
}; };
......
...@@ -855,10 +855,13 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -855,10 +855,13 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
to fuse elementwise_add_op and activation_op, to fuse elementwise_add_op and activation_op,
it may make the execution faster. Default False)DOC") it may make the execution faster. Default False)DOC")
.def("_create_passes_from_strategy", .def("_finalize_strategy_and_create_passes",
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> { [](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
return self.CreatePassesFromStrategy(); return self.CreatePassesFromStrategy(true);
}); },
R"DOC(Allow user to customized passes. Normally model-specific
optimization passes should be defined in this way. BuildStrategy
cannot be updated after being finalized.)DOC");
pe.def(py::init<const std::vector<platform::Place> &, pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &, const std::unordered_set<std::string> &,
......
...@@ -94,7 +94,7 @@ class TestPassBuilder(unittest.TestCase): ...@@ -94,7 +94,7 @@ class TestPassBuilder(unittest.TestCase):
def test_parallel_testing_with_new_strategy(self): def test_parallel_testing_with_new_strategy(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
pass_builder = build_strategy._create_passes_from_strategy() pass_builder = build_strategy._finalize_strategy_and_create_passes()
origin_len = len(pass_builder.all_passes()) origin_len = len(pass_builder.all_passes())
viz_pass = pass_builder.append_pass("graph_viz_pass") viz_pass = pass_builder.append_pass("graph_viz_pass")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册