未验证 提交 eb097d64 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Fix int8 performace drop cpu_quantize_placement_pass (#26715)

* Fix cpu quantize placement pass

* Include string lib
上级 02083bda
...@@ -1879,6 +1879,19 @@ PDNode *patterns::MultipleQuantize::operator()() { ...@@ -1879,6 +1879,19 @@ PDNode *patterns::MultipleQuantize::operator()() {
return prev_out; return prev_out;
} }
PDNode *patterns::QuantizePlacement::operator()(
const std::unordered_set<std::string> &quantize_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>({"concat", "conv2d", "elementwise_add",
"fc", "matmul", "pool2d", "prior_box",
"relu", "reshape2", "transpose2"});
if (!quantize_enabled_op_types.empty()) {
supported_op_types = quantize_enabled_op_types;
}
auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types);
return op;
}
PDNode *patterns::MKLDNNInPlace::operator()() { PDNode *patterns::MKLDNNInPlace::operator()() {
const std::unordered_set<std::string> &supported_op_types = { const std::unordered_set<std::string> &supported_op_types = {
"abs", "abs",
......
...@@ -1120,6 +1120,15 @@ struct MultipleQuantize : public PatternBase { ...@@ -1120,6 +1120,15 @@ struct MultipleQuantize : public PatternBase {
PATTERN_DECL_NODE(prev_out); PATTERN_DECL_NODE(prev_out);
}; };
struct QuantizePlacement : public PatternBase {
QuantizePlacement(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quantize_placement") {}
PDNode* operator()(
const std::unordered_set<std::string>& quantize_enabled_op_types);
PATTERN_DECL_NODE(op);
};
// Pattern used for enforcing inplace computation for in-place computation // Pattern used for enforcing inplace computation for in-place computation
// supporting DNNL ops. softmax, batch_norm and layer_norm // supporting DNNL ops. softmax, batch_norm and layer_norm
struct MKLDNNInPlace : public PatternBase { struct MKLDNNInPlace : public PatternBase {
......
...@@ -26,27 +26,33 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { ...@@ -26,27 +26,33 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
Get<std::unordered_set<int>>("quantize_excluded_op_ids"); Get<std::unordered_set<int>>("quantize_excluded_op_ids");
const auto& op_types_list = const auto& op_types_list =
Get<std::unordered_set<std::string>>("quantize_enabled_op_types"); Get<std::unordered_set<std::string>>("quantize_enabled_op_types");
for (const Node* n : graph->Nodes()) { Init(name_scope_, graph);
if (n->IsOp()) { GraphPatternDetector gpd;
patterns::QuantizePlacement quantize_placement_pattern{gpd.mutable_pattern(),
"quantize_placement"};
quantize_placement_pattern(op_types_list);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, quantize_placement_pattern);
if (std::find(excluded_ids_list.begin(), excluded_ids_list.end(), if (std::find(excluded_ids_list.begin(), excluded_ids_list.end(),
n->id()) != excluded_ids_list.end()) op->id()) != excluded_ids_list.end()) {
continue; return;
auto* op = n->Op(); }
if (op->HasAttr("mkldnn_data_type") ||
op->HasProtoAttr("mkldnn_data_type")) { if (op->Op()->HasAttr("mkldnn_data_type") ||
op->Op()->HasProtoAttr("mkldnn_data_type")) {
// use_quantizer is no longer used // use_quantizer is no longer used
// assign value for compatibility // assign value for compatibility
if (op->GetAttrIfExists<bool>("use_quantizer")) { if (op->Op()->GetAttrIfExists<bool>("use_quantizer")) {
op->SetAttr("mkldnn_data_type", std::string("int8")); op->Op()->SetAttr("mkldnn_data_type", std::string("int8"));
}
if (std::find(op_types_list.begin(), op_types_list.end(), op->Type()) !=
op_types_list.end()) {
op->SetAttr("mkldnn_data_type", std::string("int8"));
op->SetAttr("use_quantizer", true);
}
}
} }
op->Op()->SetAttr("mkldnn_data_type", std::string("int8"));
op->Op()->SetAttr("use_quantizer", true);
} }
};
gpd(graph, handler);
} }
} // namespace ir } // namespace ir
...@@ -58,10 +64,7 @@ REGISTER_PASS(cpu_quantize_placement_pass, ...@@ -58,10 +64,7 @@ REGISTER_PASS(cpu_quantize_placement_pass,
// a vector of operator type names to be quantized ("conv2d" etc.) // a vector of operator type names to be quantized ("conv2d" etc.)
// the second param is the default value for this vector // the second param is the default value for this vector
.DefaultPassAttr("quantize_enabled_op_types", .DefaultPassAttr("quantize_enabled_op_types",
new std::unordered_set<std::string>( new std::unordered_set<std::string>())
{"concat", "conv2d", "elementwise_add", "fc", "matmul",
"pool2d", "prior_box", "relu", "reshape2",
"transpose2"}))
// a vector of operator ids that are to be excluded from quantization // a vector of operator ids that are to be excluded from quantization
// the second param is the default value for this vector // the second param is the default value for this vector
.DefaultPassAttr("quantize_excluded_op_ids", new std::unordered_set<int>()); .DefaultPassAttr("quantize_excluded_op_ids", new std::unordered_set<int>());
...@@ -15,7 +15,10 @@ limitations under the License. */ ...@@ -15,7 +15,10 @@ limitations under the License. */
#pragma once #pragma once
#include <memory> #include <memory>
#include "paddle/fluid/framework/ir/pass.h" #include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -23,9 +26,10 @@ namespace ir { ...@@ -23,9 +26,10 @@ namespace ir {
/* /*
* Specifies which operators should be quantized. * Specifies which operators should be quantized.
*/ */
class CPUQuantizePlacementPass : public Pass { class CPUQuantizePlacementPass : public FusePassBase {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"cpu_quantize_placement_pass"};
}; };
} // namespace ir } // namespace ir
......
...@@ -131,8 +131,8 @@ TEST(QuantizerPlacementPass, enabled_conv_excluded_one) { ...@@ -131,8 +131,8 @@ TEST(QuantizerPlacementPass, enabled_conv_excluded_one) {
} }
TEST(QuantizerPlacementPass, empty_list) { TEST(QuantizerPlacementPass, empty_list) {
// no operator quantized // all operators quantized
MainTest({}, {}, 0); MainTest({}, {}, 6);
} }
TEST(QuantizerPlacementPass, default_attr_value) { TEST(QuantizerPlacementPass, default_attr_value) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册