diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 5c2301d6e007510a75b9c3107d781efa53765f2a..f58e6c8bff93da0b27cd147108bc57a452269188 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1980,99 +1980,58 @@ PDNode *patterns::TransposeFlattenConcat::operator()( return concat_out; } -void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, - const std::string &op_type, - const std::string &weight_name, - int times, - const std::string &quant_type, - const std::string &dequant_type) { - int kNumFields = 5; - const int kQuantizedWeightOffset = 0; - const int kQuantizedOpOffset = 1; - const int kQuantizedOpOutOffset = 2; - const int kDequantOpOffset = 3; - const int kDequantOpOutOffset = 4; - const int kDequantOpWeightScaleOffset = 5; - - // the quant op always be one. - auto quant_op_in_scale = pattern->NewNode(GetNodeName("quant_op_in_scale")) +void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node, + const std::string &quant_type) { + auto *input_scale_node = pattern->NewNode(GetNodeName("input_scale_node")) ->assert_is_op_input(quant_type, "InScale") ->AsInput(); - auto quant_op = - pattern->NewNode(GetNodeName("quant_op"))->assert_is_op(quant_type); - - PDNode *quant_op_out_scale = nullptr; + auto *quant_node = + pattern->NewNode(GetNodeName("quant_node"))->assert_is_op(quant_type); + auto *output_scale_node = pattern->NewNode(GetNodeName("output_scale_node")) + ->assert_is_op_output(quant_type, "OutScale") + ->AsOutput(); + auto *output_act_node = pattern->NewNode(GetNodeName("output_act_node")) + ->assert_is_op_output(quant_type, "Out") + ->AsOutput(); + quant_node->LinksFrom({input_scale_node, input_act_node}); + output_scale_node->LinksFrom({quant_node}); + output_act_node->LinksFrom({quant_node}); +} + +void patterns::DequantOpFuse::operator()(PDNode *quantized_op_input, + const std::string &quantized_op_type, + const std::string &dequant_type, + const std::string &weight_name) { + auto *quantized_op_weight = + pattern->NewNode(GetNodeName("quantized_op_weight")) + ->assert_is_op_input(quantized_op_type, weight_name) + ->AsInput(); + auto *quantized_op = pattern->NewNode(GetNodeName("quantized_op")) + ->assert_is_op(quantized_op_type); + auto *quantized_op_out = pattern->NewNode(GetNodeName("quantized_op_out")) + ->assert_is_op_output(quantized_op_type) + ->assert_is_op_input(dequant_type, "X"); + auto *dequant_op = + pattern->NewNode(GetNodeName("dequant_op"))->assert_is_op(dequant_type); + auto *dequant_op_out = pattern->NewNode(GetNodeName("dequant_op_out")) + ->assert_is_op_output(dequant_type, "Out") + ->AsOutput(); + PDNode *dequant_channel_scale = nullptr; if (dequant_type == "fake_channel_wise_dequantize_max_abs") { - kNumFields += 1; - quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale")) - ->assert_is_op_output(quant_type, "OutScale") - ->assert_is_op_nth_input(dequant_type, "Scales", 1) - ->AsIntermediate(); - } else { - quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale")) - ->assert_is_op_output(quant_type, "OutScale") - ->assert_is_op_input(dequant_type, "Scale") - ->AsIntermediate(); + dequant_channel_scale = + pattern->NewNode(GetNodeName("dequant_channel_scale")) + ->assert_is_op_nth_input(dequant_type, "Scales", 0) + ->AsInput(); } + quantized_op->LinksFrom({quantized_op_input, quantized_op_weight}); + quantized_op_out->LinksFrom({quantized_op}); - auto quant_op_out = pattern->NewNode(GetNodeName("quant_op_out")) - ->assert_is_op_output(quant_type, "Out") - ->assert_is_op_input(op_type) - ->AsIntermediate(); - - // there are 'times' quantized and dequant op - std::vector nodes; - for (int i = 0; i < times; i++) { - nodes.push_back( - pattern->NewNode(GetNodeName("quantized_op_weight") + std::to_string(i)) - ->assert_is_op_input(op_type, weight_name) - ->AsInput()); - nodes.push_back( - pattern->NewNode(GetNodeName("quantized_op") + std::to_string(i)) - ->assert_is_op(op_type)); - - nodes.push_back( - pattern->NewNode(GetNodeName("quantized_op_out") + std::to_string(i)) - ->assert_is_op_output(op_type) - ->assert_is_op_input(dequant_type, "X") - ->AsIntermediate()); - - nodes.push_back( - pattern->NewNode(GetNodeName("dequant_op") + std::to_string(i)) - ->assert_is_op(dequant_type)); - - nodes.push_back( - pattern->NewNode(GetNodeName("dequant_op_out") + std::to_string(i)) - ->assert_is_op_output(dequant_type, "Out") - ->AsOutput()); - - if (dequant_type == "fake_channel_wise_dequantize_max_abs") { - nodes.push_back(pattern - ->NewNode(GetNodeName("dequant_channel_scale") + - std::to_string(i)) - ->assert_is_op_nth_input(dequant_type, "Scales", 0) - ->AsInput()); - } - } - - quant_op->LinksFrom({quant_op_input, quant_op_in_scale}); - quant_op_out->LinksFrom({quant_op}); - for (int i = 0; i < times; i++) { - nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom( - {quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]}); - nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom( - {nodes[i * kNumFields + kQuantizedOpOffset]}); - if (dequant_type == "fake_channel_wise_dequantize_max_abs") { - nodes[i * kNumFields + kDequantOpOffset]->LinksFrom( - {nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale, - nodes[i * kNumFields + kDequantOpWeightScaleOffset]}); - } else { - nodes[i * kNumFields + kDequantOpOffset]->LinksFrom( - {nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale}); - } - nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom( - {nodes[i * kNumFields + kDequantOpOffset]}); + if (dequant_type == "fake_channel_wise_dequantize_max_abs") { + dequant_op->LinksFrom({quantized_op_out, dequant_channel_scale}); + } else { + dequant_op->LinksFrom({quantized_op_out}); } + dequant_op_out->LinksFrom({dequant_op}); } void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 65f7eeebd228ed032262055c1ba83bd43a3eb4a3..422ad1ef47a84ff21a2568a2773c899733f34dc7 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1150,14 +1150,28 @@ struct TransposeFlattenConcat : public PatternBase { } }; -struct QuantDequantOpFuse : public PatternBase { - QuantDequantOpFuse(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "quant_dequant_fuse") {} - - void operator()(PDNode* quant_op_input, const std::string& op_name, - const std::string& weight_name, int times, - const std::string& quant_type, - const std::string& dequant_type); +struct DeleteQuantOpFuse : public PatternBase { + DeleteQuantOpFuse(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "delete_quant_fuse") {} + + void operator()(PDNode* input_act_node, const std::string& quant_type); + + std::string GetNodeName(const std::string& op_type) { + return PDNodeName(name_scope_, repr_, id_, op_type); + } + + PDNode* GetPDNode(const std::string& op_type) { + return pattern->RetrieveNode(GetNodeName(op_type)); + } +}; + +struct DequantOpFuse : public PatternBase { + DequantOpFuse(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "dequant_fuse") {} + + void operator()(PDNode* quant_op_input, const std::string& quantized_op_type, + const std::string& dequant_type, + const std::string& weight_name); std::string GetNodeName(const std::string& op_type) { return PDNodeName(name_scope_, repr_, id_, op_type); diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 0c266fbc2019156713ac51a7f915b25f3b723560..1f1a54f140b0d0fde18529708b0ea920a52ee466 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -24,159 +24,218 @@ namespace paddle { namespace framework { namespace ir { -void RunQuantDequant(ir::Graph* graph, Scope* scope, int times, - const std::string& op_type, const std::string& quant_type, - const std::string& dequant_type) { - const std::string pattern_name = "quant_dequant_fuse"; - int kNumFields = 5; - const int kQuantizedWeightOffset = 0; - const int kQuantizedOpOffset = 1; - const int kQuantizedOpOutOffset = 2; - const int kDequantOpOffset = 3; - const int kDequantOpOutOffset = 4; - const int kDequantOpWeightScaleOffset = 5; - - if (dequant_type == "fake_channel_wise_dequantize_max_abs") { - kNumFields += 1; - } - +// Delete quant op before quantized ops, and set input scale in the attr of +// quantized ops +void DeleteQuant(ir::Graph* graph, Scope* scope, + const std::string& quant_type) { + const std::string pattern_name = "delete_quant_fuse"; GraphPatternDetector gpd; - auto* x = gpd.mutable_pattern() - ->NewNode("x") - ->assert_is_op_input(quant_type, "X") - ->AsInput(); + auto* input_act_node = gpd.mutable_pattern() + ->NewNode("input_act_node") + ->assert_is_op_input(quant_type, "X") + ->AsInput(); + + // Create pattern + patterns::DeleteQuantOpFuse pattern(gpd.mutable_pattern(), pattern_name); + pattern(input_act_node, quant_type); + + // extract input scale from quant op input to set it in attr of all quantized + // ops linked from it + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + PADDLE_ENFORCE_EQ(subgraph.count(input_act_node), true, + platform::errors::NotFound( + "Input act node not found in Delete Quant fusion.")); + Node* input_act = subgraph.at(input_act_node); + Node* input_scale = subgraph.at(pattern.GetPDNode("input_scale_node")); + Node* quant = subgraph.at(pattern.GetPDNode("quant_node")); + Node* output_scale = subgraph.at(pattern.GetPDNode("output_scale_node")); + Node* output_act = subgraph.at(pattern.GetPDNode("output_act_node")); + int bit_length = BOOST_GET_CONST(int, quant->Op()->GetAttr("bit_length")); + int range = ((1 << (bit_length - 1)) - 1); + + // Get input scale from tensor + std::string input_scale_var_name = quant->Op()->Input("InScale").front(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument( + "scope in DeleteQuantOpFuse pass should not be null.")); + const LoDTensor& input_scale_tensor = + scope->FindVar(input_scale_var_name)->Get(); + PADDLE_ENFORCE_EQ( + paddle::platform::is_cpu_place(input_scale_tensor.place()), true, + platform::errors::InvalidArgument( + "Input scale tensor's place should be CPU.")); + const float* input_scale_data = input_scale_tensor.data(); + float in_scale = input_scale_data[0]; + float scale_value = in_scale / range; + + // Set input scale in attr, and relink nodes + std::string input_act_name = input_act->Var()->Name(); + std::string output_act_name = output_act->Var()->Name(); + auto outlinks = output_act->outputs; + for (auto* quantized_node : outlinks) { + auto op_desc = quantized_node->Op(); + std::string quantized_op_type = op_desc->Type(); + if (quantized_op_type == "conv2d" || + quantized_op_type == "conv2d_fusion" || + quantized_op_type == "depthwise_conv2d" || + quantized_op_type == "fc") { + op_desc->SetAttr("Input_scale", scale_value); + } else if (quantized_op_type == "mul") { + op_desc->SetAttr("X_scale", scale_value); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported quantized op type %s", quantized_op_type)); + } + op_desc->SetAttr("bit_length", bit_length); + op_desc->RenameInput(output_act_name, input_act_name); + op_desc->Flush(); + IR_NODE_LINK_TO(input_act, quantized_node); + } + // Delete nodes and edges + std::unordered_set nodes2rm = {input_scale, quant, + output_scale, output_act}; + GraphSafeRemoveNodes(graph, nodes2rm); + }; + gpd(graph, handler); +} - std::string quantized_op_type = op_type; +// Delete dequant op after quantized ops, and convert weight from fp32 range to +// int8 range +void FuseDequant(ir::Graph* graph, Scope* scope, + const std::string& quantized_op_type, + const std::string& dequant_type) { std::string weight_name = ""; - if (op_type == "conv2d" || op_type == "depthwise_conv2d" || - op_type == "conv2d_fusion") { + std::string input_name = ""; + if (quantized_op_type == "conv2d" || + quantized_op_type == "depthwise_conv2d" || + quantized_op_type == "conv2d_fusion") { weight_name = "Filter"; - } else if (op_type == "mul") { + input_name = "Input"; + } else if (quantized_op_type == "mul") { weight_name = "Y"; - } else if (op_type == "fc") { + input_name = "X"; + } else if (quantized_op_type == "fc") { weight_name = "W"; + input_name = "Input"; } else { PADDLE_ENFORCE( "QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for " "now."); } + const std::string pattern_name = "dequant_fuse"; + GraphPatternDetector gpd; + + auto* quantized_op_input = + gpd.mutable_pattern() + ->NewNode("quantized_op_input") + ->assert_is_op_input(quantized_op_type, input_name) + ->AsInput(); - patterns::QuantDequantOpFuse pattern(gpd.mutable_pattern(), pattern_name); - pattern(x, quantized_op_type, weight_name, times, quant_type, dequant_type); + // Create pattern + patterns::DequantOpFuse pattern(gpd.mutable_pattern(), pattern_name); + pattern(quantized_op_input, quantized_op_type, dequant_type, weight_name); + // Create new op desc auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - PADDLE_ENFORCE(subgraph.count(x)); - auto* input_node = subgraph.at(x); - Node* quant_op_in_scale = - subgraph.at(pattern.GetPDNode("quant_op_in_scale")); - Node* quant_op = subgraph.at(pattern.GetPDNode("quant_op")); - Node* quant_op_out_scale = - subgraph.at(pattern.GetPDNode("quant_op_out_scale")); - Node* quant_op_out = subgraph.at(pattern.GetPDNode("quant_op_out")); - - std::vector nodes; - for (int i = 0; i < times; i++) { - nodes.push_back(subgraph.at( - pattern.GetPDNode("quantized_op_weight" + std::to_string(i)))); - nodes.push_back( - subgraph.at(pattern.GetPDNode("quantized_op" + std::to_string(i)))); - nodes.push_back(subgraph.at( - pattern.GetPDNode("quantized_op_out" + std::to_string(i)))); - nodes.push_back( - subgraph.at(pattern.GetPDNode("dequant_op" + std::to_string(i)))); - nodes.push_back( - subgraph.at(pattern.GetPDNode("dequant_op_out" + std::to_string(i)))); - if (dequant_type == "fake_channel_wise_dequantize_max_abs") { - nodes.push_back(subgraph.at( - pattern.GetPDNode("dequant_channel_scale" + std::to_string(i)))); - } - } + PADDLE_ENFORCE_EQ( + subgraph.count(quantized_op_input), true, + platform::errors::NotFound( + "Quantized op input node not found in Delete Quant fusion.")); + Node* quantized_op_input_node = subgraph.at(quantized_op_input); + Node* quantized_op_weight_node = + subgraph.at(pattern.GetPDNode("quantized_op_weight")); + Node* quantized_op_node = subgraph.at(pattern.GetPDNode("quantized_op")); + Node* dequant_op_node = subgraph.at(pattern.GetPDNode("dequant_op")); + Node* dequant_op_out_node = + subgraph.at(pattern.GetPDNode("dequant_op_out")); + std::unordered_set nodes2rm = {}; int bit_length = - BOOST_GET_CONST(int, quant_op->Op()->GetAttr("bit_length")); + BOOST_GET_CONST(int, quantized_op_node->Op()->GetAttr("bit_length")); int range = ((1 << (bit_length - 1)) - 1); - // Prepare input scale - std::string input_scale_var_name = quant_op->Op()->Input("InScale").front(); - PADDLE_ENFORCE(scope); - const LoDTensor& input_scale_tensor = - scope->FindVar(input_scale_var_name)->Get(); + std::vector weight_scale; - PADDLE_ENFORCE(paddle::platform::is_cpu_place(input_scale_tensor.place())); - const float* input_scale_data = input_scale_tensor.data(); - float input_scale = input_scale_data[0]; - std::unordered_set delete_nodes; - - for (int i = 0; i < times; i++) { - std::vector weight_scale; - - // Get weight scale from dequant op. - if (dequant_type == "fake_channel_wise_dequantize_max_abs") { - auto scales_name = - nodes[i * kNumFields + kDequantOpOffset]->Op()->Input("Scales"); - PADDLE_ENFORCE(scales_name.size() == 2); - const LoDTensor& channel_scale_tensor = - scope->FindVar(scales_name[0])->Get(); - PADDLE_ENFORCE( - paddle::platform::is_cpu_place(channel_scale_tensor.place())); - const float* channel_scale_data = channel_scale_tensor.data(); - for (int i = 0; i < channel_scale_tensor.numel(); i++) { - weight_scale.push_back(channel_scale_data[i]); - } - delete_nodes.insert( - nodes[i * kNumFields + kDequantOpWeightScaleOffset]); - } else { - float max_range = BOOST_GET_CONST( - float, nodes[i * kNumFields + kDequantOpOffset]->Op()->GetAttr( - "max_range")); - weight_scale.push_back((range * range) / max_range); + // Get weight scale + if (dequant_type == "fake_channel_wise_dequantize_max_abs") { + Node* dequant_channel_scale_node = + subgraph.at(pattern.GetPDNode("dequant_channel_scale")); + auto scales_name = dequant_op_node->Op()->Input("Scales"); + PADDLE_ENFORCE_EQ( + scales_name.size(), 2, + platform::errors::InvalidArgument( + "Scales size in channel-wise dequantize op should be 2, got %d", + scales_name.size())); + const LoDTensor& channel_scale_tensor = + scope->FindVar(scales_name[0])->Get(); + PADDLE_ENFORCE_EQ( + paddle::platform::is_cpu_place(channel_scale_tensor.place()), true, + platform::errors::InvalidArgument( + "Channel scale tensor's place should be CPU.")); + const float* channel_scale_data = channel_scale_tensor.data(); + for (int i = 0; i < channel_scale_tensor.numel(); i++) { + weight_scale.push_back(channel_scale_data[i] / range); } + nodes2rm.insert(dequant_channel_scale_node); + } else { + float max_range = + BOOST_GET_CONST(float, dequant_op_node->Op()->GetAttr("max_range")); + weight_scale.push_back((range * range) / max_range / range); + } - // create new op_desc - auto base_op_desc = - *nodes[i * kNumFields + kQuantizedOpOffset]->Op()->Proto(); - std::string new_input = input_node->Name(); - std::string new_output = - nodes[i * kNumFields + kDequantOpOutOffset]->Name(); - - framework::OpDesc new_op_desc(base_op_desc, nullptr); - new_op_desc.SetType(quantized_op_type); - new_op_desc.SetAttr("enable_int8", true); - - if (quantized_op_type == "conv2d" || - quantized_op_type == "conv2d_fusion" || - quantized_op_type == "depthwise_conv2d") { - new_op_desc.SetInput("Input", {new_input}); - new_op_desc.SetAttr("Input_scale", input_scale); - new_op_desc.SetOutput("Output", {new_output}); - } else if (quantized_op_type == "fc") { - new_op_desc.SetInput("Input", {new_input}); - new_op_desc.SetAttr("Input_scale", input_scale); - new_op_desc.SetOutput("Out", {new_output}); - } else if (quantized_op_type == "mul") { - new_op_desc.SetInput("X", {new_input}); - new_op_desc.SetAttr("X_scale", input_scale); - new_op_desc.SetOutput("Out", {new_output}); + // Convert weight to fp32 range + auto* weight_tensor = + scope->Var(quantized_op_weight_node->Name())->GetMutable(); + auto w_dims = weight_tensor->dims(); + // If quantized op is fc, weight scale size = 1; + // If quantized op is conv, weight scale size = weight dims[0] + bool valid_scale_size = + (weight_scale.size() == 1 || + weight_scale.size() == static_cast(w_dims[0])); + PADDLE_ENFORCE_EQ(valid_scale_size, true, + platform::errors::InvalidArgument( + "TRT int8 quant: invalid scale size")); + float* quantized_weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); + for (int j = 0; j < weight_tensor->numel(); j++) { + if (weight_scale.size() == 1) { + quantized_weight_data[j] *= weight_scale[0]; + } else { + int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; + quantized_weight_data[j] *= weight_scale[j / inner_size]; } + } - new_op_desc.SetAttr("weight_scale", weight_scale); - new_op_desc.Flush(); - auto* new_op = graph->CreateOpNode(&new_op_desc); - IR_NODE_LINK_TO(input_node, new_op); - IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset], new_op); - IR_NODE_LINK_TO(new_op, nodes[i * kNumFields + kDequantOpOutOffset]); + // create new op_desc + auto base_op_desc = *quantized_op_node->Op()->Proto(); + std::string new_input = quantized_op_input_node->Name(); + std::string new_output = dequant_op_out_node->Name(); - delete_nodes.insert(nodes[i * kNumFields + kQuantizedOpOffset]); - delete_nodes.insert(nodes[i * kNumFields + kQuantizedOpOutOffset]); - delete_nodes.insert(nodes[i * kNumFields + kDequantOpOffset]); + framework::OpDesc new_op_desc(base_op_desc, nullptr); + new_op_desc.SetType(quantized_op_type); + new_op_desc.SetAttr("enable_int8", true); + if (quantized_op_type == "conv2d" || quantized_op_type == "conv2d_fusion" || + quantized_op_type == "depthwise_conv2d") { + new_op_desc.SetInput("Input", {new_input}); + new_op_desc.SetOutput("Output", {new_output}); + } else if (quantized_op_type == "fc") { + new_op_desc.SetInput("Input", {new_input}); + new_op_desc.SetOutput("Out", {new_output}); + } else if (quantized_op_type == "mul") { + new_op_desc.SetInput("X", {new_input}); + new_op_desc.SetOutput("Out", {new_output}); } - - delete_nodes.insert(quant_op_in_scale); - delete_nodes.insert(quant_op); - delete_nodes.insert(quant_op_out); - delete_nodes.insert(quant_op_out_scale); - // Delete the unneeded nodes. - GraphSafeRemoveNodes(graph, delete_nodes); + new_op_desc.SetAttr("weight_scale", weight_scale); + new_op_desc.Flush(); + auto* new_op = graph->CreateOpNode(&new_op_desc); + IR_NODE_LINK_TO(quantized_op_input_node, new_op); + IR_NODE_LINK_TO(quantized_op_weight_node, new_op); + IR_NODE_LINK_TO(new_op, dequant_op_out_node); + // Delete nodes and edges + nodes2rm.insert(quantized_op_node); + nodes2rm.insert(dequant_op_node); + GraphSafeRemoveNodes(graph, nodes2rm); }; gpd(graph, handler); } @@ -186,19 +245,19 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { FusePassBase::Init(pattern_name, graph); std::unordered_set dequant_types = { - "fake_dequantize_max_abs", "fake_channel_wise_dequantize_max_abs"}; + "fake_channel_wise_dequantize_max_abs", "fake_dequantize_max_abs"}; std::unordered_set quant_types = { "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; - std::unordered_set quantized_op_types = {"conv2d", "mul", - "depthwise_conv2d"}; + std::unordered_set quantized_op_types = { + "conv2d", "mul", "depthwise_conv2d", "fc"}; auto* scope = param_scope(); + + for (auto& quant_type : quant_types) { + DeleteQuant(graph, scope, quant_type); + } for (auto& dequant_type : dequant_types) { - for (auto& quant_type : quant_types) { - for (auto& op_type : quantized_op_types) { - for (int i = 6; i >= 1; i--) { - RunQuantDequant(graph, scope, i, op_type, quant_type, dequant_type); - } - } + for (auto& quantized_op_type : quantized_op_types) { + FuseDequant(graph, scope, quantized_op_type, dequant_type); } } } diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h index a61b34563acc4cbcee778509a097587222579295..826278afc70039e5a4eed2a18b2c0a29061824d0 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h @@ -22,6 +22,9 @@ namespace paddle { namespace framework { namespace ir { +/// +/// Fuse quant + conv2d/depthwise_conv2d/mul/fc + dequant +/// class QuantDequantFusePass : public FusePassBase { public: virtual ~QuantDequantFusePass() {} diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 22c82de57c5d8e5b06ccf45c1a1a4dbdd3ae573a..e490d571a699e38d4762cb1d1771fb15639e8e13 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -365,6 +365,10 @@ const std::vector &OpDesc::Output(const std::string &name) const { return it->second; } +bool OpDesc::HasOutput(const std::string &name) const { + return outputs_.find(name) != outputs_.end(); +} + std::vector OpDesc::OutputArgumentNames() const { std::vector retv; for (auto &ipt : this->outputs_) { diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 2183ff251e9a7811fae95a7823ff282e167491e1..e15f0012fdc2ebfecec8daebbd3c04b917cd7a84 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -57,6 +57,8 @@ class OpDesc { const std::vector &Output(const std::string &name) const; + bool HasOutput(const std::string &name) const; + std::vector OutputArgumentNames() const; void SetOutput(const std::string ¶m_name, diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 23d964c798ebb982e3f4ffcd424cdda6d4c5d4e6..994f7c95352631b657edc3709f8f141cd68b3660 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -281,11 +281,8 @@ void AnalysisConfig::Update() { if (use_tensorrt_) { pass_builder()->ClearPasses(); - bool use_calib_int8 = - (tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8) && - trt_use_calib_mode_; for (const auto &pass : kTRTSubgraphPasses) { - if (use_calib_int8 && + if (tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 && (pass == "conv_bn_fuse_pass" || pass == "fc_fuse_pass")) { continue; } diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index 90773482332bfab87e1b13900308de3513fe437a..97d09925b19c4911a6b412518dc58fe88da16f64 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -52,7 +52,8 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, if (enable_int8) { #if IS_TRT_VERSION_GE(5000) CHECK(op_desc.HasAttr("Input_scale")); - float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); + float in_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127; auto weight_scale = BOOST_GET_CONST(std::vector, op_desc.GetAttr("weight_scale")); weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index 511dd9d54fbcace63855d91d1a82318e7c2fde98..0dccd3cc6390af736aa4d205a12577fd9ee14f11 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -62,7 +62,7 @@ class FcOpConverter : public OpConverter { #if IS_TRT_VERSION_GE(5000) CHECK(op_desc.HasAttr(i_name + "_scale")); float in_scale = - BOOST_GET_CONST(float, op_desc.GetAttr(i_name + "_scale")); + BOOST_GET_CONST(float, op_desc.GetAttr(i_name + "_scale")) * 127; auto weight_scale = BOOST_GET_CONST(std::vector, op_desc.GetAttr("weight_scale")); weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index a299d845662c10a4ee29f119b806b367f4a6cd83..f4b0f5f23d8fda064c29534b56868beae79f65c0 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -98,8 +98,33 @@ class OpConverter { } PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", op_desc.Type()); + it->SetEngine(engine); (*it)(op, scope, test_mode); + + bool has_out_scale = op_desc.HasAttr("out_threshold"); + if (has_out_scale) { + float out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + std::string output_name = ""; + if (op_desc.HasOutput("Output")) { + output_name = op_desc.Output("Output").front(); + } else if (op_desc.HasOutput("Out")) { + output_name = op_desc.Output("Out").front(); + } else if (op_desc.HasOutput("Y")) { + output_name = op_desc.Output("Y").front(); + } else { + PADDLE_THROW( + platform::errors::NotFound("Op %s has out threshold but doesn't " + "have an output named \"Output\", " + "\"Out\" or \"Y\".", + op_desc.Type())); + } + auto* output_itensor = engine->GetITensor(output_name); + engine->SetTensorDynamicRange(output_itensor, out_scale); + VLOG(1) << "Set out scale = " << out_scale << " for tensor " + << output_name << "."; + } } // Convert a fluid block to tensorrt network, NOTE it just convert operators, diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index b6c23d0d7b8eb899c31e3e6a6db04316788ba7e4..e1e1be683123966235c7e3b00fe894ff2c841c94 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -124,23 +124,42 @@ void TensorRTEngine::FreezeNetwork() { << ", this might be ok when trt does not need this range"; } } - std::unordered_set all_out_t_name; - for (int i = 0; i < network()->getNbOutputs(); i++) { - auto *temp = network()->getOutput(i); - temp->setDynamicRange(-1, 1); - all_out_t_name.insert(temp->getName()); - } - - for (int i = 0; i < network()->getNbLayers(); i++) { - auto layer = network()->getLayer(i); + auto is_layer_int8 = [&](nvinfer1::ILayer *layer) -> bool { + for (int j = 0; j < layer->getNbInputs(); j++) { + auto *temp_in = layer->getInput(j); + if (!temp_in->dynamicRangeIsSet()) { + VLOG(1) << "Layer(Name: " << layer->getName() + << ") is set to float32 because its input(" + << temp_in->getName() << ") doesn't have dynamic range."; + return false; + } + } for (int j = 0; j < layer->getNbOutputs(); j++) { auto *temp_out = layer->getOutput(j); - if (std::find(all_out_t_name.begin(), all_out_t_name.end(), - temp_out->getName()) != all_out_t_name.end()) { - layer->setPrecision(nvinfer1::DataType::kFLOAT); - layer->setOutputType(j, nvinfer1::DataType::kFLOAT); + if (temp_out->isNetworkOutput()) { + VLOG(1) << "Layer(Name: " << layer->getName() + << ") is set to float32 because its output(" + << temp_out->getName() << ") is the output of the network."; + return false; + } + if (!temp_out->dynamicRangeIsSet()) { + VLOG(1) << "Layer(Name: " << layer->getName() + << ") is set to float32 because its output(" + << temp_out->getName() << ") doesn't have dynamic range."; + return false; } } + return true; + }; + // If a layer's output is the network's output, or not all of its inputs + // and outputs have scales, + // this layer's precision and output type are set to float32. + // This step has no effect if this layer is fused during TRT optimization. + for (int i = 0; i < network()->getNbLayers(); i++) { + auto layer = network()->getLayer(i); + if (!is_layer_int8(layer)) { + layer->setPrecision(nvinfer1::DataType::kFLOAT); + } } #endif } @@ -237,7 +256,6 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name, std::string name_suffix = std::to_string(name_suffix_counter); std::string splitter = "__"; std::string name_with_suffix = name + splitter + name_suffix; - auto w_dims = weight_tensor->dims(); platform::CPUPlace cpu_place; PADDLE_ENFORCE_EQ( weight_map.count(name_with_suffix), 0, @@ -250,25 +268,6 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name, float *weight_data = weight_map[name_with_suffix]->mutable_data(cpu_place); name_suffix_counter += 1; - - if (enable_int8) { - // when the op is fc, scale's size should be 1 - // when the op is conv, scale's size should be w_dims[0] - bool valid_scale_size = - (scale.size() == 1 || scale.size() == static_cast(w_dims[0])); - PADDLE_ENFORCE(valid_scale_size, "TRT int8 quant: invalid scale size"); - for (int i = 0; i < weight_tensor->numel(); i++) { - if (scale.size() == 1) { - weight_data[i] *= (scale[0] / 127); - } else { - PADDLE_ENFORCE(w_dims.size() == 4, - "TRT int8 quant : We only use the channel quant for " - "conv op, so the weight dims should be 4."); - int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; - weight_data[i] *= (scale[i / inner_size] / 127); - } - } - } return weight_data; } diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index db077ff306f6e0263398d6ac41afd8789aa8646a..a7bb7c8c4fceb191c11b52ae4ff5574e5e47abd2 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -43,11 +43,18 @@ struct SimpleOpTypeSetTeller : public Teller { private: // use this set for no calib int8. - std::unordered_set int8_teller_set{ - "mul", "conv2d", "pool2d", - "relu", "depthwise_conv2d", "softmax", - "batch_norm", "elementwise_add", "leaky_relu", - "fc"}; + std::unordered_set int8_teller_set{"mul", + "conv2d", + "pool2d", + "relu", + "depthwise_conv2d", + "softmax", + "batch_norm", + "elementwise_add", + "leaky_relu", + "fc", + "relu6", + "concat"}; std::unordered_set teller_set{ "mul", "conv2d", diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index c55955f514700981d1b34f688f57e99196ac7ea3..b2f42d39e885fe93352e323627bbd532ddfe773d 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -405,6 +405,14 @@ if(WITH_GPU AND TENSORRT_FOUND) EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${TRT_MODEL_QUANT_RESNET_DIR}) + set(TRT_MODEL_QUANT_YOLOV3_DIR "${INFERENCE_DEMO_INSTALL_DIR}/yolov3_r50_quant_aware") + if (NOT EXISTS ${TRT_MODEL_QUANT_YOLOV3_DIR}) + inference_download_and_uncompress(${INFERENCE_DEMO_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "yolov3_r50_quant_aware.tgz") + endif() + inference_analysis_test(trt_quant_int8_yolov3_r50_test SRCS trt_quant_int8_yolov3_r50_test.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${TRT_MODEL_QUANT_YOLOV3_DIR}) + set(TEST_TRT_DYNAMIC_MODEL2 "${TRT_MODEL_INSTALL_DIR}/complex_model_dynamic") if (NOT EXISTS ${TEST_TRT_DYNAMIC_MODEL2}) inference_download_and_uncompress(${TEST_TRT_DYNAMIC_MODEL2} ${INFERENCE_URL}/tensorrt_test "complex_model_dynamic2.tar.gz") diff --git a/paddle/fluid/inference/tests/api/trt_quant_int8_yolov3_r50_test.cc b/paddle/fluid/inference/tests/api/trt_quant_int8_yolov3_r50_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4239c138aef2007aaf181c75a83e85ba288fef49 --- /dev/null +++ b/paddle/fluid/inference/tests/api/trt_quant_int8_yolov3_r50_test.cc @@ -0,0 +1,63 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include + +#include "paddle/fluid/inference/tests/api/trt_test_helper.h" + +namespace paddle { +namespace inference { + +TEST(quant_int8, yolov3_resnet50) { + AnalysisConfig config; + config.EnableUseGpu(100, 0); + config.SetModel(FLAGS_infer_model + "/model", FLAGS_infer_model + "/params"); + config.SwitchUseFeedFetchOps(false); + config.EnableTensorRtEngine(1 << 30, 1, 3, AnalysisConfig::Precision::kInt8, + false, false); + + auto predictor = CreatePaddlePredictor(config); + auto input_names = predictor->GetInputNames(); + int channels = 3; + int height = 608; + int width = 608; + int input_num = channels * height * width * 1; + + float *input = new float[input_num]; + int32_t *im_shape = new int32_t[2]; + im_shape[0] = 608; + im_shape[1] = 608; + memset(input, 1.0, input_num * sizeof(float)); + auto input_t = predictor->GetInputTensor(input_names[0]); + input_t->Reshape({1, channels, height, width}); + input_t->copy_from_cpu(input); + + auto input_t1 = predictor->GetInputTensor(input_names[1]); + input_t1->Reshape({1, 2}); + input_t1->copy_from_cpu(im_shape); + + ASSERT_TRUE(predictor->ZeroCopyRun()); + + std::vector out_data; + auto output_names = predictor->GetOutputNames(); + auto output_t = predictor->GetOutputTensor(output_names[0]); + std::vector output_shape = output_t->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()); + out_data.resize(out_num); + output_t->copy_to_cpu(out_data.data()); +} + +} // namespace inference +} // namespace paddle