diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc index 4ce91999207a2b1a8ad2a3ab594aa74f9aece8e3..b9cc337df87929a9a1b26314a077d1c988c9d068 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc @@ -92,7 +92,6 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { int range = ((1 << (bit_length - 1)) - 1); std::vector weight_scale; std::string quant_dequant_op_out_name = quant_dequant_op_out->Var()->Name(); - auto* any_op2_desc = any_op2->Op(); auto var_map = any_op2_desc->Inputs(); std::string arg_name = ""; @@ -106,43 +105,52 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_GT(arg_name.size(), 0, platform::errors::InvalidArgument( "can not find the input %s.", quant_dequant_op_out_name)); - any_op2_desc->SetAttr("enable_int8", true); + // any_op2_desc->SetAttr("enable_int8", true); any_op2_desc->SetAttr("bit_length", bit_length); + // modify the any_op2's inputs - any_op2_desc->Flush(); auto dequant_type = quant_dequant_op->Op()->Type(); - auto quantized_op_type = any_op2_desc->Type(); + // get weight tensor auto* weight_tensor = scope->GetVar(quant_dequant_op_x->Name())->GetMutable(); auto w_dims = weight_tensor->dims(); + float* quantized_weight_data = weight_tensor->mutable_data(platform::CPUPlace()); // Get weight scale if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") { - auto scales_name = quant_dequant_op->Op()->Output("OutScale"); + int quant_axis = + BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis")); + PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument( + "'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + + // To Do @Wangzheee: use "OutScale" to quantdequant + /*auto scales_name = quant_dequant_op->Op()->Output("OutScale"); PADDLE_ENFORCE_EQ(scales_name.size(), 1, platform::errors::InvalidArgument( "Scales size in channel-wise quant dequantize op " "should be 1, got %d.", scales_name.size())); const LoDTensor& channel_scale_tensor = - scope->GetVar(scales_name[0])->Get(); + scope->FindVar(scales_name[0])->Get(); PADDLE_ENFORCE( paddle::platform::is_cpu_place(channel_scale_tensor.place()), platform::errors::InvalidArgument( "Channel scale tensor's place should be CPU.")); // compute the channel wise abs max of the weight tensor - int quant_axis = - BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis")); - PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, - platform::errors::InvalidArgument( - "'quant_axis' should be 0 or 1, but " - "the received is %d", - quant_axis)); + const float* channel_scale_data = channel_scale_tensor.data(); + for (int i = 0; i < channel_scale_tensor.numel(); i++) { + weight_scale.push_back(channel_scale_data[i] ); + }*/ + // Implement channel_wise_quantize_dequantize_abs_max quantization + // algorithm const int64_t channel = w_dims[quant_axis]; weight_scale.resize(channel, 0); if (quant_axis == 0) { @@ -171,11 +179,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NE(weight_scale[i], 0, platform::errors::InvalidArgument( "Weight scale should be nonzero, but get zero.")); - weight_scale[i] = range / weight_scale[i]; + weight_scale[i] = weight_scale[i] / range; } } else { - auto scale_name = quant_dequant_op_outscale->Name(); - // compute the abs max of the weight tensor + // Implement quantize_dequantize_abs_max quantization algorithm float abs_max_weight = 0.; for (int j = 0; j < weight_tensor->numel(); j++) { abs_max_weight = @@ -184,113 +191,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NE(abs_max_weight, 0, platform::errors::InvalidArgument( "Weight scale should be nonzero, but get zero")); - weight_scale.push_back((range * range) / abs_max_weight / range); + weight_scale.push_back(abs_max_weight / range); } nodes2rm.insert(quant_dequant_op_outscale); - - // perform quantize dequantize operations - // If quantized op is not channel wise, weight scale size = 1; - // If quantized op is conv2d, weight scale size = weight dims[0] - // If quantized op is conv2d_transpose, weight scale size = weight dims[1] - if (dequant_type == "fake_quantize_dequantize_abs_max") { - PADDLE_ENFORCE_EQ( - weight_scale.size(), 1, - platform::errors::InvalidArgument( - "%s op weight dequantized by [fake_quantize_dequantize_max_abs] " - "requires weight scale size = 1, but got %d.", - quantized_op_type, weight_scale.size())); - for (int j = 0; j < weight_tensor->numel(); j++) { - // quantized - quantized_weight_data[j] = quantized_weight_data[j] * weight_scale[0]; - quantized_weight_data[j] = std::round(quantized_weight_data[j]); - // dequantized - quantized_weight_data[j] /= weight_scale[0]; - } - } else if (quantized_op_type == "mul" || quantized_op_type == "matmul" || - quantized_op_type == "fc") { - if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") { - PADDLE_ENFORCE_EQ( - weight_scale.size(), static_cast(w_dims[1]), - platform::errors::InvalidArgument( - "mul op weight dequantized by " - "[fake_channel_wise_quantize_dequantize_abs_max] requires " - "weight scale " - "size = 2nd dim of mul's weight, which is %zu, but got %zu.", - static_cast(w_dims[1]), weight_scale.size())); - for (int j = 0; j < weight_tensor->numel(); j++) { - // quantized - PADDLE_ENFORCE_NE( - weight_scale[j % w_dims[1]], 0, - platform::errors::InvalidArgument( - "fc op weight scale should be nonzero, but get zero")); - quantized_weight_data[j] = - quantized_weight_data[j] * weight_scale[j % w_dims[1]]; - quantized_weight_data[j] = std::round(quantized_weight_data[j]); - // dequantized - quantized_weight_data[j] /= weight_scale[j % w_dims[1]]; - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported quantized op type: %s", quantized_op_type)); - } - } else if (quantized_op_type == "conv2d" || - quantized_op_type == "depthwise_conv2d") { - if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") { - PADDLE_ENFORCE_EQ( - weight_scale.size(), static_cast(w_dims[0]), - platform::errors::InvalidArgument( - "conv2d op requires weight scale size = channel size of the " - "weight, which is %zu, but got %zu.", - static_cast(w_dims[0]), weight_scale.size())); - int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; - for (int j = 0; j < weight_tensor->numel(); j++) { - // quantized - PADDLE_ENFORCE_NE( - weight_scale[j / inner_size], 0, - platform::errors::InvalidArgument( - "conv2d op weight scale should be nonzero, but get zero")); - quantized_weight_data[j] = - quantized_weight_data[j] * weight_scale[j / inner_size]; - quantized_weight_data[j] = std::round(quantized_weight_data[j]); - // dequantized - quantized_weight_data[j] /= weight_scale[j / inner_size]; - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported quantized op type: %s", quantized_op_type)); - } - } else if (quantized_op_type == "conv2d_transpose") { - if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") { - PADDLE_ENFORCE_EQ( - weight_scale.size(), static_cast(w_dims[0]), - platform::errors::InvalidArgument( - "conv2d_transpose op requires weight scale size = channel size " - "of the " - "weight, which is %zu, but got %zu.", - static_cast(w_dims[1]), weight_scale.size())); - int inner_size = w_dims[2] * w_dims[3]; - for (int j = 0; j < weight_tensor->numel(); j++) { - // quantized - PADDLE_ENFORCE_NE(weight_scale[(j / inner_size) % w_dims[1]], 0, - platform::errors::InvalidArgument( - "conv2d_transpose op weight scale should be " - "nonzero, but get zero")); - quantized_weight_data[j] = quantized_weight_data[j] * - weight_scale[(j / inner_size) % w_dims[1]]; - quantized_weight_data[j] = std::round(quantized_weight_data[j]); - // dequantized - quantized_weight_data[j] /= - weight_scale[(j / inner_size) % w_dims[1]]; - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported quantized op type: %s", quantized_op_type)); - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported quantized op type: %s", quantized_op_type)); - } nodes2rm.insert(quant_dequant_op_out); // link weight in quant_dequant_op_x to any_op2 diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc index 65e8b8fc80d104d34efc5ff863c0851709ef2abf..b99f2266f39b248ed705fa284ebf2e9a31006088 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc @@ -28,76 +28,85 @@ namespace ir { #define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); #define GET_NODES \ - GET_IR_NODE(any_op_out); \ GET_IR_NODE(quant_dequant_op_inscale); \ GET_IR_NODE(quant_dequant_op); \ GET_IR_NODE(quant_dequant_op_outscale); \ - GET_IR_NODE(quant_dequant_op_out); \ - GET_IR_NODE(any_op2); + GET_IR_NODE(quant_dequant_op_out); void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "delete_quantdequant_op_pattern"; FusePassBase::Init(pattern_name, graph); - GraphPatternDetector gpd; + std::string quantdequant_types = + "fake_quantize_dequantize_moving_average_abs_max"; + + auto* input_node = gpd.mutable_pattern() + ->NewNode("input_node") + ->assert_is_op_input(quantdequant_types, "X") + ->AsInput(); + patterns::DeleteQuantDequantOpPattern pattern(gpd.mutable_pattern(), pattern_name); - pattern(); + pattern(input_node, quantdequant_types); auto* scope = param_scope(); + int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + PADDLE_ENFORCE_EQ( + subgraph.count(input_node), true, + platform::errors::NotFound( + "Input act node(%s) not found in QuantDequantFuse pass.", + input_node->name())); + Node* input = subgraph.at(input_node); GET_NODES; - IR_NODE_LINK_TO(any_op_out, any_op2); - std::string any_op_out_name = any_op_out->Var()->Name(); - std::string quant_dequant_op_out_name = quant_dequant_op_out->Var()->Name(); + int bit_length = + BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length")); + int range = ((1 << (bit_length - 1)) - 1); + // Get input scale from tensor std::string input_scale_var_name = quant_dequant_op->Op()->Input("InScale").front(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument( + "Scope in DeleteQuantDequantOpPass should not be null.")); const LoDTensor& input_scale_tensor = - scope->GetVar(input_scale_var_name)->Get(); - + scope->FindVar(input_scale_var_name)->Get(); + PADDLE_ENFORCE_EQ( + paddle::platform::is_cpu_place(input_scale_tensor.place()), true, + platform::errors::InvalidArgument( + "Input scale tensor's place should be CPU.")); const float* input_scale_data = input_scale_tensor.data(); - float input_scale = input_scale_data[0] / 127.; - auto* any_op2_desc = any_op2->Op(); - // auto input_args_names = any_op2_desc->InputArgumentNames(); - auto var_map = any_op2_desc->Inputs(); - std::string arg_name = ""; - for (auto& name_m : var_map) { - if (std::find(name_m.second.begin(), name_m.second.end(), - quant_dequant_op_out_name) != name_m.second.end()) { - arg_name = name_m.first; - } - } - CHECK(arg_name.size() > 0) << "can not find the input " - << quant_dequant_op_out_name; - any_op2_desc->SetAttr("enable_int8", true); - any_op2_desc->SetAttr(arg_name + "_scale", input_scale); + float input_scale = input_scale_data[0] / range; - // modify the any_op2's inputs - for (auto& name_m : var_map) { - if (std::find(name_m.second.begin(), name_m.second.end(), - quant_dequant_op_out_name) != name_m.second.end()) { - std::vector new_inputs; - for (auto& i_n : name_m.second) { - if (i_n != quant_dequant_op_out_name) { - new_inputs.push_back(i_n); - } - } - new_inputs.push_back(any_op_out_name); - any_op2_desc->SetInput(name_m.first, new_inputs); - any_op2_desc->Flush(); + // Set input scale in attr, and relink nodes + std::string input_name = input->Var()->Name(); + std::string quant_dequant_output_name = quant_dequant_op_out->Var()->Name(); + auto outlinks = quant_dequant_op_out->outputs; + for (auto* quantized_node : outlinks) { + auto op_desc = quantized_node->Op(); + std::string quantized_op_type = op_desc->Type(); + if (quantized_op_type == "mul" || quantized_op_type == "matmul" || + quantized_op_type == "matmul_v2") { + op_desc->SetAttr("X_scale", input_scale); + } else { + op_desc->SetAttr("Input_scale", input_scale); } + op_desc->SetAttr("bit_length", bit_length); + op_desc->RenameInput(quant_dequant_output_name, input_name); + op_desc->Flush(); + IR_NODE_LINK_TO(input, quantized_node); } - any_op2_desc->Flush(); + // Delete the unneeded nodes. GraphSafeRemoveNodes(graph, - {quant_dequant_op, quant_dequant_op_out, - quant_dequant_op_inscale, quant_dequant_op_outscale}); + {quant_dequant_op_inscale, quant_dequant_op, + quant_dequant_op_outscale, quant_dequant_op_out}); + found_count++; }; - gpd(graph, handler); + AddStatis(found_count); } } // namespace ir diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index e1b77a59911fbe06e0d829e36715be89ef1f656c..4150d0ca555c9d2ddc706ef3d17ff05bde02c360 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2547,39 +2547,28 @@ void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) { reshape2_out->LinksFrom({reshape2_op}); } -void patterns::DeleteQuantDequantOpPattern::operator()() { - auto any_op_out = - pattern->NewNode(any_op_out_repr()) - ->assert_is_op_input( - "fake_quantize_dequantize_moving_average_abs_max", "X") - ->AsInput(); - +void patterns::DeleteQuantDequantOpPattern::operator()( + PDNode *input_node, const std::string &quantdequant_types) { auto quant_dequant_op_inscale = pattern->NewNode(quant_dequant_op_inscale_repr()) - ->assert_is_op_input( - "fake_quantize_dequantize_moving_average_abs_max", "InScale") + ->assert_is_op_input(quantdequant_types, "InScale") ->AsInput(); - auto quant_dequant_op = - pattern->NewNode(quant_dequant_op_repr()) - ->assert_is_op("fake_quantize_dequantize_moving_average_abs_max"); + auto quant_dequant_op = pattern->NewNode(quant_dequant_op_repr()) + ->assert_is_op(quantdequant_types); - auto quant_dequant_out = + auto quant_dequant_op_out = pattern->NewNode(quant_dequant_op_out_repr()) - ->assert_is_op_output( - "fake_quantize_dequantize_moving_average_abs_max", "Out") - ->AsIntermediate(); + ->assert_is_op_output(quantdequant_types, "Out") + ->AsOutput(); auto quant_dequant_op_outscale = pattern->NewNode(quant_dequant_op_outscale_repr()) - ->assert_is_op_output( - "fake_quantize_dequantize_moving_average_abs_max", "OutScale") + ->assert_is_op_output(quantdequant_types, "OutScale") ->AsOutput(); - auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput(); - quant_dequant_op->LinksFrom({any_op_out, quant_dequant_op_inscale}); + quant_dequant_op->LinksFrom({quant_dequant_op_inscale, input_node}); quant_dequant_op_outscale->LinksFrom({quant_dequant_op}); - quant_dequant_out->LinksFrom({quant_dequant_op}); - any_op2->LinksFrom({quant_dequant_out}); + quant_dequant_op_out->LinksFrom({quant_dequant_op}); } void patterns::DeleteQuantDequantFilterOpPattern::operator()() { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 3cfaa4661ae68e0359245a841aa40caf00329aff..40c3e4f59bf262ea260a3e9a784d9bc73696ed80 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1481,14 +1481,12 @@ struct DeleteQuantDequantOpPattern : public PatternBase { DeleteQuantDequantOpPattern(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "delete_quantdequant_op_pattern") {} - void operator()(); + void operator()(PDNode* input_node, const std::string& quantdequant_types); - PATTERN_DECL_NODE(any_op_out); PATTERN_DECL_NODE(quant_dequant_op_inscale); PATTERN_DECL_NODE(quant_dequant_op); PATTERN_DECL_NODE(quant_dequant_op_outscale); PATTERN_DECL_NODE(quant_dequant_op_out); - PATTERN_DECL_NODE(any_op2); }; struct DeleteQuantDequantFilterOpPattern : public PatternBase {