未验证 提交 9e5f3a38 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add int8 support for matmul+elementwise_add fuse pass (#45077)

* Add int8 support for matmul+elementwiae_add fuse

* Corrections after review and ernie test fix
上级 d03ef054
...@@ -1892,11 +1892,19 @@ PDNode *patterns::Reshape2Matmul::operator()() { ...@@ -1892,11 +1892,19 @@ PDNode *patterns::Reshape2Matmul::operator()() {
return matmul_out; return matmul_out;
} }
PDNode *patterns::MatmulWithInputOps::operator()() { PDNode *patterns::MatmulWithInputOps::operator()(bool with_residual) {
auto prev_op_x = pattern->NewNode(prev_op_x_repr())->assert_is_op(); auto prev_op_x = pattern->NewNode(prev_op_x_repr())->assert_is_op();
auto prev_op_y = pattern->NewNode(prev_op_y_repr())->assert_is_op(); auto prev_op_y = pattern->NewNode(prev_op_y_repr())->assert_is_op();
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
if (!with_residual) {
matmul_op->assert_more([&](Node *x) {
return (!HasInput(x, "ResidualData") ||
x->Op()->Input("ResidualData").size() == 0);
});
}
auto matmul_in_x = pattern->NewNode(matmul_in_x_repr()) auto matmul_in_x = pattern->NewNode(matmul_in_x_repr())
->AsInput() ->AsInput()
->assert_is_op_input("matmul", "X"); ->assert_is_op_input("matmul", "X");
...@@ -1905,11 +1913,21 @@ PDNode *patterns::MatmulWithInputOps::operator()() { ...@@ -1905,11 +1913,21 @@ PDNode *patterns::MatmulWithInputOps::operator()() {
->assert_is_op_input("matmul", "Y"); ->assert_is_op_input("matmul", "Y");
auto matmul_out = pattern->NewNode(matmul_out_repr()) auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("matmul", "Out"); ->assert_is_op_output("matmul", "Out")
->assert_is_only_output_of_op("matmul");
std::vector<PDNode *> links_from{matmul_in_x, matmul_in_y};
if (with_residual) {
auto matmul_residual_data =
pattern->NewNode(matmul_residual_data_repr())
->AsInput()
->assert_is_op_input("matmul", "ResidualData");
links_from.push_back(matmul_residual_data);
}
prev_op_x->LinksTo({matmul_in_x}); prev_op_x->LinksTo({matmul_in_x});
prev_op_y->LinksTo({matmul_in_y}); prev_op_y->LinksTo({matmul_in_y});
matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out}); matmul_op->LinksFrom(links_from).LinksTo({matmul_out});
return matmul_out; return matmul_out;
} }
......
...@@ -1191,12 +1191,13 @@ struct MatmulWithInputOps : public PatternBase { ...@@ -1191,12 +1191,13 @@ struct MatmulWithInputOps : public PatternBase {
MatmulWithInputOps(PDPattern* pattern, const std::string& name_scope) MatmulWithInputOps(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_with_input_ops") {} : PatternBase(pattern, name_scope, "matmul_with_input_ops") {}
PDNode* operator()(); PDNode* operator()(bool with_residual);
PATTERN_DECL_NODE(prev_op_x); PATTERN_DECL_NODE(prev_op_x);
PATTERN_DECL_NODE(prev_op_y); PATTERN_DECL_NODE(prev_op_y);
PATTERN_DECL_NODE(matmul_in_x); PATTERN_DECL_NODE(matmul_in_x);
PATTERN_DECL_NODE(matmul_in_y); PATTERN_DECL_NODE(matmul_in_y);
PATTERN_DECL_NODE(matmul_op); PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_residual_data);
PATTERN_DECL_NODE(matmul_out); PATTERN_DECL_NODE(matmul_out);
}; };
......
...@@ -733,11 +733,11 @@ void CPUQuantizePass::QuantizeImmutable(Graph* graph, ...@@ -733,11 +733,11 @@ void CPUQuantizePass::QuantizeImmutable(Graph* graph,
LogQuantizedOpsCounter(immutable_type, quantize_immutable_count); LogQuantizedOpsCounter(immutable_type, quantize_immutable_count);
} }
void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { void CPUQuantizePass::QuantizeMatmul(Graph* graph, bool with_residual) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::MatmulWithInputOps matmul_pattern{pattern, name_scope_}; patterns::MatmulWithInputOps matmul_pattern{pattern, name_scope_};
matmul_pattern(); matmul_pattern(with_residual);
int quantize_matmul_count = 0; int quantize_matmul_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
...@@ -754,7 +754,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -754,7 +754,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(prev_op_y, prev_op_y, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(prev_op_y, prev_op_y, matmul_pattern);
// skip if prev ops are not quantized // skip if prev ops are not quantized
if (!IsOpDequantized(prev_op_x) || !IsOpDequantized(prev_op_y)) { if (!IsOpDequantized(prev_op_x) && !IsOpDequantized(prev_op_y)) {
MarkAndLogCannotQuantizeOp(matmul_op, MarkAndLogCannotQuantizeOp(matmul_op,
"No other quantizable operators nearby"); "No other quantizable operators nearby");
return; return;
...@@ -763,6 +763,15 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -763,6 +763,15 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);
auto has_output_scale = AreScalesPresentForNodes({matmul_out});
if (with_residual && !has_output_scale) {
MarkAndLogCannotQuantizeOp(
matmul_op,
"Matmul op with ResidualData input cannot be quantized "
"without output scale.");
return;
}
if (!AreScalesPresentForNodes({matmul_in_x, matmul_in_y})) { if (!AreScalesPresentForNodes({matmul_in_x, matmul_in_y})) {
MarkAndLogCannotQuantizeOp(matmul_op, MarkAndLogCannotQuantizeOp(matmul_op,
"No scale available for the operator"); "No scale available for the operator");
...@@ -780,6 +789,28 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -780,6 +789,28 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
"are different: x(%d), y(%d).", "are different: x(%d), y(%d).",
is_x_unsigned, is_x_unsigned,
is_y_unsigned)); is_y_unsigned));
if (with_residual) {
GET_IR_NODE_FROM_SUBGRAPH(
matmul_residual_data, matmul_residual_data, matmul_pattern);
if (!AreScalesPresentForNodes({matmul_residual_data})) {
MarkAndLogCannotQuantizeOp(matmul_op,
"No scale available for the operator");
return;
}
bool is_residual_unsigned{false};
auto residual_scale =
GetScaleValueForNode(matmul_residual_data, &is_residual_unsigned);
QuantizeInput(g,
matmul_op,
matmul_residual_data,
"ResidualData",
residual_scale,
is_residual_unsigned,
"Scale_in_eltwise");
}
QuantizeInput(g, QuantizeInput(g,
matmul_op, matmul_op,
matmul_in_x, matmul_in_x,
...@@ -814,7 +845,9 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -814,7 +845,9 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_matmul_count); AddStatis(quantize_matmul_count);
LogQuantizedOpsCounter("matmul", quantize_matmul_count); LogQuantizedOpsCounter("matmul",
quantize_matmul_count,
(with_residual ? "with residual connection" : ""));
} }
void CPUQuantizePass::QuantizeElementwise( void CPUQuantizePass::QuantizeElementwise(
...@@ -1132,7 +1165,8 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -1132,7 +1165,8 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeConcat(graph); QuantizeConcat(graph);
QuantizePriorBox(graph); QuantizePriorBox(graph);
QuantizeFc(graph); QuantizeFc(graph);
QuantizeMatmul(graph); QuantizeMatmul(graph, false /* with_residual_data */);
QuantizeMatmul(graph, true /* with_residual_data */);
QuantizeImmutable(graph, "reshape2", "X"); QuantizeImmutable(graph, "reshape2", "X");
QuantizeImmutable(graph, "transpose2", "X"); QuantizeImmutable(graph, "transpose2", "X");
QuantizeImmutable(graph, "slice", "Input"); QuantizeImmutable(graph, "slice", "Input");
......
...@@ -54,7 +54,7 @@ class CPUQuantizePass : public FusePassBase { ...@@ -54,7 +54,7 @@ class CPUQuantizePass : public FusePassBase {
void QuantizePool(Graph* graph) const; void QuantizePool(Graph* graph) const;
void QuantizeConcat(Graph* graph) const; void QuantizeConcat(Graph* graph) const;
void QuantizePriorBox(Graph* graph) const; void QuantizePriorBox(Graph* graph) const;
void QuantizeMatmul(Graph* graph) const; void QuantizeMatmul(Graph* graph, bool with_residual) const;
void QuantizeElementwise(Graph* graph, void QuantizeElementwise(Graph* graph,
const std::string& elementwise_type) const; const std::string& elementwise_type) const;
void QuantizeFusionGru(Graph* graph) const; void QuantizeFusionGru(Graph* graph) const;
......
...@@ -90,6 +90,7 @@ void SetOp(ProgramDesc* prog, ...@@ -90,6 +90,7 @@ void SetOp(ProgramDesc* prog,
} else if (type == "matmul") { } else if (type == "matmul") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]}); if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
if (inputs.size() > 2) op->SetInput("ResidualData", {inputs[2]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
op->SetAttr("Scale_x", 1.0f); op->SetAttr("Scale_x", 1.0f);
op->SetAttr("Scale_y", 1.0f); op->SetAttr("Scale_y", 1.0f);
...@@ -180,6 +181,11 @@ void CheckScales(const OpDesc* op, float scale, float shift) { ...@@ -180,6 +181,11 @@ void CheckScales(const OpDesc* op, float scale, float shift) {
scale_names.push_back("Scale_x"); scale_names.push_back("Scale_x");
scale_names.push_back("Scale_y"); scale_names.push_back("Scale_y");
scale_names.push_back("Scale_out"); scale_names.push_back("Scale_out");
if (type == "matmul") {
auto const& names = op->InputNames();
if (std::find(names.begin(), names.end(), "ResidualData") != names.end())
scale_names.push_back("Scale_in_eltwise");
}
} else if (type == "fusion_gru" || type == "fusion_lstm") { } else if (type == "fusion_gru" || type == "fusion_lstm") {
EXPECT_EQ(op->GetAttrIfExists<float>("Shift_data"), shift); EXPECT_EQ(op->GetAttrIfExists<float>("Shift_data"), shift);
EXPECT_EQ(op->GetAttrIfExists<std::vector<float>>("Scale_weights")[0], EXPECT_EQ(op->GetAttrIfExists<std::vector<float>>("Scale_weights")[0],
...@@ -579,7 +585,7 @@ INSTANTIATE_TEST_CASE_P( ...@@ -579,7 +585,7 @@ INSTANTIATE_TEST_CASE_P(
}); });
static const std::initializer_list<std::string> variable_names_matmul = { static const std::initializer_list<std::string> variable_names_matmul = {
"a", "b", "c", "d", "e", "f"}; "a", "b", "c", "d", "e", "f", "g", "h"};
ProgramDesc BuildProgramDescMatmul() { ProgramDesc BuildProgramDescMatmul() {
ProgramDesc prog; ProgramDesc prog;
...@@ -599,14 +605,28 @@ ProgramDesc BuildProgramDescMatmulNotQuantized() { ...@@ -599,14 +605,28 @@ ProgramDesc BuildProgramDescMatmulNotQuantized() {
for (auto& v : variable_names_matmul) { for (auto& v : variable_names_matmul) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "dropout", "Dropout", {"a"}, {"b"}, false); SetOp(&prog, "dropout", "Dropout1", {"a"}, {"b"}, false);
SetOp(&prog, "dequantize", "Dequantize", {"c"}, {"d"}, true); SetOp(&prog, "dropout", "Dropout2", {"c"}, {"d"}, false);
SetOp(&prog, "matmul", "Matmul", {"b", "d"}, {"e"}, true, "int8"); SetOp(&prog, "matmul", "Matmul", {"b", "d"}, {"e"}, true, "int8");
SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, "float32"); SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, "float32");
return prog; return prog;
} }
ProgramDesc BuildProgramDescMatmulResidual() {
ProgramDesc prog;
for (auto& v : variable_names_matmul) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true);
SetOp(&prog, "dequantize", "Dequantize2", {"c"}, {"d"}, true);
SetOp(&prog, "dequantize", "Dequantize3", {"e"}, {"f"}, true);
SetOp(&prog, "matmul", "Matmul", {"b", "d", "f"}, {"g"}, true, "int8");
SetOp(&prog, "dropout", "Dropout", {"g"}, {"h"}, true, "float32");
return prog;
}
TEST(CpuQuantizePass, matmul) { TEST(CpuQuantizePass, matmul) {
// 2 Quant + 2 IN + 1 DeQuant + 1 OUT // 2 Quant + 2 IN + 1 DeQuant + 1 OUT
int added_nodes = 6; int added_nodes = 6;
...@@ -623,7 +643,7 @@ TEST(CpuQuantizePass, matmul_not_quantized) { ...@@ -623,7 +643,7 @@ TEST(CpuQuantizePass, matmul_not_quantized) {
// nothing change // nothing change
int added_nodes = 0; int added_nodes = 0;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{"matmul", 1}, {"quantize", 0}, {"dequantize", 1}}; {"matmul", 1}, {"quantize", 0}, {"dequantize", 0}};
MainTest(BuildProgramDescMatmulNotQuantized(), MainTest(BuildProgramDescMatmulNotQuantized(),
variable_names_matmul, variable_names_matmul,
expected_operators, expected_operators,
...@@ -631,6 +651,18 @@ TEST(CpuQuantizePass, matmul_not_quantized) { ...@@ -631,6 +651,18 @@ TEST(CpuQuantizePass, matmul_not_quantized) {
1.0f); 1.0f);
} }
TEST(CpuQuantizePass, matmul_residual) {
// 3 Quant + 3 IN + 1 DeQuant + 1 OUT
int added_nodes = 8;
std::unordered_map<std::string, int> expected_operators = {
{"matmul", 1}, {"quantize", 3}, {"dequantize", 4}};
MainTest(BuildProgramDescMatmulResidual(),
variable_names_matmul,
expected_operators,
added_nodes,
SCALE * S8_MAX);
}
static const std::initializer_list<std::string> variable_names_elementwise = { static const std::initializer_list<std::string> variable_names_elementwise = {
"a", "b", "c", "d", "e", "f"}; "a", "b", "c", "d", "e", "f"};
......
...@@ -365,6 +365,8 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -365,6 +365,8 @@ void CpuPassStrategy::EnableMkldnnInt8() {
if (!use_mkldnn_int8_) { if (!use_mkldnn_int8_) {
passes_.clear(); passes_.clear();
passes_.push_back("quant_dequant_mkldnn_pass"); passes_.push_back("quant_dequant_mkldnn_pass");
passes_.push_back("mkldnn_placement_pass");
passes_.push_back("simplify_with_basic_ops_pass");
passes_.push_back("layer_norm_fuse_pass"); passes_.push_back("layer_norm_fuse_pass");
passes_.push_back("attention_lstm_fuse_pass"); passes_.push_back("attention_lstm_fuse_pass");
passes_.push_back("seqconv_eltadd_relu_fuse_pass"); passes_.push_back("seqconv_eltadd_relu_fuse_pass");
...@@ -386,10 +388,10 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -386,10 +388,10 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("matmul_scale_fuse_pass"); passes_.push_back("matmul_scale_fuse_pass");
passes_.push_back("gpu_cpu_map_matmul_to_mul_pass"); passes_.push_back("gpu_cpu_map_matmul_to_mul_pass");
passes_.push_back("repeated_fc_relu_fuse_pass"); passes_.push_back("repeated_fc_relu_fuse_pass");
passes_.push_back("mkldnn_placement_pass");
passes_.push_back("depthwise_conv_mkldnn_pass"); passes_.push_back("depthwise_conv_mkldnn_pass");
passes_.push_back("conv_bn_fuse_pass"); passes_.push_back("conv_bn_fuse_pass");
passes_.push_back("conv_eltwiseadd_bn_fuse_pass"); passes_.push_back("conv_eltwiseadd_bn_fuse_pass");
passes_.push_back("conv_affine_channel_mkldnn_fuse_pass");
passes_.push_back("conv_transpose_bn_fuse_pass"); passes_.push_back("conv_transpose_bn_fuse_pass");
passes_.push_back("conv_transpose_eltwiseadd_bn_fuse_pass"); passes_.push_back("conv_transpose_eltwiseadd_bn_fuse_pass");
passes_.push_back("conv_bias_mkldnn_fuse_pass"); passes_.push_back("conv_bias_mkldnn_fuse_pass");
...@@ -406,10 +408,10 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -406,10 +408,10 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("compute_propagate_scales_mkldnn_pass"); passes_.push_back("compute_propagate_scales_mkldnn_pass");
passes_.push_back("scale_matmul_fuse_pass"); passes_.push_back("scale_matmul_fuse_pass");
passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass"); passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass");
passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass");
passes_.push_back("cpu_quantize_placement_pass"); passes_.push_back("cpu_quantize_placement_pass");
passes_.push_back("cpu_quantize_pass"); passes_.push_back("cpu_quantize_pass");
passes_.push_back("cpu_quantize_squash_pass"); passes_.push_back("cpu_quantize_squash_pass");
passes_.push_back("simplify_with_basic_ops_pass");
passes_.push_back("mkldnn_inplace_pass"); passes_.push_back("mkldnn_inplace_pass");
passes_.push_back("runtime_context_cache_pass"); passes_.push_back("runtime_context_cache_pass");
} }
......
...@@ -779,10 +779,14 @@ class MatMulV2MKLDNNHandler ...@@ -779,10 +779,14 @@ class MatMulV2MKLDNNHandler
auto* residual_data = ctx.Input<Tensor>("ResidualData"); auto* residual_data = ctx.Input<Tensor>("ResidualData");
auto residual_data_tz = phi::vectorize(residual_data->dims()); auto residual_data_tz = phi::vectorize(residual_data->dims());
auto residual_data_md = memory::desc(residual_data_tz, auto residual_data_md = memory::desc(residual_data_tz,
dnnl::memory::data_type::f32, MKLDNNGetDataType<OT>(),
dnnl::memory::format_tag::abcd); dnnl::memory::format_tag::any);
post_operations.append_binary(dnnl::algorithm::binary_add, post_operations.append_binary(dnnl::algorithm::binary_add,
residual_data_md); residual_data_md);
if (ctx.HasAttr("Scale_in_eltwise")) {
float sum_scale = scale_out / ctx.Attr<float>("Scale_in_eltwise");
post_operations.append_sum(sum_scale);
}
} }
AppendActivation(ctx, post_operations); AppendActivation(ctx, post_operations);
......
...@@ -403,12 +403,13 @@ class Quant2Int8MkldnnPass(object): ...@@ -403,12 +403,13 @@ class Quant2Int8MkldnnPass(object):
def _optimize_fp32_graph(self, graph): def _optimize_fp32_graph(self, graph):
graph = self._update_activations(graph) graph = self._update_activations(graph)
graph = self._remove_ctrl_vars(graph) graph = self._remove_ctrl_vars(graph)
graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()])
# remove dropout ops
graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass')
graph = self._apply_pass(graph, 'layer_norm_fuse_pass') graph = self._apply_pass(graph, 'layer_norm_fuse_pass')
graph = self._apply_pass(graph, 'attention_lstm_fuse_pass') graph = self._apply_pass(graph, 'attention_lstm_fuse_pass')
graph = self._apply_pass(graph, 'seqconv_eltadd_relu_fuse_pass') graph = self._apply_pass(graph, 'seqconv_eltadd_relu_fuse_pass')
# graph = self._apply_pass(graph, 'seqpool_concat_fuse_pass')
graph = self._apply_pass(graph, 'seqpool_cvm_concat_fuse_pass')
# graph = self._apply_pass(graph, 'embedding_fc_lstm_fuse_pass')
graph = self._apply_pass(graph, 'fc_lstm_fuse_pass') graph = self._apply_pass(graph, 'fc_lstm_fuse_pass')
graph = self._apply_pass(graph, 'mul_lstm_fuse_pass') graph = self._apply_pass(graph, 'mul_lstm_fuse_pass')
graph = self._apply_pass(graph, 'fc_gru_fuse_pass') graph = self._apply_pass(graph, 'fc_gru_fuse_pass')
...@@ -427,8 +428,6 @@ class Quant2Int8MkldnnPass(object): ...@@ -427,8 +428,6 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'matmul_scale_fuse_pass') graph = self._apply_pass(graph, 'matmul_scale_fuse_pass')
graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_to_mul_pass') graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_to_mul_pass')
graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass') graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()])
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass') graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
graph = self._apply_pass(graph, 'conv_bn_fuse_pass') graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass') graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
...@@ -451,6 +450,11 @@ class Quant2Int8MkldnnPass(object): ...@@ -451,6 +450,11 @@ class Quant2Int8MkldnnPass(object):
'matmul_transpose_reshape_mkldnn_fuse_pass') 'matmul_transpose_reshape_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'batch_norm_act_fuse_pass') graph = self._apply_pass(graph, 'batch_norm_act_fuse_pass')
graph = self._apply_pass(graph, 'softplus_activation_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'softplus_activation_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'scale_matmul_fuse_pass')
graph = self._apply_pass(graph,
'reshape_transpose_matmul_mkldnn_fuse_pass')
graph = self._apply_pass(graph,
'matmul_elementwise_add_mkldnn_fuse_pass')
# the following pass should be the last one since it will work on all fused ops. # the following pass should be the last one since it will work on all fused ops.
graph = self._apply_pass(graph, 'runtime_context_cache_pass') graph = self._apply_pass(graph, 'runtime_context_cache_pass')
return graph return graph
...@@ -476,8 +480,6 @@ class Quant2Int8MkldnnPass(object): ...@@ -476,8 +480,6 @@ class Quant2Int8MkldnnPass(object):
return graph return graph
def _final_optimizations(self, graph): def _final_optimizations(self, graph):
# remove dropout ops
graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass')
# make some MKL-DNN ops working inplace # make some MKL-DNN ops working inplace
graph = self._apply_pass(graph, 'mkldnn_inplace_pass') graph = self._apply_pass(graph, 'mkldnn_inplace_pass')
return graph return graph
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册