diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 81b8ffa83f612f5b67cd91a7a2c1228519a1fbb7..ba1d7379c56d953a0f37d03deed6c47e46cbf129 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -68,21 +68,12 @@ pass_library(transpose_flatten_concat_fuse_pass inference) pass_library(identity_scale_op_clean_pass base) pass_library(sync_batch_norm_pass base) pass_library(runtime_context_cache_pass base) -pass_library(simplify_anakin_detection_pattern_pass inference) -pass_library(anakin_fillconstant_elementwisemul_fuse inference) +pass_library(quant_conv2d_dequant_fuse_pass inference) +pass_library(fillconstant_elementwisemul_fuse inference) -# There may be many transpose-flatten structures in a model, and the output of -# these structures will be used as inputs to the concat Op. This pattern will -# be detected by our pass. The index here represents the number of structures in the -# pattern. We use index 3 ~ 6, because these quantities of structures are -# common in the models. -foreach (index RANGE 2 6) - file(APPEND ${pass_file} "USE_PASS(transpose_flatten${index}_concat_fuse_pass);\n") -endforeach() - -foreach (index RANGE 2 6) - file(APPEND ${pass_file} "USE_PASS(simplify_anakin_detection_pattern_pass${index});\n") -endforeach() +if(ANAKIN_FOUND) +pass_library(simplify_anakin_priorbox_detection_out_pass inference) +endif() if(WITH_MKLDNN) pass_library(mkldnn_placement_pass base mkldnn) diff --git a/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.cc b/paddle/fluid/framework/ir/fillconstant_elementwisemul_fuse.cc similarity index 82% rename from paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.cc rename to paddle/fluid/framework/ir/fillconstant_elementwisemul_fuse.cc index 39077f6420613e115fff828eefc295769c187833..915a2f62bafa2baf98b7407cd87d3e69f20b44d2 100644 --- a/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.cc +++ b/paddle/fluid/framework/ir/fillconstant_elementwisemul_fuse.cc @@ -15,7 +15,7 @@ #include #include -#include "paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h" +#include "paddle/fluid/framework/ir/fillconstant_elementwisemul_fuse.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h" namespace paddle { @@ -29,8 +29,8 @@ namespace ir { GET_IR_NODE(elementwise_mul); \ GET_IR_NODE(elementwise_mul_out); -void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const { - const std::string pattern_name = "anakin_fillconstant_elementwisemul_fuse"; +void FillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = "fillconstant_elementwisemul_fuse"; FusePassBase::Init(pattern_name, graph); GraphPatternDetector gpd; @@ -39,8 +39,8 @@ void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const { ->assert_is_op_input("elementwise_mul", "X") ->AsInput(); - patterns::AnakinFillConstantElementWiseMulFuse pattern(gpd.mutable_pattern(), - pattern_name); + patterns::FillConstantElementWiseMulFuse pattern(gpd.mutable_pattern(), + pattern_name); pattern(x); auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -79,5 +79,5 @@ void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const { } // namespace framework } // namespace paddle -REGISTER_PASS(anakin_fillconstant_elementwisemul_fuse, - paddle::framework::ir::AnakinFillconstantElementwisemulFuse); +REGISTER_PASS(fillconstant_elementwisemul_fuse, + paddle::framework::ir::FillconstantElementwisemulFuse); diff --git a/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h b/paddle/fluid/framework/ir/fillconstant_elementwisemul_fuse.h similarity index 89% rename from paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h rename to paddle/fluid/framework/ir/fillconstant_elementwisemul_fuse.h index 14c07c5884ebeda602953704de6db42f16441d6e..ab66fb4a46a8a5b60b3bf95e27ae24c7217a5a3a 100644 --- a/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h +++ b/paddle/fluid/framework/ir/fillconstant_elementwisemul_fuse.h @@ -21,9 +21,9 @@ namespace paddle { namespace framework { namespace ir { -class AnakinFillconstantElementwisemulFuse : public FusePassBase { +class FillconstantElementwisemulFuse : public FusePassBase { public: - virtual ~AnakinFillconstantElementwisemulFuse() {} + virtual ~FillconstantElementwisemulFuse() {} protected: void ApplyImpl(ir::Graph* graph) const override; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 555fdc7b7a03ebc99fcc77a26341d291dac2c308..8468f9ccc12a017ebe4fe73581e7bbce00dd626d 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1471,7 +1471,8 @@ PDNode *patterns::TransposeFlattenConcat::operator()( } PDNode *patterns::AnakinDetectionPattern::operator()( - std::vector conv_in, int times) { + std::vector conv_in, int times, std::string priorbox_type, + bool is_reshape) { // The times represents the repeat times of the // {prior_box, prior_box_loc_out, flatten, prior_box_var_out, reshape} const int kNumFields = 7; @@ -1486,37 +1487,38 @@ PDNode *patterns::AnakinDetectionPattern::operator()( const int kMultiClassSecondInputNmsOffset = times + 1; std::vector nodes; + std::string op_after_priorbox = is_reshape ? "reshape2" : "flatten2"; for (int i = 0; i < times; i++) { nodes.push_back( pattern->NewNode(GetNodeName("prior_box" + std::to_string(i))) - ->assert_is_op("density_prior_box")); + ->assert_is_op(priorbox_type)); nodes.push_back(pattern->NewNode(GetNodeName("box_out" + std::to_string(i))) - ->assert_is_op_output("density_prior_box", "Boxes") - ->assert_is_op_input("reshape2", "X") + ->assert_is_op_output(priorbox_type, "Boxes") + ->assert_is_op_input(op_after_priorbox, "X") ->AsIntermediate()); nodes.push_back( pattern->NewNode(GetNodeName("reshape1" + std::to_string(i))) - ->assert_is_op("reshape2")); + ->assert_is_op(op_after_priorbox)); nodes.push_back( pattern->NewNode(GetNodeName("reshape1_out" + std::to_string(i))) - ->assert_is_op_output("reshape2") + ->assert_is_op_output(op_after_priorbox) ->assert_is_op_nth_input("concat", "X", i) ->AsIntermediate()); nodes.push_back( pattern->NewNode(GetNodeName("box_var_out" + std::to_string(i))) - ->assert_is_op_output("density_prior_box", "Variances") - ->assert_is_op_input("reshape2", "X") + ->assert_is_op_output(priorbox_type, "Variances") + ->assert_is_op_input(op_after_priorbox, "X") ->AsIntermediate()); nodes.push_back( pattern->NewNode(GetNodeName("reshape2" + std::to_string(i))) - ->assert_is_op("reshape2")); + ->assert_is_op(op_after_priorbox)); nodes.push_back( pattern->NewNode(GetNodeName("reshape2_out" + std::to_string(i))) - ->assert_is_op_output("reshape2") + ->assert_is_op_output(op_after_priorbox) ->assert_is_op_nth_input("concat", "X", i) ->AsIntermediate()); } @@ -1612,7 +1614,7 @@ PDNode *patterns::AnakinDetectionPattern::operator()( return multiclass_nms_out; } -PDNode *patterns::AnakinFillConstantElementWiseMulFuse::operator()( +PDNode *patterns::FillConstantElementWiseMulFuse::operator()( PDNode *elementwise_op_input) { auto fill_constant = pattern->NewNode(fill_constant_repr())->assert_is_op("fill_constant"); @@ -1635,6 +1637,76 @@ PDNode *patterns::AnakinFillConstantElementWiseMulFuse::operator()( return elementwise_mul_out; } +void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, + const std::string &op_type, + const std::string &weight_name, + int times) { + const int kNumFields = 5; + const int kQuantizedWeightOffset = 0; + const int kQuantizedOpOffset = 1; + const int kQuantizedOpOutOffset = 2; + const int kDequantOpOffset = 3; + const int kDequantOpOutOffset = 4; + // the quant op always be one. + auto quant_op_in_scale = + pattern->NewNode(GetNodeName("quant_op_in_scale")) + ->assert_is_op_input("fake_quantize_range_abs_max", "InScale") + ->AsInput(); + auto quant_op = pattern->NewNode(GetNodeName("quant_op")) + ->assert_is_op("fake_quantize_range_abs_max"); + + auto quant_op_out_scale = + pattern->NewNode(GetNodeName("quant_op_out_scale")) + ->assert_is_op_output("fake_quantize_range_abs_max", "OutScale") + ->assert_is_op_input("fake_dequantize_max_abs", "Scale") + ->AsIntermediate(); + + auto quant_op_out = + pattern->NewNode(GetNodeName("quant_op_out")) + ->assert_is_op_output("fake_quantize_range_abs_max", "Out") + ->assert_is_op_input(op_type) + ->AsIntermediate(); + + // there are 'times' quantized and dequant op + std::vector nodes; + for (int i = 0; i < times; i++) { + nodes.push_back( + pattern->NewNode(GetNodeName("quantized_op_weight") + std::to_string(i)) + ->assert_is_op_input(op_type, weight_name) + ->AsInput()); + nodes.push_back( + pattern->NewNode(GetNodeName("quantized_op") + std::to_string(i)) + ->assert_is_op(op_type)); + + nodes.push_back( + pattern->NewNode(GetNodeName("quantized_op_out") + std::to_string(i)) + ->assert_is_op_output(op_type) + ->assert_is_op_input("fake_dequantize_max_abs", "X") + ->AsIntermediate()); + + nodes.push_back( + pattern->NewNode(GetNodeName("dequant_op") + std::to_string(i)) + ->assert_is_op("fake_dequantize_max_abs")); + nodes.push_back( + pattern->NewNode(GetNodeName("dequant_op_out") + std::to_string(i)) + ->assert_is_op_output("fake_dequantize_max_abs", "Out") + ->AsOutput()); + } + + quant_op->LinksFrom({quant_op_input, quant_op_in_scale}); + quant_op_out->LinksFrom({quant_op}); + for (int i = 0; i < times; i++) { + nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom( + {quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]}); + nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom( + {nodes[i * kNumFields + kQuantizedOpOffset]}); + nodes[i * kNumFields + kDequantOpOffset]->LinksFrom( + {nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale}); + nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom( + {nodes[i * kNumFields + kDequantOpOffset]}); + } +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 130ddeac4cd1a38516540d175e17d46f877bd909..a5ac3a0c3733cf610159c6367d04f3323b797c50 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -848,7 +848,8 @@ struct AnakinDetectionPattern : public PatternBase { AnakinDetectionPattern(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "anakin_detect_pattern") {} - PDNode* operator()(std::vector conv_inputs, int times); + PDNode* operator()(std::vector conv_inputs, int times, + std::string priorbox_type, bool is_reshape); std::string GetNodeName(const std::string& op_type) { return PDNodeName(name_scope_, repr_, id_, op_type); @@ -859,9 +860,9 @@ struct AnakinDetectionPattern : public PatternBase { } }; -struct AnakinFillConstantElementWiseMulFuse : public PatternBase { - AnakinFillConstantElementWiseMulFuse(PDPattern* pattern, - const std::string& name_scope) +struct FillConstantElementWiseMulFuse : public PatternBase { + FillConstantElementWiseMulFuse(PDPattern* pattern, + const std::string& name_scope) : PatternBase(pattern, name_scope, "anakin_fillconstant_elementwisemul_fuse") {} @@ -874,6 +875,22 @@ struct AnakinFillConstantElementWiseMulFuse : public PatternBase { PATTERN_DECL_NODE(elementwise_mul_out); }; +struct QuantDequantOpFuse : public PatternBase { + QuantDequantOpFuse(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "quant_dequant_fuse") {} + + void operator()(PDNode* quant_op_input, const std::string& op_name, + const std::string& weight_name, int times = 1); + + std::string GetNodeName(const std::string& op_type) { + return PDNodeName(name_scope_, repr_, id_, op_type); + } + + PDNode* GetPDNode(const std::string& op_type) { + return pattern->RetrieveNode(GetNodeName(op_type)); + } +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..7cab9c353d35cb6d725d787986e992b6853d42ce --- /dev/null +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -0,0 +1,173 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph_viz_pass.h" +#include "paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +void RunQuantDequant(ir::Graph* graph, Scope* scope, int times, + std::string op_type) { + const std::string pattern_name = "quant_dequant_fuse"; + // FusePassBase::Init(pattern_name, graph); + const int kNumFields = 5; + const int kQuantizedWeightOffset = 0; + const int kQuantizedOpOffset = 1; + const int kQuantizedOpOutOffset = 2; + const int kDequantOpOffset = 3; + const int kDequantOpOutOffset = 4; + + GraphPatternDetector gpd; + auto* x = gpd.mutable_pattern() + ->NewNode("x") + ->assert_is_op_input("fake_quantize_range_abs_max", "X") + ->AsInput(); + + std::string quantized_op_type = ""; + std::string weight_name = ""; + if (op_type == "conv2d") { + quantized_op_type = "conv2d"; + weight_name = "Filter"; + } else if (op_type == "conv2d_fusion") { + quantized_op_type = "conv2d_fusion"; + weight_name = "Filter"; + } else if (op_type == "mul") { + quantized_op_type = "mul"; + weight_name = "Y"; + } else if (op_type == "fc") { + quantized_op_type = "fc"; + weight_name = "W"; + } else { + PADDLE_ENFORCE( + "QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for " + "now."); + } + + patterns::QuantDequantOpFuse pattern(gpd.mutable_pattern(), pattern_name); + pattern(x, quantized_op_type, weight_name, times); + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + PADDLE_ENFORCE(subgraph.count(x)); + auto* input_node = subgraph.at(x); + Node* quant_op_in_scale = + subgraph.at(pattern.GetPDNode("quant_op_in_scale")); + Node* quant_op = subgraph.at(pattern.GetPDNode("quant_op")); + Node* quant_op_out_scale = + subgraph.at(pattern.GetPDNode("quant_op_out_scale")); + Node* quant_op_out = subgraph.at(pattern.GetPDNode("quant_op_out")); + + std::vector nodes; + for (int i = 0; i < times; i++) { + nodes.push_back(subgraph.at( + pattern.GetPDNode("quantized_op_weight" + std::to_string(i)))); + nodes.push_back( + subgraph.at(pattern.GetPDNode("quantized_op" + std::to_string(i)))); + nodes.push_back(subgraph.at( + pattern.GetPDNode("quantized_op_out" + std::to_string(i)))); + nodes.push_back( + subgraph.at(pattern.GetPDNode("dequant_op" + std::to_string(i)))); + nodes.push_back( + subgraph.at(pattern.GetPDNode("dequant_op_out" + std::to_string(i)))); + } + + int bit_length = boost::get(quant_op->Op()->GetAttr("bit_length")); + int range = ((1 << (bit_length - 1)) - 1); + // Prepare input scale + std::string input_scale_var_name = quant_op->Op()->Input("InScale").front(); + PADDLE_ENFORCE(scope); + const LoDTensor& input_scale_tensor = + scope->FindVar(input_scale_var_name)->Get(); + + PADDLE_ENFORCE(paddle::platform::is_cpu_place(input_scale_tensor.place())); + const float* input_scale_data = input_scale_tensor.data(); + float input_scale = input_scale_data[0]; + std::unordered_set delete_nodes; + + for (int i = 0; i < times; i++) { + // max_range = (range * range) / weight_scale + float max_range = boost::get( + nodes[i * kNumFields + kDequantOpOffset]->Op()->GetAttr("max_range")); + float weight_scale = (range * range) / max_range; + + auto base_op_desc = + *nodes[i * kNumFields + kQuantizedOpOffset]->Op()->Proto(); + std::string new_input = input_node->Name(); + std::string new_output = + nodes[i * kNumFields + kDequantOpOutOffset]->Name(); + + framework::OpDesc new_op_desc(base_op_desc, nullptr); + new_op_desc.SetType(quantized_op_type); + + if (quantized_op_type == "conv2d" || + quantized_op_type == "conv2d_fusion") { + new_op_desc.SetInput("Input", {new_input}); + new_op_desc.SetOutput("Output", {new_output}); + } else if (quantized_op_type == "fc") { + new_op_desc.SetInput("Input", {new_input}); + new_op_desc.SetOutput("Out", {new_output}); + } else if (quantized_op_type == "mul") { + new_op_desc.SetInput("X", {new_input}); + new_op_desc.SetOutput("Out", {new_output}); + } + + new_op_desc.SetAttr("enable_int8", true); + new_op_desc.SetAttr("input_scale", input_scale); + new_op_desc.SetAttr("weight_scale", weight_scale); + new_op_desc.Flush(); + auto* new_op = graph->CreateOpNode(&new_op_desc); + IR_NODE_LINK_TO(input_node, new_op); + IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset], new_op); + IR_NODE_LINK_TO(new_op, nodes[i * kNumFields + kDequantOpOutOffset]); + delete_nodes.insert(nodes[i * kNumFields + kQuantizedOpOffset]); + delete_nodes.insert(nodes[i * kNumFields + kQuantizedOpOutOffset]); + delete_nodes.insert(nodes[i * kNumFields + kDequantOpOffset]); + } + + delete_nodes.insert(quant_op_in_scale); + delete_nodes.insert(quant_op); + delete_nodes.insert(quant_op_out); + delete_nodes.insert(quant_op_out_scale); + // Delete the unneeded nodes. + GraphSafeRemoveNodes(graph, delete_nodes); + }; + gpd(graph, handler); +} + +void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = "quant_dequant_fuse"; + FusePassBase::Init(pattern_name, graph); + + std::unordered_set quantized_op_types = {"conv2d", "mul"}; + auto* scope = param_scope(); + for (auto& op_type : quantized_op_types) { + for (int i = 1; i <= 6; i++) { + RunQuantDequant(graph, scope, i, op_type); + } + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(quant_conv2d_dequant_fuse_pass, + paddle::framework::ir::QuantDequantFusePass); diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..a61b34563acc4cbcee778509a097587222579295 --- /dev/null +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class QuantDequantFusePass : public FusePassBase { + public: + virtual ~QuantDequantFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.cc b/paddle/fluid/framework/ir/simplify_anakin_priorbox_detection_out_pass.cc similarity index 84% rename from paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.cc rename to paddle/fluid/framework/ir/simplify_anakin_priorbox_detection_out_pass.cc index e1ddc444707148b1b781a922429de13a715f3b60..b3606e4d922cc8f59dca90904466a889f83f6094 100644 --- a/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.cc +++ b/paddle/fluid/framework/ir/simplify_anakin_priorbox_detection_out_pass.cc @@ -17,25 +17,24 @@ #include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/node.h" -#include "paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.h" +#include "paddle/fluid/framework/ir/simplify_anakin_priorbox_detection_out_pass.h" namespace paddle { namespace framework { namespace ir { -template -void SimplifyAnakinDetectionPatternPass::ApplyImpl( - ir::Graph *graph) const { +void RunSimplifyAnakinDetection(ir::Graph *graph, int times, bool is_density, + bool is_reshape) { const std::string pattern_name = "simplify_anakin_detection_pattern_pass" + std::to_string(times); - FusePassBase::Init(pattern_name, graph); + std::string priorbox_type = is_density ? "density_prior_box" : "prior_box"; GraphPatternDetector gpd; std::vector input_nodes; for (int i = 0; i < times; i++) { input_nodes.push_back(gpd.mutable_pattern() ->NewNode("x" + std::to_string(i)) - ->assert_is_op_input("density_prior_box", "Input") + ->assert_is_op_input(priorbox_type, "Input") ->AsInput()); } input_nodes.push_back(gpd.mutable_pattern() @@ -49,7 +48,7 @@ void SimplifyAnakinDetectionPatternPass::ApplyImpl( ->AsInput()); patterns::AnakinDetectionPattern pattern(gpd.mutable_pattern(), pattern_name); - pattern(input_nodes, times); + pattern(input_nodes, times, priorbox_type, is_reshape); auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *g) { @@ -119,8 +118,7 @@ void SimplifyAnakinDetectionPatternPass::ApplyImpl( boost::get(box_coder_op->Op()->GetAttr("code_type")); bool box_normalized = boost::get(box_coder_op->Op()->GetAttr("box_normalized")); - // auto variance = - // boost::get>(box_coder_op->Op()->GetAttr("variance")); + int background_label = boost::get(multiclass_nms->Op()->GetAttr("background_label")); float score_threshold = @@ -138,7 +136,6 @@ void SimplifyAnakinDetectionPatternPass::ApplyImpl( nodes[i * kNumFields + kPriorBoxLocOffset]->Name()); } - // int axis = boost::get(concat_op1->Op()->GetAttr("axis")); framework::OpDesc concat1_desc; concat1_desc.SetType("concat"); concat1_desc.SetInput("X", concat1_input_names); @@ -213,31 +210,24 @@ void SimplifyAnakinDetectionPatternPass::ApplyImpl( gpd(graph, handler); } -template class SimplifyAnakinDetectionPatternPass<1>; -template class SimplifyAnakinDetectionPatternPass<2>; -template class SimplifyAnakinDetectionPatternPass<3>; -template class SimplifyAnakinDetectionPatternPass<4>; -template class SimplifyAnakinDetectionPatternPass<5>; -template class SimplifyAnakinDetectionPatternPass<6>; +void SimplifyAnakinDetectionPatternPass::ApplyImpl(ir::Graph *graph) const { + const int pattern_nums = 6; + const std::string pattern_name = "simplify_anakin_detection_pattern_pass"; + FusePassBase::Init(pattern_name, graph); + std::vector options = {true, false}; + for (const auto &is_density : options) { + for (const auto &is_reshape : options) { + for (int i = 1; i <= pattern_nums; i++) { + RunSimplifyAnakinDetection(graph, i, is_density, is_reshape); + } + } + } +} } // namespace ir } // namespace framework } // namespace paddle -REGISTER_PASS(simplify_anakin_detection_pattern_pass, - paddle::framework::ir::SimplifyAnakinDetectionPatternPass<1>); - -REGISTER_PASS(simplify_anakin_detection_pattern_pass2, - paddle::framework::ir::SimplifyAnakinDetectionPatternPass<2>); - -REGISTER_PASS(simplify_anakin_detection_pattern_pass3, - paddle::framework::ir::SimplifyAnakinDetectionPatternPass<3>); - -REGISTER_PASS(simplify_anakin_detection_pattern_pass4, - paddle::framework::ir::SimplifyAnakinDetectionPatternPass<4>); - -REGISTER_PASS(simplify_anakin_detection_pattern_pass5, - paddle::framework::ir::SimplifyAnakinDetectionPatternPass<5>); - -REGISTER_PASS(simplify_anakin_detection_pattern_pass6, - paddle::framework::ir::SimplifyAnakinDetectionPatternPass<6>); +typedef paddle::framework::ir::SimplifyAnakinDetectionPatternPass + priorbox_pattern; +REGISTER_PASS(simplify_anakin_priorbox_detection_out_pass, priorbox_pattern); diff --git a/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.h b/paddle/fluid/framework/ir/simplify_anakin_priorbox_detection_out_pass.h similarity index 98% rename from paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.h rename to paddle/fluid/framework/ir/simplify_anakin_priorbox_detection_out_pass.h index e4a266cbe843ac56a8c0e4fb1e6f166afea6bfac..e882b9dc252e61a2e9e4e3666de49b7eee6d714a 100644 --- a/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.h +++ b/paddle/fluid/framework/ir/simplify_anakin_priorbox_detection_out_pass.h @@ -26,7 +26,6 @@ namespace ir { // these structures will be used as inputs to the concat Op. This pattern will // be detected by our pass. The times here represents the repeat times of this // structure. -template class SimplifyAnakinDetectionPatternPass : public FusePassBase { public: virtual ~SimplifyAnakinDetectionPatternPass() {} diff --git a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc index 61c12d4b6e76bf3021a92aa99953df626b0e45e7..a984a4942b374c3e2c5f148f8147c55d0f5deb24 100644 --- a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc @@ -25,11 +25,9 @@ namespace paddle { namespace framework { namespace ir { -template -void TransposeFlattenConcatFusePass::ApplyImpl(ir::Graph *graph) const { +void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) { const std::string pattern_name = "transpose_flatten" + std::to_string(times) + "_concat_fuse"; - FusePassBase::Init(pattern_name, graph); GraphPatternDetector gpd; std::vector input_nodes; @@ -122,31 +120,18 @@ void TransposeFlattenConcatFusePass::ApplyImpl(ir::Graph *graph) const { gpd(graph, handler); } -template class TransposeFlattenConcatFusePass<1>; -template class TransposeFlattenConcatFusePass<2>; -template class TransposeFlattenConcatFusePass<3>; -template class TransposeFlattenConcatFusePass<4>; -template class TransposeFlattenConcatFusePass<5>; -template class TransposeFlattenConcatFusePass<6>; +void TransposeFlattenConcatFusePass::ApplyImpl(ir::Graph *graph) const { + const int pattern_nums = 6; + const std::string pattern_name = "transpose_flatten_concat_fuse"; + FusePassBase::Init(pattern_name, graph); + for (int i = 1; i <= pattern_nums; i++) { + RunTransposeFlattenConcatFuse(graph, i); + } +} } // namespace ir } // namespace framework } // namespace paddle REGISTER_PASS(transpose_flatten_concat_fuse_pass, - paddle::framework::ir::TransposeFlattenConcatFusePass<1>); - -REGISTER_PASS(transpose_flatten2_concat_fuse_pass, - paddle::framework::ir::TransposeFlattenConcatFusePass<2>); - -REGISTER_PASS(transpose_flatten3_concat_fuse_pass, - paddle::framework::ir::TransposeFlattenConcatFusePass<3>); - -REGISTER_PASS(transpose_flatten4_concat_fuse_pass, - paddle::framework::ir::TransposeFlattenConcatFusePass<4>); - -REGISTER_PASS(transpose_flatten5_concat_fuse_pass, - paddle::framework::ir::TransposeFlattenConcatFusePass<5>); - -REGISTER_PASS(transpose_flatten6_concat_fuse_pass, - paddle::framework::ir::TransposeFlattenConcatFusePass<6>); + paddle::framework::ir::TransposeFlattenConcatFusePass); diff --git a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h index 366d26d800c9899c455a3699f3f73f6e481aa0e0..939a8c31e5501e23968f9b44b4fe09e78280fd07 100644 --- a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h +++ b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h @@ -13,6 +13,8 @@ // limitations under the License. #pragma once +#include + #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" @@ -24,7 +26,6 @@ namespace ir { // these structures will be used as inputs to the concat Op. This pattern will // be detected by our pass. The times here represents the repeat times of this // structure. -template class TransposeFlattenConcatFusePass : public FusePassBase { public: virtual ~TransposeFlattenConcatFusePass() {} diff --git a/paddle/fluid/inference/anakin/convert/density_prior_box.cc b/paddle/fluid/inference/anakin/convert/density_prior_box.cc index a55c153f99a815c0e0092b69b8e181630aed16bf..35e02919aa70c211da5d4a5785a9833747d99ce2 100644 --- a/paddle/fluid/inference/anakin/convert/density_prior_box.cc +++ b/paddle/fluid/inference/anakin/convert/density_prior_box.cc @@ -34,25 +34,41 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op, auto input_name = op_desc.Input("Input").front(); auto image_name = op_desc.Input("Image").front(); auto output_name = op_desc.Output("Boxes").front(); + auto op_type = op_desc.Type(); + auto op_name = op_type + ":" + op_desc.Output("Boxes").front(); - auto op_name = op_desc.Type() + ":" + op_desc.Output("Boxes").front(); + // only for density_prior_box + std::vector fixed_sizes = {}; + std::vector fixed_ratios = {}; + std::vector densities = {}; - auto fixed_sizes = - boost::get>(op_desc.GetAttr("fixed_sizes")); - auto fixed_ratios = - boost::get>(op_desc.GetAttr("fixed_ratios")); - auto densities = boost::get>(op_desc.GetAttr("densities")); + std::vector min_sizes = {}; + std::vector max_sizes = {}; + std::vector aspect_ratios = {}; + bool is_clip = false; + bool is_flip = false; + + if (op_type == "density_prior_box") { + fixed_sizes = + boost::get>(op_desc.GetAttr("fixed_sizes")); + fixed_ratios = + boost::get>(op_desc.GetAttr("fixed_ratios")); + densities = boost::get>(op_desc.GetAttr("densities")); + is_clip = boost::get(op_desc.GetAttr("clip")); + } else if (op_type == "prior_box") { + min_sizes = boost::get>(op_desc.GetAttr("min_sizes")); + max_sizes = boost::get>(op_desc.GetAttr("max_sizes")); + aspect_ratios = + boost::get>(op_desc.GetAttr("aspect_ratios")); + is_clip = boost::get(op_desc.GetAttr("clip")); + is_flip = boost::get(op_desc.GetAttr("flip")); + } std::vector dens; for (auto& ele : densities) { dens.push_back(static_cast(ele)); } - // lack flip - // auto clip = boost::get(op_desc.GetAttr("clip")); auto variances = boost::get>(op_desc.GetAttr("variances")); - for (auto& ele : variances) { - LOG(INFO) << ele; - } // lack img_h, img_w auto step_h = boost::get(op_desc.GetAttr("step_h")); @@ -66,14 +82,14 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op, std::vector temp_v = {}; engine_->AddOp(op_name, "PriorBox", {input_name, image_name}, {output_name}); - engine_->AddOpAttr>(op_name, "min_size", temp_v); - engine_->AddOpAttr>(op_name, "max_size", temp_v); - engine_->AddOpAttr>(op_name, "aspect_ratio", temp_v); + engine_->AddOpAttr>(op_name, "min_size", min_sizes); + engine_->AddOpAttr>(op_name, "max_size", max_sizes); + engine_->AddOpAttr>(op_name, "aspect_ratio", aspect_ratios); engine_->AddOpAttr>(op_name, "fixed_size", fixed_sizes); engine_->AddOpAttr>(op_name, "fixed_ratio", fixed_ratios); engine_->AddOpAttr>(op_name, "density", dens); - engine_->AddOpAttr(op_name, "is_flip", static_cast(false)); - engine_->AddOpAttr(op_name, "is_clip", static_cast(false)); + engine_->AddOpAttr(op_name, "is_flip", is_flip); + engine_->AddOpAttr(op_name, "is_clip", is_clip); engine_->AddOpAttr>(op_name, "variance", variances); engine_->AddOpAttr(op_name, "img_h", static_cast(0)); engine_->AddOpAttr(op_name, "img_w", static_cast(0)); @@ -88,3 +104,4 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op, } // namespace paddle REGISTER_ANAKIN_OP_CONVERTER(density_prior_box, DensityPriorBoxOpConverter); +REGISTER_ANAKIN_OP_CONVERTER(prior_box, DensityPriorBoxOpConverter); diff --git a/paddle/fluid/inference/anakin/convert/op_converter.h b/paddle/fluid/inference/anakin/convert/op_converter.h index 4603681e1e8a3c2841a62cc88b49a84950910e73..45db4221747128cd7f6d26c8830fa75ebf81ac72 100644 --- a/paddle/fluid/inference/anakin/convert/op_converter.h +++ b/paddle/fluid/inference/anakin/convert/op_converter.h @@ -48,7 +48,7 @@ class AnakinOpConverter { framework::OpDesc op_desc(op, nullptr); std::string op_type = op_desc.Type(); AnakinOpConverter *it = nullptr; - + if (op_type == "depthwise_conv2d") op_type = "conv2d"; if (op_type == "reshape2") op_type = "reshape"; if (op_type == "transpose2") op_type = "transpose"; if (op_type == "flatten2") op_type = "flatten"; diff --git a/paddle/fluid/inference/anakin/op_teller.cc b/paddle/fluid/inference/anakin/op_teller.cc index 90cf021de2f9d365fd1fa21f7d189d3fcd9d3ab2..2042fb18ea41f8b41fc35543c7e1b642c4f2fa7c 100644 --- a/paddle/fluid/inference/anakin/op_teller.cc +++ b/paddle/fluid/inference/anakin/op_teller.cc @@ -42,6 +42,8 @@ struct SimpleOpTypeSetTeller : public Teller { teller_set.insert("dropout"); teller_set.insert("sigmoid"); teller_set.insert("sum"); + teller_set.insert("depthwise_conv2d"); + teller_set.insert("prior_box"); } bool operator()(const std::string& op_type, diff --git a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc index 9e05aa5c16186d67200c4630619cc53fa241aa1b..38612d5cc3d093885144f3b1cd6107232885b645 100644 --- a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc @@ -37,14 +37,14 @@ using framework::ir::Node; void analysis::AnakinSubgraphPass::ApplyImpl( framework::ir::Graph *graph) const { - framework::ir::FusePassBase::Init("anakin_subgraph_pass", graph.get()); + framework::ir::FusePassBase::Init("anakin_subgraph_pass", graph); auto teller = [](const framework::ir::Node *node) { if (!node->IsOp() || !node->Op()) return false; return anakin::OpTeller::Global().Tell(node->Op()->Type(), *node->Op()); }; - SubGraphFuser fuser(graph.get(), teller, 6 /* min_subgraph_size */); + SubGraphFuser fuser(graph, teller, 6 /* min_subgraph_size */); fuser(); std::vector graph_param_names = @@ -56,10 +56,10 @@ void analysis::AnakinSubgraphPass::ApplyImpl( for (auto *node : graph->Nodes()) { if (node->IsOp() && !Agent(node).subgraph()->empty()) { - CreateAnakinOp(node, graph.get(), graph_param_names, &repetitive_params); + CreateAnakinOp(node, graph, graph_param_names, &repetitive_params); std::unordered_set nodes2remove( Agent(node).subgraph()->begin(), Agent(node).subgraph()->end()); - framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove); + framework::ir::GraphSafeRemoveNodes(graph, nodes2remove); } } @@ -69,7 +69,7 @@ void analysis::AnakinSubgraphPass::ApplyImpl( nodes2remove.insert(node); } } - framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove); + framework::ir::GraphSafeRemoveNodes(graph, nodes2remove); graph->Set(framework::ir::kRepetitiveParamAttr, new std::vector(repetitive_params)); } diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index ef5872c52c6a1b3f3ade40ea43e78e2120fa6643..019098a5dd0d372a690955698a2ab6a4039a2416 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -192,6 +192,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( block_desc.Proto()->SerializeAsString()); SetAttr(op_desc->Proto(), "max_batch_size", Get("max_batch_size")); SetAttr(op_desc->Proto(), "workspace_size", Get("workspace_size")); + SetAttr(op_desc->Proto(), "gpu_id", Get("gpu_device_id")); SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping); SetAttr(op_desc->Proto(), "parameters", params); diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc index d13ec7608c3e8075c1ef62fd4d47fbeee06e9005..1f27e80cf49f49863cf000d71369512242afb7b4 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc @@ -52,6 +52,7 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { for (auto &var_name : all_vars) { if (std::count(repetitive_params.begin(), repetitive_params.end(), var_name)) { + scope->EraseVars({var_name}); continue; } auto *var = scope->FindLocalVar(var_name); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index f7260561547bb0bd7aea1590239e38090953f6fc..7d8e9fe8bfada743388afd3ae4eedb5d84961706 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -886,4 +886,5 @@ USE_ANAKIN_CONVERTER(detection_out); USE_ANAKIN_CONVERTER(density_prior_box); USE_ANAKIN_CONVERTER(dropout); USE_ANAKIN_CONVERTER(sum); +USE_ANAKIN_CONVERTER(prior_box); #endif diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 8ec32b3a0b7fe459518e269fc72b182bc168435f..1d1d39e44096b9f50e5bc9603fa12aba92b0e8e2 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -70,17 +70,15 @@ void GpuPassStrategy::EnableMKLDNN() { // The following passes works for Anakin sub-graph engine. const std::vector kAnakinSubgraphPasses({ - "infer_clean_graph_pass", // - "simplify_anakin_detection_pattern_pass5", // - "simplify_anakin_detection_pattern_pass4", // - "simplify_anakin_detection_pattern_pass3", // - "simplify_anakin_detection_pattern_pass2", // - "anakin_fillconstant_elementwisemul_fuse", // - "fc_fuse_pass", // - "conv_elementwise_add_fuse_pass", // - "conv_bn_fuse_pass", // - "conv_elementwise_add_fuse_pass", // - "fc_gru_fuse_pass", // + "infer_clean_graph_pass", // + "simplify_anakin_priorbox_detection_out_pass", // + "fillconstant_elementwisemul_fuse", // + "fc_fuse_pass", // + "conv_elementwise_add_fuse_pass", // + "conv_bn_fuse_pass", // + "conv_elementwise_add_fuse_pass", // + "fc_gru_fuse_pass", // + "quant_conv2d_dequant_fuse_pass", // "anakin_subgraph_pass", }); @@ -97,13 +95,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_elementwise_add2_act_fuse_pass", // "conv_elementwise_add_fuse_pass", // "runtime_context_cache_pass", // -#endif +#endif // + "transpose_flatten_concat_fuse_pass", }); - for (int i = 6; i >= 2; i--) { - passes_.push_back("transpose_flatten" + std::to_string(i) + - "_concat_fuse_pass"); - } use_gpu_ = true; } diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index c36673312489738ad0475a0b70a23a1c6c948b9d..7f470924b337d59943c04ab0ff2820555f961732 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -52,6 +52,7 @@ class TensorRTEngineOp : public framework::OperatorBase { std::string engine_key_; std::string engine_serialized_data_; bool calibration_mode_; + int device_id_; public: TensorRTEngineOp(const std::string &type, @@ -62,6 +63,7 @@ class TensorRTEngineOp : public framework::OperatorBase { input_names_ = Inputs("Xs"); max_batch_size_ = Attr("max_batch_size"); workspace_size_ = Attr("workspace_size"); + device_id_ = Attr("gpu_id"); enable_int8_ = Attr("enable_int8"); calibration_data_ = Attr("calibration_data"); engine_key_ = Attr("engine_key"); @@ -79,6 +81,17 @@ class TensorRTEngineOp : public framework::OperatorBase { if (enable_int8_ && calibration_data_.size()) { calibrator_.reset(new TRTInt8Calibrator(calibration_data_)); } + + if (!calibration_mode_ && !engine_serialized_data_.empty()) { + trt_engine_.reset(new inference::tensorrt::TensorRTEngine( + max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(), + device_id_)); + PADDLE_ENFORCE(engine_serialized_data_.size(), + "TRT serialized data should not be empty here," + "there must be error when generate serialized data in TRT " + "subgraph detect pass."); + trt_engine_->Deserialize(engine_serialized_data_); + } } protected: @@ -225,12 +238,8 @@ class TensorRTEngineOp : public framework::OperatorBase { if (!trt_engine_) { trt_engine_.reset(new inference::tensorrt::TensorRTEngine( max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(), - boost::get(dev_place).device)); - if (!engine_serialized_data_.empty()) { - trt_engine_->Deserialize(engine_serialized_data_); - } else { - PrepareTRTEngine(scope, trt_engine_.get()); - } + device_id_)); + PrepareTRTEngine(scope, trt_engine_.get()); } return trt_engine_.get(); } diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc index e7ad2f4fe0c654d8928f5793c1ad8052ab766fb5..cc4d8d6e6f7e24dcb04ed0f58e63cb13ce176bdb 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc @@ -108,6 +108,8 @@ TEST(TensorRTEngineOp, manual) { std::vector({"z0"})); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("engine_serialized_data", std::string("")); + int device_id = 0; + engine_op_desc.SetAttr("gpu_id", device_id); LOG(INFO) << "create engine op"; auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc); @@ -204,6 +206,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { std::vector({"z3"})); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("engine_serialized_data", std::string("")); + int device_id = 0; + engine_op_desc.SetAttr("gpu_id", device_id); auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);