未验证 提交 b2f5a149 编写于 作者: P Pei Yang 提交者: GitHub

[Paddle-TRT] Better Paddle-TensorRT support for PaddleSlim quant models (#25097)

* Paddle-TensorRT support slim QAT. test=develop

* add comments. test=develop

* use RenameInput instead of ResetInputs. test=develop
上级 a965ac4c
......@@ -1980,99 +1980,58 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
return concat_out;
}
void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
const std::string &op_type,
const std::string &weight_name,
int times,
const std::string &quant_type,
const std::string &dequant_type) {
int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
const int kDequantOpWeightScaleOffset = 5;
// the quant op always be one.
auto quant_op_in_scale = pattern->NewNode(GetNodeName("quant_op_in_scale"))
void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node,
const std::string &quant_type) {
auto *input_scale_node = pattern->NewNode(GetNodeName("input_scale_node"))
->assert_is_op_input(quant_type, "InScale")
->AsInput();
auto quant_op =
pattern->NewNode(GetNodeName("quant_op"))->assert_is_op(quant_type);
PDNode *quant_op_out_scale = nullptr;
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
kNumFields += 1;
quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_nth_input(dequant_type, "Scales", 1)
->AsIntermediate();
} else {
quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale"))
auto *quant_node =
pattern->NewNode(GetNodeName("quant_node"))->assert_is_op(quant_type);
auto *output_scale_node = pattern->NewNode(GetNodeName("output_scale_node"))
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_input(dequant_type, "Scale")
->AsIntermediate();
}
auto quant_op_out = pattern->NewNode(GetNodeName("quant_op_out"))
->AsOutput();
auto *output_act_node = pattern->NewNode(GetNodeName("output_act_node"))
->assert_is_op_output(quant_type, "Out")
->assert_is_op_input(op_type)
->AsIntermediate();
// there are 'times' quantized and dequant op
std::vector<PDNode *> nodes;
for (int i = 0; i < times; i++) {
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op_weight") + std::to_string(i))
->assert_is_op_input(op_type, weight_name)
->AsInput());
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op") + std::to_string(i))
->assert_is_op(op_type));
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op_out") + std::to_string(i))
->assert_is_op_output(op_type)
->assert_is_op_input(dequant_type, "X")
->AsIntermediate());
nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op") + std::to_string(i))
->assert_is_op(dequant_type));
nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op_out") + std::to_string(i))
->AsOutput();
quant_node->LinksFrom({input_scale_node, input_act_node});
output_scale_node->LinksFrom({quant_node});
output_act_node->LinksFrom({quant_node});
}
void patterns::DequantOpFuse::operator()(PDNode *quantized_op_input,
const std::string &quantized_op_type,
const std::string &dequant_type,
const std::string &weight_name) {
auto *quantized_op_weight =
pattern->NewNode(GetNodeName("quantized_op_weight"))
->assert_is_op_input(quantized_op_type, weight_name)
->AsInput();
auto *quantized_op = pattern->NewNode(GetNodeName("quantized_op"))
->assert_is_op(quantized_op_type);
auto *quantized_op_out = pattern->NewNode(GetNodeName("quantized_op_out"))
->assert_is_op_output(quantized_op_type)
->assert_is_op_input(dequant_type, "X");
auto *dequant_op =
pattern->NewNode(GetNodeName("dequant_op"))->assert_is_op(dequant_type);
auto *dequant_op_out = pattern->NewNode(GetNodeName("dequant_op_out"))
->assert_is_op_output(dequant_type, "Out")
->AsOutput());
->AsOutput();
PDNode *dequant_channel_scale = nullptr;
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
nodes.push_back(pattern
->NewNode(GetNodeName("dequant_channel_scale") +
std::to_string(i))
dequant_channel_scale =
pattern->NewNode(GetNodeName("dequant_channel_scale"))
->assert_is_op_nth_input(dequant_type, "Scales", 0)
->AsInput());
}
->AsInput();
}
quantized_op->LinksFrom({quantized_op_input, quantized_op_weight});
quantized_op_out->LinksFrom({quantized_op});
quant_op->LinksFrom({quant_op_input, quant_op_in_scale});
quant_op_out->LinksFrom({quant_op});
for (int i = 0; i < times; i++) {
nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom(
{quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]});
nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOffset]});
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale,
nodes[i * kNumFields + kDequantOpWeightScaleOffset]});
dequant_op->LinksFrom({quantized_op_out, dequant_channel_scale});
} else {
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
}
nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kDequantOpOffset]});
dequant_op->LinksFrom({quantized_op_out});
}
dequant_op_out->LinksFrom({dequant_op});
}
void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
......
......@@ -1150,14 +1150,28 @@ struct TransposeFlattenConcat : public PatternBase {
}
};
struct QuantDequantOpFuse : public PatternBase {
QuantDequantOpFuse(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quant_dequant_fuse") {}
void operator()(PDNode* quant_op_input, const std::string& op_name,
const std::string& weight_name, int times,
const std::string& quant_type,
const std::string& dequant_type);
struct DeleteQuantOpFuse : public PatternBase {
DeleteQuantOpFuse(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "delete_quant_fuse") {}
void operator()(PDNode* input_act_node, const std::string& quant_type);
std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type);
}
PDNode* GetPDNode(const std::string& op_type) {
return pattern->RetrieveNode(GetNodeName(op_type));
}
};
struct DequantOpFuse : public PatternBase {
DequantOpFuse(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "dequant_fuse") {}
void operator()(PDNode* quant_op_input, const std::string& quantized_op_type,
const std::string& dequant_type,
const std::string& weight_name);
std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type);
......
......@@ -24,159 +24,218 @@ namespace paddle {
namespace framework {
namespace ir {
void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
const std::string& op_type, const std::string& quant_type,
const std::string& dequant_type) {
const std::string pattern_name = "quant_dequant_fuse";
int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
const int kDequantOpWeightScaleOffset = 5;
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
kNumFields += 1;
}
// Delete quant op before quantized ops, and set input scale in the attr of
// quantized ops
void DeleteQuant(ir::Graph* graph, Scope* scope,
const std::string& quant_type) {
const std::string pattern_name = "delete_quant_fuse";
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
auto* input_act_node = gpd.mutable_pattern()
->NewNode("input_act_node")
->assert_is_op_input(quant_type, "X")
->AsInput();
std::string quantized_op_type = op_type;
// Create pattern
patterns::DeleteQuantOpFuse pattern(gpd.mutable_pattern(), pattern_name);
pattern(input_act_node, quant_type);
// extract input scale from quant op input to set it in attr of all quantized
// ops linked from it
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
PADDLE_ENFORCE_EQ(subgraph.count(input_act_node), true,
platform::errors::NotFound(
"Input act node not found in Delete Quant fusion."));
Node* input_act = subgraph.at(input_act_node);
Node* input_scale = subgraph.at(pattern.GetPDNode("input_scale_node"));
Node* quant = subgraph.at(pattern.GetPDNode("quant_node"));
Node* output_scale = subgraph.at(pattern.GetPDNode("output_scale_node"));
Node* output_act = subgraph.at(pattern.GetPDNode("output_act_node"));
int bit_length = BOOST_GET_CONST(int, quant->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
// Get input scale from tensor
std::string input_scale_var_name = quant->Op()->Input("InScale").front();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument(
"scope in DeleteQuantOpFuse pass should not be null."));
const LoDTensor& input_scale_tensor =
scope->FindVar(input_scale_var_name)->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(input_scale_tensor.place()), true,
platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>();
float in_scale = input_scale_data[0];
float scale_value = in_scale / range;
// Set input scale in attr, and relink nodes
std::string input_act_name = input_act->Var()->Name();
std::string output_act_name = output_act->Var()->Name();
auto outlinks = output_act->outputs;
for (auto* quantized_node : outlinks) {
auto op_desc = quantized_node->Op();
std::string quantized_op_type = op_desc->Type();
if (quantized_op_type == "conv2d" ||
quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d" ||
quantized_op_type == "fc") {
op_desc->SetAttr("Input_scale", scale_value);
} else if (quantized_op_type == "mul") {
op_desc->SetAttr("X_scale", scale_value);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type %s", quantized_op_type));
}
op_desc->SetAttr("bit_length", bit_length);
op_desc->RenameInput(output_act_name, input_act_name);
op_desc->Flush();
IR_NODE_LINK_TO(input_act, quantized_node);
}
// Delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {input_scale, quant,
output_scale, output_act};
GraphSafeRemoveNodes(graph, nodes2rm);
};
gpd(graph, handler);
}
// Delete dequant op after quantized ops, and convert weight from fp32 range to
// int8 range
void FuseDequant(ir::Graph* graph, Scope* scope,
const std::string& quantized_op_type,
const std::string& dequant_type) {
std::string weight_name = "";
if (op_type == "conv2d" || op_type == "depthwise_conv2d" ||
op_type == "conv2d_fusion") {
std::string input_name = "";
if (quantized_op_type == "conv2d" ||
quantized_op_type == "depthwise_conv2d" ||
quantized_op_type == "conv2d_fusion") {
weight_name = "Filter";
} else if (op_type == "mul") {
input_name = "Input";
} else if (quantized_op_type == "mul") {
weight_name = "Y";
} else if (op_type == "fc") {
input_name = "X";
} else if (quantized_op_type == "fc") {
weight_name = "W";
input_name = "Input";
} else {
PADDLE_ENFORCE(
"QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for "
"now.");
}
const std::string pattern_name = "dequant_fuse";
GraphPatternDetector gpd;
auto* quantized_op_input =
gpd.mutable_pattern()
->NewNode("quantized_op_input")
->assert_is_op_input(quantized_op_type, input_name)
->AsInput();
patterns::QuantDequantOpFuse pattern(gpd.mutable_pattern(), pattern_name);
pattern(x, quantized_op_type, weight_name, times, quant_type, dequant_type);
// Create pattern
patterns::DequantOpFuse pattern(gpd.mutable_pattern(), pattern_name);
pattern(quantized_op_input, quantized_op_type, dequant_type, weight_name);
// Create new op desc
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
PADDLE_ENFORCE(subgraph.count(x));
auto* input_node = subgraph.at(x);
Node* quant_op_in_scale =
subgraph.at(pattern.GetPDNode("quant_op_in_scale"));
Node* quant_op = subgraph.at(pattern.GetPDNode("quant_op"));
Node* quant_op_out_scale =
subgraph.at(pattern.GetPDNode("quant_op_out_scale"));
Node* quant_op_out = subgraph.at(pattern.GetPDNode("quant_op_out"));
std::vector<Node*> nodes;
for (int i = 0; i < times; i++) {
nodes.push_back(subgraph.at(
pattern.GetPDNode("quantized_op_weight" + std::to_string(i))));
nodes.push_back(
subgraph.at(pattern.GetPDNode("quantized_op" + std::to_string(i))));
nodes.push_back(subgraph.at(
pattern.GetPDNode("quantized_op_out" + std::to_string(i))));
nodes.push_back(
subgraph.at(pattern.GetPDNode("dequant_op" + std::to_string(i))));
nodes.push_back(
subgraph.at(pattern.GetPDNode("dequant_op_out" + std::to_string(i))));
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
nodes.push_back(subgraph.at(
pattern.GetPDNode("dequant_channel_scale" + std::to_string(i))));
}
}
PADDLE_ENFORCE_EQ(
subgraph.count(quantized_op_input), true,
platform::errors::NotFound(
"Quantized op input node not found in Delete Quant fusion."));
Node* quantized_op_input_node = subgraph.at(quantized_op_input);
Node* quantized_op_weight_node =
subgraph.at(pattern.GetPDNode("quantized_op_weight"));
Node* quantized_op_node = subgraph.at(pattern.GetPDNode("quantized_op"));
Node* dequant_op_node = subgraph.at(pattern.GetPDNode("dequant_op"));
Node* dequant_op_out_node =
subgraph.at(pattern.GetPDNode("dequant_op_out"));
std::unordered_set<const Node*> nodes2rm = {};
int bit_length =
BOOST_GET_CONST(int, quant_op->Op()->GetAttr("bit_length"));
BOOST_GET_CONST(int, quantized_op_node->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
// Prepare input scale
std::string input_scale_var_name = quant_op->Op()->Input("InScale").front();
PADDLE_ENFORCE(scope);
const LoDTensor& input_scale_tensor =
scope->FindVar(input_scale_var_name)->Get<LoDTensor>();
PADDLE_ENFORCE(paddle::platform::is_cpu_place(input_scale_tensor.place()));
const float* input_scale_data = input_scale_tensor.data<float>();
float input_scale = input_scale_data[0];
std::unordered_set<const Node*> delete_nodes;
for (int i = 0; i < times; i++) {
std::vector<float> weight_scale;
// Get weight scale from dequant op.
// Get weight scale
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
auto scales_name =
nodes[i * kNumFields + kDequantOpOffset]->Op()->Input("Scales");
PADDLE_ENFORCE(scales_name.size() == 2);
Node* dequant_channel_scale_node =
subgraph.at(pattern.GetPDNode("dequant_channel_scale"));
auto scales_name = dequant_op_node->Op()->Input("Scales");
PADDLE_ENFORCE_EQ(
scales_name.size(), 2,
platform::errors::InvalidArgument(
"Scales size in channel-wise dequantize op should be 2, got %d",
scales_name.size()));
const LoDTensor& channel_scale_tensor =
scope->FindVar(scales_name[0])->Get<LoDTensor>();
PADDLE_ENFORCE(
paddle::platform::is_cpu_place(channel_scale_tensor.place()));
PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(channel_scale_tensor.place()), true,
platform::errors::InvalidArgument(
"Channel scale tensor's place should be CPU."));
const float* channel_scale_data = channel_scale_tensor.data<float>();
for (int i = 0; i < channel_scale_tensor.numel(); i++) {
weight_scale.push_back(channel_scale_data[i]);
weight_scale.push_back(channel_scale_data[i] / range);
}
nodes2rm.insert(dequant_channel_scale_node);
} else {
float max_range =
BOOST_GET_CONST(float, dequant_op_node->Op()->GetAttr("max_range"));
weight_scale.push_back((range * range) / max_range / range);
}
delete_nodes.insert(
nodes[i * kNumFields + kDequantOpWeightScaleOffset]);
// Convert weight to fp32 range
auto* weight_tensor =
scope->Var(quantized_op_weight_node->Name())->GetMutable<LoDTensor>();
auto w_dims = weight_tensor->dims();
// If quantized op is fc, weight scale size = 1;
// If quantized op is conv, weight scale size = weight dims[0]
bool valid_scale_size =
(weight_scale.size() == 1 ||
weight_scale.size() == static_cast<size_t>(w_dims[0]));
PADDLE_ENFORCE_EQ(valid_scale_size, true,
platform::errors::InvalidArgument(
"TRT int8 quant: invalid scale size"));
float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
for (int j = 0; j < weight_tensor->numel(); j++) {
if (weight_scale.size() == 1) {
quantized_weight_data[j] *= weight_scale[0];
} else {
float max_range = BOOST_GET_CONST(
float, nodes[i * kNumFields + kDequantOpOffset]->Op()->GetAttr(
"max_range"));
weight_scale.push_back((range * range) / max_range);
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[j / inner_size];
}
}
// create new op_desc
auto base_op_desc =
*nodes[i * kNumFields + kQuantizedOpOffset]->Op()->Proto();
std::string new_input = input_node->Name();
std::string new_output =
nodes[i * kNumFields + kDequantOpOutOffset]->Name();
auto base_op_desc = *quantized_op_node->Op()->Proto();
std::string new_input = quantized_op_input_node->Name();
std::string new_output = dequant_op_out_node->Name();
framework::OpDesc new_op_desc(base_op_desc, nullptr);
new_op_desc.SetType(quantized_op_type);
new_op_desc.SetAttr("enable_int8", true);
if (quantized_op_type == "conv2d" ||
quantized_op_type == "conv2d_fusion" ||
if (quantized_op_type == "conv2d" || quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d") {
new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetAttr("Input_scale", input_scale);
new_op_desc.SetOutput("Output", {new_output});
} else if (quantized_op_type == "fc") {
new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetAttr("Input_scale", input_scale);
new_op_desc.SetOutput("Out", {new_output});
} else if (quantized_op_type == "mul") {
new_op_desc.SetInput("X", {new_input});
new_op_desc.SetAttr("X_scale", input_scale);
new_op_desc.SetOutput("Out", {new_output});
}
new_op_desc.SetAttr("weight_scale", weight_scale);
new_op_desc.Flush();
auto* new_op = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(input_node, new_op);
IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset], new_op);
IR_NODE_LINK_TO(new_op, nodes[i * kNumFields + kDequantOpOutOffset]);
delete_nodes.insert(nodes[i * kNumFields + kQuantizedOpOffset]);
delete_nodes.insert(nodes[i * kNumFields + kQuantizedOpOutOffset]);
delete_nodes.insert(nodes[i * kNumFields + kDequantOpOffset]);
}
delete_nodes.insert(quant_op_in_scale);
delete_nodes.insert(quant_op);
delete_nodes.insert(quant_op_out);
delete_nodes.insert(quant_op_out_scale);
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph, delete_nodes);
IR_NODE_LINK_TO(quantized_op_input_node, new_op);
IR_NODE_LINK_TO(quantized_op_weight_node, new_op);
IR_NODE_LINK_TO(new_op, dequant_op_out_node);
// Delete nodes and edges
nodes2rm.insert(quantized_op_node);
nodes2rm.insert(dequant_op_node);
GraphSafeRemoveNodes(graph, nodes2rm);
};
gpd(graph, handler);
}
......@@ -186,19 +245,19 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(pattern_name, graph);
std::unordered_set<std::string> dequant_types = {
"fake_dequantize_max_abs", "fake_channel_wise_dequantize_max_abs"};
"fake_channel_wise_dequantize_max_abs", "fake_dequantize_max_abs"};
std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = {"conv2d", "mul",
"depthwise_conv2d"};
std::unordered_set<std::string> quantized_op_types = {
"conv2d", "mul", "depthwise_conv2d", "fc"};
auto* scope = param_scope();
for (auto& dequant_type : dequant_types) {
for (auto& quant_type : quant_types) {
for (auto& op_type : quantized_op_types) {
for (int i = 6; i >= 1; i--) {
RunQuantDequant(graph, scope, i, op_type, quant_type, dequant_type);
}
DeleteQuant(graph, scope, quant_type);
}
for (auto& dequant_type : dequant_types) {
for (auto& quantized_op_type : quantized_op_types) {
FuseDequant(graph, scope, quantized_op_type, dequant_type);
}
}
}
......
......@@ -22,6 +22,9 @@ namespace paddle {
namespace framework {
namespace ir {
///
/// Fuse quant + conv2d/depthwise_conv2d/mul/fc + dequant
///
class QuantDequantFusePass : public FusePassBase {
public:
virtual ~QuantDequantFusePass() {}
......
......@@ -365,6 +365,10 @@ const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
return it->second;
}
bool OpDesc::HasOutput(const std::string &name) const {
return outputs_.find(name) != outputs_.end();
}
std::vector<std::string> OpDesc::OutputArgumentNames() const {
std::vector<std::string> retv;
for (auto &ipt : this->outputs_) {
......
......@@ -57,6 +57,8 @@ class OpDesc {
const std::vector<std::string> &Output(const std::string &name) const;
bool HasOutput(const std::string &name) const;
std::vector<std::string> OutputArgumentNames() const;
void SetOutput(const std::string &param_name,
......
......@@ -281,11 +281,8 @@ void AnalysisConfig::Update() {
if (use_tensorrt_) {
pass_builder()->ClearPasses();
bool use_calib_int8 =
(tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8) &&
trt_use_calib_mode_;
for (const auto &pass : kTRTSubgraphPasses) {
if (use_calib_int8 &&
if (tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
(pass == "conv_bn_fuse_pass" || pass == "fc_fuse_pass")) {
continue;
}
......
......@@ -52,7 +52,8 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
if (enable_int8) {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr("Input_scale"));
float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
float in_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127;
auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t,
......
......@@ -62,7 +62,7 @@ class FcOpConverter : public OpConverter {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr(i_name + "_scale"));
float in_scale =
BOOST_GET_CONST(float, op_desc.GetAttr(i_name + "_scale"));
BOOST_GET_CONST(float, op_desc.GetAttr(i_name + "_scale")) * 127;
auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(),
......
......@@ -98,8 +98,33 @@ class OpConverter {
}
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type());
it->SetEngine(engine);
(*it)(op, scope, test_mode);
bool has_out_scale = op_desc.HasAttr("out_threshold");
if (has_out_scale) {
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
std::string output_name = "";
if (op_desc.HasOutput("Output")) {
output_name = op_desc.Output("Output").front();
} else if (op_desc.HasOutput("Out")) {
output_name = op_desc.Output("Out").front();
} else if (op_desc.HasOutput("Y")) {
output_name = op_desc.Output("Y").front();
} else {
PADDLE_THROW(
platform::errors::NotFound("Op %s has out threshold but doesn't "
"have an output named \"Output\", "
"\"Out\" or \"Y\".",
op_desc.Type()));
}
auto* output_itensor = engine->GetITensor(output_name);
engine->SetTensorDynamicRange(output_itensor, out_scale);
VLOG(1) << "Set out scale = " << out_scale << " for tensor "
<< output_name << ".";
}
}
// Convert a fluid block to tensorrt network, NOTE it just convert operators,
......
......@@ -124,23 +124,42 @@ void TensorRTEngine::FreezeNetwork() {
<< ", this might be ok when trt does not need this range";
}
}
std::unordered_set<std::string> all_out_t_name;
for (int i = 0; i < network()->getNbOutputs(); i++) {
auto *temp = network()->getOutput(i);
temp->setDynamicRange(-1, 1);
all_out_t_name.insert(temp->getName());
auto is_layer_int8 = [&](nvinfer1::ILayer *layer) -> bool {
for (int j = 0; j < layer->getNbInputs(); j++) {
auto *temp_in = layer->getInput(j);
if (!temp_in->dynamicRangeIsSet()) {
VLOG(1) << "Layer(Name: " << layer->getName()
<< ") is set to float32 because its input("
<< temp_in->getName() << ") doesn't have dynamic range.";
return false;
}
}
for (int i = 0; i < network()->getNbLayers(); i++) {
auto layer = network()->getLayer(i);
for (int j = 0; j < layer->getNbOutputs(); j++) {
auto *temp_out = layer->getOutput(j);
if (std::find(all_out_t_name.begin(), all_out_t_name.end(),
temp_out->getName()) != all_out_t_name.end()) {
layer->setPrecision(nvinfer1::DataType::kFLOAT);
layer->setOutputType(j, nvinfer1::DataType::kFLOAT);
if (temp_out->isNetworkOutput()) {
VLOG(1) << "Layer(Name: " << layer->getName()
<< ") is set to float32 because its output("
<< temp_out->getName() << ") is the output of the network.";
return false;
}
if (!temp_out->dynamicRangeIsSet()) {
VLOG(1) << "Layer(Name: " << layer->getName()
<< ") is set to float32 because its output("
<< temp_out->getName() << ") doesn't have dynamic range.";
return false;
}
}
return true;
};
// If a layer's output is the network's output, or not all of its inputs
// and outputs have scales,
// this layer's precision and output type are set to float32.
// This step has no effect if this layer is fused during TRT optimization.
for (int i = 0; i < network()->getNbLayers(); i++) {
auto layer = network()->getLayer(i);
if (!is_layer_int8(layer)) {
layer->setPrecision(nvinfer1::DataType::kFLOAT);
}
}
#endif
}
......@@ -237,7 +256,6 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
std::string name_suffix = std::to_string(name_suffix_counter);
std::string splitter = "__";
std::string name_with_suffix = name + splitter + name_suffix;
auto w_dims = weight_tensor->dims();
platform::CPUPlace cpu_place;
PADDLE_ENFORCE_EQ(
weight_map.count(name_with_suffix), 0,
......@@ -250,25 +268,6 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
float *weight_data =
weight_map[name_with_suffix]->mutable_data<float>(cpu_place);
name_suffix_counter += 1;
if (enable_int8) {
// when the op is fc, scale's size should be 1
// when the op is conv, scale's size should be w_dims[0]
bool valid_scale_size =
(scale.size() == 1 || scale.size() == static_cast<size_t>(w_dims[0]));
PADDLE_ENFORCE(valid_scale_size, "TRT int8 quant: invalid scale size");
for (int i = 0; i < weight_tensor->numel(); i++) {
if (scale.size() == 1) {
weight_data[i] *= (scale[0] / 127);
} else {
PADDLE_ENFORCE(w_dims.size() == 4,
"TRT int8 quant : We only use the channel quant for "
"conv op, so the weight dims should be 4.");
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
weight_data[i] *= (scale[i / inner_size] / 127);
}
}
}
return weight_data;
}
......
......@@ -43,11 +43,18 @@ struct SimpleOpTypeSetTeller : public Teller {
private:
// use this set for no calib int8.
std::unordered_set<std::string> int8_teller_set{
"mul", "conv2d", "pool2d",
"relu", "depthwise_conv2d", "softmax",
"batch_norm", "elementwise_add", "leaky_relu",
"fc"};
std::unordered_set<std::string> int8_teller_set{"mul",
"conv2d",
"pool2d",
"relu",
"depthwise_conv2d",
"softmax",
"batch_norm",
"elementwise_add",
"leaky_relu",
"fc",
"relu6",
"concat"};
std::unordered_set<std::string> teller_set{
"mul",
"conv2d",
......
......@@ -405,6 +405,14 @@ if(WITH_GPU AND TENSORRT_FOUND)
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_QUANT_RESNET_DIR})
set(TRT_MODEL_QUANT_YOLOV3_DIR "${INFERENCE_DEMO_INSTALL_DIR}/yolov3_r50_quant_aware")
if (NOT EXISTS ${TRT_MODEL_QUANT_YOLOV3_DIR})
inference_download_and_uncompress(${INFERENCE_DEMO_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "yolov3_r50_quant_aware.tgz")
endif()
inference_analysis_test(trt_quant_int8_yolov3_r50_test SRCS trt_quant_int8_yolov3_r50_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_QUANT_YOLOV3_DIR})
set(TEST_TRT_DYNAMIC_MODEL2 "${TRT_MODEL_INSTALL_DIR}/complex_model_dynamic")
if (NOT EXISTS ${TEST_TRT_DYNAMIC_MODEL2})
inference_download_and_uncompress(${TEST_TRT_DYNAMIC_MODEL2} ${INFERENCE_URL}/tensorrt_test "complex_model_dynamic2.tar.gz")
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <numeric>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle {
namespace inference {
TEST(quant_int8, yolov3_resnet50) {
AnalysisConfig config;
config.EnableUseGpu(100, 0);
config.SetModel(FLAGS_infer_model + "/model", FLAGS_infer_model + "/params");
config.SwitchUseFeedFetchOps(false);
config.EnableTensorRtEngine(1 << 30, 1, 3, AnalysisConfig::Precision::kInt8,
false, false);
auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames();
int channels = 3;
int height = 608;
int width = 608;
int input_num = channels * height * width * 1;
float *input = new float[input_num];
int32_t *im_shape = new int32_t[2];
im_shape[0] = 608;
im_shape[1] = 608;
memset(input, 1.0, input_num * sizeof(float));
auto input_t = predictor->GetInputTensor(input_names[0]);
input_t->Reshape({1, channels, height, width});
input_t->copy_from_cpu(input);
auto input_t1 = predictor->GetInputTensor(input_names[1]);
input_t1->Reshape({1, 2});
input_t1->copy_from_cpu(im_shape);
ASSERT_TRUE(predictor->ZeroCopyRun());
std::vector<float> out_data;
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputTensor(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
out_data.resize(out_num);
output_t->copy_to_cpu(out_data.data());
}
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册