未验证 提交 c714926d 编写于 作者: T Tomasz Socha 提交者: GitHub

Enable bfloat16 for VIT-OCR model. (#42758)

* Clean-up bfloat16 tester

* New blacklist mechanizm for dequantization

* Style

* Style II

* Style III
上级 8501fb00
...@@ -186,9 +186,22 @@ class DeQuantizer final : public Quanter { ...@@ -186,9 +186,22 @@ class DeQuantizer final : public Quanter {
// Checking whether a reorder from BF16 to FP32 // Checking whether a reorder from BF16 to FP32
// should be added after the output to the operator // should be added after the output to the operator
bool IsNotPermittedName(const std::string& output_name) const override { bool IsNotPermittedName(const std::string& output_name) const override {
// XShape is output in transpose2 and reshape2 operators used to store the std::unordered_map<std::string, std::vector<std::string>> block_list{
// shape and lod of X. So this output do not need dequantize before. {"layer_norm",
return (output_name == "XShape"); {"Mean", "Variance"}}}; // not used in inference in MKLDNN
std::vector<std::string> blocked_outputs{"XShape"}; // blocklist for any op
auto op_name = op->Name();
if (block_list.count(op_name)) {
const auto& op_blocklist = block_list[op_name];
blocked_outputs.insert(blocked_outputs.begin(), op_blocklist.begin(),
op_blocklist.end());
}
return std::any_of(blocked_outputs.begin(), blocked_outputs.end(),
[&output_name](const std::string& name) {
return name == output_name;
});
} }
std::string get_op_type() const override { return "dequantize"; }; std::string get_op_type() const override { return "dequantize"; };
......
...@@ -65,22 +65,20 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -65,22 +65,20 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
static const std::initializer_list<std::string> variable_names{ static const std::initializer_list<std::string> variable_names{
"z", "a", "b", "c", "d", "e", "f", "g", "h", "i"}; "z", "a", "b", "c", "d", "e", "f", "g", "h", "i"};
void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog, void PreparePass(std::unique_ptr<ir::Graph>& graph, int* original_nodes_num,
const std::initializer_list<std::string> variable_names, int* current_nodes_num) {
int* original_nodes_num, int* current_nodes_num) {
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_pass"); auto pass = PassRegistry::Instance().Get("cpu_bfloat16_pass");
*original_nodes_num = (*graph)->Nodes().size(); *original_nodes_num = graph->Nodes().size();
(*graph).reset(pass->Apply((*graph).release())); graph.reset(pass->Apply(graph.release()));
*current_nodes_num = (*graph)->Nodes().size(); *current_nodes_num = graph->Nodes().size();
} }
void MainTest(const ProgramDesc& prog, const int& quant_count, void MainTest(const ProgramDesc& prog, const int& quant_count,
const int& dequant_count, const int& added_nodes_count) { const int& dequant_count, const int& added_nodes_count) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); auto graph = std::make_unique<ir::Graph>(prog);
int original_nodes_num, current_nodes_num; int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names, &original_nodes_num, PreparePass(graph, &original_nodes_num, &current_nodes_num);
&current_nodes_num);
int quantize_nodes_count = 0; int quantize_nodes_count = 0;
int dequantize_nodes_count = 0; int dequantize_nodes_count = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册