diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 6ef3417901f445c455293949369ad8df6651996b..a826dfb275ca3ca53223cca30380d3c37a8e78f1 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1101,12 +1101,6 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) { return out_var; } -// only support "identity" and "relu" now. -/* -std::unordered_set conv_act_set({"identity", "sigmoid", "relu", - "relu6", "relux", "tanh", - "band_pass"}); -*/ std::unordered_set conv_act_set({"identity", "relu"}); PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 9c42b83e7add348433635b1899087324e4e370d4..5886868be0a90cb57ca8cc6bbf595b95b349e4b2 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" +#include #include #include + #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h" +#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" namespace paddle { namespace inference { @@ -197,10 +199,26 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, std::vector ExtractParameters( const std::unordered_set &nodes) { + // We can judge whether a variable is a parameter by + // its presistable property, but sometimes the presistable + // of the feed op output is true, so we have to identify it. + std::vector feed_outputs; + for (const auto &node : nodes) { + if (!node->IsOp()) continue; + std::string op_type = node->Op()->Type(); + if (op_type == "feed") { + std::vector output_names = node->Op()->OutputArgumentNames(); + std::copy(output_names.begin(), output_names.end(), + std::back_inserter(feed_outputs)); + } + } + std::vector parameters; for (const auto &node : nodes) { if (!node->IsVar()) continue; - if (node->Var()->Persistable()) { + if (node->Var()->Persistable() && + std::find(feed_outputs.begin(), feed_outputs.end(), node->Name()) == + feed_outputs.end()) { parameters.push_back(node->Name()); } }