未验证 提交 c714926d 编写于 作者: T Tomasz Socha 提交者: GitHub

Enable bfloat16 for VIT-OCR model. (#42758)

* Clean-up bfloat16 tester

* New blacklist mechanizm for dequantization

* Style

* Style II

* Style III
上级 8501fb00
......@@ -186,9 +186,22 @@ class DeQuantizer final : public Quanter {
// Checking whether a reorder from BF16 to FP32
// should be added after the output to the operator
bool IsNotPermittedName(const std::string& output_name) const override {
// XShape is output in transpose2 and reshape2 operators used to store the
// shape and lod of X. So this output do not need dequantize before.
return (output_name == "XShape");
std::unordered_map<std::string, std::vector<std::string>> block_list{
{"layer_norm",
{"Mean", "Variance"}}}; // not used in inference in MKLDNN
std::vector<std::string> blocked_outputs{"XShape"}; // blocklist for any op
auto op_name = op->Name();
if (block_list.count(op_name)) {
const auto& op_blocklist = block_list[op_name];
blocked_outputs.insert(blocked_outputs.begin(), op_blocklist.begin(),
op_blocklist.end());
}
return std::any_of(blocked_outputs.begin(), blocked_outputs.end(),
[&output_name](const std::string& name) {
return name == output_name;
});
}
std::string get_op_type() const override { return "dequantize"; };
......
......@@ -65,22 +65,20 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
static const std::initializer_list<std::string> variable_names{
"z", "a", "b", "c", "d", "e", "f", "g", "h", "i"};
void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog,
const std::initializer_list<std::string> variable_names,
int* original_nodes_num, int* current_nodes_num) {
void PreparePass(std::unique_ptr<ir::Graph>& graph, int* original_nodes_num,
int* current_nodes_num) {
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_pass");
*original_nodes_num = (*graph)->Nodes().size();
(*graph).reset(pass->Apply((*graph).release()));
*current_nodes_num = (*graph)->Nodes().size();
*original_nodes_num = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
*current_nodes_num = graph->Nodes().size();
}
void MainTest(const ProgramDesc& prog, const int& quant_count,
const int& dequant_count, const int& added_nodes_count) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto graph = std::make_unique<ir::Graph>(prog);
int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names, &original_nodes_num,
&current_nodes_num);
PreparePass(graph, &original_nodes_num, &current_nodes_num);
int quantize_nodes_count = 0;
int dequantize_nodes_count = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册