diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index ff6dffa704eeceeabfc5eb1d6786f40b2e523e98..3d65fe595373fa98ba237f04134c75d4a60a7242 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1879,6 +1879,19 @@ PDNode *patterns::MultipleQuantize::operator()() { return prev_out; } +PDNode *patterns::QuantizePlacement::operator()( + const std::unordered_set &quantize_enabled_op_types) { + std::unordered_set supported_op_types = + std::unordered_set({"concat", "conv2d", "elementwise_add", + "fc", "matmul", "pool2d", "prior_box", + "relu", "reshape2", "transpose2"}); + if (!quantize_enabled_op_types.empty()) { + supported_op_types = quantize_enabled_op_types; + } + auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types); + return op; +} + PDNode *patterns::MKLDNNInPlace::operator()() { const std::unordered_set &supported_op_types = { "abs", diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index e1cce7848dd54b02a540b144ca1088f62eeb52cb..0803265884165bc754489b18d07c0d277a4bd92b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1120,6 +1120,15 @@ struct MultipleQuantize : public PatternBase { PATTERN_DECL_NODE(prev_out); }; +struct QuantizePlacement : public PatternBase { + QuantizePlacement(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "quantize_placement") {} + PDNode* operator()( + const std::unordered_set& quantize_enabled_op_types); + + PATTERN_DECL_NODE(op); +}; + // Pattern used for enforcing inplace computation for in-place computation // supporting DNNL ops. softmax, batch_norm and layer_norm struct MKLDNNInPlace : public PatternBase { diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc index 0644cf9bb6575462d2d8362713a4720d2684bf8d..bc268a834780cad843a18a74bb7f50a639db103d 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc @@ -26,27 +26,33 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { Get>("quantize_excluded_op_ids"); const auto& op_types_list = Get>("quantize_enabled_op_types"); - for (const Node* n : graph->Nodes()) { - if (n->IsOp()) { - if (std::find(excluded_ids_list.begin(), excluded_ids_list.end(), - n->id()) != excluded_ids_list.end()) - continue; - auto* op = n->Op(); - if (op->HasAttr("mkldnn_data_type") || - op->HasProtoAttr("mkldnn_data_type")) { - // use_quantizer is no longer used - // assign value for compatibility - if (op->GetAttrIfExists("use_quantizer")) { - op->SetAttr("mkldnn_data_type", std::string("int8")); - } - if (std::find(op_types_list.begin(), op_types_list.end(), op->Type()) != - op_types_list.end()) { - op->SetAttr("mkldnn_data_type", std::string("int8")); - op->SetAttr("use_quantizer", true); - } + Init(name_scope_, graph); + GraphPatternDetector gpd; + patterns::QuantizePlacement quantize_placement_pattern{gpd.mutable_pattern(), + "quantize_placement"}; + quantize_placement_pattern(op_types_list); + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(op, op, quantize_placement_pattern); + + if (std::find(excluded_ids_list.begin(), excluded_ids_list.end(), + op->id()) != excluded_ids_list.end()) { + return; + } + + if (op->Op()->HasAttr("mkldnn_data_type") || + op->Op()->HasProtoAttr("mkldnn_data_type")) { + // use_quantizer is no longer used + // assign value for compatibility + if (op->Op()->GetAttrIfExists("use_quantizer")) { + op->Op()->SetAttr("mkldnn_data_type", std::string("int8")); } + op->Op()->SetAttr("mkldnn_data_type", std::string("int8")); + op->Op()->SetAttr("use_quantizer", true); } - } + }; + gpd(graph, handler); } } // namespace ir @@ -58,10 +64,7 @@ REGISTER_PASS(cpu_quantize_placement_pass, // a vector of operator type names to be quantized ("conv2d" etc.) // the second param is the default value for this vector .DefaultPassAttr("quantize_enabled_op_types", - new std::unordered_set( - {"concat", "conv2d", "elementwise_add", "fc", "matmul", - "pool2d", "prior_box", "relu", "reshape2", - "transpose2"})) + new std::unordered_set()) // a vector of operator ids that are to be excluded from quantization // the second param is the default value for this vector .DefaultPassAttr("quantize_excluded_op_ids", new std::unordered_set()); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.h index 008a462dc414c04f53315a8f262de15ab8fb7fb5..f3229e59d6ffb97514adb9c871d4fb981fc964e0 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.h @@ -15,7 +15,10 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/framework/ir/pass.h" +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { @@ -23,9 +26,10 @@ namespace ir { /* * Specifies which operators should be quantized. */ -class CPUQuantizePlacementPass : public Pass { +class CPUQuantizePlacementPass : public FusePassBase { protected: void ApplyImpl(ir::Graph* graph) const override; + const std::string name_scope_{"cpu_quantize_placement_pass"}; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc index 6977a9495853f9aa9a0680cafc51a170b848bb37..761defc25ff5c89b740ccd5adff7d613beccd9d4 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc @@ -131,8 +131,8 @@ TEST(QuantizerPlacementPass, enabled_conv_excluded_one) { } TEST(QuantizerPlacementPass, empty_list) { - // no operator quantized - MainTest({}, {}, 0); + // all operators quantized + MainTest({}, {}, 6); } TEST(QuantizerPlacementPass, default_attr_value) {