提交 18442a60 编写于 作者: X Xin Pan

rename pass.h/.cc to analysis_pass

上级 9df2d8b5
...@@ -6,6 +6,7 @@ cc_library(analysis SRCS pass_manager.cc node.cc data_flow_graph.cc graph_traits ...@@ -6,6 +6,7 @@ cc_library(analysis SRCS pass_manager.cc node.cc data_flow_graph.cc graph_traits
analyzer.cc analyzer.cc
helper.cc helper.cc
# passes # passes
analysis_pass.cc
fluid_to_data_flow_graph_pass.cc fluid_to_data_flow_graph_pass.cc
data_flow_graph_to_fluid_pass.cc data_flow_graph_to_fluid_pass.cc
dfg_graphviz_draw_pass.cc dfg_graphviz_draw_pass.cc
......
...@@ -12,4 +12,4 @@ ...@@ -12,4 +12,4 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/pass.h" #include "paddle/fluid/inference/analysis/analysis_pass.h"
...@@ -28,10 +28,10 @@ namespace paddle { ...@@ -28,10 +28,10 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
class Pass { class AnalysisPass {
public: public:
Pass() = default; AnalysisPass() = default;
virtual ~Pass() = default; virtual ~AnalysisPass() = default;
// Mutable Pass. // Mutable Pass.
virtual bool Initialize(Argument *argument) { return false; } virtual bool Initialize(Argument *argument) { return false; }
// Readonly Pass. // Readonly Pass.
...@@ -42,23 +42,16 @@ class Pass { ...@@ -42,23 +42,16 @@ class Pass {
virtual bool Finalize() { return false; } virtual bool Finalize() { return false; }
// Get a Pass appropriate to print the Node this pass operates on. // Get a Pass appropriate to print the Node this pass operates on.
virtual Pass *CreatePrinterPass(std::ostream &os, virtual AnalysisPass *CreatePrinterPass(std::ostream &os,
const std::string &banner) const { const std::string &banner) const {
return nullptr; return nullptr;
} }
// Create a debugger Pass that draw the DFG by graphviz toolkit. // Create a debugger Pass that draw the DFG by graphviz toolkit.
virtual Pass *CreateGraphvizDebugerPass() const { return nullptr; } virtual AnalysisPass *CreateGraphvizDebugerPass() const { return nullptr; }
virtual void Run() { LOG(FATAL) << "not valid"; }
// Run on a single Node.
virtual void Run(Node *x) { LOG(FATAL) << "not valid"; }
// Run on a single Function.
virtual void Run(Function *x) { LOG(FATAL) << "not valid"; }
// Run on a single FunctionBlock.
virtual void Run(FunctionBlock *x) { LOG(FATAL) << "not valid"; }
// Run on a single DataFlowGraph. // Run on a single DataFlowGraph.
virtual void Run(DataFlowGraph *x) { LOG(FATAL) << "not valid"; } virtual void Run(DataFlowGraph *x) = 0;
// Human-readable short representation. // Human-readable short representation.
virtual std::string repr() const = 0; virtual std::string repr() const = 0;
...@@ -66,29 +59,8 @@ class Pass { ...@@ -66,29 +59,8 @@ class Pass {
virtual std::string description() const { return "No DOC"; } virtual std::string description() const { return "No DOC"; }
}; };
// NodePass process on any Node types.
class NodePass : public Pass {
public:
virtual void Run(Node *node) = 0;
};
// NodePass process on any Function node types.
class FunctionPass : public Pass {
public:
virtual void Run(Function *node) = 0;
};
// NodePass process on any FunctionBlock node types.
class FunctionBlockPass : public Pass {
public:
virtual void Run(FunctionBlock *node) = 0;
};
// GraphPass processes on any GraphType. // GraphPass processes on any GraphType.
class DataFlowGraphPass : public Pass { class DataFlowGraphPass : public AnalysisPass {};
public:
virtual void Run(DataFlowGraph *graph) = 0;
};
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" #include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" #include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
...@@ -58,7 +59,7 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -58,7 +59,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
std::string description() const override { return "DFG pass manager."; } std::string description() const override { return "DFG pass manager."; }
private: private:
void AddPass(const std::string& name, Pass* pass) { void AddPass(const std::string& name, AnalysisPass* pass) {
VLOG(3) << "Adding pass " << name; VLOG(3) << "Adding pass " << name;
Register(name, pass); Register(name, pass);
AddGraphvizDebugerPass(pass); AddGraphvizDebugerPass(pass);
...@@ -87,7 +88,7 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -87,7 +88,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
} }
// Add the graphviz debuger pass if the parent pass has one. // Add the graphviz debuger pass if the parent pass has one.
void AddGraphvizDebugerPass(Pass* pass) { void AddGraphvizDebugerPass(AnalysisPass* pass) {
auto* debuger_pass = pass->CreateGraphvizDebugerPass(); auto* debuger_pass = pass->CreateGraphvizDebugerPass();
if (debuger_pass) { if (debuger_pass) {
Register(debuger_pass->repr(), debuger_pass); Register(debuger_pass->repr(), debuger_pass);
......
...@@ -36,8 +36,11 @@ limitations under the License. */ ...@@ -36,8 +36,11 @@ limitations under the License. */
*/ */
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/flags.h" #include "paddle/fluid/inference/analysis/flags.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h" #include "paddle/fluid/inference/analysis/pass_manager.h"
namespace paddle { namespace paddle {
......
...@@ -263,7 +263,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass { ...@@ -263,7 +263,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
}; };
} // namespace } // namespace
Pass *DataFlowGraphToFluidPass::CreateGraphvizDebugerPass() const { AnalysisPass *DataFlowGraphToFluidPass::CreateGraphvizDebugerPass() const {
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config( return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
FLAGS_IA_graphviz_log_root, FLAGS_IA_graphviz_log_root,
"data_flow_graph_to_fluid_graphviz_debugger")); "data_flow_graph_to_fluid_graphviz_debugger"));
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
#include <string> #include <string>
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -42,7 +42,7 @@ class DataFlowGraphToFluidPass final : public DataFlowGraphPass { ...@@ -42,7 +42,7 @@ class DataFlowGraphToFluidPass final : public DataFlowGraphPass {
return "Transform a DFG to a Fluid ProgramDesc"; return "Transform a DFG to a Fluid ProgramDesc";
} }
Pass *CreateGraphvizDebugerPass() const override; AnalysisPass *CreateGraphvizDebugerPass() const override;
protected: protected:
// Add a Fluid Op into the ProgramDesc. // Add a Fluid Op into the ProgramDesc.
......
...@@ -21,8 +21,8 @@ limitations under the License. */ ...@@ -21,8 +21,8 @@ limitations under the License. */
#include <fstream> #include <fstream>
#include <string> #include <string>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -66,7 +66,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass { ...@@ -66,7 +66,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
}; };
} }
Pass *FluidToDataFlowGraphPass::CreateGraphvizDebugerPass() const { AnalysisPass *FluidToDataFlowGraphPass::CreateGraphvizDebugerPass() const {
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config( return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
FLAGS_IA_graphviz_log_root, "fluid-to-dfg-debuger")); FLAGS_IA_graphviz_log_root, "fluid-to-dfg-debuger"));
} }
......
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
#include <string> #include <string>
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -46,7 +46,7 @@ class FluidToDataFlowGraphPass final : public DataFlowGraphPass { ...@@ -46,7 +46,7 @@ class FluidToDataFlowGraphPass final : public DataFlowGraphPass {
return "transform a fluid ProgramDesc to a data flow graph."; return "transform a fluid ProgramDesc to a data flow graph.";
} }
Pass *CreateGraphvizDebugerPass() const override; AnalysisPass *CreateGraphvizDebugerPass() const override;
private: private:
framework::proto::ProgramDesc const *desc_; framework::proto::ProgramDesc const *desc_;
......
...@@ -14,15 +14,17 @@ ...@@ -14,15 +14,17 @@
#pragma once #pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/flags.h" #include "paddle/fluid/inference/analysis/flags.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
using namespace framework;
static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__"; static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__";
...@@ -48,7 +50,8 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -48,7 +50,8 @@ class FluidToIrPass final : public DataFlowGraphPass {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path); ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path);
// Load program. // Load program.
auto program = LoadProgramDesc(*argument->fluid_model_program_path); auto program = LoadProgramDesc(*argument->fluid_model_program_path);
argument->origin_program_desc.reset(new proto::ProgramDesc(program)); argument->origin_program_desc.reset(
new framework::proto::ProgramDesc(program));
// Create main data flow graph. // Create main data flow graph.
if (!argument->main_dfg) { if (!argument->main_dfg) {
argument->main_dfg.reset(new DataFlowGraph); argument->main_dfg.reset(new DataFlowGraph);
...@@ -78,12 +81,13 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -78,12 +81,13 @@ class FluidToIrPass final : public DataFlowGraphPass {
IRPassManager ir_passes(argument_->Get<ProgramDesc>("ir_program_desc"), IRPassManager ir_passes(argument_->Get<ProgramDesc>("ir_program_desc"),
nullptr); nullptr);
// Pass the scope from analysis to IR if needed. // Pass the scope from analysis to IR if needed.
if (argument_->Has(ir::kParamScopeAttr)) { if (argument_->Has(framework::ir::kParamScopeAttr)) {
// Here the address is passed, attention that IR doesn't own the scope, so // Here the address is passed, attention that IR doesn't own the scope, so
// the real scope in analysis should live during the IR phase. // the real scope in analysis should live during the IR phase.
ir_passes.graph().Set( ir_passes.graph().Set(
ir::kParamScopeAttr, framework::ir::kParamScopeAttr,
new Scope *(&argument_->Get<Scope>(ir::kParamScopeAttr))); new framework::Scope *(&argument_->Get<framework::Scope>(
framework::ir::kParamScopeAttr)));
} }
if (FLAGS_IA_enable_ir) { if (FLAGS_IA_enable_ir) {
...@@ -95,12 +99,12 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -95,12 +99,12 @@ class FluidToIrPass final : public DataFlowGraphPass {
PADDLE_ENFORCE(argument_->main_dfg.get()); PADDLE_ENFORCE(argument_->main_dfg.get());
argument_->main_dfg->Build(ir_passes.graph()); argument_->main_dfg->Build(ir_passes.graph());
// inherit the arguments from ir. // inherit the arguments from ir.
if (ir_passes.graph().Has(ir::kFuseStatisAttr)) { if (ir_passes.graph().Has(framework::ir::kFuseStatisAttr)) {
argument_->Set( argument_->Set(
ir::kFuseStatisAttr, framework::ir::kFuseStatisAttr,
new std::unordered_map<std::string, int>( new std::unordered_map<std::string, int>(
ir_passes.graph().Get<std::unordered_map<std::string, int>>( ir_passes.graph().Get<std::unordered_map<std::string, int>>(
ir::kFuseStatisAttr))); framework::ir::kFuseStatisAttr)));
} }
} }
...@@ -112,7 +116,7 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -112,7 +116,7 @@ class FluidToIrPass final : public DataFlowGraphPass {
private: private:
// Load parameters from a single file or from a directory. // Load parameters from a single file or from a directory.
bool LoadParams(Scope *scope, const std::string &dir, bool LoadParams(framework::Scope *scope, const std::string &dir,
const std::string &prog_file, const std::string &param_file); const std::string &prog_file, const std::string &param_file);
private: private:
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/inference/analysis/pass.h" #include "paddle/fluid/inference/analysis/analysis_pass.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -40,17 +40,6 @@ void DfgPassManager::RunAll() { ...@@ -40,17 +40,6 @@ void DfgPassManager::RunAll() {
} }
} }
void NodePassManager::RunAll() {
PADDLE_ENFORCE(argument_);
PADDLE_ENFORCE(argument_->main_dfg.get());
auto trait = GraphTraits<DataFlowGraph>(*argument_->main_dfg).nodes_in_DFS();
for (auto& node : trait) {
for (auto& pass : data_) {
pass->Run(&node);
}
}
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -33,7 +33,7 @@ limitations under the License. */ ...@@ -33,7 +33,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/pass.h" #include "paddle/fluid/inference/analysis/analysis_pass.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -43,7 +43,7 @@ namespace analysis { ...@@ -43,7 +43,7 @@ namespace analysis {
* PassManager is the base class for all pass managers, a pass manager has * PassManager is the base class for all pass managers, a pass manager has
* several Pass-es registered, and execute them in the linear order. * several Pass-es registered, and execute them in the linear order.
*/ */
class PassManager : public OrderedRegistry<Pass> { class PassManager : public OrderedRegistry<AnalysisPass> {
public: public:
PassManager() = default; PassManager() = default;
// Call all the passes' Initialize methods. The desc and data_flow_graph are // Call all the passes' Initialize methods. The desc and data_flow_graph are
...@@ -89,18 +89,6 @@ class DfgPassManager : public PassManager { ...@@ -89,18 +89,6 @@ class DfgPassManager : public PassManager {
virtual ~DfgPassManager() = default; virtual ~DfgPassManager() = default;
}; };
/*
* A pass manager that process a Node each time.
*/
class NodePassManager : public PassManager {
public:
NodePassManager() = default;
void RunAll() override;
virtual ~NodePassManager() = default;
};
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -34,28 +34,6 @@ class TestDfgPassManager final : public DfgPassManager { ...@@ -34,28 +34,6 @@ class TestDfgPassManager final : public DfgPassManager {
std::string description() const override { return "test doc"; } std::string description() const override { return "test doc"; }
}; };
class TestNodePassManager final : public NodePassManager {
public:
virtual ~TestNodePassManager() = default;
std::string repr() const override { return "test-node-pass-manager"; }
std::string description() const override { return "test doc"; }
};
class TestNodePass final : public NodePass {
public:
virtual ~TestNodePass() = default;
bool Initialize(Argument* argument) override { return true; }
void Run(Node* node) override {
LOG(INFO) << "- Processing node " << node->repr();
}
std::string repr() const override { return "test-node"; }
std::string description() const override { return "some doc"; }
};
TEST(PassManager, DFG_pass_manager) { TEST(PassManager, DFG_pass_manager) {
TestDfgPassManager manager; TestDfgPassManager manager;
DFG_GraphvizDrawPass::Config config("./", "dfg.dot"); DFG_GraphvizDrawPass::Config config("./", "dfg.dot");
...@@ -71,19 +49,6 @@ TEST(PassManager, DFG_pass_manager) { ...@@ -71,19 +49,6 @@ TEST(PassManager, DFG_pass_manager) {
manager.RunAll(); manager.RunAll();
} }
TEST(PassManager, Node_pass_manager) {
Argument argument(FLAGS_inference_model_dir);
// Pre-process: initialize the DFG with the ProgramDesc first.
FluidToDataFlowGraphPass pass0;
pass0.Initialize(&argument);
pass0.Run(argument.main_dfg.get());
TestNodePassManager manager;
manager.Register("test-node-pass", new TestNodePass);
ASSERT_TRUE(manager.Initialize(&argument));
manager.RunAll();
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -68,7 +68,7 @@ class DfgDebuggerPass : public DFG_GraphvizDrawPass { ...@@ -68,7 +68,7 @@ class DfgDebuggerPass : public DFG_GraphvizDrawPass {
} }
}; };
Pass *TensorRTSubgraphNodeMarkPass::CreateGraphvizDebugerPass() const { AnalysisPass *TensorRTSubgraphNodeMarkPass::CreateGraphvizDebugerPass() const {
DFG_GraphvizDrawPass::Config config(FLAGS_IA_graphviz_log_root, DFG_GraphvizDrawPass::Config config(FLAGS_IA_graphviz_log_root,
"tensorrt_marked_node"); "tensorrt_marked_node");
return new DfgDebuggerPass(config); return new DfgDebuggerPass(config);
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/inference/analysis/pass.h" #include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/subgraph_splitter.h" #include "paddle/fluid/inference/analysis/subgraph_splitter.h"
namespace paddle { namespace paddle {
...@@ -48,7 +48,7 @@ class TensorRTSubgraphNodeMarkPass : public DataFlowGraphPass { ...@@ -48,7 +48,7 @@ class TensorRTSubgraphNodeMarkPass : public DataFlowGraphPass {
return "tensorrt sub-graph mark pass"; return "tensorrt sub-graph mark pass";
} }
Pass* CreateGraphvizDebugerPass() const override; AnalysisPass* CreateGraphvizDebugerPass() const override;
bool Finalize() override; bool Finalize() override;
private: private:
......
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/node.h" #include "paddle/fluid/inference/analysis/node.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/subgraph_splitter.h" #include "paddle/fluid/inference/analysis/subgraph_splitter.h"
namespace paddle { namespace paddle {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册