提交 759ffca4 编写于 作者: X Xin Pan

some improvements

test=develop
上级 99dffb91
......@@ -80,13 +80,13 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
};
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
bool from_user) const {
if (finalized_by_user_) {
bool finalize_strategy) const {
if (is_finalized_) {
return pass_builder_;
}
pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
if (from_user) {
finalized_by_user_ = true;
if (finalize_strategy) {
is_finalized_ = true;
}
return pass_builder_;
}
......
......@@ -75,13 +75,20 @@ struct BuildStrategy {
bool remove_unnecessary_lock_{false};
// NOTE:
// Before you add new options, think if it's a general strategy that works
// with other strategy. If not, the strategy should be created through
// CreatePassesFromStrategy and the pass can be managed separately.
// User normally doesn't need to call this API.
// The PassBuilder allows for more customized insert, remove of passes
// from python side.
// A new PassBuilder is created based on configs defined above and
// passes are owned by the PassBuilder.
std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy(
bool from_user) const;
bool finalize_strategy) const;
bool IsFinalized() const { return is_finalized_; }
// Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph.
......@@ -98,7 +105,7 @@ struct BuildStrategy {
#endif
private:
mutable bool finalized_by_user_ = false;
mutable bool is_finalized_ = false;
mutable std::shared_ptr<ir::PassBuilder> pass_builder_;
};
......
......@@ -791,6 +791,7 @@ All parameter, weight, gradient are variables in Paddle.
"reduce_strategy",
[](const BuildStrategy &self) { return self.reduce_; },
[](BuildStrategy &self, BuildStrategy::ReduceStrategy strategy) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.reduce_ = strategy;
},
R"DOC(The type is STR, there are two reduce strategies in ParallelExecutor,
......@@ -804,6 +805,7 @@ All parameter, weight, gradient are variables in Paddle.
[](const BuildStrategy &self) { return self.gradient_scale_; },
[](BuildStrategy &self,
BuildStrategy::GradientScaleStrategy strategy) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.gradient_scale_ = strategy;
},
R"DOC(The type is STR, there are three ways of defining :math:`loss@grad` in
......@@ -815,6 +817,7 @@ All parameter, weight, gradient are variables in Paddle.
"debug_graphviz_path",
[](const BuildStrategy &self) { return self.debug_graphviz_path_; },
[](BuildStrategy &self, const std::string &path) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.debug_graphviz_path_ = path;
},
R"DOC(The type is STR, debug_graphviz_path indicate the path that
......@@ -824,6 +827,7 @@ All parameter, weight, gradient are variables in Paddle.
"enable_data_balance",
[](const BuildStrategy &self) { return self.enable_data_balance_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.enable_data_balance_ = b;
}) // FIXME(chengudo): enable_data_balance seems not important
.def_property(
......@@ -832,6 +836,7 @@ All parameter, weight, gradient are variables in Paddle.
return self.enable_sequential_execution_;
},
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.enable_sequential_execution_ = b;
},
R"DOC(The type is BOOL. If set True, the execution order of ops would be the same as what is in the program. Default False.)DOC")
......@@ -841,6 +846,7 @@ All parameter, weight, gradient are variables in Paddle.
return self.remove_unnecessary_lock_;
},
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.remove_unnecessary_lock_ = b;
},
R"DOC(The type is BOOL. If set True, some locks in GPU ops would be released and ParallelExecutor would run faster. Default False.)DOC")
......@@ -850,6 +856,7 @@ All parameter, weight, gradient are variables in Paddle.
return self.fuse_elewise_add_act_ops_;
},
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.fuse_elewise_add_act_ops_ = b;
},
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册