未验证 提交 8cfda7ee 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #14382 from panyx0718/fix4

Refine the pass builder and buildstrategy
......@@ -79,9 +79,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
BuildStrategy strategy_;
};
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy()
const {
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
bool finalize_strategy) const {
if (is_finalized_) {
return pass_builder_;
}
pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
if (finalize_strategy) {
is_finalized_ = true;
}
return pass_builder_;
}
......@@ -95,10 +101,8 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
#else
const bool use_cuda) const {
#endif
// Create a default one if not initialized by user.
if (!pass_builder_) {
CreatePassesFromStrategy();
}
// Create a default one if not finalized by user.
CreatePassesFromStrategy(false);
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
......
......@@ -75,12 +75,20 @@ struct BuildStrategy {
bool remove_unnecessary_lock_{false};
// NOTE:
// Before you add new options, think if it's a general strategy that works
// with other strategy. If not, the strategy should be created through
// CreatePassesFromStrategy and the pass can be managed separately.
// User normally doesn't need to call this API.
// The PassBuilder allows for more customized insert, remove of passes
// from python side.
// A new PassBuilder is created based on configs defined above and
// passes are owned by the PassBuilder.
std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy() const;
std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy(
bool finalize_strategy) const;
bool IsFinalized() const { return is_finalized_; }
// Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph.
......@@ -97,6 +105,7 @@ struct BuildStrategy {
#endif
private:
mutable bool is_finalized_ = false;
mutable std::shared_ptr<ir::PassBuilder> pass_builder_;
};
......
......@@ -650,9 +650,9 @@ All parameter, weight, gradient are variables in Paddle.
[](ir::Pass &self, const std::string &name, const std::string &attr) {
self.Set<std::string>(name, new std::string(attr));
})
.def("set_int", [](ir::Pass &self, const std::string &name, int val) {
self.Set<const int>(name, new int(val));
});
.def("set_int", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); })
.def("type", &ir::Pass::Type);
py::class_<ir::PassBuilder, std::shared_ptr<ir::PassBuilder>> pb(
m, "PassBuilder");
......@@ -791,6 +791,7 @@ All parameter, weight, gradient are variables in Paddle.
"reduce_strategy",
[](const BuildStrategy &self) { return self.reduce_; },
[](BuildStrategy &self, BuildStrategy::ReduceStrategy strategy) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.reduce_ = strategy;
},
R"DOC(The type is STR, there are two reduce strategies in ParallelExecutor,
......@@ -804,6 +805,7 @@ All parameter, weight, gradient are variables in Paddle.
[](const BuildStrategy &self) { return self.gradient_scale_; },
[](BuildStrategy &self,
BuildStrategy::GradientScaleStrategy strategy) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.gradient_scale_ = strategy;
},
R"DOC(The type is STR, there are three ways of defining :math:`loss@grad` in
......@@ -815,6 +817,7 @@ All parameter, weight, gradient are variables in Paddle.
"debug_graphviz_path",
[](const BuildStrategy &self) { return self.debug_graphviz_path_; },
[](BuildStrategy &self, const std::string &path) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.debug_graphviz_path_ = path;
},
R"DOC(The type is STR, debug_graphviz_path indicate the path that
......@@ -824,6 +827,7 @@ All parameter, weight, gradient are variables in Paddle.
"enable_data_balance",
[](const BuildStrategy &self) { return self.enable_data_balance_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.enable_data_balance_ = b;
}) // FIXME(chengudo): enable_data_balance seems not important
.def_property(
......@@ -832,6 +836,7 @@ All parameter, weight, gradient are variables in Paddle.
return self.enable_sequential_execution_;
},
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.enable_sequential_execution_ = b;
},
R"DOC(The type is BOOL. If set True, the execution order of ops would be the same as what is in the program. Default False.)DOC")
......@@ -841,6 +846,7 @@ All parameter, weight, gradient are variables in Paddle.
return self.remove_unnecessary_lock_;
},
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.remove_unnecessary_lock_ = b;
},
R"DOC(The type is BOOL. If set True, some locks in GPU ops would be released and ParallelExecutor would run faster. Default False.)DOC")
......@@ -850,15 +856,19 @@ All parameter, weight, gradient are variables in Paddle.
return self.fuse_elewise_add_act_ops_;
},
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.fuse_elewise_add_act_ops_ = b;
},
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
to fuse elementwise_add_op and activation_op,
it may make the execution faster. Default False)DOC")
.def("_create_passes_from_strategy",
.def("_finalize_strategy_and_create_passes",
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
return self.CreatePassesFromStrategy();
});
return self.CreatePassesFromStrategy(true);
},
R"DOC(Allow user to customized passes. Normally model-specific
optimization passes should be defined in this way. BuildStrategy
cannot be updated after being finalized.)DOC");
pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &,
......
......@@ -105,7 +105,7 @@ class TestDistRunnerBase(object):
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
if args.batch_merge_repeat > 1:
pass_builder = build_stra._create_passes_from_strategy()
pass_builder = build_stra._finalize_strategy_and_create_passes()
mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.batch_merge_repeat)
......
......@@ -94,7 +94,12 @@ class TestPassBuilder(unittest.TestCase):
def test_parallel_testing_with_new_strategy(self):
build_strategy = fluid.BuildStrategy()
pass_builder = build_strategy._create_passes_from_strategy()
self.assertFalse(build_strategy.fuse_elewise_add_act_ops)
build_strategy.fuse_elewise_add_act_ops = True
pass_builder = build_strategy._finalize_strategy_and_create_passes()
self.assertTrue("fuse_elewise_add_act_pass" in
[p.type() for p in pass_builder.all_passes()])
origin_len = len(pass_builder.all_passes())
viz_pass = pass_builder.append_pass("graph_viz_pass")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册