提交 d065b5bf 编写于 作者: N nhzlx

Anakin ssd support

refine trt first run
add quant dequant fuse pass
omit simplify_anakin_priorbox_detection template
omit transpose_flatten_concat_fuse template
test=develop
上级 ed61d67c
...@@ -68,21 +68,12 @@ pass_library(transpose_flatten_concat_fuse_pass inference) ...@@ -68,21 +68,12 @@ pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base) pass_library(identity_scale_op_clean_pass base)
pass_library(sync_batch_norm_pass base) pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base) pass_library(runtime_context_cache_pass base)
pass_library(simplify_anakin_detection_pattern_pass inference) pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(anakin_fillconstant_elementwisemul_fuse inference) pass_library(fillconstant_elementwisemul_fuse inference)
# There may be many transpose-flatten structures in a model, and the output of if(ANAKIN_FOUND)
# these structures will be used as inputs to the concat Op. This pattern will pass_library(simplify_anakin_priorbox_detection_out_pass inference)
# be detected by our pass. The index here represents the number of structures in the endif()
# pattern. We use index 3 ~ 6, because these quantities of structures are
# common in the models.
foreach (index RANGE 2 6)
file(APPEND ${pass_file} "USE_PASS(transpose_flatten${index}_concat_fuse_pass);\n")
endforeach()
foreach (index RANGE 2 6)
file(APPEND ${pass_file} "USE_PASS(simplify_anakin_detection_pattern_pass${index});\n")
endforeach()
if(WITH_MKLDNN) if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base mkldnn) pass_library(mkldnn_placement_pass base mkldnn)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h" #include "paddle/fluid/framework/ir/fillconstant_elementwisemul_fuse.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace paddle { namespace paddle {
...@@ -29,8 +29,8 @@ namespace ir { ...@@ -29,8 +29,8 @@ namespace ir {
GET_IR_NODE(elementwise_mul); \ GET_IR_NODE(elementwise_mul); \
GET_IR_NODE(elementwise_mul_out); GET_IR_NODE(elementwise_mul_out);
void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const { void FillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "anakin_fillconstant_elementwisemul_fuse"; const std::string pattern_name = "fillconstant_elementwisemul_fuse";
FusePassBase::Init(pattern_name, graph); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -39,7 +39,7 @@ void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const { ...@@ -39,7 +39,7 @@ void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
->assert_is_op_input("elementwise_mul", "X") ->assert_is_op_input("elementwise_mul", "X")
->AsInput(); ->AsInput();
patterns::AnakinFillConstantElementWiseMulFuse pattern(gpd.mutable_pattern(), patterns::FillConstantElementWiseMulFuse pattern(gpd.mutable_pattern(),
pattern_name); pattern_name);
pattern(x); pattern(x);
...@@ -79,5 +79,5 @@ void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const { ...@@ -79,5 +79,5 @@ void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(anakin_fillconstant_elementwisemul_fuse, REGISTER_PASS(fillconstant_elementwisemul_fuse,
paddle::framework::ir::AnakinFillconstantElementwisemulFuse); paddle::framework::ir::FillconstantElementwisemulFuse);
...@@ -21,9 +21,9 @@ namespace paddle { ...@@ -21,9 +21,9 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class AnakinFillconstantElementwisemulFuse : public FusePassBase { class FillconstantElementwisemulFuse : public FusePassBase {
public: public:
virtual ~AnakinFillconstantElementwisemulFuse() {} virtual ~FillconstantElementwisemulFuse() {}
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
......
...@@ -1471,7 +1471,8 @@ PDNode *patterns::TransposeFlattenConcat::operator()( ...@@ -1471,7 +1471,8 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
} }
PDNode *patterns::AnakinDetectionPattern::operator()( PDNode *patterns::AnakinDetectionPattern::operator()(
std::vector<PDNode *> conv_in, int times) { std::vector<PDNode *> conv_in, int times, std::string priorbox_type,
bool is_reshape) {
// The times represents the repeat times of the // The times represents the repeat times of the
// {prior_box, prior_box_loc_out, flatten, prior_box_var_out, reshape} // {prior_box, prior_box_loc_out, flatten, prior_box_var_out, reshape}
const int kNumFields = 7; const int kNumFields = 7;
...@@ -1486,37 +1487,38 @@ PDNode *patterns::AnakinDetectionPattern::operator()( ...@@ -1486,37 +1487,38 @@ PDNode *patterns::AnakinDetectionPattern::operator()(
const int kMultiClassSecondInputNmsOffset = times + 1; const int kMultiClassSecondInputNmsOffset = times + 1;
std::vector<PDNode *> nodes; std::vector<PDNode *> nodes;
std::string op_after_priorbox = is_reshape ? "reshape2" : "flatten2";
for (int i = 0; i < times; i++) { for (int i = 0; i < times; i++) {
nodes.push_back( nodes.push_back(
pattern->NewNode(GetNodeName("prior_box" + std::to_string(i))) pattern->NewNode(GetNodeName("prior_box" + std::to_string(i)))
->assert_is_op("density_prior_box")); ->assert_is_op(priorbox_type));
nodes.push_back(pattern->NewNode(GetNodeName("box_out" + std::to_string(i))) nodes.push_back(pattern->NewNode(GetNodeName("box_out" + std::to_string(i)))
->assert_is_op_output("density_prior_box", "Boxes") ->assert_is_op_output(priorbox_type, "Boxes")
->assert_is_op_input("reshape2", "X") ->assert_is_op_input(op_after_priorbox, "X")
->AsIntermediate()); ->AsIntermediate());
nodes.push_back( nodes.push_back(
pattern->NewNode(GetNodeName("reshape1" + std::to_string(i))) pattern->NewNode(GetNodeName("reshape1" + std::to_string(i)))
->assert_is_op("reshape2")); ->assert_is_op(op_after_priorbox));
nodes.push_back( nodes.push_back(
pattern->NewNode(GetNodeName("reshape1_out" + std::to_string(i))) pattern->NewNode(GetNodeName("reshape1_out" + std::to_string(i)))
->assert_is_op_output("reshape2") ->assert_is_op_output(op_after_priorbox)
->assert_is_op_nth_input("concat", "X", i) ->assert_is_op_nth_input("concat", "X", i)
->AsIntermediate()); ->AsIntermediate());
nodes.push_back( nodes.push_back(
pattern->NewNode(GetNodeName("box_var_out" + std::to_string(i))) pattern->NewNode(GetNodeName("box_var_out" + std::to_string(i)))
->assert_is_op_output("density_prior_box", "Variances") ->assert_is_op_output(priorbox_type, "Variances")
->assert_is_op_input("reshape2", "X") ->assert_is_op_input(op_after_priorbox, "X")
->AsIntermediate()); ->AsIntermediate());
nodes.push_back( nodes.push_back(
pattern->NewNode(GetNodeName("reshape2" + std::to_string(i))) pattern->NewNode(GetNodeName("reshape2" + std::to_string(i)))
->assert_is_op("reshape2")); ->assert_is_op(op_after_priorbox));
nodes.push_back( nodes.push_back(
pattern->NewNode(GetNodeName("reshape2_out" + std::to_string(i))) pattern->NewNode(GetNodeName("reshape2_out" + std::to_string(i)))
->assert_is_op_output("reshape2") ->assert_is_op_output(op_after_priorbox)
->assert_is_op_nth_input("concat", "X", i) ->assert_is_op_nth_input("concat", "X", i)
->AsIntermediate()); ->AsIntermediate());
} }
...@@ -1612,7 +1614,7 @@ PDNode *patterns::AnakinDetectionPattern::operator()( ...@@ -1612,7 +1614,7 @@ PDNode *patterns::AnakinDetectionPattern::operator()(
return multiclass_nms_out; return multiclass_nms_out;
} }
PDNode *patterns::AnakinFillConstantElementWiseMulFuse::operator()( PDNode *patterns::FillConstantElementWiseMulFuse::operator()(
PDNode *elementwise_op_input) { PDNode *elementwise_op_input) {
auto fill_constant = auto fill_constant =
pattern->NewNode(fill_constant_repr())->assert_is_op("fill_constant"); pattern->NewNode(fill_constant_repr())->assert_is_op("fill_constant");
...@@ -1635,6 +1637,76 @@ PDNode *patterns::AnakinFillConstantElementWiseMulFuse::operator()( ...@@ -1635,6 +1637,76 @@ PDNode *patterns::AnakinFillConstantElementWiseMulFuse::operator()(
return elementwise_mul_out; return elementwise_mul_out;
} }
void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
const std::string &op_type,
const std::string &weight_name,
int times) {
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
// the quant op always be one.
auto quant_op_in_scale =
pattern->NewNode(GetNodeName("quant_op_in_scale"))
->assert_is_op_input("fake_quantize_range_abs_max", "InScale")
->AsInput();
auto quant_op = pattern->NewNode(GetNodeName("quant_op"))
->assert_is_op("fake_quantize_range_abs_max");
auto quant_op_out_scale =
pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output("fake_quantize_range_abs_max", "OutScale")
->assert_is_op_input("fake_dequantize_max_abs", "Scale")
->AsIntermediate();
auto quant_op_out =
pattern->NewNode(GetNodeName("quant_op_out"))
->assert_is_op_output("fake_quantize_range_abs_max", "Out")
->assert_is_op_input(op_type)
->AsIntermediate();
// there are 'times' quantized and dequant op
std::vector<PDNode *> nodes;
for (int i = 0; i < times; i++) {
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op_weight") + std::to_string(i))
->assert_is_op_input(op_type, weight_name)
->AsInput());
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op") + std::to_string(i))
->assert_is_op(op_type));
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op_out") + std::to_string(i))
->assert_is_op_output(op_type)
->assert_is_op_input("fake_dequantize_max_abs", "X")
->AsIntermediate());
nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op") + std::to_string(i))
->assert_is_op("fake_dequantize_max_abs"));
nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op_out") + std::to_string(i))
->assert_is_op_output("fake_dequantize_max_abs", "Out")
->AsOutput());
}
quant_op->LinksFrom({quant_op_input, quant_op_in_scale});
quant_op_out->LinksFrom({quant_op});
for (int i = 0; i < times; i++) {
nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom(
{quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]});
nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOffset]});
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kDequantOpOffset]});
}
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -848,7 +848,8 @@ struct AnakinDetectionPattern : public PatternBase { ...@@ -848,7 +848,8 @@ struct AnakinDetectionPattern : public PatternBase {
AnakinDetectionPattern(PDPattern* pattern, const std::string& name_scope) AnakinDetectionPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "anakin_detect_pattern") {} : PatternBase(pattern, name_scope, "anakin_detect_pattern") {}
PDNode* operator()(std::vector<PDNode*> conv_inputs, int times); PDNode* operator()(std::vector<PDNode*> conv_inputs, int times,
std::string priorbox_type, bool is_reshape);
std::string GetNodeName(const std::string& op_type) { std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type); return PDNodeName(name_scope_, repr_, id_, op_type);
...@@ -859,8 +860,8 @@ struct AnakinDetectionPattern : public PatternBase { ...@@ -859,8 +860,8 @@ struct AnakinDetectionPattern : public PatternBase {
} }
}; };
struct AnakinFillConstantElementWiseMulFuse : public PatternBase { struct FillConstantElementWiseMulFuse : public PatternBase {
AnakinFillConstantElementWiseMulFuse(PDPattern* pattern, FillConstantElementWiseMulFuse(PDPattern* pattern,
const std::string& name_scope) const std::string& name_scope)
: PatternBase(pattern, name_scope, : PatternBase(pattern, name_scope,
"anakin_fillconstant_elementwisemul_fuse") {} "anakin_fillconstant_elementwisemul_fuse") {}
...@@ -874,6 +875,22 @@ struct AnakinFillConstantElementWiseMulFuse : public PatternBase { ...@@ -874,6 +875,22 @@ struct AnakinFillConstantElementWiseMulFuse : public PatternBase {
PATTERN_DECL_NODE(elementwise_mul_out); PATTERN_DECL_NODE(elementwise_mul_out);
}; };
struct QuantDequantOpFuse : public PatternBase {
QuantDequantOpFuse(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quant_dequant_fuse") {}
void operator()(PDNode* quant_op_input, const std::string& op_name,
const std::string& weight_name, int times = 1);
std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type);
}
PDNode* GetPDNode(const std::string& op_type) {
return pattern->RetrieveNode(GetNodeName(op_type));
}
};
} // namespace patterns } // namespace patterns
// Link two ir::Nodes from each other. // Link two ir::Nodes from each other.
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
std::string op_type) {
const std::string pattern_name = "quant_dequant_fuse";
// FusePassBase::Init(pattern_name, graph);
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("fake_quantize_range_abs_max", "X")
->AsInput();
std::string quantized_op_type = "";
std::string weight_name = "";
if (op_type == "conv2d") {
quantized_op_type = "conv2d";
weight_name = "Filter";
} else if (op_type == "conv2d_fusion") {
quantized_op_type = "conv2d_fusion";
weight_name = "Filter";
} else if (op_type == "mul") {
quantized_op_type = "mul";
weight_name = "Y";
} else if (op_type == "fc") {
quantized_op_type = "fc";
weight_name = "W";
} else {
PADDLE_ENFORCE(
"QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for "
"now.");
}
patterns::QuantDequantOpFuse pattern(gpd.mutable_pattern(), pattern_name);
pattern(x, quantized_op_type, weight_name, times);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
PADDLE_ENFORCE(subgraph.count(x));
auto* input_node = subgraph.at(x);
Node* quant_op_in_scale =
subgraph.at(pattern.GetPDNode("quant_op_in_scale"));
Node* quant_op = subgraph.at(pattern.GetPDNode("quant_op"));
Node* quant_op_out_scale =
subgraph.at(pattern.GetPDNode("quant_op_out_scale"));
Node* quant_op_out = subgraph.at(pattern.GetPDNode("quant_op_out"));
std::vector<Node*> nodes;
for (int i = 0; i < times; i++) {
nodes.push_back(subgraph.at(
pattern.GetPDNode("quantized_op_weight" + std::to_string(i))));
nodes.push_back(
subgraph.at(pattern.GetPDNode("quantized_op" + std::to_string(i))));
nodes.push_back(subgraph.at(
pattern.GetPDNode("quantized_op_out" + std::to_string(i))));
nodes.push_back(
subgraph.at(pattern.GetPDNode("dequant_op" + std::to_string(i))));
nodes.push_back(
subgraph.at(pattern.GetPDNode("dequant_op_out" + std::to_string(i))));
}
int bit_length = boost::get<int>(quant_op->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
// Prepare input scale
std::string input_scale_var_name = quant_op->Op()->Input("InScale").front();
PADDLE_ENFORCE(scope);
const LoDTensor& input_scale_tensor =
scope->FindVar(input_scale_var_name)->Get<LoDTensor>();
PADDLE_ENFORCE(paddle::platform::is_cpu_place(input_scale_tensor.place()));
const float* input_scale_data = input_scale_tensor.data<float>();
float input_scale = input_scale_data[0];
std::unordered_set<const Node*> delete_nodes;
for (int i = 0; i < times; i++) {
// max_range = (range * range) / weight_scale
float max_range = boost::get<float>(
nodes[i * kNumFields + kDequantOpOffset]->Op()->GetAttr("max_range"));
float weight_scale = (range * range) / max_range;
auto base_op_desc =
*nodes[i * kNumFields + kQuantizedOpOffset]->Op()->Proto();
std::string new_input = input_node->Name();
std::string new_output =
nodes[i * kNumFields + kDequantOpOutOffset]->Name();
framework::OpDesc new_op_desc(base_op_desc, nullptr);
new_op_desc.SetType(quantized_op_type);
if (quantized_op_type == "conv2d" ||
quantized_op_type == "conv2d_fusion") {
new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetOutput("Output", {new_output});
} else if (quantized_op_type == "fc") {
new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetOutput("Out", {new_output});
} else if (quantized_op_type == "mul") {
new_op_desc.SetInput("X", {new_input});
new_op_desc.SetOutput("Out", {new_output});
}
new_op_desc.SetAttr("enable_int8", true);
new_op_desc.SetAttr("input_scale", input_scale);
new_op_desc.SetAttr("weight_scale", weight_scale);
new_op_desc.Flush();
auto* new_op = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(input_node, new_op);
IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset], new_op);
IR_NODE_LINK_TO(new_op, nodes[i * kNumFields + kDequantOpOutOffset]);
delete_nodes.insert(nodes[i * kNumFields + kQuantizedOpOffset]);
delete_nodes.insert(nodes[i * kNumFields + kQuantizedOpOutOffset]);
delete_nodes.insert(nodes[i * kNumFields + kDequantOpOffset]);
}
delete_nodes.insert(quant_op_in_scale);
delete_nodes.insert(quant_op);
delete_nodes.insert(quant_op_out);
delete_nodes.insert(quant_op_out_scale);
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph, delete_nodes);
};
gpd(graph, handler);
}
void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "quant_dequant_fuse";
FusePassBase::Init(pattern_name, graph);
std::unordered_set<std::string> quantized_op_types = {"conv2d", "mul"};
auto* scope = param_scope();
for (auto& op_type : quantized_op_types) {
for (int i = 1; i <= 6; i++) {
RunQuantDequant(graph, scope, i, op_type);
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(quant_conv2d_dequant_fuse_pass,
paddle::framework::ir::QuantDequantFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class QuantDequantFusePass : public FusePassBase {
public:
virtual ~QuantDequantFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -17,25 +17,24 @@ ...@@ -17,25 +17,24 @@
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.h" #include "paddle/fluid/framework/ir/simplify_anakin_priorbox_detection_out_pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
template <int times> void RunSimplifyAnakinDetection(ir::Graph *graph, int times, bool is_density,
void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl( bool is_reshape) {
ir::Graph *graph) const {
const std::string pattern_name = const std::string pattern_name =
"simplify_anakin_detection_pattern_pass" + std::to_string(times); "simplify_anakin_detection_pattern_pass" + std::to_string(times);
FusePassBase::Init(pattern_name, graph); std::string priorbox_type = is_density ? "density_prior_box" : "prior_box";
GraphPatternDetector gpd; GraphPatternDetector gpd;
std::vector<PDNode *> input_nodes; std::vector<PDNode *> input_nodes;
for (int i = 0; i < times; i++) { for (int i = 0; i < times; i++) {
input_nodes.push_back(gpd.mutable_pattern() input_nodes.push_back(gpd.mutable_pattern()
->NewNode("x" + std::to_string(i)) ->NewNode("x" + std::to_string(i))
->assert_is_op_input("density_prior_box", "Input") ->assert_is_op_input(priorbox_type, "Input")
->AsInput()); ->AsInput());
} }
input_nodes.push_back(gpd.mutable_pattern() input_nodes.push_back(gpd.mutable_pattern()
...@@ -49,7 +48,7 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl( ...@@ -49,7 +48,7 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
->AsInput()); ->AsInput());
patterns::AnakinDetectionPattern pattern(gpd.mutable_pattern(), pattern_name); patterns::AnakinDetectionPattern pattern(gpd.mutable_pattern(), pattern_name);
pattern(input_nodes, times); pattern(input_nodes, times, priorbox_type, is_reshape);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) { Graph *g) {
...@@ -119,8 +118,7 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl( ...@@ -119,8 +118,7 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
boost::get<std::string>(box_coder_op->Op()->GetAttr("code_type")); boost::get<std::string>(box_coder_op->Op()->GetAttr("code_type"));
bool box_normalized = bool box_normalized =
boost::get<bool>(box_coder_op->Op()->GetAttr("box_normalized")); boost::get<bool>(box_coder_op->Op()->GetAttr("box_normalized"));
// auto variance =
// boost::get<std::vector<float>>(box_coder_op->Op()->GetAttr("variance"));
int background_label = int background_label =
boost::get<int>(multiclass_nms->Op()->GetAttr("background_label")); boost::get<int>(multiclass_nms->Op()->GetAttr("background_label"));
float score_threshold = float score_threshold =
...@@ -138,7 +136,6 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl( ...@@ -138,7 +136,6 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
nodes[i * kNumFields + kPriorBoxLocOffset]->Name()); nodes[i * kNumFields + kPriorBoxLocOffset]->Name());
} }
// int axis = boost::get<int>(concat_op1->Op()->GetAttr("axis"));
framework::OpDesc concat1_desc; framework::OpDesc concat1_desc;
concat1_desc.SetType("concat"); concat1_desc.SetType("concat");
concat1_desc.SetInput("X", concat1_input_names); concat1_desc.SetInput("X", concat1_input_names);
...@@ -213,31 +210,24 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl( ...@@ -213,31 +210,24 @@ void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
gpd(graph, handler); gpd(graph, handler);
} }
template class SimplifyAnakinDetectionPatternPass<1>; void SimplifyAnakinDetectionPatternPass::ApplyImpl(ir::Graph *graph) const {
template class SimplifyAnakinDetectionPatternPass<2>; const int pattern_nums = 6;
template class SimplifyAnakinDetectionPatternPass<3>; const std::string pattern_name = "simplify_anakin_detection_pattern_pass";
template class SimplifyAnakinDetectionPatternPass<4>; FusePassBase::Init(pattern_name, graph);
template class SimplifyAnakinDetectionPatternPass<5>; std::vector<bool> options = {true, false};
template class SimplifyAnakinDetectionPatternPass<6>; for (const auto &is_density : options) {
for (const auto &is_reshape : options) {
for (int i = 1; i <= pattern_nums; i++) {
RunSimplifyAnakinDetection(graph, i, is_density, is_reshape);
}
}
}
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(simplify_anakin_detection_pattern_pass, typedef paddle::framework::ir::SimplifyAnakinDetectionPatternPass
paddle::framework::ir::SimplifyAnakinDetectionPatternPass<1>); priorbox_pattern;
REGISTER_PASS(simplify_anakin_priorbox_detection_out_pass, priorbox_pattern);
REGISTER_PASS(simplify_anakin_detection_pattern_pass2,
paddle::framework::ir::SimplifyAnakinDetectionPatternPass<2>);
REGISTER_PASS(simplify_anakin_detection_pattern_pass3,
paddle::framework::ir::SimplifyAnakinDetectionPatternPass<3>);
REGISTER_PASS(simplify_anakin_detection_pattern_pass4,
paddle::framework::ir::SimplifyAnakinDetectionPatternPass<4>);
REGISTER_PASS(simplify_anakin_detection_pattern_pass5,
paddle::framework::ir::SimplifyAnakinDetectionPatternPass<5>);
REGISTER_PASS(simplify_anakin_detection_pattern_pass6,
paddle::framework::ir::SimplifyAnakinDetectionPatternPass<6>);
...@@ -26,7 +26,6 @@ namespace ir { ...@@ -26,7 +26,6 @@ namespace ir {
// these structures will be used as inputs to the concat Op. This pattern will // these structures will be used as inputs to the concat Op. This pattern will
// be detected by our pass. The times here represents the repeat times of this // be detected by our pass. The times here represents the repeat times of this
// structure. // structure.
template <int times>
class SimplifyAnakinDetectionPatternPass : public FusePassBase { class SimplifyAnakinDetectionPatternPass : public FusePassBase {
public: public:
virtual ~SimplifyAnakinDetectionPatternPass() {} virtual ~SimplifyAnakinDetectionPatternPass() {}
......
...@@ -25,11 +25,9 @@ namespace paddle { ...@@ -25,11 +25,9 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
template <int times> void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) {
void TransposeFlattenConcatFusePass<times>::ApplyImpl(ir::Graph *graph) const {
const std::string pattern_name = const std::string pattern_name =
"transpose_flatten" + std::to_string(times) + "_concat_fuse"; "transpose_flatten" + std::to_string(times) + "_concat_fuse";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
std::vector<PDNode *> input_nodes; std::vector<PDNode *> input_nodes;
...@@ -122,31 +120,18 @@ void TransposeFlattenConcatFusePass<times>::ApplyImpl(ir::Graph *graph) const { ...@@ -122,31 +120,18 @@ void TransposeFlattenConcatFusePass<times>::ApplyImpl(ir::Graph *graph) const {
gpd(graph, handler); gpd(graph, handler);
} }
template class TransposeFlattenConcatFusePass<1>; void TransposeFlattenConcatFusePass::ApplyImpl(ir::Graph *graph) const {
template class TransposeFlattenConcatFusePass<2>; const int pattern_nums = 6;
template class TransposeFlattenConcatFusePass<3>; const std::string pattern_name = "transpose_flatten_concat_fuse";
template class TransposeFlattenConcatFusePass<4>; FusePassBase::Init(pattern_name, graph);
template class TransposeFlattenConcatFusePass<5>; for (int i = 1; i <= pattern_nums; i++) {
template class TransposeFlattenConcatFusePass<6>; RunTransposeFlattenConcatFuse(graph, i);
}
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(transpose_flatten_concat_fuse_pass, REGISTER_PASS(transpose_flatten_concat_fuse_pass,
paddle::framework::ir::TransposeFlattenConcatFusePass<1>); paddle::framework::ir::TransposeFlattenConcatFusePass);
REGISTER_PASS(transpose_flatten2_concat_fuse_pass,
paddle::framework::ir::TransposeFlattenConcatFusePass<2>);
REGISTER_PASS(transpose_flatten3_concat_fuse_pass,
paddle::framework::ir::TransposeFlattenConcatFusePass<3>);
REGISTER_PASS(transpose_flatten4_concat_fuse_pass,
paddle::framework::ir::TransposeFlattenConcatFusePass<4>);
REGISTER_PASS(transpose_flatten5_concat_fuse_pass,
paddle::framework::ir::TransposeFlattenConcatFusePass<5>);
REGISTER_PASS(transpose_flatten6_concat_fuse_pass,
paddle::framework::ir::TransposeFlattenConcatFusePass<6>);
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
...@@ -24,7 +26,6 @@ namespace ir { ...@@ -24,7 +26,6 @@ namespace ir {
// these structures will be used as inputs to the concat Op. This pattern will // these structures will be used as inputs to the concat Op. This pattern will
// be detected by our pass. The times here represents the repeat times of this // be detected by our pass. The times here represents the repeat times of this
// structure. // structure.
template <int times>
class TransposeFlattenConcatFusePass : public FusePassBase { class TransposeFlattenConcatFusePass : public FusePassBase {
public: public:
virtual ~TransposeFlattenConcatFusePass() {} virtual ~TransposeFlattenConcatFusePass() {}
......
...@@ -34,25 +34,41 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op, ...@@ -34,25 +34,41 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op,
auto input_name = op_desc.Input("Input").front(); auto input_name = op_desc.Input("Input").front();
auto image_name = op_desc.Input("Image").front(); auto image_name = op_desc.Input("Image").front();
auto output_name = op_desc.Output("Boxes").front(); auto output_name = op_desc.Output("Boxes").front();
auto op_type = op_desc.Type();
auto op_name = op_type + ":" + op_desc.Output("Boxes").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Boxes").front(); // only for density_prior_box
std::vector<float> fixed_sizes = {};
std::vector<float> fixed_ratios = {};
std::vector<int> densities = {};
auto fixed_sizes = std::vector<float> min_sizes = {};
std::vector<float> max_sizes = {};
std::vector<float> aspect_ratios = {};
bool is_clip = false;
bool is_flip = false;
if (op_type == "density_prior_box") {
fixed_sizes =
boost::get<std::vector<float>>(op_desc.GetAttr("fixed_sizes")); boost::get<std::vector<float>>(op_desc.GetAttr("fixed_sizes"));
auto fixed_ratios = fixed_ratios =
boost::get<std::vector<float>>(op_desc.GetAttr("fixed_ratios")); boost::get<std::vector<float>>(op_desc.GetAttr("fixed_ratios"));
auto densities = boost::get<std::vector<int>>(op_desc.GetAttr("densities")); densities = boost::get<std::vector<int>>(op_desc.GetAttr("densities"));
is_clip = boost::get<bool>(op_desc.GetAttr("clip"));
} else if (op_type == "prior_box") {
min_sizes = boost::get<std::vector<float>>(op_desc.GetAttr("min_sizes"));
max_sizes = boost::get<std::vector<float>>(op_desc.GetAttr("max_sizes"));
aspect_ratios =
boost::get<std::vector<float>>(op_desc.GetAttr("aspect_ratios"));
is_clip = boost::get<bool>(op_desc.GetAttr("clip"));
is_flip = boost::get<bool>(op_desc.GetAttr("flip"));
}
std::vector<float> dens; std::vector<float> dens;
for (auto& ele : densities) { for (auto& ele : densities) {
dens.push_back(static_cast<float>(ele)); dens.push_back(static_cast<float>(ele));
} }
// lack flip
// auto clip = boost::get<bool>(op_desc.GetAttr("clip"));
auto variances = boost::get<std::vector<float>>(op_desc.GetAttr("variances")); auto variances = boost::get<std::vector<float>>(op_desc.GetAttr("variances"));
for (auto& ele : variances) {
LOG(INFO) << ele;
}
// lack img_h, img_w // lack img_h, img_w
auto step_h = boost::get<float>(op_desc.GetAttr("step_h")); auto step_h = boost::get<float>(op_desc.GetAttr("step_h"));
...@@ -66,14 +82,14 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op, ...@@ -66,14 +82,14 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op,
std::vector<float> temp_v = {}; std::vector<float> temp_v = {};
engine_->AddOp(op_name, "PriorBox", {input_name, image_name}, {output_name}); engine_->AddOp(op_name, "PriorBox", {input_name, image_name}, {output_name});
engine_->AddOpAttr<PTuple<float>>(op_name, "min_size", temp_v); engine_->AddOpAttr<PTuple<float>>(op_name, "min_size", min_sizes);
engine_->AddOpAttr<PTuple<float>>(op_name, "max_size", temp_v); engine_->AddOpAttr<PTuple<float>>(op_name, "max_size", max_sizes);
engine_->AddOpAttr<PTuple<float>>(op_name, "aspect_ratio", temp_v); engine_->AddOpAttr<PTuple<float>>(op_name, "aspect_ratio", aspect_ratios);
engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_size", fixed_sizes); engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_size", fixed_sizes);
engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_ratio", fixed_ratios); engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_ratio", fixed_ratios);
engine_->AddOpAttr<PTuple<float>>(op_name, "density", dens); engine_->AddOpAttr<PTuple<float>>(op_name, "density", dens);
engine_->AddOpAttr(op_name, "is_flip", static_cast<bool>(false)); engine_->AddOpAttr(op_name, "is_flip", is_flip);
engine_->AddOpAttr(op_name, "is_clip", static_cast<bool>(false)); engine_->AddOpAttr(op_name, "is_clip", is_clip);
engine_->AddOpAttr<PTuple<float>>(op_name, "variance", variances); engine_->AddOpAttr<PTuple<float>>(op_name, "variance", variances);
engine_->AddOpAttr(op_name, "img_h", static_cast<int>(0)); engine_->AddOpAttr(op_name, "img_h", static_cast<int>(0));
engine_->AddOpAttr(op_name, "img_w", static_cast<int>(0)); engine_->AddOpAttr(op_name, "img_w", static_cast<int>(0));
...@@ -88,3 +104,4 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op, ...@@ -88,3 +104,4 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op,
} // namespace paddle } // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(density_prior_box, DensityPriorBoxOpConverter); REGISTER_ANAKIN_OP_CONVERTER(density_prior_box, DensityPriorBoxOpConverter);
REGISTER_ANAKIN_OP_CONVERTER(prior_box, DensityPriorBoxOpConverter);
...@@ -48,7 +48,7 @@ class AnakinOpConverter { ...@@ -48,7 +48,7 @@ class AnakinOpConverter {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
std::string op_type = op_desc.Type(); std::string op_type = op_desc.Type();
AnakinOpConverter *it = nullptr; AnakinOpConverter *it = nullptr;
if (op_type == "depthwise_conv2d") op_type = "conv2d";
if (op_type == "reshape2") op_type = "reshape"; if (op_type == "reshape2") op_type = "reshape";
if (op_type == "transpose2") op_type = "transpose"; if (op_type == "transpose2") op_type = "transpose";
if (op_type == "flatten2") op_type = "flatten"; if (op_type == "flatten2") op_type = "flatten";
......
...@@ -42,6 +42,8 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -42,6 +42,8 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set.insert("dropout"); teller_set.insert("dropout");
teller_set.insert("sigmoid"); teller_set.insert("sigmoid");
teller_set.insert("sum"); teller_set.insert("sum");
teller_set.insert("depthwise_conv2d");
teller_set.insert("prior_box");
} }
bool operator()(const std::string& op_type, bool operator()(const std::string& op_type,
......
...@@ -37,14 +37,14 @@ using framework::ir::Node; ...@@ -37,14 +37,14 @@ using framework::ir::Node;
void analysis::AnakinSubgraphPass::ApplyImpl( void analysis::AnakinSubgraphPass::ApplyImpl(
framework::ir::Graph *graph) const { framework::ir::Graph *graph) const {
framework::ir::FusePassBase::Init("anakin_subgraph_pass", graph.get()); framework::ir::FusePassBase::Init("anakin_subgraph_pass", graph);
auto teller = [](const framework::ir::Node *node) { auto teller = [](const framework::ir::Node *node) {
if (!node->IsOp() || !node->Op()) return false; if (!node->IsOp() || !node->Op()) return false;
return anakin::OpTeller::Global().Tell(node->Op()->Type(), *node->Op()); return anakin::OpTeller::Global().Tell(node->Op()->Type(), *node->Op());
}; };
SubGraphFuser fuser(graph.get(), teller, 6 /* min_subgraph_size */); SubGraphFuser fuser(graph, teller, 6 /* min_subgraph_size */);
fuser(); fuser();
std::vector<std::string> graph_param_names = std::vector<std::string> graph_param_names =
...@@ -56,10 +56,10 @@ void analysis::AnakinSubgraphPass::ApplyImpl( ...@@ -56,10 +56,10 @@ void analysis::AnakinSubgraphPass::ApplyImpl(
for (auto *node : graph->Nodes()) { for (auto *node : graph->Nodes()) {
if (node->IsOp() && !Agent(node).subgraph()->empty()) { if (node->IsOp() && !Agent(node).subgraph()->empty()) {
CreateAnakinOp(node, graph.get(), graph_param_names, &repetitive_params); CreateAnakinOp(node, graph, graph_param_names, &repetitive_params);
std::unordered_set<const Node *> nodes2remove( std::unordered_set<const Node *> nodes2remove(
Agent(node).subgraph()->begin(), Agent(node).subgraph()->end()); Agent(node).subgraph()->begin(), Agent(node).subgraph()->end());
framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove); framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
} }
} }
...@@ -69,7 +69,7 @@ void analysis::AnakinSubgraphPass::ApplyImpl( ...@@ -69,7 +69,7 @@ void analysis::AnakinSubgraphPass::ApplyImpl(
nodes2remove.insert(node); nodes2remove.insert(node);
} }
} }
framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove); framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
graph->Set(framework::ir::kRepetitiveParamAttr, graph->Set(framework::ir::kRepetitiveParamAttr,
new std::vector<std::string>(repetitive_params)); new std::vector<std::string>(repetitive_params));
} }
......
...@@ -192,6 +192,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -192,6 +192,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
block_desc.Proto()->SerializeAsString()); block_desc.Proto()->SerializeAsString());
SetAttr(op_desc->Proto(), "max_batch_size", Get<int>("max_batch_size")); SetAttr(op_desc->Proto(), "max_batch_size", Get<int>("max_batch_size"));
SetAttr(op_desc->Proto(), "workspace_size", Get<int>("workspace_size")); SetAttr(op_desc->Proto(), "workspace_size", Get<int>("workspace_size"));
SetAttr(op_desc->Proto(), "gpu_id", Get<int>("gpu_device_id"));
SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping); SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping);
SetAttr(op_desc->Proto(), "parameters", params); SetAttr(op_desc->Proto(), "parameters", params);
......
...@@ -52,6 +52,7 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { ...@@ -52,6 +52,7 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
for (auto &var_name : all_vars) { for (auto &var_name : all_vars) {
if (std::count(repetitive_params.begin(), repetitive_params.end(), if (std::count(repetitive_params.begin(), repetitive_params.end(),
var_name)) { var_name)) {
scope->EraseVars({var_name});
continue; continue;
} }
auto *var = scope->FindLocalVar(var_name); auto *var = scope->FindLocalVar(var_name);
......
...@@ -886,4 +886,5 @@ USE_ANAKIN_CONVERTER(detection_out); ...@@ -886,4 +886,5 @@ USE_ANAKIN_CONVERTER(detection_out);
USE_ANAKIN_CONVERTER(density_prior_box); USE_ANAKIN_CONVERTER(density_prior_box);
USE_ANAKIN_CONVERTER(dropout); USE_ANAKIN_CONVERTER(dropout);
USE_ANAKIN_CONVERTER(sum); USE_ANAKIN_CONVERTER(sum);
USE_ANAKIN_CONVERTER(prior_box);
#endif #endif
...@@ -71,16 +71,14 @@ void GpuPassStrategy::EnableMKLDNN() { ...@@ -71,16 +71,14 @@ void GpuPassStrategy::EnableMKLDNN() {
// The following passes works for Anakin sub-graph engine. // The following passes works for Anakin sub-graph engine.
const std::vector<std::string> kAnakinSubgraphPasses({ const std::vector<std::string> kAnakinSubgraphPasses({
"infer_clean_graph_pass", // "infer_clean_graph_pass", //
"simplify_anakin_detection_pattern_pass5", // "simplify_anakin_priorbox_detection_out_pass", //
"simplify_anakin_detection_pattern_pass4", // "fillconstant_elementwisemul_fuse", //
"simplify_anakin_detection_pattern_pass3", //
"simplify_anakin_detection_pattern_pass2", //
"anakin_fillconstant_elementwisemul_fuse", //
"fc_fuse_pass", // "fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
"fc_gru_fuse_pass", // "fc_gru_fuse_pass", //
"quant_conv2d_dequant_fuse_pass", //
"anakin_subgraph_pass", "anakin_subgraph_pass",
}); });
...@@ -97,13 +95,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -97,13 +95,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add2_act_fuse_pass", // "conv_elementwise_add2_act_fuse_pass", //
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
"runtime_context_cache_pass", // "runtime_context_cache_pass", //
#endif #endif //
"transpose_flatten_concat_fuse_pass",
}); });
for (int i = 6; i >= 2; i--) {
passes_.push_back("transpose_flatten" + std::to_string(i) +
"_concat_fuse_pass");
}
use_gpu_ = true; use_gpu_ = true;
} }
......
...@@ -52,6 +52,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -52,6 +52,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
std::string engine_key_; std::string engine_key_;
std::string engine_serialized_data_; std::string engine_serialized_data_;
bool calibration_mode_; bool calibration_mode_;
int device_id_;
public: public:
TensorRTEngineOp(const std::string &type, TensorRTEngineOp(const std::string &type,
...@@ -62,6 +63,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -62,6 +63,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
input_names_ = Inputs("Xs"); input_names_ = Inputs("Xs");
max_batch_size_ = Attr<int>("max_batch_size"); max_batch_size_ = Attr<int>("max_batch_size");
workspace_size_ = Attr<int>("workspace_size"); workspace_size_ = Attr<int>("workspace_size");
device_id_ = Attr<int>("gpu_id");
enable_int8_ = Attr<bool>("enable_int8"); enable_int8_ = Attr<bool>("enable_int8");
calibration_data_ = Attr<std::string>("calibration_data"); calibration_data_ = Attr<std::string>("calibration_data");
engine_key_ = Attr<std::string>("engine_key"); engine_key_ = Attr<std::string>("engine_key");
...@@ -79,6 +81,17 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -79,6 +81,17 @@ class TensorRTEngineOp : public framework::OperatorBase {
if (enable_int8_ && calibration_data_.size()) { if (enable_int8_ && calibration_data_.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data_)); calibrator_.reset(new TRTInt8Calibrator(calibration_data_));
} }
if (!calibration_mode_) {
trt_engine_.reset(new inference::tensorrt::TensorRTEngine(
max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(),
device_id_));
PADDLE_ENFORCE(engine_serialized_data_.size(),
"TRT serialized data should not be empty here,"
"there must be error when generate serialized data in TRT "
"subgraph detect pass.");
trt_engine_->Deserialize(engine_serialized_data_);
}
} }
protected: protected:
...@@ -223,15 +236,8 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -223,15 +236,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
TensorRTEngine *GetEngine(const framework::Scope &scope, TensorRTEngine *GetEngine(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
if (!trt_engine_) { if (!trt_engine_) {
trt_engine_.reset(new inference::tensorrt::TensorRTEngine(
max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(),
boost::get<platform::CUDAPlace>(dev_place).device));
if (!engine_serialized_data_.empty()) {
trt_engine_->Deserialize(engine_serialized_data_);
} else {
PrepareTRTEngine(scope, trt_engine_.get()); PrepareTRTEngine(scope, trt_engine_.get());
} }
}
return trt_engine_.get(); return trt_engine_.get();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册