未验证 提交 31f81685 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Correct lstm qat test (#51499)

上级 6ba0507d
......@@ -37,7 +37,7 @@ void ComputePropagateScalesMkldnnPass::GetTensorFromVector(
void ComputePropagateScalesMkldnnPass::GetQuantInfo(
ir::Graph* graph, StringPairMap* var_quant_scales) const {
std::unordered_map<std::string, std::vector<float>> info_map{};
GetInfoFromTheFirstOp(graph, "has_quant_info", "var_quant_scales", &info_map);
GetInfoFromTheTmpOp(graph, "has_quant_info", "var_quant_scales", &info_map);
for (auto iter = info_map.begin(); iter != info_map.end(); iter++) {
phi::DenseTensor tensor;
......@@ -510,9 +510,9 @@ void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const {
UpdateReluOutputScales(graph, &var_quant_scales);
PropagateScales(graph, &var_quant_scales, scale_immutable_ops);
// save var_quant_scales in the first op's attr
// save var_quant_scales in the temporary save op's attr
// for cpu_quantize_pass
SaveInfoInTheFirstOp(
SaveInfoInTheTmpOp(
graph, "has_quant_info", "var_quant_scales", var_quant_scales);
}
......
......@@ -435,7 +435,7 @@ bool CPUQuantizePass::IsOpQuantized(const Node* node) const {
}
void CPUQuantizePass::GetQuantInfo(Graph* graph) const {
GetInfoFromTheFirstOp(
GetInfoFromTheTmpOp(
graph, "has_quant_info", "var_quant_scales", var_quant_scales_);
}
......@@ -1250,6 +1250,10 @@ void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const {
bool is_x_unsigned{false};
auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);
// In the QAT process scales are prepared for only int8 data type,
// lstm scales should behave as input is int8 to get correct accuracy
is_x_unsigned = false;
double input_x_shift{128.};
if (is_x_unsigned) input_x_shift = 0.;
......
......@@ -25,7 +25,7 @@ namespace ir {
using StringPairMap =
std::unordered_map<std::string, std::pair<bool, phi::DenseTensor>>;
static void SaveInfoInTheFirstOp(
static void SaveInfoInTheTmpOp(
ir::Graph* graph,
const std::string& flag,
const std::string& key_suffix,
......@@ -33,34 +33,27 @@ static void SaveInfoInTheFirstOp(
VLOG(3) << "save variables in the first op's attr";
const std::string suffix = "_" + key_suffix + "_" + flag;
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch" ||
op_node->Op()->Type() == "fill_constant")
continue;
OpDesc op_desc;
op_desc.SetType("save");
auto* op_node = graph->CreateOpNode(&op_desc);
op_node->Op()->SetAttr(flag, true);
for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) {
op_node->Op()->SetAttr(iter->first + suffix, iter->second);
}
break;
}
}
static void SaveInfoInTheFirstOp(ir::Graph* graph,
static void SaveInfoInTheTmpOp(ir::Graph* graph,
const std::string& flag,
const std::string& key_suffix,
const StringPairMap& info_map) {
VLOG(3) << "save variables in the first op's attr";
const std::string suffix = "_" + key_suffix + "_" + flag;
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch" ||
op_node->Op()->Type() == "fill_constant")
continue;
OpDesc op_desc;
op_desc.SetType("save");
auto* op_node = graph->CreateOpNode(&op_desc);
op_node->Op()->SetAttr(flag, true);
for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) {
......@@ -70,11 +63,9 @@ static void SaveInfoInTheFirstOp(ir::Graph* graph,
iter->second.first);
op_node->Op()->SetAttr(iter->first + suffix, data_v);
}
break;
}
}
static void GetInfoFromTheFirstOp(
static void GetInfoFromTheTmpOp(
ir::Graph* graph,
const std::string& flag,
const std::string& key_suffix,
......@@ -84,9 +75,7 @@ static void GetInfoFromTheFirstOp(
const std::string suffix = "_" + key_suffix + "_" + flag;
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch")
continue;
if (!op_node->IsOp() || op_node->Op()->Type() != "save") continue;
auto* op_desc = op_node->Op();
if (op_desc->GetAttrIfExists<bool>(flag)) {
......@@ -102,12 +91,13 @@ static void GetInfoFromTheFirstOp(
op_desc->RemoveAttr(fake_name);
}
}
graph->RemoveNode(op_node);
break;
}
}
}
static void GetInfoFromTheFirstOp(ir::Graph* graph,
static void GetInfoFromTheTmpOp(ir::Graph* graph,
const std::string& flag,
const std::string& key_suffix,
StringPairMap* info_map) {
......@@ -117,9 +107,7 @@ static void GetInfoFromTheFirstOp(ir::Graph* graph,
const std::string suffix_is_unsigned = suffix + unsigned_flag;
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch")
continue;
if (!op_node->IsOp() || op_node->Op()->Type() != "save") continue;
auto* op_desc = op_node->Op();
if (op_desc->GetAttrIfExists<bool>(flag)) {
......@@ -150,6 +138,7 @@ static void GetInfoFromTheFirstOp(ir::Graph* graph,
op_desc->RemoveAttr(vector_name);
}
}
graph->RemoveNode(op_node);
break;
}
}
......
......@@ -754,9 +754,9 @@ void QuantDequantMkldnnPass::ApplyImpl(ir::Graph* graph) const {
UpdateActivations(graph);
RemoveCtrlVars(graph);
// save var_quant_scales in the first op's attr
// save var_quant_scales in the temporary save op's attr
// for compute_propagate_scales_mkldnn_pass
SaveInfoInTheFirstOp(
SaveInfoInTheTmpOp(
graph, "has_quant_info", "var_quant_scales", var_quant_scales);
}
......
......@@ -430,7 +430,6 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("simplify_with_basic_ops_pass");
passes_.push_back("quant_dequant_mkldnn_pass");
passes_.push_back("mkldnn_placement_pass");
passes_.push_back("constant_folding_pass");
passes_.push_back("squeeze2_transpose2_onednn_fuse_pass");
passes_.push_back("layer_norm_fuse_pass");
passes_.push_back("attention_lstm_fuse_pass");
......@@ -485,6 +484,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("quant_transpose2_dequant_onednn_fuse_pass");
passes_.push_back("int8_scale_calculation_mkldnn_pass");
passes_.push_back("params_quantization_mkldnn_pass");
passes_.push_back("constant_folding_pass");
}
use_mkldnn_int8_ = true;
#else
......
......@@ -116,10 +116,9 @@ class TestLstmModelPTQ(unittest.TestCase):
config.switch_ir_optim(True)
config.enable_mkldnn()
config.disable_mkldnn_fc_passes() # fc passes caused dnnl error
config.pass_builder().insert_pass(5, "fc_lstm_fuse_pass")
config.set_mkldnn_cache_capacity(mkldnn_cache_capacity)
if mode == "ptq":
# This pass to work properly, must be added before fc_fuse_pass
config.pass_builder().insert_pass(5, "fc_lstm_fuse_pass")
config.enable_quantizer()
config.quantizer_config().set_quant_data(warmup_data)
config.quantizer_config().set_quant_batch_size(1)
......@@ -244,7 +243,7 @@ class TestLstmModelPTQ(unittest.TestCase):
)
(quant_hx_acc, quant_ctc_acc, quant_fps) = self.run_program(
quant_model + "_int8",
quant_model,
infer_data,
num_threads,
mkldnn_cache_capacity,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册