From fddea674452eb8dd3b028a9bf64bef03b5030522 Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Thu, 26 Nov 2020 04:10:08 +0100 Subject: [PATCH] Fix cpu_bfloat16_pass (#28730) * Fix cpu_bfloat16_pass * Add output_format * Fix incorrect SetOutput * Change fromating --- .../framework/ir/graph_pattern_detector.cc | 30 +++ .../framework/ir/graph_pattern_detector.h | 20 ++ .../framework/ir/mkldnn/cpu_bfloat16_pass.cc | 218 +++++++++++++----- .../ir/mkldnn/cpu_bfloat16_pass_tester.cc | 156 +++++++++---- 4 files changed, 315 insertions(+), 109 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index e163f6c352d..c3f550c0ed8 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2181,6 +2181,36 @@ PDNode *patterns::FirstBfloat16Ops::operator()() { return op; } +PDNode *patterns::DuplicatedInputs::operator()() { + auto op = pattern->NewNode(op_repr())->assert_is_ops({"concat", "sum"}); + op->assert_more([&](Node *node) { + return node->Op()->GetAttrIfExists("mkldnn_data_type") == + "bfloat16"; + }); + return op; +} + +PDNode *patterns::UnnecessaryReorders::operator()() { + auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); + prev_op->assert_more([&](Node *node) { + return node->Op()->GetAttrIfExists("mkldnn_data_type") == + "bfloat16"; + }); + + auto *quant_in = pattern->NewNode(quant_in_repr()) + ->assert_is_op_input("quantize", "Input"); + + auto *quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize"); + + auto *quant_out = pattern->NewNode(quant_out_repr()) + ->assert_is_op_output("quantize", "Output"); + + prev_op->LinksTo({quant_in}); + quant_op->LinksFrom({quant_in}).LinksTo({quant_out}); + + return quant_out; +} + PDNode *patterns::MKLDNNInPlace::operator()() { const std::unordered_set &supported_op_types = { "abs", diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index a4e8d916e5b..491e896db48 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1273,6 +1273,26 @@ struct FirstBfloat16Ops : public PatternBase { PATTERN_DECL_NODE(op); }; +struct DuplicatedInputs : public PatternBase { + DuplicatedInputs(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "many_inputs_op") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(op); +}; + +struct UnnecessaryReorders : public PatternBase { + UnnecessaryReorders(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "unnecessary_reorders") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(prev_op); + PATTERN_DECL_NODE(quant_in); + PATTERN_DECL_NODE(quant_op); + PATTERN_DECL_NODE(quant_out); +}; + // Pattern used for enforcing inplace computation for in-place computation // supporting DNNL ops. softmax, batch_norm and layer_norm struct MKLDNNInPlace : public PatternBase { diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc index ae93025e784..9658d604520 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc @@ -33,58 +33,157 @@ void UnlinkNodes(ir::Node* a, ir::Node* b) { b->inputs.end()); } -void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const { +void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in, + int* quantize_counter) { + VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out")); + auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc); + + OpDesc q_desc; + q_desc.SetType("quantize"); + q_desc.SetInput("Input", std::vector({op_in->Name()})); + q_desc.SetOutput("Output", + std::vector({quantize_out_node->Name()})); + q_desc.SetAttr("Scale", 1.f); + q_desc.SetAttr("bfloat16", true); + q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout") + ? op->Op()->GetAttr("data_layout") + : std::string("NCHW")); + auto quantize_op = g->CreateOpNode(&q_desc); + + std::vector input_names; + for (auto name : op->Op()->InputNames()) { + for (auto input_name : op->Op()->Input(name)) { + if (input_name == op_in->Name()) input_names.push_back(name); + } + } + + PADDLE_ENFORCE_NE( + input_names.empty(), true, + platform::errors::NotFound( + "Operator before operator should have input as op output")); + + for (auto name = input_names.begin(); name < input_names.end(); name++) + op->Op()->SetInput(*name, + std::vector({quantize_out_node->Name()})); + + UnlinkNodes(op_in, op); + IR_NODE_LINK_TO(op_in, quantize_op); + IR_NODE_LINK_TO(quantize_op, quantize_out_node); + IR_NODE_LINK_TO(quantize_out_node, op); + (*quantize_counter)++; +} + +void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) { + auto inputs = op->inputs; + PADDLE_ENFORCE_GE(inputs.size(), 1, + platform::errors::InvalidArgument( + "OP(%s)'s inputs(%d) must be equal or greater than 1.", + op->Name(), inputs.size())); + PADDLE_ENFORCE_EQ(op->outputs.size(), 1, + platform::errors::InvalidArgument( + "OP(%s)'s outputs(%d) must be equal to 1.", op->Name(), + op->outputs.size())); + + OpDesc q_desc; + q_desc.SetType("quantize"); + + std::vector quantize_out_nodes(inputs.size()); + std::vector quantize_out_node_names(inputs.size()); + + for (size_t i = 0; i < inputs.size(); i++) { + VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out")); + quantize_out_nodes[i] = g->CreateVarNode(&quantize_out_desc); + quantize_out_node_names[i] = quantize_out_nodes[i]->Name(); + + q_desc.SetInput("Input", std::vector({inputs[i]->Name()})); + q_desc.SetOutput("Output", + std::vector({quantize_out_node_names[i]})); + q_desc.SetAttr("Scale", 1.f); + q_desc.SetAttr("bfloat16", true); + q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout") + ? op->Op()->GetAttr("data_layout") + : std::string("NCHW")); + auto quantize_op = g->CreateOpNode(&q_desc); + + UnlinkNodes(inputs[i], op); + IR_NODE_LINK_TO(inputs[i], quantize_op); + IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]); + IR_NODE_LINK_TO(quantize_out_nodes[i], op); + (*quantize_counter)++; + } + + op->Op()->SetInput("X", quantize_out_node_names); +} + +void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) { + GraphPatternDetector gpd; + patterns::DuplicatedInputs duplicated_inputs{gpd.mutable_pattern(), + "duplicated_inputs"}; + duplicated_inputs(); + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(op, op, duplicated_inputs); + AddQuantizes(g, op, quantize_counter); + }; + gpd(graph, handler); +} + +void RemoveUnnecessaryReorders(ir::Graph* graph, int* quantize_counter) { + GraphPatternDetector gpd; + patterns::UnnecessaryReorders unnecessary_reorders{gpd.mutable_pattern(), + "unnecessary_reorders"}; + unnecessary_reorders(); + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, unnecessary_reorders); + GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, unnecessary_reorders); + GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, unnecessary_reorders); + GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, unnecessary_reorders); + + std::string op_output_name; + for (auto name : prev_op->Op()->OutputNames()) + for (auto output_name : prev_op->Op()->Output(name)) + if (output_name == quant_in->Name()) op_output_name = name; + + PADDLE_ENFORCE_NE( + op_output_name.empty(), true, + platform::errors::NotFound( + "Operator before operator should have input as op output")); + + prev_op->Op()->SetOutput(op_output_name, + std::vector({quant_out->Name()})); + + IR_NODE_LINK_TO(prev_op, quant_out); + GraphSafeRemoveNodes(graph, {quant_in, quant_op}); + (*quantize_counter)--; + }; + gpd(graph, handler); +} + +void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) { GraphPatternDetector gpd; patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(), "first_bfloat16_ops"}; bfloat16_ops(); - int quantize_counter = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, bfloat16_ops); GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_ops); GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops); - - if (op->Op()->Type() != "conv2d" && prev_op->Op()->Type() != "quantize") { - VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out")); - auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc); - - // create a quantize op node - OpDesc q_desc; - q_desc.SetType("quantize"); - q_desc.SetInput("Input", std::vector({op_in->Name()})); - q_desc.SetOutput("Output", - std::vector({quantize_out_node->Name()})); - q_desc.SetAttr("Scale", 1.f); - q_desc.SetAttr("bfloat16", true); - q_desc.SetAttr("output_format", Has("data_layout") - ? Get("data_layout") - : "NCHW"); - auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied. - - std::string op_input_name; - for (auto name : op->Op()->InputNames()) { - for (auto input_name : op->Op()->Input(name)) { - if (input_name == op_in->Name()) op_input_name = name; - } - } - - PADDLE_ENFORCE_NE( - op_input_name.empty(), true, - platform::errors::NotFound( - "Operator before operator should have input as op output")); - - op->Op()->SetInput(op_input_name, - std::vector({quantize_out_node->Name()})); - - UnlinkNodes(op_in, op); - IR_NODE_LINK_TO(op_in, quantize_op); - IR_NODE_LINK_TO(quantize_op, quantize_out_node); - IR_NODE_LINK_TO(quantize_out_node, op); - quantize_counter++; + auto prev_op_type = prev_op->Op()->Type(); + if (op->Op()->Type() != "conv2d" && prev_op_type != "quantize" && + prev_op_type != "sum" && prev_op_type != "concat") { + AddQuantize(g, op, op_in, quantize_counter); } }; gpd(graph, handler); +} + +void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const { + int quantize_counter = 0; + AddReoderBeforeDuplicatedInputs(graph, &quantize_counter); + RemoveUnnecessaryReorders(graph, &quantize_counter); + AddReoderBeforeSingleInputs(graph, &quantize_counter); PrettyLogDetail("--- added %d quantize op before bfloat16 op", quantize_counter); } @@ -101,45 +200,42 @@ void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops); GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops); GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, bfloat16_ops); - if ((op->Op()->HasAttr("force_fp32_output") || op->Op()->HasProtoAttr("force_fp32_output")) && !op->Op()->GetAttrIfExists("fuse_residual_connection")) { op->Op()->SetAttr("force_fp32_output", true); force_fp32_counter++; } else if (op->Op()->Type() != "prior_box") { - // Create dequantize input variable - VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in")); - auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc); + VarDesc dequantize_out_desc(patterns::PDNodeName("dequantize", "out")); + auto* dequantize_out_node = g->CreateVarNode(&dequantize_out_desc); - // create a dequantize op node for output. OpDesc deq_desc; deq_desc.SetType("dequantize"); - deq_desc.SetInput("Input", - std::vector({dequantize_in_node->Name()})); - deq_desc.SetOutput("Output", std::vector({op_out->Name()})); + deq_desc.SetInput("Input", std::vector({op_out->Name()})); + deq_desc.SetOutput( + "Output", std::vector({dequantize_out_node->Name()})); deq_desc.SetAttr("Scale", 1.0f); auto dequantize_op = g->CreateOpNode(&deq_desc); - std::string op_output_name; - for (auto name : op->Op()->OutputNames()) { - for (auto output_name : op->Op()->Output(name)) { - if (output_name == op_out->Name()) op_output_name = name; + std::string next_op_input_name; + for (auto name : next_op->Op()->InputNames()) { + for (auto input_name : next_op->Op()->Input(name)) { + if (input_name == op_out->Name()) next_op_input_name = name; } } PADDLE_ENFORCE_NE( - op_output_name.empty(), true, + next_op_input_name.empty(), true, platform::errors::NotFound( - "Operator after operator should have input as op output")); - - op->Op()->SetOutput(op_output_name, std::vector( - {dequantize_in_node->Name()})); + "Operator before operator should have input as op output")); - UnlinkNodes(op, op_out); - IR_NODE_LINK_TO(op, dequantize_in_node); - IR_NODE_LINK_TO(dequantize_in_node, dequantize_op); - IR_NODE_LINK_TO(dequantize_op, op_out); + next_op->Op()->SetInput( + next_op_input_name, + std::vector({dequantize_out_node->Name()})); + UnlinkNodes(op_out, next_op); + IR_NODE_LINK_TO(op_out, dequantize_op); + IR_NODE_LINK_TO(dequantize_op, dequantize_out_node); + IR_NODE_LINK_TO(dequantize_out_node, next_op); dequantize_counter++; } }; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc index 15109db9832..ab8d3cbdfc0 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc @@ -42,60 +42,45 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, type == "dropout") { op->SetInput("X", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); - op->SetAttr("mkldnn_data_type", mkldnn_data_type); + if (type != "dropout") op->SetAttr("mkldnn_data_type", mkldnn_data_type); } else if (type == "fc") { op->SetInput("Input", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); op->SetAttr("mkldnn_data_type", mkldnn_data_type); - } else if (type == "concat") { + op->SetAttr("force_fp32_output", force_fp32_output); + } else if (type == "concat" || type == "sum") { op->SetInput("X", inputs); op->SetOutput("Out", outputs); op->SetAttr("mkldnn_data_type", mkldnn_data_type); - } else if (type == "matmul" || type == "elementwise_add") { + } else if (type == "matmul" || type == "elementwise_add" || + type == "elementwise_mul") { op->SetInput("X", {inputs[0]}); if (inputs.size() > 1) op->SetInput("Y", {inputs[1]}); op->SetOutput("Out", {outputs[0]}); op->SetAttr("mkldnn_data_type", mkldnn_data_type); + if (type == "matmul") op->SetAttr("force_fp32_output", force_fp32_output); + } else if (type == "layer_norm") { + op->SetInput("X", {inputs[0]}); + op->SetOutput("Y", {outputs[0]}); + op->SetAttr("mkldnn_data_type", mkldnn_data_type); } } +static const std::initializer_list variable_names{ + "z", "a", "b", "c", "d", "e", "f", "g", "h", "i"}; + void PreparePass(std::unique_ptr* graph, const ProgramDesc& prog, const std::initializer_list variable_names, int* original_nodes_num, int* current_nodes_num) { auto pass = PassRegistry::Instance().Get("cpu_bfloat16_pass"); - graph->reset(pass->Apply(graph->release())); - *original_nodes_num = (*graph)->Nodes().size(); (*graph).reset(pass->Apply((*graph).release())); *current_nodes_num = (*graph)->Nodes().size(); } -static const std::initializer_list variable_names{ - "z", "a", "b", "c", "d", "e", "f", "g", "h", "i"}; - -ProgramDesc BuildProgramDesc(bool use_mkldnn) { - ProgramDesc prog; - for (auto& v : variable_names) { - prog.MutableBlock(0)->Var(v); - } - SetOp(&prog, "dropout", "Dropout1", {"z"}, {"a"}, use_mkldnn, "float32"); - SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, "bfloat16"); - SetOp(&prog, "pool2d", "Pool1", {"b"}, {"c"}, use_mkldnn, "bfloat16"); - SetOp(&prog, "conv2d", "Conv1", {"c"}, {"d"}, use_mkldnn, "bfloat16"); - SetOp(&prog, "dropout", "Dropout2", {"d"}, {"e"}, use_mkldnn, "float32"); - SetOp(&prog, "transpose2", "Transpose1", {"e"}, {"f"}, use_mkldnn, - "bfloat16"); - SetOp(&prog, "reshape2", "Reshape1", {"f"}, {"g"}, use_mkldnn, "bfloat16"); - SetOp(&prog, "concat", "Concat1", {"g"}, {"h"}, use_mkldnn, "bfloat16"); - SetOp(&prog, "dropout", "Dropout3", {"h"}, {"i"}, use_mkldnn, "float32"); - - return prog; -} - -void MainTest(const ProgramDesc& prog, int conv_count, int pool_count, - int transpose_count, int quant_count, int dequant_count, - int added_nodes_count) { +void MainTest(const ProgramDesc& prog, int quant_count, int dequant_count, + int force_fp32_count, int added_nodes_count) { std::unique_ptr graph(new ir::Graph(prog)); int original_nodes_num, current_nodes_num; PreparePass(&graph, prog, variable_names, &original_nodes_num, @@ -103,39 +88,114 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count, int quantize_nodes_count = 0; int dequantize_nodes_count = 0; - int conv2d_nodes_count = 0; - int pool2d_nodes_count = 0; - int transpose2_nodes_count = 0; - + int force_fp32_nodes_count = 0; for (auto* node : graph->Nodes()) { if (node->IsOp()) { auto* op = node->Op(); - if (op->Type() == "conv2d") { - conv2d_nodes_count++; - } else if (op->Type() == "pool2d") { - pool2d_nodes_count++; - } else if (op->Type() == "transpose2") { - transpose2_nodes_count++; - } else if (op->Type() == "quantize") { + if (op->Type() == "quantize") { quantize_nodes_count++; } else if (op->Type() == "dequantize") { dequantize_nodes_count++; + } else if (op->Type() == "conv2d" || op->Type() == "matmul" || + op->Type() == "fc") { + if (op->GetAttrIfExists("force_fp32_output")) + force_fp32_nodes_count++; } } } - EXPECT_EQ(conv2d_nodes_count, conv_count); - EXPECT_EQ(pool2d_nodes_count, pool_count); - EXPECT_EQ(transpose2_nodes_count, transpose_count); EXPECT_EQ(quantize_nodes_count, quant_count); EXPECT_EQ(dequantize_nodes_count, dequant_count); + EXPECT_EQ(force_fp32_nodes_count, force_fp32_count); EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num); } -TEST(CpuQuantizePass, quantize) { +ProgramDesc BuildProgramDescConv(bool use_mkldnn) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "dropout", "Dropout", {"a"}, {"b"}, use_mkldnn, "float32"); + SetOp(&prog, "conv2d", "Conv1", {"b"}, {"c"}, use_mkldnn, "bfloat16"); + SetOp(&prog, "pool2d", "Pool", {"c"}, {"d"}, use_mkldnn, "bfloat16"); + SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, "bfloat16"); + SetOp(&prog, "transpose2", "Transpose", {"e"}, {"f"}, use_mkldnn, "float32"); + + return prog; +} + +TEST(CpuBfloat16Pass, convolution) { + bool use_mkldnn = true; + // 0 added + 1 force_fp32_output + int added_nodes = 0; + MainTest(BuildProgramDescConv(use_mkldnn), 0, 0, 1, added_nodes); +} + +ProgramDesc BuildProgramDescDoubleInput(bool use_mkldnn) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "dropout", "Dropout", {"a"}, {"b"}, use_mkldnn, "float32"); + SetOp(&prog, "matmul", "Matmul", {"b", "b"}, {"c"}, use_mkldnn, "bfloat16"); + SetOp(&prog, "transpose2", "Transpose", {"d"}, {"e"}, use_mkldnn, "float32"); + SetOp(&prog, "elementwise_add", "ElemetwiseAdd", {"c", "e"}, {"f"}, + use_mkldnn, "bfloat16"); + SetOp(&prog, "reshape2", "Reshape", {"f"}, {"g"}, use_mkldnn, "bfloat16"); + + return prog; +} + +TEST(CpuBfloat16Pass, double_input_ops) { + bool use_mkldnn = true; + // 2 quant + 2 quant out + int added_nodes = 4; + MainTest(BuildProgramDescDoubleInput(use_mkldnn), 2, 0, 0, added_nodes); +} + +ProgramDesc BuildProgramDescDuplicatedInput(bool use_mkldnn) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "dropout", "Dropout1", {"a"}, {"b"}, use_mkldnn, "float32"); + SetOp(&prog, "dropout", "Dropout2", {"c"}, {"d"}, use_mkldnn, "float32"); + SetOp(&prog, "concat", "Concat", {"b", "d"}, {"e"}, use_mkldnn, "bfloat16"); + SetOp(&prog, "transpose2", "Transpose", {"f"}, {"g"}, use_mkldnn, "float32"); + SetOp(&prog, "sum", "Sum", {"e", "g"}, {"h"}, use_mkldnn, "bfloat16"); + SetOp(&prog, "reshape2", "Reshape", {"h"}, {"i"}, use_mkldnn, "bfloat16"); + + return prog; +} + +TEST(CpuBfloat16Pass, duplicated_input_ops) { + bool use_mkldnn = true; + // 3 quant + 3 quant out + int added_nodes = 6; + MainTest(BuildProgramDescDuplicatedInput(use_mkldnn), 3, 0, 0, added_nodes); +} + +ProgramDesc BuildProgramDescDoubleOutputs(bool use_mkldnn) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "layer_norm", "LayerNorm1", {"a"}, {"b"}, use_mkldnn, + "bfloat16"); + SetOp(&prog, "dropout", "Dropout1", {"b"}, {"c"}, use_mkldnn, "float32"); + SetOp(&prog, "transpose2", "Transpose", {"b"}, {"d"}, use_mkldnn, "bfloat16"); + SetOp(&prog, "layer_norm", "LayerNorm2", {"d"}, {"e"}, use_mkldnn, + "bfloat16"); + SetOp(&prog, "reshape2", "Reshape", {"e"}, {"f"}, use_mkldnn, "float32"); + SetOp(&prog, "dropout", "Dropout2", {"e"}, {"g"}, use_mkldnn, "float32"); + + return prog; +} + +TEST(CpuBfloat16Pass, double_outputs_ops) { bool use_mkldnn = true; - // 1 quantize + 1 dequantize - int added_nodes = 2; - MainTest(BuildProgramDesc(use_mkldnn), 2, 1, 1, 1, 2, added_nodes); + // 3 dequant + 3 dequant out + int added_nodes = 6; + MainTest(BuildProgramDescDoubleOutputs(use_mkldnn), 0, 3, 0, added_nodes); } } // namespace ir -- GitLab