提交 c2b3838c 编写于 作者: X Xin Pan

add some comments

上级 0d9ee0dc
...@@ -27,7 +27,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -27,7 +27,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public: public:
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy) explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
: ir::PassBuilder(), strategy_(strategy) { : ir::PassBuilder(), strategy_(strategy) {
// Apply a graph viz pass to record a graph. // Add a graph viz pass to record a graph.
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass"); auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf( const std::string graph_path = string::Sprintf(
...@@ -35,10 +35,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -35,10 +35,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path)); viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
} }
// Apply op fusion. // Add op fusion.
if (strategy.fuse_elewise_add_act_ops_) { if (strategy.fuse_elewise_add_act_ops_) {
auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass"); auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass");
// Apply a graph viz pass to record a graph. // Add a graph viz pass to record a graph.
if (!strategy.debug_graphviz_path_.empty()) { if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass"); auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf( const std::string graph_path = string::Sprintf(
...@@ -53,7 +53,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -53,7 +53,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy", multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
&strategy_); &strategy_);
// Apply a graph print pass to record a graph with device info. // Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
auto multi_devices_print_pass = AppendPass("multi_devices_print_pass"); auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
multi_devices_print_pass->SetNotOwned<const std::string>( multi_devices_print_pass->SetNotOwned<const std::string>(
...@@ -86,7 +86,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -86,7 +86,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
#else #else
const bool use_cuda) const { const bool use_cuda) const {
#endif #endif
// Create a default one if not intialized by user. // Create a default one if not initialized by user.
if (!pass_builder_) { if (!pass_builder_) {
CreatePassesFromStrategy(); CreatePassesFromStrategy();
} }
......
...@@ -69,8 +69,16 @@ struct BuildStrategy { ...@@ -69,8 +69,16 @@ struct BuildStrategy {
bool enable_data_balance_{false}; bool enable_data_balance_{false};
// The PassBuilder assembles passes based on the configs defined above.
// For example, if fuse_elewise_add_act_ops_ is true, the corresponding
// fuse pass will be added.
// The PassBuilder allows for more customized insert, remove of passes
// from python.
// A new PassBuilder is created and passes are owned by the PassBuilder.
std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy() const; std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy() const;
// Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph.
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> Apply(
const ProgramDesc &main_program, const ProgramDesc &main_program,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
......
...@@ -28,12 +28,16 @@ class PassBuilder { ...@@ -28,12 +28,16 @@ class PassBuilder {
virtual ~PassBuilder() {} virtual ~PassBuilder() {}
// Append a new pass to the end.
std::shared_ptr<Pass> AppendPass(const std::string& pass_type); std::shared_ptr<Pass> AppendPass(const std::string& pass_type);
// Insert a new pass after `idx`.
std::shared_ptr<Pass> InsertPass(size_t idx, const std::string& pass_type); std::shared_ptr<Pass> InsertPass(size_t idx, const std::string& pass_type);
// Remove a new pass at `idx`.
void RemovePass(size_t idx); void RemovePass(size_t idx);
// Returns a list of all passes.
std::vector<std::shared_ptr<Pass>> AllPasses() const { return passes_; } std::vector<std::shared_ptr<Pass>> AllPasses() const { return passes_; }
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册