提交 99dffb91 编写于 作者: X Xin Pan

allow to repeatedly share and update BuildStrategy

test=develop
上级 df826de7
......@@ -79,9 +79,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
BuildStrategy strategy_;
};
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy()
const {
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
bool from_user) const {
if (finalized_by_user_) {
return pass_builder_;
}
pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
if (from_user) {
finalized_by_user_ = true;
}
return pass_builder_;
}
......@@ -95,10 +101,8 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
#else
const bool use_cuda) const {
#endif
// Create a default one if not initialized by user.
if (!pass_builder_) {
CreatePassesFromStrategy();
}
// Create a default one if not finalized by user.
CreatePassesFromStrategy(false);
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
......
......@@ -80,7 +80,8 @@ struct BuildStrategy {
// from python side.
// A new PassBuilder is created based on configs defined above and
// passes are owned by the PassBuilder.
std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy() const;
std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy(
bool from_user) const;
// Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph.
......@@ -97,6 +98,7 @@ struct BuildStrategy {
#endif
private:
mutable bool finalized_by_user_ = false;
mutable std::shared_ptr<ir::PassBuilder> pass_builder_;
};
......
......@@ -855,10 +855,13 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
to fuse elementwise_add_op and activation_op,
it may make the execution faster. Default False)DOC")
.def("_create_passes_from_strategy",
.def("_finalize_strategy_and_create_passes",
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
return self.CreatePassesFromStrategy();
});
return self.CreatePassesFromStrategy(true);
},
R"DOC(Allow user to customized passes. Normally model-specific
optimization passes should be defined in this way. BuildStrategy
cannot be updated after being finalized.)DOC");
pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &,
......
......@@ -94,7 +94,7 @@ class TestPassBuilder(unittest.TestCase):
def test_parallel_testing_with_new_strategy(self):
build_strategy = fluid.BuildStrategy()
pass_builder = build_strategy._create_passes_from_strategy()
pass_builder = build_strategy._finalize_strategy_and_create_passes()
origin_len = len(pass_builder.all_passes())
viz_pass = pass_builder.append_pass("graph_viz_pass")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册