提交 a6aa8ea7 编写于 作者: N nhzlx

faster rcnn input is presistable. (fix it in paddle-trt)

test=develop
上级 73b47df1
......@@ -1101,12 +1101,6 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
return out_var;
}
// only support "identity" and "relu" now.
/*
std::unordered_set<std::string> conv_act_set({"identity", "sigmoid", "relu",
"relu6", "relux", "tanh",
"band_pass"});
*/
std::unordered_set<std::string> conv_act_set({"identity", "relu"});
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
namespace paddle {
namespace inference {
......@@ -197,10 +199,26 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
std::vector<std::string> ExtractParameters(
const std::unordered_set<Node *> &nodes) {
// We can judge whether a variable is a parameter by
// its presistable property, but sometimes the presistable
// of the feed op output is true, so we have to identify it.
std::vector<std::string> feed_outputs;
for (const auto &node : nodes) {
if (!node->IsOp()) continue;
std::string op_type = node->Op()->Type();
if (op_type == "feed") {
std::vector<std::string> output_names = node->Op()->OutputArgumentNames();
std::copy(output_names.begin(), output_names.end(),
std::back_inserter(feed_outputs));
}
}
std::vector<std::string> parameters;
for (const auto &node : nodes) {
if (!node->IsVar()) continue;
if (node->Var()->Persistable()) {
if (node->Var()->Persistable() &&
std::find(feed_outputs.begin(), feed_outputs.end(), node->Name()) ==
feed_outputs.end()) {
parameters.push_back(node->Name());
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册