未验证 提交 eb097d64 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Fix int8 performace drop cpu_quantize_placement_pass (#26715)

* Fix cpu quantize placement pass

* Include string lib
上级 02083bda
......@@ -1879,6 +1879,19 @@ PDNode *patterns::MultipleQuantize::operator()() {
return prev_out;
}
PDNode *patterns::QuantizePlacement::operator()(
const std::unordered_set<std::string> &quantize_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>({"concat", "conv2d", "elementwise_add",
"fc", "matmul", "pool2d", "prior_box",
"relu", "reshape2", "transpose2"});
if (!quantize_enabled_op_types.empty()) {
supported_op_types = quantize_enabled_op_types;
}
auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types);
return op;
}
PDNode *patterns::MKLDNNInPlace::operator()() {
const std::unordered_set<std::string> &supported_op_types = {
"abs",
......
......@@ -1120,6 +1120,15 @@ struct MultipleQuantize : public PatternBase {
PATTERN_DECL_NODE(prev_out);
};
struct QuantizePlacement : public PatternBase {
QuantizePlacement(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quantize_placement") {}
PDNode* operator()(
const std::unordered_set<std::string>& quantize_enabled_op_types);
PATTERN_DECL_NODE(op);
};
// Pattern used for enforcing inplace computation for in-place computation
// supporting DNNL ops. softmax, batch_norm and layer_norm
struct MKLDNNInPlace : public PatternBase {
......
......@@ -26,27 +26,33 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
Get<std::unordered_set<int>>("quantize_excluded_op_ids");
const auto& op_types_list =
Get<std::unordered_set<std::string>>("quantize_enabled_op_types");
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::QuantizePlacement quantize_placement_pattern{gpd.mutable_pattern(),
"quantize_placement"};
quantize_placement_pattern(op_types_list);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, quantize_placement_pattern);
if (std::find(excluded_ids_list.begin(), excluded_ids_list.end(),
n->id()) != excluded_ids_list.end())
continue;
auto* op = n->Op();
if (op->HasAttr("mkldnn_data_type") ||
op->HasProtoAttr("mkldnn_data_type")) {
op->id()) != excluded_ids_list.end()) {
return;
}
if (op->Op()->HasAttr("mkldnn_data_type") ||
op->Op()->HasProtoAttr("mkldnn_data_type")) {
// use_quantizer is no longer used
// assign value for compatibility
if (op->GetAttrIfExists<bool>("use_quantizer")) {
op->SetAttr("mkldnn_data_type", std::string("int8"));
}
if (std::find(op_types_list.begin(), op_types_list.end(), op->Type()) !=
op_types_list.end()) {
op->SetAttr("mkldnn_data_type", std::string("int8"));
op->SetAttr("use_quantizer", true);
}
}
if (op->Op()->GetAttrIfExists<bool>("use_quantizer")) {
op->Op()->SetAttr("mkldnn_data_type", std::string("int8"));
}
op->Op()->SetAttr("mkldnn_data_type", std::string("int8"));
op->Op()->SetAttr("use_quantizer", true);
}
};
gpd(graph, handler);
}
} // namespace ir
......@@ -58,10 +64,7 @@ REGISTER_PASS(cpu_quantize_placement_pass,
// a vector of operator type names to be quantized ("conv2d" etc.)
// the second param is the default value for this vector
.DefaultPassAttr("quantize_enabled_op_types",
new std::unordered_set<std::string>(
{"concat", "conv2d", "elementwise_add", "fc", "matmul",
"pool2d", "prior_box", "relu", "reshape2",
"transpose2"}))
new std::unordered_set<std::string>())
// a vector of operator ids that are to be excluded from quantization
// the second param is the default value for this vector
.DefaultPassAttr("quantize_excluded_op_ids", new std::unordered_set<int>());
......@@ -15,7 +15,10 @@ limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
......@@ -23,9 +26,10 @@ namespace ir {
/*
* Specifies which operators should be quantized.
*/
class CPUQuantizePlacementPass : public Pass {
class CPUQuantizePlacementPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"cpu_quantize_placement_pass"};
};
} // namespace ir
......
......@@ -131,8 +131,8 @@ TEST(QuantizerPlacementPass, enabled_conv_excluded_one) {
}
TEST(QuantizerPlacementPass, empty_list) {
// no operator quantized
MainTest({}, {}, 0);
// all operators quantized
MainTest({}, {}, 6);
}
TEST(QuantizerPlacementPass, default_attr_value) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册