提交 c2b3838c 编写于 作者: X Xin Pan

add some comments

上级 0d9ee0dc
......@@ -27,7 +27,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public:
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
: ir::PassBuilder(), strategy_(strategy) {
// Apply a graph viz pass to record a graph.
// Add a graph viz pass to record a graph.
if (!strategy_.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf(
......@@ -35,10 +35,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
}
// Apply op fusion.
// Add op fusion.
if (strategy.fuse_elewise_add_act_ops_) {
auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass");
// Apply a graph viz pass to record a graph.
// Add a graph viz pass to record a graph.
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf(
......@@ -53,7 +53,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
&strategy_);
// Apply a graph print pass to record a graph with device info.
// Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) {
auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
multi_devices_print_pass->SetNotOwned<const std::string>(
......@@ -86,7 +86,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
#else
const bool use_cuda) const {
#endif
// Create a default one if not intialized by user.
// Create a default one if not initialized by user.
if (!pass_builder_) {
CreatePassesFromStrategy();
}
......
......@@ -69,8 +69,16 @@ struct BuildStrategy {
bool enable_data_balance_{false};
// The PassBuilder assembles passes based on the configs defined above.
// For example, if fuse_elewise_add_act_ops_ is true, the corresponding
// fuse pass will be added.
// The PassBuilder allows for more customized insert, remove of passes
// from python.
// A new PassBuilder is created and passes are owned by the PassBuilder.
std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy() const;
// Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph.
std::unique_ptr<ir::Graph> Apply(
const ProgramDesc &main_program,
const std::vector<platform::Place> &places,
......
......@@ -28,12 +28,16 @@ class PassBuilder {
virtual ~PassBuilder() {}
// Append a new pass to the end.
std::shared_ptr<Pass> AppendPass(const std::string& pass_type);
// Insert a new pass after `idx`.
std::shared_ptr<Pass> InsertPass(size_t idx, const std::string& pass_type);
// Remove a new pass at `idx`.
void RemovePass(size_t idx);
// Returns a list of all passes.
std::vector<std::shared_ptr<Pass>> AllPasses() const { return passes_; }
protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册