未验证 提交 b2f5a149 编写于 作者: P Pei Yang 提交者: GitHub

[Paddle-TRT] Better Paddle-TensorRT support for PaddleSlim quant models (#25097)

* Paddle-TensorRT support slim QAT. test=develop

* add comments. test=develop

* use RenameInput instead of ResetInputs. test=develop
上级 a965ac4c
...@@ -1980,99 +1980,58 @@ PDNode *patterns::TransposeFlattenConcat::operator()( ...@@ -1980,99 +1980,58 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
return concat_out; return concat_out;
} }
void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node,
const std::string &op_type, const std::string &quant_type) {
const std::string &weight_name, auto *input_scale_node = pattern->NewNode(GetNodeName("input_scale_node"))
int times,
const std::string &quant_type,
const std::string &dequant_type) {
int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
const int kDequantOpWeightScaleOffset = 5;
// the quant op always be one.
auto quant_op_in_scale = pattern->NewNode(GetNodeName("quant_op_in_scale"))
->assert_is_op_input(quant_type, "InScale") ->assert_is_op_input(quant_type, "InScale")
->AsInput(); ->AsInput();
auto quant_op = auto *quant_node =
pattern->NewNode(GetNodeName("quant_op"))->assert_is_op(quant_type); pattern->NewNode(GetNodeName("quant_node"))->assert_is_op(quant_type);
auto *output_scale_node = pattern->NewNode(GetNodeName("output_scale_node"))
PDNode *quant_op_out_scale = nullptr; ->assert_is_op_output(quant_type, "OutScale")
->AsOutput();
auto *output_act_node = pattern->NewNode(GetNodeName("output_act_node"))
->assert_is_op_output(quant_type, "Out")
->AsOutput();
quant_node->LinksFrom({input_scale_node, input_act_node});
output_scale_node->LinksFrom({quant_node});
output_act_node->LinksFrom({quant_node});
}
void patterns::DequantOpFuse::operator()(PDNode *quantized_op_input,
const std::string &quantized_op_type,
const std::string &dequant_type,
const std::string &weight_name) {
auto *quantized_op_weight =
pattern->NewNode(GetNodeName("quantized_op_weight"))
->assert_is_op_input(quantized_op_type, weight_name)
->AsInput();
auto *quantized_op = pattern->NewNode(GetNodeName("quantized_op"))
->assert_is_op(quantized_op_type);
auto *quantized_op_out = pattern->NewNode(GetNodeName("quantized_op_out"))
->assert_is_op_output(quantized_op_type)
->assert_is_op_input(dequant_type, "X");
auto *dequant_op =
pattern->NewNode(GetNodeName("dequant_op"))->assert_is_op(dequant_type);
auto *dequant_op_out = pattern->NewNode(GetNodeName("dequant_op_out"))
->assert_is_op_output(dequant_type, "Out")
->AsOutput();
PDNode *dequant_channel_scale = nullptr;
if (dequant_type == "fake_channel_wise_dequantize_max_abs") { if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
kNumFields += 1; dequant_channel_scale =
quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale")) pattern->NewNode(GetNodeName("dequant_channel_scale"))
->assert_is_op_output(quant_type, "OutScale") ->assert_is_op_nth_input(dequant_type, "Scales", 0)
->assert_is_op_nth_input(dequant_type, "Scales", 1) ->AsInput();
->AsIntermediate();
} else {
quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_input(dequant_type, "Scale")
->AsIntermediate();
} }
quantized_op->LinksFrom({quantized_op_input, quantized_op_weight});
quantized_op_out->LinksFrom({quantized_op});
auto quant_op_out = pattern->NewNode(GetNodeName("quant_op_out")) if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
->assert_is_op_output(quant_type, "Out") dequant_op->LinksFrom({quantized_op_out, dequant_channel_scale});
->assert_is_op_input(op_type) } else {
->AsIntermediate(); dequant_op->LinksFrom({quantized_op_out});
// there are 'times' quantized and dequant op
std::vector<PDNode *> nodes;
for (int i = 0; i < times; i++) {
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op_weight") + std::to_string(i))
->assert_is_op_input(op_type, weight_name)
->AsInput());
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op") + std::to_string(i))
->assert_is_op(op_type));
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op_out") + std::to_string(i))
->assert_is_op_output(op_type)
->assert_is_op_input(dequant_type, "X")
->AsIntermediate());
nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op") + std::to_string(i))
->assert_is_op(dequant_type));
nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op_out") + std::to_string(i))
->assert_is_op_output(dequant_type, "Out")
->AsOutput());
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
nodes.push_back(pattern
->NewNode(GetNodeName("dequant_channel_scale") +
std::to_string(i))
->assert_is_op_nth_input(dequant_type, "Scales", 0)
->AsInput());
}
}
quant_op->LinksFrom({quant_op_input, quant_op_in_scale});
quant_op_out->LinksFrom({quant_op});
for (int i = 0; i < times; i++) {
nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom(
{quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]});
nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOffset]});
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale,
nodes[i * kNumFields + kDequantOpWeightScaleOffset]});
} else {
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
}
nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kDequantOpOffset]});
} }
dequant_op_out->LinksFrom({dequant_op});
} }
void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) { void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
......
...@@ -1150,14 +1150,28 @@ struct TransposeFlattenConcat : public PatternBase { ...@@ -1150,14 +1150,28 @@ struct TransposeFlattenConcat : public PatternBase {
} }
}; };
struct QuantDequantOpFuse : public PatternBase { struct DeleteQuantOpFuse : public PatternBase {
QuantDequantOpFuse(PDPattern* pattern, const std::string& name_scope) DeleteQuantOpFuse(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quant_dequant_fuse") {} : PatternBase(pattern, name_scope, "delete_quant_fuse") {}
void operator()(PDNode* quant_op_input, const std::string& op_name, void operator()(PDNode* input_act_node, const std::string& quant_type);
const std::string& weight_name, int times,
const std::string& quant_type, std::string GetNodeName(const std::string& op_type) {
const std::string& dequant_type); return PDNodeName(name_scope_, repr_, id_, op_type);
}
PDNode* GetPDNode(const std::string& op_type) {
return pattern->RetrieveNode(GetNodeName(op_type));
}
};
struct DequantOpFuse : public PatternBase {
DequantOpFuse(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "dequant_fuse") {}
void operator()(PDNode* quant_op_input, const std::string& quantized_op_type,
const std::string& dequant_type,
const std::string& weight_name);
std::string GetNodeName(const std::string& op_type) { std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type); return PDNodeName(name_scope_, repr_, id_, op_type);
......
...@@ -24,159 +24,218 @@ namespace paddle { ...@@ -24,159 +24,218 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void RunQuantDequant(ir::Graph* graph, Scope* scope, int times, // Delete quant op before quantized ops, and set input scale in the attr of
const std::string& op_type, const std::string& quant_type, // quantized ops
const std::string& dequant_type) { void DeleteQuant(ir::Graph* graph, Scope* scope,
const std::string pattern_name = "quant_dequant_fuse"; const std::string& quant_type) {
int kNumFields = 5; const std::string pattern_name = "delete_quant_fuse";
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
const int kDequantOpWeightScaleOffset = 5;
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
kNumFields += 1;
}
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern() auto* input_act_node = gpd.mutable_pattern()
->NewNode("x") ->NewNode("input_act_node")
->assert_is_op_input(quant_type, "X") ->assert_is_op_input(quant_type, "X")
->AsInput(); ->AsInput();
// Create pattern
patterns::DeleteQuantOpFuse pattern(gpd.mutable_pattern(), pattern_name);
pattern(input_act_node, quant_type);
// extract input scale from quant op input to set it in attr of all quantized
// ops linked from it
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
PADDLE_ENFORCE_EQ(subgraph.count(input_act_node), true,
platform::errors::NotFound(
"Input act node not found in Delete Quant fusion."));
Node* input_act = subgraph.at(input_act_node);
Node* input_scale = subgraph.at(pattern.GetPDNode("input_scale_node"));
Node* quant = subgraph.at(pattern.GetPDNode("quant_node"));
Node* output_scale = subgraph.at(pattern.GetPDNode("output_scale_node"));
Node* output_act = subgraph.at(pattern.GetPDNode("output_act_node"));
int bit_length = BOOST_GET_CONST(int, quant->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
// Get input scale from tensor
std::string input_scale_var_name = quant->Op()->Input("InScale").front();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument(
"scope in DeleteQuantOpFuse pass should not be null."));
const LoDTensor& input_scale_tensor =
scope->FindVar(input_scale_var_name)->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(input_scale_tensor.place()), true,
platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>();
float in_scale = input_scale_data[0];
float scale_value = in_scale / range;
// Set input scale in attr, and relink nodes
std::string input_act_name = input_act->Var()->Name();
std::string output_act_name = output_act->Var()->Name();
auto outlinks = output_act->outputs;
for (auto* quantized_node : outlinks) {
auto op_desc = quantized_node->Op();
std::string quantized_op_type = op_desc->Type();
if (quantized_op_type == "conv2d" ||
quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d" ||
quantized_op_type == "fc") {
op_desc->SetAttr("Input_scale", scale_value);
} else if (quantized_op_type == "mul") {
op_desc->SetAttr("X_scale", scale_value);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type %s", quantized_op_type));
}
op_desc->SetAttr("bit_length", bit_length);
op_desc->RenameInput(output_act_name, input_act_name);
op_desc->Flush();
IR_NODE_LINK_TO(input_act, quantized_node);
}
// Delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {input_scale, quant,
output_scale, output_act};
GraphSafeRemoveNodes(graph, nodes2rm);
};
gpd(graph, handler);
}
std::string quantized_op_type = op_type; // Delete dequant op after quantized ops, and convert weight from fp32 range to
// int8 range
void FuseDequant(ir::Graph* graph, Scope* scope,
const std::string& quantized_op_type,
const std::string& dequant_type) {
std::string weight_name = ""; std::string weight_name = "";
if (op_type == "conv2d" || op_type == "depthwise_conv2d" || std::string input_name = "";
op_type == "conv2d_fusion") { if (quantized_op_type == "conv2d" ||
quantized_op_type == "depthwise_conv2d" ||
quantized_op_type == "conv2d_fusion") {
weight_name = "Filter"; weight_name = "Filter";
} else if (op_type == "mul") { input_name = "Input";
} else if (quantized_op_type == "mul") {
weight_name = "Y"; weight_name = "Y";
} else if (op_type == "fc") { input_name = "X";
} else if (quantized_op_type == "fc") {
weight_name = "W"; weight_name = "W";
input_name = "Input";
} else { } else {
PADDLE_ENFORCE( PADDLE_ENFORCE(
"QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for " "QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for "
"now."); "now.");
} }
const std::string pattern_name = "dequant_fuse";
GraphPatternDetector gpd;
auto* quantized_op_input =
gpd.mutable_pattern()
->NewNode("quantized_op_input")
->assert_is_op_input(quantized_op_type, input_name)
->AsInput();
patterns::QuantDequantOpFuse pattern(gpd.mutable_pattern(), pattern_name); // Create pattern
pattern(x, quantized_op_type, weight_name, times, quant_type, dequant_type); patterns::DequantOpFuse pattern(gpd.mutable_pattern(), pattern_name);
pattern(quantized_op_input, quantized_op_type, dequant_type, weight_name);
// Create new op desc
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
PADDLE_ENFORCE(subgraph.count(x)); PADDLE_ENFORCE_EQ(
auto* input_node = subgraph.at(x); subgraph.count(quantized_op_input), true,
Node* quant_op_in_scale = platform::errors::NotFound(
subgraph.at(pattern.GetPDNode("quant_op_in_scale")); "Quantized op input node not found in Delete Quant fusion."));
Node* quant_op = subgraph.at(pattern.GetPDNode("quant_op")); Node* quantized_op_input_node = subgraph.at(quantized_op_input);
Node* quant_op_out_scale = Node* quantized_op_weight_node =
subgraph.at(pattern.GetPDNode("quant_op_out_scale")); subgraph.at(pattern.GetPDNode("quantized_op_weight"));
Node* quant_op_out = subgraph.at(pattern.GetPDNode("quant_op_out")); Node* quantized_op_node = subgraph.at(pattern.GetPDNode("quantized_op"));
Node* dequant_op_node = subgraph.at(pattern.GetPDNode("dequant_op"));
std::vector<Node*> nodes; Node* dequant_op_out_node =
for (int i = 0; i < times; i++) { subgraph.at(pattern.GetPDNode("dequant_op_out"));
nodes.push_back(subgraph.at(
pattern.GetPDNode("quantized_op_weight" + std::to_string(i))));
nodes.push_back(
subgraph.at(pattern.GetPDNode("quantized_op" + std::to_string(i))));
nodes.push_back(subgraph.at(
pattern.GetPDNode("quantized_op_out" + std::to_string(i))));
nodes.push_back(
subgraph.at(pattern.GetPDNode("dequant_op" + std::to_string(i))));
nodes.push_back(
subgraph.at(pattern.GetPDNode("dequant_op_out" + std::to_string(i))));
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
nodes.push_back(subgraph.at(
pattern.GetPDNode("dequant_channel_scale" + std::to_string(i))));
}
}
std::unordered_set<const Node*> nodes2rm = {};
int bit_length = int bit_length =
BOOST_GET_CONST(int, quant_op->Op()->GetAttr("bit_length")); BOOST_GET_CONST(int, quantized_op_node->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1); int range = ((1 << (bit_length - 1)) - 1);
// Prepare input scale std::vector<float> weight_scale;
std::string input_scale_var_name = quant_op->Op()->Input("InScale").front();
PADDLE_ENFORCE(scope);
const LoDTensor& input_scale_tensor =
scope->FindVar(input_scale_var_name)->Get<LoDTensor>();
PADDLE_ENFORCE(paddle::platform::is_cpu_place(input_scale_tensor.place())); // Get weight scale
const float* input_scale_data = input_scale_tensor.data<float>(); if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
float input_scale = input_scale_data[0]; Node* dequant_channel_scale_node =
std::unordered_set<const Node*> delete_nodes; subgraph.at(pattern.GetPDNode("dequant_channel_scale"));
auto scales_name = dequant_op_node->Op()->Input("Scales");
for (int i = 0; i < times; i++) { PADDLE_ENFORCE_EQ(
std::vector<float> weight_scale; scales_name.size(), 2,
platform::errors::InvalidArgument(
// Get weight scale from dequant op. "Scales size in channel-wise dequantize op should be 2, got %d",
if (dequant_type == "fake_channel_wise_dequantize_max_abs") { scales_name.size()));
auto scales_name = const LoDTensor& channel_scale_tensor =
nodes[i * kNumFields + kDequantOpOffset]->Op()->Input("Scales"); scope->FindVar(scales_name[0])->Get<LoDTensor>();
PADDLE_ENFORCE(scales_name.size() == 2); PADDLE_ENFORCE_EQ(
const LoDTensor& channel_scale_tensor = paddle::platform::is_cpu_place(channel_scale_tensor.place()), true,
scope->FindVar(scales_name[0])->Get<LoDTensor>(); platform::errors::InvalidArgument(
PADDLE_ENFORCE( "Channel scale tensor's place should be CPU."));
paddle::platform::is_cpu_place(channel_scale_tensor.place())); const float* channel_scale_data = channel_scale_tensor.data<float>();
const float* channel_scale_data = channel_scale_tensor.data<float>(); for (int i = 0; i < channel_scale_tensor.numel(); i++) {
for (int i = 0; i < channel_scale_tensor.numel(); i++) { weight_scale.push_back(channel_scale_data[i] / range);
weight_scale.push_back(channel_scale_data[i]);
}
delete_nodes.insert(
nodes[i * kNumFields + kDequantOpWeightScaleOffset]);
} else {
float max_range = BOOST_GET_CONST(
float, nodes[i * kNumFields + kDequantOpOffset]->Op()->GetAttr(
"max_range"));
weight_scale.push_back((range * range) / max_range);
} }
nodes2rm.insert(dequant_channel_scale_node);
} else {
float max_range =
BOOST_GET_CONST(float, dequant_op_node->Op()->GetAttr("max_range"));
weight_scale.push_back((range * range) / max_range / range);
}
// create new op_desc // Convert weight to fp32 range
auto base_op_desc = auto* weight_tensor =
*nodes[i * kNumFields + kQuantizedOpOffset]->Op()->Proto(); scope->Var(quantized_op_weight_node->Name())->GetMutable<LoDTensor>();
std::string new_input = input_node->Name(); auto w_dims = weight_tensor->dims();
std::string new_output = // If quantized op is fc, weight scale size = 1;
nodes[i * kNumFields + kDequantOpOutOffset]->Name(); // If quantized op is conv, weight scale size = weight dims[0]
bool valid_scale_size =
framework::OpDesc new_op_desc(base_op_desc, nullptr); (weight_scale.size() == 1 ||
new_op_desc.SetType(quantized_op_type); weight_scale.size() == static_cast<size_t>(w_dims[0]));
new_op_desc.SetAttr("enable_int8", true); PADDLE_ENFORCE_EQ(valid_scale_size, true,
platform::errors::InvalidArgument(
if (quantized_op_type == "conv2d" || "TRT int8 quant: invalid scale size"));
quantized_op_type == "conv2d_fusion" || float* quantized_weight_data =
quantized_op_type == "depthwise_conv2d") { weight_tensor->mutable_data<float>(platform::CPUPlace());
new_op_desc.SetInput("Input", {new_input}); for (int j = 0; j < weight_tensor->numel(); j++) {
new_op_desc.SetAttr("Input_scale", input_scale); if (weight_scale.size() == 1) {
new_op_desc.SetOutput("Output", {new_output}); quantized_weight_data[j] *= weight_scale[0];
} else if (quantized_op_type == "fc") { } else {
new_op_desc.SetInput("Input", {new_input}); int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
new_op_desc.SetAttr("Input_scale", input_scale); quantized_weight_data[j] *= weight_scale[j / inner_size];
new_op_desc.SetOutput("Out", {new_output});
} else if (quantized_op_type == "mul") {
new_op_desc.SetInput("X", {new_input});
new_op_desc.SetAttr("X_scale", input_scale);
new_op_desc.SetOutput("Out", {new_output});
} }
}
new_op_desc.SetAttr("weight_scale", weight_scale); // create new op_desc
new_op_desc.Flush(); auto base_op_desc = *quantized_op_node->Op()->Proto();
auto* new_op = graph->CreateOpNode(&new_op_desc); std::string new_input = quantized_op_input_node->Name();
IR_NODE_LINK_TO(input_node, new_op); std::string new_output = dequant_op_out_node->Name();
IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset], new_op);
IR_NODE_LINK_TO(new_op, nodes[i * kNumFields + kDequantOpOutOffset]);
delete_nodes.insert(nodes[i * kNumFields + kQuantizedOpOffset]); framework::OpDesc new_op_desc(base_op_desc, nullptr);
delete_nodes.insert(nodes[i * kNumFields + kQuantizedOpOutOffset]); new_op_desc.SetType(quantized_op_type);
delete_nodes.insert(nodes[i * kNumFields + kDequantOpOffset]); new_op_desc.SetAttr("enable_int8", true);
if (quantized_op_type == "conv2d" || quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d") {
new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetOutput("Output", {new_output});
} else if (quantized_op_type == "fc") {
new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetOutput("Out", {new_output});
} else if (quantized_op_type == "mul") {
new_op_desc.SetInput("X", {new_input});
new_op_desc.SetOutput("Out", {new_output});
} }
new_op_desc.SetAttr("weight_scale", weight_scale);
delete_nodes.insert(quant_op_in_scale); new_op_desc.Flush();
delete_nodes.insert(quant_op); auto* new_op = graph->CreateOpNode(&new_op_desc);
delete_nodes.insert(quant_op_out); IR_NODE_LINK_TO(quantized_op_input_node, new_op);
delete_nodes.insert(quant_op_out_scale); IR_NODE_LINK_TO(quantized_op_weight_node, new_op);
// Delete the unneeded nodes. IR_NODE_LINK_TO(new_op, dequant_op_out_node);
GraphSafeRemoveNodes(graph, delete_nodes); // Delete nodes and edges
nodes2rm.insert(quantized_op_node);
nodes2rm.insert(dequant_op_node);
GraphSafeRemoveNodes(graph, nodes2rm);
}; };
gpd(graph, handler); gpd(graph, handler);
} }
...@@ -186,19 +245,19 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -186,19 +245,19 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(pattern_name, graph); FusePassBase::Init(pattern_name, graph);
std::unordered_set<std::string> dequant_types = { std::unordered_set<std::string> dequant_types = {
"fake_dequantize_max_abs", "fake_channel_wise_dequantize_max_abs"}; "fake_channel_wise_dequantize_max_abs", "fake_dequantize_max_abs"};
std::unordered_set<std::string> quant_types = { std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = {"conv2d", "mul", std::unordered_set<std::string> quantized_op_types = {
"depthwise_conv2d"}; "conv2d", "mul", "depthwise_conv2d", "fc"};
auto* scope = param_scope(); auto* scope = param_scope();
for (auto& quant_type : quant_types) {
DeleteQuant(graph, scope, quant_type);
}
for (auto& dequant_type : dequant_types) { for (auto& dequant_type : dequant_types) {
for (auto& quant_type : quant_types) { for (auto& quantized_op_type : quantized_op_types) {
for (auto& op_type : quantized_op_types) { FuseDequant(graph, scope, quantized_op_type, dequant_type);
for (int i = 6; i >= 1; i--) {
RunQuantDequant(graph, scope, i, op_type, quant_type, dequant_type);
}
}
} }
} }
} }
......
...@@ -22,6 +22,9 @@ namespace paddle { ...@@ -22,6 +22,9 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
///
/// Fuse quant + conv2d/depthwise_conv2d/mul/fc + dequant
///
class QuantDequantFusePass : public FusePassBase { class QuantDequantFusePass : public FusePassBase {
public: public:
virtual ~QuantDequantFusePass() {} virtual ~QuantDequantFusePass() {}
......
...@@ -365,6 +365,10 @@ const std::vector<std::string> &OpDesc::Output(const std::string &name) const { ...@@ -365,6 +365,10 @@ const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
return it->second; return it->second;
} }
bool OpDesc::HasOutput(const std::string &name) const {
return outputs_.find(name) != outputs_.end();
}
std::vector<std::string> OpDesc::OutputArgumentNames() const { std::vector<std::string> OpDesc::OutputArgumentNames() const {
std::vector<std::string> retv; std::vector<std::string> retv;
for (auto &ipt : this->outputs_) { for (auto &ipt : this->outputs_) {
......
...@@ -57,6 +57,8 @@ class OpDesc { ...@@ -57,6 +57,8 @@ class OpDesc {
const std::vector<std::string> &Output(const std::string &name) const; const std::vector<std::string> &Output(const std::string &name) const;
bool HasOutput(const std::string &name) const;
std::vector<std::string> OutputArgumentNames() const; std::vector<std::string> OutputArgumentNames() const;
void SetOutput(const std::string &param_name, void SetOutput(const std::string &param_name,
......
...@@ -281,11 +281,8 @@ void AnalysisConfig::Update() { ...@@ -281,11 +281,8 @@ void AnalysisConfig::Update() {
if (use_tensorrt_) { if (use_tensorrt_) {
pass_builder()->ClearPasses(); pass_builder()->ClearPasses();
bool use_calib_int8 =
(tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8) &&
trt_use_calib_mode_;
for (const auto &pass : kTRTSubgraphPasses) { for (const auto &pass : kTRTSubgraphPasses) {
if (use_calib_int8 && if (tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
(pass == "conv_bn_fuse_pass" || pass == "fc_fuse_pass")) { (pass == "conv_bn_fuse_pass" || pass == "fc_fuse_pass")) {
continue; continue;
} }
......
...@@ -52,7 +52,8 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, ...@@ -52,7 +52,8 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
if (enable_int8) { if (enable_int8) {
#if IS_TRT_VERSION_GE(5000) #if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr("Input_scale")); CHECK(op_desc.HasAttr("Input_scale"));
float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); float in_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127;
auto weight_scale = auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale")); BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t,
......
...@@ -62,7 +62,7 @@ class FcOpConverter : public OpConverter { ...@@ -62,7 +62,7 @@ class FcOpConverter : public OpConverter {
#if IS_TRT_VERSION_GE(5000) #if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr(i_name + "_scale")); CHECK(op_desc.HasAttr(i_name + "_scale"));
float in_scale = float in_scale =
BOOST_GET_CONST(float, op_desc.GetAttr(i_name + "_scale")); BOOST_GET_CONST(float, op_desc.GetAttr(i_name + "_scale")) * 127;
auto weight_scale = auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale")); BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(),
......
...@@ -98,8 +98,33 @@ class OpConverter { ...@@ -98,8 +98,33 @@ class OpConverter {
} }
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type()); op_desc.Type());
it->SetEngine(engine); it->SetEngine(engine);
(*it)(op, scope, test_mode); (*it)(op, scope, test_mode);
bool has_out_scale = op_desc.HasAttr("out_threshold");
if (has_out_scale) {
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
std::string output_name = "";
if (op_desc.HasOutput("Output")) {
output_name = op_desc.Output("Output").front();
} else if (op_desc.HasOutput("Out")) {
output_name = op_desc.Output("Out").front();
} else if (op_desc.HasOutput("Y")) {
output_name = op_desc.Output("Y").front();
} else {
PADDLE_THROW(
platform::errors::NotFound("Op %s has out threshold but doesn't "
"have an output named \"Output\", "
"\"Out\" or \"Y\".",
op_desc.Type()));
}
auto* output_itensor = engine->GetITensor(output_name);
engine->SetTensorDynamicRange(output_itensor, out_scale);
VLOG(1) << "Set out scale = " << out_scale << " for tensor "
<< output_name << ".";
}
} }
// Convert a fluid block to tensorrt network, NOTE it just convert operators, // Convert a fluid block to tensorrt network, NOTE it just convert operators,
......
...@@ -124,23 +124,42 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -124,23 +124,42 @@ void TensorRTEngine::FreezeNetwork() {
<< ", this might be ok when trt does not need this range"; << ", this might be ok when trt does not need this range";
} }
} }
std::unordered_set<std::string> all_out_t_name; auto is_layer_int8 = [&](nvinfer1::ILayer *layer) -> bool {
for (int i = 0; i < network()->getNbOutputs(); i++) { for (int j = 0; j < layer->getNbInputs(); j++) {
auto *temp = network()->getOutput(i); auto *temp_in = layer->getInput(j);
temp->setDynamicRange(-1, 1); if (!temp_in->dynamicRangeIsSet()) {
all_out_t_name.insert(temp->getName()); VLOG(1) << "Layer(Name: " << layer->getName()
} << ") is set to float32 because its input("
<< temp_in->getName() << ") doesn't have dynamic range.";
for (int i = 0; i < network()->getNbLayers(); i++) { return false;
auto layer = network()->getLayer(i); }
}
for (int j = 0; j < layer->getNbOutputs(); j++) { for (int j = 0; j < layer->getNbOutputs(); j++) {
auto *temp_out = layer->getOutput(j); auto *temp_out = layer->getOutput(j);
if (std::find(all_out_t_name.begin(), all_out_t_name.end(), if (temp_out->isNetworkOutput()) {
temp_out->getName()) != all_out_t_name.end()) { VLOG(1) << "Layer(Name: " << layer->getName()
layer->setPrecision(nvinfer1::DataType::kFLOAT); << ") is set to float32 because its output("
layer->setOutputType(j, nvinfer1::DataType::kFLOAT); << temp_out->getName() << ") is the output of the network.";
return false;
}
if (!temp_out->dynamicRangeIsSet()) {
VLOG(1) << "Layer(Name: " << layer->getName()
<< ") is set to float32 because its output("
<< temp_out->getName() << ") doesn't have dynamic range.";
return false;
} }
} }
return true;
};
// If a layer's output is the network's output, or not all of its inputs
// and outputs have scales,
// this layer's precision and output type are set to float32.
// This step has no effect if this layer is fused during TRT optimization.
for (int i = 0; i < network()->getNbLayers(); i++) {
auto layer = network()->getLayer(i);
if (!is_layer_int8(layer)) {
layer->setPrecision(nvinfer1::DataType::kFLOAT);
}
} }
#endif #endif
} }
...@@ -237,7 +256,6 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name, ...@@ -237,7 +256,6 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
std::string name_suffix = std::to_string(name_suffix_counter); std::string name_suffix = std::to_string(name_suffix_counter);
std::string splitter = "__"; std::string splitter = "__";
std::string name_with_suffix = name + splitter + name_suffix; std::string name_with_suffix = name + splitter + name_suffix;
auto w_dims = weight_tensor->dims();
platform::CPUPlace cpu_place; platform::CPUPlace cpu_place;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_map.count(name_with_suffix), 0, weight_map.count(name_with_suffix), 0,
...@@ -250,25 +268,6 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name, ...@@ -250,25 +268,6 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
float *weight_data = float *weight_data =
weight_map[name_with_suffix]->mutable_data<float>(cpu_place); weight_map[name_with_suffix]->mutable_data<float>(cpu_place);
name_suffix_counter += 1; name_suffix_counter += 1;
if (enable_int8) {
// when the op is fc, scale's size should be 1
// when the op is conv, scale's size should be w_dims[0]
bool valid_scale_size =
(scale.size() == 1 || scale.size() == static_cast<size_t>(w_dims[0]));
PADDLE_ENFORCE(valid_scale_size, "TRT int8 quant: invalid scale size");
for (int i = 0; i < weight_tensor->numel(); i++) {
if (scale.size() == 1) {
weight_data[i] *= (scale[0] / 127);
} else {
PADDLE_ENFORCE(w_dims.size() == 4,
"TRT int8 quant : We only use the channel quant for "
"conv op, so the weight dims should be 4.");
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
weight_data[i] *= (scale[i / inner_size] / 127);
}
}
}
return weight_data; return weight_data;
} }
......
...@@ -43,11 +43,18 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -43,11 +43,18 @@ struct SimpleOpTypeSetTeller : public Teller {
private: private:
// use this set for no calib int8. // use this set for no calib int8.
std::unordered_set<std::string> int8_teller_set{ std::unordered_set<std::string> int8_teller_set{"mul",
"mul", "conv2d", "pool2d", "conv2d",
"relu", "depthwise_conv2d", "softmax", "pool2d",
"batch_norm", "elementwise_add", "leaky_relu", "relu",
"fc"}; "depthwise_conv2d",
"softmax",
"batch_norm",
"elementwise_add",
"leaky_relu",
"fc",
"relu6",
"concat"};
std::unordered_set<std::string> teller_set{ std::unordered_set<std::string> teller_set{
"mul", "mul",
"conv2d", "conv2d",
......
...@@ -405,6 +405,14 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -405,6 +405,14 @@ if(WITH_GPU AND TENSORRT_FOUND)
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_QUANT_RESNET_DIR}) ARGS --infer_model=${TRT_MODEL_QUANT_RESNET_DIR})
set(TRT_MODEL_QUANT_YOLOV3_DIR "${INFERENCE_DEMO_INSTALL_DIR}/yolov3_r50_quant_aware")
if (NOT EXISTS ${TRT_MODEL_QUANT_YOLOV3_DIR})
inference_download_and_uncompress(${INFERENCE_DEMO_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "yolov3_r50_quant_aware.tgz")
endif()
inference_analysis_test(trt_quant_int8_yolov3_r50_test SRCS trt_quant_int8_yolov3_r50_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_QUANT_YOLOV3_DIR})
set(TEST_TRT_DYNAMIC_MODEL2 "${TRT_MODEL_INSTALL_DIR}/complex_model_dynamic") set(TEST_TRT_DYNAMIC_MODEL2 "${TRT_MODEL_INSTALL_DIR}/complex_model_dynamic")
if (NOT EXISTS ${TEST_TRT_DYNAMIC_MODEL2}) if (NOT EXISTS ${TEST_TRT_DYNAMIC_MODEL2})
inference_download_and_uncompress(${TEST_TRT_DYNAMIC_MODEL2} ${INFERENCE_URL}/tensorrt_test "complex_model_dynamic2.tar.gz") inference_download_and_uncompress(${TEST_TRT_DYNAMIC_MODEL2} ${INFERENCE_URL}/tensorrt_test "complex_model_dynamic2.tar.gz")
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <numeric>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle {
namespace inference {
TEST(quant_int8, yolov3_resnet50) {
AnalysisConfig config;
config.EnableUseGpu(100, 0);
config.SetModel(FLAGS_infer_model + "/model", FLAGS_infer_model + "/params");
config.SwitchUseFeedFetchOps(false);
config.EnableTensorRtEngine(1 << 30, 1, 3, AnalysisConfig::Precision::kInt8,
false, false);
auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames();
int channels = 3;
int height = 608;
int width = 608;
int input_num = channels * height * width * 1;
float *input = new float[input_num];
int32_t *im_shape = new int32_t[2];
im_shape[0] = 608;
im_shape[1] = 608;
memset(input, 1.0, input_num * sizeof(float));
auto input_t = predictor->GetInputTensor(input_names[0]);
input_t->Reshape({1, channels, height, width});
input_t->copy_from_cpu(input);
auto input_t1 = predictor->GetInputTensor(input_names[1]);
input_t1->Reshape({1, 2});
input_t1->copy_from_cpu(im_shape);
ASSERT_TRUE(predictor->ZeroCopyRun());
std::vector<float> out_data;
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputTensor(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
out_data.resize(out_num);
output_t->copy_to_cpu(out_data.data());
}
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册