未验证 提交 e421c6a6 编写于 作者: Z zyfncg 提交者: GitHub

Fix performance problem in BF16 models (#50283)

* fix performance drop in BF16 models

* fix test_cpu_quantize_squash_pass
上级 b8713309
......@@ -2346,7 +2346,7 @@ PDNode *patterns::ScaleQuant::operator()() {
return quant_op;
}
PDNode *patterns::QuantConv::operator()() {
PDNode *patterns::QuantConv::operator()(const std::string &conv_type) {
auto quant_in = pattern->NewNode(quant_in_repr())
->AsInput()
->assert_is_op_input("quantize", "Input");
......@@ -2354,8 +2354,8 @@ PDNode *patterns::QuantConv::operator()() {
auto conv_in = pattern->NewNode(conv_in_repr())
->AsInput()
->assert_is_op_input("conv2d", "Input");
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
->assert_is_op_input(conv_type, "Input");
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op(conv_type);
conv_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
......@@ -2845,6 +2845,7 @@ PDNode *patterns::Bfloat16Placement::operator()(
"clip",
"concat",
"conv2d",
"fused_conv2d",
"conv2d_transpose",
"elementwise_add",
"elementwise_mul",
......
......@@ -1402,7 +1402,7 @@ struct QuantConv : public PatternBase {
QuantConv(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quant_conv") {}
PDNode* operator()();
PDNode* operator()(const std::string& conv_type);
PATTERN_DECL_NODE(quant_in);
PATTERN_DECL_NODE(quant_op);
......
......@@ -48,6 +48,34 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() {
.End();
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsOptional()
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
......@@ -545,11 +573,17 @@ void CPUQuantizeSquashPass::ScaleQuantSquash(Graph* graph) const {
found_scale_quant_squash_count);
}
// squash quantize if is before bfloat16 conv2d
// squash quantize if is before bfloat16 conv2d or fused_conv2d
void CPUQuantizeSquashPass::QuantizeBf16Conv(Graph* graph) const {
QuantizeBf16ConvImpl(graph, "conv2d");
QuantizeBf16ConvImpl(graph, "fused_conv2d");
}
void CPUQuantizeSquashPass::QuantizeBf16ConvImpl(
Graph* graph, const std::string& conv_type) const {
GraphPatternDetector gpd;
patterns::QuantConv pattern{gpd.mutable_pattern(), "quant_conv"};
pattern();
pattern(conv_type);
int found_quant_conv_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
......@@ -577,8 +611,9 @@ void CPUQuantizeSquashPass::QuantizeBf16Conv(Graph* graph) const {
};
gpd(graph, handler);
AddStatis(found_quant_conv_squash_count);
PrettyLogDetail("--- squashed %d quantize with bfloat16 conv2d op",
found_quant_conv_squash_count);
PrettyLogDetail("--- squashed %d quantize with bfloat16 %s op",
found_quant_conv_squash_count,
conv_type);
}
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
......
......@@ -91,10 +91,12 @@ class CPUQuantizeSquashPass : public FusePassBase {
void ScaleQuantSquash(Graph* graph) const;
/*
* Squash quantize if is before bfloat16 conv2d
* Squash quantize if is before bfloat16 conv2d or fused_conv2d
*/
void QuantizeBf16Conv(Graph* graph) const;
void QuantizeBf16ConvImpl(Graph* graph, const std::string& conv_type) const;
const std::string name_scope_{"squash"};
};
......
......@@ -56,7 +56,6 @@ void SetOp(ProgramDesc* prog,
if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]);
op->SetInput("Input", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
op->SetOutput("Output", {outputs[0]});
const std::vector<int> strides({1, 1});
const std::vector<int> paddings({1, 1});
......@@ -702,7 +701,7 @@ ProgramDesc BuildQuantConv2dProgramDesc(const bool& use_mkldnn,
SetOp(&prog,
"conv2d",
"Conv2d",
{"b", "filter", "bias"},
{"b", "filter"},
{"c"},
use_mkldnn,
{},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册