提交 b1401fb7 编写于 作者: Y Yiqun Liu 提交者: 石晓伟

Remove subgraph_detector from inference/analysis to the common framework/ir directory. (#22094)

test=develop
上级 50bee83f
......@@ -39,6 +39,7 @@ cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass)
......@@ -99,7 +100,7 @@ endif()
if(WITH_NGRAPH)
cc_library(ngraph_subgraph_pass SRCS ngraph_subgraph_pass.cc DEPS ngraph_bridge
analysis_helper subgraph_detector graph_pattern_detector pass fuse_pass_base ${op_library_DEPS})
subgraph_detector fuse_pass_base ${op_library_DEPS})
set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h)
file(APPEND ${pass_file} "USE_PASS(ngraph_subgraph_pass);\n")
set(INFER_IR_PASSES ${INFER_IR_PASSES} ngraph_subgraph_pass CACHE INTERNAL "")
......
......@@ -20,8 +20,7 @@
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/ngraph_subgraph_pass.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
......@@ -30,8 +29,6 @@ namespace paddle {
namespace framework {
namespace ir {
namespace ANAT = paddle::inference::analysis;
std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
const std::set<std::string> &engine_outputs,
const std::string &size) {
......@@ -59,19 +56,18 @@ void NgraphSubgraphPass::ApplyImpl(Graph *graph) const {
return !paddle::operators::NgraphBridge::isRegister(op_type);
};
ANAT::SubGraphFuser fuser(graph, teller, 0, "ngraph_engine");
SubGraphFuser fuser(graph, teller, 0, "ngraph_engine");
fuser();
for (auto *node : graph->Nodes()) {
if (node->IsOp() && !ANAT::Agent(node).subgraph()->empty()) {
if (node->IsOp() && !Agent(node).subgraph()->empty()) {
OpDesc *op_desc = node->Op();
op_desc->SetType("ngraph_engine");
CreateNgraphEngineOp(node, graph);
std::unordered_set<const Node *> nodes2remove(
ANAT::Agent(node).subgraph()->begin(),
ANAT::Agent(node).subgraph()->end());
Agent(node).subgraph()->begin(), Agent(node).subgraph()->end());
GraphSafeRemoveNodes(graph, nodes2remove);
}
......@@ -79,7 +75,7 @@ void NgraphSubgraphPass::ApplyImpl(Graph *graph) const {
std::unordered_set<const Node *> nodes2remove;
for (auto *node : graph->Nodes()) {
if (node->IsOp() && ANAT::Agent(node).deleted()) {
if (node->IsOp() && Agent(node).deleted()) {
nodes2remove.insert(node);
}
}
......@@ -116,7 +112,7 @@ void UpdateNgraphIO(Node *node, Graph *graph,
return;
}
auto &subgraph = *ANAT::Agent(node).subgraph();
auto &subgraph = *Agent(node).subgraph();
std::unordered_set<std::string> inputs;
std::unordered_set<std::string> outputs;
for (auto *node : subgraph) {
......@@ -138,7 +134,7 @@ void UpdateNgraphIO(Node *node, Graph *graph,
}
void NgraphSubgraphPass::CreateNgraphEngineOp(Node *node, Graph *graph) const {
auto &subgraph = *ANAT::Agent(node).subgraph();
auto &subgraph = *Agent(node).subgraph();
PADDLE_ENFORCE_NE(subgraph.empty(), true, "subgraph cannot be empty");
framework::proto::BlockDesc block_proto;
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
......@@ -24,10 +24,8 @@ limitations under the License. */
DECLARE_bool(use_ngraph);
namespace paddle {
namespace inference {
namespace analysis {
using framework::ir::Node;
namespace framework {
namespace ir {
std::pair<std::vector<Node *>, std::vector<Node *>>
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
......@@ -469,6 +467,6 @@ inline bool CheckNodeIndegreeEquals(const Node &node, size_t n) {
return node.inputs.size() == n;
}
} // namespace analysis
} // namespace inference
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file defines the the class to partition a graph.
*/
#pragma once
#include <string>
......@@ -23,15 +19,10 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/helper.h"
namespace paddle {
namespace inference {
namespace analysis {
using framework::ir::Graph;
using framework::ir::NodesTSIterator;
namespace framework {
namespace ir {
const char kIsFunctionNode[] = "__is_function_node__";
const char kFunctionNodeSubGraph[] = "__function_node_sub_graph__";
......@@ -45,13 +36,12 @@ const char kSubgraphSplitterMarkerAttrName[] =
class SubgraphDetector {
public:
// Tell whether a node is inside a sub-graph.
using NodeInsideSubgraphTeller =
std::function<bool(const framework::ir::Node *)>;
using NodeInsideSubgraphTeller = std::function<bool(const Node *)>;
SubgraphDetector(Graph *graph, const NodeInsideSubgraphTeller &teller)
: graph_(graph), node_inside_subgraph_teller_(teller) {}
std::vector<std::vector<framework::ir::Node *>> operator()();
std::vector<std::vector<Node *>> operator()();
protected:
// Mark the nodes inside the accepted sub-graph using
......@@ -59,7 +49,7 @@ class SubgraphDetector {
void MarkNodesInsideSubGraph();
// Merge the marked nodes into sub-graphs and return the sub-graphs.
std::vector<std::vector<framework::ir::Node *>> ExtractSubGraphs();
std::vector<std::vector<Node *>> ExtractSubGraphs();
private:
Graph *graph_;
......@@ -99,14 +89,14 @@ struct NodeWrapper {
bool deleted{false};
bool marked{false};
int union_find_parent{-1};
std::vector<framework::ir::Node *> subgraph;
std::vector<Node *> subgraph;
};
/*
* ir::Node agent for subgraph detector.
*/
struct Agent {
explicit Agent(framework::ir::Node *x) : x_(x) {}
explicit Agent(Node *x) : x_(x) {}
NodeWrapper &wrapper() {
if (!x_->IsWrappedBy<NodeWrapper>()) {
......@@ -128,17 +118,17 @@ struct Agent {
int union_find_parent() { return wrapper().union_find_parent; }
void set_union_find_parent(int v) { wrapper().union_find_parent = v; }
std::vector<framework::ir::Node *> *subgraph() { return &wrapper().subgraph; }
std::vector<framework::ir::Node *> &inputs() { return x_->inputs; }
std::vector<framework::ir::Node *> &outputs() { return x_->outputs; }
std::vector<Node *> *subgraph() { return &wrapper().subgraph; }
std::vector<Node *> &inputs() { return x_->inputs; }
std::vector<Node *> &outputs() { return x_->outputs; }
private:
framework::ir::Node *x_;
Node *x_;
};
// The nodes those have no input will be treated as start points.
static std::vector<framework::ir::Node *> ExtractStartPoints(const Graph &g) {
std::vector<framework::ir::Node *> result;
static std::vector<Node *> ExtractStartPoints(const Graph &g) {
std::vector<Node *> result;
for (auto *node : g.Nodes()) {
if (node->inputs.empty()) {
result.push_back(node);
......@@ -149,12 +139,16 @@ static std::vector<framework::ir::Node *> ExtractStartPoints(const Graph &g) {
static iterator_range<NodesTSIterator> TopologicalSort(const Graph &g) {
auto start_points = ExtractStartPoints(g);
PADDLE_ENFORCE(!start_points.empty());
PADDLE_ENFORCE_GT(
start_points.size(), 0U,
platform::errors::InvalidArgument(
"Expected the number of graph's start points >= 1. Expected %d.",
start_points.size()));
NodesTSIterator x(start_points);
return iterator_range<NodesTSIterator>(NodesTSIterator(start_points),
NodesTSIterator());
}
} // namespace analysis
} // namespace inference
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -24,7 +24,6 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
......
cc_library(subgraph_detector SRCS subgraph_detector.cc subgraph_util.cc DEPS proto_desc)
if(WITH_TESTING)
add_dependencies(subgraph_detector gtest)
endif()
cc_library(subgraph_util SRCS subgraph_util.cc DEPS subgraph_detector)
if (WITH_GPU AND TENSORRT_FOUND)
cc_library(tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass.cc DEPS subgraph_detector tensorrt_op_teller)
cc_library(tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass.cc DEPS subgraph_util tensorrt_op_teller)
set(analysis_deps ${analysis_deps}
subgraph_detector tensorrt_subgraph_pass
subgraph_util tensorrt_subgraph_pass
CACHE INTERNAL "")
set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h)
......@@ -16,10 +13,10 @@ if (WITH_GPU AND TENSORRT_FOUND)
endif()
if (ANAKIN_SUBGRAPH)
cc_library(anakin_subgraph_pass SRCS anakin_subgraph_pass.cc DEPS subgraph_detector anakin_op_teller)
cc_library(anakin_subgraph_pass SRCS anakin_subgraph_pass.cc DEPS subgraph_util anakin_op_teller)
set(analysis_deps ${analysis_deps}
subgraph_detector anakin_subgraph_pass
subgraph_util anakin_subgraph_pass
CACHE INTERNAL "")
set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h)
......
......@@ -22,11 +22,11 @@
#include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/op_teller.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
......@@ -50,7 +50,7 @@ void analysis::AnakinSubgraphPass::ApplyImpl(
return anakin::OpTeller::Global().Tell(node->Op()->Type(), *node->Op());
};
SubGraphFuser fuser(graph, teller, 6 /* min_subgraph_size */);
framework::ir::SubGraphFuser fuser(graph, teller, 6 /* min_subgraph_size */);
fuser();
std::vector<std::string> graph_param_names =
......@@ -61,17 +61,18 @@ void analysis::AnakinSubgraphPass::ApplyImpl(
std::vector<std::string> repetitive_params;
for (auto *node : graph->Nodes()) {
if (node->IsOp() && !Agent(node).subgraph()->empty()) {
if (node->IsOp() && !framework::ir::Agent(node).subgraph()->empty()) {
CreateAnakinOp(node, graph, graph_param_names, &repetitive_params);
std::unordered_set<const Node *> nodes2remove(
Agent(node).subgraph()->begin(), Agent(node).subgraph()->end());
framework::ir::Agent(node).subgraph()->begin(),
framework::ir::Agent(node).subgraph()->end());
framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
}
}
std::unordered_set<const Node *> nodes2remove;
for (auto *node : graph->Nodes()) {
if (node->IsOp() && Agent(node).deleted()) {
if (node->IsOp() && framework::ir::Agent(node).deleted()) {
nodes2remove.insert(node);
}
}
......@@ -96,11 +97,11 @@ std::string GenerateAnakinEngineKey(const std::set<std::string> &engine_inputs,
}
void AnakinSubgraphPass::CreateAnakinOp(
framework::ir::Node *node, Graph *graph,
framework::ir::Node *node, framework::ir::Graph *graph,
const std::vector<std::string> &graph_params,
std::vector<std::string> *repetitive_params) const {
auto *op_desc = node->Op();
auto &subgraph = *Agent(node).subgraph();
auto &subgraph = *framework::ir::Agent(node).subgraph();
PADDLE_ENFORCE(!subgraph.empty());
framework::ProgramDesc *program_desc =
......@@ -164,7 +165,7 @@ void AnakinSubgraphPass::CreateAnakinOp(
graph_var_map[node->Name()] = node;
}
}
auto &subgraph_nodes = *Agent(node).subgraph();
auto &subgraph_nodes = *framework::ir::Agent(node).subgraph();
// The following procedure is used to rename all the intermediate
// variables and the output variables of the subgraph.
......
......@@ -17,8 +17,8 @@
#include <set>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
......@@ -40,8 +40,8 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op());
};
SubGraphFuser fuser(graph, teller,
Get<int>("min_subgraph_size") /*min subgraph size*/,
framework::ir::SubGraphFuser fuser(
graph, teller, Get<int>("min_subgraph_size") /*min subgraph size*/,
"tensorrt_engine");
fuser();
......@@ -52,18 +52,19 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
std::vector<std::string> repetitive_params;
for (auto *node : graph->Nodes()) {
if (node->IsOp() && !Agent(node).subgraph()->empty()) {
if (node->IsOp() && !framework::ir::Agent(node).subgraph()->empty()) {
CreateTensorRTOp(node, graph, graph_param_names, &repetitive_params);
std::unordered_set<const Node *> nodes2remove(
Agent(node).subgraph()->begin(), Agent(node).subgraph()->end());
framework::ir::Agent(node).subgraph()->begin(),
framework::ir::Agent(node).subgraph()->end());
framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
}
}
std::unordered_set<const Node *> nodes2remove;
for (auto *node : graph->Nodes()) {
if (node->IsOp() && Agent(node).deleted()) {
if (node->IsOp() && framework::ir::Agent(node).deleted()) {
nodes2remove.insert(node);
}
}
......@@ -88,11 +89,11 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
}
void TensorRtSubgraphPass::CreateTensorRTOp(
framework::ir::Node *node, Graph *graph,
framework::ir::Node *node, framework::ir::Graph *graph,
const std::vector<std::string> &graph_params,
std::vector<std::string> *repetitive_params) const {
auto *op_desc = node->Op();
auto &subgraph = *Agent(node).subgraph();
auto &subgraph = *framework::ir::Agent(node).subgraph();
PADDLE_ENFORCE(!subgraph.empty());
framework::ProgramDesc *program_desc =
......@@ -161,7 +162,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
if (precision_mode == AnalysisConfig::Precision::kHalf) enable_fp16 = true;
auto enable_int8 = Get<bool>("enable_int8");
auto use_calib_mode = Get<bool>("use_calib_mode");
auto &subgraph_nodes = *Agent(node).subgraph();
auto &subgraph_nodes = *framework::ir::Agent(node).subgraph();
// The following procedure is used to rename all the intermediate
// variables and the output variables of the subgraph.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册