提交 18442a60 编写于 作者: X Xin Pan

rename pass.h/.cc to analysis_pass

上级 9df2d8b5
......@@ -6,6 +6,7 @@ cc_library(analysis SRCS pass_manager.cc node.cc data_flow_graph.cc graph_traits
analyzer.cc
helper.cc
# passes
analysis_pass.cc
fluid_to_data_flow_graph_pass.cc
data_flow_graph_to_fluid_pass.cc
dfg_graphviz_draw_pass.cc
......
......@@ -12,4 +12,4 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
......@@ -28,10 +28,10 @@ namespace paddle {
namespace inference {
namespace analysis {
class Pass {
class AnalysisPass {
public:
Pass() = default;
virtual ~Pass() = default;
AnalysisPass() = default;
virtual ~AnalysisPass() = default;
// Mutable Pass.
virtual bool Initialize(Argument *argument) { return false; }
// Readonly Pass.
......@@ -42,23 +42,16 @@ class Pass {
virtual bool Finalize() { return false; }
// Get a Pass appropriate to print the Node this pass operates on.
virtual Pass *CreatePrinterPass(std::ostream &os,
virtual AnalysisPass *CreatePrinterPass(std::ostream &os,
const std::string &banner) const {
return nullptr;
}
// Create a debugger Pass that draw the DFG by graphviz toolkit.
virtual Pass *CreateGraphvizDebugerPass() const { return nullptr; }
virtual AnalysisPass *CreateGraphvizDebugerPass() const { return nullptr; }
virtual void Run() { LOG(FATAL) << "not valid"; }
// Run on a single Node.
virtual void Run(Node *x) { LOG(FATAL) << "not valid"; }
// Run on a single Function.
virtual void Run(Function *x) { LOG(FATAL) << "not valid"; }
// Run on a single FunctionBlock.
virtual void Run(FunctionBlock *x) { LOG(FATAL) << "not valid"; }
// Run on a single DataFlowGraph.
virtual void Run(DataFlowGraph *x) { LOG(FATAL) << "not valid"; }
virtual void Run(DataFlowGraph *x) = 0;
// Human-readable short representation.
virtual std::string repr() const = 0;
......@@ -66,29 +59,8 @@ class Pass {
virtual std::string description() const { return "No DOC"; }
};
// NodePass process on any Node types.
class NodePass : public Pass {
public:
virtual void Run(Node *node) = 0;
};
// NodePass process on any Function node types.
class FunctionPass : public Pass {
public:
virtual void Run(Function *node) = 0;
};
// NodePass process on any FunctionBlock node types.
class FunctionBlockPass : public Pass {
public:
virtual void Run(FunctionBlock *node) = 0;
};
// GraphPass processes on any GraphType.
class DataFlowGraphPass : public Pass {
public:
virtual void Run(DataFlowGraph *graph) = 0;
};
class DataFlowGraphPass : public AnalysisPass {};
} // namespace analysis
} // namespace inference
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/analysis/analyzer.h"
#include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
......@@ -58,7 +59,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
std::string description() const override { return "DFG pass manager."; }
private:
void AddPass(const std::string& name, Pass* pass) {
void AddPass(const std::string& name, AnalysisPass* pass) {
VLOG(3) << "Adding pass " << name;
Register(name, pass);
AddGraphvizDebugerPass(pass);
......@@ -87,7 +88,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
}
// Add the graphviz debuger pass if the parent pass has one.
void AddGraphvizDebugerPass(Pass* pass) {
void AddGraphvizDebugerPass(AnalysisPass* pass) {
auto* debuger_pass = pass->CreateGraphvizDebugerPass();
if (debuger_pass) {
Register(debuger_pass->repr(), debuger_pass);
......
......@@ -36,8 +36,11 @@ limitations under the License. */
*/
#include <gflags/gflags.h>
#include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/flags.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
namespace paddle {
......
......@@ -263,7 +263,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
};
} // namespace
Pass *DataFlowGraphToFluidPass::CreateGraphvizDebugerPass() const {
AnalysisPass *DataFlowGraphToFluidPass::CreateGraphvizDebugerPass() const {
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
FLAGS_IA_graphviz_log_root,
"data_flow_graph_to_fluid_graphviz_debugger"));
......
......@@ -21,8 +21,8 @@
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle {
namespace inference {
......@@ -42,7 +42,7 @@ class DataFlowGraphToFluidPass final : public DataFlowGraphPass {
return "Transform a DFG to a Fluid ProgramDesc";
}
Pass *CreateGraphvizDebugerPass() const override;
AnalysisPass *CreateGraphvizDebugerPass() const override;
protected:
// Add a Fluid Op into the ProgramDesc.
......
......@@ -21,8 +21,8 @@ limitations under the License. */
#include <fstream>
#include <string>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle {
namespace inference {
......
......@@ -66,7 +66,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
};
}
Pass *FluidToDataFlowGraphPass::CreateGraphvizDebugerPass() const {
AnalysisPass *FluidToDataFlowGraphPass::CreateGraphvizDebugerPass() const {
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
FLAGS_IA_graphviz_log_root, "fluid-to-dfg-debuger"));
}
......
......@@ -22,8 +22,8 @@
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle {
namespace inference {
......@@ -46,7 +46,7 @@ class FluidToDataFlowGraphPass final : public DataFlowGraphPass {
return "transform a fluid ProgramDesc to a data flow graph.";
}
Pass *CreateGraphvizDebugerPass() const override;
AnalysisPass *CreateGraphvizDebugerPass() const override;
private:
framework::proto::ProgramDesc const *desc_;
......
......@@ -14,15 +14,17 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/flags.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle {
namespace inference {
namespace analysis {
using namespace framework;
static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__";
......@@ -48,7 +50,8 @@ class FluidToIrPass final : public DataFlowGraphPass {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path);
// Load program.
auto program = LoadProgramDesc(*argument->fluid_model_program_path);
argument->origin_program_desc.reset(new proto::ProgramDesc(program));
argument->origin_program_desc.reset(
new framework::proto::ProgramDesc(program));
// Create main data flow graph.
if (!argument->main_dfg) {
argument->main_dfg.reset(new DataFlowGraph);
......@@ -78,12 +81,13 @@ class FluidToIrPass final : public DataFlowGraphPass {
IRPassManager ir_passes(argument_->Get<ProgramDesc>("ir_program_desc"),
nullptr);
// Pass the scope from analysis to IR if needed.
if (argument_->Has(ir::kParamScopeAttr)) {
if (argument_->Has(framework::ir::kParamScopeAttr)) {
// Here the address is passed, attention that IR doesn't own the scope, so
// the real scope in analysis should live during the IR phase.
ir_passes.graph().Set(
ir::kParamScopeAttr,
new Scope *(&argument_->Get<Scope>(ir::kParamScopeAttr)));
framework::ir::kParamScopeAttr,
new framework::Scope *(&argument_->Get<framework::Scope>(
framework::ir::kParamScopeAttr)));
}
if (FLAGS_IA_enable_ir) {
......@@ -95,12 +99,12 @@ class FluidToIrPass final : public DataFlowGraphPass {
PADDLE_ENFORCE(argument_->main_dfg.get());
argument_->main_dfg->Build(ir_passes.graph());
// inherit the arguments from ir.
if (ir_passes.graph().Has(ir::kFuseStatisAttr)) {
if (ir_passes.graph().Has(framework::ir::kFuseStatisAttr)) {
argument_->Set(
ir::kFuseStatisAttr,
framework::ir::kFuseStatisAttr,
new std::unordered_map<std::string, int>(
ir_passes.graph().Get<std::unordered_map<std::string, int>>(
ir::kFuseStatisAttr)));
framework::ir::kFuseStatisAttr)));
}
}
......@@ -112,7 +116,7 @@ class FluidToIrPass final : public DataFlowGraphPass {
private:
// Load parameters from a single file or from a directory.
bool LoadParams(Scope *scope, const std::string &dir,
bool LoadParams(framework::Scope *scope, const std::string &dir,
const std::string &prog_file, const std::string &param_file);
private:
......
......@@ -19,7 +19,7 @@
#pragma once
#include <string>
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
namespace paddle {
namespace inference {
......
......@@ -40,17 +40,6 @@ void DfgPassManager::RunAll() {
}
}
void NodePassManager::RunAll() {
PADDLE_ENFORCE(argument_);
PADDLE_ENFORCE(argument_->main_dfg.get());
auto trait = GraphTraits<DataFlowGraph>(*argument_->main_dfg).nodes_in_DFS();
for (auto& node : trait) {
for (auto& pass : data_) {
pass->Run(&node);
}
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -33,7 +33,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
namespace paddle {
namespace inference {
......@@ -43,7 +43,7 @@ namespace analysis {
* PassManager is the base class for all pass managers, a pass manager has
* several Pass-es registered, and execute them in the linear order.
*/
class PassManager : public OrderedRegistry<Pass> {
class PassManager : public OrderedRegistry<AnalysisPass> {
public:
PassManager() = default;
// Call all the passes' Initialize methods. The desc and data_flow_graph are
......@@ -89,18 +89,6 @@ class DfgPassManager : public PassManager {
virtual ~DfgPassManager() = default;
};
/*
* A pass manager that process a Node each time.
*/
class NodePassManager : public PassManager {
public:
NodePassManager() = default;
void RunAll() override;
virtual ~NodePassManager() = default;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -34,28 +34,6 @@ class TestDfgPassManager final : public DfgPassManager {
std::string description() const override { return "test doc"; }
};
class TestNodePassManager final : public NodePassManager {
public:
virtual ~TestNodePassManager() = default;
std::string repr() const override { return "test-node-pass-manager"; }
std::string description() const override { return "test doc"; }
};
class TestNodePass final : public NodePass {
public:
virtual ~TestNodePass() = default;
bool Initialize(Argument* argument) override { return true; }
void Run(Node* node) override {
LOG(INFO) << "- Processing node " << node->repr();
}
std::string repr() const override { return "test-node"; }
std::string description() const override { return "some doc"; }
};
TEST(PassManager, DFG_pass_manager) {
TestDfgPassManager manager;
DFG_GraphvizDrawPass::Config config("./", "dfg.dot");
......@@ -71,19 +49,6 @@ TEST(PassManager, DFG_pass_manager) {
manager.RunAll();
}
TEST(PassManager, Node_pass_manager) {
Argument argument(FLAGS_inference_model_dir);
// Pre-process: initialize the DFG with the ProgramDesc first.
FluidToDataFlowGraphPass pass0;
pass0.Initialize(&argument);
pass0.Run(argument.main_dfg.get());
TestNodePassManager manager;
manager.Register("test-node-pass", new TestNodePass);
ASSERT_TRUE(manager.Initialize(&argument));
manager.RunAll();
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -68,7 +68,7 @@ class DfgDebuggerPass : public DFG_GraphvizDrawPass {
}
};
Pass *TensorRTSubgraphNodeMarkPass::CreateGraphvizDebugerPass() const {
AnalysisPass *TensorRTSubgraphNodeMarkPass::CreateGraphvizDebugerPass() const {
DFG_GraphvizDrawPass::Config config(FLAGS_IA_graphviz_log_root,
"tensorrt_marked_node");
return new DfgDebuggerPass(config);
......
......@@ -20,7 +20,7 @@
#pragma once
#include <string>
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/subgraph_splitter.h"
namespace paddle {
......@@ -48,7 +48,7 @@ class TensorRTSubgraphNodeMarkPass : public DataFlowGraphPass {
return "tensorrt sub-graph mark pass";
}
Pass* CreateGraphvizDebugerPass() const override;
AnalysisPass* CreateGraphvizDebugerPass() const override;
bool Finalize() override;
private:
......
......@@ -15,8 +15,8 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/node.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/subgraph_splitter.h"
namespace paddle {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册