提交 a6aa8ea7 编写于 作者: N nhzlx

faster rcnn input is presistable. (fix it in paddle-trt)

test=develop
上级 73b47df1
...@@ -1101,12 +1101,6 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) { ...@@ -1101,12 +1101,6 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
return out_var; return out_var;
} }
// only support "identity" and "relu" now.
/*
std::unordered_set<std::string> conv_act_set({"identity", "sigmoid", "relu",
"relu6", "relux", "tanh",
"band_pass"});
*/
std::unordered_set<std::string> conv_act_set({"identity", "relu"}); std::unordered_set<std::string> conv_act_set({"identity", "relu"});
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" #include <algorithm>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h" #include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -197,10 +199,26 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, ...@@ -197,10 +199,26 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
std::vector<std::string> ExtractParameters( std::vector<std::string> ExtractParameters(
const std::unordered_set<Node *> &nodes) { const std::unordered_set<Node *> &nodes) {
// We can judge whether a variable is a parameter by
// its presistable property, but sometimes the presistable
// of the feed op output is true, so we have to identify it.
std::vector<std::string> feed_outputs;
for (const auto &node : nodes) {
if (!node->IsOp()) continue;
std::string op_type = node->Op()->Type();
if (op_type == "feed") {
std::vector<std::string> output_names = node->Op()->OutputArgumentNames();
std::copy(output_names.begin(), output_names.end(),
std::back_inserter(feed_outputs));
}
}
std::vector<std::string> parameters; std::vector<std::string> parameters;
for (const auto &node : nodes) { for (const auto &node : nodes) {
if (!node->IsVar()) continue; if (!node->IsVar()) continue;
if (node->Var()->Persistable()) { if (node->Var()->Persistable() &&
std::find(feed_outputs.begin(), feed_outputs.end(), node->Name()) ==
feed_outputs.end()) {
parameters.push_back(node->Name()); parameters.push_back(node->Name());
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册