提交 03ff4f68 编写于 作者: N nhzlx

fix subgraph bug!

上级 5ec2fb0c
...@@ -440,6 +440,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT ...@@ -440,6 +440,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
} }
return false; return false;
}; };
for (auto &node : graph) { for (auto &node : graph) {
for (auto *in : node->inlinks) { for (auto *in : node->inlinks) {
// The Value that is written by nodes inside a sub-graph shouldn't be the // The Value that is written by nodes inside a sub-graph shouldn't be the
...@@ -459,6 +460,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT ...@@ -459,6 +460,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
std::vector<Node *>(outputs.begin(), outputs.end())); std::vector<Node *>(outputs.begin(), outputs.end()));
} }
// Filter the Intermediate results of the subgraph node.
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) { void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
std::vector<Node *> op_nodes; std::vector<Node *> op_nodes;
for (auto &node : GraphTraits<DataFlowGraph>(*graph).nodes_in_TS()) { for (auto &node : GraphTraits<DataFlowGraph>(*graph).nodes_in_TS()) {
...@@ -484,46 +486,11 @@ void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) { ...@@ -484,46 +486,11 @@ void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
out->SetDeleted(); out->SetDeleted();
} }
} }
PADDLE_ENFORCE_GE(filtered_subgraph_outlinks.size(), 1UL); // The filtered_subgraph_outlinks may be empty.
op_nodes[i]->outlinks = filtered_subgraph_outlinks; op_nodes[i]->outlinks = filtered_subgraph_outlinks;
} }
} }
void FlexibleDFS(const std::vector<Node *> &source, bool reverse,
const std::function<bool(const Node *)> &enter,
const std::function<bool(const Node *)> &leave) {
typedef struct {
const Node *node;
bool leave;
} FNode;
std::vector<FNode> stack;
for (auto &node : source) {
stack.push_back(FNode{node, false});
}
std::unordered_set<const Node *> visited;
while (!stack.empty()) {
auto fnode = stack.back();
stack.pop_back();
if (fnode.leave) {
if (leave && !leave(fnode.node)) return;
}
if (visited.count(fnode.node)) continue;
visited.insert(fnode.node);
if (enter && !enter(fnode.node)) return;
if (leave) stack.push_back(FNode{fnode.node, true});
const std::vector<Node *> iter_nodes =
reverse == true ? fnode.node->inlinks : fnode.node->outlinks;
for (const Node *node : iter_nodes) {
if (!visited.count(node)) {
stack.push_back(FNode{node, false});
}
}
}
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -204,9 +204,6 @@ std::pair<std::vector<Node *>, std::vector<Node *>> ...@@ -204,9 +204,6 @@ std::pair<std::vector<Node *>, std::vector<Node *>>
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph); // NOLINT ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph); // NOLINT
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph); void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph);
void FlexibleDFS(const std::vector<Node *> &source, bool reverse,
const std::function<bool(const Node *)> &enter,
const std::function<bool(const Node *)> &leave);
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -106,20 +106,23 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, ...@@ -106,20 +106,23 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
// collect inputs // collect inputs
std::unordered_set<std::string> input_names; std::unordered_set<std::string> input_names;
std::unordered_set<std::string> input_names_with_id;
for (auto *x : func->inlinks) { for (auto *x : func->inlinks) {
input_names.insert(x->name()); input_names.insert(x->name());
input_names_with_id.insert(x->name() + std::to_string(x->id()));
} }
desc.SetInput( desc.SetInput(
"Xs", std::vector<std::string>(input_names.begin(), input_names.end())); "Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
std::unordered_set<std::string> output_names; std::unordered_set<std::string> output_names;
std::unordered_set<std::string> output_names_with_id;
for (auto *x : func->outlinks) { for (auto *x : func->outlinks) {
output_names.insert(x->name()); output_names.insert(x->name());
output_names_with_id.insert(x->name() + std::to_string(x->id()));
} }
std::vector<std::string> output_temp(output_names.begin(), desc.SetOutput(
output_names.end()); "Ys", std::vector<std::string>(output_names.begin(), output_names.end()));
desc.SetOutput("Ys", output_temp);
desc.SetType("tensorrt_engine"); desc.SetType("tensorrt_engine");
std::unordered_map<std::string, std::string> output_name_map; std::unordered_map<std::string, std::string> output_name_map;
...@@ -153,11 +156,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, ...@@ -153,11 +156,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
std::vector<std::string> replaced_names; std::vector<std::string> replaced_names;
for (int k = 0; k < in_var->arguments_size(); k++) { for (int k = 0; k < in_var->arguments_size(); k++) {
std::string arg_value = in_var->arguments(k); std::string arg_value = in_var->arguments(k);
if (input_names.count(arg_value)) { std::string arg_value_with_id =
arg_value + std::to_string(var2id[arg_value]);
if (input_names_with_id.count(arg_value_with_id)) {
replaced_names.push_back(arg_value); replaced_names.push_back(arg_value);
} else { } else {
replaced_names.push_back(arg_value + replaced_names.push_back(arg_value_with_id);
std::to_string(var2id[arg_value]));
} }
} }
in_var->clear_arguments(); in_var->clear_arguments();
...@@ -176,11 +180,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, ...@@ -176,11 +180,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
std::vector<std::string> replaced_names; std::vector<std::string> replaced_names;
for (int k = 0; k < out_var->arguments_size(); k++) { for (int k = 0; k < out_var->arguments_size(); k++) {
std::string arg_value = out_var->arguments(k); std::string arg_value = out_var->arguments(k);
if (output_names.count(arg_value)) { std::string arg_value_with_id =
output_name_map[arg_value] = arg_value + std::to_string(var2id[arg_value]);
arg_value + std::to_string(var2id[arg_value]); if (output_names_with_id.count(arg_value_with_id)) {
output_name_map[arg_value] = arg_value_with_id;
} }
replaced_names.push_back(arg_value + std::to_string(var2id[arg_value])); replaced_names.push_back(arg_value_with_id);
} }
out_var->clear_arguments(); out_var->clear_arguments();
for (size_t k = 0; k < replaced_names.size(); k++) { for (size_t k = 0; k < replaced_names.size(); k++) {
......
...@@ -74,13 +74,126 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) { ...@@ -74,13 +74,126 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
node_map.at(b)->attr(kUnionFindParent).Int32() = a_ancestor; node_map.at(b)->attr(kUnionFindParent).Int32() = a_ancestor;
} }
// This is a simple representation of a graph.
// The BriefNode hold the pointer of the Node.
// This is to avoid changing the original graph
// in the process of trt graph analysis.
struct BriefNode {
explicit BriefNode(Node *n) { node = n; }
Node *node;
std::vector<BriefNode *> inlinks;
std::vector<BriefNode *> outlinks;
};
void UnionContractedNodes(const std::unordered_map<int, BriefNode *> &node_map,
int src_id, int dst_id) {
// merge the two adjacent nodes into one node.
BriefNode *src_node = node_map.at(src_id);
BriefNode *dst_node = node_map.at(dst_id);
std::unordered_set<BriefNode *> inputs(src_node->inlinks.begin(),
src_node->inlinks.end());
std::unordered_set<BriefNode *> outputs;
for (auto *n : src_node->outlinks) {
if (n != dst_node) outputs.insert(n);
}
// Add the inlinks and outlinks of dst node to src node.
std::vector<BriefNode *> dst_in_nodes = dst_node->inlinks;
for (BriefNode *node : dst_in_nodes) {
if (node != src_node) {
inputs.insert(node);
}
}
std::vector<BriefNode *> dst_out_nodes = dst_node->outlinks;
for (BriefNode *node : dst_out_nodes) {
outputs.insert(node);
}
// update the dst and src node's inlinks and outlinks.
src_node->inlinks =
std::move(std::vector<BriefNode *>(inputs.begin(), inputs.end()));
src_node->outlinks =
std::move(std::vector<BriefNode *>(outputs.begin(), outputs.end()));
dst_node->inlinks.clear();
dst_node->outlinks.clear();
auto inlink_or_outlink_cleaner = [&](std::vector<BriefNode *> &nodes) {
for (auto *&n : nodes) {
if (n == src_node || n == dst_node) {
n = src_node;
}
}
};
// Change all the dst inputs and outputs corresponding inlink and
// outlink to the src node.
for (auto *node : src_node->inlinks) {
inlink_or_outlink_cleaner(node->outlinks);
}
for (auto *node : src_node->outlinks) {
inlink_or_outlink_cleaner(node->inlinks);
}
}
// FlexibleDfS
// If reverse is true, do reverse dfs.
// If enter func is not nullptr, calls enter(node) before visiting any children
// of node.
// If leave func not nullptr, calls leave(node) after visiting all parents of
// node.
void FlexibleDFS(const std::vector<BriefNode *> &source, bool reverse,
const std::function<bool(const BriefNode *)> &enter,
const std::function<bool(const BriefNode *)> &leave) {
typedef struct {
const BriefNode *node;
bool leave;
} FNode;
std::vector<FNode> stack;
for (auto &node : source) {
stack.push_back(FNode{node, false});
}
std::unordered_set<const BriefNode *> visited;
while (!stack.empty()) {
auto fnode = stack.back();
stack.pop_back();
if (fnode.leave) {
if (leave && !leave(fnode.node)) return;
}
if (visited.count(fnode.node)) continue;
visited.insert(fnode.node);
if (enter && !enter(fnode.node)) return;
if (leave) stack.push_back(FNode{fnode.node, true});
const std::vector<BriefNode *> iter_nodes =
reverse == true ? fnode.node->inlinks : fnode.node->outlinks;
for (const BriefNode *node : iter_nodes) {
if (!visited.count(node)) {
stack.push_back(FNode{node, false});
}
}
}
}
std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() { std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
// Run the Extract algorithm to find all subgraphs.
std::vector<Node *> marked_nodes; std::vector<Node *> marked_nodes;
// We use brief_node_map to represent the original graph in order to avoid
// changing the original graph.
std::unordered_map<int, BriefNode *> brief_node_map;
for (auto &node : GraphTraits<DataFlowGraph>(*graph_).nodes_in_TS()) { for (auto &node : GraphTraits<DataFlowGraph>(*graph_).nodes_in_TS()) {
brief_node_map[node.id()] = new BriefNode(&node);
if (node.attr(kMarkerAttrName).Bool()) { if (node.attr(kMarkerAttrName).Bool()) {
marked_nodes.push_back(&node); marked_nodes.push_back(&node);
} }
} }
// extract sub-graphs in the marked node set, use Union Find algorithm. // extract sub-graphs in the marked node set, use Union Find algorithm.
node_map_t node_map; // id to ptr node_map_t node_map; // id to ptr
for (auto *n : marked_nodes) { for (auto *n : marked_nodes) {
...@@ -88,11 +201,73 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() { ...@@ -88,11 +201,73 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
n->attr(kUnionFindParent).Int32() = n->id(); n->attr(kUnionFindParent).Int32() = n->id();
node_map[n->id()] = n; node_map[n->id()] = n;
} }
std::unordered_set<Node *> visited;
for (auto *n : marked_nodes) { // create breif node map
for (auto *out : n->outlinks) { for (auto &itr : brief_node_map) {
if (node_map.count(out->id())) { for (Node *node : itr.second->node->inlinks) {
UnionFindCombine(node_map, n->id(), out->id()); itr.second->inlinks.push_back(brief_node_map[node->id()]);
}
for (Node *node : itr.second->node->outlinks) {
itr.second->outlinks.push_back(brief_node_map[node->id()]);
}
}
for (auto &itr : brief_node_map) {
BriefNode *brief_node = itr.second;
if (!brief_node->node->attr(kMarkerAttrName).Bool()) {
VLOG(4) << brief_node->node->id() << " node not a trt candicate.";
continue;
}
// Our algorithm must guarantee that:
// 1. The graph is always directed acyclic graph(DAG).
// 2. If there is a path in the subgraph from X to Y (X and Y are both
// nodes
// in the subgraph), then all paths from X to Y are in the subgraph.
//
// In order to achieve the above guarantee.
// For adjacent nodes src -> dst.
// 1. Get all dst input nodes except src.
// 2. Reverse DFS from those input nodes
// 3. If there is a path from input nodes to src,
// then the src and dst nodes can not be fused into one node,
// otherwise it can be done.
while (true) {
std::unordered_set<BriefNode *> contract_nodes;
for (auto *out : brief_node->outlinks) {
// must be an trt candidate
if (!out->node->attr(kMarkerAttrName).Bool()) continue;
// get all dst input nodes except src.
std::vector<BriefNode *> source_nodes;
for (auto *n : out->inlinks) {
if (n != brief_node) {
source_nodes.push_back(n);
}
}
// Reverse DFS from the source_nodes.
bool have_excess_path = false;
FlexibleDFS(source_nodes, true, nullptr,
[&have_excess_path, brief_node](const BriefNode *n) {
if (n == brief_node) {
have_excess_path = true;
return false;
}
return true;
});
if (have_excess_path) continue;
contract_nodes.insert(out);
}
if (contract_nodes.empty()) break;
for (auto dst_node : contract_nodes) {
UnionFindCombine(node_map, brief_node->node->id(),
dst_node->node->id());
UnionContractedNodes(brief_node_map, brief_node->node->id(),
dst_node->node->id());
} }
} }
} }
...@@ -128,6 +303,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() { ...@@ -128,6 +303,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
auto io = ExtractInputAndOutputOfSubGraph(subgraph); auto io = ExtractInputAndOutputOfSubGraph(subgraph);
block_node->inlinks = std::move(io.first); block_node->inlinks = std::move(io.first);
block_node->outlinks = std::move(io.second); block_node->outlinks = std::move(io.second);
for (auto *node : subgraph) { for (auto *node : subgraph) {
// TODO(Superjomn) need a unified mechanism to treat deleted node in each // TODO(Superjomn) need a unified mechanism to treat deleted node in each
// pass. // pass.
......
...@@ -82,7 +82,7 @@ TEST(SubGraphSplitter, Fuse) { ...@@ -82,7 +82,7 @@ TEST(SubGraphSplitter, Fuse) {
// At least one nodes should be deleted. // At least one nodes should be deleted.
ASSERT_EQ(dfg.nodes.size(), count0 + 1); // added a new FunctionBlock ASSERT_EQ(dfg.nodes.size(), count0 + 1); // added a new FunctionBlock
ASSERT_EQ(6, count1); ASSERT_EQ(11, count1);
} }
} // namespace analysis } // namespace analysis
......
...@@ -160,11 +160,21 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -160,11 +160,21 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
fluid_t->mutable_data<float>(platform::CUDAPlace( fluid_t->mutable_data<float>(platform::CUDAPlace(
boost::get<platform::CUDAPlace>(context.GetPlace()).device)), boost::get<platform::CUDAPlace>(context.GetPlace()).device)),
size * sizeof(float)); size * sizeof(float));
//} else {
// engine->GetOutputInGPU( // TODO(zhaolong) : delete it sometimes
// y, fluid_t->mutable_data<float>(platform::CUDAPlace()), /* THIS CODE JUST FOR TEST
// size * sizeof(float)); std::cout << output_maps[output_index] << std::endl;
//} platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(framework::make_ddim(ddim));
auto* temp_data = temp_tensor.mutable_data<float>(cpu_place);
TensorCopySync(*fluid_t, cpu_place ,&temp_tensor);
for(int i = 0; i < size; i++) {
std::cout << temp_data[i] << " " ;
}
std::cout << std::endl;
*/
output_index += 1; output_index += 1;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册