未验证 提交 c9995289 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #13124 from NHZlX/fix_subgraph_bug

Fix tensorrt subgraph bug
...@@ -440,6 +440,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT ...@@ -440,6 +440,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
} }
return false; return false;
}; };
for (auto &node : graph) { for (auto &node : graph) {
for (auto *in : node->inlinks) { for (auto *in : node->inlinks) {
// The Value that is written by nodes inside a sub-graph shouldn't be the // The Value that is written by nodes inside a sub-graph shouldn't be the
...@@ -459,6 +460,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT ...@@ -459,6 +460,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
std::vector<Node *>(outputs.begin(), outputs.end())); std::vector<Node *>(outputs.begin(), outputs.end()));
} }
// Filter the Intermediate results of the subgraph node.
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) { void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
std::vector<Node *> op_nodes; std::vector<Node *> op_nodes;
for (auto &node : GraphTraits<DataFlowGraph>(*graph).nodes_in_TS()) { for (auto &node : GraphTraits<DataFlowGraph>(*graph).nodes_in_TS()) {
...@@ -480,9 +482,11 @@ void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) { ...@@ -480,9 +482,11 @@ void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
for (auto *out : op_nodes[i]->outlinks) { for (auto *out : op_nodes[i]->outlinks) {
if (follow_up_input_names.count(out->name())) { if (follow_up_input_names.count(out->name())) {
filtered_subgraph_outlinks.push_back(out); filtered_subgraph_outlinks.push_back(out);
} else {
out->SetDeleted();
} }
} }
PADDLE_ENFORCE_GE(filtered_subgraph_outlinks.size(), 1UL); // The filtered_subgraph_outlinks may be empty.
op_nodes[i]->outlinks = filtered_subgraph_outlinks; op_nodes[i]->outlinks = filtered_subgraph_outlinks;
} }
} }
......
...@@ -106,20 +106,23 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, ...@@ -106,20 +106,23 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
// collect inputs // collect inputs
std::unordered_set<std::string> input_names; std::unordered_set<std::string> input_names;
std::unordered_set<std::string> input_names_with_id;
for (auto *x : func->inlinks) { for (auto *x : func->inlinks) {
input_names.insert(x->name()); input_names.insert(x->name());
input_names_with_id.insert(x->name() + std::to_string(x->id()));
} }
desc.SetInput( desc.SetInput(
"Xs", std::vector<std::string>(input_names.begin(), input_names.end())); "Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
std::unordered_set<std::string> output_names; std::unordered_set<std::string> output_names;
std::unordered_set<std::string> output_names_with_id;
for (auto *x : func->outlinks) { for (auto *x : func->outlinks) {
output_names.insert(x->name()); output_names.insert(x->name());
output_names_with_id.insert(x->name() + std::to_string(x->id()));
} }
std::vector<std::string> output_temp(output_names.begin(), desc.SetOutput(
output_names.end()); "Ys", std::vector<std::string>(output_names.begin(), output_names.end()));
desc.SetOutput("Ys", output_temp);
desc.SetType("tensorrt_engine"); desc.SetType("tensorrt_engine");
std::unordered_map<std::string, std::string> output_name_map; std::unordered_map<std::string, std::string> output_name_map;
...@@ -153,11 +156,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, ...@@ -153,11 +156,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
std::vector<std::string> replaced_names; std::vector<std::string> replaced_names;
for (int k = 0; k < in_var->arguments_size(); k++) { for (int k = 0; k < in_var->arguments_size(); k++) {
std::string arg_value = in_var->arguments(k); std::string arg_value = in_var->arguments(k);
if (input_names.count(arg_value)) { std::string arg_value_with_id =
arg_value + std::to_string(var2id[arg_value]);
if (input_names_with_id.count(arg_value_with_id)) {
replaced_names.push_back(arg_value); replaced_names.push_back(arg_value);
} else { } else {
replaced_names.push_back(arg_value + replaced_names.push_back(arg_value_with_id);
std::to_string(var2id[arg_value]));
} }
} }
in_var->clear_arguments(); in_var->clear_arguments();
...@@ -176,11 +180,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, ...@@ -176,11 +180,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
std::vector<std::string> replaced_names; std::vector<std::string> replaced_names;
for (int k = 0; k < out_var->arguments_size(); k++) { for (int k = 0; k < out_var->arguments_size(); k++) {
std::string arg_value = out_var->arguments(k); std::string arg_value = out_var->arguments(k);
if (output_names.count(arg_value)) { std::string arg_value_with_id =
output_name_map[arg_value] = arg_value + std::to_string(var2id[arg_value]);
arg_value + std::to_string(var2id[arg_value]); if (output_names_with_id.count(arg_value_with_id)) {
output_name_map[arg_value] = arg_value_with_id;
} }
replaced_names.push_back(arg_value + std::to_string(var2id[arg_value])); replaced_names.push_back(arg_value_with_id);
} }
out_var->clear_arguments(); out_var->clear_arguments();
for (size_t k = 0; k < replaced_names.size(); k++) { for (size_t k = 0; k < replaced_names.size(); k++) {
......
...@@ -74,13 +74,134 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) { ...@@ -74,13 +74,134 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
node_map.at(b)->attr(kUnionFindParent).Int32() = a_ancestor; node_map.at(b)->attr(kUnionFindParent).Int32() = a_ancestor;
} }
// This is a simple representation of a graph.
// The BriefNode hold the pointer of the Node.
// This is to avoid changing the original graph
// in the process of trt graph analysis.
struct BriefNode {
explicit BriefNode(Node *n) { node = n; }
Node *node;
std::vector<BriefNode *> inlinks;
std::vector<BriefNode *> outlinks;
};
// Union two adjacent BriefNode.
// Suppose we have two adjacent nodes src and dst.
// We will perform the following operations:
// 1. add all inputs(except src) of dst to src inlinks.
// 2. add all outputs of dst to src outlinks.
// 3. change all the dst's inputs and outputs
// corresponding inlinks and outlinks to src node.
// 4. delete all dst's inlinks and outlinks.
void UnionContractedNodes(const std::unordered_map<int, BriefNode *> &node_map,
int src_id, int dst_id) {
// merge the two adjacent nodes into one node.
BriefNode *src_node = node_map.at(src_id);
BriefNode *dst_node = node_map.at(dst_id);
std::unordered_set<BriefNode *> inputs(src_node->inlinks.begin(),
src_node->inlinks.end());
std::unordered_set<BriefNode *> outputs;
for (auto *n : src_node->outlinks) {
if (n != dst_node) outputs.insert(n);
}
// Add the inlinks and outlinks of dst node to src node.
std::vector<BriefNode *> dst_in_nodes = dst_node->inlinks;
for (BriefNode *node : dst_in_nodes) {
if (node != src_node) {
inputs.insert(node);
}
}
std::vector<BriefNode *> dst_out_nodes = dst_node->outlinks;
for (BriefNode *node : dst_out_nodes) {
outputs.insert(node);
}
// update the dst and src node's inlinks and outlinks.
src_node->inlinks =
std::move(std::vector<BriefNode *>(inputs.begin(), inputs.end()));
src_node->outlinks =
std::move(std::vector<BriefNode *>(outputs.begin(), outputs.end()));
dst_node->inlinks.clear();
dst_node->outlinks.clear();
auto inlink_or_outlink_cleaner = [&](std::vector<BriefNode *> &nodes) {
for (auto *&n : nodes) {
if (n == src_node || n == dst_node) {
n = src_node;
}
}
};
// Change all the dst inputs and outputs corresponding inlink and
// outlink to the src node.
for (auto *node : src_node->inlinks) {
inlink_or_outlink_cleaner(node->outlinks);
}
for (auto *node : src_node->outlinks) {
inlink_or_outlink_cleaner(node->inlinks);
}
}
// FlexibleDFS
// If reverse is true, do reverse dfs.
// If enter func is not nullptr, calls enter(node) before visiting any children
// of node.
// If leave func not nullptr, calls leave(node) after visiting all parents of
// node.
void FlexibleDFS(const std::vector<BriefNode *> &source, bool reverse,
const std::function<bool(const BriefNode *)> &enter,
const std::function<bool(const BriefNode *)> &leave) {
typedef struct {
const BriefNode *node;
bool leave;
} FNode;
std::vector<FNode> stack;
for (auto &node : source) {
stack.push_back(FNode{node, false});
}
std::unordered_set<const BriefNode *> visited;
while (!stack.empty()) {
auto fnode = stack.back();
stack.pop_back();
if (fnode.leave) {
if (leave && !leave(fnode.node)) return;
}
if (visited.count(fnode.node)) continue;
visited.insert(fnode.node);
if (enter && !enter(fnode.node)) return;
if (leave) stack.push_back(FNode{fnode.node, true});
const std::vector<BriefNode *> iter_nodes =
reverse == true ? fnode.node->inlinks : fnode.node->outlinks;
for (const BriefNode *node : iter_nodes) {
if (!visited.count(node)) {
stack.push_back(FNode{node, false});
}
}
}
}
std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() { std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
// Run the Extract algorithm to find all subgraphs.
std::vector<Node *> marked_nodes; std::vector<Node *> marked_nodes;
// We use brief_node_map to represent the original graph in order to avoid
// changing the original graph.
std::unordered_map<int, BriefNode *> brief_node_map;
for (auto &node : GraphTraits<DataFlowGraph>(*graph_).nodes_in_TS()) { for (auto &node : GraphTraits<DataFlowGraph>(*graph_).nodes_in_TS()) {
brief_node_map[node.id()] = new BriefNode(&node);
if (node.attr(kMarkerAttrName).Bool()) { if (node.attr(kMarkerAttrName).Bool()) {
marked_nodes.push_back(&node); marked_nodes.push_back(&node);
} }
} }
// extract sub-graphs in the marked node set, use Union Find algorithm. // extract sub-graphs in the marked node set, use Union Find algorithm.
node_map_t node_map; // id to ptr node_map_t node_map; // id to ptr
for (auto *n : marked_nodes) { for (auto *n : marked_nodes) {
...@@ -88,11 +209,73 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() { ...@@ -88,11 +209,73 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
n->attr(kUnionFindParent).Int32() = n->id(); n->attr(kUnionFindParent).Int32() = n->id();
node_map[n->id()] = n; node_map[n->id()] = n;
} }
std::unordered_set<Node *> visited;
for (auto *n : marked_nodes) { // create breif node map
for (auto *out : n->outlinks) { for (auto &itr : brief_node_map) {
if (node_map.count(out->id())) { for (Node *node : itr.second->node->inlinks) {
UnionFindCombine(node_map, n->id(), out->id()); itr.second->inlinks.push_back(brief_node_map[node->id()]);
}
for (Node *node : itr.second->node->outlinks) {
itr.second->outlinks.push_back(brief_node_map[node->id()]);
}
}
for (auto &itr : brief_node_map) {
BriefNode *brief_node = itr.second;
if (!brief_node->node->attr(kMarkerAttrName).Bool()) {
VLOG(4) << brief_node->node->id() << " node not a trt candicate.";
continue;
}
// Our algorithm must guarantee that:
// 1. The graph is always directed acyclic graph(DAG).
// 2. If there is a path in the subgraph from X to Y (X and Y are both
// nodes in the subgraph), then all paths from X to Y are in the
// subgraph.
//
// In order to achieve the above guarantee.
// For adjacent nodes src -> dst.
// 1. Get all dst input nodes except src.
// 2. Reverse DFS from those input nodes
// 3. If there is a path from input nodes to src,
// then the src and dst nodes can not be fused into one node,
// otherwise it can be done.
while (true) {
std::unordered_set<BriefNode *> contract_nodes;
for (auto *out : brief_node->outlinks) {
// must be an trt candidate
if (!out->node->attr(kMarkerAttrName).Bool()) continue;
// get all dst input nodes except src.
std::vector<BriefNode *> source_nodes;
for (auto *n : out->inlinks) {
if (n != brief_node) {
source_nodes.push_back(n);
}
}
// Reverse DFS from the source_nodes.
bool have_excess_path = false;
FlexibleDFS(source_nodes, true, nullptr,
[&have_excess_path, brief_node](const BriefNode *n) {
if (n == brief_node) {
have_excess_path = true;
return false;
}
return true;
});
if (have_excess_path) continue;
contract_nodes.insert(out);
}
if (contract_nodes.empty()) break;
for (auto dst_node : contract_nodes) {
UnionFindCombine(node_map, brief_node->node->id(),
dst_node->node->id());
UnionContractedNodes(brief_node_map, brief_node->node->id(),
dst_node->node->id());
} }
} }
} }
...@@ -128,6 +311,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() { ...@@ -128,6 +311,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
auto io = ExtractInputAndOutputOfSubGraph(subgraph); auto io = ExtractInputAndOutputOfSubGraph(subgraph);
block_node->inlinks = std::move(io.first); block_node->inlinks = std::move(io.first);
block_node->outlinks = std::move(io.second); block_node->outlinks = std::move(io.second);
for (auto *node : subgraph) { for (auto *node : subgraph) {
// TODO(Superjomn) need a unified mechanism to treat deleted node in each // TODO(Superjomn) need a unified mechanism to treat deleted node in each
// pass. // pass.
......
...@@ -82,7 +82,7 @@ TEST(SubGraphSplitter, Fuse) { ...@@ -82,7 +82,7 @@ TEST(SubGraphSplitter, Fuse) {
// At least one nodes should be deleted. // At least one nodes should be deleted.
ASSERT_EQ(dfg.nodes.size(), count0 + 1); // added a new FunctionBlock ASSERT_EQ(dfg.nodes.size(), count0 + 1); // added a new FunctionBlock
ASSERT_EQ(6, count1); ASSERT_EQ(11, count1);
} }
} // namespace analysis } // namespace analysis
......
...@@ -35,6 +35,8 @@ class ReluOpConverter : public OpConverter { ...@@ -35,6 +35,8 @@ class ReluOpConverter : public OpConverter {
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor), engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
nvinfer1::ActivationType::kRELU); nvinfer1::ActivationType::kRELU);
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
layer->setName(("relu (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0)); engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside. // output, so place the declaration inside.
......
...@@ -116,6 +116,8 @@ class BatchNormOpConverter : public OpConverter { ...@@ -116,6 +116,8 @@ class BatchNormOpConverter : public OpConverter {
scale_weights.get(), power_weights.get()); scale_weights.get(), power_weights.get());
auto output_name = op_desc.Output("Y").front(); auto output_name = op_desc.Output("Y").front();
layer->setName(("batch_norm (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str());
engine_->weight_map[op_desc.Input("Bias").front()] = engine_->weight_map[op_desc.Input("Bias").front()] =
std::move(combile_bias_tensor); std::move(combile_bias_tensor);
engine_->weight_map[op_desc.Input("Scale").front()] = engine_->weight_map[op_desc.Input("Scale").front()] =
......
...@@ -42,6 +42,8 @@ class ConcatOpConverter : public OpConverter { ...@@ -42,6 +42,8 @@ class ConcatOpConverter : public OpConverter {
axis = axis - 1; // Remove batch dim axis = axis - 1; // Remove batch dim
layer->setAxis(axis); layer->setAxis(axis);
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
layer->setName(("concat (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0)); engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside. // output, so place the declaration inside.
......
...@@ -78,8 +78,10 @@ class Conv2dOpConverter : public OpConverter { ...@@ -78,8 +78,10 @@ class Conv2dOpConverter : public OpConverter {
layer->setNbGroups(groups); layer->setNbGroups(groups);
auto output_name = op_desc.Output("Output").front(); auto output_name = op_desc.Output("Output").front();
layer->setName(("conv2d (Output: " + output_name + ")").c_str());
engine_->weight_map[op_desc.Input("Filter").front()] = engine_->weight_map[op_desc.Input("Filter").front()] =
std::move(weight_tensor); std::move(weight_tensor);
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0)); engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { if (test_mode) {
engine_->DeclareOutput(output_name); engine_->DeclareOutput(output_name);
......
...@@ -89,6 +89,8 @@ class ElementwiseWeightOpConverter : public OpConverter { ...@@ -89,6 +89,8 @@ class ElementwiseWeightOpConverter : public OpConverter {
shift_weights.get(), scale_weights.get(), power_weights.get()); shift_weights.get(), scale_weights.get(), power_weights.get());
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
layer->setName(("elementwise_add (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str());
engine_->weight_map[op_desc.Input("Y").front()] = std::move(weight_tensor); engine_->weight_map[op_desc.Input("Y").front()] = std::move(weight_tensor);
engine_->SetITensor(output_name, layer->getOutput(0)); engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the if (test_mode) { // the test framework can not determine which is the
...@@ -137,6 +139,8 @@ class ElementwiseTensorOpConverter : public OpConverter { ...@@ -137,6 +139,8 @@ class ElementwiseTensorOpConverter : public OpConverter {
*const_cast<nvinfer1::ITensor*>(Y), op_pair->second); *const_cast<nvinfer1::ITensor*>(Y), op_pair->second);
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
layer->setName(("elementwise (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0)); engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside. // output, so place the declaration inside.
......
...@@ -107,6 +107,8 @@ class FcOpConverter : public OpConverter { ...@@ -107,6 +107,8 @@ class FcOpConverter : public OpConverter {
n_output, tmp_weight.get(), bias.get()); n_output, tmp_weight.get(), bias.get());
auto output_name = op_desc.Output("Out").front(); auto output_name = op_desc.Output("Out").front();
layer->setName(("fc (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0)); engine_->SetITensor(output_name, layer->getOutput(0));
engine_->weight_map[op_desc.Input("Y").front()] = std::move(tmp); engine_->weight_map[op_desc.Input("Y").front()] = std::move(tmp);
if (test_mode) { if (test_mode) {
......
...@@ -72,6 +72,8 @@ class Pool2dOpConverter : public OpConverter { ...@@ -72,6 +72,8 @@ class Pool2dOpConverter : public OpConverter {
layer->setPadding(nv_paddings); layer->setPadding(nv_paddings);
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
layer->setName(("pool2d (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0)); engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { if (test_mode) {
engine_->DeclareOutput(output_name); engine_->DeclareOutput(output_name);
......
...@@ -160,11 +160,7 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -160,11 +160,7 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
fluid_t->mutable_data<float>(platform::CUDAPlace( fluid_t->mutable_data<float>(platform::CUDAPlace(
boost::get<platform::CUDAPlace>(context.GetPlace()).device)), boost::get<platform::CUDAPlace>(context.GetPlace()).device)),
size * sizeof(float)); size * sizeof(float));
//} else {
// engine->GetOutputInGPU(
// y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
// size * sizeof(float));
//}
output_index += 1; output_index += 1;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册