From a6aa8ea7719f6664e5218bb13d3d1db691e4225f Mon Sep 17 00:00:00 2001 From: nhzlx Date: Wed, 26 Dec 2018 05:58:23 +0000 Subject: [PATCH] faster rcnn input is presistable. (fix it in paddle-trt) test=develop --- .../framework/ir/graph_pattern_detector.cc | 6 ----- .../ir_passes/tensorrt_subgraph_pass.cc | 22 +++++++++++++++++-- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 6ef341790..a826dfb27 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1101,12 +1101,6 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) { return out_var; } -// only support "identity" and "relu" now. -/* -std::unordered_set conv_act_set({"identity", "sigmoid", "relu", - "relu6", "relux", "tanh", - "band_pass"}); -*/ std::unordered_set conv_act_set({"identity", "relu"}); PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 9c42b83e7..5886868be 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" +#include #include #include + #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h" +#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" namespace paddle { namespace inference { @@ -197,10 +199,26 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, std::vector ExtractParameters( const std::unordered_set &nodes) { + // We can judge whether a variable is a parameter by + // its presistable property, but sometimes the presistable + // of the feed op output is true, so we have to identify it. + std::vector feed_outputs; + for (const auto &node : nodes) { + if (!node->IsOp()) continue; + std::string op_type = node->Op()->Type(); + if (op_type == "feed") { + std::vector output_names = node->Op()->OutputArgumentNames(); + std::copy(output_names.begin(), output_names.end(), + std::back_inserter(feed_outputs)); + } + } + std::vector parameters; for (const auto &node : nodes) { if (!node->IsVar()) continue; - if (node->Var()->Persistable()) { + if (node->Var()->Persistable() && + std::find(feed_outputs.begin(), feed_outputs.end(), node->Name()) == + feed_outputs.end()) { parameters.push_back(node->Name()); } } -- GitLab