From c714926d2ceacc372c8b1e863c40213fbd397801 Mon Sep 17 00:00:00 2001 From: Tomasz Socha Date: Mon, 16 May 2022 15:32:13 +0200 Subject: [PATCH] Enable bfloat16 for VIT-OCR model. (#42758) * Clean-up bfloat16 tester * New blacklist mechanizm for dequantization * Style * Style II * Style III --- .../framework/ir/mkldnn/cpu_bfloat16_pass.cc | 19 ++++++++++++++++--- .../ir/mkldnn/cpu_bfloat16_pass_tester.cc | 16 +++++++--------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc index 62b2be712b..eebc87f5d9 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc @@ -186,9 +186,22 @@ class DeQuantizer final : public Quanter { // Checking whether a reorder from BF16 to FP32 // should be added after the output to the operator bool IsNotPermittedName(const std::string& output_name) const override { - // XShape is output in transpose2 and reshape2 operators used to store the - // shape and lod of X. So this output do not need dequantize before. - return (output_name == "XShape"); + std::unordered_map> block_list{ + {"layer_norm", + {"Mean", "Variance"}}}; // not used in inference in MKLDNN + + std::vector blocked_outputs{"XShape"}; // blocklist for any op + auto op_name = op->Name(); + if (block_list.count(op_name)) { + const auto& op_blocklist = block_list[op_name]; + blocked_outputs.insert(blocked_outputs.begin(), op_blocklist.begin(), + op_blocklist.end()); + } + + return std::any_of(blocked_outputs.begin(), blocked_outputs.end(), + [&output_name](const std::string& name) { + return name == output_name; + }); } std::string get_op_type() const override { return "dequantize"; }; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc index 877ee71fc2..3f5e9a1484 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc @@ -65,22 +65,20 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, static const std::initializer_list variable_names{ "z", "a", "b", "c", "d", "e", "f", "g", "h", "i"}; -void PreparePass(std::unique_ptr* graph, const ProgramDesc& prog, - const std::initializer_list variable_names, - int* original_nodes_num, int* current_nodes_num) { +void PreparePass(std::unique_ptr& graph, int* original_nodes_num, + int* current_nodes_num) { auto pass = PassRegistry::Instance().Get("cpu_bfloat16_pass"); - *original_nodes_num = (*graph)->Nodes().size(); - (*graph).reset(pass->Apply((*graph).release())); - *current_nodes_num = (*graph)->Nodes().size(); + *original_nodes_num = graph->Nodes().size(); + graph.reset(pass->Apply(graph.release())); + *current_nodes_num = graph->Nodes().size(); } void MainTest(const ProgramDesc& prog, const int& quant_count, const int& dequant_count, const int& added_nodes_count) { - std::unique_ptr graph(new ir::Graph(prog)); + auto graph = std::make_unique(prog); int original_nodes_num, current_nodes_num; - PreparePass(&graph, prog, variable_names, &original_nodes_num, - ¤t_nodes_num); + PreparePass(graph, &original_nodes_num, ¤t_nodes_num); int quantize_nodes_count = 0; int dequantize_nodes_count = 0; -- GitLab