未验证 提交 8686a745 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add matmul_v2 and fused_matmul to the quantization process and adjust Ernie model test (#50354)

* Add matmul_v2 to the quantization process and adjust Ernie model test

* Correct cpu_quantize_pass test

* Move op to fuse transformation to placement pass

* Correct test
上级 6fa29c55
...@@ -2046,11 +2046,9 @@ PDNode *patterns::Reshape2Matmul::operator()() { ...@@ -2046,11 +2046,9 @@ PDNode *patterns::Reshape2Matmul::operator()() {
return matmul_out; return matmul_out;
} }
PDNode *patterns::MatmulWithInputOps::operator()(bool with_residual) { PDNode *patterns::FusedMatmul::operator()(bool with_residual) {
auto prev_op_x = pattern->NewNode(prev_op_x_repr())->assert_is_op(); auto matmul_op =
auto prev_op_y = pattern->NewNode(prev_op_y_repr())->assert_is_op(); pattern->NewNode(matmul_op_repr())->assert_is_op("fused_matmul");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
if (!with_residual) { if (!with_residual) {
matmul_op->assert_more([&](Node *x) { matmul_op->assert_more([&](Node *x) {
...@@ -2061,26 +2059,24 @@ PDNode *patterns::MatmulWithInputOps::operator()(bool with_residual) { ...@@ -2061,26 +2059,24 @@ PDNode *patterns::MatmulWithInputOps::operator()(bool with_residual) {
auto matmul_in_x = pattern->NewNode(matmul_in_x_repr()) auto matmul_in_x = pattern->NewNode(matmul_in_x_repr())
->AsInput() ->AsInput()
->assert_is_op_input("matmul", "X"); ->assert_is_op_input("fused_matmul", "X");
auto matmul_in_y = pattern->NewNode(matmul_in_y_repr()) auto matmul_in_y = pattern->NewNode(matmul_in_y_repr())
->AsInput() ->AsInput()
->assert_is_op_input("matmul", "Y"); ->assert_is_op_input("fused_matmul", "Y");
auto matmul_out = pattern->NewNode(matmul_out_repr()) auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("matmul", "Out") ->assert_is_op_output("fused_matmul", "Out")
->assert_is_only_output_of_op("matmul"); ->assert_is_only_output_of_op("fused_matmul");
std::vector<PDNode *> links_from{matmul_in_x, matmul_in_y}; std::vector<PDNode *> links_from{matmul_in_x, matmul_in_y};
if (with_residual) { if (with_residual) {
auto matmul_residual_data = auto matmul_residual_data =
pattern->NewNode(matmul_residual_data_repr()) pattern->NewNode(matmul_residual_data_repr())
->AsInput() ->AsInput()
->assert_is_op_input("matmul", "ResidualData"); ->assert_is_op_input("fused_matmul", "ResidualData");
links_from.push_back(matmul_residual_data); links_from.push_back(matmul_residual_data);
} }
prev_op_x->LinksTo({matmul_in_x});
prev_op_y->LinksTo({matmul_in_y});
matmul_op->LinksFrom(links_from).LinksTo({matmul_out}); matmul_op->LinksFrom(links_from).LinksTo({matmul_out});
return matmul_out; return matmul_out;
} }
...@@ -2835,6 +2831,9 @@ PDNode *patterns::QuantizePlacement::operator()( ...@@ -2835,6 +2831,9 @@ PDNode *patterns::QuantizePlacement::operator()(
const std::unordered_set<std::string> &quantize_enabled_op_types) { const std::unordered_set<std::string> &quantize_enabled_op_types) {
auto *op = auto *op =
pattern->NewNode(op_repr())->assert_is_ops(quantize_enabled_op_types); pattern->NewNode(op_repr())->assert_is_ops(quantize_enabled_op_types);
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<bool>("use_mkldnn");
});
return op; return op;
} }
......
...@@ -1281,15 +1281,13 @@ struct Reshape2Matmul : public PatternBase { ...@@ -1281,15 +1281,13 @@ struct Reshape2Matmul : public PatternBase {
PATTERN_DECL_NODE(matmul_out); PATTERN_DECL_NODE(matmul_out);
}; };
// Forward pass for two input ops and matmul op. // Forward pass for two input ops and fused_matmul op.
// matmul_out is a result of the operator. // matmul_out is a result of the operator.
struct MatmulWithInputOps : public PatternBase { struct FusedMatmul : public PatternBase {
MatmulWithInputOps(PDPattern* pattern, const std::string& name_scope) FusedMatmul(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_with_input_ops") {} : PatternBase(pattern, name_scope, "fused_matmul") {}
PDNode* operator()(bool with_residual); PDNode* operator()(bool with_residual);
PATTERN_DECL_NODE(prev_op_x);
PATTERN_DECL_NODE(prev_op_y);
PATTERN_DECL_NODE(matmul_in_x); PATTERN_DECL_NODE(matmul_in_x);
PATTERN_DECL_NODE(matmul_in_y); PATTERN_DECL_NODE(matmul_in_y);
PATTERN_DECL_NODE(matmul_op); PATTERN_DECL_NODE(matmul_op);
......
...@@ -880,7 +880,7 @@ void CPUQuantizePass::QuantizeImmutable(Graph* graph, ...@@ -880,7 +880,7 @@ void CPUQuantizePass::QuantizeImmutable(Graph* graph,
void CPUQuantizePass::QuantizeMatmul(Graph* graph, bool with_residual) const { void CPUQuantizePass::QuantizeMatmul(Graph* graph, bool with_residual) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::MatmulWithInputOps matmul_pattern{pattern, name_scope_}; patterns::FusedMatmul matmul_pattern{pattern, name_scope_};
matmul_pattern(with_residual); matmul_pattern(with_residual);
int quantize_matmul_count = 0; int quantize_matmul_count = 0;
...@@ -894,15 +894,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph, bool with_residual) const { ...@@ -894,15 +894,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph, bool with_residual) const {
LogQuantizationDisabled(matmul_op); LogQuantizationDisabled(matmul_op);
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(prev_op_x, prev_op_x, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(prev_op_y, prev_op_y, matmul_pattern);
// skip if prev ops are not quantized
if (!IsOpDequantized(prev_op_x) && !IsOpDequantized(prev_op_y)) {
MarkAndLogCannotQuantizeOp(matmul_op,
"No other quantizable operators nearby");
return;
}
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);
......
...@@ -87,7 +87,8 @@ void SetOp(ProgramDesc* prog, ...@@ -87,7 +87,8 @@ void SetOp(ProgramDesc* prog,
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]}); op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale", 1.0f); op->SetAttr("Scale", 1.0f);
} else if (type == "matmul") { } else if (type == "matmul" || type == "matmul_v2" ||
type == "fused_matmul") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]}); if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
if (inputs.size() > 2) op->SetInput("ResidualData", {inputs[2]}); if (inputs.size() > 2) op->SetInput("ResidualData", {inputs[2]});
...@@ -176,12 +177,12 @@ void CheckScales(const OpDesc* op, float scale, float shift) { ...@@ -176,12 +177,12 @@ void CheckScales(const OpDesc* op, float scale, float shift) {
scale); scale);
scale_names.push_back("Scale_in"); scale_names.push_back("Scale_in");
scale_names.push_back("Scale_out"); scale_names.push_back("Scale_out");
} else if (type == "matmul" || type == "elementwise_add" || } else if (type == "fused_matmul" || type == "elementwise_add" ||
type == "elementwise_mul" || type == "elementwise_sub") { type == "elementwise_mul" || type == "elementwise_sub") {
scale_names.push_back("Scale_x"); scale_names.push_back("Scale_x");
scale_names.push_back("Scale_y"); scale_names.push_back("Scale_y");
scale_names.push_back("Scale_out"); scale_names.push_back("Scale_out");
if (type == "matmul") { if (type == "fused_matmul") {
auto const& names = op->InputNames(); auto const& names = op->InputNames();
if (std::find(names.begin(), names.end(), "ResidualData") != names.end()) if (std::find(names.begin(), names.end(), "ResidualData") != names.end())
scale_names.push_back("Scale_in_eltwise"); scale_names.push_back("Scale_in_eltwise");
...@@ -594,20 +595,7 @@ ProgramDesc BuildProgramDescMatmul() { ...@@ -594,20 +595,7 @@ ProgramDesc BuildProgramDescMatmul() {
} }
SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true); SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true);
SetOp(&prog, "dequantize", "Dequantize2", {"c"}, {"d"}, true); SetOp(&prog, "dequantize", "Dequantize2", {"c"}, {"d"}, true);
SetOp(&prog, "matmul", "Matmul", {"b", "d"}, {"e"}, true, "int8"); SetOp(&prog, "fused_matmul", "FusedMatmul", {"b", "d"}, {"e"}, true, "int8");
SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, "float32");
return prog;
}
ProgramDesc BuildProgramDescMatmulNotQuantized() {
ProgramDesc prog;
for (auto& v : variable_names_matmul) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "dropout", "Dropout1", {"a"}, {"b"}, false);
SetOp(&prog, "dropout", "Dropout2", {"c"}, {"d"}, false);
SetOp(&prog, "matmul", "Matmul", {"b", "d"}, {"e"}, true, "int8");
SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, "float32"); SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, "float32");
return prog; return prog;
...@@ -621,7 +609,13 @@ ProgramDesc BuildProgramDescMatmulResidual() { ...@@ -621,7 +609,13 @@ ProgramDesc BuildProgramDescMatmulResidual() {
SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true); SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true);
SetOp(&prog, "dequantize", "Dequantize2", {"c"}, {"d"}, true); SetOp(&prog, "dequantize", "Dequantize2", {"c"}, {"d"}, true);
SetOp(&prog, "dequantize", "Dequantize3", {"e"}, {"f"}, true); SetOp(&prog, "dequantize", "Dequantize3", {"e"}, {"f"}, true);
SetOp(&prog, "matmul", "Matmul", {"b", "d", "f"}, {"g"}, true, "int8"); SetOp(&prog,
"fused_matmul",
"FusedMatmul",
{"b", "d", "f"},
{"g"},
true,
"int8");
SetOp(&prog, "dropout", "Dropout", {"g"}, {"h"}, true, "float32"); SetOp(&prog, "dropout", "Dropout", {"g"}, {"h"}, true, "float32");
return prog; return prog;
...@@ -631,7 +625,7 @@ TEST(CpuQuantizePass, matmul) { ...@@ -631,7 +625,7 @@ TEST(CpuQuantizePass, matmul) {
// 2 Quant + 2 IN + 1 DeQuant + 1 OUT // 2 Quant + 2 IN + 1 DeQuant + 1 OUT
int added_nodes = 6; int added_nodes = 6;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{"matmul", 1}, {"quantize", 2}, {"dequantize", 3}}; {"fused_matmul", 1}, {"quantize", 2}, {"dequantize", 3}};
MainTest(BuildProgramDescMatmul(), MainTest(BuildProgramDescMatmul(),
variable_names_matmul, variable_names_matmul,
expected_operators, expected_operators,
...@@ -639,23 +633,11 @@ TEST(CpuQuantizePass, matmul) { ...@@ -639,23 +633,11 @@ TEST(CpuQuantizePass, matmul) {
SCALE * S8_MAX); SCALE * S8_MAX);
} }
TEST(CpuQuantizePass, matmul_not_quantized) {
// nothing change
int added_nodes = 0;
std::unordered_map<std::string, int> expected_operators = {
{"matmul", 1}, {"quantize", 0}, {"dequantize", 0}};
MainTest(BuildProgramDescMatmulNotQuantized(),
variable_names_matmul,
expected_operators,
added_nodes,
1.0f);
}
TEST(CpuQuantizePass, matmul_residual) { TEST(CpuQuantizePass, matmul_residual) {
// 3 Quant + 3 IN + 1 DeQuant + 1 OUT // 3 Quant + 3 IN + 1 DeQuant + 1 OUT
int added_nodes = 8; int added_nodes = 8;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{"matmul", 1}, {"quantize", 3}, {"dequantize", 4}}; {"fused_matmul", 1}, {"quantize", 3}, {"dequantize", 4}};
MainTest(BuildProgramDescMatmulResidual(), MainTest(BuildProgramDescMatmulResidual(),
variable_names_matmul, variable_names_matmul,
expected_operators, expected_operators,
......
...@@ -22,17 +22,44 @@ namespace ir { ...@@ -22,17 +22,44 @@ namespace ir {
class Graph; class Graph;
void ReplaceWithFusedOp(Node* op) {
const std::string matmul_type = op->Op()->Type();
if (matmul_type == "matmul" || matmul_type == "matmul_v2") {
op->Op()->SetType("fused_matmul");
if (matmul_type == "matmul") {
op->Op()->SetAttr("trans_x", op->Op()->GetAttr("transpose_X"));
op->Op()->SetAttr("trans_y", op->Op()->GetAttr("transpose_Y"));
op->Op()->SetAttr("matmul_alpha", op->Op()->GetAttr("alpha"));
}
}
}
void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Marks operators which are to be quantized."; VLOG(3) << "Marks operators which are to be quantized.";
std::unordered_set<std::string> supported_op_types = std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>( std::unordered_set<std::string>({"concat",
{"concat", "conv2d", "depthwise_conv2d", "conv2d",
"fused_conv2d", "fused_conv3d", "elementwise_add", "depthwise_conv2d",
"elementwise_mul", "elementwise_sub", "fc", "fused_conv2d",
"matmul", "nearest_interp", "nearest_interp_v2", "fused_conv3d",
"pool2d", "prior_box", "reshape2", "fused_matmul",
"transpose2", "fusion_gru", "fusion_lstm", "elementwise_add",
"multi_gru", "slice", "split"}); "elementwise_mul",
"elementwise_sub",
"fc",
"matmul",
"matmul_v2",
"nearest_interp",
"nearest_interp_v2",
"pool2d",
"prior_box",
"reshape2",
"transpose2",
"fusion_gru",
"fusion_lstm",
"multi_gru",
"slice",
"split"});
const auto& excluded_ids_list = const auto& excluded_ids_list =
Get<std::unordered_set<int>>("quantize_excluded_op_ids"); Get<std::unordered_set<int>>("quantize_excluded_op_ids");
const auto& op_types_list = const auto& op_types_list =
...@@ -69,6 +96,8 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { ...@@ -69,6 +96,8 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
if (op->Op()->GetAttrIfExists<int>("skip_quant") == 1) { if (op->Op()->GetAttrIfExists<int>("skip_quant") == 1) {
return; return;
} }
ReplaceWithFusedOp(op);
op->Op()->SetAttr("mkldnn_data_type", std::string("int8")); op->Op()->SetAttr("mkldnn_data_type", std::string("int8"));
}; };
gpd(graph, handler); gpd(graph, handler);
......
...@@ -30,6 +30,7 @@ void SetOp(ProgramDesc* prog, ...@@ -30,6 +30,7 @@ void SetOp(ProgramDesc* prog,
auto* op = prog->MutableBlock(0)->AppendOp(); auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type); op->SetType(type);
op->SetAttr("use_mkldnn", true);
op->SetAttr("mkldnn_data_type", mkldnn_data_type); op->SetAttr("mkldnn_data_type", mkldnn_data_type);
if (type == "conv2d") { if (type == "conv2d") {
......
...@@ -36,7 +36,8 @@ static void SaveInfoInTheFirstOp( ...@@ -36,7 +36,8 @@ static void SaveInfoInTheFirstOp(
for (auto* op_node : for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) { ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch") op_node->Op()->Type() == "fetch" ||
op_node->Op()->Type() == "fill_constant")
continue; continue;
op_node->Op()->SetAttr(flag, true); op_node->Op()->SetAttr(flag, true);
...@@ -57,7 +58,8 @@ static void SaveInfoInTheFirstOp(ir::Graph* graph, ...@@ -57,7 +58,8 @@ static void SaveInfoInTheFirstOp(ir::Graph* graph,
for (auto* op_node : for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) { ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch") op_node->Op()->Type() == "fetch" ||
op_node->Op()->Type() == "fill_constant")
continue; continue;
op_node->Op()->SetAttr(flag, true); op_node->Op()->SetAttr(flag, true);
......
...@@ -672,22 +672,8 @@ void AnalysisConfig::EnableMkldnnInt8( ...@@ -672,22 +672,8 @@ void AnalysisConfig::EnableMkldnnInt8(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
use_mkldnn_int8_ = true; use_mkldnn_int8_ = true;
use_fc_padding_ = false; use_fc_padding_ = false;
if (!op_list.empty()) { if (!op_list.empty())
for (auto &type : op_list) { quantize_enabled_op_types_.insert(op_list.begin(), op_list.end());
if (!quantize_enabled_op_types_.count(type)) {
LOG(ERROR) << "There are unsupported operators in the configured "
"quantization operator list. The unsupported operator "
"is: "
<< type;
use_mkldnn_int8_ = false;
break;
}
}
if (use_mkldnn_int8_) {
quantize_enabled_op_types_.clear();
quantize_enabled_op_types_.insert(op_list.begin(), op_list.end());
}
}
#else #else
LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnInt8"; LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnInt8";
use_mkldnn_int8_ = false; use_mkldnn_int8_ = false;
......
...@@ -1191,26 +1191,7 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -1191,26 +1191,7 @@ struct PD_INFER_DECL AnalysisConfig {
std::unordered_set<std::string> bfloat16_enabled_op_types_; std::unordered_set<std::string> bfloat16_enabled_op_types_;
bool use_mkldnn_int8_{false}; bool use_mkldnn_int8_{false};
std::unordered_set<int> quantize_excluded_op_ids_{}; std::unordered_set<int> quantize_excluded_op_ids_{};
std::unordered_set<std::string> quantize_enabled_op_types_{ std::unordered_set<std::string> quantize_enabled_op_types_{};
"concat",
"conv2d",
"depthwise_conv2d",
"fused_conv2d",
"elementwise_add",
"elementwise_mul",
"fc",
"matmul",
"nearest_interp",
"nearest_interp_v2",
"pool2d",
"prior_box",
"reshape2",
"transpose2",
"fusion_gru",
"fusion_lstm",
"multi_gru",
"slice",
"split"};
bool disable_mkldnn_fc_passes_{false}; bool disable_mkldnn_fc_passes_{false};
......
...@@ -29,11 +29,16 @@ void SetInt8Config(AnalysisConfig *cfg, ...@@ -29,11 +29,16 @@ void SetInt8Config(AnalysisConfig *cfg,
std::vector<paddle::PaddleTensor> data) { std::vector<paddle::PaddleTensor> data) {
cfg->SetModel(FLAGS_infer_model); cfg->SetModel(FLAGS_infer_model);
cfg->EnableMKLDNN(); cfg->EnableMKLDNN();
cfg->DisableMkldnnFcPasses(); // fc passes caused loss in accuracy
cfg->EnableMkldnnQuantizer(); cfg->EnableMkldnnQuantizer();
auto pass_builder = cfg->pass_builder(); auto pass_builder = cfg->pass_builder();
pass_builder->DeletePass("constant_folding_pass"); pass_builder->DeletePass("constant_folding_pass");
auto warmup_data = std::make_shared<std::vector<PaddleTensor>>(data); auto warmup_data = std::make_shared<std::vector<PaddleTensor>>(data);
cfg->mkldnn_quantizer_config()->SetEnabledOpTypes(
{"elementwise_add", "matmul", "matmul_v2", "fused_matmul"});
// Exclusion of several matmules that should not be quantized due to the fact
// that they reduce the accuracy of the model
cfg->mkldnn_quantizer_config()->SetExcludedOpIds(
{75, 172, 269, 366, 463, 560, 657, 754, 851, 948, 1045, 1142});
cfg->mkldnn_quantizer_config()->SetWarmupData(warmup_data); cfg->mkldnn_quantizer_config()->SetWarmupData(warmup_data);
cfg->mkldnn_quantizer_config()->SetWarmupBatchSize(FLAGS_batch_size); cfg->mkldnn_quantizer_config()->SetWarmupBatchSize(FLAGS_batch_size);
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册